Small molecule affinity optimization with active learning

Affinity Optimization with Active Learning

This tutorial shows how to use active learning in conjunction with generative models to find high scoring compounds in a way that minimizes the number of compounds that need to be scored. This technique is particularly advantageous when the score function being used has a high compute requirement.

This notebook follows from the Affinity Optimization tutorial. Consult the Affinity Optimization tutorial for details on the affinity task and the score function developed.

Performance Notes

The workflow in this notebook is more CPU-constrained than GPU-constrained due to the need to evaluate samples on CPU. If you have a multi-core machine, it is recommended that you uncomment and run the set_global_pool cells in the notebook. This will trigger the use of multiprocessing, which will result in 2-4x speedups.

This notebook may run slow on Collab due to CPU limitations.

If running on Collab, remember to change the runtime to GPU

import sys

from mrl.imports import *
from mrl.core import *
from mrl.chem import *
from mrl.templates.all import *

from mrl.torch_imports import *
from mrl.torch_core import *
from mrl.layers import *
from mrl.dataloaders import *
from mrl.g_models.all import *
from mrl.vocab import *
from mrl.policy_gradient import *
from mrl.train.all import *
from mrl.model_zoo import *
from mrl.combichem import *
from sklearn.metrics import r2_score
/home/dmai/miniconda3/envs/mrl/lib/python3.7/importlib/ RuntimeWarning: to-Python converter for boost::shared_ptr<RDKit::FilterCatalogEntry const> already registered; second conversion method ignored.
  return f(*args, **kwds)


The dataset contains binding data of compounds against EGFR obtained from CHEMBL

df = pd.read_csv('files/erbB1_affinity_data.csv')
smiles neg_log_ic50 chembl_id set
0 Brc1cc2c(NCc3ccccc3)ncnc2s1 6.617983 CHEMBL3416599 valid
1 Brc1cc2c(NCc3ccccn3)ncnc2s1 5.102153 CHEMBL3416616 train
2 Brc1cc2c(NCc3cccs3)ncnc2s1 5.862013 CHEMBL3416619 train
3 Brc1cc2c(NCc3ccncc3)ncnc2s1 5.410833 CHEMBL3416614 train
4 Brc1cc2c(Nc3ccccc3)ncnc2s1 7.096910 CHEMBL3416621 train
(6467, 4)
df['smiles_ns'] = x: remove_stereo(x))

Chemical Space

Here we construct the template from the Affinity Optimization tutorial

