Basic prior optimization workflow

Prior Optimization Workflows

This notebook shows a basic workflow for optimizing a prior distribution relative to a generative model. The focus here is on showing how to set up the code, rather than maximizing performance. For this reason, we will use a simple template and a simple reward function.

During prior optimization, we will create a prior distribution with a specific mean and variance. We will then optimize the weights of the prior while keeping the generative model constant

Performance Notes

The workflow in this notebook is more CPU-constrained than GPU-constrained due to the need to evaluate samples on CPU. If you have a multi-core machine, it is recommended that you uncomment and run the set_global_pool cells in the notebook. This will trigger the use of multiprocessing, which will result in 2-4x speedups.

This notebook may run slow on Collab due to CPU limitations.

If running on Collab, remember to change the runtime to GPU

import sys
sys.path.append('..')

from mrl.imports import *
from mrl.core import *
from mrl.chem import *
from mrl.templates.all import *

from mrl.torch_imports import *
from mrl.torch_core import *
from mrl.layers import *
from mrl.dataloaders import *
from mrl.g_models.all import *
from mrl.vocab import *
from mrl.policy_gradient import *
from mrl.train.all import *
from mrl.model_zoo import *
/home/dmai/miniconda3/envs/mrl/lib/python3.7/importlib/_bootstrap.py:219: RuntimeWarning: to-Python converter for boost::shared_ptr<RDKit::FilterCatalogEntry const> already registered; second conversion method ignored.
  return f(*args, **kwds)

Agent

Here we create the model we want to optimize. We will use the FP_Cond_LSTM_LM_Small_ZINC - a LSTM-based conditionl language model pretrained on part of the ZINC database.

Note that for prior optimization we specifically need a conditional generative model

agent = FP_Cond_LSTM_LM_Small_ZINC(drop_scale=0.5,opt_kwargs={'lr':5e-5}, base_model=None)

Here we freeze the weights of the model so that the model won't be updated during training

freeze(agent.model)

Template

We will set up a very basic template that will only check compounds for structural validity

template = Template([ValidityFilter(), 
                     SingleCompoundFilter()],
                    [])

template_cb = TemplateCallback(template, prefilter=True)

Reward

For the reward, we will load a scikit-learn linear regression model trained to predict affinity against erbB1 using molecular fingerprints.

This score function is extremely simple and won't translate well to affinity. It is used as a lightweight example

class FP_Regression_Score():
    def __init__(self, fname):
        self.model = torch.load(fname)
        self.fp_function = partial(failsafe_fp, fp_function=ECFP6)
        
    def __call__(self, samples):
        mols = to_mols(samples)
        fps = maybe_parallel(self.fp_function, mols)
        fps = [fp_to_array(i) for i in fps]
        x_vals = np.stack(fps)
        preds = self.model.predict(x_vals)
        return preds
    
# if in the repo:
reward_function = FP_Regression_Score('../files/erbB1_regression.sklearn')

# if in Collab:
# download_files()
# reward_function = FP_Regression_Score('files/erbB1_regression.sklearn')

reward = Reward(reward_function, weight=1.)

aff_reward = RewardCallback(reward, 'aff')

Priors

Here we show two ways to set up priors to train

Method 1: Initialize Priors as Normal Distributions

