VAE-type models

VAE Models

Variational Autoencoders (VAEs) are a class of generative model that follows an encoder-decoder framework. The encoder maps an input down to a latent vector. The decoder reconstructs the input, and the model is trained on the quality of the reconstruction. VAEs impose a Gaussian prior on the latent space by regularizing the deviation of latent vectors from a unit Gaussian. As a result, we can easily sample valid latent vectors by randomly sampling from a Gaussian distribution.

When dealing with generating text sequences (SMILES, amino acids, etc), we train the model in a self-supervised fashion using next token prediction. In this framework, the decoder module is a conditional LSTM. These models are very similar to the Conditional_LSTM_LM model, with the added constraint on the latent space.

During generation, we sample from the model in an autoregression fashion. Given an input token, the model predicts a distribution of tokens over the next token. We then sample from that distribution and feed the selected token back into the model. We repeat this process until either an end of sentence (EOS) token is predicted, or the generated sequence reaches a maximum allowed length.

During sampling, we save the log probability of each token predicted. This gives us a probability value for the model's estimated likelihood of the generated compound. We can also backprop through this value.

VAE Transition

The VAE_Transition class converts an input vector to a valid latent vector. Following standard VAE implementations, the input vector is used to predict a vector of means and a vector of log variances which parameterize a Gaussian distribution. We then use the Reparametarization Trick to sample a latent vector from that distribution.

The module returns the latent vectors and a KL-loss based on the latent distributions deviations from a unit Gaussian.

class VAE_Transition[source]

VAE_Transition(d_latent) :: Module

VAE_Transition - converts an input vector to a latent vector using the reparametarization trick

Inputs:

  • d_latent int: latent variable dimensions

VAE

VAE is the base class for VAE models. VAE models consist of a encoder, decoder, and transition modules.

The encoder module maps the model's inputs down to a vector. The model will work with any of the Encoder subclasses in the layers section. Custom encoders also work so long as they are compatible with the data format you are using and produce a single vector per batch item.

The transition module converts the encoder output to a valid latent vector. By default, VAEs use the VAE_Transition which samples a latent vector assuming a target distribution of a unit Gaussian. This can be substituted for any module that returns latent vectors and a KL loss term (0 if not applicable).

The decoder module uses the latent vector to reconstruct the input.

class VAE[source]

VAE(encoder, decoder, prior=None, bos_idx=0, transition=None, forward_rollout=False, p_force=0.0, force_decay=0.99) :: GenerativeModel

VAE - base VAE class

Inputs:

  • encoder nn.Module: encoder module

  • decoder nn.Module: decoder module

  • prior Optional[nn.Module]: prior module

  • bos_idx int: BOS token

  • transition Optional[nn.Module]: transition module

  • forward_rollout bool: if True, run supervised training using rollout with teacher forcing. This is a technique used in some seq2seq models and should not be used for pure generative models

  • p_force float: teacher forcing probabiliy

  • force_decay float: rate of decay of p_force

encoder = LSTM_Encoder(32, 64, 128, 2, 128)
decoder = Conditional_LSTM_Block(32, 64, 128, 64, 128, 2,
                                condition_hidden=True, condition_output=True)
vae = VAE(encoder, decoder)

ints = torch.randint(0, 31, (8, 10))
x = ints[:,:-1]
y = ints[:,1:]

_ = vae(x)
decoder = Conditional_LSTM_Block(32, 64, 128, 64, 128, 2,
                                condition_hidden=False, condition_output=True)
vae = VAE(encoder, decoder)

_ = vae(x)
decoder = Conditional_LSTM_Block(32, 64, 128, 64, 128, 2,
                                condition_hidden=True, condition_output=False)
vae = VAE(encoder, decoder)

_ = vae(x)

_ = vae.sample(8, 16)

z = vae.prior.sample([8])
_ = vae.sample(8, 16, z=z)

o,lp,lpg,e = vae.get_rl_tensors(x,y)
vae.set_prior_from_encoder(x[0].unsqueeze(0), trainable=True);

