Performance Notes
The workflow in this notebook is more CPU-constrained than GPU-constrained due to the need to evaluate samples on CPU. If you have a multi-core machine, it is recommended that you uncomment and run the set_global_pool
cells in the notebook. This will trigger the use of multiprocessing, which will result in 2-4x speedups.
This notebook may run slow on Collab due to CPU limitations.
If running on Collab, remember to change the runtime to GPU
import sys
sys.path.append('..')
from mrl.imports import *
from mrl.core import *
from mrl.chem import *
from mrl.templates.all import *
from mrl.torch_imports import *
from mrl.torch_core import *
from mrl.layers import *
from mrl.dataloaders import *
from mrl.g_models.all import *
from mrl.vocab import *
from mrl.policy_gradient import *
from mrl.train.all import *
from mrl.model_zoo import *
from sklearn.metrics import roc_curve, auc
from sklearn.metrics import roc_auc_score
os.makedirs('untracked_files', exist_ok=True)
Data
The dataset comes from the AMPlify repo. It contains ~8300 short peptides classified as antimicrobial or not antimicrobial
df = pd.read_csv('../files/anti_microbial_peptides.csv')
# if in Collab
# download_files()
# df = pd.read_csv('files/anti_microbial_peptides.csv')
df.head()
Most peptides are short, less than 50 amino acids
df.sequence.map(lambda x: len(x)).hist()
Score Function
Now we want to develop a score function for predicting antimicrobial activity. We will use a CNN encoder with a MLP head to predict a binary classification value for antimicrobial activity.
Our input data will be token integers for amino acids. Note that fingerprint representations are a poor fit for peptide work because peptides contain many repeating substructures.
We will train on 95% of the data and validate on the 5% held out.
train_df = df[df.dataset=='train']
valid_df = df[df.dataset=='valid']
aa_vocab = CharacterVocab(AMINO_ACID_VOCAB)
amp_ds = Text_Prediction_Dataset(train_df.sequence.values, train_df.label.values, aa_vocab)
test_ds = Text_Prediction_Dataset(valid_df.sequence.values, valid_df.label.values, aa_vocab)
This is the model we will use:
class Predictive_CNN(nn.Module):
def __init__(self,
d_vocab,
d_embedding,
d_latent,
filters,
kernel_sizes,
strides,
dropouts,
mlp_dims,
mlp_drops,
d_out
):
super().__init__()
self.conv_encoder = Conv_Encoder(
d_vocab,
d_embedding,
d_latent,
filters,
kernel_sizes,
strides,
dropouts,
)
self.mlp_head = MLP(
d_latent,
mlp_dims,
d_out,
mlp_drops
)
def forward(self, x):
encoded = self.conv_encoder(x)
out = self.mlp_head(encoded)
return out
d_vocab = len(aa_vocab.itos)
d_embedding = 256
d_latent = 512
filters = [256, 512, 1024]
kernel_sizes = [7, 7, 7]
strides = [2,2,2]
dropouts = [0.2, 0.2, 0.2]
mlp_dims = [512, 256, 128]
mlp_drops = [0.2, 0.2, 0.2]
d_out = 1
amp_model = Predictive_CNN(
d_vocab,
d_embedding,
d_latent,
filters,
kernel_sizes,
strides,
dropouts,
mlp_dims,
mlp_drops,
d_out
)
r_agent = PredictiveAgent(amp_model, BinaryCrossEntropy(), amp_ds, opt_kwargs={'lr':1e-3})
r_agent.train_supervised(32, 12, 1e-3)
Optional: save score function weights
Optional: to load the exact weights used, run the following:
valid_dl = test_ds.dataloader(256, shuffle=False)
r_agent.model.eval();
preds = []
targs = []
with torch.no_grad():
for i, batch in enumerate(valid_dl):
batch = to_device(batch)
x,y = batch
pred = r_agent.model(x)
preds.append(pred.detach().cpu())
targs.append(y.detach().cpu())
preds = torch.cat(preds).numpy()
targs = torch.cat(targs).numpy()
fpr, tpr, _ = roc_curve(targs, torch.tensor(preds).sigmoid().squeeze().numpy())
roc_auc = auc(fpr, tpr)
plt.figure()
lw = 2
plt.plot(fpr, tpr, color='darkorange',
lw=lw, label='ROC curve (area = %0.4f)' % roc_auc)
plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver operating characteristic example')
plt.legend(loc="lower right")
plt.show()
Chemical Space
Next we need to develop our chemical space. This is where we decide what sequences will be allowed and what sequences will be removed.
Getting the right filtering parameters makes a huge difference in sequence quality. In practice, finding the right constraints is an interative process. First run a generative screen. Then examine the highest scoring sequences and look for undesirable properties or structural features. Then update the template and iterate.
For peptides, the presence of Arginine has shown to be toxic [ref]. We will apply a template filter for the number of Arginine residues per unit length.
We will also limit the maximum residue frequency in a sample to 0.4. This prevents a common failure mode of seeing the same residue repeated multiple times (ie MSSSSSSSRP
). This is a flaw in the simplistic score functions we are using
aa_vocab = CharacterVocab(AMINO_ACID_VOCAB)
template = Template([ValidityFilter(),
CharacterCountFilter(['A'], min_val=0, max_val=0.1, per_length=True, mode='protein'),
CharacterCountFilter(aa_vocab.itos[4:], min_val=0, max_val=0.3,
per_length=True, mode='protein')],
[], fail_score=-10., log=False, use_lookup=False, mode='protein')
template_cb = TemplateCallback(template, prefilter=True)
Load Model
We load the LSTM_LM_Small_Swissprot
model. This is a basic LSTM-based language model trained on part of the Swissprot protein database
agent = LSTM_LM_Small_Swissprot(drop_scale=0.3, opt_kwargs={'lr':1e-4})
agent.update_dataset_from_inputs(df[df.label==1].sequence.values)
agent.train_supervised(32, 8, 5e-5)
agent.base_to_model()
Optional: save fine-tuned weights
Loss
We use PPO
as our policy gradient loss
pg = PPO(0.99,
0.5,
lam=0.95,
v_coef=0.5,
cliprange=0.3,
v_cliprange=0.3,
ent_coef=0.01,
kl_target=0.03,
kl_horizon=3000,
scale_rewards=True)
loss = PolicyLoss(pg, 'PPO',
value_head=ValueHead(256),
v_update_iter=2,
vopt_kwargs={'lr':1e-3})
Reward
Here we pass the reward agent we trained earlier to a callback.
Since our model is a classification model, we have a decision with respect to what value we output from our score function. We could output the raw logit value from the model, or the sigmoid-scaled classification prediction.
If we use the sigmoid-scaled output, we tend to get lots of samples with scores in the 0.999
range that only differ by a very small amount. This can make it difficult to differentiate the true winners.
If we use the logit output, we can get much better differentiation of samples at the top end. However, we have a much higher chance of finding weird samples that get absurd logit values.
The code below uses the raw logit value, clipped to a range of [-10, 10]
d_vocab = len(aa_vocab.itos)
d_embedding = 256
d_latent = 512
filters = [256, 512, 1024]
kernel_sizes = [7, 7, 7]
strides = [2,2,2]
dropouts = [0.2, 0.2, 0.2]
mlp_dims = [512, 256, 128]
mlp_drops = [0.2, 0.2, 0.2]
d_out = 1
reward_model = Predictive_CNN(
d_vocab,
d_embedding,
d_latent,
filters,
kernel_sizes,
strides,
dropouts,
mlp_dims,
mlp_drops,
d_out
)
r_ds = Text_Prediction_Dataset(['M'], [0.], aa_vocab)
r_agent = PredictiveAgent(reward_model, BinaryCrossEntropy(), r_ds, opt_kwargs={'lr':1e-3})
r_agent.load_weights('untracked_files/amp_predictor.pt')
# r_agent.load_state_dict(model_from_url('amp_predictor.pt')) # optional - load exact weights
reward_model.eval();
freeze(reward_model)
class ClippedModelReward():
def __init__(self, agent, minclip, maxclip):
self.agent = agent
self.minclip = minclip
self.maxclip = maxclip
def __call__(self, sequences):
preds = self.agent.predict_data(sequences)
preds = torch.clamp(preds, self.minclip, self.maxclip)
return preds
reward_function = Reward(ClippedModelReward(r_agent, -10, 10), weight=1)
amp_reward = RewardCallback(reward_function, 'amp')
Optional Reward: Stability Metric
There has been a lot of great work recently looking at large scale transformer language models for unsupervised learning of protein structures. One interesting thing that has emerged is a rough relationship between the protein sequence log probability given by a generative model and the stability of the protein sequence.
We can use the log probability values from a pretrained protein transformer model as a proxy for stability. Including this as a reward function can help keep the generated peptides realistic.
To include this as a reward term, run the code below to install the ESM library to access a pretrained protein transformer model.
In the interest of time, we will use the smallest ESM model with 43M parameters, rather than the large scale 630M parameter model. Note that even with the smaller model, this reward term adds significanly to the training runtime
# ! pip install fair-esm
import esm
protein_model, alphabet = esm.pretrained.esm1_t6_43M_UR50S()
batch_converter = alphabet.get_batch_converter()
class PeptideStability():
def __init__(self, model, alphabet, batch_converter):
self.model = model
to_device(self.model)
self.alphabet = alphabet
self.batch_converter = batch_converter
def __call__(self, samples):
data = [
(f'protein{i}', samples[i]) for i in range(len(samples))
]
batch_labels, batch_strs, batch_tokens = self.batch_converter(data)
with torch.no_grad():
results = self.model(to_device(batch_tokens))
lps = F.log_softmax(results['logits'], -1)
mean_lps = lps.gather(2, to_device(batch_tokens).unsqueeze(-1)).squeeze(-1).mean(-1)
return mean_lps
ps = PeptideStability(protein_model, alphabet, batch_converter)
stability_reward = Reward(ps, weight=0.1, bs=300)
stability_cb = RewardCallback(stability_reward, name='stability')
stability_reward(df.sequence.values[:10])
Samplers
We create the following samplers:
sampler1 ModelSampler
: this samples from the main model. This sample will add 1000 compounds to the buffer each buffer build, and sample 40% of each batch on the fly from the main model.sampler2 ModelSampler
: this samples from the baseline model and is not sampled live on each batchsampler3 LogSampler
: this samples high scoring samples from the ligsampler4 TokenSwapSampler
: this uses token swap comibichem to generate new samples from high scoring samplessampler5 DatasetSampler
: this sprinkles in a small amount of known actives into each buffer build. Note that we subset by peptide length to align with the generated length (75 amino acids)
gen_bs = 1500
sampler1 = ModelSampler(agent.vocab, agent.model, 'live', 1000, 0., gen_bs)
sampler2 = ModelSampler(agent.vocab, agent.base_model, 'base', 1000, 0., gen_bs)
sampler3 = LogSampler('samples', 'rewards', 10, 98, 200)
sampler4 = TokenSwapSampler('samples', 'rewards', 10, 98, 200, aa_vocab, 0.2)
sampler5 = DatasetSampler(df[(df.label==1) & (df.sequence.map(lambda x: len(x)<=75))].sequence.values,
'data', buffer_size=4)
samplers = [sampler1, sampler2, sampler3, sampler4, sampler5]
Callbacks
Additional callbacks
SupervisedCB
: runs supervised training on the top 3% of samples every 400 batchesMaxCallback
: prints the max reward for each batchPercentileCallback
: prints the 90th percentile score each batch
supervised_cb = SupervisedCB(agent, 20, 0.5, 98, 1e-4, 64)
live_max = MaxCallback('rewards', 'live')
live_p90 = PercentileCallback('rewards', 'live', 90)
cbs = [supervised_cb, live_p90, live_max]
env = Environment(agent, template_cb, samplers=samplers, rewards=[amp_reward, stability_cb], losses=[loss],
cbs=cbs)
set_global_pool(min(12, os.cpu_count()))
env.fit(128, 75, 300, 20)
env.log.plot_metrics()
Evaluation
Looking at high scoring sequences, we see many sequences contain with a high concentration of K
residues. These features likely come from a small number of dataset samples with high K
contents. For example, AATKPKKAGAEAAPKKPAKKQTKKKPAKKAGGKKKPKRAGAKKAKK
is a AMP sequence in the dataset.
If these traits are undesirable, they can be controlled by updating the Template
to limit other residues the same way we limited A
residues.
env.log.df[env.log.df.rewards>9.8]
aas = ['G', 'V', 'I', 'K']
fig, axes = plt.subplots(2,2, figsize=(12,8))
for i, ax in enumerate(axes.flat):
if i<len(aas):
env.log.df[env.log.df.rewards>9.].samples.map(
lambda x: x.count(aas[i])/len(x)).hist(density=True, alpha=0.5, label=f'Generated {aas[i]} Counts',
ax=ax)
df[df.label==1].sequence.map(
lambda x: x.count(aas[i])/len(x)).hist(density=True, alpha=0.5, label=f'Dataset {aas[i]} Counts',
ax=ax)
ax.legend()
else:
ax.axis('off')