Penalized LogP Optimization
This notebook shows how to optimize a generative model with respect to Penalized LogP score. This is a standard benchmark in many generative design papers
Performance Notes
The workflow in this notebook is more CPU-constrained than GPU-constrained due to the need to evaluate samples on CPU. If you have a multi-core machine, it is recommended that you uncomment and run the set_global_pool
cells in the notebook. This will trigger the use of multiprocessing, which will result in 2-4x speedups.
This notebook may run slow on Collab due to CPU limitations.
If running on Collab, remember to change the runtime to GPU
import sys
sys.path.append('..')
from mrl.imports import *
from mrl.core import *
from mrl.chem import *
from mrl.templates.all import *
from mrl.torch_imports import *
from mrl.torch_core import *
from mrl.layers import *
from mrl.dataloaders import *
from mrl.g_models.all import *
from mrl.vocab import *
from mrl.policy_gradient import *
from mrl.train.all import *
from mrl.model_zoo import *
from mrl.combichem import *
Agent
Here we create the model we want to optimize. We will use the LSTM_LM_Small_ZINC_NC
- a LSTM-based language model pretrained on part of the ZINC database without chirality. Training without chirality prevents a form of mode collapse where the model converges to generating different isomers of the same compound.
Whether or not to use a model trained with or without chirality depends on the reward function you are trying to optimize. You should use a model with chirality if your reward function handles chirality in a meaningful way. Specically this means your reward function should give different scores to different isomers. This difference should relate to a real aspect of the propety predicted (ie affinity of different isomers) rather than being a spurious feature learned by the model (this happens surprisingly often).
Penalized LogP score isn't impacted by chirality, so using a non-chiral model makes sense. To be technical, Penalized LogP includes a SA score component which is influenced by the number of stereocenters in a molecule, but this does not result in different isomers getting different scores.
agent = LSTM_LM_Small_ZINC_NC(drop_scale=0.3, opt_kwargs={'lr':5e-5})
Template
Here we create our template.
We set the following hard filters:
ValidityFilter
: screens for vaid compoundsSingleCompoundFilter
: screens for single compounds
We set the following soft filters:
PenalizedLogPFilter
: evaluates the Penalized LogP score of a compound. By passingscore=PassThroughScore()
, this filter simply returns the Penalized LogP score
template = Template([ValidityFilter(),
SingleCompoundFilter(),
],
[PenalizedLogPFilter(None, None, score=PassThroughScore())],
fail_score=-1., log=False)
template_cb = TemplateCallback(template, prefilter=True)
Loss Function
We will use the PPO
policy gradient algorithm
pg = PPO(0.99,
0.5,
lam=0.95,
v_coef=0.5,
cliprange=0.3,
v_cliprange=0.3,
ent_coef=0.01,
kl_target=0.03,
kl_horizon=3000,
scale_rewards=True)
loss = PolicyLoss(pg, 'PPO',
value_head=ValueHead(256),
v_update_iter=2,
vopt_kwargs={'lr':1e-3})
Samplers
We create the following samplers:
sampler1 ModelSampler
: this samples from the main modelsampler2 ModelSampler
: this samples from the baseline modelsampler3 LogSampler
: this samples high scoring samples from the logsampler4 CombichemSampler
: this sampler runs combichem generation on the top scoring compounds. The combination of generative models with combichem greatly accelerates finding high scoring compounds
gen_bs = 1500
sampler1 = ModelSampler(agent.vocab, agent.model, 'live', 400, 0., gen_bs)
sampler2 = ModelSampler(agent.vocab, agent.base_model, 'base', 400, 0., gen_bs)
sampler3 = LogSampler('samples', 'rewards', 10, 98, 100)
mutators = [
ChangeAtom(['6', '7', '8', '9']),
AppendAtomSingle(['C', 'N', 'O', 'F', 'Cl', 'Br']),
AppendAtomsDouble(['C', 'N', 'O']),
AppendAtomsTriple(),
DeleteAtom(),
ChangeBond(),
InsertAtomSingle(['C', 'N', 'O']),
InsertAtomDouble(['C', 'N']),
InsertAtomTriple(),
AddRing(),
ShuffleNitrogen(10)
]
mc = MutatorCollection(mutators)
crossovers = [FragmentCrossover()]
cbc = CombiChem(mc, crossovers, template=template, rewards=[],
prune_percentile=70, max_library_size=400, log=True, p_explore=0.2)
sampler4 = CombichemSampler(cbc, 20, 98, 0.2, 1, 'rewards', 'combichem')
samplers = [sampler1, sampler2, sampler3, sampler4]
Other Callbacks
We add the following callbacks:
supervised_cb
: every 20 batches, this callback grabs the top 2% of samples from the log and runs supervised training with these sampleslive_max
: prints the maximum score fromsampler1
each batchlive_p90
: prints the top 10% score fromsampler1
each batch
supervised_cb = SupervisedCB(agent, 20, 0.5, 98, 5e-4, 64, epochs=5)
live_max = MaxCallback('rewards', 'live')
live_p90 = PercentileCallback('rewards', 'live', 90)
cbs = [supervised_cb, live_p90, live_max]
env = Environment(agent, template_cb, samplers=samplers, rewards=[], losses=[loss],
cbs=cbs)
set_global_pool(min(10, os.cpu_count()))
env.fit(128, 150, 400, 25)
env.log.plot_metrics()
subset = env.log.df[env.log.df.rewards>23.6]
draw_mols(to_mols(subset.samples.values), legends=[f"{i:.5f}" for i in subset.rewards.values])
Note that penalized LogP is strongly influenced by the size of the molecule generated. For this reason it's not a very good benchmark. Nonetheless it is very common in literature. Increasing the maximum sequence length from 150
(what we used in Environment.fit
above) to something like 200
will result in higher penalized logP scores for the same amount of training