Skip to content





class FairScaleCheckpointWrapper(CheckpointWrapper):
 | def __init__(
 |     self,
 |     offload_to_cpu: Optional[bool] = True,
 |     maintain_forward_counter: Optional[bool] = None
 | ) -> None

Provides FairScale's activation/gradient checkpointing functionality.

The parameters and their defaults are the same as they are in FairScale, and any of them can be overriden on a per-module basis by passing the corresponding parameter to .wrap_module().

This can also be used in conjunction with the FairScaleFsdpAccelerator. See the T5 implementation for an example of how to use the two together.


If using the FairScaleFsdpAccelerator, you need to set maintain_forward_counter to True. For convenience, if maintain_forward_counter is not set, internally it will be set to True if training in a distributed setup, or False otherwise.


class FairScaleCheckpointWrapper(CheckpointWrapper):
 | ...
 | @overrides
 | def wrap_module(self, module: nn.Module, **kwargs,) -> nn.Module