LSTM Language Models
LSTM language models are a type of autoregressive generative model. This particular type of model is a good fit for RL-based optimization as they are light, robust and easy to optimize. These models make use of the LSTM architecture design.
Language models are trained in a self-supervised fashion by next token prediction. Given a series of tokens, the model predicts a probability distribution over he next token. Self supervised training is very fast and doesn't require any data labels. Each text string in the dataset labels itself.
During generation, we sample from the model in an autoregression fashion. Given an input token, the model predicts aa distribution of tokens over the next token. We then sample from that distributiona and feed the selected token back into the model. We repeat this process until either an end of sentence (EOS) token is predicted, or the generated sequence reaches a maximum allowed length.
During sampling, we save the log probability of each token predicted. This gives us a probability value for the model's estimated likelihood of the generated compound. We can also backprop through this value.
import sys
sys.path.append('..')
from mrl.imports import *
from mrl.core import *
from mrl.chem import *
from mrl.torch_imports import *
from mrl.torch_core import *
from mrl.layers import *
from mrl.dataloaders import *
from mrl.g_models.all import *
from mrl.train.agent import *
from mrl.vocab import *
Setup
Before creating a model, we need to set up our data.
Our raw data is in the form of SMILES strings. We need to convert these to tensors.
First we need a Vocab
to handle converting strings to tokens and mapping those tokens to integers. We will use the CharacterVocab
class with the SMILES_CHAR_VOCAB
vocabulary. This will tokenize SMILES on a character basis.
More sophisticated tokenization schemes exist, but character tokenization is nice for the simplicity. Character tokenization has a small, compact vocabulary. Other tokenization strategies can tokenize by more meaningful subwords, but these strategies create a long tail of low frequency tokens and lots of unk
characters.
df = pd.read_csv('../files/smiles.csv')
# if in Collab:
# download_files()
# df = pd.read_csv('files/smiles.csv')
df.head()
vocab = CharacterVocab(SMILES_CHAR_VOCAB)
vocab
first tokenizes smiles into characters, then numericalizes the tokens into integer keys
' '.join(vocab.tokenize(df.smiles.values[0]))
' '.join([str(i) for i in vocab.numericalize(vocab.tokenize(df.smiles.values[0]))])
Now we need a dataset. We will use the Text_Dataset
class
dataset = Text_Dataset(df.smiles.values, vocab)
dataloader = dataset.dataloader(32)
Now we can look at the actual data
x,y = next(iter(dataloader))
x
y
You will notice the y
tensor is the same as the x
tensor with the values shifted by one. This is because the goal of autoregressive language modeling is to predict the next character given the previous series of characters.
Model Creation
We can create a model through the LSTM_LM
class
d_vocab = len(vocab.itos)
bos_idx = vocab.stoi['bos']
d_embedding = 256
d_hidden = 1024
n_layers = 3
bidir = False
tie_weights = True
input_dropout = 0.3
lstm_dropout = 0.3
model = LSTM_LM(d_vocab,
d_embedding,
d_hidden,
n_layers,
input_dropout,
lstm_dropout,
bos_idx,
bidir,
tie_weights)
We can see the model is quite simple. We have an embedding layer, three LSTM layers, and an output layer
model
Now we'll put the model into a GenerativeAgent
to manage supervised training.
We need to specify a loss function - we will use standard cross entropy
loss_function = CrossEntropy()
agent = GenerativeAgent(model, vocab, loss_function, dataset, base_model=False)
Now we can train in a supervised fashion on next token prediction
agent.train_supervised(32, 1, 1e-3)
This was just a quick example to show the training API. We're not going to do a whole training process here. To train custom models, repeat this code with your own set of SMILES.
Pre-trained Models
The MRL model zoo offers a number of pre-trained models. We'll load one of these to continue.
We'll use the LSTM_LM_Small_Chembl
model. This model was trained first on a chunk of the ZINC database, then fine-tuned on the CHEMBL database
del model
del agent
gc.collect()
torch.cuda.empty_cache()
from mrl.model_zoo import LSTM_LM_Small_Chembl
agent = LSTM_LM_Small_Chembl(base_model=False)
Now with a fully pre-trained model, we can look at drawing samples
preds, lps = agent.model.sample_no_grad(256, 100, temperature=1.)
preds
lps.shape
The sample_no_grad
function gives is two outputs - preds
and lps
.
preds
is a long tensor of size (bs, sl)
containing the integer tokens of the samples.
lps
is a float tensor of size (bs, sl)
containing the log probabilities of each value in preds
We can now reconstruct the predictions back into SMILES strings
smiles = agent.reconstruct(preds)
smiles[:10]
mols = to_mols(smiles)
Now lets look at some key generation statistics.
- diversity - the percentage of unique samples
- valid - the number of chemically valid samples
div = len(set(smiles))/len(smiles)
val = len([i for i in mols if i is not None])/len(mols)
print(f'Diversity:\t{div:.3f}\nValid:\t\t{val:.3f}')
valid_mols = [i for i in mols if i is not None]
draw_mols(valid_mols[:16], mols_per_row=4)