Using MRL to design promoters

Pomoter Design

This tutorial runs an end to end workflow for designing promoter sequences

Promoters are regions of the genome where transcription of RNA begins. This makes promoter sequences a region of interest for biotechnology, bioinformatics and medical research.

This notebook uses MRL to design promoter sequences using reinforcement learning

Performance Notes

The workflow in this notebook is more CPU-constrained than GPU-constrained due to the need to evaluate samples on CPU. If you have a multi-core machine, it is recommended that you uncomment and run the set_global_pool cells in the notebook. This will trigger the use of multiprocessing, which will result in 2-4x speedups.

This notebook may run slow on Collab due to CPU limitations.

If running on Collab, remember to change the runtime to GPU

import sys
sys.path.append('..')

from mrl.imports import *
from mrl.core import *
from mrl.chem import *
from mrl.templates.all import *

from mrl.torch_imports import *
from mrl.torch_core import *
from mrl.layers import *
from mrl.dataloaders import *
from mrl.g_models.all import *
from mrl.vocab import *
from mrl.policy_gradient import *
from mrl.train.all import *
from mrl.model_zoo import *
from sklearn.metrics import roc_curve, auc
from sklearn.metrics import roc_auc_score
/home/dmai/miniconda3/envs/mrl/lib/python3.7/importlib/_bootstrap.py:219: RuntimeWarning: to-Python converter for boost::shared_ptr<RDKit::FilterCatalogEntry const> already registered; second conversion method ignored.
  return f(*args, **kwds)
os.makedirs('untracked_files', exist_ok=True)

Data

The dataset comes from Classifying Promoters by Interpreting the Hidden Information of DNA Sequences via Deep Learning and Combination of Continuous FastText N-Grams.

The dataset contains ~6700 DNA sequences. ~1400 are strong promoters, ~1791 are weak promoters, and ~3500 are non-promoter sequences.

Side note: the authors released their data as a 381-page PDF document. Who does that? It was so hard to work with!

df = pd.read_csv('../files/promoters.csv')

# if in Collab
# download_files()
# df = pd.read_csv('files/promoters.csv')
df.head()
sequence id label
0 TAGATGTCCTTGATTAACACCAAAATTAAACCTTTTAAAAACCAGG... ECK120016719 strong
1 AAAGAAAATAATTAATTTTACAGCTGTTAAACCAAACGGTTATAAC... ECK120009966 strong
2 CTGCTGTTCCTTGCGATCGAAAAGATCAAGGGCGGACCGGTATCCG... ECK120010006 strong
3 GCGGAAGCACAAATTGCACCAGGTACGGAACTAAAAGCCGTAGATG... ECK120016583 strong
4 AAATACTTATGGTGCGCTGGCTTCTTTGGAACTTGCGCAGCAATTT... ECK120016567 strong
df.label.value_counts()
non-promoter    3499
weak            1791
strong          1474
Name: label, dtype: int64
df.sequence.map(lambda x: len(x)).value_counts() # most sequences are 81 nucleotides long
81     6761
80        1
140       1
135       1
Name: sequence, dtype: int64

For the purpose of this tutorial, we won't look at promoter strength. We will simplify the task to classification between promoter and non-promoter sequences

df['int_label'] = df.label.map(lambda x: 1*(x=='strong' or x=='weak')+0*(x=='non-promoter'))
df.int_label.value_counts()
0    3499
1    3265
Name: int_label, dtype: int64

Score Function

Now we want to develop a score function for prediting if a sequence is a promoter. We will use a CNN encoder with a MLP head to predict promoter class.

Our input data will be token integers for nucleic acids.

We will train on 95% of the data and validate on the 5% held out.

train_df = df.sample(frac=0.95, random_state=42).copy()
valid_df = df[~df.index.isin(train_df.index)].copy()
n_vocab = KmerVocab(DNA_TRIMERS, 3)

promoter_ds = Text_Prediction_Dataset(train_df.sequence.values, train_df.int_label.values, n_vocab)
test_ds = Text_Prediction_Dataset(valid_df.sequence.values, valid_df.int_label.values, n_vocab)

This is the model we will use:

