Overview of the RL training cycle

RL Train Cycle Overview

The goal of this tutorial is to walk through the RL fit cycle to familiarize ourselves with the Events cycle and get a better understanding of how Callback and Environment classes work.

Performance Notes

The workflow in this notebook is more CPU-constrained than GPU-constrained due to the need to evaluate samples on CPU. If you have a multi-core machine, it is recommended that you uncomment and run the set_global_pool cells in the notebook. This will trigger the use of multiprocessing, which will result in 2-4x speedups.

This notebook may run slow on Collab due to CPU limitations.

If running on Collab, remember to change the runtime to GPU

High Level Overview

The Environment

At the highest level, we have the Environment class. The Environment holds together several sub-modules and orchestrates them during the fit loop. The following are contained in the Environment:

  • agent - This is the actual model we're training
  • template_cb - this holds a Template class that we use to define our chemical space
  • samplers - samplers generate new samples to train on
  • buffer - the buffer collects and distributes samples from all the samplers
  • rewards - rewards score samples
  • losses - losses generate values we can backpropagate through
  • log - the log holds a record of all samples in the training process

Callbacks and the Event Cycle

Each one of the above items is a Callback. A Callback is a a general class that can hook into the Environment fit cycle at a number of pre-defined Events. When the Environment calls a specific Event, the event name is passed to every callback in the Environment. If a given Callback has a defined function named after the event, that function is called. This creates a very flexible system for customizing training loops.

We'll be looking more at Events later. For now, we'll just list them in brief. These are the events called during the RL training cycle in the order they are executed:

  • setup - called when the Environment is created, used to set up values
  • before_train - called before training is started
  • build_buffer - draws samples from samplers into the buffer
  • filter_buffer - filters samples in the buffer
  • after_build_buffer - called after buffer filtering. Used for cleanup, logging, etc
  • before_batch - called before a batch starts, used to set up the batch state
  • sample_batch - samples are drawn from sampers and buffer into the batch state
  • before_filter_batch - allows preprocessing of samples before filtering
  • filter_batch - filters samples in batch state
  • after_sample - used for calculating sampling metrics
  • before_compute_reward - used to set up any values needed for reward computation
  • compute_reward - used by rewards to compute rewards for all samples in the batch state
  • after_compute_reward - used for logging reward metrics
  • reward_modification - modify rewards in ways not tracked by the log
  • after_reward_modification - log reward modification metrics
  • get_model_outputs - generate necessary tensors from the model
  • after_get_model_outputs - used for any processing required prior to loss calculation
  • compute_loss - compute loss values
  • zero_grad - zero grad
  • before_step - used for computation before optimizer step (ie gradient clipping)
  • step - step optimizer
  • after_batch - compute batch stats
  • after_train - final event after all training batches
import sys
sys.path.append('..')

from mrl.imports import *
from mrl.core import *
from mrl.chem import *
from mrl.templates.all import *

from mrl.torch_imports import *
from mrl.torch_core import *
from mrl.layers import *
from mrl.dataloaders import *
from mrl.g_models.all import *
from mrl.vocab import *
from mrl.policy_gradient import *
from mrl.train.all import *
from mrl.model_zoo import *
/home/dmai/miniconda3/envs/mrl/lib/python3.7/importlib/_bootstrap.py:219: RuntimeWarning: to-Python converter for boost::shared_ptr<RDKit::FilterCatalogEntry const> already registered; second conversion method ignored.
  return f(*args, **kwds)
from collections import Counter
 

Getting Started

We start by creating all the components we need to train a model

Agent

The Agent is the actual model we want to train. For this example, we will use the LSTM_LM_Small_ZINC model, which is a LSTM_LM model trained on a chunk of the ZINC database.

The agent will actually contain two versions of the model. The main model that we will train with every update iteration, and a baseline model which is updated as an exponentially weighted moving average of the main model. Both models are used in the RL training algorithm we will set up later

agent = LSTM_LM_Small_ZINC(drop_scale=0.5,opt_kwargs={'lr':5e-5})

Template

The Template class is used to conrol the chemical space. We can set parameters on what molecular properties we want to allow. For this example, we set the following:

  • Hard Filters - must have qualities
  • Soft Filters - nice to have qualities
    • QEDFilter - Compounds get a score bonus of +1 if their QED value is greater than 0.5
    • SAFilter - compounds get a score bonus of + if their SA score is less than 5

We then pass the Template to the TemplateCallback which integrates the template into the fit loop. Note that we pass prefilter=True to the TemplateCallback, which ensures compounds that don't meet our hard filters are removed from training

template = Template([ValidityFilter(), 
                     SingleCompoundFilter(), 
                     RotBondFilter(None, 8),
                     ChargeFilter(0, 0)],
                    [QEDFilter(0.5, None, score=1.),
                     SAFilter(None, 5, score=1.)])

template_cb = TemplateCallback(template, prefilter=True)

Reward

For the reward, we will load a scikit-learn linear regression model. This model was trained to predict affinity against erbB1 using molecular fingerprints as inputs

This score function is extremely simple and likely won't translate well to real affinity. It is used as a lightweight example

class FP_Regression_Score():
    def __init__(self, fname):
        self.model = torch.load(fname)
        self.fp_function = partial(failsafe_fp, fp_function=ECFP6)
        
    def __call__(self, samples):
        mols = to_mols(samples)
        fps = maybe_parallel(self.fp_function, mols)
        fps = [fp_to_array(i) for i in fps]
        x_vals = np.stack(fps)
        preds = self.model.predict(x_vals)
        return preds

