Small molecule affinity optimization

Affinity Optimization

This tutorial runs an end to end workflow for designing high affinity ligands using generative screening. We will design potential ligands against EGFR, a protein implicated in several types of cancer.

Here is an outline of the workflow we will follow:

  1. Design score function
  2. Design chemical space
  3. Load pre-trained model
  4. Fine-tune pre-trained model
  5. Reinforcement learning
  6. Analysis

Performance Notes

The workflow in this notebook is more CPU-constrained than GPU-constrained due to the need to evaluate samples on CPU. If you have a multi-core machine, it is recommended that you uncomment and run the set_global_pool cells in the notebook. This will trigger the use of multiprocessing, which will result in 2-4x speedups.

This notebook may run slow on Collab due to CPU limitations.

If running on Collab, remember to change the runtime to GPU

import sys
sys.path.append('..')

from mrl.imports import *
from mrl.core import *
from mrl.chem import *
from mrl.templates.all import *

from mrl.torch_imports import *
from mrl.torch_core import *
from mrl.layers import *
from mrl.dataloaders import *
from mrl.g_models.all import *
from mrl.vocab import *
from mrl.policy_gradient import *
from mrl.train.all import *
from mrl.model_zoo import *
from mrl.combichem import *
from sklearn.metrics import r2_score
/home/dmai/miniconda3/envs/mrl/lib/python3.7/importlib/_bootstrap.py:219: RuntimeWarning: to-Python converter for boost::shared_ptr<RDKit::FilterCatalogEntry const> already registered; second conversion method ignored.
  return f(*args, **kwds)
os.makedirs('untracked_files', exist_ok=True)

The Target

EGFR is a transmembrane protein receptor for growth factors. When growth factors bind to EGFR on the cell surface, it initiates a cascade of signals inside the cell. This triggers the cell to initiate gene transcription and cell proliferation.

Overexpression of EGFR can lead to accelerated cell division which can result in tumors. This mechanism has been implicated in several types of cancer, including adenocarcinoma of the lung and glioblastoma.

Data

We will specifically target the ErbB1 form of EFGR, which has the ChEMBL ID CHEMBL203. Our dataset comes from ChEMBL. It contains affinity measurements from 6509 small molecule drugs.

df = pd.read_csv('../files/erbB1_affinity_data.csv')

# if in Collab
# download_files()
# df = pd.read_csv('files/erbB1_affinity_data.csv')
df.head()
smiles neg_log_ic50 chembl_id set
0 Brc1cc2c(NCc3ccccc3)ncnc2s1 6.617983 CHEMBL3416599 valid
1 Brc1cc2c(NCc3ccccn3)ncnc2s1 5.102153 CHEMBL3416616 train
2 Brc1cc2c(NCc3cccs3)ncnc2s1 5.862013 CHEMBL3416619 train
3 Brc1cc2c(NCc3ccncc3)ncnc2s1 5.410833 CHEMBL3416614 train
4 Brc1cc2c(Nc3ccccc3)ncnc2s1 7.096910 CHEMBL3416621 train
df.shape
(6467, 4)
df['smiles_ns'] = df.smiles.map(lambda x: remove_stereo(x))

The affinity value is given as -log_10(IC50) These values range from 0 to 12, where higher values correspond with higher affinity

df.neg_log_ic50.hist()
<AxesSubplot:>

Approved Drugs

There are seven approved drugs in the dataset. All drugs are FDA approved except for Icotinib, which is only approved in China.

drug_smiles = [
    'COc1cc2ncnc(Nc3ccc(F)c(Cl)c3)c2cc1OCCCN1CCOCC1',
    'C#Cc1cccc(Nc2ncnc3cc(OCCOC)c(OCCOC)cc23)c1',
    'CN(C)CC=CC(=O)Nc1cc2c(Nc3ccc(F)c(Cl)c3)ncnc2cc1OC1CCOC1',
    'COc1cc(N2CCC(N3CCN(C)CC3)CC2)ccc1Nc1ncc(Cl)c(Nc2ccccc2P(C)(C)=O)n1',
    'C#Cc1cccc(Nc2ncnc3cc4c(cc23)OCCOCCOCCO4)c1',
    'CS(=O)(=O)CCNCc1ccc(-c2ccc3ncnc(Nc4ccc(OCc5cccc(F)c5)c(Cl)c4)c3c2)o1',
    'C=CC(=O)Nc1cc(Nc2nccc(-c3cn(C)c4ccccc34)n2)c(OC)cc1N(C)CCN(C)C'
]

