Training and using LSTM language models

LSTM Language Models

LSTM language models are a type of autoregressive generative model. This particular type of model is a good fit for RL-based optimization as they are light, robust and easy to optimize. These models make use of the LSTM architecture design.

Language models are trained in a self-supervised fashion by next token prediction. Given a series of tokens, the model predicts a probability distribution over he next token. Self supervised training is very fast and doesn't require any data labels. Each text string in the dataset labels itself.

During generation, we sample from the model in an autoregression fashion. Given an input token, the model predicts aa distribution of tokens over the next token. We then sample from that distributiona and feed the selected token back into the model. We repeat this process until either an end of sentence (EOS) token is predicted, or the generated sequence reaches a maximum allowed length.

During sampling, we save the log probability of each token predicted. This gives us a probability value for the model's estimated likelihood of the generated compound. We can also backprop through this value.

Performance Notes

If running on Collab, remember to change the runtime to GPU

import sys
sys.path.append('..')

from mrl.imports import *
from mrl.core import *
from mrl.chem import *
from mrl.torch_imports import *
from mrl.torch_core import *
from mrl.layers import *
from mrl.dataloaders import *
from mrl.g_models.all import *
from mrl.train.agent import *
from mrl.vocab import *
/home/dmai/miniconda3/envs/mrl/lib/python3.7/importlib/_bootstrap.py:219: RuntimeWarning: to-Python converter for boost::shared_ptr<RDKit::FilterCatalogEntry const> already registered; second conversion method ignored.
  return f(*args, **kwds)

Setup

Before creating a model, we need to set up our data.

Our raw data is in the form of SMILES strings. We need to convert these to tensors.

First we need a Vocab to handle converting strings to tokens and mapping those tokens to integers. We will use the CharacterVocab class with the SMILES_CHAR_VOCAB vocabulary. This will tokenize SMILES on a character basis.

More sophisticated tokenization schemes exist, but character tokenization is nice for the simplicity. Character tokenization has a small, compact vocabulary. Other tokenization strategies can tokenize by more meaningful subwords, but these strategies create a long tail of low frequency tokens and lots of unk characters.

df = pd.read_csv('../files/smiles.csv')

