Loss callback functions
Loss Function Callbacks
Loss function callbacks compute some loss value from the current batch state and add the resulting value to BatchState.loss
.
LossCallback
provides a simple hook for custom loss functions. Any object with a from_batch_state
method that returns a scalar value can be passed to LossCallback
. Ex:
class MyLoss():
def from_batch_state(self, batch_state):
loss = self.do_loss_calculation()
return loss
Policy Loss
The PolicyLoss
interfaces with any of the BasePolicy
subclasses like PolicyGradient
, TRPO
or PPO
.
PolicyLoss
can optionally contain a value_head
, a model to predict state values. The code will look for a batch_state.value_input
attribute, which is typically set by Agent.get_model_outputs