drug_names = [
    'Gefitinib',
    'Erlotinib',
    'Afatinib',
    'Brigatinib',
    'Icotinib',
    'Lapatinib',
    'Osimertinib'
]

drug_dict = {drug_smiles[i]:drug_names[i] for i in range(len(drug_names))}
subset = df[df.smiles_ns.isin(drug_smiles)].copy()
subset.reset_index(inplace=True, drop=True)
subset['name'] = subset.smiles_ns.map(lambda x: drug_dict[x])
subset.drop_duplicates(subset='name', inplace=True)
mols = to_mols(subset.smiles_ns.values)
names = subset.name.values
acts = subset.neg_log_ic50.values
legends = [f"{names[i]}, -log(IC50)={acts[i]:.3f}" for i in range(len(names))]
draw_mols(mols, legends=legends, mols_per_row=4)

It is worth noting that the approved drugs are not the highest affinity binders in the dataset. The top 8 highest affinity binders are shown below. This illustrates how affinity alone is not sufficient for making a drug

subset2 = df.iloc[df.neg_log_ic50.nlargest(8).index]

mols = to_mols(subset2.smiles_ns.values)
acts = subset2.neg_log_ic50.values
legends = [f"-log(IC50)={acts[i]:.3f}" for i in range(len(mols))]
draw_mols(mols, legends=legends, mols_per_row=4)

Score Function

Now we want to develop a score function for predicting binding affinity. We are going to use a MLP-type model to predict affinity from an ECFP6 fingerprint.

We will train on 90% of the data and validate on the 10% held out.

df_train = df[df.set=='train']
df_valid = df[df.set=='valid']
reward_model = MLP(2048, [1024, 512, 256, 128], 1, [0.2, 0.2, 0.2, 0.2], outrange=[0,15])
r_ds = Vec_Prediction_Dataset(df_train.smiles_ns.values, df_train.neg_log_ic50.values, ECFP6)
r_agent = PredictiveAgent(reward_model, MSELoss(), r_ds, opt_kwargs={'lr':1e-3})
r_agent.train_supervised(32, 10, 1e-3, percent_valid=0.1)
Epoch Train Loss Valid Loss Time
0 1.54977 0.28657 00:06
1 1.05367 0.65758 00:06
2 1.08841 0.31990 00:06
3 0.43909 0.24786 00:06
4 0.46526 0.33699 00:06
5 0.30168 0.30161 00:06
6 0.20798 0.51298 00:06
7 0.18604 0.34009 00:06
8 0.11575 0.30871 00:06
9 0.19196 0.32855 00:06

Optional: save score function weights

 

Optional: to load the exact weights used, run the following:

 
valid_ds = Vec_Prediction_Dataset(df_valid.smiles_ns.values, df_valid.neg_log_ic50.values, ECFP6)

valid_dl = valid_ds.dataloader(256, num_workers=0, shuffle=False)
r_agent.model.eval();

preds = []
targs = []

with torch.no_grad():
    for i, batch in enumerate(valid_dl):
        batch = to_device(batch)
        x,y = batch
        pred = r_agent.model(x)
        preds.append(pred.detach().cpu())
        targs.append(y.detach().cpu())
        
preds = torch.cat(preds).numpy()
targs = torch.cat(targs).numpy()

preds = preds.squeeze(-1)

Our score function has an r^2 value of about 0.72 on the validation dataset

fig, ax = plt.subplots()

ax.scatter(targs, preds, alpha=0.5, s=1)
plt.xlabel('Target')
plt.ylabel('Prediction')

slope, intercept = np.polyfit(targs, preds, 1)
ax.plot(np.array(ax.get_xlim()), intercept + slope*np.array(ax.get_xlim()), c='r')

plt.text(5., 9., 'R-squared = %0.3f' % r2_score(targs, preds));

We should also take a look at the prediction distribution for our known actives. The predicted values will differ somewhat from the actual values. This will give us a sense of what score we want to see from the model

df['preds'] = r_agent.predict_data(df.smiles.values).detach().cpu().numpy()
df.preds.max()
9.820236
np.percentile(df.preds, 99)
9.123940906524659

The maximum predicted value is 9.8. A compound scoring 9.12 or higher would be in the top 1% of all known actives

Chemical Space

Next we need to develop our chemical space. This is where we decide what compounds will be allowed and what compounds will be removed.

