Base callback class

Callbacks

The training cycle in MRL is built around the Callback system. Rather than trying to explicitly define every training cycle variant, Callbacks define a series of events (see Events) that occur during training and allow users to easily hook into those events. The result is an extremely flexible framework that can adapt to most generative design challenges.

Callbacks use the __call__ function to organize events. The call function will be passed an event name, like compute_reward. If the Callback function has an attribute that matches the event name, the attribute is called.

Callbacks have access to the training environment (see Environment) and can access the training environment, the model/agent, the training buffer, training log, other callbacks and all other aspects of the training state

class Callback[source]

Callback(name='base_callback', order=10)

class Event[source]

Event()

Event

Base class for events

class Setup[source]

Setup() :: Event

Setup

Setup is called after an Environment is created. The setup step is used to do things like set attributes or add logging terms

class BeforeTrain[source]

BeforeTrain() :: Event

BeforeTrain

This event is called by Environment.fit before the first batch is run

class BuildBuffer[source]

BuildBuffer() :: Event

BuildBuffer

The build buffer event is used to add samples to the Buffer

class FilterBuffer[source]

FilterBuffer() :: Event

FilterBuffer

The filter buffer event is used to screen items added to the buffer during build_buffer and remove ones that do not match the filter criteria

class AfterBuildBuffer[source]

AfterBuildBuffer() :: Event

AfterBuildBuffer

This event is called after the buffer has been filtered and before the next batch starts. This event can be used to evaluate metrics and statistics related to the buffer creation

class BeforeBatch[source]

BeforeBatch() :: Event

BeforeBatch

This event is called before the next batch is sampled

class SampleBatch[source]

SampleBatch() :: Event

SampleBatch

This event produces a series of samples that are added to the next batch

class BeforeFilterBatch[source]

BeforeFilterBatch() :: Event

BeforeFilterBatch

This event is called before the current batch is filtered

class FilterBatch[source]

FilterBatch() :: Event

FilterBatch

This event is used to screen items in the current batch and remove items that do not match the filter criteria

class AfterSample[source]

AfterSample() :: Event

AfterSample

This event is called after a batch is sampled and filtered. This event can be used to log stats about the last batch

class BeforeComputeReward[source]

BeforeComputeReward() :: Event

BeforeComputeReward

This event is called prior to computing rewards on the current batch. This event can be used to generate any inputs required for computing rewards

class ComputeReward[source]

ComputeReward() :: Event

ComputeReward

This event is used to compute rewards for the current batch

All rewards should be added to self.environmemnt.batch_state.rewards

class AfterComputeReward[source]

AfterComputeReward() :: Event

AfterComputeReward

This event is called after all rewards have been computed. This event can be used to log stats and metrics related to the rewards for the current batch

class RewardModification[source]

RewardModification() :: Event

RewardModification

This event is used to modify rewards before they are used to compute the model's loss. Reward modifications encompass changes to rewards in the context of the current training cycle. These are things like "give a score bonus to new samples that havent't been seen before" or "penalize the score of samples that have occurred in the last 5 batches".

These types of modifications are kept separate from the core reward for logging purposes. Samples are logged with their respective rewards. These logged scores are referenced later when samples are drawn from the log. This means we need the logged score to be independent from "batch context" type scores

All reward modifications should be applied to self.environmemnt.batch_state.rewards

class AfterRewardModification[source]

AfterRewardModification() :: Event

AfterRewardModification

This event is called after all reward modifications have been computed. This event can be used to log stats and metrics related to the reward modifications for the current batch

class GetModelOutputs[source]

GetModelOutputs() :: Event

GetModelOutputs

This event is used to generate any model-derived outputs relevant to loss computation

class AfterGetModelOutputs[source]

AfterGetModelOutputs() :: Event

AfterGetModelOutputs

This event is called after get_model_outputs. This event can be used for any processing required prior to loss computation

class ComputeLoss[source]

ComputeLoss() :: Event

ComputeLoss

This event is used to compute loss values

All loss values should be added to self.environment.batch_state.loss

class ZeroGrad[source]

ZeroGrad() :: Event

ZeroGrad

This event is used to zero gradients in any optimizers relevant to the fit cycle

loss.backward() is called after zero grad

class BeforeStep[source]

BeforeStep() :: Event

BeforeStep

This event is used for any processed needed after loss.backward() but before opt.step(), ie gradient clipping

class Step[source]

Step() :: Event

Step

This event is used to step all optimizers

class AfterBatch[source]

AfterBatch() :: Event

AfterBatch

This event is called after step. This event can be used to compute batch stats and clean up values before the next batch

class AfterTrain[source]

AfterTrain() :: Event

AfterTrain

This event is called after all batch steps have been completed

class Events[source]

Events()

class SettrDict[source]

SettrDict() :: dict

dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object's (key, value) pairs dict(iterable) -> new dictionary initialized as if via: d = {} for k, v in iterable: d[k] = v dict(**kwargs) -> new dictionary initialized with the name=value pairs in the keyword argument list. For example: dict(one=1, two=2)

Batch State

The BatchState class is used by an Environment to track values generated or computed during a batch. Every batch, the old BatchState is deleted and a new BatchState is created.

Attributes in BatchState can be set or accessed with a key like a dictionary or as an attribute. BatchState can hold any arbitrary value during a batch. However, it was designed for the use case where every attribute is either a single value or a list/container with length equal to the current batch size.

Rewards

BatchState holds the rewards value for a batch. All reward functions should ultimately add their reward value to BatchState.rewards. See Reward for more information.

Loss

BatchState holds the loss value for a batch. This is the value that will be backpropagated during the optimizer update. All loss functions should ultimately add their value to BatchState.loss. See Loss for more information.

class BatchState[source]

BatchState() :: SettrDict

dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object's (key, value) pairs dict(iterable) -> new dictionary initialized as if via: d = {} for k, v in iterable: d[k] = v dict(**kwargs) -> new dictionary initialized with the name=value pairs in the keyword argument list. For example: dict(one=1, two=2)