Using MRL to design antimicrobial peptides

Antimicrobial Peptide Design

This tutorial runs an end to end workflow for designing antimicrobial peptides using MRL.

Antimicrobial peptides are a family of peptides that kill bacteria. These peptides are of interest for antibiotics and anticancer therapeutics

Performance Notes

The workflow in this notebook is more CPU-constrained than GPU-constrained due to the need to evaluate samples on CPU. If you have a multi-core machine, it is recommended that you uncomment and run the set_global_pool cells in the notebook. This will trigger the use of multiprocessing, which will result in 2-4x speedups.

This notebook may run slow on Collab due to CPU limitations.

If running on Collab, remember to change the runtime to GPU

import sys
sys.path.append('..')

from mrl.imports import *
from mrl.core import *
from mrl.chem import *
from mrl.templates.all import *

from mrl.torch_imports import *
from mrl.torch_core import *
from mrl.layers import *
from mrl.dataloaders import *
from mrl.g_models.all import *
from mrl.vocab import *
from mrl.policy_gradient import *
from mrl.train.all import *
from mrl.model_zoo import *
from sklearn.metrics import roc_curve, auc
from sklearn.metrics import roc_auc_score
/home/dmai/miniconda3/envs/bio-transformers/lib/python3.7/importlib/_bootstrap.py:219: RuntimeWarning: to-Python converter for boost::shared_ptr<RDKit::FilterCatalogEntry const> already registered; second conversion method ignored.
  return f(*args, **kwds)
os.makedirs('untracked_files', exist_ok=True)

Data

The dataset comes from the AMPlify repo. It contains ~8300 short peptides classified as antimicrobial or not antimicrobial

df = pd.read_csv('../files/anti_microbial_peptides.csv')

# if in Collab
# download_files()
# df = pd.read_csv('files/anti_microbial_peptides.csv')
df.head()
name sequence dataset label
0 >trAMP0001 GLLDTFKNLALNAAKSAGVSVLNSLSCKLSKTC train 1
1 >trAMP0002 AKKPVAKKAAGGVKKPK train 1
2 >trAMP0003 GIIDIAKKLVGGIRNVLGI train 1
3 >trAMP0004 MAGFLKVVQLLAKYGSKAVQWAWANKGKILDWLNAGQAIDWVVSKI... train 1
4 >trAMP0005 YGPGDGHGGGHGGGHGGGHGNGQGGGHGHGPGGGFGGGHGGGHGGG... train 1

Most peptides are short, less than 50 amino acids

df.sequence.map(lambda x: len(x)).hist()
<AxesSubplot:>

Score Function

Now we want to develop a score function for predicting antimicrobial activity. We will use a CNN encoder with a MLP head to predict a binary classification value for antimicrobial activity.

Our input data will be token integers for amino acids. Note that fingerprint representations are a poor fit for peptide work because peptides contain many repeating substructures.

We will train on 95% of the data and validate on the 5% held out.

train_df = df[df.dataset=='train']
valid_df = df[df.dataset=='valid']
aa_vocab = CharacterVocab(AMINO_ACID_VOCAB)

amp_ds = Text_Prediction_Dataset(train_df.sequence.values, train_df.label.values, aa_vocab)
test_ds = Text_Prediction_Dataset(valid_df.sequence.values, valid_df.label.values, aa_vocab)

This is the model we will use:

class Predictive_CNN(nn.Module):
    def __init__(self,
                 d_vocab,
                 d_embedding,
                 d_latent,
                 filters,
                 kernel_sizes,
                 strides,
                 dropouts,
                 mlp_dims,
                 mlp_drops,
                 d_out
                ):
        super().__init__()
        
        self.conv_encoder = Conv_Encoder(
                                        d_vocab,
                                        d_embedding,
                                        d_latent,
                                        filters,
                                        kernel_sizes,
                                        strides,
                                        dropouts,
                                    )
        
        self.mlp_head = MLP(
                            d_latent,
                            mlp_dims,
                            d_out,
                            mlp_drops
                            )
        
    def forward(self, x):
        encoded = self.conv_encoder(x)
        out = self.mlp_head(encoded)
        return out
