callback
allennlp.training.callbacks.callback
TrainerCallback¶
class TrainerCallback(Registrable):
| def __init__(self, serialization_dir: str) -> None
A general callback object that handles multiple events.
This class has on_backward, on_batch, on_epoch, and on_end methods, corresponding to
each callback type. Each one receives the state of the wrapper object as self.
This enables easier state sharing between related callbacks.
Also, this callback type is instantiated with serialization_dir and on_start is called
with the trainer instance as an argument. This might be handy in case of callback logging
and saving its own files next to the config/checkpoints/logs/etc.
on_start¶
class TrainerCallback(Registrable):
| ...
| def on_start(
| self,
| trainer: "GradientDescentTrainer",
| is_primary: bool = True,
| **kwargs
| ) -> None
This callback hook is called before the training is started.
on_backward¶
class TrainerCallback(Registrable):
| ...
| def on_backward(
| self,
| trainer: "GradientDescentTrainer",
| batch_outputs: Dict[str, torch.Tensor],
| backward_called: bool,
| **kwargs
| ) -> bool
This callback hook performs backpropagation and allows for gradient manipulation.
backward_called indicates if loss.backward has been called prior to this callback.
on_backward should return True if and only if loss.backward is called in its body.
on_batch¶
class TrainerCallback(Registrable):
| ...
| def on_batch(
| self,
| trainer: "GradientDescentTrainer",
| batch_inputs: List[TensorDict],
| batch_outputs: List[Dict[str, Any]],
| batch_metrics: Dict[str, Any],
| epoch: int,
| batch_number: int,
| is_training: bool,
| is_primary: bool = True,
| batch_grad_norm: Optional[float] = None,
| **kwargs
| ) -> None
This callback hook is called after the end of each batch.
on_epoch¶
class TrainerCallback(Registrable):
| ...
| def on_epoch(
| self,
| trainer: "GradientDescentTrainer",
| metrics: Dict[str, Any],
| epoch: int,
| is_primary: bool = True,
| **kwargs
| ) -> None
This callback hook is called after the end of each epoch.
on_end¶
class TrainerCallback(Registrable):
| ...
| def on_end(
| self,
| trainer: "GradientDescentTrainer",
| metrics: Dict[str, Any] = None,
| epoch: int = None,
| is_primary: bool = True,
| **kwargs
| ) -> None
This callback hook is called after the final training epoch.
state_dict¶
class TrainerCallback(Registrable):
| ...
| def state_dict(self) -> Dict[str, Any]
load_state_dict¶
class TrainerCallback(Registrable):
| ...
| def load_state_dict(self, state_dict: Dict[str, Any]) -> None