# if in Collab:
# download_files()
# df = pd.read_csv('files/smiles.csv')
df.head()
smiles
0 CNc1nc(SCC(=O)Nc2cc(Cl)ccc2OC)nc2ccccc12
1 COc1ccc(C(=O)Oc2ccc(/C=C3\C(=N)N4OC(C)=CC4=NC3...
2 Cc1sc(NC(=O)c2ccccc2)c(C(N)=O)c1C
3 COc1ccc(NCc2noc(-c3ccoc3)n2)cc1OC(F)F
4 O=C(COC(=O)c1cccc(Br)c1)c1ccc2c(c1)OCCCO2
vocab = CharacterVocab(SMILES_CHAR_VOCAB)

vocab first tokenizes smiles into characters, then numericalizes the tokens into integer keys

' '.join(vocab.tokenize(df.smiles.values[0]))
'bos C N c 1 n c ( S C C ( = O ) N c 2 c c ( C l ) c c c 2 O C ) n c 2 c c c c c 1 2 eos'
' '.join([str(i) for i in vocab.numericalize(vocab.tokenize(df.smiles.values[0]))])
'0 23 27 34 11 37 34 5 30 23 23 5 20 28 6 27 34 12 34 34 5 23 36 6 34 34 34 12 28 23 6 37 34 12 34 34 34 34 34 11 12 1'

Now we need a dataset. We will use the Text_Dataset class

dataset = Text_Dataset(df.smiles.values, vocab)
dataloader = dataset.dataloader(32)

Now we can look at the actual data

x,y = next(iter(dataloader))
x
tensor([[ 0, 23, 27,  ...,  2,  2,  2],
        [ 0, 23, 28,  ...,  2,  2,  2],
        [ 0, 23, 34,  ...,  2,  2,  2],
        ...,
        [ 0, 23,  4,  ...,  2,  2,  2],
        [ 0, 23, 23,  ...,  2,  2,  2],
        [ 0, 23, 28,  ...,  2,  2,  2]])
y
tensor([[23, 27, 34,  ...,  2,  2,  2],
        [23, 28, 34,  ...,  2,  2,  2],
        [23, 34, 11,  ...,  2,  2,  2],
        ...,
        [23,  4, 23,  ...,  2,  2,  2],
        [23, 23, 11,  ...,  2,  2,  2],
        [23, 28, 34,  ...,  2,  2,  2]])

You will notice the y tensor is the same as the x tensor with the values shifted by one. This is because the goal of autoregressive language modeling is to predict the next character given the previous series of characters.

Model Creation

We can create a model through the LSTM_LM class

d_vocab = len(vocab.itos)
bos_idx = vocab.stoi['bos']
d_embedding = 256
d_hidden = 1024
n_layers = 3
bidir = False
tie_weights = True    
input_dropout = 0.3
lstm_dropout = 0.3
 
model = LSTM_LM(d_vocab, 
                d_embedding,
                d_hidden, 
                n_layers,
                input_dropout,
                lstm_dropout,
                bos_idx, 
                bidir, 
                tie_weights)

We can see the model is quite simple. We have an embedding layer, three LSTM layers, and an output layer

model
LSTM_LM(
  (embedding): Embedding(47, 256)
  (lstm): LSTM(
    (input_drop): SequenceDropout()
    (lstm_drop): SequenceDropout()
    (lstms): ModuleList(
      (0): LSTM(256, 1024, batch_first=True)
      (1): LSTM(1024, 1024, batch_first=True)
      (2): LSTM(1024, 256, batch_first=True)
    )
  )
  (head): Linear(in_features=256, out_features=47, bias=True)
)

Now we'll put the model into a GenerativeAgent to manage supervised training.

We need to specify a loss function - we will use standard cross entropy

loss_function = CrossEntropy()
agent = GenerativeAgent(model, vocab, loss_function, dataset, base_model=False)

Now we can train in a supervised fashion on next token prediction

agent.train_supervised(32, 1, 1e-3)
Epoch Train Loss Valid Loss Time
0 2.52853 3.02847 00:03

This was just a quick example to show the training API. We're not going to do a whole training process here. To train custom models, repeat this code with your own set of SMILES.

Pre-trained Models

The MRL model zoo offers a number of pre-trained models. We'll load one of these to continue.

We'll use the LSTM_LM_Small_Chembl model. This model was trained first on a chunk of the ZINC database, then fine-tuned on the CHEMBL database

del model
del agent
gc.collect()
torch.cuda.empty_cache()
from mrl.model_zoo import LSTM_LM_Small_Chembl
agent = LSTM_LM_Small_Chembl(base_model=False)

Now with a fully pre-trained model, we can look at drawing samples

preds, lps = agent.model.sample_no_grad(256, 100, temperature=1.)
preds
tensor([[23, 28, 23,  ...,  2,  2,  2],
        [23, 23, 28,  ...,  2,  2,  2],
        [23, 28, 23,  ...,  2,  2,  2],
        ...,
        [23, 27,  5,  ...,  2,  2,  2],
        [23, 28, 34,  ...,  2,  2,  2],
        [23, 34, 11,  ...,  2,  2,  2]], device='cuda:0')
lps.shape
torch.Size([256, 100])

The sample_no_grad function gives is two outputs - preds and lps.

preds is a long tensor of size (bs, sl) containing the integer tokens of the samples.

lps is a float tensor of size (bs, sl) containing the log probabilities of each value in preds

We can now reconstruct the predictions back into SMILES strings

smiles = agent.reconstruct(preds)
smiles[:10]
['COC1(OC)CCC2(CC(Br)(Br)CCl)CCC21',
 'CCOc1nccc(-c2csc(N3CCN(C)CC3)n2)c1C',
 'COC(=O)C1=C(C)NC(C)=C(C(=O)OC(C)C)C1c1sc2ccccc2c1Cl',
 'C[C@]1(O)C[C@@H](c2nc(-c3ccc(C(=O)Nc4cc(C(F)F)ccn4)cc3)c3c(N)nccn23)C1',
 'CC(=O)Oc1ccc2c(c1)OCO2',
 'NC(/C=C1/C(=O)Oc2ccc(C(F)(F)F)cc21)C(F)(F)F',
 'CC(C#N)OC(=O)c1ccccc1O',
 'Cc1cc(Cl)ccc1OCC(=O)N(c1nc2ccccc2s1)C1CCCCC1',
 'FC(F)(F)c1nc(N2CCC(N3CCCCC3)CC2)c2cccnc2n1',
 'CN(C)[C@H]1CC[C@H](NC(=N)NC(=O)CC2CCCC2)CC1']
mols = to_mols(smiles)

Now lets look at some key generation statistics.

  • diversity - the percentage of unique samples
  • valid - the number of chemically valid samples
div = len(set(smiles))/len(smiles)
val = len([i for i in mols if i is not None])/len(mols)
print(f'Diversity:\t{div:.3f}\nValid:\t\t{val:.3f}')
Diversity:	1.000
Valid:		0.938
valid_mols = [i for i in mols if i is not None]
draw_mols(valid_mols[:16], mols_per_row=4)