Basic latent optimization workflow

Latent Optimization Workflows

This notebook shows a basic workflow for optimizing a latent vector relative to a generative model. The focus here is on showing how to set up the code, rather than maximizing performance. For this reason, we will use a simple template and a simple reward function.

During latent optimization, we will create several latent vectors. Upon initialisation, these latent vectors will map to compounds with some score. We will then use gradient descent to optimize these vectors to produce high scoring compounds while keeping the generative model constant

Performance Notes

The workflow in this notebook is more CPU-constrained than GPU-constrained due to the need to evaluate samples on CPU. If you have a multi-core machine, it is recommended that you uncomment and run the set_global_pool cells in the notebook. This will trigger the use of multiprocessing, which will result in 2-4x speedups.

This notebook may run slow on Collab due to CPU limitations.

If running on Collab, remember to change the runtime to GPU

import sys
sys.path.append('..')

from mrl.imports import *
from mrl.core import *
from mrl.chem import *
from mrl.templates.all import *

from mrl.torch_imports import *
from mrl.torch_core import *
from mrl.layers import *
from mrl.dataloaders import *
from mrl.g_models.all import *
from mrl.vocab import *
from mrl.policy_gradient import *
from mrl.train.all import *
from mrl.model_zoo import *
/home/dmai/miniconda3/envs/mrl/lib/python3.7/importlib/_bootstrap.py:219: RuntimeWarning: to-Python converter for boost::shared_ptr<RDKit::FilterCatalogEntry const> already registered; second conversion method ignored.
  return f(*args, **kwds)

Agent

Here we create the model we want to optimize. We will use the FP_Cond_LSTM_LM_Small_ZINC - a LSTM-based conditionl language model pretrained on part of the ZINC database.

Note that for latent optimization we specifically need a conditional generative model

agent = FP_Cond_LSTM_LM_Small_ZINC(drop_scale=0.2,opt_kwargs={'lr':5e-5}, base_model=None)

Here we freeze the weights of the model so that the model won't be updated during training

freeze(agent.model)

Template

We will set up a very basic template that will only check compounds for structural validity

template = Template([ValidityFilter(), 
                     SingleCompoundFilter()],
                    [])

template_cb = TemplateCallback(template, prefilter=True)

Reward

For the reward, we will load a scikit-learn linear regression model trained to predict affinity against erbB1 using molecular fingerprints.

This score function is extremely simple and won't translate well to affinity. It is used as a lightweight example

class FP_Regression_Score():
    def __init__(self, fname):
        self.model = torch.load(fname)
        self.fp_function = partial(failsafe_fp, fp_function=ECFP6)
        
    def __call__(self, samples):
        mols = to_mols(samples)
        fps = maybe_parallel(self.fp_function, mols)
        fps = [fp_to_array(i) for i in fps]
        x_vals = np.stack(fps)
        preds = self.model.predict(x_vals)
        return preds
    
# if in the repo
reward_function = FP_Regression_Score('../files/erbB1_regression.sklearn')

# if in Collab
# download_files()
# reward_function = FP_Regression_Score('files/erbB1_regression.sklearn')

reward = Reward(reward_function, weight=1.)

aff_reward = RewardCallback(reward, 'aff')

Latents

Here we show two ways to set up latent vectors to train

Method 1: Initialize Latents as Random Vectors

Here we get our initial latents by sampling from a normal distribution. We create 200 latent vectors. During training, these vectors will be randomly sampled to assemble a batch

latents = torch.randn((500, agent.model.encoder.d_latent))
gen_bs = 1500

sampler = LatentSampler(agent.vocab, agent.model, latents, 'latent', 0, 1., gen_bs, opt_kwargs={'lr':1e-3})

samplers = [sampler]

Method 2: Initialize Latents from Data

Here we grab 5 high scoring samples from the erbB1 training dataset. We then convert these samples into latent vectors. We make several copies of these vectors and add a small amount of random noise.

df = pd.read_csv('../files/erbB1_affinity_data.csv')

# if in Collab
# download_files()
# df = pd.read_csv('files/erbB1_affinity_data.csv')

