Log
The Log
Callback serves the purpose of logging data generated during a training run. The log holds the following objects of interest:
batch_log
- a dictionary of batch-wise logged data. Each key is a string denoting the name of a logged attribute. The values are lists of lists, where each sub-list is the value of the given attribute for each item in a batch. For examplelog.batch_log['rewards'][5]
would yield an array of rewards for each item in batch 5timelog
- a dictionary of lists. The keys denote different training steps (build_buffer
,sample_batch
, etc) with how long the step took for each batch. To view the times for all training stages, use theLog.plot_timelog
function.metrics
- a dictionary of lists. Each key is the name of a tracked metric. Each value is a list containing the value of that metric for each batch. Metrics are single scalar values, one value per batch. Metrics can be plotted with theLog.plot_metrics
functionunique_samples
- a set containing every unique sample seen during training. This can be used to quickly reference if a sample has been seen beforedf
- a dataframe containing every unique sample and everybatch_log
value associated with that sample
Custom Metrics Logging
Adding new items to metric tracking and batch logging is easy.
Use the add_log
to add a new term to the batch log. At some point during the batch, add the values to be logged to the current BatchState
with an attribute name that matches the name added to the log. The batch log will automatically add the values to the batch log.
Use add_metric
to add a new term to the metric log. At some point during the batch, compute the metric you wish to log. Then use Log.update_metric
to add the value to the metric log.
Here's an outline implementation:
class MyCallback(Callback):
def __init__(self):
super().__init__(name='my_callback')
def setup(self):
log = self.environment.log
log.add_log(self.name) # adding new term to batch log
log.add_metric(self.name) # adding new term to metrics
def compute_reward(self):
# make tensor of dummy rewards
batch_state = self.environment.batch_state
bs = len(batch_state.samples)
rewards = to_device(torch.ones(bs).float())
batch_state.rewards += rewards
batch_state[self.name] = rewards # this is the value the batch log will pick up
def after_compute_reward(self):
log = self.environment.log
batch_state = self.environment.batch_state
my_callback_rewards = batch_state[self.name]
my_metric = my_callback_rewards.mean()
log.update_metric(self.name, my_metric.detach().cpu().numpy()) # update metric value