# if in the repo
reward_function = FP_Regression_Score('../files/erbB1_regression.sklearn')

# if in Collab:
# download_files()
# reward_function = FP_Regression_Score('files/erbB1_regression.sklearn')

reward = Reward(reward_function, weight=1.)

aff_reward = RewardCallback(reward, 'aff')

We can think of the score function as a black box that takes in samples (SMILES strings) and returns a single numeric score for each sample. Any score function that follows this paradigm can be integrated into MRL

samples = ['Brc1cc2c(NCc3cccs3)ncnc2s1',
           'Brc1cc2c(NCc3ccncc3)ncnc2s1']

reward_function(samples)
array([5.33797993, 6.17020286])

Loss Function

For our loss, we will use the PPO reinforcement learning algorithm. See the PPO paper for full details.

The gist of it is the loss function takes a batch of samples and directs he model to increase the probability of above-average samples (relative to the batch mean) and decrease he probability of below-average samples.

pg = PPO(0.99,
        0.5,
        lam=0.95,
        v_coef=0.5,
        cliprange=0.3,
        v_cliprange=0.3,
        ent_coef=0.01,
        kl_target=0.03,
        kl_horizon=3000,
        scale_rewards=True)

loss = PolicyLoss(pg, 'PPO', 
                   value_head=ValueHead(256), 
                   v_update_iter=2, 
                   vopt_kwargs={'lr':1e-3})

Samplers

Samplers fill the role of generating samples to train on. We will use four samplers for this run:

  • sampler1: ModelSampler - this sampler will draw samples from the main model in the Agent. We set buffer_size=1000, which means we will generate 1000 samples every time we build the buffer. We set p_batch=0.5, which means during training, 50% of each batch will be sampled on the fly from the main model and the rest of the batch will come from the buffer
  • sampler2: ModelSampler - this sampler is the same as sampler1, but we draw from the baseline model instead of the main model. We set p_batch=0., so this sampler will only contribute to the buffer
  • sampler3: LogSampler - this sampler looks through the log of previous samples. Based on our input arguments, it grabs the top 95 percentile of samples in the log, and randomly selects 100 samples from that subset
  • sampler4: DatasetSampler - this sampler is seeded wih erbB1 training data used to train the score function. This sampler will randomly select 4 samples from the dataset to add to the buffer
gen_bs = 1500

# if in the repo
df = pd.read_csv('../files/erbB1_affinity_data.csv')

# if in Collab
# download_files()
# df = pd.read_csv('files/erbB1_affinity_data.csv')

df = df[df.neg_log_ic50>9.2]

sampler1 = ModelSampler(agent.vocab, agent.model, 'live', 1000, 0.5, gen_bs)
sampler2 = ModelSampler(agent.vocab, agent.base_model, 'base', 1000, 0., gen_bs)
sampler3 = LogSampler('samples', 'rewards', 10, 95, 100)
sampler4 = DatasetSampler(df.smiles.values, 'erbB1_data', buffer_size=4)

samplers = [sampler1, sampler2, sampler3, sampler4]

Other Callbacks

