Small molecule QED optimization

QED Optimization

This notebook shows how to optimize a generative model with respect to QED score. This is a standard benchmark in many generative design papers

Performance Notes

The workflow in this notebook is more CPU-constrained than GPU-constrained due to the need to evaluate samples on CPU. If you have a multi-core machine, it is recommended that you uncomment and run the set_global_pool cells in the notebook. This will trigger the use of multiprocessing, which will result in 2-4x speedups.

This notebook may run slow on Collab due to CPU limitations.

If running on Collab, remember to change the runtime to GPU

import sys
sys.path.append('..')

from mrl.imports import *
from mrl.core import *
from mrl.chem import *
from mrl.templates.all import *

from mrl.torch_imports import *
from mrl.torch_core import *
from mrl.layers import *
from mrl.dataloaders import *
from mrl.g_models.all import *
from mrl.vocab import *
from mrl.policy_gradient import *
from mrl.train.all import *
from mrl.model_zoo import *
from mrl.combichem import *
/home/dmai/miniconda3/envs/mrl/lib/python3.7/importlib/_bootstrap.py:219: RuntimeWarning: to-Python converter for boost::shared_ptr<RDKit::FilterCatalogEntry const> already registered; second conversion method ignored.
  return f(*args, **kwds)

Agent

Here we create the model we want to optimize. We will use the LSTM_LM_Small_ZINC_NC - a LSTM-based language model pretrained on part of the ZINC database without chirality. Training without chirality prevents a form of mode collapse where the model converges to generating different isomers of the same compound.

Whether or not to use a model trained with or without chirality depends on the reward function you are trying to optimize. You should use a model with chirality if your reward function handles chirality in a meaningful way. Specically this means your reward function should give different scores to different isomers. This difference should relate to a real aspect of the propety predicted (ie affinity of different isomers) rather than being a spurious feature learned by the model (this happens surprisingly often).

QED score isn't impacted by chirality, so using a non-chiral model makes sense

agent = LSTM_LM_Small_ZINC_NC(drop_scale=0.3, opt_kwargs={'lr':5e-5})

Template

Here we create our template.

We set the following hard filters:

We set the following soft filters:

  • QEDFilter: evaluates the QED score of a compound. By passing score=PassThroughScore(), this filter simply returns the QED score
template = Template([ValidityFilter(), 
                     SingleCompoundFilter(), 
                     ],
                    [QEDFilter(None, None, score=PassThroughScore())], 
                    fail_score=-1.)

template_cb = TemplateCallback(template, prefilter=True)

Reward

We are only optimizing towards QED score, which is contained in our template. For this reason, we don't have any additional score terms

Loss Function

We will use the PPO policy gradient algorithm

pg = PPO(0.99,
        0.5,
        lam=0.95,
        v_coef=0.5,
        cliprange=0.3,
        v_cliprange=0.3,
        ent_coef=0.01,
        kl_target=0.03,
        kl_horizon=3000,
        scale_rewards=True)

loss = PolicyLoss(pg, 'PPO', 
                   value_head=ValueHead(256), 
                   v_update_iter=2, 
                   vopt_kwargs={'lr':1e-3})

Samplers

We create the following samplers:

  • sampler1 ModelSampler: this samples from the main model
  • sampler2 ModelSampler: this samples from the baseline model
  • sampler3 LogSampler: this samples high scoring samples from the log
  • sampler4 CombichemSampler: this sampler runs combichem generation on the top scoring compounds. The combination of generative models with combichem greatly accelerates finding high scoring compounds
gen_bs = 1500

sampler1 = ModelSampler(agent.vocab, agent.model, 'live', 400, 0., gen_bs)
sampler2 = ModelSampler(agent.vocab, agent.base_model, 'base', 400, 0., gen_bs)
sampler3 = LogSampler('samples', 'rewards', 10, 98, 100)

mutators = [
    ChangeAtom(['6', '7', '8', '9']),
    AppendAtomSingle(['C', 'N', 'O', 'F', 'Cl', 'Br']),
    AppendAtomsDouble(['C', 'N', 'O']),
    AppendAtomsTriple(),
    DeleteAtom(),
    ChangeBond(),
    InsertAtomSingle(['C', 'N', 'O']),
    InsertAtomDouble(['C', 'N']),
    InsertAtomTriple(),
    AddRing(),
    ShuffleNitrogen(10)
]

mc = MutatorCollection(mutators)

crossovers = [FragmentCrossover()]

cbc = CombiChem(mc, crossovers, template=template, rewards=[],
                prune_percentile=70, max_library_size=400, log=True, p_explore=0.2)

sampler4 = CombichemSampler(cbc, 20, 98, 0.2, 1, 'rewards', 'combichem')

samplers = [sampler1, sampler2, sampler3, sampler4]

Other Callbacks

We add the following callbacks:

  • supervised_cb: every 20 batches, this callback grabs the top 2% of samples from the log and runs supervised training with these samples
  • live_mean: prints the mean score from sampler1 each batch
  • live_max: prints the maximum score from sampler1 each batch
  • live_p90: prints the top 10% score from sampler1 each batch
supervised_cb = SupervisedCB(agent, 20, 0.5, 98, 1e-4, 64, epochs=2)
live_max = MaxCallback('rewards', 'live')
live_mean = MeanCallback('rewards', 'live')
live_p90 = PercentileCallback('rewards', 'live', 90)

cbs = [supervised_cb, live_mean, live_p90, live_max]

Environment

We create our environment with the objects assembled so far

env = Environment(agent, template_cb, samplers=samplers, rewards=[], losses=[loss],
                 cbs=cbs)

Train

 
env.fit(128, 100, 300, 20)
iterations rewards rewards_final new diversity bs template valid PPO rewards_live_mean rewards_live_p90 rewards_live_max
0 0.729 0.729 1.000 1.000 128 0.729 1.000 -0.006 0.739 0.889 0.936
20 0.761 0.761 0.961 1.000 128 0.761 1.000 0.565 0.744 0.848 0.944
40 0.764 0.764 0.914 1.000 128 0.764 1.000 0.458 0.717 0.836 0.895
60 0.761 0.761 0.898 1.000 128 0.761 1.000 2.310 0.747 0.853 0.910
80 0.790 0.790 0.859 1.000 128 0.790 1.000 0.404 0.752 0.840 0.917
100 0.764 0.764 0.898 1.000 128 0.764 1.000 0.814 0.695 0.845 0.904
120 0.772 0.772 0.859 1.000 128 0.772 1.000 -0.004 0.751 0.870 0.900
140 0.761 0.761 0.859 1.000 128 0.761 1.000 0.126 0.730 0.862 0.939
160 0.788 0.788 0.875 1.000 128 0.788 1.000 0.333 0.767 0.865 0.927
180 0.790 0.790 0.805 1.000 128 0.790 1.000 0.107 0.763 0.892 0.934
200 0.786 0.786 0.852 1.000 128 0.786 1.000 0.083 0.780 0.898 0.947
220 0.790 0.790 0.828 1.000 128 0.790 1.000 0.034 0.758 0.881 0.924
240 0.784 0.784 0.844 1.000 128 0.784 1.000 0.355 0.790 0.894 0.924
260 0.810 0.810 0.875 1.000 128 0.810 1.000 0.006 0.792 0.910 0.939
280 0.803 0.803 0.805 1.000 128 0.803 1.000 -0.040 0.796 0.920 0.941
env.log.plot_metrics()
subset = env.log.df[env.log.df.template>.9483]
draw_mols(to_mols(subset.samples.values), legends=[f"{i:.5f}" for i in subset.rewards.values])