Prior Optimization Workflows
This notebook shows a basic workflow for optimizing a prior distribution relative to a generative model. The focus here is on showing how to set up the code, rather than maximizing performance. For this reason, we will use a simple template and a simple reward function.
During prior optimization, we will create a prior distribution with a specific mean and variance. We will then optimize the weights of the prior while keeping the generative model constant
Performance Notes
The workflow in this notebook is more CPU-constrained than GPU-constrained due to the need to evaluate samples on CPU. If you have a multi-core machine, it is recommended that you uncomment and run the set_global_pool
cells in the notebook. This will trigger the use of multiprocessing, which will result in 2-4x speedups.
This notebook may run slow on Collab due to CPU limitations.
If running on Collab, remember to change the runtime to GPU
import sys
sys.path.append('..')
from mrl.imports import *
from mrl.core import *
from mrl.chem import *
from mrl.templates.all import *
from mrl.torch_imports import *
from mrl.torch_core import *
from mrl.layers import *
from mrl.dataloaders import *
from mrl.g_models.all import *
from mrl.vocab import *
from mrl.policy_gradient import *
from mrl.train.all import *
from mrl.model_zoo import *
Agent
Here we create the model we want to optimize. We will use the FP_Cond_LSTM_LM_Small_ZINC
- a LSTM-based conditionl language model pretrained on part of the ZINC database.
Note that for prior optimization we specifically need a conditional generative model
agent = FP_Cond_LSTM_LM_Small_ZINC(drop_scale=0.5,opt_kwargs={'lr':5e-5}, base_model=None)
Here we freeze the weights of the model so that the model won't be updated during training
freeze(agent.model)
template = Template([ValidityFilter(),
SingleCompoundFilter()],
[])
template_cb = TemplateCallback(template, prefilter=True)
class FP_Regression_Score():
def __init__(self, fname):
self.model = torch.load(fname)
self.fp_function = partial(failsafe_fp, fp_function=ECFP6)
def __call__(self, samples):
mols = to_mols(samples)
fps = maybe_parallel(self.fp_function, mols)
fps = [fp_to_array(i) for i in fps]
x_vals = np.stack(fps)
preds = self.model.predict(x_vals)
return preds
# if in the repo:
reward_function = FP_Regression_Score('../files/erbB1_regression.sklearn')
# if in Collab:
# download_files()
# reward_function = FP_Regression_Score('files/erbB1_regression.sklearn')
reward = Reward(reward_function, weight=1.)
aff_reward = RewardCallback(reward, 'aff')
genbatch = 1500
priors = []
samplers = []
n_priors = 5
for i in range(n_priors):
prior = SphericalPrior(torch.zeros((agent.model.encoder.d_latent)),
torch.zeros((agent.model.encoder.d_latent)),
trainable=True)
priors.append(prior)
prior_loss = PriorLoss(prior)
sampler = PriorSampler(agent.vocab, agent.model, prior, f'prior_{i}',
0, 1./n_priors, genbatch,
train=True, train_all=False, prior_loss=prior_loss,
track_losses=False,
opt_kwargs={'lr':5e-3})
samplers.append(sampler)
df = pd.read_csv('../files/erbB1_affinity_data.csv')
# if in Collab:
# download_files()
# df = pd.read_csv('files/erbB1_affinity_data.csv')
df = df[df.neg_log_ic50>10]
smiles = df.sample(n=5).smiles.values
print(reward_function(smiles))
new_ds = agent.dataset.new(smiles)
batch = collate_ds(new_ds)
x,y = batch
latents = agent.model.x_to_latent(to_device(x))
genbatch = 1500
priors = []
samplers = []
n_priors = latents.shape[0]
for i in range(n_priors):
prior = SphericalPrior(latents[i],
torch.zeros((agent.model.encoder.d_latent))-1,
trainable=True)
priors.append(prior)
prior_loss = PriorLoss(prior)
sampler = PriorSampler(agent.vocab, agent.model, prior, f'prior_{i}',
0, 1./n_priors, genbatch,
train=True, train_all=False, prior_loss=prior_loss,
track_losses=False,
opt_kwargs={'lr':5e-3})
samplers.append(sampler)
Optional: Policy Gradient Loss
The priors we have set so far will be trained by the PriorLoss
passed to the PriorSampler
. We can additionally add a policy gradient loss term. This isn't necessary, but tends to speed up convergence.
losses = []
pg = PolicyGradient(discount=True, gamma=0.97)
loss = PolicyLoss(pg, 'PG')
losses.append(loss)
env = Environment(agent, template_cb, samplers=samplers, rewards=[aff_reward], losses=losses,
cbs=[])
set_global_pool(cpus=min(10, os.cpu_count()))
env.fit(128, 90, 500, 25)
env.log.plot_metrics()