Model agents

Agent

The Agent class holds a model, a dataset and a loss function in a single object. The Agent is also a callback and serves several roles in the fit loop.

Notable Functions

class Agent[source]

Agent(model, loss_function, dataset, opt_kwargs={}, clip=1.0, name='agent') :: Callback

Agent - class for bundling a model, loss function, and dataset

Inputs:

  • model nn.Module: model

  • loss_function Callable: loss function for supervised training. Should function as loss = loss_function(model_output, y)

  • dataset Base_Dataset: dataset

  • opt_kwargs dict: dictionary of keyword arguments passed to optim.Adam

  • clip float: gradient clipping

  • name str: agent name

class PredictiveAgent[source]

PredictiveAgent(model, loss_function, dataset, opt_kwargs={}, clip=1.0, name='agent') :: Agent

PredictiveAgent - Agent class for predictive models

Inputs:

  • model nn.Module: model

  • loss_function Callable: loss function for supervised training. Should function as loss = loss_function(model_output, y)

  • dataset Base_Dataset: dataset

  • opt_kwargs dict: dictionary of keyword arguments passed to optim.Adam

  • clip float: gradient clipping

  • name str: agent name

Baseline Agent

Many RL algorithms make use of two agents. The main agent is trained every batch. The other agent (the baseline agent) is updated every n batches. RL algorithms like PPO and TRPO make use of the ratio between the main agent and the baseline agent.

The BaselineAgent creates a copy of the model that serves as the baseline. The baseline agent is updated every base_update_iter batches.

The baseline is updated following w_baseline_new = alpha*w_baseline_old + (1-alpha)*w_main where alpha is set by the base_update parameter. Setting base_update=0 will cause the weights of the main agent to be simply copied into the baseline.

class BaselineAgent[source]

BaselineAgent(model, loss_function, dataset, base_update=0.99, base_update_iter=10, base_model=True, opt_kwargs={}, clip=1.0, name='baseline_agent') :: Agent

BaselineAgent - agent for a model with a baseline model

Inputs:

  • model nn.Module: model

  • loss_function Callable: loss function for supervised training. Should function as loss = loss_function(model_output, y)

  • dataset Base_Dataset: dataset

  • base_update float: update fraction for the baseline model. Updates the base model following base_model = base_update*base_model + (1-base_update)*model

  • base_update_iter int: update frequency for baseline model

  • base_model bool: if False, baseline model will not be created

  • opt_kwargs dict: dictionary of keyword arguments passed to optim.Adam

  • clip float: gradient clipping

  • name str: agent name

class CriticAgent[source]

CriticAgent(model, loss_function, dataset, base_update=0.99, base_update_iter=10, base_model=True, opt_kwargs={}, clip=1.0, name='baseline_agent') :: BaselineAgent

CriticAgent - baseline agent for critic models

Inputs:

  • model nn.Module: model

  • loss_function Callable: loss function for supervised training. Should function as loss = loss_function(model_output, y)

  • dataset Base_Dataset: dataset

  • base_update float: update fraction for the baseline model. Updates the base model following base_model = base_update*base_model + (1-base_update)*model

  • base_update_iter int: update frequency for baseline model

  • base_model bool: if False, baseline model will not be created

  • opt_kwargs dict: dictionary of keyword arguments passed to optim.Adam

  • clip float: gradient clipping

  • name str: agent name

Generative Agent

The GenerativeAgent class adds in a vocab input to reconstruct generated samples. This class also has updated before_compute_reward and get_model_outputs to create the relevant values for training

class GenerativeAgent[source]

GenerativeAgent(model, vocab, loss_function, dataset, base_update=0.99, base_update_iter=10, base_model=True, opt_kwargs={}, clip=1.0, name='generative_agent') :: BaselineAgent

GenerativeAgent - baseline agent for generative models

Inputs:

  • model nn.Module: model

  • vocab Vocab: vocabulary

  • loss_function Callable: loss function for supervised training. Should function as loss = loss_function(model_output, y)

  • dataset Base_Dataset: dataset

  • base_update float: update fraction for the baseline model. Updates the base model following base_model = base_update*base_model + (1-base_update)*model

  • base_update_iter int: update frequency for baseline model

  • base_model bool: if False, baseline model will not be created

  • opt_kwargs dict: dictionary of keyword arguments passed to optim.Adam

  • clip float: gradient clipping

  • name str: agent name

Training Callbacks

These callbacks organize using supervised training during the RL fit cycle

  • SupervisedCB - runs supervised training on the top x percentile of samples with a set frequency

  • Rollback - if a chosen metric falls (above/below) a certain value, the weights of the main model are reverted to the baseline model

  • RetrainRollback - runs supervised training if a chosen metric falls (above/below) a certain value

  • ResetAndRetrain - with a set frequency, reloads a saved checkpoint and runs supervised training from the sample log

  • SaveAgentWeights - saves weights with a set frequency

class SupervisedCB[source]

SupervisedCB(agent, frequency, base_update, percentile, lr, bs, log_term='rewards', epochs=1, silent=True) :: Callback

SupervisedCB - supervised training callback. When triggered, this callback grabs the top percentile of samples from the log and runs supervised training with the sampled data

Inputs:

  • agent Agent: agent

  • frequency int: how often to run supervised training

  • base_update float: how much to update the baseline model after supervised training (if applicable)

  • percentile int: percentile (int value 1-100) of data to sample from the log

  • lr float: learning rate

  • bs int: batch size

  • log_term str: what term in the log to take the percentile of

  • epochs int: number of training epochs

  • silent bool: if training losses should be printed

class Rollback[source]

Rollback(agent, metric_name, lookback, target, alpha, name, mode='greater') :: Callback

Rollback - if metric_name falls (above/below) target, updates the main model's weights with the baseline model's weights

