Basic contrastive optimization workflow

Contrastive Optimization Workflows

Contrastive optimization is a type of conditional generation task. The goal is to put some sample into a conditional generative model and from that sample generate a new output sample that is similar to the input sample but with improved properties.

Our samples in this case take the form of (source, target) pairs

Performance Notes

The workflow in this notebook is more CPU-constrained than GPU-constrained due to the need to evaluate samples on CPU. If you have a multi-core machine, it is recommended that you uncomment and run the set_global_pool cells in the notebook. This will trigger the use of multiprocessing, which will result in 2-4x speedups.

This notebook may run slow on Collab due to CPU limitations.

If running on Collab, remember to change the runtime to GPU

import sys
sys.path.append('..')

from mrl.imports import *
from mrl.core import *
from mrl.chem import *
from mrl.templates.all import *

from mrl.torch_imports import *
from mrl.torch_core import *
from mrl.layers import *
from mrl.dataloaders import *
from mrl.g_models.all import *
from mrl.vocab import *
from mrl.policy_gradient import *
from mrl.train.all import *
from mrl.model_zoo import *
/home/dmai/miniconda3/envs/mrl/lib/python3.7/importlib/_bootstrap.py:219: RuntimeWarning: to-Python converter for boost::shared_ptr<RDKit::FilterCatalogEntry const> already registered; second conversion method ignored.
  return f(*args, **kwds)

Agent

Here we create the model we want to optimize. We will use the FP_Cond_LSTM_LM_Small_ZINC - a LSTM-based conditionl language model pretrained on part of the ZINC database.

Note that for contrastive optimization we specifically need a conditional generative model

agent = FP_Cond_LSTM_LM_Small_ZINC(drop_scale=0.2,opt_kwargs={'lr':5e-5}, base_model=True)

Template

We will set up a very basic template that will only check compounds for structural validity.

For contrastive generation, we use the ContrastiveTemplate callback rather than TemplateCallback. The contrastive template will validate both the source and target samples pass the template.

The ContrastiveTemplate also allows us to impose a similarity constraint on (source, target) pairs. This is important to controlling the quality of results. If we have no similarity constraint, the model can learn to ignore the source input and simply generate high scoring compounds, which is not what we want here. On the fip side, if we reward the model for producing high similarity outputs, the model will simply learn to produce the exact source input, which is also not what we want.

We want our (source, target) pairs to be different, but not too different. We use the FPSimilarity to impose this constraint by setting minimum and maximum similarity values. In the code below, we set the allowed similarity range for (source, target) pairs to be between 0.3 and 0.9, measured by Tanimoto similarity.

template = Template([ValidityFilter(), 
                     SingleCompoundFilter()],
                    [])

sf = FPSimilarity(partial(failsafe_fp, fp_function=ECFP6), tanimoto_rd, 0.3, 0.9, 0.05, -1.)
template_cb = ContrastiveTemplate(sf, template=template, prefilter=True)

Reward

For the reward, we will load a scikit-learn linear regression model trained to predict affinity against erbB1 using molecular fingerprints.

This score function is extremely simple and won't translate well to affinity. It is used as a lightweight example

Similar to how we used the ContrastiveTemplate wrapper for our template, we use the ContrastiveReward wrapper for our reward. The contrastive reward will evaluate reward(target) - reward(source).

This however raises an interesting question on how to treat this score. Consider a score that ranges between [0,1]. We have two samples pairs with (source, target) scores of (0.2, 0.9), (0.5, 0.9). Both target samples have the same score, but different source scores. We could evaluate our contrastive score on the actual relative basis, which would give us:

  • (0.2, 0.9) -> (0.9-0.2)=.7
  • (0.5, 0.9) -> (0.9-0.5)=.4

We can also scale the scores relative to the maximum possible score. This can be thought of as rewarding the model for how much of the potential maximum score it achieved:

  • (0.2, 0.9) -> (0.9-0.2)/(1-0.2)=.875
  • (0.5, 0.9) -> (0.9-0.5)/(1-0.5)=.8

The code below does the latter, using 15 as the set max score (the reward function tends to range from [0,15])

class FP_Regression_Score():
    def __init__(self, fname):
        self.model = torch.load(fname)
        self.fp_function = partial(failsafe_fp, fp_function=ECFP6)
        
    def __call__(self, samples):
        mols = to_mols(samples)
        fps = maybe_parallel(self.fp_function, mols)
        fps = [fp_to_array(i) for i in fps]
        x_vals = np.stack(fps)
        preds = self.model.predict(x_vals)
        preds = np.clip(preds, 0, 15)
        return preds
    
