VAE Models
Variational Autoencoders (VAEs) are a class of generative model that follows an encoder-decoder framework. The encoder maps an input down to a latent vector. The decoder reconstructs the input, and the model is trained on the quality of the reconstruction. VAEs impose a Gaussian prior on the latent space by regularizing the deviation of latent vectors from a unit Gaussian. As a result, we can easily sample valid latent vectors by randomly sampling from a Gaussian distribution.
When dealing with generating text sequences (SMILES, amino acids, etc), we train the model in a self-supervised fashion using next token prediction. In this framework, the decoder module is a conditional LSTM. These models are very similar to the Conditional_LSTM_LM
model, with the added constraint on the latent space.
During generation, we sample from the model in an autoregression fashion. Given an input token, the model predicts a distribution of tokens over the next token. We then sample from that distribution and feed the selected token back into the model. We repeat this process until either an end of sentence (EOS) token is predicted, or the generated sequence reaches a maximum allowed length.
During sampling, we save the log probability of each token predicted. This gives us a probability value for the model's estimated likelihood of the generated compound. We can also backprop through this value.
VAE Transition
The VAE_Transition
class converts an input vector to a valid latent vector. Following standard VAE implementations, the input vector is used to predict a vector of means and a vector of log variances which parameterize a Gaussian distribution. We then use the Reparametarization Trick to sample a latent vector from that distribution.
The module returns the latent vectors and a KL-loss based on the latent distributions deviations from a unit Gaussian.
VAE
VAE
is the base class for VAE models. VAE models consist of a encoder, decoder, and transition modules.
The encoder module maps the model's inputs down to a vector. The model will work with any of the Encoder
subclasses in the layers
section. Custom encoders also work so long as they are compatible with the data format you are using and produce a single vector per batch item.
The transition module converts the encoder output to a valid latent vector. By default, VAEs use the VAE_Transition
which samples a latent vector assuming a target distribution of a unit Gaussian. This can be substituted for any module that returns latent vectors and a KL loss term (0 if not applicable).
The decoder module uses the latent vector to reconstruct the input.
encoder = LSTM_Encoder(32, 64, 128, 2, 128)
decoder = Conditional_LSTM_Block(32, 64, 128, 64, 128, 2,
condition_hidden=True, condition_output=True)
vae = VAE(encoder, decoder)
ints = torch.randint(0, 31, (8, 10))
x = ints[:,:-1]
y = ints[:,1:]
_ = vae(x)
decoder = Conditional_LSTM_Block(32, 64, 128, 64, 128, 2,
condition_hidden=False, condition_output=True)
vae = VAE(encoder, decoder)
_ = vae(x)
decoder = Conditional_LSTM_Block(32, 64, 128, 64, 128, 2,
condition_hidden=True, condition_output=False)
vae = VAE(encoder, decoder)
_ = vae(x)
_ = vae.sample(8, 16)
z = vae.prior.sample([8])
_ = vae.sample(8, 16, z=z)
o,lp,lpg,e = vae.get_rl_tensors(x,y)
vae.set_prior_from_encoder(x[0].unsqueeze(0), trainable=True);
o,lp,lpg,e = vae.get_rl_tensors(x,y)
loss = lpg.mean()
assert vae.prior.loc.grad is None
loss.backward()
assert vae.prior.loc.grad is not None
latent = torch.randn((x.shape[0], 128))
o,lp,lpg,e = vae.get_rl_tensors(x,y,latent=latent)
Preset VAEs
The following modules are preset encoder-decoder configurations.
LSTM_VAE
has a LSTM encoder and a conditional LSTM decoder. This works for most sequence to sequence type tasks.
Conv_VAE
replaces the LSTM encoder with a 1D conv encoder (the model is still intended for sequences, not 2D images). The convolutional encoder tends to be faster and lighter than the LSTM encoder. Anecdotally, it seems to perform better than the LSTM encoder.
MLP_VAE
uses a MLP encoder. This model works best for reconstructing a sequence from a fingerprint, property vector, or something similar.
vae = LSTM_VAE(32, 64, 128, 2, 128, condition_hidden=True, condition_output=True)
ints = torch.randint(0, 31, (8, 10))
x = ints[:,:-1]
y = ints[:,1:]
_ = vae(x)
_ = vae.sample(8, 16)
z = vae.prior.sample([8])
_ = vae.sample(8, 16, z=z)
_ = vae.get_rl_tensors(x,y)
vae = Conv_VAE(32, 64, [128, 256], [7,7], [2,2], [0.1,0.1], 128, 2, 128,
condition_hidden=False, condition_output=True)
ints = torch.randint(0, 31, (8, 10))
x = ints[:,:-1]
y = ints[:,1:]
_ = vae(y,x)
_ = vae.sample(8, 16)
z = vae.prior.sample([8])
_ = vae.sample(8, 16, z=z)
_ = vae.get_rl_tensors(x,y)
vae = MLP_VAE(32, 64, 128, [64, 32], [0.1, 0.1], 128, 2, 128,
condition_hidden=False, condition_output=True)
ints = torch.randint(0, 31, (8, 10))
x = ints[:,:-1]
y = ints[:,1:]
condition = torch.randn((8,128))
_ = vae(x, condition)
_ = vae.sample(8, 16)
z = vae.prior.sample([8])
_ = vae.sample(8, 16, z=z)
_ = vae.get_rl_tensors([x, condition],y)
# text reconstruction
from mrl.vocab import *
from mrl.dataloaders import *
df = pd.read_csv('files/smiles.csv')
vocab = CharacterVocab(SMILES_CHAR_VOCAB)
ds = Text_Dataset(df.smiles.values, vocab)
dl = ds.dataloader(16, num_workers=0)
loss = VAELoss()
d_vocab = len(vocab.itos)
d_embedding = 256
d_hidden = 1024
d_latent = 512
n_layers = 3
input_dropout = 0.3
lstm_dropout = 0.3
bidir = False
condition_hidden = True
condition_output = True
bos_idx = vocab.stoi['bos']
model = LSTM_VAE(
d_vocab,
d_embedding,
d_hidden,
n_layers,
d_latent,
input_dropout=0.0,
lstm_dropout=0.0,
condition_hidden=condition_hidden,
condition_output=condition_output,
prior=None,
bos_idx=bos_idx,
transition=None,
)
to_device(model)
opt = optim.Adam(model.parameters())
scheduler = optim.lr_scheduler.OneCycleLR(opt, max_lr=1e-3,
steps_per_epoch=len(dl), epochs=10)
losses = []
for i, batch in enumerate(dl):
batch = to_device(batch)
x,y = batch
preds = model(x)
batch_loss = loss(preds, y)
opt.zero_grad()
batch_loss.backward()
opt.step()
scheduler.step()
losses.append(batch_loss.item())
plt.plot(losses)
# text reconstruction conv
from mrl.vocab import *
from mrl.dataloaders import *
df = pd.read_csv('files/smiles.csv')
vocab = CharacterVocab(SMILES_CHAR_VOCAB)
ds = Text_Dataset(df.smiles.values, vocab)
dl = ds.dataloader(16, num_workers=0)
loss = VAELoss()
d_vocab = len(vocab.itos)
d_embedding = 256
conv_filters = [128, 256, 512]
kernel_sizes = [7,7,7]
strides = [2,2,2]
conv_drops = [0.3, 0.3, 0.3]
d_hidden = 1024
d_latent = 512
n_layers = 3
input_dropout = 0.3
lstm_dropout = 0.3
bidir = False
condition_hidden = True
condition_output = True
bos_idx = vocab.stoi['bos']
model = Conv_VAE(
d_vocab,
d_embedding,
conv_filters,
kernel_sizes,
strides,
conv_drops,
d_hidden,
n_layers,
d_latent,
input_dropout=input_dropout,
lstm_dropout=lstm_dropout,
condition_hidden=condition_hidden,
condition_output=condition_output,
prior=None,
bos_idx=bos_idx,
transition=None,
)
to_device(model)
opt = optim.Adam(model.parameters())
scheduler = optim.lr_scheduler.OneCycleLR(opt, max_lr=1e-3,
steps_per_epoch=len(dl), epochs=10)
losses = []
for i, batch in enumerate(dl):
batch = to_device(batch)
x,y = batch
preds = model(x)
batch_loss = loss(preds, y)
opt.zero_grad()
batch_loss.backward()
opt.step()
scheduler.step()
losses.append(batch_loss.item())
plt.plot(losses)
# vector to text
from mrl.vocab import *
from mrl.dataloaders import *
from mrl.chem import ECFP6
df = pd.read_csv('files/smiles.csv')
vocab = CharacterVocab(SMILES_CHAR_VOCAB)
ds = Vec_To_Text_Dataset(df.smiles.values, vocab, ECFP6)
dl = ds.dataloader(16, num_workers=0)
loss = VAELoss()
d_vocab = len(vocab.itos)
d_embedding = 256
encoder_d_in = 2048
encoder_dims = [1024, 512]
encoder_drops = [0.1, 0.1]
d_hidden = 1024
d_latent = 512
n_layers = 3
input_dropout = 0.3
lstm_dropout = 0.3
bidir = False
condition_hidden = True
condition_output = True
bos_idx = vocab.stoi['bos']
model = MLP_VAE(
d_vocab,
d_embedding,
encoder_d_in,
encoder_dims,
encoder_drops,
d_hidden,
n_layers,
d_latent,
input_dropout=input_dropout,
lstm_dropout=lstm_dropout,
condition_hidden=condition_hidden,
condition_output=condition_output,
prior=None,
bos_idx=bos_idx,
transition=None,
)
to_device(model)
opt = optim.Adam(model.parameters())
scheduler = optim.lr_scheduler.OneCycleLR(opt, max_lr=1e-3,
steps_per_epoch=len(dl), epochs=10)
losses = []
for i, batch in enumerate(dl):
batch = to_device(batch)
x,y = batch
preds = model(*x)
batch_loss = loss(preds, y)
opt.zero_grad()
batch_loss.backward()
opt.step()
scheduler.step()
losses.append(batch_loss.item())
plt.plot(losses)
# text to text
from mrl.vocab import *
from mrl.dataloaders import *
from rdkit.Chem.Scaffolds import MurckoScaffold
from mrl.chem import to_mol, to_smile
df = pd.read_csv('files/smiles.csv')
source_smiles = [to_smile(MurckoScaffold.GetScaffoldForMol(to_mol(i))) for i in df.smiles.values]
target_smiles = df.smiles.values
smiles = [(source_smiles[i], target_smiles[i]) for i in range(len(source_smiles))]
vocab = CharacterVocab(SMILES_CHAR_VOCAB)
ds = VAE_Seq2Seq_Dataset(smiles, vocab)
dl = ds.dataloader(16, num_workers=0)
loss = VAELoss()
d_vocab = len(vocab.itos)
d_embedding = 256
d_hidden = 1024
d_latent = 512
n_layers = 3
input_dropout = 0.3
lstm_dropout = 0.3
bidir = False
condition_hidden = True
condition_output = True
bos_idx = vocab.stoi['bos']
model = LSTM_VAE(
d_vocab,
d_embedding,
d_hidden,
n_layers,
d_latent,
input_dropout=0.0,
lstm_dropout=0.0,
condition_hidden=condition_hidden,
condition_output=condition_output,
prior=None,
bos_idx=bos_idx,
transition=None,
)
to_device(model)
opt = optim.Adam(model.parameters())
scheduler = optim.lr_scheduler.OneCycleLR(opt, max_lr=1e-3,
steps_per_epoch=len(dl), epochs=10)
losses = []
for i, batch in enumerate(dl):
batch = to_device(batch)
x,y = batch
preds = model(*x)
batch_loss = loss(preds, y)
opt.zero_grad()
batch_loss.backward()
opt.step()
scheduler.step()
losses.append(batch_loss.item())
plt.plot(losses)
# vec to text
from mrl.vocab import *
from mrl.dataloaders import *
from rdkit.Chem.Scaffolds import MurckoScaffold
from mrl.chem import to_mol, to_smile
df = pd.read_csv('files/smiles.csv')
source_smiles = [to_smile(MurckoScaffold.GetScaffoldForMol(to_mol(i))) for i in df.smiles.values]
target_smiles = df.smiles.values
smiles = [(source_smiles[i], target_smiles[i]) for i in range(len(source_smiles))]
vocab = CharacterVocab(SMILES_CHAR_VOCAB)
ds = Vec_To_Text_Dataset(smiles, vocab, ECFP6)
dl = ds.dataloader(16, num_workers=0)
loss = VAELoss()
d_vocab = len(vocab.itos)
d_embedding = 256
encoder_d_in = 2048
encoder_dims = [1024, 512]
encoder_drops = [0.1, 0.1]
d_hidden = 1024
d_latent = 512
n_layers = 3
input_dropout = 0.3
lstm_dropout = 0.3
bidir = False
condition_hidden = True
condition_output = True
bos_idx = vocab.stoi['bos']
model = MLP_VAE(
d_vocab,
d_embedding,
encoder_d_in,
encoder_dims,
encoder_drops,
d_hidden,
n_layers,
d_latent,
input_dropout=input_dropout,
lstm_dropout=lstm_dropout,
condition_hidden=condition_hidden,
condition_output=condition_output,
prior=None,
bos_idx=bos_idx,
transition=None,
)
to_device(model)
opt = optim.Adam(model.parameters())
scheduler = optim.lr_scheduler.OneCycleLR(opt, max_lr=1e-3,
steps_per_epoch=len(dl), epochs=10)
losses = []
for i, batch in enumerate(dl):
batch = to_device(batch)
x,y = batch
preds = model(*x)
batch_loss = loss(preds, y)
opt.zero_grad()
batch_loss.backward()
opt.step()
scheduler.step()
losses.append(batch_loss.item())
plt.plot(losses)
# vec to text with rollout
from mrl.vocab import *
from mrl.dataloaders import *
from rdkit.Chem.Scaffolds import MurckoScaffold
from mrl.chem import to_mol, to_smile
df = pd.read_csv('files/smiles.csv')
source_smiles = [to_smile(MurckoScaffold.GetScaffoldForMol(to_mol(i))) for i in df.smiles.values]
target_smiles = df.smiles.values
smiles = [(source_smiles[i], target_smiles[i]) for i in range(len(source_smiles))]
vocab = CharacterVocab(SMILES_CHAR_VOCAB)
ds = Vec_To_Text_Dataset(smiles, vocab, ECFP6)
dl = ds.dataloader(16, num_workers=0)
loss = VAELoss()
d_vocab = len(vocab.itos)
d_embedding = 256
encoder_d_in = 2048
encoder_dims = [1024, 512]
encoder_drops = [0.1, 0.1]
d_hidden = 1024
d_latent = 512
n_layers = 3
input_dropout = 0.3
lstm_dropout = 0.3
bidir = False
condition_hidden = True
condition_output = True
bos_idx = vocab.stoi['bos']
forward_rollout=True
p_force=1.
force_decay=0.9
model = MLP_VAE(
d_vocab,
d_embedding,
encoder_d_in,
encoder_dims,
encoder_drops,
d_hidden,
n_layers,
d_latent,
input_dropout=input_dropout,
lstm_dropout=lstm_dropout,
condition_hidden=condition_hidden,
condition_output=condition_output,
prior=None,
bos_idx=bos_idx,
transition=None,
forward_rollout=forward_rollout,
p_force=p_force,
force_decay=force_decay
)
to_device(model)
opt = optim.Adam(model.parameters())
scheduler = optim.lr_scheduler.OneCycleLR(opt, max_lr=1e-3,
steps_per_epoch=len(dl), epochs=10)
losses = []
for i, batch in enumerate(dl):
batch = to_device(batch)
x,y = batch
preds = model(*x)
batch_loss = loss(preds, y)
opt.zero_grad()
batch_loss.backward()
opt.step()
scheduler.step()
losses.append(batch_loss.item())
plt.plot(losses)