o,lp,lpg,e = vae.get_rl_tensors(x,y)
loss = lpg.mean()
assert vae.prior.loc.grad is None
loss.backward()
assert vae.prior.loc.grad is not None

latent = torch.randn((x.shape[0], 128))
o,lp,lpg,e = vae.get_rl_tensors(x,y,latent=latent)

Preset VAEs

The following modules are preset encoder-decoder configurations.

LSTM_VAE has a LSTM encoder and a conditional LSTM decoder. This works for most sequence to sequence type tasks.

Conv_VAE replaces the LSTM encoder with a 1D conv encoder (the model is still intended for sequences, not 2D images). The convolutional encoder tends to be faster and lighter than the LSTM encoder. Anecdotally, it seems to perform better than the LSTM encoder.

MLP_VAE uses a MLP encoder. This model works best for reconstructing a sequence from a fingerprint, property vector, or something similar.

class LSTM_VAE[source]

LSTM_VAE(d_vocab, d_embedding, d_hidden, n_layers, d_latent, input_dropout=0.0, lstm_dropout=0.0, condition_hidden=True, condition_output=True, prior=None, bos_idx=0, transition=None, forward_rollout=False, p_force=0.0, force_decay=0.99) :: VAE

LSTM_VAE - VAE with LSTM encoder and conditional LSTM decoder. Usable for text-to-text or seq-2-seq tasks or similar

Inputs:

  • d_vocab int: vocab size

  • d_embedding int: embedding dimension

  • d_hidden int: hidden dimension

  • n_layers int: number of LSTM layers (same for encoder and decoder)

  • d_latent int: latent vector dimension

  • input_dropout float: dropout percentage on inputs

  • lstm_dropout float: dropout on LSTM layers

  • condition_hidden bool: if True, latent vector is used to initialize the hidden state

  • condition_output bool: if True, latent vector is concatenated to inputs

  • prior Optional[nn.Module]: optional prior distribution to sample from. see Prior

  • bos_idx int: beginning of sentence token

  • transition Optional[nn.Module]: transition module

  • forward_rollout bool: if True, run supervised training using rollout with teacher forcing. This is a technique used in some seq2seq models and should not be used for pure generative models

  • p_force float: teacher forcing probabiliy

  • force_decay float: rate of decay of p_force

vae = LSTM_VAE(32, 64, 128, 2, 128, condition_hidden=True, condition_output=True)

ints = torch.randint(0, 31, (8, 10))
x = ints[:,:-1]
y = ints[:,1:]

_ = vae(x)

_ = vae.sample(8, 16)

z = vae.prior.sample([8])
_ = vae.sample(8, 16, z=z)

_ = vae.get_rl_tensors(x,y)

class Conv_VAE[source]

Conv_VAE(d_vocab, d_embedding, conv_filters, kernel_sizes, strides, conv_drops, d_hidden, n_layers, d_latent, input_dropout=0.0, lstm_dropout=0.0, condition_hidden=True, condition_output=True, prior=None, bos_idx=0, transition=None, forward_rollout=False, p_force=0.0, force_decay=0.99) :: VAE

Conv_VAE - VAE with 1D Conv encoder and conditional LSTM decoder. Usable for text-to-text or seq-2-seq tasks or similar

Inputs:

  • d_vocab int: vocab size

  • d_embedding int: embedding dimension

  • conv_filters list[int]: filter sizes for conv layers ie [64, 128, 256]

  • kernel_sizes list[int]: kernel sizes for conv layers ie [5, 5, 5]

  • strides list[int]: strides for conv layers ie [2, 2, 2]

  • conv_drops list[float]: list of dropout pobabilities ie [0.2, 0.2, 0.3]

  • d_hidden int: hidden dimension

  • n_layers int: number of LSTM layers (same for encoder and decoder)

  • d_latent int: latent variable dimension

  • input_dropout float: dropout percentage on inputs

  • lstm_dropout float: dropout on LSTM layers

  • condition_hidden bool: if True, latent vector is used to initialize the hidden state

  • condition_output bool: if True, latent vector is concatenated to the outputs of the embedding layer

  • prior Optional[nn.Module]: prior module

  • bos_idx int: BOS token

  • transition Optional[nn.Module]: transition module

  • forward_rollout bool: if True, run supervised training using rollout with teacher forcing. This is a technique used in some seq2seq models and should not be used for pure generative models

  • p_force float: teacher forcing probabiliy

  • force_decay float: rate of decay of p_force

