Implementing MolDQN with MRL

MolDQN

MolDQN is a actor/critic based RL approach to compound generation. The MolDQN paradigm differs from MRL's generative model approach. For this reason, MolDQN makes for an interesting example of integrating new or different modeling apporaches into MRL.

MRL Approach

MRL focuses mainly on generative models. These are models that produce chemical structures through some sampling process. Generative models give us a probability associated with each sample. This allows us to run an optimization procedure where we tell the generative model to increase the probability of high scoring compounds. The score for each compound is given by some model or function outside the generative model.

The MolDQN Approach

MolDQN uses an actor/critic approach. The actor/critic models are given some molecular state and predict a scalar score value from that state. This means there are no generative models involved. So how do we get new molecule states? MolDQN uses a heuristics based approach to generating new compounds. This works as follows:

  1. Start with some initial state (ie CCC)
  2. Define a list of allowed atom types and bond types (ie ['C', 'O', 'N'], ['single', 'double', 'triple'])
  3. Create all variants of the initial state with 1 atom added or 1 bond added (CCCO, C=CC, CC(C)=N, ...)
  4. Use the model to score each possible next state
  5. Select the state with the highest predicted score

This process lets us build up a set of (state, next_state) sample pairs. Then during optimization, we train the model to predict the value of state plus a portion of the value of next_state. Including a portion of future rewards in the prediction is an important part of MolDQN. We could forego the RL aspect of this entirely and use just the sampling procedure in a beam-search sort of way where we calculate the highest value next state and simply choose that. The flaw with this approach is it can lead to getting stuck in local optima. With RL, we hope the model can learn that choosing a state that has a lower value on the next immediate step can actually be advantgeous as it can lead to an overall higher reward in say 3 steps.

Integrating MolDQN

To integrate the MolDQN approch into MRL, we need to design the following:

  1. A Base_Dataset variant for MolDQN
  2. A Agent variant for MolDQN
  3. A Sampler variant for MolDQN

Performance Notes

The workflow in this notebook is more CPU-constrained than GPU-constrained due to the need to evaluate samples on CPU. If you have a multi-core machine, it is recommended that you uncomment and run the set_global_pool cells in the notebook. This will trigger the use of multiprocessing, which will result in 2-4x speedups.

This notebook may run slow on Collab due to CPU limitations.

If running on Collab, remember to change the runtime to GPU

import sys
sys.path.append('..')

from mrl.imports import *
from mrl.core import *
from mrl.chem import *
from mrl.templates.all import *

from mrl.torch_imports import *
from mrl.torch_core import *
from mrl.layers import *
from mrl.dataloaders import *
from mrl.g_models.all import *
from mrl.vocab import *
from mrl.policy_gradient import *
from mrl.train.all import *
from mrl.model_zoo import *
/home/dmai/miniconda3/envs/mrl/lib/python3.7/importlib/_bootstrap.py:219: RuntimeWarning: to-Python converter for boost::shared_ptr<RDKit::FilterCatalogEntry const> already registered; second conversion method ignored.
  return f(*args, **kwds)

MolDQN Dataset

We need to build a variant of Base_Dataset to work with MolDQN. This means we need to implement the following methods:

  • __len__
  • __getitem__
  • new
  • split_on_idxs

We also need a collate function.

First we need to decide what our samples are going to look like. To be comptible with MRL, samples need to be hashable.

MolDQN requires three pieces of information to make a prediction:

  • Initial state
  • Current step
  • Possible next states

The initial state will be a SMILES string. The dataloader will convert the SMILES string into a fingerprint to be sent to the model.

The current step will be an integer value. The use of this value is discussed more later. In short, MolDQN alters the score of a state based on how many steps the model has been running to ty encourage the model to pursue longer rollout paths. This value will be concatenated to the initial state fingerprint.

The possible next states will be a tuple of SMILES strings. These will be converted into a batch of fingerprints.

With this definition, a sample might look like:

('CCC', 4, ('CCCC', 'CCC=N', 'CCCO'))

Which would be converted to

[FP('CCC'),4], [[FP('CCCC'), 3], [FP('CCC=N'), 3], [FP('CCCO'), 3]]

