QED Optimization
This notebook shows how to optimize a generative model with respect to QED score. This is a standard benchmark in many generative design papers
Performance Notes
The workflow in this notebook is more CPU-constrained than GPU-constrained due to the need to evaluate samples on CPU. If you have a multi-core machine, it is recommended that you uncomment and run the set_global_pool
cells in the notebook. This will trigger the use of multiprocessing, which will result in 2-4x speedups.
This notebook may run slow on Collab due to CPU limitations.
If running on Collab, remember to change the runtime to GPU
import sys
sys.path.append('..')
from mrl.imports import *
from mrl.core import *
from mrl.chem import *
from mrl.templates.all import *
from mrl.torch_imports import *
from mrl.torch_core import *
from mrl.layers import *
from mrl.dataloaders import *
from mrl.g_models.all import *
from mrl.vocab import *
from mrl.policy_gradient import *
from mrl.train.all import *
from mrl.model_zoo import *
from mrl.combichem import *
Agent
Here we create the model we want to optimize. We will use the LSTM_LM_Small_ZINC_NC
- a LSTM-based language model pretrained on part of the ZINC database without chirality. Training without chirality prevents a form of mode collapse where the model converges to generating different isomers of the same compound.
Whether or not to use a model trained with or without chirality depends on the reward function you are trying to optimize. You should use a model with chirality if your reward function handles chirality in a meaningful way. Specically this means your reward function should give different scores to different isomers. This difference should relate to a real aspect of the propety predicted (ie affinity of different isomers) rather than being a spurious feature learned by the model (this happens surprisingly often).
QED score isn't impacted by chirality, so using a non-chiral model makes sense
agent = LSTM_LM_Small_ZINC_NC(drop_scale=0.3, opt_kwargs={'lr':5e-5})
Template
Here we create our template.
We set the following hard filters:
ValidityFilter
: screens for vaid compoundsSingleCompoundFilter
: screens for single compounds
We set the following soft filters:
QEDFilter
: evaluates the QED score of a compound. By passingscore=PassThroughScore()
, this filter simply returns the QED score
template = Template([ValidityFilter(),
SingleCompoundFilter(),
],
[QEDFilter(None, None, score=PassThroughScore())],
fail_score=-1.)
template_cb = TemplateCallback(template, prefilter=True)
Loss Function
We will use the PPO
policy gradient algorithm
pg = PPO(0.99,
0.5,
lam=0.95,
v_coef=0.5,
cliprange=0.3,
v_cliprange=0.3,
ent_coef=0.01,
kl_target=0.03,
kl_horizon=3000,
scale_rewards=True)
loss = PolicyLoss(pg, 'PPO',
value_head=ValueHead(256),
v_update_iter=2,
vopt_kwargs={'lr':1e-3})
Samplers
We create the following samplers:
sampler1 ModelSampler
: this samples from the main modelsampler2 ModelSampler
: this samples from the baseline modelsampler3 LogSampler
: this samples high scoring samples from the logsampler4 CombichemSampler
: this sampler runs combichem generation on the top scoring compounds. The combination of generative models with combichem greatly accelerates finding high scoring compounds
gen_bs = 1500
sampler1 = ModelSampler(agent.vocab, agent.model, 'live', 400, 0., gen_bs)
sampler2 = ModelSampler(agent.vocab, agent.base_model, 'base', 400, 0., gen_bs)
sampler3 = LogSampler('samples', 'rewards', 10, 98, 100)
mutators = [
ChangeAtom(['6', '7', '8', '9']),
AppendAtomSingle(['C', 'N', 'O', 'F', 'Cl', 'Br']),
AppendAtomsDouble(['C', 'N', 'O']),
AppendAtomsTriple(),
DeleteAtom(),
ChangeBond(),
InsertAtomSingle(['C', 'N', 'O']),
InsertAtomDouble(['C', 'N']),
InsertAtomTriple(),
AddRing(),
ShuffleNitrogen(10)
]
mc = MutatorCollection(mutators)
crossovers = [FragmentCrossover()]
cbc = CombiChem(mc, crossovers, template=template, rewards=[],
prune_percentile=70, max_library_size=400, log=True, p_explore=0.2)
sampler4 = CombichemSampler(cbc, 20, 98, 0.2, 1, 'rewards', 'combichem')
samplers = [sampler1, sampler2, sampler3, sampler4]
Other Callbacks
We add the following callbacks:
supervised_cb
: every 20 batches, this callback grabs the top 2% of samples from the log and runs supervised training with these sampleslive_mean
: prints the mean score fromsampler1
each batchlive_max
: prints the maximum score fromsampler1
each batchlive_p90
: prints the top 10% score fromsampler1
each batch
supervised_cb = SupervisedCB(agent, 20, 0.5, 98, 1e-4, 64, epochs=2)
live_max = MaxCallback('rewards', 'live')
live_mean = MeanCallback('rewards', 'live')
live_p90 = PercentileCallback('rewards', 'live', 90)
cbs = [supervised_cb, live_mean, live_p90, live_max]
env = Environment(agent, template_cb, samplers=samplers, rewards=[], losses=[loss],
cbs=cbs)
env.fit(128, 100, 300, 20)
env.log.plot_metrics()
subset = env.log.df[env.log.df.template>.9483]
draw_mols(to_mols(subset.samples.values), legends=[f"{i:.5f}" for i in subset.rewards.values])