We'll add three more callbacks:

  • MaxCallback: this will grab the max reward within a batch that came from the source live. live is the name we gave to sampler1 above. This means the max callback will grab all outputs from sampler1 corresponding to samples from the live model and add the largest to the batch metrics
  • PercentileCallback: this does the same as MaxCallback but instead of printing the maximum score, it prints the 90th percentile score
  • NoveltyReward: this is reward modification that gives a bonus score of 0.05 to new samples (ie samples that haven't appeared before in training)
live_max = MaxCallback('rewards', 'live')
live_p90 = PercentileCallback('rewards', 'live', 90)
new_cb = NoveltyReward(weight=0.05)

cbs = [new_cb, live_p90, live_max]

Training Walkthrough

Now we will step through the training cycle looking at how each callback event is used

Setup

The first event occurs when we create our Environment using the callbacks we set up before. Instantiating the Environment registers all callbacks and runs the setup event. Many callbacks use the setup event to add terms to the batch log or the metrics log.

env = Environment(agent, template_cb, samplers=samplers, rewards=[aff_reward], losses=[loss],
                 cbs=cbs)

Inside the environment, we just created a Buffer and a Log.

The Buffer holds a list of samples, which is currently empty

env.buffer
buffer
env.buffer.buffer
[]

The Log holds a number of containers for tracking training outputs

  • metrics: dictionary of batch metrics. Each key maps to a list where each value in the list is the metric term for given batch
  • batch_log: dictionary of batch items. Each key maps to a list. Each element in that list is a list containing the batch values for that key in a given batch
  • unique_samples: dictionary of unique samples and the rewards for those samples. Useful for looking up if a sample has been seen before
  • df: dataframe of unique samples and all associated values stored in the batch_log

We can see that these log terms have already been populated during the setup event

env.log.metrics
{'rewards': [],
 'rewards_final': [],
 'new': [],
 'diversity': [],
 'bs': [],
 'template': [],
 'valid': [],
 'live_diversity': [],
 'live_valid': [],
 'live_rewards': [],
 'live_new': [],
 'aff': [],
 'novel': [],
 'PPO': [],
 'rewards_live_p90': [],
 'rewards_live_max': []}
env.log.batch_log
{'samples': [],
 'sources': [],
 'rewards': [],
 'rewards_final': [],
 'template': [],
 'aff': [],
 'novel': [],
 'PPO': []}
env.log.df
samples sources rewards rewards_final template aff novel PPO

The keys in the above dictionaries were added by the associated callbacks. For example, look at the setup method in ModelSampler, the type of sampler we used for sampler1:

    def setup(self):
        if self.p_batch>0. and self.track:
            log = self.environment.log
            log.add_metric(f'{self.name}_diversity')
            log.add_metric(f'{self.name}_valid')
            log.add_metric(f'{self.name}_rewards')
            log.add_metric(f'{self.name}_new')

We gave sampler1 the name live. As a result, the terms live_diversity, live_valid, live_rewards and live_new were added to the metrics.

We can also look at the setup method of our loss function loss:

    def setup(self):
        if self.track:
            log = self.environment.log
            log.add_metric(self.name)
            log.add_log(self.name)

This is responsible for the PPO terms in the batch_log and the metrics. The PPO metrics term will store the average PPO loss value across a batch, while the PPO batch log term will store the PPO value for each item in a batch

The Fit Cycle

At this point, we could start training using Environment.fit. We could call env.fit(200, 90, 10, 2) to train for 10 batches with a batch size of 200. For this tutorial, we will step through each part of the fit cycle and observe what is happening

Before Train

The first stage of the fit cycle is the before_train stage. This sets the batch size and sequence length based on the inputs to Environment.fit (which we will set manually) and prints the top of the log

env.bs = 200 # batch size of 200
env.sl = 90 # max sample length of 90 steps
mb = master_bar(range(1))
env.log.pbar = mb
env.report = 1
env.log.report = 1 # report stats every batch
env('before_train')

Build Buffer

The next stage of the cycle is the build_buffer stage. This consists of the following events:

  • build_buffer: samplers add items to the buffer
  • filter_buffer: the buffer is filtered
  • after_build_buffer: use as needed

Going into this stage, our buffer is empty:

env.buffer.buffer
[]

build_buffer

By calling the build_buffer event, our samplers will add items to the buffer

env('build_buffer')

Now we have 2004 items in the buffer.

len(env.buffer.buffer)
2004

We can use the buffer_sources attribute to see where each item came from. We have 1000 items from live_buffer which corresponds to sampler1, sampling from the main model.

We have 1000 items from base_buffer which corresponds to sampler2, sampling from the baseline model.

We have 4 items from erbB1_data_buffer, our dataset sampler (sampler4).

Our log sampler, sampler3 was set to start sampling after 10 training iterations, so we don't currently have any samples from that sampler

Counter(env.buffer.buffer_sources)
Counter({'live_buffer': 1000, 'base_buffer': 1000, 'erbB1_data_buffer': 4})

filter_buffer

It's likely some of these samples don't match our compound requirements defined in the Template we used, so we want to filter the buffer for passing compounds. This is what the filter_buffer does. For this current example, the only callback doing any buffer filtering is the template callback. However, the filter_buffer can be used to implement any form of buffer filtering.

Any callback that passes a list of boolean values to Buffer._filter_buffer can filter the buffer.

After filtering, we have 1829 remaining samples

env('filter_buffer')
len(env.buffer.buffer)
1829
Counter(env.buffer.buffer_sources)
Counter({'live_buffer': 922, 'base_buffer': 905, 'erbB1_data_buffer': 2})

after_build_buffer

Next is the after_build_buffer event. None of our current callbacks make use of this event, but it exists to allow for evaluation/postprocessing/whatever after buffer creation.

Sample Batch

The next event stage is the sample_batch stage. This consists of the following events:

  • before_batch: set up/refresh any required state prior to batch sampling
  • sample_batch: draw one batch of samples
  • before_filter_batch: evaluate unfiltered batch
  • filter_batch: filter batch
  • after_sample: compute sample based metrics

before_batch

This event is used to create a new BatchState for the environment. The batch state is a container designed to hold any values required by the batch

env.batch_state = BatchState()
env('before_batch')

Currently the batch state only has placeholder values for commonly generated terms

env.batch_state
{'samples': [],
 'sources': [],
 'rewards': tensor(0., device='cuda:0'),
 'loss': tensor(0., device='cuda:0', grad_fn=<CopyBackwards>),
 'latent_data': {}}

sample_batch

Now we actually draw samples to form a batch. All of our Sampler objects have a p_batch value, which designated what percentage of the batch should come from that sampler. Batch sampling is designed such that individual sampler p_batch values are respected, and any remaining batch percentage comes from the buffer.

Only sampler1 has p_batch>0., with a value of p_batch=0.5. This means 50% of the batch will be sampled on he fly from sampler1, and the remaining 50% of the batch will come from the buffer.

Using a hybrid of live sampling and buffer sampling seems to work best. That said, it is possible to have every batch be 100% buffer samples (like offline RL), or have 100% be live samples (like online RL)

env('sample_batch')

Now we can see we've populated several terms in the batch state. BatchState.samples now has a list of samples. BatchState.sources has the source of each sample.

We also added BatchState.live_raw and BatchState.base_raw. These terms hold the outputs of sampler1 and sampler2. When we filter BatchState.samples, we can refer to the _raw terms to see what samples were removed.

Note that BatchState.base_raw is an empty list since sampler2.p_batch=0.

env.batch_state.keys()
dict_keys(['samples', 'sources', 'rewards', 'loss', 'latent_data', 'live_raw', 'base_raw'])

BatchState.sources holds the source of each sample. We have 100 samples from live, which corresponds to our on the fly samples from sampler1. The remaining 100 samples come from live_buffer and base_buffer. This means they came from either sampler1 (live) or sampler2 (base) by way of being sampled from the buffer

Counter(env.batch_state['sources'])
Counter({'live_buffer': 49, 'base_buffer': 51, 'live': 100})
env.batch_state['samples'][:5]
['COc1ccc2c(c1)OC[C@@H]2CC(=O)N(C)CCN1C(=O)c2ccccc2C1=O',
 'CC(C)CC[C@@](C)(O)CNC(=O)[C@H]1CC[C@H](C(C)C)CC1',
 'CCOCc1nnc(N2CCC(C#N)CC2)n1Cc1coc2c(C)cccc12',
 'N#CC1(NC(=O)[C@H]2CC23CCN(CCOc2ccccc2F)CC3)CCC1',
 'O=C(N[C@H](C1CCC1)C1CC1)C(=O)N1CCn2c(cnc2C(F)(F)F)C1']
env.batch_state['sources'][:5]
['live_buffer', 'live_buffer', 'live_buffer', 'base_buffer', 'live_buffer']
env.batch_state['live_raw'][:5]
['Cc1cc(-c2noc(C[C@H](NC(=O)OC(C)(C)C)c3ccccc3)n2)cc(C)c1Br',
 'O=C(OC[C@@H]1CCC[C@@H](O)C1)c1cnc2ccccn2c1=O',
 'COc1ccc(C[C@@H](C)CNC(=O)N[C@H]2CCc3[nH]ncc3C2)cc1OC',
 'CC[C@@H](C)N1C[C@@]2(CC1=O)COCCN(C(=O)Cc1c(C)noc1Cl)C2',
 'CSc1nccnc1C(=O)N(CCO)C1CCSCC1']
env.batch_state['base_raw']
[]

before_filter_batch

This event is not used by any of our current callbacks. It provides a hook to influence the batch state prior to filtering

filter_batch

Now the batch will be filtered by our Template, as well as any other callbacks with a filter_batch method

env('filter_batch')

We can see that 13 of our 200 samples were removed by filtering

len(env.batch_state['samples'])
187

We can compare the values in BatchState.samples and BatchState.live_raw to see what was filtered

raw_samples = env.batch_state['live_raw']
filtered_samples = [env.batch_state['samples'][i] for i in range(len(env.batch_state['samples'])) 
                    if env.batch_state.sources[i]=='live']

len(filtered_samples), len(raw_samples)
(87, 100)
[i for i in raw_samples if not i in filtered_samples]
['CC[C@@H](C)N1C[C@@]2(CC1=O)COCCN(C(=O)Cc1c(C)noc1Cl)C2',
 'CC[C@@H](Cn1c(CCc2ccccc2)nnc1N(C)Cc1ccc(Cl)s1)N1CCCC1',
 'CCO[C@H]1C[C@H](NC(=O)NCCCc2nc3ccccc3[nH]2)C1(CC)CC',
 'CCOCCCC(=O)N1C[C@@H](C)[C@H](NC(=O)C[C@H](C)c2cnn(C)c2)C1',
 'Cc1nnc(NC(=O)C2(c3cccc(C)c3)CC2)s1',
 'CCN(CC)C[C@H](F)C(=O)N(C)C[C@H]1CCN1C(=O)CC1CC1',
 'CCN(CC)[C@H](CNC(=O)C(=O)NCCC(=O)O)CC(C)C',
 'Cn1cnnc1CN1C[C@@H](O)[C@H](NC(=O)c2ncnc3sccc32)C1',
 'CCC[C@@H](C(=O)NC[C@](C)(NC(=O)CC)C1CC1)c1ccccn1',
 'CC(C)(C(=O)N1CC[C@]2(C1)CN(CC#N)CCO2)c1ccccc1',
 'CCC[C@@H](C)C(=O)N1CCC2(CN(C(=O)[C@H](C)C3CC3)C2)C1',
 'CN(C)S(=O)(=O)CCNC(=O)Nc1ccc(OC(F)F)c(Cc2ccccc2)c1',
 'CCC(CC)(C(=O)N[C@@H](CO)C[C@@H](O)c1ccccc1)c1ccc(OC)cc1',
 'C=C[C@H](CC)CC(=O)NC[C@H]1C[C@@H](NCc2cc(C)no2)C1',
 'CC(C)[C@H]1C[C@H](CC(=O)NCc2ccccc2CN2CCC(C)CC2)CCO1',
 'O=C(COC(=O)c1cccc(COc2ccccc2)c1)NCCOc1ccc(F)cc1',
 'C[C@H](CCCNC[C@@H]1CCCCO1)NC(=O)C(C)(C)C1CCCC1',
 'CCCc1nc(CNC[C@@H](C)NC(=O)[C@@H]2C[C@@H]3O[C@H]2[C@H]2C[C@H]23)cs1',
 'CC[C@@H](NC(=O)C(=O)N1CCC2(C1)CCOCC2)c1cccc(S(N)(=O)=O)c1',
 'CC(C)c1cc(C(=O)N2C[C@@H](C)[C@H](Nc3nc4ccccc4s3)C2)no1',
 'CNS(=O)(=O)CCn1c(-c2cnn(C)c2)nnc1N(C)Cc1ccc(N2CCOCC2)cc1',
 'COC[C@H](C)N1CC2(C1)CCN(C(=O)[C@@H]1CCCOC1)CC2',
 'O=C(CN1CCC(=O)NC1=O)N[C@H]1CCN(CCO)CC1(C)C',
 'CCCN(CCC)C(=O)c1cccc(C(=O)O[C@H](c2nc(C)no2)c2ccccc2)c1']

after_sample

The after_sample event is used to calculate metrics related to sampling

env('after_sample')

We can see that several values have been added to Environment.log.metrics

  • new: percent of samples that have not been seen before
  • diversity: number of unique samples relative to the number of total samples
  • bs: true batch size after filtering
  • valid: percent of samples that passed filtering
  • live_diversity: number of unique samples relative to the number of total samples from sampler1
  • live_valid: percent of samples that passed filtering from sampler1
  • live_new: percent of samples that have not been seen before from sampler1
env.log.metrics
{'rewards': [],
 'rewards_final': [],
 'new': [1.0],
 'diversity': [1.0],
 'bs': [187],
 'template': [],
 'valid': [0.935],
 'live_diversity': [1.0],
 'live_valid': [0.87],
 'live_rewards': [],
 'live_new': [1.0],
 'aff': [],
 'novel': [],
 'PPO': [],
 'rewards_live_p90': [],
 'rewards_live_max': []}

Compute Reward

After we sample a batch, we enter the compute_reward stage. This consists of the following events:

  • before_compute_reward - used to set up any values needed for reward computation
  • compute_reward - used by rewards to compute rewards for all samples in the batch state
  • after_compute_reward - used for logging reward metrics
  • reward_modification - modify rewards in ways not tracked by the log
  • after_reward_modification - log reward modification metrics

before_compute_reward

This event can be used to set up any values needed for reward computation. Most rewards only need the raw samples as inputs, but rewards can use other inputs if needed. The only requirement for a reward is that it returns a tensor with one value per batch item.

By default, the Agent class will tensorize the samples present at this step. Our PPO loss will also add placeholder values for the terms needed by that function

env('before_compute_reward')

A number of new items have populated the batch state

env.batch_state.keys()
dict_keys(['samples', 'sources', 'rewards', 'loss', 'latent_data', 'live_raw', 'base_raw', 'model_gathered_logprobs', 'base_gathered_logprobs', 'mask', 'trajectory_rewards', 'model_logprobs', 'base_logprobs', 'value_input', 'x', 'y', 'bs', 'lengths', 'sl'])
env.batch_state.x # x tensor
tensor([[ 0, 23, 28,  ...,  2,  2,  2],
        [ 0, 23, 23,  ...,  2,  2,  2],
        [ 0, 23, 23,  ...,  2,  2,  2],
        ...,
        [ 0, 23, 31,  ...,  2,  2,  2],
        [ 0, 23, 31,  ...,  2,  2,  2],
        [ 0, 23, 27,  ...,  2,  2,  2]], device='cuda:0')
env.batch_state.y # y tensor
tensor([[23, 28, 34,  ...,  2,  2,  2],
        [23, 23,  5,  ...,  2,  2,  2],
        [23, 23, 28,  ...,  2,  2,  2],
        ...,
        [23, 31, 23,  ...,  2,  2,  2],
        [23, 31, 23,  ...,  2,  2,  2],
        [23, 27,  5,  ...,  2,  2,  2]], device='cuda:0')
env.batch_state.mask # padding mask
tensor([[ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False]], device='cuda:0')

compute_reward

This step actually computes rewards. The BatchState has a tensor of 0s as a placeholder for reward values. Rewards will compute a numeric score for each item in the batch and add it to BatchState.rewards

env.batch_state.rewards
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       device='cuda:0')
env('compute_reward')
env.batch_state.rewards
tensor([ 7.4464,  6.9985,  8.6470,  8.5587,  6.4126,  6.3626,  5.4491,  8.5913,
         7.6887,  7.8485,  7.3070,  8.1959,  7.4571,  4.8321,  7.7518,  6.9646,
         9.0067,  8.7635,  6.6427,  7.1914,  7.6669,  8.5002,  5.4258,  5.9034,
         5.8713,  7.9348,  6.8055,  7.3766,  5.3981,  8.8464, 10.2083,  7.7705,
         4.0919,  5.4432,  7.3088,  7.7280,  9.6633,  8.1869,  6.4295,  4.5877,
         5.8513,  8.9504,  6.9817,  7.2481,  8.7868,  6.8186,  8.5599,  5.9329,
         7.7897,  9.1975,  4.8209,  6.8187,  6.1769,  7.6650,  5.1367,  7.6710,
         7.6972,  8.7095,  6.5353, 11.3099,  5.6652, 10.3340,  9.7048,  9.0019,
         5.3304,  7.3670,  9.6277,  8.0019,  8.7727,  8.0397,  7.7085,  7.4089,
         7.5608,  7.7316,  8.3243,  7.4006,  8.1880,  4.9502,  5.2859,  7.4890,
         6.8271,  8.1306,  6.1215,  7.2989,  5.9260,  8.4519,  8.8245,  5.2587,
         9.2377,  8.6317,  7.1252,  9.7453,  6.5998,  6.4446,  5.7345,  5.8627,
         8.8814,  6.4167,  6.9583,  6.4811,  5.5591,  4.6644,  8.7930,  7.7654,
         7.6589,  8.2830,  7.4545,  7.1201,  8.6511,  8.1746,  7.3143,  6.9259,
         6.7407,  8.6857,  4.5039,  5.8502,  9.2962,  6.1733,  8.3172,  3.5688,
         5.6853,  5.2566,  8.1857,  4.6914, 10.1267,  8.0808,  6.0151,  5.8964,
         6.9508,  4.9608,  4.1395,  6.6439,  6.1228,  5.7914,  5.4085,  8.6273,
         8.1571,  7.5381,  8.7961,  7.3255,  6.3455,  8.1302,  6.9034,  9.5043,
         7.9387,  7.6071,  5.7581, 10.3549,  6.5539,  6.8969,  7.5937,  7.9687,
         6.1901,  2.9819,  8.2540,  6.9759,  9.0442,  6.0747,  5.9112,  4.8929,
         8.3603,  5.3070,  8.4804,  8.1893,  9.2698,  9.5042,  8.8843,  7.7454,
         7.7210,  6.5514,  8.9875,  7.7331,  9.7721,  6.0056,  7.1964,  6.9201,
         6.0658,  6.8281,  6.1915,  4.9288,  7.2785,  8.8447,  4.9078,  8.3447,
         6.4719,  7.8901,  7.1313], device='cuda:0')

