fairscale_checkpoint_wrapper
allennlp.nn.checkpoint.fairscale_checkpoint_wrapper
FairScaleCheckpointWrapper¶
@CheckpointWrapper.register("fairscale")
class FairScaleCheckpointWrapper(CheckpointWrapper):
| def __init__(self, offload_to_cpu: Optional[bool] = True) -> None
Provides FairScale's activation/gradient checkpointing functionality.
The parameters and their defaults are the same as they are in FairScale, and
any of them can be overriden on a per-module basis by passing the corresponding parameter
to .wrap_module()
.
This can also be used in conjunction with the
FairScaleFsdpAccelerator
.
See the T5 implementation for an example
of how to use the two together.
wrap_module¶
class FairScaleCheckpointWrapper(CheckpointWrapper):
| ...
| def wrap_module(self, module: nn.Module, **kwargs,) -> nn.Module