Polymer Band Gap optimization

Polymer Band Gap Optimization

This tutorial runs an end to end workflow for designing low band gap polymers.

In physics, the Band Gap is the energy gap between electron orbitals. This property is of great interest in the development of organic photovoltaic cells (OPVC). The band gap of a polymer material determines what light spectric can be absorbed by the material to be converted into energy. A material cannot absorb photons with energy less than the band gap. Band gap also influences the voltage and current produces by an OPVC.

Modern OPVCs use multiple materials with different band gaps in a single cell to capture a wider spectrum of light. One issue in materials science for OPVCs is designing low band gap polymers. Low band gap polymers are needed to absorb infrared and near-infrared light.

In this notebook we will use MRL to design low band gap polymers

Performance Notes

The workflow in this notebook is more CPU-constrained than GPU-constrained due to the need to evaluate samples on CPU. If you have a multi-core machine, it is recommended that you uncomment and run the set_global_pool cells in the notebook. This will trigger the use of multiprocessing, which will result in 2-4x speedups.

This notebook may run slow on Collab due to CPU limitations.

If running on Collab, remember to change the runtime tp GPU

import sys
sys.path.append('..')

from mrl.imports import *
from mrl.core import *
from mrl.chem import *
from mrl.templates.all import *

from mrl.torch_imports import *
from mrl.torch_core import *
from mrl.layers import *
from mrl.dataloaders import *
from mrl.g_models.all import *
from mrl.vocab import *
from mrl.policy_gradient import *
from mrl.train.all import *
from mrl.model_zoo import *
from mrl.combichem import *
from sklearn.metrics import r2_score
/home/dmai/miniconda3/envs/mrl/lib/python3.7/importlib/_bootstrap.py:219: RuntimeWarning: to-Python converter for boost::shared_ptr<RDKit::FilterCatalogEntry const> already registered; second conversion method ignored.
  return f(*args, **kwds)
os.makedirs('untracked_files', exist_ok=True)

Data

The dataset comes from Polymer Informatics with Multi-Task Learning. It contains polymer chain band gap values for ~3300 polymers computed with DFT simulations.

# ! wget https://khazana.gatech.edu/dataset/MTL_Khazana.zip --no-check-certificate
# ! unzip MTL_Khazana.zip
# ! rm MTL_Khazana.zip
df = pd.read_csv('export.csv')
df = df[df.property=='Egc']
from rdkit import Chem
def clean_smile(smile):
    mol = to_mol(smile)
    for atom in mol.GetAtoms():
        atom.SetIsotope(0)
    mol = Chem.RemoveHs(mol)
    return to_smile(mol)
df['smiles'] = [clean_smile(i) for i in df.smiles.values]
df.head()
Unnamed: 0 smiles property value
822 822 *C* Egc 6.8972
823 823 *CC(*)C Egc 6.5196
824 824 *CC(*)CC Egc 6.5170
825 825 *CC(*)CCC Egc 6.7336
826 826 *CC(*)CC(C)C Egc 6.7394
df.value.hist()
<AxesSubplot:>

Score Function

Now we want to develop a score function for predicting band gap. We are going to use a MLP-type model to predict band gap from an ECFP6 fingerprint.

We will train on 90% of the data and validate on the 10% held out.

df_train = df.sample(frac=0.9, random_state=42).copy()
df_valid = df[~df.index.isin(df_train.index)].copy()
reward_model = MLP(2048, [1024, 512, 256, 128], 1, [0.2, 0.2, 0.2, 0.2], outrange=[-1,11])
r_ds = Vec_Prediction_Dataset(df_train.smiles.values, df_train.value.values, ECFP6)
r_agent = PredictiveAgent(reward_model, MSELoss(), r_ds, opt_kwargs={'lr':1e-3})
r_agent.train_supervised(32, 10, 1e-3)
Epoch Train Loss Valid Loss Time
0 2.07289 0.72712 00:05
1 1.20504 0.74069 00:05
2 1.02586 0.37059 00:05
3 4.20034 0.37407 00:05
4 1.21133 0.32353 00:05
5 0.53076 0.29154 00:05
6 0.25303 0.33754 00:05
7 0.26530 0.36427 00:05
8 0.21324 0.33180 00:05
9 0.44266 0.42136 00:05

