Performance Notes
The workflow in this notebook is more CPU-constrained than GPU-constrained due to the need to evaluate samples on CPU. If you have a multi-core machine, it is recommended that you uncomment and run the set_global_pool
cells in the notebook. This will trigger the use of multiprocessing, which will result in 2-4x speedups.
This notebook may run slow on Collab due to CPU limitations.
If running on Collab, remember to change the runtime to GPU
import sys
from mrl.imports import *
from mrl.core import *
from mrl.chem import *
from mrl.templates.all import *
from mrl.torch_imports import *
from mrl.torch_core import *
from mrl.layers import *
from mrl.dataloaders import *
from mrl.g_models.all import *
from mrl.vocab import *
from mrl.policy_gradient import *
from mrl.train.all import *
from mrl.model_zoo import *
from mrl.combichem import *
Here we create the model we want to optimize. We will use the LSTM_LM_Small_ZINC_NC
- a LSTM-based language model pretrained on part of the ZINC database without chirality. Training without chirality prevents a form of mode collapse where the model converges to generating different isomers of the same compound.
Whether or not to use a model trained with or without chirality depends on the reward function you are trying to optimize. You should use a model with chirality if your reward function handles chirality in a meaningful way. Specically this means your reward function should give different scores to different isomers. This difference should relate to a real aspect of the propety predicted (ie affinity of different isomers) rather than being a spurious feature learned by the model (this happens surprisingly often).
Our score isn't influenced by chirality, so we will use a nonchiral model
agent = LSTM_LM_Small_ZINC_NC(drop_scale=0.3, opt_kwargs={'lr':5e-5})
We will jointly optimize three molecular properties - QED score, LogP and SA score.
First we set up our hard filters. We set constraints following the Rule of 5
We set the following hard filters:
: Only valid compoundsSingleCompoundFilter
: Only single compoundsMolWtFilter
: Molecular weight less than 500 g/molHBDFilter
: Less than or equal to 5 hydrogen bond donorsHBAFilter
: Less than or equal to 10 hydrogen bond acceptorsLogPFilter
: LogP less than 5
Next we set up our soft filters which will serve as our score. For each property (QED, LogP, SA), we scale the value:
- QED scaled between
- LogP scaled between
- SA score scaled between
This weights LogP and SA score evenly, giving double weight to QED score
def scale_sa(sa):
return (10-sa)/9
def scale_logp(logp):
logp = logp/5
logp = min(max(logp,0),1)
return logp
def scale_qed(qed):
return 2*qed
template = Template([ValidityFilter(),
MolWtFilter(None, 500),
HBDFilter(None, 5),
HBAFilter(None, 10),
LogPFilter(None, 5)
[QEDFilter(None, None, score=PropertyFunctionScore(scale_qed)),
SAFilter(None, None, score=PropertyFunctionScore(scale_sa)),
LogPFilter(None, None, score=PropertyFunctionScore(scale_logp))],
fail_score=-1., log=False)
template_cb = TemplateCallback(template, prefilter=True)
Loss Function
We will use the PPO
policy gradient algorithm
pg = PPO(0.99,
loss = PolicyLoss(pg, 'PPO',
We create the following samplers:
sampler1 ModelSampler
: this samples from the main modelsampler2 ModelSampler
: this samples from the baseline modelsampler3 LogSampler
: this samples high scoring samples from the logsampler4 CombichemSampler
: this sampler runs combichem generation on the top scoring compounds. The combination of generative models with combichem greatly accelerates finding high scoring compounds
gen_bs = 1500
sampler1 = ModelSampler(agent.vocab, agent.model, 'live', 400, 0., gen_bs)
sampler2 = ModelSampler(agent.vocab, agent.base_model, 'base', 400, 0., gen_bs)
sampler3 = LogSampler('samples', 'rewards', 10, 98, 100)
mutators = [
ChangeAtom(['6', '7', '8', '9']),
AppendAtomSingle(['C', 'N', 'O', 'F', 'Cl', 'Br']),
AppendAtomsDouble(['C', 'N', 'O']),
InsertAtomSingle(['C', 'N', 'O']),
InsertAtomDouble(['C', 'N']),
mc = MutatorCollection(mutators)
crossovers = [FragmentCrossover()]
cbc = CombiChem(mc, crossovers, template=template, rewards=[],
prune_percentile=70, max_library_size=400, log=True, p_explore=0.2)
sampler4 = CombichemSampler(cbc, 20, 98, 0.2, 1, 'rewards', 'combichem')
samplers = [sampler1, sampler2, sampler3, sampler4]
Other Callbacks
We add the following callbacks:
: every 200 batches, this callback grabs the top 3% of samples from the log and runs supervised training with these sampleslive_max
: prints the maximum score fromsampler1
each batchlive_p90
: prints the top 10% score fromsampler1
each batch
supervised_cb = SupervisedCB(agent, 20, 0.5, 98, 1e-4, 64, epochs=4)
live_max = MaxCallback('rewards', 'live')
live_p90 = PercentileCallback('rewards', 'live', 90)
cbs = [supervised_cb, live_p90, live_max]
env = Environment(agent, template_cb, samplers=samplers, rewards=[], losses=[loss],
set_global_pool(min(10, os.cpu_count())), 100, 400, 25)
subset = env.log.df[env.log.df.rewards>3.657]
samples = subset.samples.values
values = subset.rewards.values
mols = [to_mol(i) for i in samples]
qeds = [qed(i) for i in mols]
logps = [logp(i) for i in mols]
sas = [sa_score(i) for i in mols]
legends = [f'QED: {qeds[i]:.3f}\tLogP: {logps[i]:.3f}\nSA Score: {sas[i]:.3f}\tOverall: {values[i]:.3f}'
for i in range(len(samples))]
draw_mols(mols, legends=legends)