So where did these rewards come from?

One reward term comes from our Template. We specified soft rewards for compounds with QED>=0.5 and SA<=5. Compounds could score a maximum of 2 from the template.

We also have the reward from the erbB1 regression model we set up earlier.

The specific rewards from each of these sources are logged in the BatchState

For the Template, we have BatchState.template and BatchState.template_passes

env.batch_state.keys()
dict_keys(['samples', 'sources', 'rewards', 'loss', 'latent_data', 'live_raw', 'base_raw', 'model_gathered_logprobs', 'base_gathered_logprobs', 'mask', 'trajectory_rewards', 'model_logprobs', 'base_logprobs', 'value_input', 'x', 'y', 'bs', 'lengths', 'sl', 'template', 'template_passes', 'aff'])

Template scores:

env.batch_state.template
array([2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
       2., 2., 2., 2., 2., 1., 2., 2., 2., 2., 2., 2., 2., 2., 1., 2., 2.,
       2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 1.,
       2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 1., 2., 2., 2., 2., 2.,
       2., 2., 2., 2., 2., 2., 2., 2., 1., 2., 2., 2., 2., 2., 2., 2., 2.,
       2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 1., 2.,
       2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
       2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 1., 2.,
       2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
       1., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
       2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.])

BatchState.template_passes shows which samples passed the hard filters. Since we decided to prefilter with our template earlier, all remaining samples are passing

