Pomoter Design
This tutorial runs an end to end workflow for designing promoter sequences
Promoters are regions of the genome where transcription of RNA begins. This makes promoter sequences a region of interest for biotechnology, bioinformatics and medical research.
This notebook uses MRL to design promoter sequences using reinforcement learning
Performance Notes
The workflow in this notebook is more CPU-constrained than GPU-constrained due to the need to evaluate samples on CPU. If you have a multi-core machine, it is recommended that you uncomment and run the set_global_pool
cells in the notebook. This will trigger the use of multiprocessing, which will result in 2-4x speedups.
This notebook may run slow on Collab due to CPU limitations.
If running on Collab, remember to change the runtime to GPU
import sys
sys.path.append('..')
from mrl.imports import *
from mrl.core import *
from mrl.chem import *
from mrl.templates.all import *
from mrl.torch_imports import *
from mrl.torch_core import *
from mrl.layers import *
from mrl.dataloaders import *
from mrl.g_models.all import *
from mrl.vocab import *
from mrl.policy_gradient import *
from mrl.train.all import *
from mrl.model_zoo import *
from sklearn.metrics import roc_curve, auc
from sklearn.metrics import roc_auc_score
os.makedirs('untracked_files', exist_ok=True)
Data
The dataset comes from Classifying Promoters by Interpreting the Hidden Information of DNA Sequences via Deep Learning and Combination of Continuous FastText N-Grams.
The dataset contains ~6700 DNA sequences. ~1400 are strong promoters, ~1791 are weak promoters, and ~3500 are non-promoter sequences.
Side note: the authors released their data as a 381-page PDF document. Who does that? It was so hard to work with!
df = pd.read_csv('../files/promoters.csv')
# if in Collab
# download_files()
# df = pd.read_csv('files/promoters.csv')
df.head()
df.label.value_counts()
df.sequence.map(lambda x: len(x)).value_counts() # most sequences are 81 nucleotides long
For the purpose of this tutorial, we won't look at promoter strength. We will simplify the task to classification between promoter and non-promoter sequences
df['int_label'] = df.label.map(lambda x: 1*(x=='strong' or x=='weak')+0*(x=='non-promoter'))
df.int_label.value_counts()
train_df = df.sample(frac=0.95, random_state=42).copy()
valid_df = df[~df.index.isin(train_df.index)].copy()
n_vocab = KmerVocab(DNA_TRIMERS, 3)
promoter_ds = Text_Prediction_Dataset(train_df.sequence.values, train_df.int_label.values, n_vocab)
test_ds = Text_Prediction_Dataset(valid_df.sequence.values, valid_df.int_label.values, n_vocab)
This is the model we will use:
class Predictive_CNN(nn.Module):
def __init__(self,
d_vocab,
d_embedding,
d_latent,
filters,
kernel_sizes,
strides,
dropouts,
mlp_dims,
mlp_drops,
d_out
):
super().__init__()
self.conv_encoder = Conv_Encoder(
d_vocab,
d_embedding,
d_latent,
filters,
kernel_sizes,
strides,
dropouts,
)
self.mlp_head = MLP(
d_latent,
mlp_dims,
d_out,
mlp_drops
)
def forward(self, x):
encoded = self.conv_encoder(x)
out = self.mlp_head(encoded)
return out
d_vocab = len(n_vocab.itos)
d_embedding = 256
d_latent = 512
filters = [256, 512, 1024]
kernel_sizes = [7, 7, 7]
strides = [2,2,2]
dropouts = [0.2, 0.2, 0.2]
mlp_dims = [512, 256, 128]
mlp_drops = [0.2, 0.2, 0.2]
d_out = 1
promoter_model = Predictive_CNN(
d_vocab,
d_embedding,
d_latent,
filters,
kernel_sizes,
strides,
dropouts,
mlp_dims,
mlp_drops,
d_out
)
r_agent = PredictiveAgent(promoter_model, BinaryCrossEntropy(), promoter_ds, opt_kwargs={'lr':1e-3})
r_agent.train_supervised(32, 10, 1e-3, opt_kwargs={'weight_decay':5e-3})
Optional: save score function weights
Optional: to load the exact weights used, run the following:
valid_dl = test_ds.dataloader(256, shuffle=False)
r_agent.model.eval();
preds = []
targs = []
with torch.no_grad():
for i, batch in enumerate(valid_dl):
batch = to_device(batch)
x,y = batch
pred = r_agent.model(x)
preds.append(pred.detach().cpu())
targs.append(y.detach().cpu())
preds = torch.cat(preds).numpy()
targs = torch.cat(targs).numpy()
fpr, tpr, _ = roc_curve(targs, torch.tensor(preds).sigmoid().squeeze().numpy())
roc_auc = auc(fpr, tpr)
plt.figure()
lw = 2
plt.plot(fpr, tpr, color='darkorange',
lw=lw, label='ROC curve (area = %0.4f)' % roc_auc)
plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver operating characteristic example')
plt.legend(loc="lower right")
plt.show()
template = Template([ValidityFilter(),
CharacterCountFilter(['A', 'T', 'G', 'C'], min_val=0, max_val=0.5,
per_length=True, mode='dna'),
CharacterCountFilter(['AAAA', 'TTTT', 'GGGG', 'CCCC'], min_val=0, max_val=4,
per_length=False, mode='dna'),
CharacterCountFilter(['AAAAAA', 'TTTTTT', 'GGGGGG', 'CCCCCC'], min_val=0, max_val=2,
per_length=False, mode='dna')],
[], fail_score=-1., log=False, use_lookup=False, mode='dna')
template_cb = TemplateCallback(template, prefilter=True)
Load Model
We load the LSTM_LM_Small_HGenome
model. This is a basic LSTM-based language model trained on 400 bp chunks of the human genome
agent = LSTM_LM_Small_HGenome_3Mer(drop_scale=0.3, opt_kwargs={'lr':5e-5})
agent.update_dataset_from_inputs(df[df.int_label==1].sequence.values)
agent.train_supervised(32, 8, 5e-5)
agent.base_to_model()
Optional: save fine-tuned weights
Loss
We use PPO
as our policy gradient loss
pg = PPO(0.99,
0.5,
lam=0.95,
v_coef=0.5,
cliprange=0.3,
v_cliprange=0.3,
ent_coef=0.01,
kl_target=0.03,
kl_horizon=3000,
scale_rewards=True)
loss = PolicyLoss(pg, 'PPO',
value_head=ValueHead(256),
v_update_iter=2,
vopt_kwargs={'lr':1e-3})
Reward
Here we pass the reward agent we trained earlier to a callback.
Since our model is a classification model, we have a decision with respect to what value we output from our score function. We could output the raw logit value from the model, or the sigmoid-scaled classification prediction.
If we use the sigmoid-scaled output, we tend to get lots of samples with scores in the 0.999
range that only differ by a very small amount. This can make it difficult to differentiate the true winners.
If we use the logit output, we can get much better differentiation of samples at the top end. However, we have a much higher chance of finding weird samples that get absurd logit values.
The code below uses the raw logit value, clipped to a range of [-10, 10]
n_vocab = KmerVocab(DNA_TRIMERS, 3)
d_vocab = len(n_vocab.itos)
d_embedding = 256
d_latent = 512
filters = [256, 512, 1024]
kernel_sizes = [7, 7, 7]
strides = [2,2,2]
dropouts = [0.2, 0.2, 0.2]
mlp_dims = [512, 256, 128]
mlp_drops = [0.2, 0.2, 0.2]
d_out = 1
reward_model = Predictive_CNN(
d_vocab,
d_embedding,
d_latent,
filters,
kernel_sizes,
strides,
dropouts,
mlp_dims,
mlp_drops,
d_out
)
r_ds = Text_Prediction_Dataset(['ATC'], [0.], n_vocab)
r_agent = PredictiveAgent(reward_model, BinaryCrossEntropy(), r_ds, opt_kwargs={'lr':1e-3})
r_agent.load_weights('untracked_files/promoter_predictor.pt')
# r_agent.load_state_dict(model_from_url('promoter_predictor.pt')) # optional - load exact weights
reward_model.eval();
freeze(reward_model)
class ClippedModelReward():
def __init__(self, agent, minclip, maxclip):
self.agent = agent
self.minclip = minclip
self.maxclip = maxclip
def __call__(self, sequences):
preds = self.agent.predict_data(sequences)
preds = torch.clamp(preds, self.minclip, self.maxclip)
return preds
reward_function = Reward(ClippedModelReward(r_agent, -10, 10), weight=1)
promoter_reward = RewardCallback(reward_function, 'promoter')
Samplers
We create the following samplers:
sampler1 ModelSampler
: this samples from the main model. This sample will add 1000 compounds to the buffer each buffer build, and sample 40% of each batch on the fly from the main model.sampler2 ModelSampler
: this samples from the baseline model and is not sampled live on each batchsampler3 LogSampler
: this samples high scoring samples from the ligsampler4 TokenSwapSampler
: this uses token swap comibichem to generate new samples from high scoring samplessampler5 DatasetSampler
: this sprinkles in a small amount of known promoters into each buffer build.
gen_bs = 1500
sampler1 = ModelSampler(agent.vocab, agent.model, 'live', 1000, 0., gen_bs)
sampler2 = ModelSampler(agent.vocab, agent.base_model, 'base', 1000, 0., gen_bs)
sampler3 = LogSampler('samples', 'rewards', 10, 98, 200)
sampler4 = TokenSwapSampler('samples', 'rewards', 10, 98, 200, n_vocab, 0.2)
sampler5 = DatasetSampler(df[(df.int_label==1) & (df.sequence.map(lambda x: len(x)<=100))].sequence.values,
'data', buffer_size=6)
samplers = [sampler1, sampler2, sampler3, sampler4, sampler5]
Callbacks
Additional callbacks
SupervisedCB
: runs supervised training on the top 3% of samples every 400 batchesMaxCallback
: prints the max reward for each batchPercentileCallback
: prints the 90th percentile score each batch
supervised_cb = SupervisedCB(agent, 20, 0.5, 98, 1e-4, 64)
live_max = MaxCallback('rewards', 'live')
live_p90 = PercentileCallback('rewards', 'live', 90)
cbs = [supervised_cb, live_p90, live_max]
env = Environment(agent, template_cb, samplers=samplers, rewards=[promoter_reward], losses=[loss],
cbs=cbs)
env.fit(128, 27, 300, 20)
env.log.plot_metrics()
def plot_logo(seqs):
freqs = []
for i in range(len(seqs[0])):
aas = [j[i] for j in seqs]
counts = Counter(aas)
total = sum(counts.values())
norm_counts = defaultdict(lambda: 0)
for k in counts.keys():
norm_counts[k] = counts[k]/total
freqs.append(norm_counts)
aas = ['A', 'T', 'G', 'C']
dfs = []
for i, f in enumerate(freqs):
df_iter = pd.DataFrame([[f[aa] for aa in aas]], columns=aas)
df_iter['Position'] = i
dfs.append(df_iter)
dfs = pd.concat(dfs)
dfs = dfs.set_index('Position')
return dfs.groupby('Position').mean().plot.bar(stacked=True, figsize=(16,8))
plot_logo(df[df.label=='strong'].sequence.values[:800])
plot_logo(env.log.df[env.log.df.promoter>5.5].samples.values)