Here we initialize priors as normal distributions. We create 5 priors with the same initialization. These priors will be trained separately (ie prior 1 won't be trained on samples from prior 2) so the identical initializations aren't an issue

genbatch = 1500

priors = []
samplers = []

n_priors = 5

for i in range(n_priors):
    prior = SphericalPrior(torch.zeros((agent.model.encoder.d_latent)),
                       torch.zeros((agent.model.encoder.d_latent)),
                       trainable=True)
    priors.append(prior)
    
    prior_loss = PriorLoss(prior)
    
    sampler = PriorSampler(agent.vocab, agent.model, prior, f'prior_{i}', 
                           0, 1./n_priors, genbatch, 
                           train=True, train_all=False, prior_loss=prior_loss,
                           track_losses=False,
                           opt_kwargs={'lr':5e-3})

    samplers.append(sampler)

Method 2: Initialize Prior from Data

Here we grab 5 high scoring samples from the erbB1 training dataset. We then convert these samples into latent vectors, and use those latent vectors to initialize our priors.

df = pd.read_csv('../files/erbB1_affinity_data.csv')

# if in Collab:
# download_files()
# df = pd.read_csv('files/erbB1_affinity_data.csv')

df = df[df.neg_log_ic50>10]

smiles = df.sample(n=5).smiles.values

print(reward_function(smiles))
[ 8.94284966  9.06363406  8.04805186 10.02078208  9.89236912]
new_ds = agent.dataset.new(smiles)

batch = collate_ds(new_ds)
x,y = batch

latents = agent.model.x_to_latent(to_device(x))

genbatch = 1500

priors = []
samplers = []

n_priors = latents.shape[0]

for i in range(n_priors):
    prior = SphericalPrior(latents[i],
                       torch.zeros((agent.model.encoder.d_latent))-1,
                       trainable=True)
    priors.append(prior)
    
    prior_loss = PriorLoss(prior)
    
    sampler = PriorSampler(agent.vocab, agent.model, prior, f'prior_{i}', 
                           0, 1./n_priors, genbatch, 
                           train=True, train_all=False, prior_loss=prior_loss,
                           track_losses=False,
                           opt_kwargs={'lr':5e-3})

    samplers.append(sampler)

Optional: Policy Gradient Loss

The priors we have set so far will be trained by the PriorLoss passed to the PriorSampler. We can additionally add a policy gradient loss term. This isn't necessary, but tends to speed up convergence.

losses = []

pg = PolicyGradient(discount=True, gamma=0.97)

loss = PolicyLoss(pg, 'PG')

losses.append(loss)

Environment

We create our environment with the objects assembled so far

env = Environment(agent, template_cb, samplers=samplers, rewards=[aff_reward], losses=losses,
                 cbs=[])

Train

set_global_pool(cpus=min(10, os.cpu_count()))
env.fit(128, 90, 500, 25)
iterations rewards rewards_final new diversity bs template valid prior_0_diversity prior_0_valid prior_0_rewards prior_0_new prior_1_diversity prior_1_valid prior_1_rewards prior_1_new prior_2_diversity prior_2_valid prior_2_rewards prior_2_new prior_3_diversity prior_3_valid prior_3_rewards prior_3_new prior_4_diversity prior_4_valid prior_4_rewards prior_4_new aff PG
0 5.008 5.008 1.000 1.000 122 0.000 0.976 1.000 1.000 4.907 1.000 1.000 0.960 5.250 1.000 1.000 1.000 5.371 1.000 1.000 0.920 4.717 1.000 1.000 1.000 4.780 1.000 5.008 0.001
25 5.468 5.468 1.000 1.000 123 0.000 0.984 1.000 1.000 5.258 1.000 1.000 0.960 5.379 1.000 1.000 1.000 5.469 1.000 1.000 1.000 5.407 1.000 1.000 0.960 5.836 1.000 5.468 0.012
50 5.393 5.393 1.000 1.000 124 0.000 0.992 1.000 0.960 5.395 1.000 1.000 1.000 5.340 1.000 1.000 1.000 5.555 1.000 1.000 1.000 5.237 1.000 1.000 1.000 5.438 1.000 5.393 -0.029
75 5.541 5.541 1.000 1.000 121 0.000 0.968 1.000 0.960 5.703 1.000 1.000 0.960 5.642 1.000 1.000 1.000 5.483 1.000 1.000 1.000 5.335 1.000 1.000 0.920 5.552 1.000 5.541 0.010
100 5.392 5.392 1.000 1.000 123 0.000 0.984 1.000 0.960 5.358 1.000 1.000 0.960 5.089 1.000 1.000 1.000 5.560 1.000 1.000 1.000 5.247 1.000 1.000 1.000 5.694 1.000 5.392 -0.016
125 5.615 5.615 1.000 1.000 123 0.000 0.984 1.000 0.960 5.434 1.000 1.000 1.000 5.407 1.000 1.000 1.000 5.877 1.000 1.000 1.000 5.260 1.000 1.000 0.960 6.110 1.000 5.615 -0.018
150 5.727 5.727 1.000 1.000 123 0.000 0.984 1.000 0.960 5.429 1.000 1.000 1.000 5.826 1.000 1.000 1.000 5.813 1.000 1.000 0.960 5.216 1.000 1.000 1.000 6.317 1.000 5.727 0.005
175 6.060 6.060 1.000 1.000 123 0.000 0.984 1.000 0.960 5.765 1.000 1.000 1.000 5.785 1.000 1.000 0.960 6.345 1.000 1.000 1.000 6.299 1.000 1.000 1.000 6.107 1.000 6.060 0.016
200 6.081 6.081 1.000 1.000 123 0.000 0.984 1.000 1.000 6.077 1.000 1.000 1.000 5.889 1.000 1.000 0.960 6.357 1.000 1.000 0.960 5.825 1.000 1.000 1.000 6.258 1.000 6.081 0.003
225 6.119 6.119 1.000 1.000 122 0.000 0.976 1.000 1.000 6.331 1.000 1.000 0.920 6.044 1.000 1.000 1.000 6.573 1.000 1.000 1.000 5.211 1.000 1.000 0.960 6.444 1.000 6.119 0.013
250 6.303 6.303 1.000 1.000 123 0.000 0.984 1.000 1.000 6.454 1.000 1.000 0.960 6.348 1.000 1.000 1.000 6.346 1.000 1.000 1.000 5.763 1.000 1.000 0.960 6.620 1.000 6.303 -0.031
275 6.393 6.393 1.000 1.000 122 0.000 0.976 1.000 1.000 6.990 1.000 1.000 1.000 5.661 1.000 1.000 0.880 6.732 1.000 1.000 1.000 6.168 1.000 1.000 1.000 6.454 1.000 6.393 0.014
300 6.485 6.485 1.000 1.000 121 0.000 0.968 1.000 0.960 6.737 1.000 1.000 0.960 6.410 1.000 1.000 0.960 6.646 1.000 1.000 1.000 6.324 1.000 1.000 0.960 6.312 1.000 6.485 0.024
325 6.628 6.628 1.000 1.000 124 0.000 0.992 1.000 0.960 6.596 1.000 1.000 1.000 6.677 1.000 1.000 1.000 6.947 1.000 1.000 1.000 6.366 1.000 1.000 1.000 6.556 1.000 6.628 0.016
350 6.374 6.374 1.000 1.000 123 0.000 0.984 1.000 1.000 6.168 1.000 1.000 0.960 6.455 1.000 1.000 1.000 6.801 1.000 1.000 1.000 6.468 1.000 1.000 0.960 5.967 1.000 6.374 -0.029
375 6.706 6.706 1.000 1.000 121 0.000 0.968 1.000 1.000 6.855 1.000 1.000 0.960 7.070 1.000 1.000 0.960 6.825 1.000 1.000 1.000 6.527 1.000 1.000 0.920 6.236 1.000 6.706 -0.023
400 6.779 6.779 1.000 1.000 122 0.000 0.976 1.000 1.000 7.011 1.000 1.000 1.000 7.029 1.000 1.000 1.000 6.565 1.000 1.000 0.920 6.803 1.000 1.000 0.960 6.475 1.000 6.779 -0.070
425 6.730 6.730 0.992 1.000 125 0.000 1.000 1.000 1.000 6.413 1.000 1.000 1.000 6.844 0.960 1.000 1.000 6.678 1.000 1.000 1.000 6.430 1.000 1.000 1.000 7.283 1.000 6.730 0.011
450 6.742 6.742 0.984 1.000 125 0.000 1.000 1.000 1.000 6.160 1.000 1.000 1.000 7.398 0.920 1.000 1.000 6.745 1.000 1.000 1.000 6.820 1.000 1.000 1.000 6.585 1.000 6.742 -0.007
475 6.822 6.822 0.968 1.000 124 0.000 0.992 1.000 1.000 6.495 1.000 1.000 1.000 6.942 0.840 1.000 1.000 6.968 1.000 1.000 1.000 7.205 1.000 1.000 0.960 6.489 1.000 6.822 -0.031
env.log.plot_metrics()