smarts = ['[#6](=[#16])(-[#7])-[#7]',

template = Template([ValidityFilter(), 
                     RotBondFilter(None, 8),
                     HeteroatomFilter(None, 8),
                     ChargeFilter(None, 0),
                     MaxRingFilter(None, 6),
                     MinRingFilter(5, None),
                     HBDFilter(None, 5),
                     HBAFilter(None, 10),
                     MolWtFilter(None, 500),
                     LogPFilter(None, 5),
                     SAFilter(None, 7),
                     ExclusionFilter(smarts, criteria='any'),
                     RotChainFilter(None, 7)
                    fail_score=-1., log=False, use_lookup=False)

template_cb = TemplateCallback(template, prefilter=True)

Load Model

We load the LSTM_LM_Small_ZINC_NC model. This is a basic LSTM-based language model trained on part of the ZINC database

agent = LSTM_LM_Small_ZINC_NC(drop_scale=0.3, opt_kwargs={'lr':5e-5})


Here we load the reward function from the Affinity Optimization tutorial

reward_model = MLP(2048, [1024, 512, 256, 128], 1, [0.2, 0.2, 0.2, 0.2], outrange=[0,15])

r_ds = Vec_Prediction_Dataset(['C'], [0], partial(failsafe_fp, fp_function=ECFP6))

r_agent = PredictiveAgent(reward_model, MSELoss(), r_ds, opt_kwargs={'lr':1e-3})

r_agent.load_state_dict(model_from_url('')) # optional - load exact weights



reward = Reward(r_agent.predict_data, weight=1.)

aff_reward = RewardCallback(reward, 'affinity')


We create the following samplers:

  • sampler1 ModelSampler: this samples from the main model
  • sampler2 ModelSampler: this samples from the baseline model
  • sampler4 CombichemSampler: this sampler runs combichem generation on the top scoring compounds. The combination of generative models with combichem greatly accelerates finding high scoring compounds

Note on the use of combichem. Adding combichem to the sampler mix greatly accelerates convergence to high scoring compounds. However, combichem tends to end up with "weirder" structures that are more exploitative of the score function. These structures can be removed by updating the template

mutators = [
    ChangeAtom(['6', '7', '8', '9', '17', '35']),
    AppendAtomSingle(['C', 'N', 'O', 'F', 'Cl', 'Br']),
    AppendAtomsDouble(['C', 'N', 'O']),
    InsertAtomSingle(['C', 'N', 'O']),
    InsertAtomDouble(['C', 'N']),

mc = MutatorCollection(mutators)

crossovers = [FragmentCrossover()]

cbc = CombiChem(mc, crossovers, template=template, rewards=[reward],
                prune_percentile=70, max_library_size=400, log=True, p_explore=0.2)
gen_bs = 1500

sampler1 = ModelSampler(agent.vocab, agent.model, 'live', 500, 0., gen_bs)
sampler2 = ModelSampler(agent.vocab, agent.base_model, 'base', 500, 0., gen_bs)
sampler3 = CombichemSampler(cbc, 150, 98, 0.2, 1, 'rewards', 'combichem')
sampler4 = LogSampler('samples', 'rewards', 10, 95, 100)
sampler5 = DatasetSampler(df[df.neg_log_ic50>8].smiles.values, 'data', buffer_size=10)

samplers = [sampler1, sampler2, sampler3, sampler4, sampler5]


Additional callbacks

supervised_cb = SupervisedCB(agent, 10, 0.5, 98, 5e-4, 64, epochs=5, silent=True)
live_max = MaxCallback('rewards', 'live')
live_p90 = PercentileCallback('rewards', 'live', 90)
bfr_cb = BufferSizeCallback()

cbs = [supervised_cb, live_p90, live_max, bfr_cb]

Active Learning

Now we get to the active learning part. We will construct a buffer that contains a machine learning model. This model estimates the score of each item in the buffer. During sampling, the buffer uses these scores to select promising compounds.

The idea here is that the buffer model is a very fast model that can generate predictions in ~1s per compound. We can pair this quick model with a more costly score function (ie docking, molecular dynamics, etc). The fast score function lets us predict over a large number of compounds and choose the most promising while limiting the compute invested in the slow score function.

Since this is a tutorial, we're using a score function that's also fast. This notebook serves as a guide in how to implement an active learning setup. Developing big fancy score functions is left to the user

d_vocab = len(agent.vocab.itos)
d_embedding = 256
d_latent = 1
filters = [256, 512, 1024]
kernel_sizes = [7, 7, 7]
strides = [2,2,2]
dropouts = [0.2, 0.2, 0.2]
mlp_dims = [512, 256, 128]
mlp_drops = [0.2, 0.2, 0.2]
d_out = 1

encoder = Conv_Encoder(

active_model = ScaledEncoder(encoder, outrange=[0,15])

r_ds = Text_Prediction_Dataset(['C'], [0], agent.vocab)

active_agent = PredictiveAgent(active_model, MSELoss(), r_ds, opt_kwargs={'lr':1e-3})

p_total = 1.
refresh_predictions = 10     # update predictions every 10 batches
pred_bs = 2048              # prediction batch size
supervised_frequency = 8    # how often to do offline supervised training 
supervised_epochs = 2       # number of epochs for offline training
supervised_bs = 32          # supervised training bs
supervised_lr = 1e-3        # supervised training learning rate
pct_argmax=0.6              # percent of batch to sample via argmax rather than weighted sampling

buffer = PredictiveBuffer(p_total,

Environment and Train

Now we can put together our Environment and run the training process

env = Environment(agent, template_cb, samplers=samplers, rewards=[aff_reward], losses=[],
                 cbs=cbs, buffer=buffer)

With our current setup, our generative model will be trained offline by our SupervisedCB callback. This means we don't need to train the model on each iteration, or generate model outputs (logits, etc) during the training loop. By setting the flags below, we can avoid unnecessary overhead = False
agent.compute_outputs = False
set_global_pool(min(8, os.cpu_count())), 100, 200, 10, buffer_frequency=10)
iterations rewards rewards_final new diversity bs template valid predictive_buffer_loss predictive_buffer_preds affinity buffer size rewards_live_p90 rewards_live_max
0 5.493 5.493 1.000 1.000 64 0.000 1.000 2.739 6.860 5.493 125 6.041 6.745
10 5.668 5.668 0.906 1.000 64 0.000 1.000 1.754 6.440 5.668 741 5.769 5.881
20 6.429 6.429 0.875 1.000 64 0.000 1.000 2.909 5.021 6.429 918 5.843 5.916
30 7.783 7.783 0.703 1.000 64 0.000 1.000 2.500 6.397 7.783 1340 5.145 5.145
40 7.835 7.835 0.766 1.000 64 0.000 1.000 0.447 7.709 7.835 1736 0.000 0.000
50 8.000 8.000 0.688 1.000 64 0.000 1.000 2.832 6.444 8.000 2271 5.349 5.380
60 8.067 8.067 0.609 1.000 64 0.000 1.000 1.036 7.265 8.067 2851 5.419 5.494
70 8.199 8.199 0.688 1.000 64 0.000 1.000 0.542 7.764 8.199 3457 4.236 4.236
80 8.236 8.236 0.828 1.000 64 0.000 1.000 0.534 8.587 8.236 4068 8.699 8.826
90 8.387 8.387 0.656 1.000 64 0.000 1.000 2.548 6.894 8.387 4801 9.001 9.147
100 8.694 8.694 0.719 1.000 64 0.000 1.000 0.599 8.062 8.694 5593 9.570 9.597
110 7.890 7.890 0.844 1.000 64 0.000 1.000 0.418 8.170 7.890 6381 8.566 8.724
120 8.305 8.305 0.750 1.000 64 0.000 1.000 0.193 8.378 8.305 7200 9.680 9.766
130 8.513 8.513 0.656 1.000 64 0.000 1.000 1.491 7.378 8.513 7951 9.870 10.003
140 8.610 8.610 0.656 1.000 64 0.000 1.000 0.536 8.044 8.610 8706 9.766 10.003
150 8.563 8.563 0.703 1.000 64 0.000 1.000 0.240 8.287 8.563 9411 9.699 10.030
160 8.716 8.716 0.484 1.000 64 0.000 1.000 0.155 8.728 8.716 10014 10.004 10.030
170 8.708 8.708 0.625 1.000 64 0.000 1.000 0.883 7.861 8.708 10572 9.843 10.011
180 8.787 8.787 0.625 1.000 64 0.000 1.000 0.840 7.994 8.787 11143 10.007 10.054
190 8.677 8.677 0.625 1.000 64 0.000 1.000 0.181 8.714 8.677 11671 9.947 10.030


Following from the Affinity Optimization tutorial, we determined that compounds with a score above 9.12 would be in the top 1% of samples. The top compounds found in the previous tutorial scored around 9.9-10.

With active learning, we found top scoring compounds in a similar range. We also achieve this by screening 3x fewer compounds overall compared to the previous tutorial. If we are dealing with a compute-intelsive score function, this 3x reduction in compounds screened could translate to major cost savings

subset = env.log.df[(env.log.df.affinity>=9.99) & ~(env.log.df.sources=='data_buffer')]
mols = to_mols(subset.samples.values)
legends = [f'{subset.rewards.values[i]:.3f}' for i in range(subset.shape[0])]
draw_mols(mols, legends=legends)

We can also look at predicted scores for items still in the buffer

pred_df = pd.DataFrame(zip(env.buffer.buffer, env.buffer.weights), columns=['sample', 'weight'])
sample weight
0 CC(O)CNC1CN(C(=O)CN(C)Cc2ccccc2)CC1c1ccccc1 10.313317
464 CC1CCN(c2cccc(NC(=O)Cc3ccc(NCCN)cc3)c2)CC1 10.415010
498 CNc1cc2ncnc(CCO)c2cn1 10.340527
506 CC1CCC(C(=O)NC(C)(CC(=O)O)c2cccc(N(C)CCO)c2)CC1 10.343094
1034 CCCCC(O)CNc1c(Br)cccc1NC 10.306750
... ... ...
11081 CNC(C)CC(O)CN(C)c1cc2c(Nc3cccc(Br)c3)ncnc2cn1 10.457481
11083 CN=CC=CC(C=NC)CN(C)c1cc2c(Nc3cccc(Br)c3)ncnc2cn1 10.308384
11090 CN=CC(C=O)CN(C)c1cc2c(Nc3cccc(Br)c3)ncnc2cn1 10.380414
11092 CNC(O)=CC(C)c1cc2c(Nc3cccc(I)c3)ncnc2cn1 10.343092
11093 CN=C(C=CO)CN(C)c1cc2c(Nc3cccc(Br)c3)ncnc2cn1 10.344720

716 rows × 2 columns

(reward(pred_df[pred_df.weight>10.3]['sample'].values) > 9.9).float().sum()
tensor(15., device='cuda:0')