Pytorch model layers
layer = LinearBlock(128, 64, bn=True, dropout=0.5)
_ = layer(torch.randn(16,128))
p = NormalPrior(torch.zeros((64,)), torch.zeros((64,)), trainable=True)
assert p.rsample(5).requires_grad
assert not p.sample(5).requires_grad
p = SphericalPrior(torch.zeros((2,)), torch.zeros((2,)), trainable=False)
assert not p.rsample(5).requires_grad
assert not p.sample(5).requires_grad
d_embedding=64
d_hidden=128
d_latent = 32
n_layers = 2
l1 = Conditional_LSTM(d_embedding, d_hidden, d_embedding, d_latent, n_layers,
condition_hidden=True, condition_output=True,
bidir=False, batch_first=True)
l2 = Conditional_LSTM(d_embedding, d_hidden, d_embedding, d_latent, n_layers,
condition_hidden=True, condition_output=True,
bidir=True, batch_first=True)
l3 = Conditional_LSTM(d_embedding, d_hidden, d_embedding, d_latent, n_layers,
condition_hidden=False, condition_output=True,
bidir=False, batch_first=True)
l4 = Conditional_LSTM(d_embedding, d_hidden, d_embedding, d_latent, n_layers,
condition_hidden=True, condition_output=False,
bidir=False, batch_first=True)
l5 = Conditional_LSTM(d_embedding, d_hidden, d_embedding, d_latent, n_layers,
condition_hidden=False, condition_output=False,
bidir=True, batch_first=True, input_dropout=0.5, lstm_dropout=0.5)
bs = 12
x = torch.randn((bs, 21, d_embedding))
z = torch.randn((bs, d_latent))
_ = l1(x,z)
_ = l1(x,z, l1.latent_to_hidden(z))
_ = l2(x,z)
_ = l2(x,z, l2.latent_to_hidden(z))
_ = l3(x,z)
_ = l3(x,z, l3.get_new_hidden(bs))
_ = l4(x,z)
_ = l4(x,z, l4.get_new_hidden(bs))
_ = l5(x,z)
_ = l5(x,None)
_ = l5(x,z, l5.get_new_hidden(bs))
l1 = LSTM(d_embedding, d_hidden, d_embedding, n_layers, bidir=False, batch_first=True)
l2 = LSTM(d_embedding, d_hidden, d_embedding, n_layers, bidir=True, batch_first=True,
input_dropout=0.5, lstm_dropout=0.5)
_ = l1(x)
_ = l1(x, l1.get_new_hidden(bs))
_ = l2(x)
_ = l2(x, l2.get_new_hidden(bs))
d_latent = 128
l = LSTM_Encoder(32, 64, 128, 2, 128, input_dropout=0.5, lstm_dropout=0.5)
assert l(torch.randint(0,31, (10,15))).shape[-1] == d_latent
m = MLP_Encoder(128, [64, 32, 16], d_latent, [0.1, 0.1, 0.1])
assert m(torch.randn(8,128)).shape[-1] == d_latent
c = Conv_Encoder(32, 64, d_latent, [32, 16], [7,7], [2,2], [0.1, 0.1])
assert c(torch.randint(0,31, (10,15))).shape[-1] == d_latent
active_model = MLP(128, [64, 32], 1, [0.2, 0.2], outrange=[0,15])
p = active_model(torch.randn((32,128)))
assert p.min()>=0 and p.max()<=15