Latent Optimization Workflows
This notebook shows a basic workflow for optimizing a latent vector relative to a generative model. The focus here is on showing how to set up the code, rather than maximizing performance. For this reason, we will use a simple template and a simple reward function.
During latent optimization, we will create several latent vectors. Upon initialisation, these latent vectors will map to compounds with some score. We will then use gradient descent to optimize these vectors to produce high scoring compounds while keeping the generative model constant
Performance Notes
The workflow in this notebook is more CPU-constrained than GPU-constrained due to the need to evaluate samples on CPU. If you have a multi-core machine, it is recommended that you uncomment and run the set_global_pool
cells in the notebook. This will trigger the use of multiprocessing, which will result in 2-4x speedups.
This notebook may run slow on Collab due to CPU limitations.
If running on Collab, remember to change the runtime to GPU
import sys
sys.path.append('..')
from mrl.imports import *
from mrl.core import *
from mrl.chem import *
from mrl.templates.all import *
from mrl.torch_imports import *
from mrl.torch_core import *
from mrl.layers import *
from mrl.dataloaders import *
from mrl.g_models.all import *
from mrl.vocab import *
from mrl.policy_gradient import *
from mrl.train.all import *
from mrl.model_zoo import *
Agent
Here we create the model we want to optimize. We will use the FP_Cond_LSTM_LM_Small_ZINC
- a LSTM-based conditionl language model pretrained on part of the ZINC database.
Note that for latent optimization we specifically need a conditional generative model
agent = FP_Cond_LSTM_LM_Small_ZINC(drop_scale=0.2,opt_kwargs={'lr':5e-5}, base_model=None)
Here we freeze the weights of the model so that the model won't be updated during training
freeze(agent.model)
template = Template([ValidityFilter(),
SingleCompoundFilter()],
[])
template_cb = TemplateCallback(template, prefilter=True)
class FP_Regression_Score():
def __init__(self, fname):
self.model = torch.load(fname)
self.fp_function = partial(failsafe_fp, fp_function=ECFP6)
def __call__(self, samples):
mols = to_mols(samples)
fps = maybe_parallel(self.fp_function, mols)
fps = [fp_to_array(i) for i in fps]
x_vals = np.stack(fps)
preds = self.model.predict(x_vals)
return preds
# if in the repo
reward_function = FP_Regression_Score('../files/erbB1_regression.sklearn')
# if in Collab
# download_files()
# reward_function = FP_Regression_Score('files/erbB1_regression.sklearn')
reward = Reward(reward_function, weight=1.)
aff_reward = RewardCallback(reward, 'aff')
latents = torch.randn((500, agent.model.encoder.d_latent))
gen_bs = 1500
sampler = LatentSampler(agent.vocab, agent.model, latents, 'latent', 0, 1., gen_bs, opt_kwargs={'lr':1e-3})
samplers = [sampler]
df = pd.read_csv('../files/erbB1_affinity_data.csv')
# if in Collab
# download_files()
# df = pd.read_csv('files/erbB1_affinity_data.csv')
df = df[df.neg_log_ic50>10]
smiles = df.sample(n=5).smiles.values
print(reward_function(smiles))
new_ds = agent.dataset.new(smiles)
batch = collate_ds(new_ds)
x,y = batch
latents = agent.model.x_to_latent(to_device(x))
genbatch = 1500
n_latents = 200
latents = latents[np.random.choice(range(latents.shape[0]), n_latents)]
latents = latents + to_device(torch.randn(latents.shape)/100)
sampler = LatentSampler(agent.vocab, agent.model, latents, 'latent', 0, 1., genbatch, opt_kwargs={'lr':1e-3})
samplers = [sampler]
Loss Function
We will use the PolicyGradient
class, the simplest policy gradient algorithm
pg = PolicyGradient(discount=True, gamma=0.97)
loss = PolicyLoss(pg, 'PG')
losses = [loss]
env = Environment(agent, template_cb, samplers=samplers, rewards=[aff_reward], losses=losses,
cbs=[])
set_global_pool(cpus=min(10, os.cpu_count()))
env.fit(128, 90, 500, 25)
env.log.plot_metrics()