Skip to content

track_epoch

allennlp.training.callbacks.track_epoch

[SOURCE]


TrackEpochCallback

@TrainerCallback.register("track_epoch_callback")
class TrackEpochCallback(TrainerCallback)

A callback that you can pass to the GradientDescentTrainer to access the current epoch number in your model during training. This callback sets model.epoch, which can be read inside of model.forward(). We set model.epoch = epoch + 1 which now denotes the number of completed epochs at a given training state.

on_start

class TrackEpochCallback(TrainerCallback):
 | ...
 | def on_start(
 |     self,
 |     trainer: "GradientDescentTrainer",
 |     is_primary: bool = True,
 |     **kwargs
 | ) -> None

on_epoch

class TrackEpochCallback(TrainerCallback):
 | ...
 | def on_epoch(
 |     self,
 |     trainer: "GradientDescentTrainer",
 |     metrics: Dict[str, Any],
 |     epoch: int,
 |     is_primary: bool = True,
 |     **kwargs
 | ) -> None