Tutorial on model assisted combichem

This tutorial shows how to use combichem in conjunction with a generative model

Performance Notes

Parts of this notebook aare CPU-constrained. If you have a multi-core machine, it is recommended that you uncomment and run the set_global_pool cells in the notebook. This will trigger the use of multiprocessing, which will result in 2-4x speedups.

This notebook may run slow on Collab due to CPU limitations.

If running on Collab, remember to change the runtime to GPU

import sys
sys.path.append('..')

from mrl.imports import *
from mrl.core import *
from mrl.chem import *
from mrl.templates.all import *

from mrl.torch_imports import *
from mrl.torch_core import *
from mrl.layers import *
from mrl.dataloaders import *
from mrl.train.reward import Reward
from mrl.train.agent import PredictiveAgent
from mrl.model_zoo import *
from mrl.combichem import *
/home/dmai/miniconda3/envs/mrl/lib/python3.7/importlib/_bootstrap.py:219: RuntimeWarning: to-Python converter for boost::shared_ptr<RDKit::FilterCatalogEntry const> already registered; second conversion method ignored.
  return f(*args, **kwds)

Model Assisted Combichem

Model-assisted combichem involves using a generative model in conjunction with a combichem process to optimize some score function.

Standard combichem consists of the following steps:

  1. Library generation - create the next iteration of the library
  2. Library scoring - apply a numeric score to each item in the library
  3. Library pruning - remove low scoring compounds

For more details on library generation, see the Combichem Tutorial.

Model-assisted combichem incorporates a generative model into the combichem workflow:

  1. Combichem Library generation - generate compounds with a combichem process
  2. Model sampling - generate compounds from a generative model
  3. Library scoring - apply a numeric score to each item in the library
  4. Library pruning - remove low scoring compounds
  5. Model training - train the generative model on high scoring compounds

We use the generative model to add compounds to the library. After scoring and pruning, we train the generative model on high scoring compounds

Template

We will use the following template to constrain chemical space

smarts = ['[#6](=[#16])(-[#7])-[#7]',
        '[#6]=[#6]=[#6]',
        '[#7][F,Cl,Br,I]',
        '[*]#[Cl,Br]',
        '[#6;!R]=[#6;!R]-[#6;!R]=[#6;!R]',
        '[#6]#[#6]',
        '[#15]',
        '[#16]',
        '[*]=[#17,#9,#35]',
        '[*]=[*]=[*]',
        '[*]-[#6]=[#6H2]',
        '[#7]~[#8]',
        '[#7]~[#7]',
        '[*;R]=[*;!R]']

template = Template([ValidityFilter(), 
                     SingleCompoundFilter(), 
                     RotBondFilter(None, 8),
                     HeteroatomFilter(None, 8),
                     ChargeFilter(None, 0),
                     MaxRingFilter(None, 6),
                     MinRingFilter(5, None),
                     HBDFilter(None, 5),
                     HBAFilter(None, 10),
                     MolWtFilter(None, 500),
                     LogPFilter(None, 5),
                     SAFilter(None, 7),
                     BridgeheadFilter(None,0),
                     PAINSAFilter(),
                     ExclusionFilter(smarts, criteria='any'),
                     RotChainFilter(None, 7),
                     ChargeFilter(0,0),
                     RingFilter(None, 4)
                    ],
                    [], 
                    fail_score=-1., log=False, use_lookup=False)

Reward

For the reward, we will load a scikit-learn linear regression model trained to predict affinity against erbB1 using molecular fingerprints.

This score function is extremely simple and won't translate well to affinity. It is used as a lightweight example

class FP_Regression_Score():
    def __init__(self, fname):
        self.model = torch.load(fname)
        self.fp_function = partial(failsafe_fp, fp_function=ECFP6)
        
    def __call__(self, samples):
        mols = to_mols(samples)
        fps = maybe_parallel(self.fp_function, mols)
        fps = [fp_to_array(i) for i in fps]
        x_vals = np.stack(fps)
        preds = self.model.predict(x_vals)
        return preds
    
# if in the repo
reward_function = FP_Regression_Score('../files/erbB1_regression.sklearn')

# if in Collab
# download_files()
# reward_function = FP_Regression_Score('files/erbB1_regression.sklearn')

reward = Reward(reward_function, weight=1.)

Combichem

Here we set up our combichem module with a list of mutators and crossovers

mutators = [
    ChangeAtom(['6', '7', '8', '9', '17', '35']),
    AppendAtomSingle(['C', 'N', 'O', 'F', 'Cl', 'Br']),
    AppendAtomsDouble(['C', 'N', 'O']),
    AppendAtomsTriple(),
    DeleteAtom(),
    ChangeBond(),
    InsertAtomSingle(['C', 'N', 'O']),
    InsertAtomDouble(['C', 'N']),
    InsertAtomTriple(),
    AddRing(),
    ShuffleNitrogen(20)
]

mc = MutatorCollection(mutators)

crossovers = [FragmentCrossover()]
cbc = CombiChem(mc, crossovers, template=template, rewards=[reward],
                prune_percentile=70, max_library_size=400, log=True, p_explore=0.2)

Generative Model

We load a pre-trained generative model. This is a LSTM-based language model trained on the ZINC database

agent = LSTM_LM_Small_ZINC_NC(drop_scale=0.3)

Helper Functions

def step(cbc, agent, n_gen, sl):
    new_library = cbc.build_generation()
    new_library = cbc.clean_library(new_library)
    cbc.append_data(new_library)
    preds, _ = agent.model.sample_no_grad(n_gen, sl)
    smiles = agent.reconstruct(preds)
    cbc.append_data(smiles)
    cbc.score_library()
    cbc.prune_library()
    
def train_from_cbc(cbc, agent, ds_size, epochs, bs, lr):
    df = pd.concat([cbc.library, cbc.old_library])
    subset = df.iloc[df.score.nlargest(ds_size).index]
    agent.update_dataset_from_inputs(subset.smiles.values)
    agent.train_supervised(bs, epochs, lr, silent=True)

Model Assisted Combichem

 
for i in range(5):
    for j in range(5):
        step(cbc, agent, 1024, 90)
        
    train_from_cbc(cbc, agent, 6000, 3, 128, 1e-4)
    print(cbc.library.score.mean())
10.300824455022813
13.433134039044381
15.450325035452842
16.517450484484435
17.25814539551735