Pytorch datasets, dataloaders, collate functions and vocabularies

Collate Functions

Collate functions are used to batch Dataset outputs into batches

batch_sequences[source]

batch_sequences(sequences, pad_idx)

Packs sequences into a dense tensor, using pad_idx for padding

lm_collate[source]

lm_collate(batch, pad_idx, batch_first=True)

Collate function for language models. Returns packed batch for next-token prediction

sequence_prediction_collate[source]

sequence_prediction_collate(batch, pad_idx, batch_first=True)

Collate function for predicting some y value from a sequence

vector_collate[source]

vector_collate(batch)

Collate function for vectors

vec_to_text_collate[source]

vec_to_text_collate(batch, pad_idx, batch_first=True)

Collate function for predicting a sequence from an input vector where batch_tensor is needed for input (ie predict SMILES from properties)

vector_prediction_collate[source]

vector_prediction_collate(batch)

Collate function for predicting some y value from a vector

collate_ds[source]

collate_ds(ds)

Datasets

Datasets subclass the Pytorch Dataset class. MRL datasets add a collate function and the Base_Dataset.dataloader function to easily generate Pytorch dataloaders from the same class

Like all Pytorch datasets, subclass datasets must contain a valid __len__ and __getitem__ method. MRL datasets should also include a new method.

The purpose of new is to create a new dataset from new data using the same input arguments and collate function as the current dataset. This is used during generative training to process and batch generated samples to ensure they are processed and batched the same as training data

class Base_Dataset[source]

Base_Dataset(*args, **kwds) :: Dataset

BaseDataset - base dataset

Inputs:

  • collate_function Callable: batch collate function for the particular dataset class

  • cache Bool: if True, cache dataset

Text Datasets

Text datasets deal with tokenizing and numericalizing text data, like SMILES strings.

Text_Dataset returns numericalized SMILES for language modeling.

Text_Prediction_Dataset returns numericaized SMILES along with some y_val output value, for tasks like property prediction

class Text_Dataset[source]

Text_Dataset(*args, **kwds) :: Base_Dataset

Text_Dataset - base dataset for language modes

Inputs:

  • sequences [list[str], list[tuple]]: list of text sequences or text tuples (source, target)

  • vocab Vocab: vocabuary for tokenization/numericaization

  • collate_function Callable: batch collate function. If None, defauts to lm_collate

  • cache Bool: if True, cache dataset

If sequences is a list of strings, __getitem__ returns a tuple of (sequence_ints, None). This is suitable for language modeling where the goal is to predict the input sequence.

If sequences is a list of tuples, __getitem__ returns a tuple of (input_sequence_ints, output_sequence_ints). This is suitable for seq-to-seq tasks where the predicted sequence is different from the input sequence

df = pd.read_csv('files/smiles.csv')
vocab = CharacterVocab(SMILES_CHAR_VOCAB)

ds = Text_Dataset(df.smiles.values, vocab)
dl = ds.dataloader(16, num_workers=0)
x,y = next(iter(dl))

assert (x[:,1:] == y[:,:-1]).all()

assert sum([len(i) for i in ds.split(0.2)]) == len(ds)
ds = Text_Dataset([(i,i) for i in df.smiles.values], vocab)
dl = ds.dataloader(16, num_workers=0)
x,y = next(iter(dl))

assert (x[:,1:] == y[:,:-1]).all()
ds = Text_Dataset(df.smiles.values, vocab, cache=True)
_ = ds[1]
assert 1 in ds.cache.keys()

class Text_Prediction_Dataset[source]

Text_Prediction_Dataset(*args, **kwds) :: Text_Dataset

Text_Prediction_Dataset - base dataset for predicting from text strings

Inputs:

  • sequences list[str]: list of text sequences

  • y_vals list[int, float]: list of paired output values

  • vocab Vocab: vocabuary for tokenization/numericaization

  • collate_function Callable: batch collate function. If None, defauts to sequence_prediction_collate

  • cache Bool: if True, cache dataset

__getitem__ returns a tuple of (sequence_ints, y_vals) suitable for predicting regressions or classifications from the sequence