# if in the repo
reward_function = FP_Regression_Score('../files/erbB1_regression.sklearn')

# if in Collab
# download_files
# reward_function = FP_Regression_Score('files/erbB1_regression.sklearn')

reward = Reward(reward_function, weight=1.)

aff_reward = RewardCallback(reward, 'aff')

aff_reward_contrastive = ContrastiveReward(aff_reward, max_score=20)

Loss Function

We will use the PolicyGradient class, the simplest policy gradient algorithm

pg = PolicyGradient(discount=True, gamma=0.97)

loss = PolicyLoss(pg, 'PG')

Optional: PPO

A more sophisticated policy gradient algorithm

#         0.5,
#         lam=0.95,
#         v_coef=0.5,
#         cliprange=0.3,
#         v_cliprange=0.3,
#         ent_coef=0.01,
#         kl_target=0.03,
#         kl_horizon=3000,
#         scale_rewards=True)

# loss = PolicyLoss(pg, 'PPO', 
#                    value_head=ValueHead(256), 
#                    v_update_iter=2, 
#                    vopt_kwargs={'lr':1e-3})

Samplers

The contrastive task views the model as a translator between source and target samples. This makes the sampling task a little bit different compared to other workflows. Instead of sampling compounds from the model, we want to give the model a pre-defined set of source samples and use the model to generate target samples. To do this, we first generrate a small dataset of ~200000 compounds.

%%time
gen_bs = 1500

all_smiles = set()
for i in range(100):
    preds, _ = agent.model.sample_no_grad(2000, 90)
    smiles = agent.reconstruct(preds)
    valids = [i is not None for i in to_mols(smiles)]
    smiles = [smiles[i] for i in range(len(smiles)) if valids[i]]
    all_smiles.update(set(smiles))
    
len(all_smiles)
CPU times: user 1min 8s, sys: 118 ms, total: 1min 8s
Wall time: 1min 8s
197091

Now we can set up our sampler. Similar to how we used the ContrastiveTemplate wrapper for our template, we use the ContrastiveSampler wrapper for our sampler.

The ContrastiveSampler takes another Sampler as input (the base sampler). It draws a set of source samples from the base sampler, then generates a set of target samples on the fly from the specified model. This wrapper can be applied to any Sampler class.

We'll create a DatasetSampler from the samples we just generated, then pass the DatasetSampler to ContrastiveSampler. Every batch we will sample source compounds from DatasetSampler, generate new target compounds on the fly, and train.

Note that using ContrastiveSampler is only necessary if the target samples are not known. If a dataset of (source, target) samples has already been defined, the (source, target) pairs can be passed as tuples directly to the DatasetSampler

sampler1 = DatasetSampler(list(all_smiles), 'smiles_data', 1000)
sampler1 = ContrastiveSampler(sampler1, agent.vocab, agent.dataset, agent.model, gen_bs, repeats=6)

sampler2 = LogSampler('samples', 'rewards', 50, 97, 500)

samplers = [sampler1, sampler2]

Environment

We create our environment with the objects assembled so far

env = Environment(agent, template_cb, samplers=samplers, rewards=[aff_reward_contrastive], losses=[loss],
                 cbs=[])
