checkpoint_wrapper
allennlp.nn.checkpoint.checkpoint_wrapper
CheckpointWrapper¶
class CheckpointWrapper(Registrable)
A CheckpointWrapper
is used to enable activation/gradient checkpointing on modules
that you wrap via the .wrap_module()
method.
default_implementation¶
class CheckpointWrapper(Registrable):
| ...
| default_implementation = "torch"
wrap_module¶
class CheckpointWrapper(Registrable):
| ...
| def wrap_module(self, module: nn.Module, **kwargs) -> nn.Module
TorchCheckpointWrapper¶
@CheckpointWrapper.register("torch")
class TorchCheckpointWrapper(CheckpointWrapper)
wrap_module¶
class TorchCheckpointWrapper(CheckpointWrapper):
| ...
| def wrap_module(self, module: nn.Module, **kwargs) -> nn.Module
Wrap a module so that the forward method uses PyTorch's checkpointing functionality.
Note
Currently this CheckpointWrapper
implementation requires that the wrapped
module is called with positional arguments only.
We recommend you use the
FairScaleCheckpointWrapper
if you need more flexibility.