Skip to content

fairscale_checkpoint_wrapper

allennlp.nn.checkpoint.fairscale_checkpoint_wrapper

[SOURCE]


FairScaleCheckpointWrapper

@CheckpointWrapper.register("fairscale")
class FairScaleCheckpointWrapper(CheckpointWrapper):
 | def __init__(self, offload_to_cpu: Optional[bool] = True) -> None

Provides FairScale's activation/gradient checkpointing functionality.

The parameters and their defaults are the same as they are in FairScale, and any of them can be overriden on a per-module basis by passing the corresponding parameter to .wrap_module().

This can also be used in conjunction with the FairScaleFsdpAccelerator. See the T5 implementation for an example of how to use the two together.

wrap_module

class FairScaleCheckpointWrapper(CheckpointWrapper):
 | ...
 | def wrap_module(self, module: nn.Module, **kwargs,) -> nn.Module