Optional: save score function weights

 

Optional: to load the exact weights used, run the following:

 
valid_ds = Vec_Prediction_Dataset(df_valid.smiles.values, df_valid.value.values, ECFP6)

valid_dl = valid_ds.dataloader(256, num_workers=0, shuffle=False)
r_agent.model.eval()

preds = []
targs = []

with torch.no_grad():
    for i, batch in enumerate(valid_dl):
        batch = to_device(batch)
        x,y = batch
        pred = r_agent.model(x)
        preds.append(pred.detach().cpu())
        targs.append(y.detach().cpu())
        
preds = torch.cat(preds).numpy()
targs = torch.cat(targs).numpy()

preds = preds.squeeze(-1)

Our score function has an r^2 value of about 0.84 on the validation dataset

fig, ax = plt.subplots()

ax.scatter(targs, preds, alpha=0.5, s=1)
plt.xlabel('Target')
plt.ylabel('Prediction')

slope, intercept = np.polyfit(targs, preds, 1)
ax.plot(np.array(ax.get_xlim()), intercept + slope*np.array(ax.get_xlim()), c='r')

plt.text(5., 9., 'R-squared = %0.3f' % r2_score(targs, preds));

We should also take a look at the prediction distribution for band gaps. The predicted values will differ somewhat from the actual values. This will give us a sense of what score we want to see from the model

df['preds'] = r_agent.predict_data(df.smiles.values).detach().cpu().numpy()
df.preds.min()
0.122281075
-np.percentile(-df.preds, 99) # negatives because we want smaller
1.327669801712034

The minimum predicted value is 0.12. A compound scoring 1.3 or lower would be in the top 1% of lowest bandgaps relative to the dataset

Chemical Space

Next we need to develop our chemical space. This is where we decide what compounds will be allowed and what compounds will be removed.

Getting the right filtering parameters makes a huge difference in compound quality. In practice, finding the right constraints is an interative process. First run a generative screen. Then examine the highest scoring compounds and look for undesirable properties or structural features. Then update the template and iterate.

smarts = ['[#6](=[#16])(-[#7])-[#7]',
        '[#6]=[#6]=[#6]',
        '[#7][F,Cl,Br,I]',
        '[*]#[Cl,Br,I]',
        '[*]-[Cl,Br,I]-[*]',
        '[#6;!R]=[#6;!R]-[#6;!R]=[#6;!R]',
        '[#6]#[#6]',
        '[#15]',
        '[#16]',
        '[*]=[#17,#9,#35]',
        '[*]=[*]=[*]',
        '[*]-[#6]=[#6H2]',
        '[#7]~[#8]',
        '[#7]~[#7]',
        '[*;R]=[*;!R]']

template = Template([ValidityFilter(), 
                     SingleCompoundFilter(), 
                     AttachmentFilter(min_val=2, max_val=2),
                     MolWtFilter(None, 600),
                     RingFilter(1,None),
                     ExclusionFilter(smarts, criteria='any')
                    ],
                    [], 
                    fail_score=-1., log=False, use_lookup=False)

template_cb = TemplateCallback(template, prefilter=True)

Load Model

We load the LSTM_LM_Small_PI1M model. This is a basic LSTM-based language model trained on the PI1M polymer library

agent = LSTM_LM_Small_PI1M(drop_scale=0.3, opt_kwargs={'lr':5e-5})

Fine-Tune Model

The pretrained model we loaded is a very general model that can produce a high diversity of structures. However, what we actually want are structues with low band gap. To induce this, we fine-tune on the dataset.

This is actually a step to really think about. Fine-tuning on known low band gap polymers will bias the model towards known chemotypes. This will significantly reduce the number of iterations needed for the model to find high scoring compounds. However, the model will likely converge on compounds similar to known chemotypes.

On the other hand, we could skip this stage of fine-tuning. If we did this, the model would take longer to converge (at least 3x more batches), but we have a much greater chance of finding new chemotypes.