class Predictive_CNN(nn.Module):
    def __init__(self,
                 d_vocab,
                 d_embedding,
                 d_latent,
                 filters,
                 kernel_sizes,
                 strides,
                 dropouts,
                 mlp_dims,
                 mlp_drops,
                 d_out
                ):
        super().__init__()
        
        self.conv_encoder = Conv_Encoder(
                                        d_vocab,
                                        d_embedding,
                                        d_latent,
                                        filters,
                                        kernel_sizes,
                                        strides,
                                        dropouts,
                                    )
        
        self.mlp_head = MLP(
                            d_latent,
                            mlp_dims,
                            d_out,
                            mlp_drops
                            )
        
    def forward(self, x):
        encoded = self.conv_encoder(x)
        out = self.mlp_head(encoded)
        return out
d_vocab = len(n_vocab.itos)
d_embedding = 256
d_latent = 512
filters = [256, 512, 1024]
kernel_sizes = [7, 7, 7]
strides = [2,2,2]
dropouts = [0.2, 0.2, 0.2]
mlp_dims = [512, 256, 128]
mlp_drops = [0.2, 0.2, 0.2]
d_out = 1


promoter_model = Predictive_CNN(
                    d_vocab,
                    d_embedding,
                    d_latent,
                    filters,
                    kernel_sizes,
                    strides,
                    dropouts,
                    mlp_dims,
                    mlp_drops,
                    d_out
                )
r_agent = PredictiveAgent(promoter_model, BinaryCrossEntropy(), promoter_ds, opt_kwargs={'lr':1e-3})
r_agent.train_supervised(32, 10, 1e-3, opt_kwargs={'weight_decay':5e-3})
Epoch Train Loss Valid Loss Time
0 0.46897 0.37497 00:04
1 0.54312 0.31922 00:04
2 0.56689 0.34757 00:04
3 0.24959 0.16821 00:04
4 0.49990 0.31942 00:04
5 0.18535 0.17889 00:04
6 0.36100 0.21455 00:04
7 0.08074 0.42434 00:04
8 0.29440 0.17360 00:04
9 0.08274 0.09494 00:04

Optional: save score function weights

 

Optional: to load the exact weights used, run the following:

 
valid_dl = test_ds.dataloader(256, shuffle=False)
r_agent.model.eval();

preds = []
targs = []

with torch.no_grad():
    for i, batch in enumerate(valid_dl):
        batch = to_device(batch)
        x,y = batch
        pred = r_agent.model(x)
        preds.append(pred.detach().cpu())
        targs.append(y.detach().cpu())
        
preds = torch.cat(preds).numpy()
targs = torch.cat(targs).numpy()
fpr, tpr, _ = roc_curve(targs, torch.tensor(preds).sigmoid().squeeze().numpy())
roc_auc = auc(fpr, tpr)

plt.figure()
lw = 2
plt.plot(fpr, tpr, color='darkorange',
         lw=lw, label='ROC curve (area = %0.4f)' % roc_auc)
plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver operating characteristic example')
plt.legend(loc="lower right")
plt.show()

Chemical Space

Next we need to develop our chemical space. This is where we decide what sequences will be allowed and what sequences will be removed.

For DNA sequences, we cap the maximum frequency of a specific nucleotide to 50% in a single sequence

template = Template([ValidityFilter(),
                     CharacterCountFilter(['A', 'T', 'G', 'C'], min_val=0, max_val=0.5, 
                                          per_length=True, mode='dna'),
                    CharacterCountFilter(['AAAA', 'TTTT', 'GGGG', 'CCCC'], min_val=0, max_val=4, 
                                          per_length=False, mode='dna'),
                    CharacterCountFilter(['AAAAAA', 'TTTTTT', 'GGGGGG', 'CCCCCC'], min_val=0, max_val=2, 
                                          per_length=False, mode='dna')],
                    [], fail_score=-1., log=False, use_lookup=False, mode='dna')

template_cb = TemplateCallback(template, prefilter=True)

Load Model

We load the LSTM_LM_Small_HGenome model. This is a basic LSTM-based language model trained on 400 bp chunks of the human genome