env.batch_state.template_passes
array([ True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True])

And here we have the erbB2 regression scores

env.batch_state.aff
tensor([5.4464, 4.9985, 6.6470, 6.5587, 4.4126, 4.3626, 3.4491, 6.5913, 5.6887,
        5.8485, 5.3070, 6.1959, 5.4571, 2.8321, 5.7518, 4.9646, 7.0067, 6.7635,
        4.6427, 5.1914, 5.6669, 6.5002, 4.4258, 3.9034, 3.8713, 5.9348, 4.8055,
        5.3766, 3.3981, 6.8464, 8.2083, 6.7705, 2.0919, 3.4432, 5.3088, 5.7280,
        7.6633, 6.1869, 4.4295, 2.5877, 3.8513, 6.9504, 4.9817, 5.2481, 6.7868,
        4.8186, 6.5599, 3.9329, 5.7897, 7.1975, 3.8209, 4.8187, 4.1769, 5.6650,
        3.1367, 5.6710, 5.6972, 6.7095, 4.5353, 9.3099, 3.6652, 8.3340, 8.7048,
        7.0019, 3.3304, 5.3670, 7.6277, 6.0019, 6.7727, 6.0397, 5.7085, 5.4089,
        5.5608, 5.7316, 6.3243, 5.4006, 7.1880, 2.9502, 3.2859, 5.4890, 4.8271,
        6.1306, 4.1215, 5.2989, 3.9260, 6.4519, 6.8245, 3.2587, 7.2377, 6.6317,
        5.1252, 7.7453, 4.5998, 4.4446, 3.7345, 3.8627, 6.8814, 4.4167, 4.9583,
        4.4811, 4.5591, 2.6644, 6.7930, 5.7654, 5.6589, 6.2830, 5.4545, 5.1201,
        6.6511, 6.1746, 5.3143, 4.9259, 4.7407, 6.6857, 2.5039, 3.8502, 7.2962,
        4.1733, 6.3172, 1.5688, 3.6853, 3.2566, 6.1857, 2.6914, 8.1267, 6.0808,
        4.0151, 3.8964, 4.9508, 2.9608, 2.1395, 4.6439, 4.1228, 3.7914, 4.4085,
        6.6273, 6.1571, 5.5381, 6.7961, 5.3255, 4.3455, 6.1302, 4.9034, 7.5043,
        5.9387, 5.6071, 3.7581, 8.3549, 4.5539, 4.8969, 5.5937, 5.9687, 4.1901,
        1.9819, 6.2540, 4.9759, 7.0442, 4.0747, 3.9112, 2.8929, 6.3603, 3.3070,
        6.4804, 6.1893, 7.2698, 7.5042, 6.8843, 5.7454, 5.7210, 4.5514, 6.9875,
        5.7331, 7.7721, 4.0056, 5.1964, 4.9201, 4.0658, 4.8281, 4.1915, 2.9288,
        5.2785, 6.8447, 2.9078, 6.3447, 4.4719, 5.8901, 5.1313],
       device='cuda:0')

