Sampler Callbacks
Samplers serve two main functions in the fit loop.
During the build_buffer
event, samplers add samples to the Buffer
During the sample_batch
event, samplers add samples to the current BatchState
Samplers generally have the ability to toggle which events they add samples to. For example if you wanted to do entirely offline RL, you could disable live sampling during the sample_batch
event and only train off samples stored in the buffer
Sampler Size
Samplers have two main parameters that control sample size.
The buffer_size
parameter is an integer value that control how many samples are generated during the build_buffer
event.
The p_batch
parameter is a float value between 0 and 1 that determines what percentage of a batch should be drawn from a specific sampler. When using multiple samplers, the sum of all p_batch
values should be less than or equal to 1. The difference between the sum of p_batch
values and the desired batch size will be made up by sampling from the buffer
Model Sampler
The ModelSampler
sampler can be used to draw samples from any GenerativeModel
subclass model. By default, it will track the following sample metrics:
- diversity - how many duplicate samples were generated
- valid - how many samples are left after filtering
- rewards - average rewards from samples generated by the model sampler
- new - how many samples are novel to the training run
Prior Sampler
PriorSampler
allows for sampling based on latent vectors from a specific prior distribution. If desired, this prior can also be optimized during the fit loop
Latent Sampler
LatentSampler
allows for sampling based on specific latent vectors. If desired, the latent vectors can also be optimized during the fit loop
Contrastive Sampler
So far samplers have focused on drawing individual samples from a model, dataset, or other source.
Contrastive sampling looks at the task of generating a new sample based on an old sample. For example, we could want to train a model to take in a compound and produce different compound with a high similarity to the original compound but with a better score based on some metric.
In these cases, the samples we create will be a tuple in the form (source_sample, target_sample)
. When training a contrastive metric, we may have a pre-made dataset of source, target
pairs to use. However, if such paired data doesn't exist, we need some way to generate it on the fly. This is where the ContrastiveSampler
class comes in.
ContrastiveSampler
turns any normal Sampler into a contrastive sampler. The contrastive sampler uses a base_sampler
to generate an initial set of source
samples. Then the contrastive sampler uses an output_model
to generate target
samples on the fly.
This generation process can be run during build_buffer
or sample_batch
.
Note that the ContrastiveSampler
does not do any batch or buffer filtering to ensure source, target
pairs match external constraints like minimum similarity. This must be handled by other callbacks, like ContrastiveTemplate