class DQNDataset(Base_Dataset):
    def __init__(self, samples, fp_function, collate_function):
        super().__init__(collate_function)
        
        self.samples = samples
        self.fp_function = fp_function
        
    def __len__(self):
        return len(self.samples)
    
    def get_state_fps(self, state, steps_left):
        fp = fp_to_array(self.fp_function(state))
        fp = np.append(fp, steps_left)
        return torch.tensor(fp).float()
    
    def get_next_state_fps(self, next_states, steps_left):
        next_fps = [fp_to_array(self.fp_function(i)) for i in next_states]
        next_fps = np.stack([np.append(i, steps_left-1) for i in next_fps])
        return torch.tensor(next_fps).float()
    
    def __getitem__(self, idx):
        sample = self.samples[idx]
        smile, steps_left, next_states = sample
        fp = self.get_state_fps(smile, steps_left)
        next_fps = self.get_next_state_fps(next_states, steps_left)
        
        return fp, next_fps
    
    def new(self, samples):
        return self.__class__(samples, self.fp_function, self.collate_function)
    
    def split_on_idxs(self, train_idxs, valid_idxs):
        
        train_ds = self.new([self.samples[i] for i in train_idxs])
        valid_ds = self.new([self.samples[i] for i in valid_idxs])
        return (train_ds, valid_ds)
        
    
def moldqn_collate(batch):
    fps = torch.stack([i[0] for i in batch])
    
    y_size = max([i[1].shape[0] for i in batch])
    y_block = torch.zeros((len(batch), y_size, fps.shape[-1]))-1
    
    nfps = [i[1] for i in batch]
    for i, nfp in enumerate(nfps):
        num_fps = nfp.shape[0]
        y_block[i, :num_fps] = nfp

    return fps, y_block
ds = DQNDataset([('C', 4, ('CCC'))], ECFP6, moldqn_collate)

MolDQN Agent

Since our data is tensorized into fingerprint vectors, our model will be a MLP-type model. We now need to build an Agent to use this model.

MolDQN uses a double Q learning approach. This fits with the BaselineAgent framework of having a main model and a baseline model.

We build a new BaselineAgent with the following methods implemented:

  • setup: define terms to log for MolDQN
  • before_compute_reward: correctly tensorize samples
  • reward_modification: apply reward scaling (see below)
  • compute_loss: compute MolDQN loss (see below)

Reward Modification

MolDQN modifies rewards based on what step the reward was generated. If we rollout samples over 40 steps, we want to bias the model towards long sample paths rather than ones that terminate early. This is accomplished by scaling the reward by reward_modified = reward * scaling_factor ** steps_left.

The MolDQN authors call this scaling factor a "discount factor". This is extremely annoying because "discount factor" is already a standard concept in RL. The MolDQN authors create a completely different concept then give it the same name, presumably just to confuse you.

Standard RL discounting takes the form reward_discounted[i] = reward[i] + discount_factor*reward[i+1], where the discount factor adds a small amount of future reward to the current reward. MolDQN "discounting" is purely based on the number of steps left, and has the effect of reducing the current reward rather than adding to it.

We implement this as a reward modification. This means we can see the true undiscounted rewards as well as the discounted rewards

Loss Function

Computing the MolDQN loss is done in the following steps:

  1. Grab a batch of (state, steps_left, next_states) samples
  2. Use the baseline model to predict rewards of next_states
  3. Select single next state for each sample based on baseline model predictions
  4. Compute value of state with the main model (model(state))
  5. Compute target reward of target_reward = reward(state) + gamma * baseline_model(next_state)
  6. Compute Huber loss between model(state) and target_reward