In the interest of time, we do the additional fine-tuning to speed up convergence

agent.update_dataset_from_inputs(df[df.value<6].smiles.values)
agent.train_supervised(32, 6, 1e-4)

agent.update_dataset_from_inputs(df[df.value<4].smiles.values)
agent.train_supervised(32, 5, 1e-4)

agent.update_dataset_from_inputs(df[df.value<2].smiles.values)
agent.train_supervised(32, 3, 1e-4)

agent.base_to_model()
Epoch Train Loss Valid Loss Time
0 0.18745 0.20679 00:07
1 0.19874 0.17411 00:07
2 0.19635 0.16575 00:07
3 0.19232 0.16242 00:07
4 0.27759 0.16143 00:07
5 0.16866 0.16152 00:07
Epoch Train Loss Valid Loss Time
0 0.22635 0.38918 00:06
1 0.19260 0.39359 00:06
2 0.20890 0.39567 00:06
3 0.34051 0.39773 00:06
4 0.18027 0.39808 00:06
Epoch Train Loss Valid Loss Time
0 0.17865 0.27674 00:05
1 0.14835 0.27729 00:05
2 0.16405 0.27732 00:05

Optional: save fine-tuned weights

 
agent.load_weights('untracked_files/finetuned_model.pt')

Reinforcement Learning

Now we enter the reinforcement learning stage

Loss

We use PPO as our policy gradient loss

pg = PPO(0.99,
        0.5,
        lam=0.95,
        v_coef=0.5,
        cliprange=0.3,
        v_cliprange=0.3,
        ent_coef=0.01,
        kl_target=0.03,
        kl_horizon=3000,
        scale_rewards=True)

loss = PolicyLoss(pg, 'PPO', 
                   value_head=ValueHead(256), 
                   v_update_iter=2, 
                   vopt_kwargs={'lr':1e-3})

Reward

Here we pass the reward agent we trained earlier to a callback

reward_model = MLP(2048, [1024, 512, 256, 128], 1, [0.2, 0.2, 0.2, 0.2], outrange=[0,15])

r_ds = Vec_Prediction_Dataset(['C'], [0], partial(failsafe_fp, fp_function=ECFP6))

r_agent = PredictiveAgent(reward_model, MSELoss(), r_ds, opt_kwargs={'lr':1e-3})

r_agent.load_weights('untracked_files/polymer_bandgap.pt')
# r_agent.load_state_dict(model_from_url('polymer_bandgap.pt')) # optional - load exact weights

r_agent.model.eval();

freeze(r_agent.model)

reward = Reward(r_agent.predict_data, weight=-1.)

bg_reward = RewardCallback(reward, '-bandgap')

Samplers

We create the following samplers:

  • sampler1 ModelSampler: this samples from the main model
  • sampler2 ModelSampler: this samples from the baseline model
  • sampler3 LogSampler: this samples high scoring samples from the log
  • sampler4 CombichemSampler: this sampler runs combichem generation on the top scoring compounds. The combination of generative models with combichem greatly accelerates finding high scoring compounds
  • sampler5 DatasetSampler: this sprinkles in a small amount of known actives into each buffer build

Note on the use of combichem. Adding combichem to the sampler mix greatly accelerates convergence to high scoring compounds. However, combichem tends to end up with "weirder" structures that are more exploitative of the score function. These structures can be removed by updating the template

gen_bs = 1500

sampler1 = ModelSampler(agent.vocab, agent.model, 'live', 400, 0., gen_bs)
sampler2 = ModelSampler(agent.vocab, agent.base_model, 'base', 400, 0., gen_bs)
sampler3 = LogSampler('samples', 'rewards', 10, 98, 100)

mutators = [
    ChangeAtom(['6', '7', '8']),
    AppendAtomSingle(['C', 'N', 'O']),
    AppendAtomsDouble(['C', 'N', 'O']),
    DeleteAtom(),
    ChangeBond(),
    InsertAtomSingle(['C', 'N', 'O']),
    InsertAtomDouble(['C', 'N']),
    AddRing(),
    ShuffleNitrogen(10)
]

