RL Train Cycle Overview
The goal of this tutorial is to walk through the RL fit cycle to familiarize ourselves with the Events
cycle and get a better understanding of how Callback
and Environment
classes work.
Performance Notes
The workflow in this notebook is more CPU-constrained than GPU-constrained due to the need to evaluate samples on CPU. If you have a multi-core machine, it is recommended that you uncomment and run the set_global_pool
cells in the notebook. This will trigger the use of multiprocessing, which will result in 2-4x speedups.
This notebook may run slow on Collab due to CPU limitations.
If running on Collab, remember to change the runtime to GPU
High Level Overview
The Environment
At the highest level, we have the Environment
class. The Environment
holds together several sub-modules and orchestrates them during the fit loop. The following are contained in the Environment
:
agent
- This is the actual model we're trainingtemplate_cb
- this holds aTemplate
class that we use to define our chemical spacesamplers
- samplers generate new samples to train onbuffer
- the buffer collects and distributes samples from all thesamplers
rewards
- rewards score sampleslosses
- losses generate values we can backpropagate throughlog
- the log holds a record of all samples in the training process
Callbacks and the Event Cycle
Each one of the above items is a Callback
. A Callback
is a a general class that can hook into the Environment
fit cycle at a number of pre-defined Events
. When the Environment
calls a specific Event
, the event name is passed to every callback in the Environment
. If a given Callback
has a defined function named after the event, that function is called. This creates a very flexible system for customizing training loops.
We'll be looking more at Events
later. For now, we'll just list them in brief. These are the events called during the RL training cycle in the order they are executed:
setup
- called when theEnvironment
is created, used to set up valuesbefore_train
- called before training is startedbuild_buffer
- draws samples fromsamplers
into thebuffer
filter_buffer
- filters samples in the bufferafter_build_buffer
- called after buffer filtering. Used for cleanup, logging, etcbefore_batch
- called before a batch starts, used to set up thebatch state
sample_batch
- samples are drawn fromsampers
andbuffer
into thebatch state
before_filter_batch
- allows preprocessing of samples before filteringfilter_batch
- filters samples inbatch state
after_sample
- used for calculating sampling metricsbefore_compute_reward
- used to set up any values needed for reward computationcompute_reward
- used byrewards
to compute rewards for all samples in thebatch state
after_compute_reward
- used for logging reward metricsreward_modification
- modify rewards in ways not tracked by the logafter_reward_modification
- log reward modification metricsget_model_outputs
- generate necessary tensors from the modelafter_get_model_outputs
- used for any processing required prior to loss calculationcompute_loss
- compute loss valueszero_grad
- zero gradbefore_step
- used for computation before optimizer step (ie gradient clipping)step
- step optimizerafter_batch
- compute batch statsafter_train
- final event after all training batches
import sys
sys.path.append('..')
from mrl.imports import *
from mrl.core import *
from mrl.chem import *
from mrl.templates.all import *
from mrl.torch_imports import *
from mrl.torch_core import *
from mrl.layers import *
from mrl.dataloaders import *
from mrl.g_models.all import *
from mrl.vocab import *
from mrl.policy_gradient import *
from mrl.train.all import *
from mrl.model_zoo import *
from collections import Counter
Agent
The Agent
is the actual model we want to train. For this example, we will use the LSTM_LM_Small_ZINC
model, which is a LSTM_LM
model trained on a chunk of the ZINC database.
The agent will actually contain two versions of the model. The main model that we will train with every update iteration, and a baseline model which is updated as an exponentially weighted moving average of the main model. Both models are used in the RL training algorithm we will set up later
agent = LSTM_LM_Small_ZINC(drop_scale=0.5,opt_kwargs={'lr':5e-5})
Template
The Template
class is used to conrol the chemical space. We can set parameters on what molecular properties we want to allow. For this example, we set the following:
- Hard Filters - must have qualities
ValidityFilter
- must be a valid chemical structureSingleCompoundFilter
- samples must be single compoundsRotBondFilter
- compounds can have at most 8 rotatable bondsChargeFilter
- compounds must have no net charge
- Soft Filters - nice to have qualities
We then pass the Template
to the TemplateCallback
which integrates the template into the fit loop. Note that we pass prefilter=True
to the TemplateCallback
, which ensures compounds that don't meet our hard filters are removed from training
template = Template([ValidityFilter(),
SingleCompoundFilter(),
RotBondFilter(None, 8),
ChargeFilter(0, 0)],
[QEDFilter(0.5, None, score=1.),
SAFilter(None, 5, score=1.)])
template_cb = TemplateCallback(template, prefilter=True)
class FP_Regression_Score():
def __init__(self, fname):
self.model = torch.load(fname)
self.fp_function = partial(failsafe_fp, fp_function=ECFP6)
def __call__(self, samples):
mols = to_mols(samples)
fps = maybe_parallel(self.fp_function, mols)
fps = [fp_to_array(i) for i in fps]
x_vals = np.stack(fps)
preds = self.model.predict(x_vals)
return preds
# if in the repo
reward_function = FP_Regression_Score('../files/erbB1_regression.sklearn')
# if in Collab:
# download_files()
# reward_function = FP_Regression_Score('files/erbB1_regression.sklearn')
reward = Reward(reward_function, weight=1.)
aff_reward = RewardCallback(reward, 'aff')
We can think of the score function as a black box that takes in samples (SMILES strings) and returns a single numeric score for each sample. Any score function that follows this paradigm can be integrated into MRL
samples = ['Brc1cc2c(NCc3cccs3)ncnc2s1',
'Brc1cc2c(NCc3ccncc3)ncnc2s1']
reward_function(samples)
Loss Function
For our loss, we will use the PPO
reinforcement learning algorithm. See the PPO paper for full details.
The gist of it is the loss function takes a batch of samples and directs he model to increase the probability of above-average samples (relative to the batch mean) and decrease he probability of below-average samples.
pg = PPO(0.99,
0.5,
lam=0.95,
v_coef=0.5,
cliprange=0.3,
v_cliprange=0.3,
ent_coef=0.01,
kl_target=0.03,
kl_horizon=3000,
scale_rewards=True)
loss = PolicyLoss(pg, 'PPO',
value_head=ValueHead(256),
v_update_iter=2,
vopt_kwargs={'lr':1e-3})
Samplers
Samplers
fill the role of generating samples to train on. We will use four samplers for this run:
sampler1
:ModelSampler
- this sampler will draw samples from the main model in theAgent
. We setbuffer_size=1000
, which means we will generate 1000 samples every time we build the buffer. We setp_batch=0.5
, which means during training, 50% of each batch will be sampled on the fly from the main model and the rest of the batch will come from the buffersampler2
:ModelSampler
- this sampler is the same assampler1
, but we draw from the baseline model instead of the main model. We setp_batch=0.
, so this sampler will only contribute to the buffersampler3
:LogSampler
- this sampler looks through the log of previous samples. Based on our input arguments, it grabs the top95
percentile of samples in the log, and randomly selects100
samples from that subsetsampler4
:DatasetSampler
- this sampler is seeded wih erbB1 training data used to train the score function. This sampler will randomly select 4 samples from the dataset to add to the buffer
gen_bs = 1500
# if in the repo
df = pd.read_csv('../files/erbB1_affinity_data.csv')
# if in Collab
# download_files()
# df = pd.read_csv('files/erbB1_affinity_data.csv')
df = df[df.neg_log_ic50>9.2]
sampler1 = ModelSampler(agent.vocab, agent.model, 'live', 1000, 0.5, gen_bs)
sampler2 = ModelSampler(agent.vocab, agent.base_model, 'base', 1000, 0., gen_bs)
sampler3 = LogSampler('samples', 'rewards', 10, 95, 100)
sampler4 = DatasetSampler(df.smiles.values, 'erbB1_data', buffer_size=4)
samplers = [sampler1, sampler2, sampler3, sampler4]
Other Callbacks
We'll add three more callbacks:
MaxCallback
: this will grab the max reward within a batch that came from the sourcelive
.live
is the name we gave tosampler1
above. This means the max callback will grab all outputs fromsampler1
corresponding to samples from the live model and add the largest to the batch metricsPercentileCallback
: this does the same asMaxCallback
but instead of printing the maximum score, it prints the 90th percentile scoreNoveltyReward
: this is reward modification that gives a bonus score of0.05
to new samples (ie samples that haven't appeared before in training)
live_max = MaxCallback('rewards', 'live')
live_p90 = PercentileCallback('rewards', 'live', 90)
new_cb = NoveltyReward(weight=0.05)
cbs = [new_cb, live_p90, live_max]
Setup
The first event occurs when we create our Environment
using the callbacks we set up before. Instantiating the Environment
registers all callbacks and runs the setup
event. Many callbacks use the setup
event to add terms to the batch log or the metrics log.
env = Environment(agent, template_cb, samplers=samplers, rewards=[aff_reward], losses=[loss],
cbs=cbs)
env.buffer
env.buffer.buffer
The Log
holds a number of containers for tracking training outputs
metrics
: dictionary of batch metrics. Each key maps to a list where each value in the list is the metric term for given batchbatch_log
: dictionary of batch items. Each key maps to a list. Each element in that list is a list containing the batch values for that key in a given batchunique_samples
: dictionary of unique samples and the rewards for those samples. Useful for looking up if a sample has been seen beforedf
: dataframe of unique samples and all associated values stored in thebatch_log
We can see that these log terms have already been populated during the setup
event
env.log.metrics
env.log.batch_log
env.log.df
The keys in the above dictionaries were added by the associated callbacks. For example, look at the setup
method in ModelSampler
, the type of sampler we used for sampler1
:
def setup(self):
if self.p_batch>0. and self.track:
log = self.environment.log
log.add_metric(f'{self.name}_diversity')
log.add_metric(f'{self.name}_valid')
log.add_metric(f'{self.name}_rewards')
log.add_metric(f'{self.name}_new')
We gave sampler1
the name live
. As a result, the terms live_diversity
, live_valid
, live_rewards
and live_new
were added to the metrics.
We can also look at the setup
method of our loss function loss
:
def setup(self):
if self.track:
log = self.environment.log
log.add_metric(self.name)
log.add_log(self.name)
This is responsible for the PPO
terms in the batch_log
and the metrics
. The PPO metrics term will store the average PPO loss value across a batch, while the PPO batch log term will store the PPO value for each item in a batch
The Fit Cycle
At this point, we could start training using Environment.fit
. We could call env.fit(200, 90, 10, 2)
to train for 10 batches with a batch size of 200. For this tutorial, we will step through each part of the fit cycle and observe what is happening
Before Train
The first stage of the fit cycle is the before_train
stage. This sets the batch size and sequence length based on the inputs to Environment.fit
(which we will set manually) and prints the top of the log
env.bs = 200 # batch size of 200
env.sl = 90 # max sample length of 90 steps
mb = master_bar(range(1))
env.log.pbar = mb
env.report = 1
env.log.report = 1 # report stats every batch
env('before_train')
env.buffer.buffer
env('build_buffer')
Now we have 2004 items in the buffer.
len(env.buffer.buffer)
We can use the buffer_sources
attribute to see where each item came from. We have 1000 items from live_buffer
which corresponds to sampler1
, sampling from the main model.
We have 1000 items from base_buffer
which corresponds to sampler2
, sampling from the baseline model.
We have 4 items from erbB1_data_buffer
, our dataset sampler (sampler4
).
Our log sampler, sampler3
was set to start sampling after 10 training iterations, so we don't currently have any samples from that sampler
Counter(env.buffer.buffer_sources)
filter_buffer
It's likely some of these samples don't match our compound requirements defined in the Template
we used, so we want to filter the buffer for passing compounds. This is what the filter_buffer
does. For this current example, the only callback doing any buffer filtering is the template callback. However, the filter_buffer
can be used to implement any form of buffer filtering.
Any callback that passes a list of boolean values to Buffer._filter_buffer
can filter the buffer.
After filtering, we have 1829 remaining samples
env('filter_buffer')
len(env.buffer.buffer)
Counter(env.buffer.buffer_sources)
Sample Batch
The next event stage is the sample_batch
stage. This consists of the following events:
before_batch
: set up/refresh any required state prior to batch samplingsample_batch
: draw one batch of samplesbefore_filter_batch
: evaluate unfiltered batchfilter_batch
: filter batchafter_sample
: compute sample based metrics
before_batch
This event is used to create a new BatchState
for the environment. The batch state is a container designed to hold any values required by the batch
env.batch_state = BatchState()
env('before_batch')
Currently the batch state only has placeholder values for commonly generated terms
env.batch_state
sample_batch
Now we actually draw samples to form a batch. All of our Sampler
objects have a p_batch
value, which designated what percentage of the batch should come from that sampler. Batch sampling is designed such that individual sampler p_batch
values are respected, and any remaining batch percentage comes from the buffer.
Only sampler1
has p_batch>0.
, with a value of p_batch=0.5
. This means 50% of the batch will be sampled on he fly from sampler1
, and the remaining 50% of the batch will come from the buffer.
Using a hybrid of live sampling and buffer sampling seems to work best. That said, it is possible to have every batch be 100% buffer samples (like offline RL), or have 100% be live samples (like online RL)
env('sample_batch')
Now we can see we've populated several terms in the batch state. BatchState.samples
now has a list of samples. BatchState.sources
has the source of each sample.
We also added BatchState.live_raw
and BatchState.base_raw
. These terms hold the outputs of sampler1
and sampler2
. When we filter BatchState.samples
, we can refer to the _raw
terms to see what samples were removed.
Note that BatchState.base_raw
is an empty list since sampler2.p_batch=0.
env.batch_state.keys()
BatchState.sources
holds the source of each sample. We have 100 samples from live
, which corresponds to our on the fly samples from sampler1
. The remaining 100 samples come from live_buffer
and base_buffer
. This means they came from either sampler1
(live) or sampler2
(base) by way of being sampled from the buffer
Counter(env.batch_state['sources'])
env.batch_state['samples'][:5]
env.batch_state['sources'][:5]
env.batch_state['live_raw'][:5]
env.batch_state['base_raw']
filter_batch
Now the batch will be filtered by our Template
, as well as any other callbacks with a filter_batch
method
env('filter_batch')
We can see that 13 of our 200 samples were removed by filtering
len(env.batch_state['samples'])
We can compare the values in BatchState.samples
and BatchState.live_raw
to see what was filtered
raw_samples = env.batch_state['live_raw']
filtered_samples = [env.batch_state['samples'][i] for i in range(len(env.batch_state['samples']))
if env.batch_state.sources[i]=='live']
len(filtered_samples), len(raw_samples)
[i for i in raw_samples if not i in filtered_samples]
env('after_sample')
We can see that several values have been added to Environment.log.metrics
new
: percent of samples that have not been seen beforediversity
: number of unique samples relative to the number of total samplesbs
: true batch size after filteringvalid
: percent of samples that passed filteringlive_diversity
: number of unique samples relative to the number of total samples fromsampler1
live_valid
: percent of samples that passed filtering fromsampler1
live_new
: percent of samples that have not been seen before fromsampler1
env.log.metrics
Compute Reward
After we sample a batch, we enter the compute_reward
stage. This consists of the following events:
before_compute_reward
- used to set up any values needed for reward computationcompute_reward
- used byrewards
to compute rewards for all samples in thebatch state
after_compute_reward
- used for logging reward metricsreward_modification
- modify rewards in ways not tracked by the logafter_reward_modification
- log reward modification metrics
before_compute_reward
This event can be used to set up any values needed for reward computation. Most rewards only need the raw samples as inputs, but rewards can use other inputs if needed. The only requirement for a reward is that it returns a tensor with one value per batch item.
By default, the Agent
class will tensorize the samples present at this step. Our PPO
loss will also add placeholder values for the terms needed by that function
env('before_compute_reward')
A number of new items have populated the batch state
env.batch_state.keys()
env.batch_state.x # x tensor
env.batch_state.y # y tensor
env.batch_state.mask # padding mask
compute_reward
This step actually computes rewards. The BatchState
has a tensor of 0s as a placeholder for reward values. Rewards will compute a numeric score for each item in the batch and add it to BatchState.rewards
env.batch_state.rewards
env('compute_reward')
env.batch_state.rewards
So where did these rewards come from?
One reward term comes from our Template
. We specified soft rewards for compounds with QED>=0.5
and SA<=5
. Compounds could score a maximum of 2 from the template.
We also have the reward from the erbB1 regression model we set up earlier.
The specific rewards from each of these sources are logged in the BatchState
For the Template
, we have BatchState.template
and BatchState.template_passes
env.batch_state.keys()
Template scores:
env.batch_state.template
BatchState.template_passes
shows which samples passed the hard filters. Since we decided to prefilter with our template earlier, all remaining samples are passing
env.batch_state.template_passes
And here we have the erbB2 regression scores
env.batch_state.aff
env('after_compute_reward')
env.log.metrics
reward_modification
The reward modification event can be thought of as a second reward that isn't logged. The reason for including this is to allow for transient, "batch context" rewards that don't affect logged values.
When we set up our callbacks earlier, we had a term
new_cb = NoveltyReward(weight=0.05)
Which would add a bonus score of 0.05 to new, never before seen samples. The point of this callback is to give the model a soft incentive to generate novel samples.
We want this score to impact our current batch. However, if we treated it the same as our actual rewards, the samples would be saved into env.log
with their scores inflated by 0.05. Later, when our LogSampler
samples from the log, the sampling would be influenced by a score that was only supposed to be given once.
Separating out rewards and reward modifications lets us avoid this
env('reward_modification')
env.batch_state.novel
env('after_reward_modification')
env.log.metrics
Get Model Outputs
After computing rewards, we move to set up our loss calculation. The get_model_outputs
stage is based on generating the values that we will be backpropagating through. This stage consists of the following events:
get_model_outputs
- generate necessary tensors from the modelafter_get_model_outputs
- used for any processing required prior to loss calculation
get_model_outputs
This is where we generate tensor values used for loss computation.
The specifics of what happens here depends on the type of model used. For autoregressive models, this step involves taking the x
and y
tensors we generated during the before_compute_reward
event and doing a forward pass.
x
is a tensor of size (bs, sl)
. Running x
through the model will give a set of log probabilities of size (bs, sl, d_vocab)
. We then use y
to gather the relevant log probs to get a gathered log prob tensor of size (bs, sl)
.
We generate these values from both the main model and the baseline model
env('get_model_outputs')
env.batch_state.keys()
env.batch_state.model_logprobs.shape, env.batch_state.model_gathered_logprobs.shape
Compute Loss
Now we actually compute a loss value and do an optimizer update. See the PPO
class for a description of the policy gradient algorithm used.
Loss computation consists of the following steps:
compute_loss
- compute loss valueszero_grad
- zero gradbefore_step
- used for computation before optimizer step (ie gradient clipping)step
- step optimizer
compute_loss
When we first created our BatchState
, there was a placehoder value for loss
. This is the value that will ulimately be backpropagated through. This means we can run any sort of loss configuration, so long as the final values end up in BatchState.loss
.
For example, the PPO
policy gradient algorithm we are using involved a ValueHead
that predicts values at every time step. This model is held in the PolicyLoss
callback that holds the PPO
class. During the compute_loss
event, PPO
computes an additional loss for the value head that is added to BatchState.loss
. PolicyLoss
also holds an optimizer for the ValueHead
parameters.
env.batch_state.loss
env('compute_loss')
env.batch_state.loss
zero_grad
This is an event to zero gradients of all optimizers in play. We currently have one optimizer in Agent
for our generative model and one in PolicyLoss
for the ValueHead
of our policy gradient algorithm.
env('zero_grad')
env.batch_state.loss.backward()
env('before_step')
step
This is the actual optimizer step. This will step both the Agent
and PolicyLoss
optimizers
env('step')
env('after_batch')
env.log.df
env('after_train')
Conclusions
Hopefully walking through the training process step by step has made he process more understandable. We conclude by simply running Environment.fit
so we don't have to go through things step by step anymore
env.fit(200, 90, 50, 4)