model = MLP_Encoder(2049, [1024, 512, 256, 128], 1, [0.1, 0.1, 0.1, 0.1])
class DQNAgent(BaselineAgent):
    def __init__(self, model, dataset, discount,
                 gamma,
                base_update=0.99, base_update_iter=5,
                opt_kwargs={}, clip=1., name='DQN'):
        super().__init__(model=model, 
                         loss_function=None, 
                         dataset=dataset, 
                         base_update=base_update,
                         base_update_iter=base_update_iter, 
                         base_model=True, 
                         opt_kwargs=opt_kwargs, 
                         clip=clip, 
                         name=name)
        
        self.discount = discount
        self.gamma = gamma
  
    def setup(self):
        log = self.environment.log
        log.add_metric('dqn_loss')
        log.add_log('dqn_loss')
        
        log.add_metric('dqn_reward')
        log.add_log('dqn_reward')
        
        log.add_log('full_samples')
        log.add_log('steps_left')
    
    def before_compute_reward(self):
        env = self.environment
        batch_state = env.batch_state
        samples = batch_state.samples
 
        batch_ds = self.dataset.new(samples)
        batch = batch_ds.collate_function([batch_ds[i] for i in range(len(batch_ds))])
        batch = to_device(batch)
        bs = len(batch_ds)
        x,y = batch
            
        batch_state.x = x
        batch_state.y = y
        batch_state.bs = bs
        batch_state.rewards = to_device(torch.zeros(bs))
        batch_state.full_samples = samples
        batch_state.steps_left = x[:,-1]
        batch_state.samples = [i[0] for i in samples]
        
    def reward_modification(self):
        env = self.environment
        batch_state = env.batch_state
        
        x = batch_state.x
        steps_left = x[:,-1]
        batch_state.rewards_final = batch_state.rewards_final * self.discount ** (steps_left)
        
        batch_state['dqn_reward'] = batch_state.rewards_final
        
        env.log.update_metric('dqn_reward', 
                              batch_state.rewards_final.mean().detach().cpu().numpy())
        
    def compute_loss(self):
        env = self.environment
        batch_state = env.batch_state
        
        x = batch_state.x
        y = batch_state.y
        
        steps_left = x[:,-1]
        reward = batch_state.rewards_final
        dones = steps_left<=0
        
        with torch.no_grad():
            
            bs,num_fps,dim = y.shape
            preds = self.base_model(y.view(-1, dim)).view(bs,num_fps)
            mask = ((y==-1).sum(-1))==0
            actions = preds.masked_fill(~mask, float('-inf')).argmax(-1)
            
        act_fps = y[torch.arange(bs), actions]
        
        v_pred = self.model(x).squeeze()
        
        with torch.no_grad():
            v_t1_pred = self.base_model(act_fps).squeeze()
            
        v_target = reward + self.gamma*v_t1_pred*(1-dones.float())
        v_error = v_pred - v_target

        loss = torch.where(
            torch.abs(v_error) < 1.0,
            0.5 * v_error * v_error,
            1.0 * (torch.abs(v_error) - 0.5),
        )
        
        
        self.environment.batch_state.loss += loss.mean()
        self.environment.batch_state['dqn_loss'] = loss.detach().cpu().numpy()
        env.log.update_metric('dqn_loss', 
                              loss.mean().detach().cpu().numpy())
agent = DQNAgent(model, 
                 ds, 
                 0.98,
                 .995, 
                 base_update=0.995, 
                 base_update_iter=1,
                 opt_kwargs={'lr':1e-3}, 
                 clip=1.)

Sampler

Now we need to build a Sampler to generate samples. This sampler needs to follow the "add one atom or bond" rollout procedure for MolDQN. It also needs to generate samples in our pre-determined format of (str[state], int[steps_left], tuple[next_states])

We will use the add_bond_combi and add_atom_combi functions from the chem module to generate next states.

During rollout, we start with some initial state (C, N, or O). We generate all possible next states. We then evaluate those states with the baseline model. With probability p, we select the highest scoring state as predicted by the model. With probability 1-p, we randomly select the next state. This probability is set by the eps parameter and decreased by a factor of 0.99907 every episode.

The build_episode method handles running the rollout.

We also build a DQNLogSampler which will pull high scoring episodes from the log