mc = MutatorCollection(mutators)

crossovers = [FragmentCrossover()]

cbc = CombiChem(mc, crossovers, template=template, rewards=[reward],
                prune_percentile=70, max_library_size=400, log=True, p_explore=0.2)

sampler4 = CombichemSampler(cbc, 20, 98, 0.2, 1, 'rewards', 'combichem')

sampler5 = DatasetSampler(df[df.value<1.5].smiles.values, 'data', buffer_size=4)

samplers = [sampler1, sampler2, sampler3, sampler4, sampler5]

Callbacks

Additional callbacks

  • SupervisedCB: runs supervised training on the top 2% of samples every 20 batches
  • MaxCallback: prints the max reward for each batch
  • PercentileCallback: prints the 90th percentile score each batch
  • Timeout: prohibits training on the same sample more than once every 10 batches
supervised_cb = SupervisedCB(agent, 20, 0.5, 98, 1e-4, 64, epochs=2)
live_max = MaxCallback('rewards', 'live')
live_p90 = PercentileCallback('rewards', 'live', 90)

cbs = [supervised_cb, live_p90, live_max]

Environment and Train

Now we can put together our Environment and run the training process

env = Environment(agent, template_cb, samplers=samplers, rewards=[bg_reward], losses=[loss],
                 cbs=cbs)
set_global_pool(min(10, os.cpu_count()))
env.fit(128, 100, 300, 20)
iterations rewards rewards_final new diversity bs template valid -bandgap PPO rewards_live_p90 rewards_live_max
0 -6.287 -6.287 1.000 1.000 128 0.000 1.000 -6.287 0.229 -4.975 -2.003
20 -5.476 -5.476 0.773 1.000 128 0.000 1.000 -5.476 0.329 -5.301 -1.961
40 -5.352 -5.352 0.695 1.000 128 0.000 1.000 -5.352 0.369 -3.978 -2.827
60 -4.599 -4.599 0.484 1.000 128 0.000 1.000 -4.599 0.557 -2.527 -2.385
80 -4.468 -4.468 0.492 1.000 128 0.000 1.000 -4.468 0.257 -2.674 -1.961
100 -3.612 -3.612 0.562 1.000 128 0.000 1.000 -3.612 0.627 -2.561 -1.961
120 -3.269 -3.269 0.641 1.000 128 0.000 1.000 -3.269 0.765 -2.618 -2.402
140 -2.607 -2.607 0.625 1.000 128 0.000 1.000 -2.607 0.611 -2.693 -2.361
160 -2.803 -2.803 0.562 1.000 128 0.000 1.000 -2.803 0.576 -2.580 -2.276
180 -2.287 -2.287 0.523 1.000 128 0.000 1.000 -2.287 0.515 -2.796 -2.571
200 -2.410 -2.410 0.445 1.000 128 0.000 1.000 -2.410 0.423 -2.071 -1.969
220 -2.199 -2.199 0.352 1.000 128 0.000 1.000 -2.199 0.441 -2.223 -1.973
240 -2.021 -2.021 0.289 1.000 128 0.000 1.000 -2.021 0.597 -2.099 -1.394
260 -2.278 -2.278 0.461 1.000 128 0.000 1.000 -2.278 0.528 -1.494 -0.810
280 -2.349 -2.349 0.422 1.000 128 0.000 1.000 -2.349 0.563 -0.995 -0.785
env.log.plot_metrics()

Evaluation

We determined before that a compound scoring 1.3 or lower would be in the top 1% of samples. Looking at the log, we see ~2000 samples with a score of less than 1.3. All of these samples are novel samples that came from the model run rather than the original dataset (added in by sampler5).

env.log.df[env.log.df.rewards>-1.3].shape
(2021, 7)
subset = env.log.df[(env.log.df.rewards>-.72) & ~(env.log.df.sources=='data_buffer')]
mols = to_mols(subset.samples.values)
legends = [f'{subset.rewards.values[i]:.3f}' for i in range(subset.shape[0])]
draw_mols(mols, legends=legends, sub_img_size=(400,400))