Performance Notes
The workflow in this notebook is more CPU-constrained than GPU-constrained due to the need to evaluate samples on CPU. If you have a multi-core machine, it is recommended that you uncomment and run the set_global_pool
cells in the notebook. This will trigger the use of multiprocessing, which will result in 2-4x speedups.
This notebook may run slow on Collab due to CPU limitations.
If running on Collab, remember to change the runtime to GPU
import sys
sys.path.append('..')
from mrl.imports import *
from mrl.core import *
from mrl.chem import *
from mrl.templates.all import *
from mrl.torch_imports import *
from mrl.torch_core import *
from mrl.layers import *
from mrl.dataloaders import *
from mrl.g_models.all import *
from mrl.vocab import *
from mrl.policy_gradient import *
from mrl.train.all import *
from mrl.combichem import *
from mrl.model_zoo import *
Agent
Here we create the model we want to optimize. We will use the LSTM_LM_Small_PI1M
model. This model was trained on the PI1M polymer library
agent = LSTM_LM_Small_PI1M(drop_scale=0.3, opt_kwargs={'lr':5e-5})
Template
Here we create our template.
We set the following hard filters:
ValidityFilter
: screens for vaid compoundsSingleCompoundFilter
: screens for single compoundsAttachmentFilter
: ensures our polymers always have two attachment points, representing the polymerization locationsMolWtFilter
: constrains maximum monomer molecular weightRingFilter
: constrains maximum monomer number of rings
We set the following soft filters:
SAFilter
: evaluates the SA score of a compound. By passingscore=PropertyFunctionScore(scale_sa)
, this filter returns the SA score scaled to between 0 and 1
def scale_sa(sa):
return (10-sa)/9
template = Template([ValidityFilter(),
SingleCompoundFilter(),
AttachmentFilter(min_val=2, max_val=2),
MolWtFilter(None, 600),
RingFilter(None, 4)
],
[SAFilter(None, None, score=PropertyFunctionScore(scale_sa))],
fail_score=-1., use_lookup=False)
template_cb = TemplateCallback(template, prefilter=True)
Loss Function
We will use the PPO
policy gradient algorithm
pg = PPO(0.99,
0.5,
lam=0.95,
v_coef=0.5,
cliprange=0.3,
v_cliprange=0.3,
ent_coef=0.01,
kl_target=0.03,
kl_horizon=3000,
scale_rewards=True)
loss = PolicyLoss(pg, 'PPO',
value_head=ValueHead(256),
v_update_iter=2,
vopt_kwargs={'lr':1e-3})
Samplers
We create the following samplers:
sampler1 ModelSampler
: this samples from the main modelsampler2 ModelSampler
: this samples from the baseline modelsampler3 LogSampler
: this samples high scoring samples from the logsampler4 CombichemSampler
: this sampler runs combichem generation on the top scoring compounds. The combination of generative models with combichem greatly accelerates finding high scoring compounds
gen_bs = 1500
sampler1 = ModelSampler(agent.vocab, agent.model, 'live', 400, 0., gen_bs)
sampler2 = ModelSampler(agent.vocab, agent.base_model, 'base', 400, 0., gen_bs)
sampler3 = LogSampler('samples', 'rewards', 10, 98, 100)
mutators = [
ChangeAtom(['6', '7', '8', '9']),
AppendAtomSingle(['C', 'N', 'O', 'F', 'Cl', 'Br']),
AppendAtomsDouble(['C', 'N', 'O']),
DeleteAtom(),
ChangeBond(),
InsertAtomSingle(['C', 'N', 'O']),
InsertAtomDouble(['C', 'N']),
AddRing(),
ShuffleNitrogen(10)
]
mc = MutatorCollection(mutators)
crossovers = [FragmentCrossover()]
cbc = CombiChem(mc, crossovers, template=template, rewards=[],
prune_percentile=70, max_library_size=400, log=True, p_explore=0.2)
sampler4 = CombichemSampler(cbc, 20, 98, 0.2, 1, 'rewards', 'combichem')
samplers = [sampler1, sampler2, sampler3, sampler4]
Other Callbacks
We add the following callbacks:
supervised_cb
: every 200 batches, this callback grabs the top 3% of samples from the log and runs supervised training with these sampleslive_max
: prints the maximum score fromsampler1
each batchlive_p90
: prints the top 10% score fromsampler1
each batch
supervised_cb = SupervisedCB(agent, 20, 0.5, 98, 1e-4, 64, epochs=2)
live_max = MaxCallback('rewards', 'live')
live_p90 = PercentileCallback('rewards', 'live', 90)
cbs = [supervised_cb, live_p90, live_max]
env = Environment(agent, template_cb, samplers=samplers, rewards=[], losses=[loss],
cbs=cbs)
set_global_pool(min(10, os.cpu_count()))
env.fit(128, 120, 300, 20)
env.log.plot_metrics()