class DQNSampler(Sampler):
    def __init__(self, agent, buffer_size, eps, min_eps, max_steps=40):
        super().__init__('dqn_sampler', buffer_size=buffer_size)
        
        self.agent = agent
        self.transition_dict = {}
        self.starting_states = ['C', 'N', 'O']
        self.atom_types = ['C', 'N', 'O', -1, -2]
        self.eps = eps
        self.min_eps = min_eps
        self.max_steps = max_steps
                 
        
    def build_buffer(self):
        
        buffer_items = []
        while len(buffer_items) < self.buffer_size:
            buffer_items += self.build_episode('C', self.max_steps)

        if buffer_items:
            self.environment.buffer.add(buffer_items, self.name)
        
        
    def build_episode(self, state, steps_left):

        buffer_items = []
        starting_step = self.max_steps - steps_left
        
        for i in range(starting_step, self.max_steps):
            steps_left = self.max_steps - i
            new_states = self.get_new_states(state)
            buffer_items.append((state, steps_left, new_states))
            
            if len(new_states)>1:
                new_fps = self.agent.dataset.get_next_state_fps(new_states, steps_left)

                with torch.no_grad():
                    new_fps = to_device(new_fps).float()
                    values = self.agent.base_model(new_fps)

                self.eps = max(self.eps, self.min_eps)

                if np.random.uniform() < self.eps:
                    action = np.random.randint(0, values.shape[0])
                else:
                    action = values.argmax()

                state = new_states[action]
            else:
                break
            
        self.eps *= 0.99907
        return buffer_items
        
    def get_new_states(self, state):
        
        if state in self.transition_dict.keys():
            clean_states = self.transition_dict[state]
        else:
            new_states = add_bond_combi(state) + add_atom_combi(state, self.atom_types) + [state]
            new_states = list(set(new_states))
            clean_states = []
            for ns in new_states:
                if (ns is not None) and (not '.' in ns) and (len(ns)>0):
                    clean_states.append(ns)
                    
            clean_states = tuple(clean_states)
            self.transition_dict[state] = clean_states
            
        return clean_states
    
class DQNLogSampler(DQNSampler):
    def __init__(self, agent, percentile, lookup_name, 
                 buffer_size, eps, min_eps, max_steps=40):
        super().__init__(agent, buffer_size, eps, min_eps, max_steps)
        self.percentile = percentile
        self.lookup_name = lookup_name
        
    def build_buffer(self):
        
        df = self.environment.log.df
        
        if df.shape[0]>100:
            df = df[df[self.lookup_name]>np.percentile(df[self.lookup_name].values, self.percentile)]

            buffer_items = []
            while len(buffer_items) < self.buffer_size:

                sample = df.sample(n=1)
                state = sample.samples.values[0]
                steps = int(sample.steps_left.values[0])

                buffer_items += self.build_episode(state, steps)

            if buffer_items:
                self.environment.buffer.add(buffer_items, self.name)
        
        
sampler1 = DQNSampler(agent, 500, 0.999, 0.05, max_steps=40)
sampler2 = LogSampler('full_samples', 'rewards', 10, 95, 100)
sampler3 = DQNLogSampler(agent, 95, 'rewards', 100, 0.999, 0.05, max_steps=40)

samplers = [sampler1, sampler2, sampler3]

Template

Here we define a simple template. For this example, we will only optimize towards QED

template = Template([ValidityFilter(), 
                     SingleCompoundFilter()],
                    [QEDFilter(None, None, score=PassThroughScore())], 
                    fail_score=-1., log=False)

template_cb = TemplateCallback(template, prefilter=False, do_filter=False)

In our samplers, we set the rollout length to be 40 steps. One metric of interest is the average number of steps left for each item in a batch. This shows how well the model is learning to choose longer sample paths, which is directly inventivised by the reward structure.

Here's a quick callback to grab that value

class PathLength(Callback):
    def __init__(self):
        super().__init__(name='avg_steps_left')
        
    def setup(self):
        log = self.environment.log
        log.add_metric(self.name)
        
    def after_compute_reward(self):
        log = self.environment.log
        x = self.environment.batch_state.x
        env.log.update_metric(self.name, x[:,-1].mean().detach().cpu().numpy())

Here's some callbacks for tracking metrics

live_max = MaxCallback('rewards', None)
live_p90 = PercentileCallback('rewards', None, 90)
pl_cb = PathLength()

cbs = [live_p90, live_max, pl_cb]

Environment and Train

Now we can put together our Environment and run the training process

env = Environment(agent, template_cb, samplers=samplers, rewards=[], losses=[],
                 cbs=cbs)
 
