Small molecule MPO

Multiparameter Optimization

This notebook shows how to optimize a generative model with respect to an MPO score consisting of multiple desired properties

Performance Notes

The workflow in this notebook is more CPU-constrained than GPU-constrained due to the need to evaluate samples on CPU. If you have a multi-core machine, it is recommended that you uncomment and run the set_global_pool cells in the notebook. This will trigger the use of multiprocessing, which will result in 2-4x speedups.

This notebook may run slow on Collab due to CPU limitations.

If running on Collab, remember to change the runtime to GPU

import sys
sys.path.append('..')

from mrl.imports import *
from mrl.core import *
from mrl.chem import *
from mrl.templates.all import *

from mrl.torch_imports import *
from mrl.torch_core import *
from mrl.layers import *
from mrl.dataloaders import *
from mrl.g_models.all import *
from mrl.vocab import *
from mrl.policy_gradient import *
from mrl.train.all import *
from mrl.model_zoo import *
from mrl.combichem import *
/home/dmai/miniconda3/envs/mrl/lib/python3.7/importlib/_bootstrap.py:219: RuntimeWarning: to-Python converter for boost::shared_ptr<RDKit::FilterCatalogEntry const> already registered; second conversion method ignored.
  return f(*args, **kwds)

Agent

Here we create the model we want to optimize. We will use the LSTM_LM_Small_ZINC_NC - a LSTM-based language model pretrained on part of the ZINC database without chirality. Training without chirality prevents a form of mode collapse where the model converges to generating different isomers of the same compound.

Whether or not to use a model trained with or without chirality depends on the reward function you are trying to optimize. You should use a model with chirality if your reward function handles chirality in a meaningful way. Specically this means your reward function should give different scores to different isomers. This difference should relate to a real aspect of the propety predicted (ie affinity of different isomers) rather than being a spurious feature learned by the model (this happens surprisingly often).

Our score isn't influenced by chirality, so we will use a nonchiral model

agent = LSTM_LM_Small_ZINC_NC(drop_scale=0.3, opt_kwargs={'lr':5e-5})

Template

We will jointly optimize three molecular properties - QED score, LogP and SA score.

First we set up our hard filters. We set constraints following the Rule of 5

We set the following hard filters:

Next we set up our soft filters which will serve as our score. For each property (QED, LogP, SA), we scale the value:

  • QED scaled between [0,2]
  • LogP scaled between [0,1]
  • SA score scaled between [0,1]

This weights LogP and SA score evenly, giving double weight to QED score

def scale_sa(sa):
    return (10-sa)/9

def scale_logp(logp):
    logp = logp/5
    logp = min(max(logp,0),1)
    return logp

def scale_qed(qed):
    return 2*qed

template = Template([ValidityFilter(), 
                     SingleCompoundFilter(), 
                     MolWtFilter(None, 500),
                     HBDFilter(None, 5),
                     HBAFilter(None, 10),
                     LogPFilter(None, 5)
                     ],
                    [QEDFilter(None, None, score=PropertyFunctionScore(scale_qed)),
                     SAFilter(None, None, score=PropertyFunctionScore(scale_sa)),
                     LogPFilter(None, None, score=PropertyFunctionScore(scale_logp))], 
                    fail_score=-1., log=False)

template_cb = TemplateCallback(template, prefilter=True)

Reward

We are only optimizing towards Penaized LogP, which is contained in our template. For this reason, we don't have any additional score terms

Loss Function

We will use the PPO policy gradient algorithm

pg = PPO(0.99,
        0.5,
        lam=0.95,
        v_coef=0.5,
        cliprange=0.3,
        v_cliprange=0.3,
        ent_coef=0.01,
        kl_target=0.03,
        kl_horizon=3000,
        scale_rewards=True)

loss = PolicyLoss(pg, 'PPO', 
                   value_head=ValueHead(256), 
                   v_update_iter=2, 
                   vopt_kwargs={'lr':1e-3})

Samplers

We create the following samplers:

  • sampler1 ModelSampler: this samples from the main model
  • sampler2 ModelSampler: this samples from the baseline model
  • sampler3 LogSampler: this samples high scoring samples from the log
  • sampler4 CombichemSampler: this sampler runs combichem generation on the top scoring compounds. The combination of generative models with combichem greatly accelerates finding high scoring compounds
gen_bs = 1500

