Using MRL to design AAV capsids

Adeno-Associated Virus Capsid Design

This tutorial runs an end to end workflow for designing AAV capsids

Adeno-Associated Viruses (AAV) are small viruses that generally do not cause disease. The non-pathogenic nature of these viruses makes them an attractive target for gene therapy delivey vectors. One drawback to using AAV capsids for gene deivery is they can be neutralized by natural immunity. This sets up a protein engineering problem. We want to develop new variants of the virus that caan evade the human immune system.

We could approach this with a scanning mutation approach, but this would result in a high number of invalid sequences. Instead, we will build a score function based off laboratory data and use a generative model to exploit this score function. We hope this approach will give more realistic variants.

Performance Notes

The workflow in this notebook is more CPU-constrained than GPU-constrained due to the need to evaluate samples on CPU. If you have a multi-core machine, it is recommended that you uncomment and run the set_global_pool cells in the notebook. This will trigger the use of multiprocessing, which will result in 2-4x speedups.

This notebook may run slow on Collab due to CPU limitations.

If running on Collab, remember to change the runtime to GPU

import sys
sys.path.append('..')

from mrl.imports import *
from mrl.core import *
from mrl.chem import *
from mrl.templates.all import *

from mrl.torch_imports import *
from mrl.torch_core import *
from mrl.layers import *
from mrl.dataloaders import *
from mrl.g_models.all import *
from mrl.vocab import *
from mrl.policy_gradient import *
from mrl.train.all import *
from mrl.model_zoo import *
from sklearn.metrics import r2_score
/home/dmai/miniconda3/envs/bio-transformers/lib/python3.7/importlib/_bootstrap.py:219: RuntimeWarning: to-Python converter for boost::shared_ptr<RDKit::FilterCatalogEntry const> already registered; second conversion method ignored.
  return f(*args, **kwds)
os.makedirs('untracked_files', exist_ok=True)

Data

The dataset comes from the paper Generative AAV capsid diversification by latent interpolation . The authors looked at a 28 amino acid section of the AAV2 VP3 protein shown to have immunological significance.

The authors created a mutation library based on 564 naturally occurring sequences. They then trained a VAE model on this dataset, sampled new 28-AA sequences from this model, and tested them in the lab.

We will use their laboratory data to build a score function. Then we will run a generative screen against that score function

# ! wget https://raw.githubusercontent.com/churchlab/Generative_AAV_design/main/Data/vae2021_processed_data.csv
df = pd.read_csv('vae2021_processed_data.csv')
df.dropna(subset=['VAE_virus_S'], inplace=True)
df.head()
Unnamed: 0 category_orig lin_fit pred_fit sampling aa wt_dist mask VAE_virus_S mut category wt_conserved_pos colors beats_wt viable
0 0 VAE-MSA -47.652505 29.069985 most_likely DEEEIRTTNPVATEQYGVTATNLQNSNT 6 _________________VTA____NS_T 6.660194 6.0 VAE-MSA 22.0 plum 1 1
1 1 VAE-MSA -48.071833 28.955883 most_likely SEEEIRTTNPVATEQYGTTATNLQSSNT 7 S________________TTA____SS_T 2.716090 7.0 VAE-MSA 21.0 plum 0 1
2 2 VAE-MSA -47.482053 28.946108 most_likely DEEEIRTTNPVATEQYGVTATNLQNSTT 7 _________________VTA____NSTT 6.619935 7.0 VAE-MSA 21.0 plum 1 1
3 3 VAE-MSA -47.745287 28.280085 most_likely DEEEIRTTNPVATEQYGVTATNLQSSNT 6 _________________VTA____SS_T 6.344426 6.0 VAE-MSA 22.0 plum 1 1
4 4 VAE-MSA -46.253139 28.170899 most_likely DEEEIRTTNPVATEQYGTTATNLQNSNT 6 _________________TTA____NS_T 6.148984 6.0 VAE-MSA 22.0 plum 1 1

Our metric of interest is VAE_virus_S. This is the log-2 ratio of the frequency of a variant in a virus pool relative to the frequency of the corresponding plasmid in the plasmid pool. Higher values indicate higher viability. The goal of the design is to produce variants predicted to beat the wildtype sequence. The dataset indicates which sequences are currently performing at this level

df.VAE_virus_S.hist(alpha=0.5, density=True, label='full_dataset')
df[(df.viable==1) & (df.beats_wt==1)].VAE_virus_S.hist(alpha=0.5, density=True, label='Beats WT')
plt.legend();