env.fit(48, 40, 2000, 50)
iterations rewards rewards_final new diversity bs template valid dqn_loss dqn_reward avg_steps_left rewards_p90 rewards_max
0 0.264 0.171 1.000 1.000 48 0.264 1.000 0.124 0.171 20.646 0.400 0.537
50 0.361 0.264 1.000 1.000 48 0.361 1.000 0.057 0.264 17.000 0.594 0.693
100 0.342 0.242 1.000 1.000 48 0.342 1.000 0.033 0.242 18.292 0.575 0.723
150 0.339 0.249 1.000 1.000 48 0.339 1.000 0.028 0.249 16.542 0.622 0.798
200 0.366 0.266 1.000 1.000 48 0.366 1.000 0.024 0.266 18.208 0.588 0.676
250 0.404 0.281 1.000 1.000 48 0.404 1.000 0.029 0.281 19.250 0.576 0.721
300 0.428 0.331 1.000 1.000 48 0.428 1.000 0.022 0.331 15.229 0.630 0.767
350 0.456 0.342 1.000 1.000 48 0.456 1.000 0.021 0.342 16.229 0.689 0.783
400 0.473 0.354 1.000 1.000 48 0.473 1.000 0.025 0.354 17.000 0.698 0.814
450 0.487 0.364 1.000 1.000 48 0.487 1.000 0.029 0.364 17.458 0.688 0.785
500 0.442 0.326 1.000 1.000 48 0.442 1.000 0.016 0.326 17.417 0.646 0.737
550 0.481 0.348 1.000 1.000 48 0.481 1.000 0.016 0.348 18.875 0.684 0.756
600 0.503 0.371 1.000 1.000 48 0.503 1.000 0.012 0.371 17.083 0.723 0.814
650 0.488 0.361 1.000 1.000 48 0.488 1.000 0.010 0.361 16.729 0.705 0.778
700 0.546 0.417 1.000 1.000 48 0.546 1.000 0.013 0.417 16.333 0.751 0.878
750 0.507 0.382 1.000 1.000 48 0.507 1.000 0.015 0.382 16.708 0.696 0.814
800 0.494 0.380 1.000 1.000 48 0.494 1.000 0.018 0.380 15.521 0.756 0.824
850 0.563 0.446 1.000 1.000 48 0.563 1.000 0.027 0.446 14.375 0.724 0.891
900 0.589 0.462 1.000 1.000 48 0.589 1.000 0.026 0.462 13.521 0.741 0.799
950 0.581 0.466 1.000 1.000 48 0.581 1.000 0.018 0.466 14.000 0.780 0.826
1000 0.543 0.402 1.000 1.000 48 0.543 1.000 0.024 0.402 17.438 0.741 0.840
1050 0.583 0.455 1.000 1.000 48 0.583 1.000 0.018 0.455 14.896 0.764 0.844
1100 0.572 0.442 1.000 1.000 48 0.572 1.000 0.026 0.442 15.708 0.763 0.803
1150 0.583 0.453 1.000 1.000 48 0.583 1.000 0.021 0.453 16.146 0.779 0.828
1200 0.535 0.417 1.000 1.000 48 0.535 1.000 0.028 0.417 13.958 0.769 0.846
1250 0.622 0.474 1.000 1.000 48 0.622 1.000 0.019 0.474 16.312 0.778 0.857
1300 0.595 0.468 1.000 1.000 48 0.595 1.000 0.027 0.468 14.625 0.756 0.826
1350 0.549 0.405 1.000 1.000 48 0.549 1.000 0.031 0.405 17.708 0.757 0.844
1400 0.624 0.500 1.000 1.000 48 0.624 1.000 0.046 0.500 13.208 0.772 0.856
1450 0.618 0.471 1.000 1.000 48 0.618 1.000 0.029 0.471 16.938 0.833 0.861
1500 0.590 0.456 1.000 1.000 48 0.590 1.000 0.038 0.456 15.229 0.806 0.927
1550 0.580 0.451 1.000 1.000 48 0.580 1.000 0.031 0.451 15.229 0.788 0.822
1600 0.661 0.544 1.000 1.000 48 0.661 1.000 0.054 0.544 12.396 0.864 0.926
1650 0.626 0.484 1.000 1.000 48 0.626 1.000 0.040 0.484 15.500 0.831 0.927
1700 0.649 0.511 1.000 1.000 48 0.649 1.000 0.045 0.511 14.208 0.835 0.892
1750 0.664 0.509 1.000 1.000 48 0.664 1.000 0.031 0.509 14.812 0.807 0.872
1800 0.671 0.541 1.000 1.000 48 0.671 1.000 0.028 0.541 13.312 0.843 0.904
1850 0.656 0.510 1.000 1.000 48 0.656 1.000 0.054 0.510 15.896 0.831 0.901
1900 0.719 0.579 1.000 1.000 48 0.719 1.000 0.044 0.579 12.896 0.859 0.883
1950 0.652 0.503 1.000 1.000 48 0.652 1.000 0.040 0.503 15.208 0.863 0.907
env.log.plot_metrics()