Affinity Optimization with Active Learning
This tutorial shows how to use active learning in conjunction with generative models to find high scoring compounds in a way that minimizes the number of compounds that need to be scored. This technique is particularly advantageous when the score function being used has a high compute requirement.
This notebook follows from the Affinity Optimization tutorial. Consult the Affinity Optimization tutorial for details on the affinity task and the score function developed.
Performance Notes
The workflow in this notebook is more CPU-constrained than GPU-constrained due to the need to evaluate samples on CPU. If you have a multi-core machine, it is recommended that you uncomment and run the set_global_pool
cells in the notebook. This will trigger the use of multiprocessing, which will result in 2-4x speedups.
This notebook may run slow on Collab due to CPU limitations.
If running on Collab, remember to change the runtime to GPU
import sys
sys.path.append('..')
from mrl.imports import *
from mrl.core import *
from mrl.chem import *
from mrl.templates.all import *
from mrl.torch_imports import *
from mrl.torch_core import *
from mrl.layers import *
from mrl.dataloaders import *
from mrl.g_models.all import *
from mrl.vocab import *
from mrl.policy_gradient import *
from mrl.train.all import *
from mrl.model_zoo import *
from mrl.combichem import *
from sklearn.metrics import r2_score
download_files()
df = pd.read_csv('files/erbB1_affinity_data.csv')
df.head()
df.shape
df['smiles_ns'] = df.smiles.map(lambda x: remove_stereo(x))
Chemical Space
Here we construct the template from the Affinity Optimization tutorial
smarts = ['[#6](=[#16])(-[#7])-[#7]',
'[#6]=[#6]=[#6]',
'[#7][F,Cl,Br,I]',
'[*]#[Cl,Br]',
'[#6;!R]=[#6;!R]-[#6;!R]=[#6;!R]',
'[#6]#[#6]',
'[#15]',
'[#16]',
'[*]=[#17,#9,#35]',
'[*]=[*]=[*]',
'[*]-[#6]=[#6H2]',
'[#7]~[#8]',
'[#7]~[#7]',
'[*;R]=[*;!R]']
template = Template([ValidityFilter(),
SingleCompoundFilter(),
RotBondFilter(None, 8),
HeteroatomFilter(None, 8),
ChargeFilter(None, 0),
MaxRingFilter(None, 6),
MinRingFilter(5, None),
HBDFilter(None, 5),
HBAFilter(None, 10),
MolWtFilter(None, 500),
LogPFilter(None, 5),
SAFilter(None, 7),
BridgeheadFilter(None,0),
PAINSAFilter(),
ExclusionFilter(smarts, criteria='any'),
RotChainFilter(None, 7)
],
[],
fail_score=-1., log=False, use_lookup=False)
template_cb = TemplateCallback(template, prefilter=True)
Load Model
We load the LSTM_LM_Small_ZINC_NC
model. This is a basic LSTM-based language model trained on part of the ZINC database
agent = LSTM_LM_Small_ZINC_NC(drop_scale=0.3, opt_kwargs={'lr':5e-5})
Reward
Here we load the reward function from the Affinity Optimization tutorial
reward_model = MLP(2048, [1024, 512, 256, 128], 1, [0.2, 0.2, 0.2, 0.2], outrange=[0,15])
r_ds = Vec_Prediction_Dataset(['C'], [0], partial(failsafe_fp, fp_function=ECFP6))
r_agent = PredictiveAgent(reward_model, MSELoss(), r_ds, opt_kwargs={'lr':1e-3})
r_agent.load_state_dict(model_from_url('egfr_affinity_mlp.pt')) # optional - load exact weights
r_agent.model.eval();
freeze(r_agent.model)
reward = Reward(r_agent.predict_data, weight=1.)
aff_reward = RewardCallback(reward, 'affinity')
Samplers
We create the following samplers:
sampler1 ModelSampler
: this samples from the main modelsampler2 ModelSampler
: this samples from the baseline modelsampler4 CombichemSampler
: this sampler runs combichem generation on the top scoring compounds. The combination of generative models with combichem greatly accelerates finding high scoring compounds
Note on the use of combichem. Adding combichem to the sampler mix greatly accelerates convergence to high scoring compounds. However, combichem tends to end up with "weirder" structures that are more exploitative of the score function. These structures can be removed by updating the template
mutators = [
ChangeAtom(['6', '7', '8', '9', '17', '35']),
AppendAtomSingle(['C', 'N', 'O', 'F', 'Cl', 'Br']),
AppendAtomsDouble(['C', 'N', 'O']),
AppendAtomsTriple(),
DeleteAtom(),
ChangeBond(),
InsertAtomSingle(['C', 'N', 'O']),
InsertAtomDouble(['C', 'N']),
InsertAtomTriple(),
AddRing(),
ShuffleNitrogen(20)
]
mc = MutatorCollection(mutators)
crossovers = [FragmentCrossover()]
cbc = CombiChem(mc, crossovers, template=template, rewards=[reward],
prune_percentile=70, max_library_size=400, log=True, p_explore=0.2)
gen_bs = 1500
sampler1 = ModelSampler(agent.vocab, agent.model, 'live', 500, 0., gen_bs)
sampler2 = ModelSampler(agent.vocab, agent.base_model, 'base', 500, 0., gen_bs)
sampler3 = CombichemSampler(cbc, 150, 98, 0.2, 1, 'rewards', 'combichem')
sampler4 = LogSampler('samples', 'rewards', 10, 95, 100)
sampler5 = DatasetSampler(df[df.neg_log_ic50>8].smiles.values, 'data', buffer_size=10)
samplers = [sampler1, sampler2, sampler3, sampler4, sampler5]
Callbacks
Additional callbacks
SupervisedCB
: runs supervised training on the top 2% of samples every 3 batchesMaxCallback
: prints the max reward for each batchPercentileCallback
: prints the 90th percentile score each batchBufferSizeCallback
: prints the current buffer size
supervised_cb = SupervisedCB(agent, 10, 0.5, 98, 5e-4, 64, epochs=5, silent=True)
live_max = MaxCallback('rewards', 'live')
live_p90 = PercentileCallback('rewards', 'live', 90)
bfr_cb = BufferSizeCallback()
cbs = [supervised_cb, live_p90, live_max, bfr_cb]
Active Learning
Now we get to the active learning part. We will construct a buffer that contains a machine learning model. This model estimates the score of each item in the buffer. During sampling, the buffer uses these scores to select promising compounds.
The idea here is that the buffer model is a very fast model that can generate predictions in ~1s per compound. We can pair this quick model with a more costly score function (ie docking, molecular dynamics, etc). The fast score function lets us predict over a large number of compounds and choose the most promising while limiting the compute invested in the slow score function.
Since this is a tutorial, we're using a score function that's also fast. This notebook serves as a guide in how to implement an active learning setup. Developing big fancy score functions is left to the user
d_vocab = len(agent.vocab.itos)
d_embedding = 256
d_latent = 1
filters = [256, 512, 1024]
kernel_sizes = [7, 7, 7]
strides = [2,2,2]
dropouts = [0.2, 0.2, 0.2]
mlp_dims = [512, 256, 128]
mlp_drops = [0.2, 0.2, 0.2]
d_out = 1
encoder = Conv_Encoder(
d_vocab,
d_embedding,
d_latent,
filters,
kernel_sizes,
strides,
dropouts,
)
active_model = ScaledEncoder(encoder, outrange=[0,15])
r_ds = Text_Prediction_Dataset(['C'], [0], agent.vocab)
active_agent = PredictiveAgent(active_model, MSELoss(), r_ds, opt_kwargs={'lr':1e-3})
p_total = 1.
refresh_predictions = 10 # update predictions every 10 batches
pred_bs = 2048 # prediction batch size
supervised_frequency = 8 # how often to do offline supervised training
supervised_epochs = 2 # number of epochs for offline training
supervised_bs = 32 # supervised training bs
supervised_lr = 1e-3 # supervised training learning rate
pct_argmax=0.6 # percent of batch to sample via argmax rather than weighted sampling
buffer = PredictiveBuffer(p_total,
refresh_predictions,
active_agent,
pred_bs,
supervised_frequency,
supervised_epochs,
supervised_bs,
supervised_lr,
pct_argmax=pct_argmax)
env = Environment(agent, template_cb, samplers=samplers, rewards=[aff_reward], losses=[],
cbs=cbs, buffer=buffer)
With our current setup, our generative model will be trained offline by our SupervisedCB
callback. This means we don't need to train the model on each iteration, or generate model outputs (logits, etc) during the training loop. By setting the flags below, we can avoid unnecessary overhead
agent.training = False
agent.compute_outputs = False
set_global_pool(min(8, os.cpu_count()))
env.fit(64, 100, 200, 10, buffer_frequency=10)
env.log.plot_metrics()
Evaluation
Following from the Affinity Optimization tutorial, we determined that compounds with a score above 9.12 would be in the top 1% of samples. The top compounds found in the previous tutorial scored around 9.9-10.
With active learning, we found top scoring compounds in a similar range. We also achieve this by screening 3x fewer compounds overall compared to the previous tutorial. If we are dealing with a compute-intelsive score function, this 3x reduction in compounds screened could translate to major cost savings
subset = env.log.df[(env.log.df.affinity>=9.99) & ~(env.log.df.sources=='data_buffer')]
mols = to_mols(subset.samples.values)
legends = [f'{subset.rewards.values[i]:.3f}' for i in range(subset.shape[0])]
draw_mols(mols, legends=legends)
We can also look at predicted scores for items still in the buffer
pred_df = pd.DataFrame(zip(env.buffer.buffer, env.buffer.weights), columns=['sample', 'weight'])
pred_df[pred_df.weight>10.3]
(reward(pred_df[pred_df.weight>10.3]['sample'].values) > 9.9).float().sum()