Score Function

Now we want to develop a score function for predicting antimicrobial activity. We will use a CNN encoder with a MLP head to predict the score described above.

Our input data will be token integers for amino acids. Note that fingerprint representations are a poor fit for peptide work because peptides contain many repeating substructures.

We will train on 90% of the data and validate on the 10% held out.

train_df = df.sample(frac=0.9, random_state=42).copy()
valid_df = df[~df.index.isin(train_df.index)].copy()
aa_vocab = CharacterVocab(AMINO_ACID_VOCAB)

train_ds = Text_Prediction_Dataset(train_df.aa.values, train_df.VAE_virus_S.values, aa_vocab)
test_ds = Text_Prediction_Dataset(valid_df.aa.values, valid_df.VAE_virus_S.values, aa_vocab)

This is the model we will use:

class Predictive_CNN(nn.Module):
    def __init__(self,
                 d_vocab,
                 d_embedding,
                 d_latent,
                 filters,
                 kernel_sizes,
                 strides,
                 dropouts,
                 mlp_dims,
                 mlp_drops,
                 d_out,
                 outrange
                ):
        super().__init__()
        
        self.conv_encoder = Conv_Encoder(
                                        d_vocab,
                                        d_embedding,
                                        d_latent,
                                        filters,
                                        kernel_sizes,
                                        strides,
                                        dropouts,
                                    )
        
        self.mlp_head = MLP(
                            d_latent,
                            mlp_dims,
                            d_out,
                            mlp_drops,
                            outrange=outrange
                            )
        
    def forward(self, x):
        encoded = self.conv_encoder(x)
        out = self.mlp_head(encoded)
        return out
d_vocab = len(aa_vocab.itos)
d_embedding = 256
d_latent = 512
filters = [256, 512, 1024]
kernel_sizes = [7, 7, 7]
strides = [2,2,2]
dropouts = [0.2, 0.2, 0.2]
mlp_dims = [512, 256, 128]
mlp_drops = [0.2, 0.2, 0.2]
d_out = 1
outrange = [-10, 10]


virus_model = Predictive_CNN(
                    d_vocab,
                    d_embedding,
                    d_latent,
                    filters,
                    kernel_sizes,
                    strides,
                    dropouts,
                    mlp_dims,
                    mlp_drops,
                    d_out,
                    outrange
                )
r_agent = PredictiveAgent(virus_model, MSELoss(), train_ds, opt_kwargs={'lr':1e-3})
r_agent.train_supervised(32, 20, 1e-3)
Epoch Train Loss Valid Loss Time
0 3.82397 3.43804 00:06
1 3.80967 3.17915 00:06
2 2.48006 1.58518 00:06
3 1.86415 1.09765 00:06
4 1.34830 0.79360 00:06
5 1.46673 0.69241 00:06
6 1.37931 0.57727 00:06
7 2.57791 0.69516 00:06
8 1.28085 0.90489 00:06
9 1.70550 0.53054 00:06
10 1.45487 0.49308 00:06
11 0.66824 0.56526 00:06
12 0.53722 0.34486 00:06
13 0.72762 0.36265 00:06
14 0.31081 0.27825 00:06
15 0.93874 0.26328 00:06
16 0.62005 0.24898 00:06
17 0.56705 0.26793 00:06
18 0.33817 0.21645 00:06
19 0.63277 0.20664 00:06

Optional: save score function weights

 

Optional: to load the exact weights used, run the following:

 
valid_dl = test_ds.dataloader(256, num_workers=0, shuffle=False)
r_agent.model.eval();

preds = []
targs = []

with torch.no_grad():
    for i, batch in enumerate(valid_dl):
        batch = to_device(batch)
        x,y = batch
        pred = r_agent.model(x)
        preds.append(pred.detach().cpu())
        targs.append(y.detach().cpu())
        
preds = torch.cat(preds).numpy()
targs = torch.cat(targs).numpy()

preds = preds.squeeze(-1)

Our score function has an r^2 value of about 0.883 on the validation dataset

fig, ax = plt.subplots()

ax.scatter(targs, preds, alpha=0.5, s=1)
plt.xlabel('Target')
plt.ylabel('Prediction')

slope, intercept = np.polyfit(targs, preds, 1)
ax.plot(np.array(ax.get_xlim()), intercept + slope*np.array(ax.get_xlim()), c='r')

