Adeno-Associated Virus Capsid Design
This tutorial runs an end to end workflow for designing AAV capsids
Adeno-Associated Viruses (AAV) are small viruses that generally do not cause disease. The non-pathogenic nature of these viruses makes them an attractive target for gene therapy delivey vectors. One drawback to using AAV capsids for gene deivery is they can be neutralized by natural immunity. This sets up a protein engineering problem. We want to develop new variants of the virus that caan evade the human immune system.
We could approach this with a scanning mutation approach, but this would result in a high number of invalid sequences. Instead, we will build a score function based off laboratory data and use a generative model to exploit this score function. We hope this approach will give more realistic variants.
Performance Notes
The workflow in this notebook is more CPU-constrained than GPU-constrained due to the need to evaluate samples on CPU. If you have a multi-core machine, it is recommended that you uncomment and run the set_global_pool
cells in the notebook. This will trigger the use of multiprocessing, which will result in 2-4x speedups.
This notebook may run slow on Collab due to CPU limitations.
If running on Collab, remember to change the runtime to GPU
import sys
sys.path.append('..')
from mrl.imports import *
from mrl.core import *
from mrl.chem import *
from mrl.templates.all import *
from mrl.torch_imports import *
from mrl.torch_core import *
from mrl.layers import *
from mrl.dataloaders import *
from mrl.g_models.all import *
from mrl.vocab import *
from mrl.policy_gradient import *
from mrl.train.all import *
from mrl.model_zoo import *
from sklearn.metrics import r2_score
os.makedirs('untracked_files', exist_ok=True)
Data
The dataset comes from the paper Generative AAV capsid diversification by latent interpolation . The authors looked at a 28 amino acid section of the AAV2 VP3 protein shown to have immunological significance.
The authors created a mutation library based on 564 naturally occurring sequences. They then trained a VAE model on this dataset, sampled new 28-AA sequences from this model, and tested them in the lab.
We will use their laboratory data to build a score function. Then we will run a generative screen against that score function
# ! wget https://raw.githubusercontent.com/churchlab/Generative_AAV_design/main/Data/vae2021_processed_data.csv
df = pd.read_csv('vae2021_processed_data.csv')
df.dropna(subset=['VAE_virus_S'], inplace=True)
df.head()
Our metric of interest is VAE_virus_S
. This is the log-2 ratio of the frequency of a variant in a virus pool relative to the frequency of the corresponding plasmid in the plasmid pool. Higher values indicate higher viability. The goal of the design is to produce variants predicted to beat the wildtype sequence. The dataset indicates which sequences are currently performing at this level
df.VAE_virus_S.hist(alpha=0.5, density=True, label='full_dataset')
df[(df.viable==1) & (df.beats_wt==1)].VAE_virus_S.hist(alpha=0.5, density=True, label='Beats WT')
plt.legend();
Score Function
Now we want to develop a score function for predicting antimicrobial activity. We will use a CNN encoder with a MLP head to predict the score described above.
Our input data will be token integers for amino acids. Note that fingerprint representations are a poor fit for peptide work because peptides contain many repeating substructures.
We will train on 90% of the data and validate on the 10% held out.
train_df = df.sample(frac=0.9, random_state=42).copy()
valid_df = df[~df.index.isin(train_df.index)].copy()
aa_vocab = CharacterVocab(AMINO_ACID_VOCAB)
train_ds = Text_Prediction_Dataset(train_df.aa.values, train_df.VAE_virus_S.values, aa_vocab)
test_ds = Text_Prediction_Dataset(valid_df.aa.values, valid_df.VAE_virus_S.values, aa_vocab)
This is the model we will use:
class Predictive_CNN(nn.Module):
def __init__(self,
d_vocab,
d_embedding,
d_latent,
filters,
kernel_sizes,
strides,
dropouts,
mlp_dims,
mlp_drops,
d_out,
outrange
):
super().__init__()
self.conv_encoder = Conv_Encoder(
d_vocab,
d_embedding,
d_latent,
filters,
kernel_sizes,
strides,
dropouts,
)
self.mlp_head = MLP(
d_latent,
mlp_dims,
d_out,
mlp_drops,
outrange=outrange
)
def forward(self, x):
encoded = self.conv_encoder(x)
out = self.mlp_head(encoded)
return out
d_vocab = len(aa_vocab.itos)
d_embedding = 256
d_latent = 512
filters = [256, 512, 1024]
kernel_sizes = [7, 7, 7]
strides = [2,2,2]
dropouts = [0.2, 0.2, 0.2]
mlp_dims = [512, 256, 128]
mlp_drops = [0.2, 0.2, 0.2]
d_out = 1
outrange = [-10, 10]
virus_model = Predictive_CNN(
d_vocab,
d_embedding,
d_latent,
filters,
kernel_sizes,
strides,
dropouts,
mlp_dims,
mlp_drops,
d_out,
outrange
)
r_agent = PredictiveAgent(virus_model, MSELoss(), train_ds, opt_kwargs={'lr':1e-3})
r_agent.train_supervised(32, 20, 1e-3)
Optional: save score function weights
Optional: to load the exact weights used, run the following:
valid_dl = test_ds.dataloader(256, num_workers=0, shuffle=False)
r_agent.model.eval();
preds = []
targs = []
with torch.no_grad():
for i, batch in enumerate(valid_dl):
batch = to_device(batch)
x,y = batch
pred = r_agent.model(x)
preds.append(pred.detach().cpu())
targs.append(y.detach().cpu())
preds = torch.cat(preds).numpy()
targs = torch.cat(targs).numpy()
preds = preds.squeeze(-1)
Our score function has an r^2 value of about 0.883 on the validation dataset
fig, ax = plt.subplots()
ax.scatter(targs, preds, alpha=0.5, s=1)
plt.xlabel('Target')
plt.ylabel('Prediction')
slope, intercept = np.polyfit(targs, preds, 1)
ax.plot(np.array(ax.get_xlim()), intercept + slope*np.array(ax.get_xlim()), c='r')
plt.text(5., 9., 'R-squared = %0.3f' % r2_score(targs, preds));
We should also take a look at the prediction distribution for our known actives. The predicted values will differ somewhat from the actual values. This will give us a sense of what score we want to see from the model
df['preds'] = r_agent.predict_data(df.aa.values).detach().cpu().numpy()
df.preds.max()
np.percentile(df.preds, 99)
The maximum predicted value is 8.63
. A compound scoring 5.33
or higher would be in the top 1% of all known sequences
Chemical Space
Next we need to develop our chemical space. This is where we decide what sequences will be allowed and what sequences will be removed.
Getting the right filtering parameters makes a huge difference in sequence quality. In practice, finding the right constraints is an interative process. First run a generative screen. Then examine the highest scoring sequences and look for undesirable properties or structural features. Then update the template and iterate.
For peptides, the presence of Arginine has shown to be toxic [ref]. We will apply a template filter for the number of Arginine residues per unit length.
We will also limit the maximum residue frequency in a sample to 0.3. This prevents a common failure mode of seeing the same residue repeated multiple times (ie MSSSSSSSRP
). This is a flaw in the simplistic score functions we are using
aa_vocab = CharacterVocab(AMINO_ACID_VOCAB)
template = Template([ValidityFilter(),
CharacterCountFilter(['A'], min_val=0, max_val=0.1, per_length=True, mode='protein'),
CharacterCountFilter(aa_vocab.itos[4:], min_val=0, max_val=0.3,
per_length=True, mode='protein')],
[], fail_score=-10., log=False, use_lookup=False, mode='protein')
template_cb = TemplateCallback(template, prefilter=True)
Load Model
We load the LSTM_LM_Small_Swissprot
model. This is a basic LSTM-based language model trained on part of the Swissprot protein database
agent = LSTM_LM_Small_Swissprot(drop_scale=0.3, opt_kwargs={'lr':5e-5})
Fine-Tune Model
The pretrained model we loaded is a very general model that can produce a high diversity of structures. However, what we actually want are structues with antimicrobial activity. To induce this, we fine-tune on high scoring sequences.
The dataset denotes two categories for high scoring sequences. The viable
category contains ~3000 samples that were able to successfully assemble and package genetic material. The beats_wt
category contains ~300 samples with higher expression than the wildtype variant.
The beats_wt
dataset feels a bit small for fine-tuning, so we will first fine-tune on the viable
dataset, then the beats_wt
dataset
agent.update_dataset_from_inputs(df[df.viable==1].aa.values)
agent.train_supervised(32, 8, 5e-5)
agent.update_dataset_from_inputs(df[df.beats_wt==1].aa.values)
agent.train_supervised(32, 6, 5e-5)
agent.base_to_model()
Optional: save fine-tuned weights
Loss
We use PPO
as our policy gradient loss
pg = PPO(0.99,
0.5,
lam=0.95,
v_coef=0.5,
cliprange=0.3,
v_cliprange=0.3,
ent_coef=0.01,
kl_target=0.03,
kl_horizon=3000,
scale_rewards=True)
loss = PolicyLoss(pg, 'PPO',
value_head=ValueHead(256),
v_update_iter=2,
vopt_kwargs={'lr':1e-3})
aa_vocab = CharacterVocab(AMINO_ACID_VOCAB)
d_vocab = len(aa_vocab.itos)
d_embedding = 256
d_latent = 512
filters = [256, 512, 1024]
kernel_sizes = [7, 7, 7]
strides = [2,2,2]
dropouts = [0.2, 0.2, 0.2]
mlp_dims = [512, 256, 128]
mlp_drops = [0.2, 0.2, 0.2]
d_out = 1
outrange = [-10, 10]
reward_model = Predictive_CNN(
d_vocab,
d_embedding,
d_latent,
filters,
kernel_sizes,
strides,
dropouts,
mlp_dims,
mlp_drops,
d_out,
outrange
)
r_ds = Text_Prediction_Dataset(['M'], [0.], aa_vocab)
r_agent = PredictiveAgent(reward_model, MSELoss(), r_ds, opt_kwargs={'lr':1e-3})
r_agent.load_weights('untracked_files/virus_predictor.pt')
# r_agent.load_state_dict(model_from_url('virus_predictor.pt')) # optional - load exact weights
reward_model.eval();
freeze(reward_model)
reward_function = Reward(r_agent.predict_data, weight=1)
virus_reward = RewardCallback(reward_function, 'virus')
Optional Reward: Stability Metric
There has been a lot of great work recently looking at large scale transformer language models for unsupervised learning of protein structures. One interesting thing that has emerged is a rough relationship between the protein sequence log probability given by a generative model and the stability of the protein sequence.
We can use the log probability values from a pretrained protein transformer model as a proxy for stability. Including this as a reward function can help keep the generated peptides realistic.
To include this as a reward term, run the code below to install the ESM library to access a pretrained protein transformer model.
In the interest of time, we will use the smallest ESM model with 43M parameters, rather than the large scale 630M parameter model. Note that even with the smaller model, this reward term adds significanly to the training runtime
# ! pip install fair-esm
import esm
protein_model, alphabet = esm.pretrained.esm1_t6_43M_UR50S()
batch_converter = alphabet.get_batch_converter()
class PeptideStability():
def __init__(self, model, alphabet, batch_converter):
self.model = model
to_device(self.model)
self.alphabet = alphabet
self.batch_converter = batch_converter
def __call__(self, samples):
data = [
(f'protein{i}', samples[i]) for i in range(len(samples))
]
batch_labels, batch_strs, batch_tokens = self.batch_converter(data)
with torch.no_grad():
results = self.model(to_device(batch_tokens))
lps = F.log_softmax(results['logits'], -1)
mean_lps = lps.gather(2, to_device(batch_tokens).unsqueeze(-1)).squeeze(-1).mean(-1)
return mean_lps
ps = PeptideStability(protein_model, alphabet, batch_converter)
stability_reward = Reward(ps, weight=0.1, bs=300)
stability_cb = RewardCallback(stability_reward, name='stability')
stability_reward(df.aa.values[:10])
Samplers
We create the following samplers:
sampler1 ModelSampler
: this samples from the main model. This sample will add 1000 compounds to the buffer each buffer build, and sample 40% of each batch on the fly from the main model.sampler2 ModelSampler
: this samples from the baseline model and is not sampled live on each batchsampler3 LogSampler
: this samples high scoring samples from the ligsampler4 TokenSwapSampler
: this uses token swap comibichem to generate new samples from high scoring samplessampler5 DatasetSampler
: this sprinkles in a small amount of high scoring samples into each buffer build.
gen_bs = 1500
sampler1 = ModelSampler(agent.vocab, agent.model, 'live', 1000, 0., gen_bs)
sampler2 = ModelSampler(agent.vocab, agent.base_model, 'base', 1000, 0., gen_bs)
sampler3 = LogSampler('samples', 'rewards', 10, 98, 200)
sampler4 = TokenSwapSampler('samples', 'rewards', 10, 98, 200, aa_vocab, 0.2)
sampler5 = DatasetSampler(df[(df.beats_wt==1)].aa.values,
'data', buffer_size=6)
samplers = [sampler1, sampler2, sampler3, sampler4, sampler5]
Callbacks
Additional callbacks
SupervisedCB
: runs supervised training on the top 3% of samples every 400 batchesMaxCallback
: prints the max reward for each batchPercentileCallback
: prints the 90th percentile score each batch
supervised_cb = SupervisedCB(agent, 20, 0.5, 98, 1e-4, 64)
live_max = MaxCallback('rewards', 'live')
live_p90 = PercentileCallback('rewards', 'live', 90)
cbs = [supervised_cb, live_p90, live_max]
env = Environment(agent, template_cb, samplers=samplers, rewards=[virus_reward, stability_cb], losses=[loss],
cbs=cbs)
env.fit(128, 28, 400, 25)
env.log.plot_metrics()
Evaluation
Based on our score function, we determined a sequence with a predicted score of 5.33 or higher would be in the top 1% of sequences relative to the training data. A sequence with a predicted score of 8.63 or higher would beat out all sequences in the dataset
We found 1670 sequences with a predicted score of 5.33 or higher and 13 sequences with a predicted score of 8.63 or higher
env.log.df[(env.log.df.virus>5.33) & ~(env.log.df.sources=='data_buffer')].shape
env.log.df[(env.log.df.virus>8.63) & ~(env.log.df.sources=='data_buffer')].shape
We can generate logo stack plots to visualize what residues are favored
def plot_logo(seqs):
freqs = []
for i in range(len(seqs[0])):
aas = [j[i] for j in seqs]
counts = Counter(aas)
total = sum(counts.values())
norm_counts = defaultdict(lambda: 0)
for k in counts.keys():
norm_counts[k] = counts[k]/total
freqs.append(norm_counts)
aas = aa_vocab.itos[4:]
dfs = []
for i, f in enumerate(freqs):
df_iter = pd.DataFrame([[f[aa] for aa in aas]], columns=aas)
df_iter['Position'] = i
dfs.append(df_iter)
dfs = pd.concat(dfs)
dfs = dfs.set_index('Position')
return dfs.groupby('Position').mean().plot.bar(stacked=True, figsize=(12,8))
Here's the logo plot for high scoring sequences in the dataset
plot_logo(df[(df.beats_wt==1) & (df.VAE_virus_S>7)].aa.values)
Here's the residue plot for high scoring generated sequences
plot_logo(env.log.df[(env.log.df.virus>8.9) & ~(env.log.df.sources=='data_buffer')].samples.values)
We can see similarities between dataset and generated sequences at many positions, with some significant differences (ie position 15)
So are these residue changes real and meaningful? That depends on the quality of our score function. The only way to know is to test sequences in the lab