Deep Drug Decoder
This tutorial shows how to implement the conditional LSTM model from Direct Steering of de novo Molecular Generation using Descriptor Conditional Recurrent Neural Networks (cRNNs).
In the paper, a vector of molecular features (weight, tpsa, etc) is used to condition generation. We can implement this in MRL with a conditional LSTM model
Performance Notes
Parts of this notebook aare CPU-constrained. If you have a multi-core machine, it is recommended that you uncomment and run the set_global_pool
cells in the notebook. This will trigger the use of multiprocessing, which will result in 2-4x speedups.
This notebook may run slow on Collab due to CPU limitations.
If running on Collab, remember to change the runtime to GPU
import sys
sys.path.append('..')
from mrl.imports import *
from mrl.core import *
from mrl.chem import *
from mrl.templates.all import *
from mrl.torch_imports import *
from mrl.torch_core import *
from mrl.layers import *
from mrl.dataloaders import *
from mrl.g_models.all import *
from mrl.vocab import *
from mrl.policy_gradient import *
from mrl.train.all import *
from mrl.model_zoo import *
agent = LSTM_LM_Small_ZINC_NC(base_model=False)
agent.model.eval();
smiles = set()
for i in range(50):
preds, _ = agent.model.sample_no_grad(2048, 90)
s = list(set(agent.reconstruct(preds)))
mols = to_mols(s)
s = [s[i] for i in range(len(s)) if mols[i] is not None]
smiles.update(set(s))
smiles = list(smiles)
len(smiles)
del preds
del agent.model
del agent
gc.collect()
torch.cuda.empty_cache()
Setup - Dataset
Now we set up our dataset. We'll use the Vec_To_Text_Dataset
dataset class. To create our conditional vectors of molecular properties, we need a function that maps SMILES to vectors. We've created the SmilesFeaturizer
to do so. The featurizer will map a SMILES string to a vector of [weight, tpsa, logp, qed, hba, hbd]
values and normalize those values relative to the dataset
class SmilesFeaturizer():
def __init__(self, means=None, stds=None):
self.means = means
self.stds = stds
def __call__(self, smile):
output = self.featurize_smile(smile)
if self.means is not None:
output = output - self.means
if self.stds is not None:
output = output/self.stds
return output
def featurize_smile(self, smile):
mol = to_mol(smile)
if mol is not None:
weight = molwt(mol)
t = tpsa(mol)
lp = logp(mol)
q = qed(mol)
ha = hba(mol)
hd = hbd(mol)
r = rings(mol)
output = np.array([weight, t, lp, q, ha, hd, r])
else:
output = np.array([0.]*7)
return output
def stats_from_smiles(self, smiles):
self.means = None
self.stds = None
stats = maybe_parallel(self.featurize_smile, smiles)
stats = np.array(stats)
self.means = stats.mean(0)
self.stds = stats.std(0)
sf = SmilesFeaturizer()
sf.stats_from_smiles(smiles)
vocab = CharacterVocab(SMILES_CHAR_VOCAB)
ds = Vec_To_Text_Dataset(smiles, vocab, sf)
Now we can inspect our data. We are going to give the model the vector of properties and the integer token codes of the SMILES string it should generate
ds[0]
Model Creation
We will use the Conditional_LSTM_LM
class as our model. This model conditions the hidden and cell states of a LSTM with the condition vector.
One important difference here from previous tutorials using the Conditional_LSTM_LM
is we do not normalize latent vectors. This means it will be harder to sample from the latent space, but the model will produce better results from the condition vectors since they are not constrained by the norm prior
enc_drops = [0.1, 0.1, 0.1]
encoder = MLP_Encoder(7, [256, 256, 256], 512, enc_drops)
d_vocab = len(vocab.itos)
bos_idx = vocab.stoi['bos']
d_latent = 512
d_embedding = 256
d_hidden = 1024
n_layers = 3
bidir = False
tie_weights = True
condition_hidden = True
condition_output = False
norm_latent = False
input_dropout = 0.3
lstm_dropout = 0.3
model = Conditional_LSTM_LM(encoder,
d_vocab,
d_embedding,
d_hidden,
d_latent,
n_layers,
input_dropout,
lstm_dropout,
norm_latent,
condition_hidden,
condition_output,
bos_idx)
agent = GenerativeAgent(model, vocab, CrossEntropy(), ds, base_model=False)
agent.train_supervised(512, 10, 1e-3)
agent.model.eval();
condition = to_device(ds[0][0].unsqueeze(0).repeat(512, 1)) # sample 512x from same condition
z = agent.model.x_to_latent((None, condition))
preds, _ = agent.model.sample_no_grad(512, 90, z=z)
smiles = agent.reconstruct(preds)
smiles = [i for i in smiles if to_mol(i) is not None]
stats = np.array([sf(i) for i in smiles])
fig, axes = plt.subplots(3, 3, figsize=(12,8))
properties = ['Weight', 'TPSA', 'LogP', 'QED', 'HBA', 'HBD', 'Rings']
for i, ax in enumerate(axes.flat):
if i<len(properties):
ax.hist(stats[:,i])
ax.axvline(condition[0][i].detach().cpu(), color='r')
ax.set_xlabel(properties[i])
else:
ax.axis('off')
subset = np.random.choice(smiles, size=8, replace=False)
stats = [np.array([molwt(i), tpsa(i), logp(i), qed(i), hba(i), hbd(i), rings(i)]) for i in to_mols(subset)]
legends = [f'Weight: {i[0]:.1f}, TPSA: {i[1]:.1f},\
LogP: {i[2]:.1f}\nQED: {i[3]:.2f}, \
HBA: {i[4]:.0f}, HBD: {i[5]:.0f}, Rings: {i[6]:.0f}'
for i in stats]
draw_mols(to_mols(subset), legends=legends)