vae = Conv_VAE(32, 64, [128, 256], [7,7], [2,2], [0.1,0.1], 128, 2, 128, 
               condition_hidden=False, condition_output=True)

ints = torch.randint(0, 31, (8, 10))
x = ints[:,:-1]
y = ints[:,1:]

_ = vae(y,x)

_ = vae.sample(8, 16)

z = vae.prior.sample([8])
_ = vae.sample(8, 16, z=z)

_ = vae.get_rl_tensors(x,y)

class MLP_VAE[source]

MLP_VAE(d_vocab, d_embedding, encoder_d_in, encoder_dims, encoder_drops, d_hidden, n_layers, d_latent, input_dropout=0.0, lstm_dropout=0.0, condition_hidden=True, condition_output=True, prior=None, bos_idx=0, transition=None, forward_rollout=False, p_force=0.0, force_decay=0.99) :: VAE

MLP_VAE - VAE with MLP encoder and conditional LSTM decoder. Usable for reconstructing a sequence from a vector

Inputs:

  • d_vocab int: vocab size

  • d_embedding int: embedding dimension

  • encoder_d_in int: encoder input dimension

  • encoder_dims list[int]: list of encoder layer sizes ie [1024, 512, 256]

  • encoder_drops list[float]: list of dropout pobabilities ie [0.2, 0.2, 0.3]

  • d_hidden int: hidden dimension

  • n_layers int: number of LSTM layers (same for encoder and decoder)

  • d_latent int: latent variable dimension

  • input_dropout float: dropout percentage on inputs

  • lstm_dropout float: dropout on LSTM layers

  • condition_hidden bool: if True, latent vector is used to initialize the hidden state

  • condition_output bool: if True, latent vector is concatenated to the outputs of the embedding layer

  • prior Optional[nn.Module]: prior module

  • bos_idx int: BOS token

  • transition Optional[nn.Module]: transition module

  • forward_rollout bool: if True, run supervised training using rollout with teacher forcing. This is a technique used in some seq2seq models and should not be used for pure generative models

  • p_force float: teacher forcing probabiliy

  • force_decay float: rate of decay of p_force

vae = MLP_VAE(32, 64, 128, [64, 32], [0.1, 0.1], 128, 2, 128, 
               condition_hidden=False, condition_output=True)

ints = torch.randint(0, 31, (8, 10))
x = ints[:,:-1]
y = ints[:,1:]

condition = torch.randn((8,128))


_ = vae(x, condition)

_ = vae.sample(8, 16)

z = vae.prior.sample([8])
_ = vae.sample(8, 16, z=z)

_ = vae.get_rl_tensors([x, condition],y)

class VAELoss[source]

VAELoss(weight=1.0)

VAELoss - loss for VAE models

Inputs:

  • weight float: KL loss weight

vae_seq2seq_collate[source]

vae_seq2seq_collate(batch, pad_idx, batch_first=True)

class VAE_Seq2Seq_Dataset[source]

VAE_Seq2Seq_Dataset(*args, **kwds) :: Text_Dataset

VAE_Seq2Seq_Dataset - seq to seq dataset for VAEs

Inputs:

  • sequences - list[tuple]: list of text tuples (source, target)

  • vocab - Vocab: vocabuary for tokenization/numericaization

  • collate_function Callable: batch collate function. If None, defauts to vae_seq2seq_collate

# text reconstruction

from mrl.vocab import *
from mrl.dataloaders import *