plt.text(5., 9., 'R-squared = %0.3f' % r2_score(targs, preds));

We should also take a look at the prediction distribution for our known actives. The predicted values will differ somewhat from the actual values. This will give us a sense of what score we want to see from the model

df['preds'] = r_agent.predict_data(df.aa.values).detach().cpu().numpy()
df.preds.max()
8.637711
np.percentile(df.preds, 99)
5.337126960754395

The maximum predicted value is 8.63. A compound scoring 5.33 or higher would be in the top 1% of all known sequences

Chemical Space

Next we need to develop our chemical space. This is where we decide what sequences will be allowed and what sequences will be removed.

Getting the right filtering parameters makes a huge difference in sequence quality. In practice, finding the right constraints is an interative process. First run a generative screen. Then examine the highest scoring sequences and look for undesirable properties or structural features. Then update the template and iterate.

For peptides, the presence of Arginine has shown to be toxic [ref]. We will apply a template filter for the number of Arginine residues per unit length.

We will also limit the maximum residue frequency in a sample to 0.3. This prevents a common failure mode of seeing the same residue repeated multiple times (ie MSSSSSSSRP). This is a flaw in the simplistic score functions we are using

aa_vocab = CharacterVocab(AMINO_ACID_VOCAB)

template = Template([ValidityFilter(),
                     CharacterCountFilter(['A'], min_val=0, max_val=0.1, per_length=True, mode='protein'),
                     CharacterCountFilter(aa_vocab.itos[4:], min_val=0, max_val=0.3, 
                                          per_length=True, mode='protein')], 
                    [], fail_score=-10., log=False, use_lookup=False, mode='protein')

template_cb = TemplateCallback(template, prefilter=True)

Load Model

We load the LSTM_LM_Small_Swissprot model. This is a basic LSTM-based language model trained on part of the Swissprot protein database

agent = LSTM_LM_Small_Swissprot(drop_scale=0.3, opt_kwargs={'lr':5e-5})

Fine-Tune Model

The pretrained model we loaded is a very general model that can produce a high diversity of structures. However, what we actually want are structues with antimicrobial activity. To induce this, we fine-tune on high scoring sequences.

The dataset denotes two categories for high scoring sequences. The viable category contains ~3000 samples that were able to successfully assemble and package genetic material. The beats_wt category contains ~300 samples with higher expression than the wildtype variant.

The beats_wt dataset feels a bit small for fine-tuning, so we will first fine-tune on the viable dataset, then the beats_wt dataset

agent.update_dataset_from_inputs(df[df.viable==1].aa.values)
agent.train_supervised(32, 8, 5e-5)

agent.update_dataset_from_inputs(df[df.beats_wt==1].aa.values)
agent.train_supervised(32, 6, 5e-5)

agent.base_to_model()
Epoch Train Loss Valid Loss Time
0 2.98167 2.94361 00:08
1 2.29855 2.01615 00:08
2 1.18073 1.43596 00:07
3 1.18336 1.24130 00:07
4 1.00311 1.16523 00:07
5 1.40234 1.13306 00:07
6 1.26355 1.12186 00:07
7 0.81410 1.12125 00:08
Epoch Train Loss Valid Loss Time
0 0.82398 0.64645 00:06
1 0.88730 0.63657 00:06
2 0.70039 0.62748 00:06
3 0.69084 0.62392 00:06
4 0.73466 0.62024 00:06
5 0.61169 0.61998 00:06

Optional: save fine-tuned weights

 
 

Reinforcement Learning

Now we enter the reinforcement learning stage

Loss

We use PPO as our policy gradient loss

pg = PPO(0.99,
        0.5,
        lam=0.95,
        v_coef=0.5,
        cliprange=0.3,
        v_cliprange=0.3,
        ent_coef=0.01,
        kl_target=0.03,
        kl_horizon=3000,
        scale_rewards=True)

loss = PolicyLoss(pg, 'PPO', 
                   value_head=ValueHead(256), 
                   v_update_iter=2, 
                   vopt_kwargs={'lr':1e-3})

Reward

Here we pass the reward agent we trained earlier to a callback.

aa_vocab = CharacterVocab(AMINO_ACID_VOCAB)

d_vocab = len(aa_vocab.itos)
d_embedding = 256
d_latent = 512
filters = [256, 512, 1024]
kernel_sizes = [7, 7, 7]
strides = [2,2,2]
dropouts = [0.2, 0.2, 0.2]
mlp_dims = [512, 256, 128]
mlp_drops = [0.2, 0.2, 0.2]
d_out = 1
outrange = [-10, 10]