d_vocab = len(aa_vocab.itos)
d_embedding = 256
d_latent = 512
filters = [256, 512, 1024]
kernel_sizes = [7, 7, 7]
strides = [2,2,2]
dropouts = [0.2, 0.2, 0.2]
mlp_dims = [512, 256, 128]
mlp_drops = [0.2, 0.2, 0.2]
d_out = 1


amp_model = Predictive_CNN(
                    d_vocab,
                    d_embedding,
                    d_latent,
                    filters,
                    kernel_sizes,
                    strides,
                    dropouts,
                    mlp_dims,
                    mlp_drops,
                    d_out
                )
r_agent = PredictiveAgent(amp_model, BinaryCrossEntropy(), amp_ds, opt_kwargs={'lr':1e-3})
r_agent.train_supervised(32, 12, 1e-3)
Epoch Train Loss Valid Loss Time
0 0.48329 0.35328 00:07
1 0.19658 0.29164 00:05
2 0.22350 0.38356 00:05
3 0.56445 0.29710 00:06
4 0.06063 0.26211 00:05
5 1.61786 0.17783 00:05
6 0.10195 0.15677 00:05
7 0.27435 0.09549 00:05
8 0.25234 0.08911 00:05
9 0.05291 0.09178 00:05
10 0.01665 0.11609 00:05
11 0.03879 0.13075 00:05

Optional: save score function weights

 

Optional: to load the exact weights used, run the following:

 
valid_dl = test_ds.dataloader(256, shuffle=False)
r_agent.model.eval();

preds = []
targs = []

with torch.no_grad():
    for i, batch in enumerate(valid_dl):
        batch = to_device(batch)
        x,y = batch
        pred = r_agent.model(x)
        preds.append(pred.detach().cpu())
        targs.append(y.detach().cpu())
        
preds = torch.cat(preds).numpy()
targs = torch.cat(targs).numpy()
fpr, tpr, _ = roc_curve(targs, torch.tensor(preds).sigmoid().squeeze().numpy())
roc_auc = auc(fpr, tpr)

plt.figure()
lw = 2
plt.plot(fpr, tpr, color='darkorange',
         lw=lw, label='ROC curve (area = %0.4f)' % roc_auc)
plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver operating characteristic example')
plt.legend(loc="lower right")
plt.show()

Chemical Space

Next we need to develop our chemical space. This is where we decide what sequences will be allowed and what sequences will be removed.

Getting the right filtering parameters makes a huge difference in sequence quality. In practice, finding the right constraints is an interative process. First run a generative screen. Then examine the highest scoring sequences and look for undesirable properties or structural features. Then update the template and iterate.

For peptides, the presence of Arginine has shown to be toxic [ref]. We will apply a template filter for the number of Arginine residues per unit length.

We will also limit the maximum residue frequency in a sample to 0.4. This prevents a common failure mode of seeing the same residue repeated multiple times (ie MSSSSSSSRP). This is a flaw in the simplistic score functions we are using

aa_vocab = CharacterVocab(AMINO_ACID_VOCAB)

template = Template([ValidityFilter(),
                     CharacterCountFilter(['A'], min_val=0, max_val=0.1, per_length=True, mode='protein'),
                     CharacterCountFilter(aa_vocab.itos[4:], min_val=0, max_val=0.3, 
                                          per_length=True, mode='protein')], 
                    [], fail_score=-10., log=False, use_lookup=False, mode='protein')

template_cb = TemplateCallback(template, prefilter=True)

Load Model

We load the LSTM_LM_Small_Swissprot model. This is a basic LSTM-based language model trained on part of the Swissprot protein database

agent = LSTM_LM_Small_Swissprot(drop_scale=0.3, opt_kwargs={'lr':1e-4})

Fine-Tune Model

The pretrained model we loaded is a very general model that can produce a high diversity of structures. However, what we actually want are structues with antimicrobial activity. To induce this, we fine-tune on the active peptides in the dataset

