Callbacks for buffer

Buffer

The Buffer class holds samples generated during the BuildBuffer event. Samples added to the buffer can be any python object that is hashable.

Sample hashing is used to determine unique samples. For this reason, samples should avoid using containers like pytorch Tensors which are hashed on the tensor object level rather than the numeric level.

set([torch.tensor(0.), torch.tensor(0.)])
>> {tensor(0.), tensor(0.)}

set([0., 0.])
>> {0.0}

class Buffer[source]

Buffer(p_total) :: Callback

Buffer - training buffer

Inputs:

  • p_total float: batch percentage for sample_batch

class WeightedBuffer[source]

WeightedBuffer(p_total, refresh_predictions, pct_argmax=0.0) :: Buffer

WeightedBuffer - base class for buffer with weighted sampling

Inputs:

  • p_total float: batch percentage for sample_batch

  • refresh_predictions int: how often to generate new prdictions for all items in the buffer

  • pct_argmax float[0., 1.]: percent of samples to draw with argmax over the calculated weight versus weighted random sampling

class PredictiveBuffer[source]

PredictiveBuffer(p_total, refresh_predictions, predictive_agent, pred_bs, supervised_frequency, supervised_epochs, supervised_bs, supervised_lr, train_silent=True, pct_argmax=0.0, track=True) :: WeightedBuffer

PredictiveBuffer - buffer with active learning score prediction

Inputs:

  • p_total float: batch percentage for sample_batch

  • refresh_predictions int: how often to generate new prdictions for all items in the buffer

  • predictive_agent PredictiveAgent: active learning agent to train

  • pred_bs int: prediction batch size for predictive_agent

  • supervised_frequency int: how often to run offline supervised training of the predictive agent

  • supervised_epochs int: how many epochs to run during offline supervised training

  • supervised_bs int: batch size for offline supervised training

  • supervised_lr float: learning rate for offline supervised training

  • train_silent bool: if True, offline supervised training results are printed

  • pct_argmax float[0., 1.]: percent of samples to draw with argmax over the calculated weight versus weighted random sampling

  • track bool: if True, predictive buffer metrics are added to the environment printout

class BufferSizeCallback[source]

BufferSizeCallback() :: Callback

BufferSizeCallback - print out current buffer size during training