Pretrained Models
The use of pretrained models has had a pronounced impact on computer vision and natural language processing. Pretrained models are initially trained on large datasets. This allows the models to develop high quality representations built on a large chunk of the data landscape. These models can then be applied to downstream tasks that may be more limited in terms of available data. Fine-tuning a pretrained model on a smaller dataset almost always results in improved performance compared to training from scratch on a small dataset.
The MRL library provides a large number of pretrained generative models for multiple data modalities. These models have been trained on large datasets to generate small molecules (whole compounds, r-groups, linkers), polymers (generated as monomers), nucleic acid sequences and protein sequences. The full list of pretrained models can be found at the Model Zoo page.
When using these models, it is common to first fine-tune them on a dataset that better represents the desired search space. This tutorial shows how to load a pretrained model and fine-tune it on a new dataset.
import sys
sys.path.append('..')
from mrl.imports import *
from mrl.chem import *
from mrl.model_zoo import LSTM_LM_Small_ZINC
Here we load the LSTM_LM_Small_ZINC
pretrained model. This is a 3-layer LSTM model trained on ~79 million compounds from the ZINC library
agent = LSTM_LM_Small_ZINC()
agent.model
We can generate compounds by sampling from the model
smiles = agent.sample_and_reconstruct(100, 100)
mols = to_mols(smiles)
draw_mols(mols[:4], mols_per_row=4)
Now we read in our dataset. This is just a toy dataset to illustrate the API. More realistically, this could be a designed library. For example, a library containing molecules built on a specific scaffold of interest
df = pd.read_csv('../files/smiles.csv')
# if in Collab:
# download_files()
# df = pd.read_csv('files/smiles.csv')
df
Here we load our dataset into the agent
agent.update_dataset_from_inputs(df.smiles.values)
Now we train. An important question here is how much fine-tuning should be done. This is something you have to feel out based on the size of your fine-tuning dataset and the needs of the downstream tasks. Fine-tuning extensively on a small dataset can lead to mode collapse in the resulting model. On the other hand, not doing enough fine-tuning can lead to a model that fails to generate compounds of interest.
As a general rule of thumb, 1e-5
works well as an initial learning rate for fine-tuning. Depending on your dataset, higher learning rates (1e-4
or 1e-3
) may be appropriate.
agent.train_supervised(32, 1, 1e-5)
When first fine-tuning on a new dataset, a good approach is to fine-tune at a lower learning rate for 1 epoch at a time, monitoring the generated compounds after each training epoch. Generate 2000+ compounds and check the generated dataset for percent unique and percent valid compounds.
If the fine-tuning dataset represents a specific, opinionated chemical space (ie specific scaffolds), spot check compounds to see if the model has adapted to the fine-tuning dataset.
Continue fine-tuning until generated compounds look acceptable, so long as percent unique and percent valid metrics are high
smiles = agent.batch_sample_and_reconstruct(4000, 512, 90)
mols = to_mols(smiles)
percent_valid = len([i for i in mols if i is not None])/len(mols)
percent_unique = len(set(smiles))/len(smiles)
print(f'Percent Valid: {percent_valid:.3f}, Percent Unique: {percent_unique:.3f}')