Skip to content

checkpoint_wrapper

allennlp.nn.checkpoint.checkpoint_wrapper

[SOURCE]


CheckpointWrapper

class CheckpointWrapper(Registrable)

A CheckpointWrapper is used to enable activation/gradient checkpointing on modules that you wrap via the .wrap_module() method.

default_implementation

class CheckpointWrapper(Registrable):
 | ...
 | default_implementation = "torch"

wrap_module

class CheckpointWrapper(Registrable):
 | ...
 | def wrap_module(self, module: nn.Module, **kwargs) -> nn.Module

TorchCheckpointWrapper

@CheckpointWrapper.register("torch")
class TorchCheckpointWrapper(CheckpointWrapper)

wrap_module

class TorchCheckpointWrapper(CheckpointWrapper):
 | ...
 | def wrap_module(self, module: nn.Module, **kwargs) -> nn.Module

Wrap a module so that the forward method uses PyTorch's checkpointing functionality.

Note

Currently this CheckpointWrapper implementation requires that the wrapped module is called with positional arguments only.

We recommend you use the FairScaleCheckpointWrapper if you need more flexibility.