df = pd.read_csv('files/smiles.csv')
vocab = CharacterVocab(SMILES_CHAR_VOCAB)

ds = Text_Dataset(df.smiles.values, vocab)
dl = ds.dataloader(16, num_workers=0)
loss = VAELoss()

d_vocab = len(vocab.itos)
d_embedding = 256
d_hidden = 1024
d_latent = 512
n_layers = 3
input_dropout = 0.3
lstm_dropout = 0.3
bidir = False
condition_hidden = True
condition_output = True
bos_idx = vocab.stoi['bos']

model = LSTM_VAE(
                d_vocab,
                d_embedding,
                d_hidden,
                n_layers,
                d_latent,
                input_dropout=0.0,
                lstm_dropout=0.0,
                condition_hidden=condition_hidden,
                condition_output=condition_output,
                prior=None,
                bos_idx=bos_idx,
                transition=None,
            )


to_device(model)

opt = optim.Adam(model.parameters())

scheduler = optim.lr_scheduler.OneCycleLR(opt, max_lr=1e-3,
                                    steps_per_epoch=len(dl), epochs=10)

losses = []

for i, batch in enumerate(dl):
    batch = to_device(batch)
    x,y = batch
    preds = model(x)
    batch_loss = loss(preds, y)
    opt.zero_grad()
    batch_loss.backward()
    opt.step()
    scheduler.step()
    losses.append(batch_loss.item())
    
plt.plot(losses)
[<matplotlib.lines.Line2D at 0x7f152c3d44e0>]
# text reconstruction conv

from mrl.vocab import *
from mrl.dataloaders import *

df = pd.read_csv('files/smiles.csv')
vocab = CharacterVocab(SMILES_CHAR_VOCAB)

ds = Text_Dataset(df.smiles.values, vocab)
dl = ds.dataloader(16, num_workers=0)
loss = VAELoss()

d_vocab = len(vocab.itos)
d_embedding = 256
conv_filters = [128, 256, 512]
kernel_sizes = [7,7,7]
strides = [2,2,2]
conv_drops = [0.3, 0.3, 0.3]
d_hidden = 1024
d_latent = 512
n_layers = 3
input_dropout = 0.3
lstm_dropout = 0.3
bidir = False
condition_hidden = True
condition_output = True
bos_idx = vocab.stoi['bos']

model = Conv_VAE(
                d_vocab,
                d_embedding,
                conv_filters,
                kernel_sizes,
                strides,
                conv_drops,
                d_hidden,
                n_layers,
                d_latent,
                input_dropout=input_dropout,
                lstm_dropout=lstm_dropout,
                condition_hidden=condition_hidden,
                condition_output=condition_output,
                prior=None,
                bos_idx=bos_idx,
                transition=None,
            )


to_device(model)

opt = optim.Adam(model.parameters())

scheduler = optim.lr_scheduler.OneCycleLR(opt, max_lr=1e-3,
                                    steps_per_epoch=len(dl), epochs=10)

losses = []

for i, batch in enumerate(dl):
    batch = to_device(batch)
    x,y = batch
    preds = model(x)
    batch_loss = loss(preds, y)
    opt.zero_grad()
    batch_loss.backward()
    opt.step()
    scheduler.step()
    losses.append(batch_loss.item())
    
plt.plot(losses)
[<matplotlib.lines.Line2D at 0x7f7ce077d2b0>]
# vector to text

from mrl.vocab import *
from mrl.dataloaders import *
from mrl.chem import ECFP6

df = pd.read_csv('files/smiles.csv')
vocab = CharacterVocab(SMILES_CHAR_VOCAB)

ds = Vec_To_Text_Dataset(df.smiles.values, vocab, ECFP6)
dl = ds.dataloader(16, num_workers=0)
loss = VAELoss()

d_vocab = len(vocab.itos)
d_embedding = 256
encoder_d_in = 2048
encoder_dims = [1024, 512]
encoder_drops = [0.1, 0.1]
d_hidden = 1024
d_latent = 512
n_layers = 3
input_dropout = 0.3
lstm_dropout = 0.3
bidir = False
condition_hidden = True
condition_output = True
bos_idx = vocab.stoi['bos']