Getting the right filtering parameters makes a huge difference in compound quality. In practice, finding the right constraints is an interative process. First run a generative screen. Then examine the highest scoring compounds and look for undesirable properties or structural features. Then update the template and iterate.

smarts = ['[#6](=[#16])(-[#7])-[#7]',
        '[#6]=[#6]=[#6]',
        '[#7][F,Cl,Br,I]',
        '[*]#[Cl,Br,I]',
        '[*]-[Cl,Br,I]-[*]',
        '[#6;!R]=[#6;!R]-[#6;!R]=[#6;!R]',
        '[#6]#[#6]',
        '[#15]',
        '[#16]',
        '[*]=[#17,#9,#35]',
        '[*]=[*]=[*]',
        '[*]-[#6]=[#6H2]',
        '[#7]~[#8]',
        '[#7]~[#7]',
        '[*;R]=[*;!R]']

template = Template([ValidityFilter(), 
                     SingleCompoundFilter(), 
                     RotBondFilter(None, 6),
                     HeteroatomFilter(None, 8),
                     ChargeFilter(None, 0),
                     MaxRingFilter(None, 6),
                     MinRingFilter(5, None),
                     RingFilter(None, 5),
                     HBDFilter(None, 5),
                     HBAFilter(None, 10),
                     MolWtFilter(None, 500),
                     LogPFilter(None, 5),
                     SAFilter(None, 7),
                     BridgeheadFilter(None,0),
                     PAINSAFilter(),
                     ExclusionFilter(smarts, criteria='any'),
                     RotChainFilter(None, 5)
                    ],
                    [], 
                    fail_score=-1., log=False, use_lookup=False)

template_cb = TemplateCallback(template, prefilter=True)

Load Model

We load the LSTM_LM_Small_ZINC_NC model. This is a basic LSTM-based language model trained on part of the ZINC database

agent = LSTM_LM_Small_ZINC_NC(drop_scale=0.3, opt_kwargs={'lr':5e-5})

Fine-Tune Model

The pretrained model we loaded is a very general model that can produce a high diversity of structures. However, what we actually want are structues with high affinity for EGFR that conform to our template. We can get our model in a better starting shape by fine-tuning the pretrained weights

Fine-Tuning Part 1: Template Tuning

Currently, about 60-70% of compounds generated by the model won't pass the template. This can slow training by wasting time generating, evaluating and filtering unwanted compounds. We can avoid this by first fine-tuning on a dataset of compounds that pass our filter.

Note it is recommended you do not run this section on Collab because it will be extremely slow. If you're just perusing the notebook on Collab, skip to the next fine tuning section.

For template tuning, we first generate a dataset of compounds from the model. Then we screen these compounds with the template and fiter for passing compounds. Finally we fine-tune on the compounds.

This can also be done on a dataset of your choosing if you have one on hand

 

Here we generate aa set of SMILES. If you have some time on your hands, feel free to generate more. The generative model is capable of producing millions of unique SMILES with a low rate of invalid or duplicate SMILES

%%time
all_smiles = set()
for i in range(200):
    preds, _ = agent.model.sample_no_grad(2500, 120)
    smiles = agent.reconstruct(preds)
    mols = to_mols(smiles)
    smiles = [smiles[i] for i in range(len(smiles)) if mols[i] is not None]
    all_smiles.update(set(smiles))
CPU times: user 22min 7s, sys: 43.8 s, total: 22min 51s
Wall time: 3min 3s
 

Here we filter the SMILES with the template

all_smiles = list(all_smiles)
hps = np.array(template(all_smiles))
df2 = pd.DataFrame(all_smiles, columns=['smiles'])
df2 = df2[hps]

Optional: save dataset

df2.to_csv('untracked_files/pretrain_data.csv', index=False)
 

Now we fine-tune on our new template-conforming dataset

agent.update_dataset_from_inputs(df2.smiles.values)
agent.train_supervised(128, 4, 1e-5)
agent.base_to_model()
Epoch Train Loss Valid Loss Time
0 0.39135 0.39605 01:11
1 0.32032 0.39432 01:11
2 0.36329 0.39362 01:12
3 0.37961 0.39354 01:11

Fine-Tuning Part 2: Dataset Tuning

Next we fine-tune on the dataset of known actives.

This is actually a step to really think about. This will bias the model towards known chemotypes. This will significantly reduce the number of iterations needed for the model to find high scoring compounds. However, the model will likely converge on compounds similar to known chemotypes.

