fairscale_checkpoint_wrapper
allennlp.nn.checkpoint.fairscale_checkpoint_wrapper
FairScaleCheckpointWrapper¶
@CheckpointWrapper.register("fairscale")
class FairScaleCheckpointWrapper(CheckpointWrapper):
| def __init__(
| self,
| offload_to_cpu: Optional[bool] = True,
| maintain_forward_counter: Optional[bool] = None
| ) -> None
Provides FairScale's activation/gradient checkpointing functionality.
The parameters and their defaults are the same as they are in FairScale, and
any of them can be overriden on a per-module basis by passing the corresponding parameter
to .wrap_module()
.
This can also be used in conjunction with the
FairScaleFsdpAccelerator
.
See the T5 implementation for an example
of how to use the two together.
Note
If using the FairScaleFsdpAccelerator
, you need to set maintain_forward_counter
to True
.
For convenience, if maintain_forward_counter
is not set, internally it will be
set to True
if training in a distributed setup, or False
otherwise.
wrap_module¶
class FairScaleCheckpointWrapper(CheckpointWrapper):
| ...
| @overrides
| def wrap_module(self, module: nn.Module, **kwargs,) -> nn.Module