sampler1 = ModelSampler(agent.vocab, agent.model, 'live', 400, 0., gen_bs)
sampler2 = ModelSampler(agent.vocab, agent.base_model, 'base', 400, 0., gen_bs)
sampler3 = LogSampler('samples', 'rewards', 10, 98, 100)

mutators = [
    ChangeAtom(['6', '7', '8', '9']),
    AppendAtomSingle(['C', 'N', 'O', 'F', 'Cl', 'Br']),
    AppendAtomsDouble(['C', 'N', 'O']),
    AppendAtomsTriple(),
    DeleteAtom(),
    ChangeBond(),
    InsertAtomSingle(['C', 'N', 'O']),
    InsertAtomDouble(['C', 'N']),
    InsertAtomTriple(),
    AddRing(),
    ShuffleNitrogen(10)
]

mc = MutatorCollection(mutators)

crossovers = [FragmentCrossover()]

cbc = CombiChem(mc, crossovers, template=template, rewards=[],
                prune_percentile=70, max_library_size=400, log=True, p_explore=0.2)

sampler4 = CombichemSampler(cbc, 20, 98, 0.2, 1, 'rewards', 'combichem')

samplers = [sampler1, sampler2, sampler3, sampler4]

Other Callbacks

We add the following callbacks:

  • supervised_cb: every 200 batches, this callback grabs the top 3% of samples from the log and runs supervised training with these samples
  • live_max: prints the maximum score from sampler1 each batch
  • live_p90: prints the top 10% score from sampler1 each batch
supervised_cb = SupervisedCB(agent, 20, 0.5, 98, 1e-4, 64, epochs=4)
live_max = MaxCallback('rewards', 'live')
live_p90 = PercentileCallback('rewards', 'live', 90)

cbs = [supervised_cb, live_p90, live_max]

Environment

We create our environment with the objects assembled so far

env = Environment(agent, template_cb, samplers=samplers, rewards=[], losses=[loss],
                 cbs=cbs)

Train

set_global_pool(min(10, os.cpu_count()))
env.fit(128, 100, 400, 25)
iterations rewards rewards_final new diversity bs template valid PPO rewards_live_p90 rewards_live_max
0 2.704 2.704 1.000 1.000 128 2.704 1.000 0.071 3.136 3.486
25 2.801 2.801 0.961 1.000 128 2.801 1.000 0.017 3.047 3.309
50 2.980 2.980 0.867 1.000 128 2.980 1.000 0.005 3.158 3.338
75 2.970 2.970 0.836 1.000 128 2.970 1.000 -0.001 3.238 3.535
100 3.005 3.005 0.797 1.000 128 3.005 1.000 -0.026 3.268 3.466
125 3.081 3.081 0.742 1.000 128 3.081 1.000 0.055 3.496 3.622
150 3.099 3.099 0.828 1.000 128 3.099 1.000 -0.040 3.473 3.581
175 3.132 3.132 0.844 1.000 128 3.132 1.000 0.000 3.519 3.628
200 3.322 3.322 0.719 1.000 128 3.322 1.000 0.666 3.611 3.651
225 3.213 3.213 0.742 1.000 128 3.213 1.000 0.075 3.618 3.630
250 3.402 3.402 0.648 1.000 128 3.402 1.000 0.225 3.651 3.653
275 3.420 3.420 0.523 1.000 128 3.420 1.000 0.189 3.639 3.654
300 3.354 3.354 0.477 1.000 128 3.354 1.000 0.269 3.642 3.649
325 3.328 3.328 0.453 1.000 128 3.328 1.000 7.340 3.649 3.657
350 3.486 3.486 0.367 1.000 128 3.486 1.000 0.043 3.653 3.656
375 3.478 3.478 0.336 1.000 128 3.478 1.000 2.132 3.648 3.654
env.log.plot_metrics()
subset = env.log.df[env.log.df.rewards>3.657]
samples = subset.samples.values
values = subset.rewards.values
mols = [to_mol(i) for i in samples]
qeds = [qed(i) for i in mols]
logps = [logp(i) for i in mols]
sas = [sa_score(i) for i in mols]

legends = [f'QED: {qeds[i]:.3f}\tLogP: {logps[i]:.3f}\nSA Score: {sas[i]:.3f}\tOverall: {values[i]:.3f}'
          for i in range(len(samples))]

draw_mols(mols, legends=legends)