model = MLP_VAE(
            d_vocab,
            d_embedding,
            encoder_d_in,
            encoder_dims,
            encoder_drops,
            d_hidden,
            n_layers,
            d_latent,
            input_dropout=input_dropout,
            lstm_dropout=lstm_dropout,
            condition_hidden=condition_hidden,
            condition_output=condition_output,
            prior=None,
            bos_idx=bos_idx,
            transition=None,
        )

to_device(model)

opt = optim.Adam(model.parameters())

scheduler = optim.lr_scheduler.OneCycleLR(opt, max_lr=1e-3,
                                    steps_per_epoch=len(dl), epochs=10)

losses = []

for i, batch in enumerate(dl):
    batch = to_device(batch)
    x,y = batch
    preds = model(*x)
    batch_loss = loss(preds, y)
    opt.zero_grad()
    batch_loss.backward()
    opt.step()
    scheduler.step()
    losses.append(batch_loss.item())
    
plt.plot(losses)
/home/dmai/miniconda3/envs/mrl/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: to-Python converter for boost::shared_ptr<RDKit::FilterCatalogEntry const> already registered; second conversion method ignored.
  return f(*args, **kwds)
[<matplotlib.lines.Line2D at 0x7f7ce07b1a58>]
# text to text

from mrl.vocab import *
from mrl.dataloaders import *
from rdkit.Chem.Scaffolds import MurckoScaffold
from mrl.chem import to_mol, to_smile


df = pd.read_csv('files/smiles.csv')
source_smiles = [to_smile(MurckoScaffold.GetScaffoldForMol(to_mol(i))) for i in df.smiles.values]
target_smiles = df.smiles.values
smiles = [(source_smiles[i], target_smiles[i]) for i in range(len(source_smiles))]

vocab = CharacterVocab(SMILES_CHAR_VOCAB)

ds = VAE_Seq2Seq_Dataset(smiles, vocab)
dl = ds.dataloader(16, num_workers=0)
loss = VAELoss()

d_vocab = len(vocab.itos)
d_embedding = 256
d_hidden = 1024
d_latent = 512
n_layers = 3
input_dropout = 0.3
lstm_dropout = 0.3
bidir = False
condition_hidden = True
condition_output = True
bos_idx = vocab.stoi['bos']

model = LSTM_VAE(
                d_vocab,
                d_embedding,
                d_hidden,
                n_layers,
                d_latent,
                input_dropout=0.0,
                lstm_dropout=0.0,
                condition_hidden=condition_hidden,
                condition_output=condition_output,
                prior=None,
                bos_idx=bos_idx,
                transition=None,
            )


to_device(model)

opt = optim.Adam(model.parameters())

scheduler = optim.lr_scheduler.OneCycleLR(opt, max_lr=1e-3,
                                    steps_per_epoch=len(dl), epochs=10)

losses = []

for i, batch in enumerate(dl):
    batch = to_device(batch)
    x,y = batch
    preds = model(*x)
    batch_loss = loss(preds, y)
    opt.zero_grad()
    batch_loss.backward()
    opt.step()
    scheduler.step()
    losses.append(batch_loss.item())
    
plt.plot(losses)
[<matplotlib.lines.Line2D at 0x7f7c41a73f98>]
# vec to text

from mrl.vocab import *
from mrl.dataloaders import *
from rdkit.Chem.Scaffolds import MurckoScaffold
from mrl.chem import to_mol, to_smile


df = pd.read_csv('files/smiles.csv')
source_smiles = [to_smile(MurckoScaffold.GetScaffoldForMol(to_mol(i))) for i in df.smiles.values]
target_smiles = df.smiles.values
smiles = [(source_smiles[i], target_smiles[i]) for i in range(len(source_smiles))]

vocab = CharacterVocab(SMILES_CHAR_VOCAB)


