Polymer Band Gap Optimization
This tutorial runs an end to end workflow for designing low band gap polymers.
In physics, the Band Gap is the energy gap between electron orbitals. This property is of great interest in the development of organic photovoltaic cells (OPVC). The band gap of a polymer material determines what light spectric can be absorbed by the material to be converted into energy. A material cannot absorb photons with energy less than the band gap. Band gap also influences the voltage and current produces by an OPVC.
Modern OPVCs use multiple materials with different band gaps in a single cell to capture a wider spectrum of light. One issue in materials science for OPVCs is designing low band gap polymers. Low band gap polymers are needed to absorb infrared and near-infrared light.
In this notebook we will use MRL to design low band gap polymers
Performance Notes
The workflow in this notebook is more CPU-constrained than GPU-constrained due to the need to evaluate samples on CPU. If you have a multi-core machine, it is recommended that you uncomment and run the set_global_pool
cells in the notebook. This will trigger the use of multiprocessing, which will result in 2-4x speedups.
This notebook may run slow on Collab due to CPU limitations.
If running on Collab, remember to change the runtime tp GPU
import sys
sys.path.append('..')
from mrl.imports import *
from mrl.core import *
from mrl.chem import *
from mrl.templates.all import *
from mrl.torch_imports import *
from mrl.torch_core import *
from mrl.layers import *
from mrl.dataloaders import *
from mrl.g_models.all import *
from mrl.vocab import *
from mrl.policy_gradient import *
from mrl.train.all import *
from mrl.model_zoo import *
from mrl.combichem import *
from sklearn.metrics import r2_score
os.makedirs('untracked_files', exist_ok=True)
Data
The dataset comes from Polymer Informatics with Multi-Task Learning. It contains polymer chain band gap values for ~3300 polymers computed with DFT simulations.
# ! wget https://khazana.gatech.edu/dataset/MTL_Khazana.zip --no-check-certificate
# ! unzip MTL_Khazana.zip
# ! rm MTL_Khazana.zip
df = pd.read_csv('export.csv')
df = df[df.property=='Egc']
from rdkit import Chem
def clean_smile(smile):
mol = to_mol(smile)
for atom in mol.GetAtoms():
atom.SetIsotope(0)
mol = Chem.RemoveHs(mol)
return to_smile(mol)
df['smiles'] = [clean_smile(i) for i in df.smiles.values]
df.head()
df.value.hist()
Score Function
Now we want to develop a score function for predicting band gap. We are going to use a MLP-type model to predict band gap from an ECFP6
fingerprint.
We will train on 90% of the data and validate on the 10% held out.
df_train = df.sample(frac=0.9, random_state=42).copy()
df_valid = df[~df.index.isin(df_train.index)].copy()
reward_model = MLP(2048, [1024, 512, 256, 128], 1, [0.2, 0.2, 0.2, 0.2], outrange=[-1,11])
r_ds = Vec_Prediction_Dataset(df_train.smiles.values, df_train.value.values, ECFP6)
r_agent = PredictiveAgent(reward_model, MSELoss(), r_ds, opt_kwargs={'lr':1e-3})
r_agent.train_supervised(32, 10, 1e-3)
Optional: save score function weights
Optional: to load the exact weights used, run the following:
valid_ds = Vec_Prediction_Dataset(df_valid.smiles.values, df_valid.value.values, ECFP6)
valid_dl = valid_ds.dataloader(256, num_workers=0, shuffle=False)
r_agent.model.eval()
preds = []
targs = []
with torch.no_grad():
for i, batch in enumerate(valid_dl):
batch = to_device(batch)
x,y = batch
pred = r_agent.model(x)
preds.append(pred.detach().cpu())
targs.append(y.detach().cpu())
preds = torch.cat(preds).numpy()
targs = torch.cat(targs).numpy()
preds = preds.squeeze(-1)
Our score function has an r^2 value of about 0.84 on the validation dataset
fig, ax = plt.subplots()
ax.scatter(targs, preds, alpha=0.5, s=1)
plt.xlabel('Target')
plt.ylabel('Prediction')
slope, intercept = np.polyfit(targs, preds, 1)
ax.plot(np.array(ax.get_xlim()), intercept + slope*np.array(ax.get_xlim()), c='r')
plt.text(5., 9., 'R-squared = %0.3f' % r2_score(targs, preds));
We should also take a look at the prediction distribution for band gaps. The predicted values will differ somewhat from the actual values. This will give us a sense of what score we want to see from the model
df['preds'] = r_agent.predict_data(df.smiles.values).detach().cpu().numpy()
df.preds.min()
-np.percentile(-df.preds, 99) # negatives because we want smaller
The minimum predicted value is 0.12
. A compound scoring 1.3
or lower would be in the top 1% of lowest bandgaps relative to the dataset
Chemical Space
Next we need to develop our chemical space. This is where we decide what compounds will be allowed and what compounds will be removed.
Getting the right filtering parameters makes a huge difference in compound quality. In practice, finding the right constraints is an interative process. First run a generative screen. Then examine the highest scoring compounds and look for undesirable properties or structural features. Then update the template and iterate.
smarts = ['[#6](=[#16])(-[#7])-[#7]',
'[#6]=[#6]=[#6]',
'[#7][F,Cl,Br,I]',
'[*]#[Cl,Br,I]',
'[*]-[Cl,Br,I]-[*]',
'[#6;!R]=[#6;!R]-[#6;!R]=[#6;!R]',
'[#6]#[#6]',
'[#15]',
'[#16]',
'[*]=[#17,#9,#35]',
'[*]=[*]=[*]',
'[*]-[#6]=[#6H2]',
'[#7]~[#8]',
'[#7]~[#7]',
'[*;R]=[*;!R]']
template = Template([ValidityFilter(),
SingleCompoundFilter(),
AttachmentFilter(min_val=2, max_val=2),
MolWtFilter(None, 600),
RingFilter(1,None),
ExclusionFilter(smarts, criteria='any')
],
[],
fail_score=-1., log=False, use_lookup=False)
template_cb = TemplateCallback(template, prefilter=True)
Load Model
We load the LSTM_LM_Small_PI1M
model. This is a basic LSTM-based language model trained on the PI1M polymer library
agent = LSTM_LM_Small_PI1M(drop_scale=0.3, opt_kwargs={'lr':5e-5})
Fine-Tune Model
The pretrained model we loaded is a very general model that can produce a high diversity of structures. However, what we actually want are structues with low band gap. To induce this, we fine-tune on the dataset.
This is actually a step to really think about. Fine-tuning on known low band gap polymers will bias the model towards known chemotypes. This will significantly reduce the number of iterations needed for the model to find high scoring compounds. However, the model will likely converge on compounds similar to known chemotypes.
On the other hand, we could skip this stage of fine-tuning. If we did this, the model would take longer to converge (at least 3x more batches), but we have a much greater chance of finding new chemotypes.
In the interest of time, we do the additional fine-tuning to speed up convergence
agent.update_dataset_from_inputs(df[df.value<6].smiles.values)
agent.train_supervised(32, 6, 1e-4)
agent.update_dataset_from_inputs(df[df.value<4].smiles.values)
agent.train_supervised(32, 5, 1e-4)
agent.update_dataset_from_inputs(df[df.value<2].smiles.values)
agent.train_supervised(32, 3, 1e-4)
agent.base_to_model()
Optional: save fine-tuned weights
agent.load_weights('untracked_files/finetuned_model.pt')
Loss
We use PPO
as our policy gradient loss
pg = PPO(0.99,
0.5,
lam=0.95,
v_coef=0.5,
cliprange=0.3,
v_cliprange=0.3,
ent_coef=0.01,
kl_target=0.03,
kl_horizon=3000,
scale_rewards=True)
loss = PolicyLoss(pg, 'PPO',
value_head=ValueHead(256),
v_update_iter=2,
vopt_kwargs={'lr':1e-3})
reward_model = MLP(2048, [1024, 512, 256, 128], 1, [0.2, 0.2, 0.2, 0.2], outrange=[0,15])
r_ds = Vec_Prediction_Dataset(['C'], [0], partial(failsafe_fp, fp_function=ECFP6))
r_agent = PredictiveAgent(reward_model, MSELoss(), r_ds, opt_kwargs={'lr':1e-3})
r_agent.load_weights('untracked_files/polymer_bandgap.pt')
# r_agent.load_state_dict(model_from_url('polymer_bandgap.pt')) # optional - load exact weights
r_agent.model.eval();
freeze(r_agent.model)
reward = Reward(r_agent.predict_data, weight=-1.)
bg_reward = RewardCallback(reward, '-bandgap')
Samplers
We create the following samplers:
sampler1 ModelSampler
: this samples from the main modelsampler2 ModelSampler
: this samples from the baseline modelsampler3 LogSampler
: this samples high scoring samples from the logsampler4 CombichemSampler
: this sampler runs combichem generation on the top scoring compounds. The combination of generative models with combichem greatly accelerates finding high scoring compoundssampler5 DatasetSampler
: this sprinkles in a small amount of known actives into each buffer build
Note on the use of combichem. Adding combichem to the sampler mix greatly accelerates convergence to high scoring compounds. However, combichem tends to end up with "weirder" structures that are more exploitative of the score function. These structures can be removed by updating the template
gen_bs = 1500
sampler1 = ModelSampler(agent.vocab, agent.model, 'live', 400, 0., gen_bs)
sampler2 = ModelSampler(agent.vocab, agent.base_model, 'base', 400, 0., gen_bs)
sampler3 = LogSampler('samples', 'rewards', 10, 98, 100)
mutators = [
ChangeAtom(['6', '7', '8']),
AppendAtomSingle(['C', 'N', 'O']),
AppendAtomsDouble(['C', 'N', 'O']),
DeleteAtom(),
ChangeBond(),
InsertAtomSingle(['C', 'N', 'O']),
InsertAtomDouble(['C', 'N']),
AddRing(),
ShuffleNitrogen(10)
]
mc = MutatorCollection(mutators)
crossovers = [FragmentCrossover()]
cbc = CombiChem(mc, crossovers, template=template, rewards=[reward],
prune_percentile=70, max_library_size=400, log=True, p_explore=0.2)
sampler4 = CombichemSampler(cbc, 20, 98, 0.2, 1, 'rewards', 'combichem')
sampler5 = DatasetSampler(df[df.value<1.5].smiles.values, 'data', buffer_size=4)
samplers = [sampler1, sampler2, sampler3, sampler4, sampler5]
Callbacks
Additional callbacks
SupervisedCB
: runs supervised training on the top 2% of samples every 20 batchesMaxCallback
: prints the max reward for each batchPercentileCallback
: prints the 90th percentile score each batchTimeout
: prohibits training on the same sample more than once every 10 batches
supervised_cb = SupervisedCB(agent, 20, 0.5, 98, 1e-4, 64, epochs=2)
live_max = MaxCallback('rewards', 'live')
live_p90 = PercentileCallback('rewards', 'live', 90)
cbs = [supervised_cb, live_p90, live_max]
env = Environment(agent, template_cb, samplers=samplers, rewards=[bg_reward], losses=[loss],
cbs=cbs)
set_global_pool(min(10, os.cpu_count()))
env.fit(128, 100, 300, 20)
env.log.plot_metrics()
Evaluation
We determined before that a compound scoring 1.3 or lower would be in the top 1% of samples. Looking at the log, we see ~2000 samples with a score of less than 1.3. All of these samples are novel samples that came from the model run rather than the original dataset (added in by sampler5
).
env.log.df[env.log.df.rewards>-1.3].shape
subset = env.log.df[(env.log.df.rewards>-.72) & ~(env.log.df.sources=='data_buffer')]
mols = to_mols(subset.samples.values)
legends = [f'{subset.rewards.values[i]:.3f}' for i in range(subset.shape[0])]
draw_mols(mols, legends=legends, sub_img_size=(400,400))