agent.update_dataset_from_inputs(df[df.label==1].sequence.values)
agent.train_supervised(32, 8, 5e-5)
agent.base_to_model()
Epoch Train Loss Valid Loss Time
0 1.34313 0.93171 00:09
1 1.60108 0.85136 00:09
2 1.26246 0.80041 00:09
3 1.09434 0.78952 00:09
4 0.70235 0.78172 00:09
5 1.37438 0.77918 00:09
6 0.61792 0.77801 00:09
7 0.71503 0.77747 00:09

Optional: save fine-tuned weights

 
 

Reinforcement Learning

Now we enter the reinforcement learning stage

Loss

We use PPO as our policy gradient loss

pg = PPO(0.99,
        0.5,
        lam=0.95,
        v_coef=0.5,
        cliprange=0.3,
        v_cliprange=0.3,
        ent_coef=0.01,
        kl_target=0.03,
        kl_horizon=3000,
        scale_rewards=True)

loss = PolicyLoss(pg, 'PPO', 
                   value_head=ValueHead(256), 
                   v_update_iter=2, 
                   vopt_kwargs={'lr':1e-3})

Reward

Here we pass the reward agent we trained earlier to a callback.

Since our model is a classification model, we have a decision with respect to what value we output from our score function. We could output the raw logit value from the model, or the sigmoid-scaled classification prediction.

If we use the sigmoid-scaled output, we tend to get lots of samples with scores in the 0.999 range that only differ by a very small amount. This can make it difficult to differentiate the true winners.

If we use the logit output, we can get much better differentiation of samples at the top end. However, we have a much higher chance of finding weird samples that get absurd logit values.

The code below uses the raw logit value, clipped to a range of [-10, 10]

d_vocab = len(aa_vocab.itos)
d_embedding = 256
d_latent = 512
filters = [256, 512, 1024]
kernel_sizes = [7, 7, 7]
strides = [2,2,2]
dropouts = [0.2, 0.2, 0.2]
mlp_dims = [512, 256, 128]
mlp_drops = [0.2, 0.2, 0.2]
d_out = 1


reward_model = Predictive_CNN(
                    d_vocab,
                    d_embedding,
                    d_latent,
                    filters,
                    kernel_sizes,
                    strides,
                    dropouts,
                    mlp_dims,
                    mlp_drops,
                    d_out
                )


r_ds = Text_Prediction_Dataset(['M'], [0.], aa_vocab)

r_agent = PredictiveAgent(reward_model, BinaryCrossEntropy(), r_ds, opt_kwargs={'lr':1e-3})

r_agent.load_weights('untracked_files/amp_predictor.pt')
# r_agent.load_state_dict(model_from_url('amp_predictor.pt')) # optional - load exact weights

reward_model.eval();

freeze(reward_model)

class ClippedModelReward():
    def __init__(self, agent, minclip, maxclip):
        self.agent = agent
        self.minclip = minclip
        self.maxclip = maxclip
        
    def __call__(self, sequences):
        preds = self.agent.predict_data(sequences)
        preds = torch.clamp(preds, self.minclip, self.maxclip)
        return preds
        
reward_function = Reward(ClippedModelReward(r_agent, -10, 10), weight=1)

amp_reward = RewardCallback(reward_function, 'amp')

Optional Reward: Stability Metric

There has been a lot of great work recently looking at large scale transformer language models for unsupervised learning of protein structures. One interesting thing that has emerged is a rough relationship between the protein sequence log probability given by a generative model and the stability of the protein sequence.

We can use the log probability values from a pretrained protein transformer model as a proxy for stability. Including this as a reward function can help keep the generated peptides realistic.

To include this as a reward term, run the code below to install the ESM library to access a pretrained protein transformer model.

In the interest of time, we will use the smallest ESM model with 43M parameters, rather than the large scale 630M parameter model. Note that even with the smaller model, this reward term adds significanly to the training runtime

