Implementing Deep Drug Decoder with MRL

Deep Drug Decoder

This tutorial shows how to implement the conditional LSTM model from Direct Steering of de novo Molecular Generation using Descriptor Conditional Recurrent Neural Networks (cRNNs).

In the paper, a vector of molecular features (weight, tpsa, etc) is used to condition generation. We can implement this in MRL with a conditional LSTM model

Performance Notes

Parts of this notebook aare CPU-constrained. If you have a multi-core machine, it is recommended that you uncomment and run the set_global_pool cells in the notebook. This will trigger the use of multiprocessing, which will result in 2-4x speedups.

This notebook may run slow on Collab due to CPU limitations.

If running on Collab, remember to change the runtime to GPU

import sys
sys.path.append('..')

from mrl.imports import *
from mrl.core import *
from mrl.chem import *
from mrl.templates.all import *

from mrl.torch_imports import *
from mrl.torch_core import *
from mrl.layers import *
from mrl.dataloaders import *
from mrl.g_models.all import *
from mrl.vocab import *
from mrl.policy_gradient import *
from mrl.train.all import *
from mrl.model_zoo import *
/home/dmai/miniconda3/envs/mrl/lib/python3.7/importlib/_bootstrap.py:219: RuntimeWarning: to-Python converter for boost::shared_ptr<RDKit::FilterCatalogEntry const> already registered; second conversion method ignored.
  return f(*args, **kwds)

Setup - Data Generation

First we need to generate some SMILES data. We'll do this using a pre-trained model

agent = LSTM_LM_Small_ZINC_NC(base_model=False)
agent.model.eval();
 
smiles = set()

for i in range(50):
    preds, _ = agent.model.sample_no_grad(2048, 90)
    s = list(set(agent.reconstruct(preds)))
    mols = to_mols(s)
    s = [s[i] for i in range(len(s)) if mols[i] is not None]
    smiles.update(set(s))
    
smiles = list(smiles)
len(smiles)
102126
del preds
del agent.model
del agent
gc.collect()
torch.cuda.empty_cache()

Setup - Dataset

Now we set up our dataset. We'll use the Vec_To_Text_Dataset dataset class. To create our conditional vectors of molecular properties, we need a function that maps SMILES to vectors. We've created the SmilesFeaturizer to do so. The featurizer will map a SMILES string to a vector of [weight, tpsa, logp, qed, hba, hbd] values and normalize those values relative to the dataset

class SmilesFeaturizer():
    def __init__(self, means=None, stds=None):
        self.means = means
        self.stds = stds
        
    def __call__(self, smile):
        output = self.featurize_smile(smile)
        
        if self.means is not None:
            output = output - self.means
            
        if self.stds is not None:
            output = output/self.stds
            
        return output
        
    def featurize_smile(self, smile):
        mol = to_mol(smile)

        if mol is not None:
            weight = molwt(mol)
            t = tpsa(mol)
            lp = logp(mol)
            q = qed(mol)
            ha = hba(mol)
            hd = hbd(mol)
            r = rings(mol)
            output = np.array([weight, t, lp, q, ha, hd, r])
        else:
            output = np.array([0.]*7)

        return output
    
    def stats_from_smiles(self, smiles):
        self.means = None
        self.stds = None
        
        stats = maybe_parallel(self.featurize_smile, smiles)
        stats = np.array(stats)
        self.means = stats.mean(0)
        self.stds = stats.std(0)
 
sf = SmilesFeaturizer()
sf.stats_from_smiles(smiles)
vocab = CharacterVocab(SMILES_CHAR_VOCAB)
ds = Vec_To_Text_Dataset(smiles, vocab, sf)

Now we can inspect our data. We are going to give the model the vector of properties and the integer token codes of the SMILES string it should generate

ds[0]
(tensor([-0.9333,  0.1813, -0.3171,  0.9053, -0.9953,  0.7025,  0.2339]),
 tensor([ 0, 23, 34, 11, 34, 34, 34,  5, 23,  5, 20, 28,  6, 27, 23, 12, 23, 23,
         27,  5, 23,  5, 20, 28,  6, 23, 20, 23, 34, 13, 34, 34, 34, 31, 37, 25,
         33, 13,  6, 23, 12,  6, 38, 11,  1]))

Model Creation

We will use the Conditional_LSTM_LM class as our model. This model conditions the hidden and cell states of a LSTM with the condition vector.

One important difference here from previous tutorials using the Conditional_LSTM_LM is we do not normalize latent vectors. This means it will be harder to sample from the latent space, but the model will produce better results from the condition vectors since they are not constrained by the norm prior

enc_drops = [0.1, 0.1, 0.1]

encoder = MLP_Encoder(7, [256, 256, 256], 512, enc_drops)

d_vocab = len(vocab.itos)
bos_idx = vocab.stoi['bos']

d_latent = 512
d_embedding = 256
d_hidden = 1024
n_layers = 3
bidir = False
tie_weights = True
condition_hidden = True
condition_output = False
norm_latent = False

input_dropout = 0.3
lstm_dropout = 0.3

model = Conditional_LSTM_LM(encoder,
                            d_vocab,
                            d_embedding,
                            d_hidden,
                            d_latent,
                            n_layers,
                            input_dropout,
                            lstm_dropout,
                            norm_latent,
                            condition_hidden,
                            condition_output,
                            bos_idx)
agent = GenerativeAgent(model, vocab, CrossEntropy(), ds, base_model=False)
agent.train_supervised(512, 10, 1e-3)
Epoch Train Loss Valid Loss Time
0 1.55650 1.46464 00:31
1 0.69325 0.74226 00:31
2 0.49820 0.50996 00:31
3 0.39854 0.43493 00:31
4 0.27718 0.36608 00:31
5 0.32794 0.34877 00:31
6 0.29620 0.32851 00:31
7 0.30384 0.32074 00:31
8 0.28538 0.31851 00:31
9 0.30455 0.31727 00:31

Sampling

We can now sample compounds from the trained model

agent.model.eval();
condition = to_device(ds[0][0].unsqueeze(0).repeat(512, 1)) # sample 512x from same condition

z = agent.model.x_to_latent((None, condition))
preds, _ = agent.model.sample_no_grad(512, 90, z=z)
smiles = agent.reconstruct(preds)
smiles = [i for i in smiles if to_mol(i) is not None]

stats = np.array([sf(i) for i in smiles])
fig, axes = plt.subplots(3, 3, figsize=(12,8))
properties = ['Weight', 'TPSA', 'LogP', 'QED', 'HBA', 'HBD', 'Rings']

for i, ax in enumerate(axes.flat):
    if i<len(properties):
        ax.hist(stats[:,i])
        ax.axvline(condition[0][i].detach().cpu(), color='r')
        ax.set_xlabel(properties[i])
    else:
        ax.axis('off')
subset = np.random.choice(smiles, size=8, replace=False)

stats = [np.array([molwt(i), tpsa(i), logp(i), qed(i), hba(i), hbd(i), rings(i)]) for i in to_mols(subset)]
legends = [f'Weight: {i[0]:.1f}, TPSA: {i[1]:.1f},\
 LogP: {i[2]:.1f}\nQED: {i[3]:.2f}, \
 HBA: {i[4]:.0f}, HBD: {i[5]:.0f}, Rings: {i[6]:.0f}'
          for i in stats]

draw_mols(to_mols(subset), legends=legends)