ds = Text_Prediction_Dataset(df.smiles.values, [0]*len(df.smiles.values), vocab)
dl = ds.dataloader(16, num_workers=0)
x,y = next(iter(dl))
assert (y == torch.zeros(y.shape).float()).all()
assert sum([len(i) for i in ds.split(0.2)]) == len(ds)
ds = Text_Prediction_Dataset(df.smiles.values, [0]*len(df.smiles.values), vocab, cache=True)
_ = ds[0]
assert 0 in ds.cache.keys()

Vector Datasets

Another common dataset framework is where we are dealing with vectors derived from a molecule. This could be a vector of properties, fingerprints, or any task where a molecule-derived vector is needed.

Vector_Dataset is a base dataset that simply returns the molecule derived vector

Vec_Recon_Dataset returns the molecule derived vector and tokenized SMILES strings. This is used for tasks like generating compounds based on an input vector or fingerprint

class Vector_Dataset[source]

Vector_Dataset(*args, **kwds) :: Base_Dataset

Vector_Dataset - base dataset for molecule-derived vectors

Inputs:

  • sequences list[str]: list of text sequences

  • vec_function Callable: function to convert sequence to a vector

  • collate_function Callable: batch collate function. If None, defauts to vector_collate

  • cache Bool: if True, cache dataset

from mrl.chem import ECFP6

df = pd.read_csv('files/smiles.csv')
ds = Vector_Dataset(df.smiles.values, ECFP6, cache=True)
dl = ds.dataloader(16, num_workers=0)
batch = next(iter(dl))
new1, new2 = ds.split(0.1)

x = next(iter(dl))
assert x.shape==(16,2048)
assert sum([len(i) for i in ds.split(0.2)]) == len(ds)
/home/dmai/miniconda3/envs/mrl/lib/python3.7/importlib/_bootstrap.py:219: RuntimeWarning: to-Python converter for boost::shared_ptr<RDKit::FilterCatalogEntry const> already registered; second conversion method ignored.
  return f(*args, **kwds)

class Vec_To_Text_Dataset[source]

Vec_To_Text_Dataset(*args, **kwds) :: Vector_Dataset

Vec_To_Text_Dataset - base dataset for predicting text sequences from vectors

Inputs:

  • sequences [list[str], list[tuple]]: list of text sequences or text tuples (source, target)

  • vocab Vocab: vocabuary for tokenization/numericaization

  • vec_function Callable: function to convert a sequence to a vector

  • collate_function Callable: batch collate function. If None, defauts to vec_to_text_collate

  • cache Bool: if True, cache dataset

__getitem__ returns a tuple of (sequence_vector, sequence_ints).

If sequences is a list of strings, both sequence_vector and sequence_ints will be derived from the same sequence.

If sequences is a list of tuples, sequence_vector will be derived from the first sequence and sequence_ints will be derived from the second sequence

from mrl.chem import ECFP6

vocab = CharacterVocab(SMILES_CHAR_VOCAB)
ds = Vec_To_Text_Dataset(df.smiles.values, vocab, ECFP6)
dl = ds.dataloader(16, num_workers=0)
x,y = next(iter(dl))
assert len(x)==2
assert (x[0][:,1:] == y[:,:-1]).all()
assert sum([len(i) for i in ds.split(0.2)]) == len(ds)
from mrl.chem import ECFP6

ds = Vec_To_Text_Dataset(
    [(df.smiles.values[i],df.smiles.values[i+1]) for i in range(len(df.smiles.values)-1)], 
    vocab, ECFP6)
dl = ds.dataloader(16, num_workers=0)
x,y = next(iter(dl))
assert len(x)==2

class Vec_Prediction_Dataset[source]

Vec_Prediction_Dataset(*args, **kwds) :: Vector_Dataset

Vec_Prediction_Dataset - base dataset for predicting y_vals from vectors

Inputs:

  • sequences list[str]: list of text sequences

  • y_vals list[int, float]: list of paired output values

  • vec_function Callable: function to convert a sequence to a vector

  • collate_function Callable: batch collate function. If None, defauts to vector_prediction_collate

  • cache Bool: if True, cache dataset

from mrl.chem import ECFP6

ds = Vec_Prediction_Dataset(df.smiles.values, [0 for i in df.smiles.values], ECFP6)
dl = ds.dataloader(16, num_workers=0)
x,y = next(iter(dl))
assert sum([len(i) for i in ds.split(0.2)]) == len(ds)