Agent
The Agent class holds a model, a dataset and a loss function in a single object. The Agent is also a callback and serves several roles in the fit loop.
Notable Functions
Agent.train_supervised- runs a supervised training loop using the items inAgent.dataset. Subclass this function for custom supervised training loopsAgent.update_dataset/Agent.update_dataset_from_inputs- updatesAgent.datasetwith new itemsAgent.before_compute_reward- used during the fit loop to convert samples into tensors. Items from the current batch are used to create a version ofAgent.datasetcontaining the new samples. This dataset is then used to convert samples into tensors
Baseline Agent
Many RL algorithms make use of two agents. The main agent is trained every batch. The other agent (the baseline agent) is updated every n batches. RL algorithms like PPO and TRPO make use of the ratio between the main agent and the baseline agent.
The BaselineAgent creates a copy of the model that serves as the baseline. The baseline agent is updated every base_update_iter batches.
The baseline is updated following w_baseline_new = alpha*w_baseline_old + (1-alpha)*w_main where alpha is set by the base_update parameter. Setting base_update=0 will cause the weights of the main agent to be simply copied into the baseline.
Generative Agent
The GenerativeAgent class adds in a vocab input to reconstruct generated samples. This class also has updated before_compute_reward and get_model_outputs to create the relevant values for training
Training Callbacks
These callbacks organize using supervised training during the RL fit cycle
SupervisedCB- runs supervised training on the topxpercentile of samples with a set frequencyRollback- if a chosen metric falls (above/below) a certain value, the weights of the main model are reverted to the baseline modelRetrainRollback- runs supervised training if a chosen metric falls (above/below) a certain valueResetAndRetrain- with a set frequency, reloads a saved checkpoint and runs supervised training from the sample logSaveAgentWeights- saves weights with a set frequency
from mrl.vocab import *
from mrl.dataloaders import *
from mrl.g_models.all import *
df = pd.read_csv('files/smiles.csv')
vocab = CharacterVocab(SMILES_CHAR_VOCAB)
ds = Text_Dataset(list(df.smiles.values)*10, vocab)
d_vocab = len(vocab.itos)
d_embedding = 256
d_hidden = 1024
n_layers = 3
input_dropout = 0.3
lstm_dropout = 0.3
bos_idx = vocab.stoi['bos']
bidir = False
tie_weights = True
model = LSTM_LM(d_vocab,
d_embedding,
d_hidden,
n_layers,
input_dropout,
lstm_dropout,
bos_idx,
bidir,
tie_weights)
model.load_state_dict(torch.load('untracked_files/lstm_lm_zinc.pt'))
agent = GenerativeAgent(model, vocab, CrossEntropy(), ds, opt_kwargs={'lr':1e-4})
agent.train_supervised(64, 1, 1e-4, silent=False)
from mrl.vocab import *
from mrl.dataloaders import *
from mrl.g_models.all import *
df = pd.read_csv('files/smiles.csv')
vocab = CharacterVocab(SMILES_CHAR_VOCAB)
ds = Text_Dataset(list(df.smiles.values)*10, vocab)
d_vocab = len(vocab.itos)
d_embedding = 256
d_hidden = 1024
n_layers = 3
input_dropout = 0.3
lstm_dropout = 0.3
bos_idx = vocab.stoi['bos']
bidir = False
tie_weights = True
model = LSTM_LM(d_vocab,
d_embedding,
d_hidden,
n_layers,
input_dropout,
lstm_dropout,
bos_idx,
bidir,
tie_weights)
model.load_state_dict(torch.load('untracked_files/lstm_lm_zinc.pt'))
agent = GenerativeAgent(model, vocab, CrossEntropy(), ds, opt_kwargs={'lr':1e-4})
agent.train_supervised(64, 1, 1e-4, silent=False, fp16=True)