MRL environment
/home/dmai/miniconda3/envs/mrl/lib/python3.7/importlib/_bootstrap.py:219: RuntimeWarning: to-Python converter for boost::shared_ptr<RDKit::FilterCatalogEntry const> already registered; second conversion method ignored.
  return f(*args, **kwds)

Environment

The Environment class holds all Callback classes involved in the fit cycle and runs the fit loop. All callbacks are treated the same, but the following callback classes are distinguished for semantic convenience:

The Fit Loop

The following describes the order of events in Environment.fit

  1. Callbacks added during Environment.fit are registered
  2. before_train event is called
  3. Start iterating over the number of batches. For each batch:
  4. Call Environment.build_buffer. If current buffer size is less than the current batch size:
    1. call build_buffer event
    2. call filter_buffer event
    3. call after_build_buffer event
  5. Call Environment.sample_batch
    1. create new BatchState
    2. call before_batch event
    3. call sample_batch event
    4. call before_filter_batch event
    5. call filter_batch event
    6. call after_sample event
  6. Call Environment.compute_reward
    1. call before_compute_reward event
    2. call compute_reward event
    3. call after_compute_reward event
    4. call reward_modification event
    5. call after_reward_modification event
  7. Call Environment.get_model_outputs
    1. call get_model_outputs event
    2. call after_get_model_outputs event
  8. Call Environment.compute_loss
    1. call compute_loss event
    2. call zero_grad event
    3. call before_step event
    4. call step event
  9. Call Environment.after_batch
    1. call after_batch event
  10. After the specified number of iterations have completed, call after_train event
  11. Remove callbacks registered at the start of the fit loop

class Environment[source]

Environment(agent=None, template_cb=None, samplers=None, rewards=None, losses=None, cbs=None, buffer=None, log=None)

Environment - Environment for training agent

Inputs:

  • agent Optional[Agent]: agent to train

  • template_cb Optional[TemplateCallback]: template callback

  • samplers Optional[list[Sampler]]: any sampler callbacks (can be any amount)

  • rewards Optional[list[RewardCallback]]: any reward callbacks

  • losses Optional[list[LossCallback]]: any loss callbacks

  • cbs Optional[list[Callback]]: any other callbacks

  • buffer_p_batch Optional[float]: percentage of each batch that should come from the buffer. If None, value is inferred from p_batch values in samplers

  • log Optional[Log]: custom log. If None, standard Log is used

Environment.fit[source]

Environment.fit(bs, sl, iters, report, cbs=None, verbose=False, buffer_frequency=None)

fit - runs the fit cycle

Inputs:

  • bs int: batch size

  • sl int: max sample length

  • iters int: number of batches to train

  • report int: report batch stats every report batches

  • cbs Optional[list[Callback]]: optional callbacks for the fit loop

  • verbose Bool: if True, prints event calls

  • buffer_frequency Optional[int]: minimum buffer generation frequency. If None, buffer is regenerated whenever the buffer size falls below the current batch size