reward_model = Predictive_CNN(
                    d_vocab,
                    d_embedding,
                    d_latent,
                    filters,
                    kernel_sizes,
                    strides,
                    dropouts,
                    mlp_dims,
                    mlp_drops,
                    d_out,
                    outrange
                )

r_ds = Text_Prediction_Dataset(['M'], [0.], aa_vocab)

r_agent = PredictiveAgent(reward_model, MSELoss(), r_ds, opt_kwargs={'lr':1e-3})

r_agent.load_weights('untracked_files/virus_predictor.pt')
# r_agent.load_state_dict(model_from_url('virus_predictor.pt')) # optional - load exact weights

reward_model.eval();

freeze(reward_model)

reward_function = Reward(r_agent.predict_data, weight=1)

virus_reward = RewardCallback(reward_function, 'virus')

Optional Reward: Stability Metric

There has been a lot of great work recently looking at large scale transformer language models for unsupervised learning of protein structures. One interesting thing that has emerged is a rough relationship between the protein sequence log probability given by a generative model and the stability of the protein sequence.

We can use the log probability values from a pretrained protein transformer model as a proxy for stability. Including this as a reward function can help keep the generated peptides realistic.

To include this as a reward term, run the code below to install the ESM library to access a pretrained protein transformer model.

In the interest of time, we will use the smallest ESM model with 43M parameters, rather than the large scale 630M parameter model. Note that even with the smaller model, this reward term adds significanly to the training runtime

# ! pip install fair-esm
import esm
protein_model, alphabet = esm.pretrained.esm1_t6_43M_UR50S()
batch_converter = alphabet.get_batch_converter()
class PeptideStability():
    def __init__(self, model, alphabet, batch_converter):
        self.model = model
        to_device(self.model)
        self.alphabet = alphabet
        self.batch_converter = batch_converter
        
    def __call__(self, samples):
        
        data = [
            (f'protein{i}', samples[i]) for i in range(len(samples))
        ]
        
        batch_labels, batch_strs, batch_tokens = self.batch_converter(data)

        with torch.no_grad():
            results = self.model(to_device(batch_tokens))

        lps = F.log_softmax(results['logits'], -1)

        mean_lps = lps.gather(2, to_device(batch_tokens).unsqueeze(-1)).squeeze(-1).mean(-1)
        
        return mean_lps
ps = PeptideStability(protein_model, alphabet, batch_converter)
stability_reward = Reward(ps, weight=0.1, bs=300)
stability_cb = RewardCallback(stability_reward, name='stability')
stability_reward(df.aa.values[:10])
tensor([-0.4359, -0.4157, -0.4197, -0.4416, -0.4249, -0.4359, -0.4065, -0.4426,
        -0.4420, -0.4499], device='cuda:0')

Samplers

We create the following samplers:

  • sampler1 ModelSampler: this samples from the main model. This sample will add 1000 compounds to the buffer each buffer build, and sample 40% of each batch on the fly from the main model.
  • sampler2 ModelSampler: this samples from the baseline model and is not sampled live on each batch
  • sampler3 LogSampler: this samples high scoring samples from the lig
  • sampler4 TokenSwapSampler: this uses token swap comibichem to generate new samples from high scoring samples
  • sampler5 DatasetSampler: this sprinkles in a small amount of high scoring samples into each buffer build.
gen_bs = 1500

sampler1 = ModelSampler(agent.vocab, agent.model, 'live', 1000, 0., gen_bs)
sampler2 = ModelSampler(agent.vocab, agent.base_model, 'base', 1000, 0., gen_bs)
sampler3 = LogSampler('samples', 'rewards', 10, 98, 200)
sampler4 = TokenSwapSampler('samples', 'rewards', 10, 98, 200, aa_vocab, 0.2)
sampler5 = DatasetSampler(df[(df.beats_wt==1)].aa.values, 
                          'data', buffer_size=6)

samplers = [sampler1, sampler2, sampler3, sampler4, sampler5]

Callbacks

Additional callbacks

supervised_cb = SupervisedCB(agent, 20, 0.5, 98, 1e-4, 64)
live_max = MaxCallback('rewards', 'live')
live_p90 = PercentileCallback('rewards', 'live', 90)

cbs = [supervised_cb, live_p90, live_max]

Environment and Train

Now we can put together our Environment and run the training process