# ! pip install fair-esm
import esm
protein_model, alphabet = esm.pretrained.esm1_t6_43M_UR50S()
batch_converter = alphabet.get_batch_converter()
class PeptideStability():
    def __init__(self, model, alphabet, batch_converter):
        self.model = model
        to_device(self.model)
        self.alphabet = alphabet
        self.batch_converter = batch_converter
        
    def __call__(self, samples):
        
        data = [
            (f'protein{i}', samples[i]) for i in range(len(samples))
        ]
        
        batch_labels, batch_strs, batch_tokens = self.batch_converter(data)

        with torch.no_grad():
            results = self.model(to_device(batch_tokens))

        lps = F.log_softmax(results['logits'], -1)

        mean_lps = lps.gather(2, to_device(batch_tokens).unsqueeze(-1)).squeeze(-1).mean(-1)
        
        return mean_lps
ps = PeptideStability(protein_model, alphabet, batch_converter)
stability_reward = Reward(ps, weight=0.1, bs=300)
stability_cb = RewardCallback(stability_reward, name='stability')
stability_reward(df.sequence.values[:10])
tensor([-2.8469, -3.1892, -3.2956, -1.7978, -0.1794, -4.1585, -3.9322, -3.8093,
        -2.4914, -3.7035], device='cuda:0')

Samplers

We create the following samplers:

  • sampler1 ModelSampler: this samples from the main model. This sample will add 1000 compounds to the buffer each buffer build, and sample 40% of each batch on the fly from the main model.
  • sampler2 ModelSampler: this samples from the baseline model and is not sampled live on each batch
  • sampler3 LogSampler: this samples high scoring samples from the lig
  • sampler4 TokenSwapSampler: this uses token swap comibichem to generate new samples from high scoring samples
  • sampler5 DatasetSampler: this sprinkles in a small amount of known actives into each buffer build. Note that we subset by peptide length to align with the generated length (75 amino acids)
gen_bs = 1500

sampler1 = ModelSampler(agent.vocab, agent.model, 'live', 1000, 0., gen_bs)
sampler2 = ModelSampler(agent.vocab, agent.base_model, 'base', 1000, 0., gen_bs)
sampler3 = LogSampler('samples', 'rewards', 10, 98, 200)
sampler4 = TokenSwapSampler('samples', 'rewards', 10, 98, 200, aa_vocab, 0.2)
sampler5 = DatasetSampler(df[(df.label==1) & (df.sequence.map(lambda x: len(x)<=75))].sequence.values, 
                          'data', buffer_size=4)

samplers = [sampler1, sampler2, sampler3, sampler4, sampler5]

Callbacks

Additional callbacks

supervised_cb = SupervisedCB(agent, 20, 0.5, 98, 1e-4, 64)
live_max = MaxCallback('rewards', 'live')
live_p90 = PercentileCallback('rewards', 'live', 90)

cbs = [supervised_cb, live_p90, live_max]

Environment and Train

Now we can put together our Environment and run the training process

env = Environment(agent, template_cb, samplers=samplers, rewards=[amp_reward, stability_cb], losses=[loss],
                 cbs=cbs)
set_global_pool(min(12, os.cpu_count()))
env.fit(128, 75, 300, 20)
iterations rewards rewards_final new diversity bs template valid amp stability PPO rewards_live_p90 rewards_live_max
0 -0.824 -0.824 1.000 1.000 128 0.000 1.000 2.367 -3.191 2.453 3.657 6.105
20 1.215 1.215 0.969 1.000 128 0.000 1.000 3.946 -2.731 4.314 7.218 9.701
40 3.259 3.259 0.922 1.000 128 0.000 1.000 5.021 -1.762 5.137 9.453 9.752
60 6.111 6.111 0.844 1.000 128 0.000 1.000 7.243 -1.131 6.201 9.640 9.676
80 6.526 6.526 0.852 1.000 128 0.000 1.000 7.476 -0.950 4.232 9.597 9.675
100 7.899 7.899 0.844 1.000 128 0.000 1.000 8.536 -0.637 2.952 9.732 9.747
120 8.080 8.080 0.875 1.000 128 0.000 1.000 8.717 -0.637 1.876 9.698 9.720
140 8.489 8.489 0.820 1.000 128 0.000 1.000 8.878 -0.390 0.945 9.742 9.762
160 8.553 8.553 0.828 1.000 128 0.000 1.000 9.049 -0.496 0.694 9.702 9.727
180 8.925 8.925 0.852 1.000 128 0.000 1.000 9.333 -0.408 0.427 9.755 9.766
200 9.152 9.152 0.797 1.000 128 0.000 1.000 9.461 -0.309 0.211 9.712 9.727
220 9.195 9.195 0.781 1.000 128 0.000 1.000 9.520 -0.324 0.310 9.760 9.776
240 9.096 9.096 0.797 1.000 128 0.000 1.000 9.503 -0.407 0.486 9.713 9.725
260 9.295 9.295 0.805 1.000 128 0.000 1.000 9.600 -0.305 0.262 9.765 9.779
280 9.299 9.299 0.859 1.000 128 0.000 1.000 9.623 -0.324 0.406 9.774 9.806
env.log.plot_metrics()