ds = Vec_To_Text_Dataset(smiles, vocab, ECFP6)
dl = ds.dataloader(16, num_workers=0)
loss = VAELoss()

d_vocab = len(vocab.itos)
d_embedding = 256
encoder_d_in = 2048
encoder_dims = [1024, 512]
encoder_drops = [0.1, 0.1]
d_hidden = 1024
d_latent = 512
n_layers = 3
input_dropout = 0.3
lstm_dropout = 0.3
bidir = False
condition_hidden = True
condition_output = True
bos_idx = vocab.stoi['bos']


model = MLP_VAE(
            d_vocab,
            d_embedding,
            encoder_d_in,
            encoder_dims,
            encoder_drops,
            d_hidden,
            n_layers,
            d_latent,
            input_dropout=input_dropout,
            lstm_dropout=lstm_dropout,
            condition_hidden=condition_hidden,
            condition_output=condition_output,
            prior=None,
            bos_idx=bos_idx,
            transition=None,
        )

to_device(model)

opt = optim.Adam(model.parameters())

scheduler = optim.lr_scheduler.OneCycleLR(opt, max_lr=1e-3,
                                    steps_per_epoch=len(dl), epochs=10)

losses = []

for i, batch in enumerate(dl):
    batch = to_device(batch)
    x,y = batch
    preds = model(*x)
    batch_loss = loss(preds, y)
    opt.zero_grad()
    batch_loss.backward()
    opt.step()
    scheduler.step()
    losses.append(batch_loss.item())
    
plt.plot(losses)
[<matplotlib.lines.Line2D at 0x7f7c419e7e48>]
# vec to text with rollout

from mrl.vocab import *
from mrl.dataloaders import *
from rdkit.Chem.Scaffolds import MurckoScaffold
from mrl.chem import to_mol, to_smile


df = pd.read_csv('files/smiles.csv')
source_smiles = [to_smile(MurckoScaffold.GetScaffoldForMol(to_mol(i))) for i in df.smiles.values]
target_smiles = df.smiles.values
smiles = [(source_smiles[i], target_smiles[i]) for i in range(len(source_smiles))]

vocab = CharacterVocab(SMILES_CHAR_VOCAB)


ds = Vec_To_Text_Dataset(smiles, vocab, ECFP6)
dl = ds.dataloader(16, num_workers=0)
loss = VAELoss()

d_vocab = len(vocab.itos)
d_embedding = 256
encoder_d_in = 2048
encoder_dims = [1024, 512]
encoder_drops = [0.1, 0.1]
d_hidden = 1024
d_latent = 512
n_layers = 3
input_dropout = 0.3
lstm_dropout = 0.3
bidir = False
condition_hidden = True
condition_output = True
bos_idx = vocab.stoi['bos']
forward_rollout=True
p_force=1.
force_decay=0.9


model = MLP_VAE(
            d_vocab,
            d_embedding,
            encoder_d_in,
            encoder_dims,
            encoder_drops,
            d_hidden,
            n_layers,
            d_latent,
            input_dropout=input_dropout,
            lstm_dropout=lstm_dropout,
            condition_hidden=condition_hidden,
            condition_output=condition_output,
            prior=None,
            bos_idx=bos_idx,
            transition=None,
            forward_rollout=forward_rollout,
            p_force=p_force,
            force_decay=force_decay
        )

to_device(model)

opt = optim.Adam(model.parameters())

scheduler = optim.lr_scheduler.OneCycleLR(opt, max_lr=1e-3,
                                    steps_per_epoch=len(dl), epochs=10)

losses = []

for i, batch in enumerate(dl):
    batch = to_device(batch)
    x,y = batch
    preds = model(*x)
    batch_loss = loss(preds, y)
    opt.zero_grad()
    batch_loss.backward()
    opt.step()
    scheduler.step()
    losses.append(batch_loss.item())
    
plt.plot(losses)
[<matplotlib.lines.Line2D at 0x7f7c41922da0>]