env = Environment(agent, template_cb, samplers=samplers, rewards=[virus_reward, stability_cb], losses=[loss],
                 cbs=cbs)
 
env.fit(128, 28, 400, 25)
iterations rewards rewards_final new diversity bs template valid virus stability PPO rewards_live_p90 rewards_live_max
0 -1.395 -1.395 1.000 1.000 128 0.000 1.000 -0.850 -0.545 0.766 -0.505 0.577
25 -0.893 -0.893 0.961 1.000 128 0.000 1.000 -0.257 -0.636 0.954 -0.388 1.612
50 -1.355 -1.355 0.922 1.000 128 0.000 1.000 -0.563 -0.792 1.333 -0.813 1.709
75 -0.483 -0.483 0.883 1.000 128 0.000 1.000 0.143 -0.626 1.409 -0.014 4.435
100 -0.082 -0.082 0.836 1.000 128 0.000 1.000 0.585 -0.667 1.915 0.440 3.581
125 -0.315 -0.315 0.898 1.000 128 0.000 1.000 0.313 -0.628 1.402 -0.061 3.180
150 -0.251 -0.251 0.891 1.000 128 0.000 1.000 0.385 -0.635 1.593 0.927 3.964
175 -0.536 -0.536 0.914 1.000 128 0.000 1.000 0.090 -0.626 1.053 0.388 3.711
200 -0.155 -0.155 0.883 1.000 128 0.000 1.000 0.508 -0.663 1.955 2.968 5.444
225 0.253 0.253 0.875 1.000 128 0.000 1.000 0.931 -0.678 1.996 3.170 5.373
250 0.194 0.194 0.875 1.000 128 0.000 1.000 0.838 -0.643 1.893 2.214 4.495
275 -0.033 -0.033 0.938 1.000 128 0.000 1.000 0.472 -0.506 1.557 1.815 5.068
300 0.429 0.429 0.922 1.000 128 0.000 1.000 1.048 -0.619 3.572 5.125 6.869
325 0.677 0.677 0.875 1.000 128 0.000 1.000 1.324 -0.648 2.413 4.447 6.803
350 0.821 0.821 0.930 1.000 128 0.000 1.000 1.310 -0.489 1.856 3.381 5.525
375 1.333 1.333 0.906 1.000 128 0.000 1.000 2.027 -0.694 2.674 5.372 6.963
env.log.plot_metrics()

Evaluation

Based on our score function, we determined a sequence with a predicted score of 5.33 or higher would be in the top 1% of sequences relative to the training data. A sequence with a predicted score of 8.63 or higher would beat out all sequences in the dataset

We found 1670 sequences with a predicted score of 5.33 or higher and 13 sequences with a predicted score of 8.63 or higher

env.log.df[(env.log.df.virus>5.33) & ~(env.log.df.sources=='data_buffer')].shape
(1670, 8)
env.log.df[(env.log.df.virus>8.63) & ~(env.log.df.sources=='data_buffer')].shape
(13, 8)

We can generate logo stack plots to visualize what residues are favored

def plot_logo(seqs):
    freqs = []

    for i in range(len(seqs[0])):
        aas = [j[i] for j in seqs]
        counts = Counter(aas)
        total = sum(counts.values())
        norm_counts = defaultdict(lambda: 0)
        for k in counts.keys():
            norm_counts[k] = counts[k]/total
        freqs.append(norm_counts)
        
    aas = aa_vocab.itos[4:]
    dfs = []

    for i, f in enumerate(freqs):
        df_iter = pd.DataFrame([[f[aa] for aa in aas]], columns=aas)
        df_iter['Position'] = i
        dfs.append(df_iter)
        
    dfs = pd.concat(dfs)
    dfs = dfs.set_index('Position')

    return dfs.groupby('Position').mean().plot.bar(stacked=True, figsize=(12,8))

Here's the logo plot for high scoring sequences in the dataset

plot_logo(df[(df.beats_wt==1) & (df.VAE_virus_S>7)].aa.values)
<AxesSubplot:xlabel='Position'>

Here's the residue plot for high scoring generated sequences

plot_logo(env.log.df[(env.log.df.virus>8.9) & ~(env.log.df.sources=='data_buffer')].samples.values)
<AxesSubplot:xlabel='Position'>

We can see similarities between dataset and generated sequences at many positions, with some significant differences (ie position 15)

So are these residue changes real and meaningful? That depends on the quality of our score function. The only way to know is to test sequences in the lab