Affinity Optimization
This tutorial runs an end to end workflow for designing high affinity ligands using generative screening. We will design potential ligands against EGFR, a protein implicated in several types of cancer.
Here is an outline of the workflow we will follow:
- Design score function
- Design chemical space
- Load pre-trained model
- Fine-tune pre-trained model
- Reinforcement learning
- Analysis
Performance Notes
The workflow in this notebook is more CPU-constrained than GPU-constrained due to the need to evaluate samples on CPU. If you have a multi-core machine, it is recommended that you uncomment and run the set_global_pool
cells in the notebook. This will trigger the use of multiprocessing, which will result in 2-4x speedups.
This notebook may run slow on Collab due to CPU limitations.
If running on Collab, remember to change the runtime to GPU
import sys
from mrl.imports import *
from mrl.core import *
from mrl.chem import *
from mrl.templates.all import *
from mrl.torch_imports import *
from mrl.torch_core import *
from mrl.layers import *
from mrl.dataloaders import *
from mrl.g_models.all import *
from mrl.vocab import *
from mrl.policy_gradient import *
from mrl.train.all import *
from mrl.model_zoo import *
from mrl.combichem import *
from sklearn.metrics import r2_score
os.makedirs('untracked_files', exist_ok=True)
Overexpression of EGFR can lead to accelerated cell division which can result in tumors. This mechanism has been implicated in several types of cancer, including adenocarcinoma of the lung and glioblastoma.
df = pd.read_csv('../files/erbB1_affinity_data.csv')
# if in Collab
# download_files()
# df = pd.read_csv('files/erbB1_affinity_data.csv')
df['smiles_ns'] = x: remove_stereo(x))
The affinity value is given as -log_10(IC50)
These values range from 0 to 12, where higher values correspond with higher affinity
drug_smiles = [
drug_names = [
drug_dict = {drug_smiles[i]:drug_names[i] for i in range(len(drug_names))}
subset = df[df.smiles_ns.isin(drug_smiles)].copy()
subset.reset_index(inplace=True, drop=True)
subset['name'] = x: drug_dict[x])
subset.drop_duplicates(subset='name', inplace=True)
mols = to_mols(subset.smiles_ns.values)
names =
acts = subset.neg_log_ic50.values
legends = [f"{names[i]}, -log(IC50)={acts[i]:.3f}" for i in range(len(names))]
draw_mols(mols, legends=legends, mols_per_row=4)
It is worth noting that the approved drugs are not the highest affinity binders in the dataset. The top 8 highest affinity binders are shown below. This illustrates how affinity alone is not sufficient for making a drug
subset2 = df.iloc[df.neg_log_ic50.nlargest(8).index]
mols = to_mols(subset2.smiles_ns.values)
acts = subset2.neg_log_ic50.values
legends = [f"-log(IC50)={acts[i]:.3f}" for i in range(len(mols))]
draw_mols(mols, legends=legends, mols_per_row=4)
Score Function
Now we want to develop a score function for predicting binding affinity. We are going to use a MLP-type model to predict affinity from an ECFP6
We will train on 90% of the data and validate on the 10% held out.
df_train = df[df.set=='train']
df_valid = df[df.set=='valid']
reward_model = MLP(2048, [1024, 512, 256, 128], 1, [0.2, 0.2, 0.2, 0.2], outrange=[0,15])
r_ds = Vec_Prediction_Dataset(df_train.smiles_ns.values, df_train.neg_log_ic50.values, ECFP6)
r_agent = PredictiveAgent(reward_model, MSELoss(), r_ds, opt_kwargs={'lr':1e-3})
r_agent.train_supervised(32, 10, 1e-3, percent_valid=0.1)
Optional: save score function weights
Optional: to load the exact weights used, run the following:
valid_ds = Vec_Prediction_Dataset(df_valid.smiles_ns.values, df_valid.neg_log_ic50.values, ECFP6)
valid_dl = valid_ds.dataloader(256, num_workers=0, shuffle=False)
preds = []
targs = []
with torch.no_grad():
for i, batch in enumerate(valid_dl):
batch = to_device(batch)
x,y = batch
pred = r_agent.model(x)
preds =
targs =
preds = preds.squeeze(-1)
Our score function has an r^2 value of about 0.72 on the validation dataset
fig, ax = plt.subplots()
ax.scatter(targs, preds, alpha=0.5, s=1)
slope, intercept = np.polyfit(targs, preds, 1)
ax.plot(np.array(ax.get_xlim()), intercept + slope*np.array(ax.get_xlim()), c='r')
plt.text(5., 9., 'R-squared = %0.3f' % r2_score(targs, preds));
We should also take a look at the prediction distribution for our known actives. The predicted values will differ somewhat from the actual values. This will give us a sense of what score we want to see from the model
df['preds'] = r_agent.predict_data(df.smiles.values).detach().cpu().numpy()
np.percentile(df.preds, 99)
The maximum predicted value is 9.8
. A compound scoring 9.12
or higher would be in the top 1% of all known actives
Chemical Space
Next we need to develop our chemical space. This is where we decide what compounds will be allowed and what compounds will be removed.
Getting the right filtering parameters makes a huge difference in compound quality. In practice, finding the right constraints is an interative process. First run a generative screen. Then examine the highest scoring compounds and look for undesirable properties or structural features. Then update the template and iterate.
smarts = ['[#6](=[#16])(-[#7])-[#7]',
template = Template([ValidityFilter(),
RotBondFilter(None, 6),
HeteroatomFilter(None, 8),
ChargeFilter(None, 0),
MaxRingFilter(None, 6),
MinRingFilter(5, None),
RingFilter(None, 5),
HBDFilter(None, 5),
HBAFilter(None, 10),
MolWtFilter(None, 500),
LogPFilter(None, 5),
SAFilter(None, 7),
ExclusionFilter(smarts, criteria='any'),
RotChainFilter(None, 5)
fail_score=-1., log=False, use_lookup=False)
template_cb = TemplateCallback(template, prefilter=True)
Load Model
We load the LSTM_LM_Small_ZINC_NC
model. This is a basic LSTM-based language model trained on part of the ZINC database
agent = LSTM_LM_Small_ZINC_NC(drop_scale=0.3, opt_kwargs={'lr':5e-5})
Fine-Tune Model
The pretrained model we loaded is a very general model that can produce a high diversity of structures. However, what we actually want are structues with high affinity for EGFR that conform to our template. We can get our model in a better starting shape by fine-tuning the pretrained weights
Fine-Tuning Part 1: Template Tuning
Currently, about 60-70% of compounds generated by the model won't pass the template. This can slow training by wasting time generating, evaluating and filtering unwanted compounds. We can avoid this by first fine-tuning on a dataset of compounds that pass our filter.
Note it is recommended you do not run this section on Collab because it will be extremely slow. If you're just perusing the notebook on Collab, skip to the next fine tuning section.
For template tuning, we first generate a dataset of compounds from the model. Then we screen these compounds with the template and fiter for passing compounds. Finally we fine-tune on the compounds.
This can also be done on a dataset of your choosing if you have one on hand
Here we generate aa set of SMILES. If you have some time on your hands, feel free to generate more. The generative model is capable of producing millions of unique SMILES with a low rate of invalid or duplicate SMILES
all_smiles = set()
for i in range(200):
preds, _ = agent.model.sample_no_grad(2500, 120)
smiles = agent.reconstruct(preds)
mols = to_mols(smiles)
smiles = [smiles[i] for i in range(len(smiles)) if mols[i] is not None]
Here we filter the SMILES with the template
all_smiles = list(all_smiles)
hps = np.array(template(all_smiles))
df2 = pd.DataFrame(all_smiles, columns=['smiles'])
df2 = df2[hps]
Optional: save dataset
df2.to_csv('untracked_files/pretrain_data.csv', index=False)
Now we fine-tune on our new template-conforming dataset
agent.train_supervised(128, 4, 1e-5)
Fine-Tuning Part 2: Dataset Tuning
Next we fine-tune on the dataset of known actives.
This is actually a step to really think about. This will bias the model towards known chemotypes. This will significantly reduce the number of iterations needed for the model to find high scoring compounds. However, the model will likely converge on compounds similar to known chemotypes.
On the other hand, we could skip this stage of fine-tuning. If we did this, the model would take longer to converge (at least 3x more batches), but we have a much greater chance of finding new chemotypes.
In the interest of time, we do the additional fine-tuning to speed up convergence
agent.train_supervised(32, 6, 5e-5)
agent.train_supervised(32, 5, 5e-5)
agent.train_supervised(32, 3, 5e-5)
Optional: save fine-tuned weights
We use PPO
as our policy gradient loss
pg = PPO(0.99,
loss = PolicyLoss(pg, 'PPO',
reward_model = MLP(2048, [1024, 512, 256, 128], 1, [0.2, 0.2, 0.2, 0.2], outrange=[0,15])
r_ds = Vec_Prediction_Dataset(['C'], [0], partial(failsafe_fp, fp_function=ECFP6))
r_agent = PredictiveAgent(reward_model, MSELoss(), r_ds, opt_kwargs={'lr':1e-3})
# r_agent.load_state_dict(model_from_url('')) # optional - load exact weights
reward = Reward(r_agent.predict_data, weight=1.)
aff_reward = RewardCallback(reward, 'affinity')
We create the following samplers:
sampler1 ModelSampler
: this samples from the main modelsampler2 ModelSampler
: this samples from the baseline modelsampler3 LogSampler
: this samples high scoring samples from the logsampler4 CombichemSampler
: this sampler runs combichem generation on the top scoring compounds. The combination of generative models with combichem greatly accelerates finding high scoring compoundssampler5 DatasetSampler
: this sprinkles in a small amount of known actives into each buffer build
Note on the use of combichem. Adding combichem to the sampler mix greatly accelerates convergence to high scoring compounds. However, combichem tends to end up with "weirder" structures that are more exploitative of the score function. These structures can be removed by updating the template
gen_bs = 1500
sampler1 = ModelSampler(agent.vocab, agent.model, 'live', 400, 0., gen_bs)
sampler2 = ModelSampler(agent.vocab, agent.base_model, 'base', 400, 0., gen_bs)
sampler3 = LogSampler('samples', 'rewards', 10, 98, 100)
mutators = [
ChangeAtom(['6', '7', '8', '9', '17', '35']),
AppendAtomSingle(['C', 'N', 'O', 'F', 'Cl', 'Br']),
AppendAtomsDouble(['C', 'N', 'O']),
InsertAtomSingle(['C', 'N', 'O']),
InsertAtomDouble(['C', 'N']),
mc = MutatorCollection(mutators)
crossovers = [FragmentCrossover()]
cbc = CombiChem(mc, crossovers, template=template, rewards=[reward],
prune_percentile=70, max_library_size=400, log=True, p_explore=0.2)
sampler4 = CombichemSampler(cbc, 20, 98, 0.2, 1, 'rewards', 'combichem')
sampler5 = DatasetSampler(df[df.neg_log_ic50>8].smiles.values, 'data', buffer_size=4)
samplers = [sampler1, sampler2, sampler3, sampler4, sampler5]
Additional callbacks
: runs supervised training on the top 2% of samples every 20 batchesMaxCallback
: prints the max reward for each batchPercentileCallback
: prints the 90th percentile score each batchTimeout
: prohibits training on the same sample more than once every 10 batches
supervised_cb = SupervisedCB(agent, 20, 0.5, 98, 1e-4, 64, epochs=2)
live_max = MaxCallback('rewards', 'live')
live_p90 = PercentileCallback('rewards', 'live', 90)
tm_cb = Timeout(10)
cbs = [supervised_cb, live_p90, live_max, tm_cb]
env = Environment(agent, template_cb, samplers=samplers, rewards=[aff_reward], losses=[loss],
set_global_pool(min(10, os.cpu_count())), 100, 300, 20)
We determined before that a compound scoring 9.12 or higher would be in the top 1% of samples. Looking at the log, we see 1017 novel samples with a score of more than 9.12 created by the generative process.
The generative run also found 9 samples with a score of 9.9 or higher, putting them in the top 0.14% of compounds relative to the original dataset
env.log.df[(env.log.df.affinity>=9.12) & (env.log.df.sources != 'data_buffer')].shape
subset = env.log.df[(env.log.df.affinity>=9.9) & ~(env.log.df.sources=='data_buffer')]
mols = to_mols(subset.samples.values)
legends = [f'{subset.rewards.values[i]:.3f}' for i in range(subset.shape[0])]
draw_mols(mols, legends=legends)
In the interest of time, we will stop here. That said, the model isn't done. If you invest more time into training, the model will continue to find higher scoring compounds. Eventually, the model will fully exploit the score function and produce compounds that make no sense from a medchem perspective but still recieve a high score.
Optional: save log dataframe