Datasets
Datasets subclass the Pytorch Dataset class. MRL datasets add a collate function and the Base_Dataset.dataloader function to easily generate Pytorch dataloaders from the same class
Like all Pytorch datasets, subclass datasets must contain a valid __len__ and __getitem__ method. MRL datasets should also include a new method.
The purpose of new is to create a new dataset from new data using the same input arguments and collate function as the current dataset. This is used during generative training to process and batch generated samples to ensure they are processed and batched the same as training data
Text Datasets
Text datasets deal with tokenizing and numericalizing text data, like SMILES strings.
Text_Dataset returns numericalized SMILES for language modeling.
Text_Prediction_Dataset returns numericaized SMILES along with some y_val output value, for tasks like property prediction
df = pd.read_csv('files/smiles.csv')
vocab = CharacterVocab(SMILES_CHAR_VOCAB)
ds = Text_Dataset(df.smiles.values, vocab)
dl = ds.dataloader(16, num_workers=0)
x,y = next(iter(dl))
assert (x[:,1:] == y[:,:-1]).all()
assert sum([len(i) for i in ds.split(0.2)]) == len(ds)
ds = Text_Dataset([(i,i) for i in df.smiles.values], vocab)
dl = ds.dataloader(16, num_workers=0)
x,y = next(iter(dl))
assert (x[:,1:] == y[:,:-1]).all()
ds = Text_Dataset(df.smiles.values, vocab, cache=True)
_ = ds[1]
assert 1 in ds.cache.keys()
ds = Text_Prediction_Dataset(df.smiles.values, [0]*len(df.smiles.values), vocab)
dl = ds.dataloader(16, num_workers=0)
x,y = next(iter(dl))
assert (y == torch.zeros(y.shape).float()).all()
assert sum([len(i) for i in ds.split(0.2)]) == len(ds)
ds = Text_Prediction_Dataset(df.smiles.values, [0]*len(df.smiles.values), vocab, cache=True)
_ = ds[0]
assert 0 in ds.cache.keys()
Vector Datasets
Another common dataset framework is where we are dealing with vectors derived from a molecule. This could be a vector of properties, fingerprints, or any task where a molecule-derived vector is needed.
Vector_Dataset is a base dataset that simply returns the molecule derived vector
Vec_Recon_Dataset returns the molecule derived vector and tokenized SMILES strings. This is used for tasks like generating compounds based on an input vector or fingerprint
from mrl.chem import ECFP6
df = pd.read_csv('files/smiles.csv')
ds = Vector_Dataset(df.smiles.values, ECFP6, cache=True)
dl = ds.dataloader(16, num_workers=0)
batch = next(iter(dl))
new1, new2 = ds.split(0.1)
x = next(iter(dl))
assert x.shape==(16,2048)
assert sum([len(i) for i in ds.split(0.2)]) == len(ds)
from mrl.chem import ECFP6
vocab = CharacterVocab(SMILES_CHAR_VOCAB)
ds = Vec_To_Text_Dataset(df.smiles.values, vocab, ECFP6)
dl = ds.dataloader(16, num_workers=0)
x,y = next(iter(dl))
assert len(x)==2
assert (x[0][:,1:] == y[:,:-1]).all()
assert sum([len(i) for i in ds.split(0.2)]) == len(ds)
from mrl.chem import ECFP6
ds = Vec_To_Text_Dataset(
[(df.smiles.values[i],df.smiles.values[i+1]) for i in range(len(df.smiles.values)-1)],
vocab, ECFP6)
dl = ds.dataloader(16, num_workers=0)
x,y = next(iter(dl))
assert len(x)==2
from mrl.chem import ECFP6
ds = Vec_Prediction_Dataset(df.smiles.values, [0 for i in df.smiles.values], ECFP6)
dl = ds.dataloader(16, num_workers=0)
x,y = next(iter(dl))
assert sum([len(i) for i in ds.split(0.2)]) == len(ds)