On the other hand, we could skip this stage of fine-tuning. If we did this, the model would take longer to converge (at least 3x more batches), but we have a much greater chance of finding new chemotypes.

In the interest of time, we do the additional fine-tuning to speed up convergence

agent.update_dataset_from_inputs(df[df.neg_log_ic50>7.].smiles.values)
agent.train_supervised(32, 6, 5e-5)

agent.update_dataset_from_inputs(df[df.neg_log_ic50>8.].smiles.values)
agent.train_supervised(32, 5, 5e-5)

agent.update_dataset_from_inputs(df[df.neg_log_ic50>9.].smiles.values)
agent.train_supervised(32, 3, 5e-5)

agent.base_to_model()
Epoch Train Loss Valid Loss Time
0 0.68854 0.67482 00:08
1 0.44465 0.48596 00:08
2 0.38428 0.40745 00:08
3 0.46983 0.37164 00:08
4 0.41797 0.35819 00:08
5 0.36653 0.35644 00:08
Epoch Train Loss Valid Loss Time
0 0.36252 0.35782 00:07
1 0.37530 0.34022 00:07
2 0.26538 0.32968 00:07
3 0.30116 0.32402 00:07
4 0.36090 0.32323 00:07
Epoch Train Loss Valid Loss Time
0 0.33252 0.25572 00:06
1 0.28551 0.25221 00:06
2 0.21042 0.25156 00:06

Optional: save fine-tuned weights

 
 

Reinforcement Learning

Now we enter the reinforcement learning stage

Loss

We use PPO as our policy gradient loss

pg = PPO(0.99,
        0.5,
        lam=0.95,
        v_coef=0.5,
        cliprange=0.3,
        v_cliprange=0.3,
        ent_coef=0.01,
        kl_target=0.03,
        kl_horizon=3000,
        scale_rewards=True)

loss = PolicyLoss(pg, 'PPO', 
                   value_head=ValueHead(256), 
                   v_update_iter=2, 
                   vopt_kwargs={'lr':1e-3})

Reward

Here we pass the reward agent we trained earlier to a callback

reward_model = MLP(2048, [1024, 512, 256, 128], 1, [0.2, 0.2, 0.2, 0.2], outrange=[0,15])

r_ds = Vec_Prediction_Dataset(['C'], [0], partial(failsafe_fp, fp_function=ECFP6))

r_agent = PredictiveAgent(reward_model, MSELoss(), r_ds, opt_kwargs={'lr':1e-3})

r_agent.load_weights('untracked_files/egfr_affinity_mlp.pt')
# r_agent.load_state_dict(model_from_url('egfr_affinity_mlp.pt')) # optional - load exact weights

r_agent.model.eval();

freeze(r_agent.model)

reward = Reward(r_agent.predict_data, weight=1.)

aff_reward = RewardCallback(reward, 'affinity')

Samplers

We create the following samplers:

  • sampler1 ModelSampler: this samples from the main model
  • sampler2 ModelSampler: this samples from the baseline model
  • sampler3 LogSampler: this samples high scoring samples from the log
  • sampler4 CombichemSampler: this sampler runs combichem generation on the top scoring compounds. The combination of generative models with combichem greatly accelerates finding high scoring compounds
  • sampler5 DatasetSampler: this sprinkles in a small amount of known actives into each buffer build

Note on the use of combichem. Adding combichem to the sampler mix greatly accelerates convergence to high scoring compounds. However, combichem tends to end up with "weirder" structures that are more exploitative of the score function. These structures can be removed by updating the template

gen_bs = 1500

sampler1 = ModelSampler(agent.vocab, agent.model, 'live', 400, 0., gen_bs)
sampler2 = ModelSampler(agent.vocab, agent.base_model, 'base', 400, 0., gen_bs)
sampler3 = LogSampler('samples', 'rewards', 10, 98, 100)

mutators = [
    ChangeAtom(['6', '7', '8', '9', '17', '35']),
    AppendAtomSingle(['C', 'N', 'O', 'F', 'Cl', 'Br']),
    AppendAtomsDouble(['C', 'N', 'O']),
    AppendAtomsTriple(),
    DeleteAtom(),
    ChangeBond(),
    InsertAtomSingle(['C', 'N', 'O']),
    InsertAtomDouble(['C', 'N']),
    InsertAtomTriple(),
    AddRing(),
    ShuffleNitrogen(20)
]