after_compute_reward

This event is used to calculate metrics on the rewards

env('after_compute_reward')
env.log.metrics
{'rewards': [7.2516804],
 'rewards_final': [],
 'new': [1.0],
 'diversity': [1.0],
 'bs': [187],
 'template': [1.9572192513368984],
 'valid': [0.935],
 'live_diversity': [1.0],
 'live_valid': [0.87],
 'live_rewards': [7.1182923],
 'live_new': [1.0],
 'aff': [array(5.294461, dtype=float32)],
 'novel': [],
 'PPO': [],
 'rewards_live_p90': [8.845569610595703],
 'rewards_live_max': [10.354888]}

reward_modification

The reward modification event can be thought of as a second reward that isn't logged. The reason for including this is to allow for transient, "batch context" rewards that don't affect logged values.

When we set up our callbacks earlier, we had a term

new_cb = NoveltyReward(weight=0.05)

Which would add a bonus score of 0.05 to new, never before seen samples. The point of this callback is to give the model a soft incentive to generate novel samples.

We want this score to impact our current batch. However, if we treated it the same as our actual rewards, the samples would be saved into env.log with their scores inflated by 0.05. Later, when our LogSampler samples from the log, the sampling would be influenced by a score that was only supposed to be given once.

Separating out rewards and reward modifications lets us avoid this

