Agent
The Agent
class holds a model, a dataset and a loss function in a single object. The Agent is also a callback and serves several roles in the fit loop.
Notable Functions
Agent.train_supervised
- runs a supervised training loop using the items inAgent.dataset
. Subclass this function for custom supervised training loopsAgent.update_dataset
/Agent.update_dataset_from_inputs
- updatesAgent.dataset
with new itemsAgent.before_compute_reward
- used during the fit loop to convert samples into tensors. Items from the current batch are used to create a version ofAgent.dataset
containing the new samples. This dataset is then used to convert samples into tensors
Baseline Agent
Many RL algorithms make use of two agents. The main agent is trained every batch. The other agent (the baseline agent) is updated every n
batches. RL algorithms like PPO
and TRPO
make use of the ratio between the main agent and the baseline agent.
The BaselineAgent
creates a copy of the model that serves as the baseline. The baseline agent is updated every base_update_iter
batches.
The baseline is updated following w_baseline_new = alpha*w_baseline_old + (1-alpha)*w_main
where alpha
is set by the base_update
parameter. Setting base_update=0
will cause the weights of the main agent to be simply copied into the baseline.
Generative Agent
The GenerativeAgent
class adds in a vocab
input to reconstruct generated samples. This class also has updated before_compute_reward
and get_model_outputs
to create the relevant values for training
Training Callbacks
These callbacks organize using supervised training during the RL fit cycle
SupervisedCB
- runs supervised training on the topx
percentile of samples with a set frequencyRollback
- if a chosen metric falls (above/below) a certain value, the weights of the main model are reverted to the baseline modelRetrainRollback
- runs supervised training if a chosen metric falls (above/below) a certain valueResetAndRetrain
- with a set frequency, reloads a saved checkpoint and runs supervised training from the sample logSaveAgentWeights
- saves weights with a set frequency
from mrl.vocab import *
from mrl.dataloaders import *
from mrl.g_models.all import *
df = pd.read_csv('files/smiles.csv')
vocab = CharacterVocab(SMILES_CHAR_VOCAB)
ds = Text_Dataset(list(df.smiles.values)*10, vocab)
d_vocab = len(vocab.itos)
d_embedding = 256
d_hidden = 1024
n_layers = 3
input_dropout = 0.3
lstm_dropout = 0.3
bos_idx = vocab.stoi['bos']
bidir = False
tie_weights = True
model = LSTM_LM(d_vocab,
d_embedding,
d_hidden,
n_layers,
input_dropout,
lstm_dropout,
bos_idx,
bidir,
tie_weights)
model.load_state_dict(torch.load('untracked_files/lstm_lm_zinc.pt'))
agent = GenerativeAgent(model, vocab, CrossEntropy(), ds, opt_kwargs={'lr':1e-4})
agent.train_supervised(64, 1, 1e-4, silent=False)
from mrl.vocab import *
from mrl.dataloaders import *
from mrl.g_models.all import *
df = pd.read_csv('files/smiles.csv')
vocab = CharacterVocab(SMILES_CHAR_VOCAB)
ds = Text_Dataset(list(df.smiles.values)*10, vocab)
d_vocab = len(vocab.itos)
d_embedding = 256
d_hidden = 1024
n_layers = 3
input_dropout = 0.3
lstm_dropout = 0.3
bos_idx = vocab.stoi['bos']
bidir = False
tie_weights = True
model = LSTM_LM(d_vocab,
d_embedding,
d_hidden,
n_layers,
input_dropout,
lstm_dropout,
bos_idx,
bidir,
tie_weights)
model.load_state_dict(torch.load('untracked_files/lstm_lm_zinc.pt'))
agent = GenerativeAgent(model, vocab, CrossEntropy(), ds, opt_kwargs={'lr':1e-4})
agent.train_supervised(64, 1, 1e-4, silent=False, fp16=True)