Evaluation

Looking at high scoring sequences, we see many sequences contain with a high concentration of K residues. These features likely come from a small number of dataset samples with high K contents. For example, AATKPKKAGAEAAPKKPAKKQTKKKPAKKAGGKKKPKRAGAKKAKK is a AMP sequence in the dataset.

If these traits are undesirable, they can be controlled by updating the Template to limit other residues the same way we limited A residues.

env.log.df[env.log.df.rewards>9.8]
samples sources rewards rewards_final template amp stability PPO
24119 GFLIGLAAEGIKKIGGKIGKIIGKIIKKVGKNIGDTVEKIGKNAGK... base_buffer 9.800656 9.800656 0.0 10.0 -0.199344 -0.473588
24514 GIPGAIGKAIKGGLGKVLKGCGVKGAKIIGGGRKVGINGKVVKKVV... base_buffer 9.800788 9.800788 0.0 10.0 -0.199212 -0.319604
29395 GLLGLFGKAIKDIGVKVIGKAIKNIGIKGIDKIMKGIIGKKKIPNV... live_buffer 9.802987 9.802987 0.0 10.0 -0.197013 -0.461350
29602 GIFGGIAKGVKNAIKKLGKKIGGKIGKGIGKIIIGKATKGAHGKKV... live_buffer 9.802441 9.802441 0.0 10.0 -0.197559 -0.438858
30374 GLIGALKKAAKKIGKKVGKKIGGKIVKGGMIKKGGKGIIGDKTIGG... live_buffer 9.820191 9.820191 0.0 10.0 -0.179809 -0.489034
30917 GIAGAIGKAGIKAIKKGRKIVGKGIGKILIKKVGKGIGNKIKVINP... live_buffer 9.805631 9.805631 0.0 10.0 -0.194369 -0.387992
31535 GILGKVGKKIGKGVTKIVGKIGKIKIGGKITNHVIGVIKKVGKSYI... base_buffer 9.801026 9.801026 0.0 10.0 -0.198973 -0.457702
32550 GILGGIVKKLAGKIAKGIGKIVKGLGKKIGKGAGVTGKILGEKLGK... base_buffer 9.804182 9.804182 0.0 10.0 -0.195818 -0.310737
32688 GVLAALGKAIGKAVKGAGVKAIKGLKVKTIGKIGTIIGKIAGKVIL... live_buffer 9.810798 9.810798 0.0 10.0 -0.189203 -0.362843
aas = ['G', 'V', 'I', 'K']

fig, axes = plt.subplots(2,2, figsize=(12,8))

for i, ax in enumerate(axes.flat):
    if i<len(aas):
        env.log.df[env.log.df.rewards>9.].samples.map(
            lambda x: x.count(aas[i])/len(x)).hist(density=True, alpha=0.5, label=f'Generated {aas[i]} Counts',
                                                  ax=ax)
        df[df.label==1].sequence.map(
            lambda x: x.count(aas[i])/len(x)).hist(density=True, alpha=0.5, label=f'Dataset {aas[i]} Counts',
                                                  ax=ax)
        ax.legend()
    else:
        ax.axis('off')