Loss callback functions

Loss Function Callbacks

Loss function callbacks compute some loss value from the current batch state and add the resulting value to BatchState.loss.

LossCallback provides a simple hook for custom loss functions. Any object with a from_batch_state method that returns a scalar value can be passed to LossCallback. Ex:

class MyLoss():
    def from_batch_state(self, batch_state):
        loss = self.do_loss_calculation()
        return loss

class LossCallback[source]

LossCallback(loss_function, name, weight=1.0, track=True, order=20) :: Callback

LossCallback - basic loss callback

Inputs:

  • loss_function: any object with a from_batch_state method

  • name str: loss name

  • weight float: loss scaling weight

  • track bool: if values from this loss should be tracked

  • order int: callback order

Policy Loss

The PolicyLoss interfaces with any of the BasePolicy subclasses like PolicyGradient, TRPO or PPO.

PolicyLoss can optionally contain a value_head, a model to predict state values. The code will look for a batch_state.value_input attribute, which is typically set by Agent.get_model_outputs

class PolicyLoss[source]

PolicyLoss(policy_function, name, value_head=None, v_update=0.95, v_update_iter=10, vopt_kwargs={}, track=True) :: LossCallback

PolicyLoss - Loss callback for BasePolicy subclasses

Inputs:

  • policy_function BasePolicy: policy

  • name str: loss name

  • value_head Optional[nn.Module]: state value prediction model

  • v_update float: update fraction for the baseline value head

  • v_update_iter int: update frequency for baseline value head

  • vopt_kwargs dict: dictionary of keyword args passed to optim.Adam

  • track bool: if values from this loss should be tracked

class PriorLoss[source]

PriorLoss(prior, base_prior=None, clip=10.0)

PriorLoss - loss for a trainable prior

Inputs:

  • prior nn.Module: trainable prior

  • base_prior Optional[nn.Module]: optional baseline prior

  • clip float: loss clipping value

class HistoricPriorLoss[source]

HistoricPriorLoss(prior_loss, model, dataset, percentile, n, above_percent, start_iter, frequency, log_term='rewards', weight=1.0, track=True) :: Callback

HistoricPriorLoss - draws samples from batch log to send to prior_loss

Inputs:

  • prior_loss PriorLoss: prior loss function

  • model nn.Module: model used to convert samples to latent vectors

  • dataset Base_Dataset: dataset to convert samples to tensors

  • percentile int: value [1-100] percentile to sample from

  • n int: number of samples to draw

  • above_percent float: what percentage of samples should be above percentile in score

  • start_iter int: what iteration to start using historical loss

  • frequency int: batch frequency of calling historical loss

  • log_term str: what term in the batch log to use for percentile computation

  • weight float: loss scaling weight

  • track bool: if values from this callback should be tracked