set_global_pool(min(10, os.cpu_count()))
env.fit(128, 90, 500, 25)
iterations rewards rewards_final new diversity bs template template_temp template_sim valid aff PG
0 0.048 0.048 1.000 1.000 128 0.050 0.000 0.050 1.000 -0.002 0.028
25 0.046 0.046 0.992 1.000 128 0.050 0.000 0.050 1.000 -0.004 -0.004
50 0.046 0.046 1.000 1.000 128 0.050 0.000 0.050 1.000 -0.004 -0.025
75 0.068 0.068 0.898 1.000 128 0.050 0.000 0.050 1.000 0.018 0.005
100 0.064 0.064 0.914 1.000 128 0.050 0.000 0.050 1.000 0.014 -0.009
125 0.078 0.078 0.852 1.000 128 0.050 0.000 0.050 1.000 0.028 -0.023
150 0.081 0.081 0.836 1.000 128 0.050 0.000 0.050 1.000 0.031 0.013
175 0.086 0.086 0.797 1.000 128 0.050 0.000 0.050 1.000 0.036 -0.001
200 0.093 0.093 0.828 1.000 128 0.050 0.000 0.050 1.000 0.043 0.003
225 0.076 0.076 0.852 1.000 128 0.050 0.000 0.050 1.000 0.026 -0.010
250 0.076 0.076 0.859 1.000 128 0.050 0.000 0.050 1.000 0.026 0.016
275 0.101 0.101 0.750 1.000 128 0.050 0.000 0.050 1.000 0.051 -0.019
300 0.078 0.078 0.844 1.000 128 0.050 0.000 0.050 1.000 0.028 -0.013
325 0.090 0.090 0.836 1.000 128 0.050 0.000 0.050 1.000 0.040 -0.021
350 0.092 0.092 0.859 1.000 128 0.050 0.000 0.050 1.000 0.042 -0.042
375 0.074 0.074 0.820 1.000 128 0.050 0.000 0.050 1.000 0.024 -0.012
400 0.086 0.086 0.852 1.000 128 0.050 0.000 0.050 1.000 0.036 -0.006
425 0.082 0.082 0.820 1.000 128 0.050 0.000 0.050 1.000 0.032 -0.013
450 0.115 0.115 0.766 1.000 128 0.050 0.000 0.050 1.000 0.065 -0.025
475 0.064 0.064 0.836 1.000 128 0.050 0.000 0.050 1.000 0.014 -0.059
env.log.plot_metrics()

Generating Samples

Here we'll grab 100 random smiles from our dataset and generate a series of improved smiles. We plot a scatter between the source score and the maximum generated target score

sampler1.repeats = 100
inputs = np.random.choice(list(all_smiles), 100, replace=False)
inputs = [remove_stereo(i) for i in inputs]
samples = sampler1.sample_outputs(inputs, 90)
samples = list(set(samples))

hps = template_cb.get_hps(samples)
samples = [samples[i] for i in range(len(samples)) if hps[i]]
samples = [(i[0], remove_stereo(i[1])) for i in samples]
samples = list(set(samples))
sample_df = pd.DataFrame([i[0] for i in samples], columns=['source'])
sample_df['target'] = [i[1] for i in samples]
sample_df['source_reward'] = aff_reward_contrastive.compute_and_clean(sample_df.source.values)
sample_df['target_reward'] = aff_reward_contrastive.compute_and_clean(sample_df.target.values)
gb = sample_df.groupby('source')
sr = gb.source_reward.mean()
tr = gb.target_reward.max()

fig, ax = plt.subplots()
ax.scatter(sr, tr, c=tr>sr)
ax.set_xlabel('Source Score')
ax.set_ylabel('Target Score')

lims = [
    np.min([ax.get_xlim(), ax.get_ylim()]),
    np.max([ax.get_xlim(), ax.get_ylim()]),
]

ax.plot(lims, lims, 'b', alpha=0.75, zorder=0, label='x=y line')
plt.legend();

We can look at iteratively applying the model to a series of compounds to develop a chain of improved versions

agent.model.eval();
progressions = []
rewards = []
sampler1.repeats = 100

for i in range(10):
    smile = remove_stereo(list(all_smiles)[1])
    progression = [smile]
    
    reward = [np.atleast_1d(aff_reward_contrastive.compute_and_clean([smile]))[0]]

    for j in range(10):
        current_smile = progression[-1]
        new_samples = sampler1.sample_outputs([current_smile], 90)
        new_samples = list(set(new_samples))
        clean_samples = []
        
        for sample in new_samples:
            source, target = sample
            if to_mol(target) is not None:
                target = remove_stereo(target)
                if not source==target and template_cb.get_hps([(source, target)])[0]:
                    clean_samples.append((source, target))

        if clean_samples:
            output_smiles = [i[1] for i in clean_samples]
            r = np.atleast_1d(aff_reward_contrastive.compute_and_clean(output_smiles))
            if r.max()>reward[-1]:
                progression.append(output_smiles[r.argmax()])
                reward.append(r[r.argmax()])
        else:
            break
            
    progressions.append(progression)
    rewards.append(reward)
for r in rewards:
    plt.plot(r)

Here's a series of itertive designs showing repeated steps of similarity constrained optimization.

(If you're a medicinal chemist and you think that 4-N linkage is weird, you can prevent that by updating the Template)

idx = np.array([i[-1] for i in rewards]).argmax()
draw_mols(to_mols(progressions[idx]), legends=[f"{i:.3f}" for i in rewards[idx]])