LSTM Language Models
LSTM language models are a type of autoregressive generative model. This particular type of model is a good fit for RL-based optimization as they are light, robust and easy to optimize. These models make use of the LSTM architecture design.
Language models are trained in a self-supervised fashion by next token prediction. Given a series of tokens, the model predicts a probability distribution over he next token. Self supervised training is very fast and doesn't require any data labels. Each text string in the dataset labels itself.
During generation, we sample from the model in an autoregression fashion. Given an input token, the model predicts aa distribution of tokens over the next token. We then sample from that distributiona and feed the selected token back into the model. We repeat this process until either an end of sentence (EOS) token is predicted, or the generated sequence reaches a maximum allowed length.
During sampling, we save the log probability of each token predicted. This gives us a probability value for the model's estimated likelihood of the generated compound. We can also backprop through this value.
LSTM_LM
The LSTM_LM
model is a bare-bones LSTM language model. The model consists of an embedding layer, a series of LSTM layers, and a final output layer that generates a prediction.
lm = LSTM_LM(32, 64, 256, 2)
ints = torch.randint(0, 31, (16, 10))
x = ints[:,:-1]
y = ints[:,1:]
out = lm(x)
o,lp,lpg,e = lm.get_rl_tensors(x,y)
_ = lm.sample(8, 10)
to_device(lm)
x = to_device(x)
y = to_device(y)
o,lp,lpg,e = lm.get_rl_tensors(x,y)
_ = lm.sample(8, 10)
Conditional LSTM
Conditional_LSTM_LM
is a conditional variant of LSTM_LM
. This model uses an encoder to generate a latent vector that is used to condition the output LSTM. The model will work with any of the Encoder
subclasses in the layers
section. Custom encoders also work so long as they are compatible with the data format you are using and produce a single vector per batch item.
If condition_hidden=True
, the latent variable is used to initialize the hidden state of the LSTM decoder. If condition_output=True
, the latent vector is concatenated to the activations going into the first LSTM layer. Both conditions can be used simultaneously.
One important detail of this model is it imposes no strict prior distribution on the laten vector, compared to models like VAEs which impose a Gaussian prior on latent variables. This poses a problem for sampling from the model, as we do not know the distribution to sample latent vectors from. We could sample a vectors that map to invalid outputs. One solution to this is to normalize the latent vector to a length of 1, essentially imposing the constraint that all latent vectors should lie on the surface of the sphere. This is similar to what is done in StyleGAN type models. Set norm_latent=True
to impose his constraint.
encoder = MLP_Encoder(128, [64, 32], 16, [0.1, 0.1])
lm = Conditional_LSTM_LM(encoder, 32, 64, 128, 16, 2)
ints = torch.randint(0, 31, (8, 10))
x = ints[:,:-1]
y = ints[:,1:]
condition = torch.randn((8,128))
prior = SphericalPrior(torch.zeros((encoder.d_latent,)),
torch.zeros((encoder.d_latent,)), True)
_ = lm(x, condition)
o,lp,lpg,e = lm.get_rl_tensors([x,condition],y)
_ = lm.sample(3, 80)
lm.prior = prior
o,lp,lpg,e = lm.get_rl_tensors([x,condition],y)
loss = lpg.mean()
assert lm.prior.loc.grad is None
loss.backward()
assert lm.prior.loc.grad is not None
# standard lm
from mrl.vocab import *
from mrl.dataloaders import *
df = pd.read_csv('files/smiles.csv')
vocab = CharacterVocab(SMILES_CHAR_VOCAB)
ds = Text_Dataset(df.smiles.values, vocab)
dl = ds.dataloader(16, num_workers=0)
loss = CrossEntropy()
d_vocab = len(vocab.itos)
d_embedding = 256
d_hidden = 1024
n_layers = 3
input_dropout = 0.3
lstm_dropout = 0.3
bos_idx = vocab.stoi['bos']
bidir = False
tie_weights = True
model = LSTM_LM(d_vocab,
d_embedding,
d_hidden,
n_layers,
input_dropout,
lstm_dropout,
bos_idx,
bidir,
tie_weights)
to_device(model)
opt = optim.Adam(model.parameters())
scheduler = optim.lr_scheduler.OneCycleLR(opt, max_lr=1e-3,
steps_per_epoch=len(dl), epochs=10)
losses = []
for i, batch in enumerate(dl):
batch = to_device(batch)
x,y = batch
preds = model(x)
batch_loss = loss(preds, y)
opt.zero_grad()
batch_loss.backward()
opt.step()
scheduler.step()
losses.append(batch_loss.item())
plt.plot(losses)
# fp conditional lm
from mrl.vocab import *
from mrl.dataloaders import *
from mrl.chem import ECFP6
df = pd.read_csv('files/smiles.csv')
vocab = CharacterVocab(SMILES_CHAR_VOCAB)
ds = Vec_To_Text_Dataset(df.smiles.values, vocab, ECFP6)
dl = ds.dataloader(16, num_workers=0)
loss = CrossEntropy()
encoder = MLP_Encoder(2048, [1024, 512], 512, [0.1, 0.1])
d_vocab = len(vocab.itos)
d_embedding = 256
d_hidden = 1024
d_latent = 512
n_layers = 3
input_dropout = 0.3
lstm_dropout = 0.3
condition_hidden = True
condition_output = False
bos_idx = vocab.stoi['bos']
norm_latent = True
model = Conditional_LSTM_LM(encoder,
d_vocab,
d_embedding,
d_hidden,
d_latent,
n_layers,
input_dropout,
lstm_dropout,
norm_latent,
condition_hidden,
condition_output,
bos_idx)
to_device(model)
opt = optim.Adam(model.parameters())
scheduler = optim.lr_scheduler.OneCycleLR(opt, max_lr=1e-3,
steps_per_epoch=len(dl), epochs=10)
losses = []
for i, batch in enumerate(dl):
batch = to_device(batch)
x,y = batch
preds = model(*x)
batch_loss = loss(preds, y)
opt.zero_grad()
batch_loss.backward()
opt.step()
scheduler.step()
losses.append(batch_loss.item())
plt.plot(losses)
# text conditional lm
from mrl.vocab import *
from mrl.dataloaders import *
from mrl.chem import ECFP6
df = pd.read_csv('files/smiles.csv')
vocab = CharacterVocab(SMILES_CHAR_VOCAB)
ds = Text_Dataset(df.smiles.values, vocab)
dl = ds.dataloader(16, num_workers=0)
loss = CrossEntropy()
d_vocab = len(vocab.itos)
d_embedding = 256
d_hidden = 1024
d_latent = 512
n_layers = 3
input_dropout = 0.3
lstm_dropout = 0.3
condition_hidden = True
condition_output = False
bos_idx = vocab.stoi['bos']
norm_latent = True
encoder = LSTM_Encoder(
d_vocab,
d_embedding,
d_hidden,
n_layers,
d_latent,
input_dropout=input_dropout,
lstm_dropout=lstm_dropout
)
model = Conditional_LSTM_LM(encoder,
d_vocab,
d_embedding,
d_hidden,
d_latent,
n_layers,
input_dropout,
lstm_dropout,
norm_latent,
condition_hidden,
condition_output,
bos_idx)
to_device(model)
opt = optim.Adam(model.parameters())
scheduler = optim.lr_scheduler.OneCycleLR(opt, max_lr=1e-3,
steps_per_epoch=len(dl), epochs=10)
losses = []
for i, batch in enumerate(dl):
batch = to_device(batch)
x,y = batch
preds = model(x)
batch_loss = loss(preds, y)
opt.zero_grad()
batch_loss.backward()
opt.step()
scheduler.step()
losses.append(batch_loss.item())
plt.plot(losses)
# vec to text
from mrl.vocab import *
from mrl.dataloaders import *
from mrl.chem import ECFP6
from rdkit.Chem.Scaffolds import MurckoScaffold
from mrl.chem import to_mol, to_smile
df = pd.read_csv('files/smiles.csv')
source_smiles = [to_smile(MurckoScaffold.GetScaffoldForMol(to_mol(i))) for i in df.smiles.values]
target_smiles = df.smiles.values
smiles = [(source_smiles[i], target_smiles[i]) for i in range(len(source_smiles))]
vocab = CharacterVocab(SMILES_CHAR_VOCAB)
ds = Vec_To_Text_Dataset(smiles, vocab, ECFP6)
dl = ds.dataloader(16, num_workers=0)
loss = CrossEntropy()
encoder = MLP_Encoder(2048, [1024, 512], 512, [0.1, 0.1])
d_vocab = len(vocab.itos)
d_embedding = 256
d_hidden = 1024
d_latent = 512
n_layers = 3
input_dropout = 0.3
lstm_dropout = 0.3
condition_hidden = True
condition_output = False
bos_idx = vocab.stoi['bos']
norm_latent = True
model = Conditional_LSTM_LM(encoder,
d_vocab,
d_embedding,
d_hidden,
d_latent,
n_layers,
input_dropout,
lstm_dropout,
norm_latent,
condition_hidden,
condition_output,
bos_idx)
to_device(model)
opt = optim.Adam(model.parameters())
scheduler = optim.lr_scheduler.OneCycleLR(opt, max_lr=1e-3,
steps_per_epoch=len(dl), epochs=10)
losses = []
for i, batch in enumerate(dl):
batch = to_device(batch)
x,y = batch
preds = model(*x)
batch_loss = loss(preds, y)
opt.zero_grad()
batch_loss.backward()
opt.step()
scheduler.step()
losses.append(batch_loss.item())
plt.plot(losses)
# vec to text with rollout
from mrl.vocab import *
from mrl.dataloaders import *
from mrl.chem import ECFP6
from rdkit.Chem.Scaffolds import MurckoScaffold
from mrl.chem import to_mol, to_smile
df = pd.read_csv('files/smiles.csv')
source_smiles = [to_smile(MurckoScaffold.GetScaffoldForMol(to_mol(i))) for i in df.smiles.values]
target_smiles = df.smiles.values
smiles = [(source_smiles[i], target_smiles[i]) for i in range(len(source_smiles))]
vocab = CharacterVocab(SMILES_CHAR_VOCAB)
ds = Vec_To_Text_Dataset(smiles, vocab, ECFP6)
dl = ds.dataloader(16, num_workers=0)
loss = CrossEntropy()
encoder = MLP_Encoder(2048, [1024, 512], 512, [0.1, 0.1])
d_vocab = len(vocab.itos)
d_embedding = 256
d_hidden = 1024
d_latent = 512
n_layers = 3
input_dropout = 0.3
lstm_dropout = 0.3
condition_hidden = True
condition_output = False
bos_idx = vocab.stoi['bos']
norm_latent = True
forward_rollout = True
p_force = 1.
force_decay = .9
model = Conditional_LSTM_LM(encoder,
d_vocab,
d_embedding,
d_hidden,
d_latent,
n_layers,
input_dropout,
lstm_dropout,
norm_latent,
condition_hidden,
condition_output,
bos_idx,
forward_rollout=forward_rollout,
p_force=p_force,
force_decay=force_decay)
to_device(model)
opt = optim.Adam(model.parameters())
scheduler = optim.lr_scheduler.OneCycleLR(opt, max_lr=1e-3,
steps_per_epoch=len(dl), epochs=10)
losses = []
for i, batch in enumerate(dl):
batch = to_device(batch)
x,y = batch
preds = model(*x)
batch_loss = loss(preds, y)
opt.zero_grad()
batch_loss.backward()
opt.step()
scheduler.step()
losses.append(batch_loss.item())
plt.plot(losses)