env('reward_modification')
env.batch_state.novel
tensor([0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500,
        0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500,
        0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500,
        0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500,
        0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500,
        0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500,
        0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500,
        0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500,
        0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500,
        0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500,
        0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500,
        0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500,
        0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500,
        0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500,
        0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500,
        0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500,
        0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500,
        0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500,
        0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500,
        0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500,
        0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500],
       device='cuda:0')

after_reward_modification

Similar to after_compute_reward, this event can be used to compute stats on reward modifications

env('after_reward_modification')
env.log.metrics
{'rewards': [7.2516804],
 'rewards_final': [7.30168],
 'new': [1.0],
 'diversity': [1.0],
 'bs': [187],
 'template': [1.9572192513368984],
 'valid': [0.935],
 'live_diversity': [1.0],
 'live_valid': [0.87],
 'live_rewards': [7.1182923],
 'live_new': [1.0],
 'aff': [array(5.294461, dtype=float32)],
 'novel': [array(0.05, dtype=float32)],
 'PPO': [],
 'rewards_live_p90': [8.845569610595703],
 'rewards_live_max': [10.354888]}

Get Model Outputs

After computing rewards, we move to set up our loss calculation. The get_model_outputs stage is based on generating the values that we will be backpropagating through. This stage consists of the following events:

  • get_model_outputs - generate necessary tensors from the model
  • after_get_model_outputs - used for any processing required prior to loss calculation

get_model_outputs

This is where we generate tensor values used for loss computation.

The specifics of what happens here depends on the type of model used. For autoregressive models, this step involves taking the x and y tensors we generated during the before_compute_reward event and doing a forward pass.

x is a tensor of size (bs, sl). Running x through the model will give a set of log probabilities of size (bs, sl, d_vocab). We then use y to gather the relevant log probs to get a gathered log prob tensor of size (bs, sl).

We generate these values from both the main model and the baseline model

env('get_model_outputs')
env.batch_state.keys()
dict_keys(['samples', 'sources', 'rewards', 'loss', 'latent_data', 'live_raw', 'base_raw', 'model_gathered_logprobs', 'base_gathered_logprobs', 'mask', 'trajectory_rewards', 'model_logprobs', 'base_logprobs', 'value_input', 'x', 'y', 'bs', 'lengths', 'sl', 'template', 'template_passes', 'aff', 'rewards_final', 'novel', 'model_output', 'model_encoded', 'model_latent', 'y_gumbel', 'base_output', 'base_encoded', 'base_latent', 'state_values', 'ref_state_values'])
env.batch_state.model_logprobs.shape, env.batch_state.model_gathered_logprobs.shape
(torch.Size([187, 74, 47]), torch.Size([187, 74]))

after_get_model_outputs

This event is not used by any of our current callbacks, but can be used for any sort of post-processing needed before loss computation

Compute Loss

Now we actually compute a loss value and do an optimizer update. See the PPO class for a description of the policy gradient algorithm used.

Loss computation consists of the following steps:

  • compute_loss - compute loss values
  • zero_grad - zero grad
  • before_step - used for computation before optimizer step (ie gradient clipping)
  • step - step optimizer

compute_loss

When we first created our BatchState, there was a placehoder value for loss. This is the value that will ulimately be backpropagated through. This means we can run any sort of loss configuration, so long as the final values end up in BatchState.loss.

For example, the PPO policy gradient algorithm we are using involved a ValueHead that predicts values at every time step. This model is held in the PolicyLoss callback that holds the PPO class. During the compute_loss event, PPO computes an additional loss for the value head that is added to BatchState.loss. PolicyLoss also holds an optimizer for the ValueHead parameters.

