backward
allennlp.training.callbacks.backward
MixedPrecisionBackwardCallback¶
@TrainerCallback.register("mixed_precision_backward")
class MixedPrecisionBackwardCallback(TrainerCallback)
Performs backpropagation for mixed precision training.
on_backward¶
class MixedPrecisionBackwardCallback(TrainerCallback):
| ...
| def on_backward(
| self,
| trainer: "GradientDescentTrainer",
| batch_outputs: Dict[str, torch.Tensor],
| backward_called: bool,
| **kwargs
| ) -> bool
OnBackwardException¶
class OnBackwardException(Exception):
| def __init__(self, message="") -> None
The exception type raised if an on_backward
callback
attempts to call backward
when backward_called
is True
.