mc = MutatorCollection(mutators)

crossovers = [FragmentCrossover()]

cbc = CombiChem(mc, crossovers, template=template, rewards=[reward],
                prune_percentile=70, max_library_size=400, log=True, p_explore=0.2)

sampler4 = CombichemSampler(cbc, 20, 98, 0.2, 1, 'rewards', 'combichem')

sampler5 = DatasetSampler(df[df.neg_log_ic50>8].smiles.values, 'data', buffer_size=4)

samplers = [sampler1, sampler2, sampler3, sampler4, sampler5]

Callbacks

Additional callbacks

  • SupervisedCB: runs supervised training on the top 2% of samples every 20 batches
  • MaxCallback: prints the max reward for each batch
  • PercentileCallback: prints the 90th percentile score each batch
  • Timeout: prohibits training on the same sample more than once every 10 batches
supervised_cb = SupervisedCB(agent, 20, 0.5, 98, 1e-4, 64, epochs=2)
live_max = MaxCallback('rewards', 'live')
live_p90 = PercentileCallback('rewards', 'live', 90)
tm_cb = Timeout(10)

cbs = [supervised_cb, live_p90, live_max, tm_cb]

Environment and Train

Now we can put together our Environment and run the training process

env = Environment(agent, template_cb, samplers=samplers, rewards=[aff_reward], losses=[loss],
                 cbs=cbs)
set_global_pool(min(10, os.cpu_count()))
env.fit(128, 100, 300, 20)
iterations rewards rewards_final new diversity bs template valid affinity timeout PPO rewards_live_p90 rewards_live_max
0 5.666 5.666 1.000 1.000 92 0.000 1.000 5.666 1.000 0.219 6.893 8.482
20 6.631 6.631 0.981 1.000 105 0.000 1.000 6.631 0.820 0.500 6.278 7.010
40 6.673 6.673 0.926 1.000 95 0.000 1.000 6.673 0.742 0.514 6.375 8.306
60 6.683 6.683 0.899 1.000 89 0.000 1.000 6.683 0.695 0.565 6.708 7.328
80 6.789 6.789 0.872 1.000 94 0.000 1.000 6.789 0.734 0.620 6.420 7.791
100 7.236 7.236 0.823 1.000 96 0.000 1.000 7.236 0.750 0.701 6.758 7.478
120 6.947 6.947 0.738 1.000 80 0.000 1.000 6.947 0.625 0.622 6.978 8.673
140 6.584 6.584 0.809 1.000 89 0.000 1.000 6.584 0.695 0.556 5.933 6.508
160 7.251 7.251 0.756 1.000 90 0.000 1.000 7.251 0.703 0.576 7.361 8.551
180 7.141 7.141 0.724 1.000 87 0.000 1.000 7.141 0.680 0.627 6.860 8.366
200 7.375 7.375 0.600 1.000 100 0.000 1.000 7.375 0.781 0.568 8.632 8.732
220 7.598 7.598 0.614 1.000 88 0.000 1.000 7.598 0.688 0.487 8.303 8.915
240 7.762 7.762 0.566 1.000 99 0.000 1.000 7.762 0.773 0.406 8.181 8.630
260 7.568 7.568 0.650 1.000 100 0.000 1.000 7.568 0.781 0.510 8.346 9.289
280 7.590 7.590 0.659 1.000 88 0.000 1.000 7.590 0.688 0.449 8.636 8.872
env.log.plot_metrics()

Evaluation

We determined before that a compound scoring 9.12 or higher would be in the top 1% of samples. Looking at the log, we see 1017 novel samples with a score of more than 9.12 created by the generative process.

The generative run also found 9 samples with a score of 9.9 or higher, putting them in the top 0.14% of compounds relative to the original dataset

env.log.df[(env.log.df.affinity>=9.12) & (env.log.df.sources != 'data_buffer')].shape
(1017, 7)
subset = env.log.df[(env.log.df.affinity>=9.9) & ~(env.log.df.sources=='data_buffer')]
mols = to_mols(subset.samples.values)
legends = [f'{subset.rewards.values[i]:.3f}' for i in range(subset.shape[0])]
draw_mols(mols, legends=legends)

In the interest of time, we will stop here. That said, the model isn't done. If you invest more time into training, the model will continue to find higher scoring compounds. Eventually, the model will fully exploit the score function and produce compounds that make no sense from a medchem perspective but still recieve a high score.

Optional: save log dataframe