Inputs:

  • agent BaselineAgent: agent

  • metric_name str: metric to track

  • lookback int: number of batches to look back. Also sets the maximum rollback frequency

  • target float: desired cutoff for metric_name

  • alpha float: during rollback, the main model weights are updated following model = alpha*model + (1-alpha)*base_model

  • name str: callback name

  • mode str['greater', 'lesser']: if greater, rollback is triggered by the metric going over target. If lesser, rollback is triggered by the metric falling below target

class RetrainRollback[source]

RetrainRollback(agent, metric_name, log_term, lookback, target, percentile, lr, bs, base_update, name, mode='greater', silent=False) :: Callback

RetrainRollback - triggers supervised training if metric_name falls (above/below) target

Inputs:

  • agent BaselineAgent: agent

  • metric_name str: metric to track

  • log_term str: what term in the log to take the percentile of

  • lookback int: number of batches to look back. Also sets the maximum rollback frequency

  • target float: desired cutoff for metric_name

  • percentile int: percentile (1-100) of data to sample from the log

  • lr float: learning rate

  • bs int: batch size

  • base_update float: after supervised training, the weights of the baseline model are updated following base_model = alpha*base_model + (1-alpha)*model

  • name str: callback name

  • mode str['greater', 'lesser']: if greater, rollback is triggered by the metric going over target. If lesser, rollback is triggered by the metric falling below target

  • silent bool: if training losses should be printed

class ResetAndRetrain[source]

ResetAndRetrain(agent, frequency, weight_fp, percentile, lr, bs, epochs, log_term='rewards', sample_term='samples', silent=False) :: Callback

ResetAndRetrain - with a set frequency, loads a file of saved weights and runs supervised training

Inputs:

  • agent BaselineAgent: agent

  • frequency int: how often to run supervised training

  • weight_fp str: filepath to weights

  • percentile int: percentile (int value 1-100) of data to sample from the log

  • lr float: learning rate

  • bs int: batch size

  • epochs int: number of epochs to run

  • log_term str: what term in the log to take the percentile of

  • sample_term str: what log term contains the samples to train on

  • silent bool: if training losses should be printed

class MetricResetAndRetrain[source]

MetricResetAndRetrain(agent, metric_name, lookback, target, weight_fp, percentile, lr, bs, epochs, log_term='rewards', sample_term='samples', mode='greater', silent=False) :: Callback

MetricResetAndRetrain - loads a file of saved weights and runs supervised training if metric_name falls (above/below) target

Inputs:

  • agent BaselineAgent: agent

  • metric_name str: metric to track

  • lookback int: number of batches to look back. Also sets the maximum rollback frequency

  • target float: desired cutoff for metric_name

  • weight_fp str: filepath to weights

  • percentile int: percentile (int value 1-100) of data to sample from the log

  • lr float: learning rate

  • bs int: batch size

  • epochs int: number of epochs to run

  • log_term str: what term in the log to take the percentile of

  • sample_term str: what log term contains the samples to train on

  • mode str['greater', 'lesser']: if greater, rollback is triggered by the metric going over target. If lesser, rollback is triggered by the metric falling below target

  • silent bool: if training losses should be printed

class SaveAgentWeights[source]

SaveAgentWeights(file_path, filename, n_batches, agent) :: Callback

SaveAgentWeights - saves weights every n_batches. Weights are saved to file_path/filename_iterations.pt

Inputs:

  • file_path str: directory to save weights in

  • filename str: base filename

  • n_batches int: how often to save weights

  • agent Agent: agent

from mrl.vocab import *
from mrl.dataloaders import *
from mrl.g_models.all import *

df = pd.read_csv('files/smiles.csv')
vocab = CharacterVocab(SMILES_CHAR_VOCAB)

ds = Text_Dataset(list(df.smiles.values)*10, vocab)

d_vocab = len(vocab.itos)
d_embedding = 256
d_hidden = 1024
n_layers = 3
input_dropout = 0.3
lstm_dropout = 0.3
bos_idx = vocab.stoi['bos']
bidir = False
tie_weights = True

model = LSTM_LM(d_vocab, 
                d_embedding,
                d_hidden, 
                n_layers,
                input_dropout,
                lstm_dropout,
                bos_idx, 
                bidir, 
                tie_weights)

model.load_state_dict(torch.load('untracked_files/lstm_lm_zinc.pt'))

agent = GenerativeAgent(model, vocab, CrossEntropy(), ds, opt_kwargs={'lr':1e-4})

agent.train_supervised(64, 1, 1e-4, silent=False)
Epoch Train Loss Valid Loss Time
0 0.42942 0.51512 00:11
from mrl.vocab import *
from mrl.dataloaders import *
from mrl.g_models.all import *

df = pd.read_csv('files/smiles.csv')
vocab = CharacterVocab(SMILES_CHAR_VOCAB)

ds = Text_Dataset(list(df.smiles.values)*10, vocab)

d_vocab = len(vocab.itos)
d_embedding = 256
d_hidden = 1024
n_layers = 3
input_dropout = 0.3
lstm_dropout = 0.3
bos_idx = vocab.stoi['bos']
bidir = False
tie_weights = True

model = LSTM_LM(d_vocab, 
                d_embedding,
                d_hidden, 
                n_layers,
                input_dropout,
                lstm_dropout,
                bos_idx, 
                bidir, 
                tie_weights)

model.load_state_dict(torch.load('untracked_files/lstm_lm_zinc.pt'))

agent = GenerativeAgent(model, vocab, CrossEntropy(), ds, opt_kwargs={'lr':1e-4})

agent.train_supervised(64, 1, 1e-4, silent=False, fp16=True)
Epoch Train Loss Valid Loss Time
0 0.42942 0.51508 00:08