agent = LSTM_LM_Small_HGenome_3Mer(drop_scale=0.3, opt_kwargs={'lr':5e-5})

Fine-Tune Model

The pretrained model we loaded is a very general model that can produce a high diversity of sequences. However, what we actually want are promoter-like sequences. To help the model converge, we can fine-tune on the promoter dataset

agent.update_dataset_from_inputs(df[df.int_label==1].sequence.values)
agent.train_supervised(32, 8, 5e-5)
agent.base_to_model()
Epoch Train Loss Valid Loss Time
0 4.55166 4.72411 00:07
1 4.40919 4.43935 00:07
2 3.90611 3.88763 00:07
3 3.91249 3.82219 00:07
4 3.84867 3.82571 00:07
5 3.81493 3.81817 00:06
6 3.80724 3.81137 00:06
7 3.84021 3.81283 00:07

Optional: save fine-tuned weights

 
 

Reinforcement Learning

Now we enter the reinforcement learning stage

Loss

We use PPO as our policy gradient loss

pg = PPO(0.99,
        0.5,
        lam=0.95,
        v_coef=0.5,
        cliprange=0.3,
        v_cliprange=0.3,
        ent_coef=0.01,
        kl_target=0.03,
        kl_horizon=3000,
        scale_rewards=True)

loss = PolicyLoss(pg, 'PPO', 
                   value_head=ValueHead(256), 
                   v_update_iter=2, 
                   vopt_kwargs={'lr':1e-3})

Reward

Here we pass the reward agent we trained earlier to a callback.

Since our model is a classification model, we have a decision with respect to what value we output from our score function. We could output the raw logit value from the model, or the sigmoid-scaled classification prediction.

If we use the sigmoid-scaled output, we tend to get lots of samples with scores in the 0.999 range that only differ by a very small amount. This can make it difficult to differentiate the true winners.

If we use the logit output, we can get much better differentiation of samples at the top end. However, we have a much higher chance of finding weird samples that get absurd logit values.

The code below uses the raw logit value, clipped to a range of [-10, 10]

n_vocab = KmerVocab(DNA_TRIMERS, 3)

d_vocab = len(n_vocab.itos)
d_embedding = 256
d_latent = 512
filters = [256, 512, 1024]
kernel_sizes = [7, 7, 7]
strides = [2,2,2]
dropouts = [0.2, 0.2, 0.2]
mlp_dims = [512, 256, 128]
mlp_drops = [0.2, 0.2, 0.2]
d_out = 1


reward_model = Predictive_CNN(
                    d_vocab,
                    d_embedding,
                    d_latent,
                    filters,
                    kernel_sizes,
                    strides,
                    dropouts,
                    mlp_dims,
                    mlp_drops,
                    d_out
                )


r_ds = Text_Prediction_Dataset(['ATC'], [0.], n_vocab)

r_agent = PredictiveAgent(reward_model, BinaryCrossEntropy(), r_ds, opt_kwargs={'lr':1e-3})

r_agent.load_weights('untracked_files/promoter_predictor.pt')
# r_agent.load_state_dict(model_from_url('promoter_predictor.pt')) # optional - load exact weights

reward_model.eval();

freeze(reward_model)

class ClippedModelReward():
    def __init__(self, agent, minclip, maxclip):
        self.agent = agent
        self.minclip = minclip
        self.maxclip = maxclip
        
    def __call__(self, sequences):
        preds = self.agent.predict_data(sequences)
        preds = torch.clamp(preds, self.minclip, self.maxclip)
        return preds
        
reward_function = Reward(ClippedModelReward(r_agent, -10, 10), weight=1)

promoter_reward = RewardCallback(reward_function, 'promoter')

Samplers

We create the following samplers:

  • sampler1 ModelSampler: this samples from the main model. This sample will add 1000 compounds to the buffer each buffer build, and sample 40% of each batch on the fly from the main model.
  • sampler2 ModelSampler: this samples from the baseline model and is not sampled live on each batch
  • sampler3 LogSampler: this samples high scoring samples from the lig
  • sampler4 TokenSwapSampler: this uses token swap comibichem to generate new samples from high scoring samples
  • sampler5 DatasetSampler: this sprinkles in a small amount of known promoters into each buffer build.