df = df[df.neg_log_ic50>10]

smiles = df.sample(n=5).smiles.values

print(reward_function(smiles))
[ 8.94284966  9.31949995 10.02078208  9.06363406  9.89236912]
new_ds = agent.dataset.new(smiles)

batch = collate_ds(new_ds)
x,y = batch

latents = agent.model.x_to_latent(to_device(x))

genbatch = 1500
n_latents = 200

latents = latents[np.random.choice(range(latents.shape[0]), n_latents)]
latents = latents + to_device(torch.randn(latents.shape)/100)

sampler = LatentSampler(agent.vocab, agent.model, latents, 'latent', 0, 1., genbatch, opt_kwargs={'lr':1e-3})

samplers = [sampler]

Loss Function

We will use the PolicyGradient class, the simplest policy gradient algorithm

pg = PolicyGradient(discount=True, gamma=0.97)

loss = PolicyLoss(pg, 'PG')

losses = [loss]

Environment

We create our environment with the objects assembled so far

env = Environment(agent, template_cb, samplers=samplers, rewards=[aff_reward], losses=losses,
                 cbs=[])

Train

set_global_pool(cpus=min(10, os.cpu_count()))
env.fit(128, 90, 500, 25)
iterations rewards rewards_final new diversity bs template valid latent_diversity latent_valid latent_rewards latent_new aff PG
0 5.365 5.365 1.000 1.000 115 0.000 0.992 0.906 0.991 5.365 1.000 5.365 -0.006
25 5.833 5.833 0.650 1.000 117 0.000 0.969 0.945 0.967 5.833 0.650 5.833 -0.030
50 6.255 6.255 0.593 1.000 118 0.000 0.984 0.938 0.983 6.255 0.593 6.255 -0.125
75 6.306 6.306 0.619 1.000 118 0.000 0.992 0.930 0.992 6.306 0.619 6.306 -0.142
100 6.930 6.930 0.500 1.000 112 0.000 0.977 0.898 0.974 6.930 0.500 6.930 -0.152
125 7.037 7.037 0.383 1.000 115 0.000 0.992 0.906 0.991 7.037 0.383 7.037 -0.220
150 7.120 7.120 0.459 1.000 109 0.000 0.977 0.875 0.973 7.120 0.459 7.120 -0.192
175 7.203 7.203 0.352 1.000 108 0.000 0.953 0.891 0.947 7.203 0.352 7.203 -0.210
200 7.146 7.146 0.434 1.000 113 0.000 1.000 0.883 1.000 7.146 0.434 7.146 -0.272
225 7.570 7.570 0.386 1.000 114 0.000 0.984 0.906 0.983 7.570 0.386 7.570 -0.206
250 7.712 7.712 0.307 1.000 114 0.000 0.992 0.898 0.991 7.712 0.307 7.712 -0.178
275 8.018 8.018 0.269 1.000 104 0.000 0.977 0.844 0.963 8.018 0.269 8.018 -0.170
300 7.657 7.657 0.255 1.000 110 0.000 0.977 0.883 0.973 7.657 0.255 7.657 -0.169
325 7.896 7.896 0.259 1.000 108 0.000 0.953 0.891 0.947 7.896 0.259 7.896 -0.146
350 8.129 8.129 0.269 1.000 108 0.000 0.938 0.906 0.931 8.129 0.269 8.129 -0.136
375 8.225 8.225 0.218 1.000 101 0.000 0.922 0.867 0.910 8.225 0.218 8.225 -0.192
400 7.798 7.798 0.223 1.000 103 0.000 0.969 0.836 0.963 7.798 0.223 7.798 -0.154
425 8.135 8.135 0.221 1.000 104 0.000 0.953 0.836 0.972 8.135 0.221 8.135 -0.102
450 7.995 7.995 0.243 1.000 107 0.000 0.969 0.867 0.964 7.995 0.243 7.995 -0.230
475 8.272 8.272 0.205 1.000 112 0.000 0.984 0.891 0.982 8.272 0.205 8.272 -0.140
env.log.plot_metrics()