env.batch_state.loss
tensor(0., device='cuda:0', grad_fn=<CopyBackwards>)
env('compute_loss')
env.batch_state.loss
tensor(0.4805, device='cuda:0', grad_fn=<AddBackward0>)

zero_grad

This is an event to zero gradients of all optimizers in play. We currently have one optimizer in Agent for our generative model and one in PolicyLoss for the ValueHead of our policy gradient algorithm.

env('zero_grad')
env.batch_state.loss.backward()

before_step

This is an event before the actual optimizer step. This is used for things like gradient clipping

env('before_step')

step

This is the actual optimizer step. This will step both the Agent and PolicyLoss optimizers

env('step')

After Batch

The after_batch stage consists of a single after_batch event. This is used for any updates at the end of the batch.

In particular, the Log will update Log.df and the Agent will update he baseline model

env('after_batch')
env.log.df
samples sources rewards rewards_final template aff novel PPO
0 COc1ccc2c(c1)OC[C@@H]2CC(=O)N(C)CCN1C(=O)c2ccc... live_buffer 7.446445 7.496445 2.0 5.446445 0.05 -0.112296
1 CC(C)CC[C@@](C)(O)CNC(=O)[C@H]1CC[C@H](C(C)C)CC1 live_buffer 6.998550 7.048550 2.0 4.998550 0.05 0.199071
2 CCOCc1nnc(N2CCC(C#N)CC2)n1Cc1coc2c(C)cccc12 live_buffer 8.646967 8.696967 2.0 6.646966 0.05 -0.360880
3 N#CC1(NC(=O)[C@H]2CC23CCN(CCOc2ccccc2F)CC3)CCC1 base_buffer 8.558744 8.608745 2.0 6.558744 0.05 -0.420598
4 O=C(N[C@H](C1CCC1)C1CC1)C(=O)N1CCn2c(cnc2C(F)(... live_buffer 6.412636 6.462636 2.0 4.412636 0.05 0.700054
... ... ... ... ... ... ... ... ...
182 O=C(Nc1cccc(C2CCC2)c1)c1ccc2c(c1)COC2 live 4.907760 4.957760 2.0 2.907760 0.05 2.387913
183 CC(C)c1ccc([C@H](NC(=O)c2ccccc2Cl)C(C)C)cc1 live 8.344660 8.394660 2.0 6.344660 0.05 -0.408492
184 C[C@@H](CNC(=O)Cc1c[nH]c2cnccc12)NCCC(F)(F)F live 6.471912 6.521913 2.0 4.471912 0.05 0.576185
185 C[C@@](CO)(NC(=O)c1ccc(=O)n(Cc2ccccc2)n1)c1ccc... live 7.890116 7.940116 2.0 5.890116 0.05 -0.261187
186 CN(CCCCNC(=O)[C@@H]1CC1(C)C)Cc1ccccc1 live 7.131267 7.181267 2.0 5.131267 0.05 0.066176

187 rows × 8 columns

After Train

The after_train event can be used to calculate any final statistics or other values as desired

env('after_train')

Conclusions

Hopefully walking through the training process step by step has made he process more understandable. We conclude by simply running Environment.fit so we don't have to go through things step by step anymore

env.fit(200, 90, 50, 4)
iterations rewards rewards_final new diversity bs template valid live_diversity live_valid live_rewards live_new aff novel PPO rewards_live_p90 rewards_live_max
4 7.312 7.362 1.000 1.000 194 1.964 0.970 1.000 0.940 7.095 1.000 5.348 0.050 0.638 9.296 11.366
8 7.305 7.355 1.000 1.000 189 1.942 0.945 1.000 0.890 7.512 1.000 5.363 0.050 0.504 9.157 11.167
12 7.270 7.320 1.000 1.000 189 1.968 0.945 1.000 0.890 7.160 1.000 5.302 0.050 0.485 9.258 11.448
16 7.299 7.349 1.000 1.000 191 1.974 0.955 1.000 0.910 7.226 1.000 5.325 0.050 0.545 9.175 11.390
20 7.307 7.356 0.984 1.000 193 1.969 0.965 1.000 0.930 7.433 1.000 5.338 0.049 0.470 8.971 10.498
24 7.382 7.431 0.974 1.000 195 1.979 0.975 1.000 0.950 7.321 1.000 5.403 0.049 0.569 9.201 11.858
28 7.403 7.452 0.964 1.000 194 1.964 0.970 1.000 0.940 7.190 1.000 5.440 0.048 0.507 8.878 11.578
32 7.385 7.433 0.969 1.000 191 1.963 0.955 1.000 0.910 7.334 1.000 5.421 0.048 0.488 8.888 11.747
36 7.294 7.343 0.974 1.000 191 1.979 0.955 1.000 0.910 7.212 1.000 5.315 0.049 0.571 8.949 11.737
40 7.057 7.106 0.984 1.000 189 1.968 0.945 1.000 0.890 6.976 1.000 5.089 0.049 0.445 8.795 9.695
44 7.448 7.497 0.985 1.000 196 1.959 0.980 1.000 0.960 7.661 1.000 5.489 0.049 0.619 9.440 12.977
48 7.430 7.479 0.979 1.000 188 1.968 0.940 1.000 0.880 7.287 1.000 5.462 0.049 0.371 9.150 10.677