gen_bs = 1500

sampler1 = ModelSampler(agent.vocab, agent.model, 'live', 1000, 0., gen_bs)
sampler2 = ModelSampler(agent.vocab, agent.base_model, 'base', 1000, 0., gen_bs)
sampler3 = LogSampler('samples', 'rewards', 10, 98, 200)
sampler4 = TokenSwapSampler('samples', 'rewards', 10, 98, 200, n_vocab, 0.2)
sampler5 = DatasetSampler(df[(df.int_label==1) & (df.sequence.map(lambda x: len(x)<=100))].sequence.values, 
                          'data', buffer_size=6)

samplers = [sampler1, sampler2, sampler3, sampler4, sampler5]

Callbacks

Additional callbacks

supervised_cb = SupervisedCB(agent, 20, 0.5, 98, 1e-4, 64)
live_max = MaxCallback('rewards', 'live')
live_p90 = PercentileCallback('rewards', 'live', 90)

cbs = [supervised_cb, live_p90, live_max]

Environment and Train

Now we can put together our Environment and run the training process

env = Environment(agent, template_cb, samplers=samplers, rewards=[promoter_reward], losses=[loss],
                 cbs=cbs)
 
env.fit(128, 27, 300, 20)
iterations rewards rewards_final new diversity bs template valid promoter PPO rewards_live_p90 rewards_live_max
0 0.870 0.870 1.000 1.000 128 0.000 1.000 0.870 2.774 3.647 4.020
20 1.077 1.077 0.992 1.000 128 0.000 1.000 1.077 2.354 3.524 4.336
40 1.838 1.838 0.961 1.000 128 0.000 1.000 1.838 1.637 3.784 4.161
60 2.141 2.141 0.953 1.000 128 0.000 1.000 2.141 1.697 4.094 5.032
80 2.685 2.685 0.953 1.000 128 0.000 1.000 2.685 1.080 3.960 4.803
100 3.212 3.212 0.898 1.000 128 0.000 1.000 3.212 0.750 4.243 5.120
120 3.422 3.422 0.961 1.000 128 0.000 1.000 3.422 0.442 4.421 4.880
140 3.467 3.467 0.891 1.000 128 0.000 1.000 3.467 0.598 4.339 5.242
160 3.753 3.753 0.906 1.000 128 0.000 1.000 3.753 0.444 4.384 5.118
180 3.923 3.923 0.883 1.000 128 0.000 1.000 3.923 0.305 4.605 5.139
200 3.910 3.910 0.891 1.000 128 0.000 1.000 3.910 0.209 4.595 5.216
220 3.877 3.877 0.906 1.000 128 0.000 1.000 3.877 0.413 4.695 5.219
240 3.797 3.797 0.914 1.000 128 0.000 1.000 3.797 0.418 4.625 5.017
260 4.058 4.058 0.922 1.000 128 0.000 1.000 4.058 0.222 4.606 5.137
280 4.141 4.141 0.859 1.000 128 0.000 1.000 4.141 9181181952.000 4.701 5.173
env.log.plot_metrics()
def plot_logo(seqs):
    freqs = []

    for i in range(len(seqs[0])):
        aas = [j[i] for j in seqs]
        counts = Counter(aas)
        total = sum(counts.values())
        norm_counts = defaultdict(lambda: 0)
        for k in counts.keys():
            norm_counts[k] = counts[k]/total
        freqs.append(norm_counts)
        
    aas = ['A', 'T', 'G', 'C']
    dfs = []

    for i, f in enumerate(freqs):
        df_iter = pd.DataFrame([[f[aa] for aa in aas]], columns=aas)
        df_iter['Position'] = i
        dfs.append(df_iter)
        
    dfs = pd.concat(dfs)
    dfs = dfs.set_index('Position')

    return dfs.groupby('Position').mean().plot.bar(stacked=True, figsize=(16,8))
plot_logo(df[df.label=='strong'].sequence.values[:800])
<AxesSubplot:xlabel='Position'>
plot_logo(env.log.df[env.log.df.promoter>5.5].samples.values)
<AxesSubplot:xlabel='Position'>