Datasets
Datasets subclass the Pytorch Dataset
class. MRL datasets add a collate function and the Base_Dataset.dataloader
function to easily generate Pytorch dataloaders from the same class
Like all Pytorch datasets, subclass datasets must contain a valid __len__
and __getitem__
method. MRL datasets should also include a new
method.
The purpose of new
is to create a new dataset from new data using the same input arguments and collate function as the current dataset. This is used during generative training to process and batch generated samples to ensure they are processed and batched the same as training data
Text Datasets
Text datasets deal with tokenizing and numericalizing text data, like SMILES strings.
Text_Dataset
returns numericalized SMILES for language modeling.
Text_Prediction_Dataset
returns numericaized SMILES along with some y_val
output value, for tasks like property prediction
df = pd.read_csv('files/smiles.csv')
vocab = CharacterVocab(SMILES_CHAR_VOCAB)
ds = Text_Dataset(df.smiles.values, vocab)
dl = ds.dataloader(16, num_workers=0)
x,y = next(iter(dl))
assert (x[:,1:] == y[:,:-1]).all()
assert sum([len(i) for i in ds.split(0.2)]) == len(ds)
ds = Text_Dataset([(i,i) for i in df.smiles.values], vocab)
dl = ds.dataloader(16, num_workers=0)
x,y = next(iter(dl))
assert (x[:,1:] == y[:,:-1]).all()
ds = Text_Dataset(df.smiles.values, vocab, cache=True)
_ = ds[1]
assert 1 in ds.cache.keys()
ds = Text_Prediction_Dataset(df.smiles.values, [0]*len(df.smiles.values), vocab)
dl = ds.dataloader(16, num_workers=0)
x,y = next(iter(dl))
assert (y == torch.zeros(y.shape).float()).all()
assert sum([len(i) for i in ds.split(0.2)]) == len(ds)
ds = Text_Prediction_Dataset(df.smiles.values, [0]*len(df.smiles.values), vocab, cache=True)
_ = ds[0]
assert 0 in ds.cache.keys()
Vector Datasets
Another common dataset framework is where we are dealing with vectors derived from a molecule. This could be a vector of properties, fingerprints, or any task where a molecule-derived vector is needed.
Vector_Dataset
is a base dataset that simply returns the molecule derived vector
Vec_Recon_Dataset
returns the molecule derived vector and tokenized SMILES strings. This is used for tasks like generating compounds based on an input vector or fingerprint
from mrl.chem import ECFP6
df = pd.read_csv('files/smiles.csv')
ds = Vector_Dataset(df.smiles.values, ECFP6, cache=True)
dl = ds.dataloader(16, num_workers=0)
batch = next(iter(dl))
new1, new2 = ds.split(0.1)
x = next(iter(dl))
assert x.shape==(16,2048)
assert sum([len(i) for i in ds.split(0.2)]) == len(ds)
from mrl.chem import ECFP6
vocab = CharacterVocab(SMILES_CHAR_VOCAB)
ds = Vec_To_Text_Dataset(df.smiles.values, vocab, ECFP6)
dl = ds.dataloader(16, num_workers=0)
x,y = next(iter(dl))
assert len(x)==2
assert (x[0][:,1:] == y[:,:-1]).all()
assert sum([len(i) for i in ds.split(0.2)]) == len(ds)
from mrl.chem import ECFP6
ds = Vec_To_Text_Dataset(
[(df.smiles.values[i],df.smiles.values[i+1]) for i in range(len(df.smiles.values)-1)],
vocab, ECFP6)
dl = ds.dataloader(16, num_workers=0)
x,y = next(iter(dl))
assert len(x)==2
from mrl.chem import ECFP6
ds = Vec_Prediction_Dataset(df.smiles.values, [0 for i in df.smiles.values], ECFP6)
dl = ds.dataloader(16, num_workers=0)
x,y = next(iter(dl))
assert sum([len(i) for i in ds.split(0.2)]) == len(ds)