fairscale_checkpoint_wrapper
allennlp.nn.checkpoint.fairscale_checkpoint_wrapper
FairScaleCheckpointWrapper¶
@CheckpointWrapper.register("fairscale")
class FairScaleCheckpointWrapper(CheckpointWrapper):
| def __init__(
| self,
| offload_to_cpu: Optional[bool] = True,
| maintain_forward_counter: Optional[bool] = None
| ) -> None
Provides FairScale's activation/gradient checkpointing functionality.
The parameters and their defaults are the same as they are in FairScale, and
any of them can be overriden on a per-module basis by passing the corresponding parameter
to .wrap_module().
This can also be used in conjunction with the
FairScaleFsdpAccelerator.
See the T5 implementation for an example
of how to use the two together.
Note
If using the FairScaleFsdpAccelerator, you need to set maintain_forward_counter to True.
For convenience, if maintain_forward_counter is not set, internally it will be
set to True if training in a distributed setup, or False otherwise.
wrap_module¶
class FairScaleCheckpointWrapper(CheckpointWrapper):
| ...
| @overrides
| def wrap_module(self, module: nn.Module, **kwargs,) -> nn.Module