Skip to content

fairscale_fsdp_accelerator

allennlp.nn.parallel.fairscale_fsdp_accelerator

[SOURCE]


FairScaleFsdpWrappedModel

class FairScaleFsdpWrappedModel(DdpWrappedModel)

The wrapped model type returned from FairScaleFsdpWrappedModel.wrap_model.

consolidate_sharded_state

class FairScaleFsdpWrappedModel(DdpWrappedModel):
 | ...
 | @staticmethod
 | def consolidate_sharded_state(
 |     sharded_state_files: Sequence[Union[str, os.PathLike]]
 | ) -> StateDictType

load_state_dict

class FairScaleFsdpWrappedModel(DdpWrappedModel):
 | ...
 | def load_state_dict(
 |     self,
 |     state_dict: StateDictType,
 |     strict: bool = True
 | ) -> LoadStateDictReturnType

state_dict

class FairScaleFsdpWrappedModel(DdpWrappedModel):
 | ...
 | def state_dict(self, *args, **kwargs) -> StateDictType

clip_grad_norm_

class FairScaleFsdpWrappedModel(DdpWrappedModel):
 | ...
 | def clip_grad_norm_(self, max_norm: Union[float, int]) -> torch.Tensor

init_grad_scaler

class FairScaleFsdpWrappedModel(DdpWrappedModel):
 | ...
 | def init_grad_scaler(self) -> amp.GradScaler

FairScaleFsdpAccelerator

@DdpAccelerator.register("fairscale_fsdp")
class FairScaleFsdpAccelerator(DdpAccelerator):
 | def __init__(
 |     self,
 |     *, mixed_precision: bool = False,
 |     *, reshard_after_forward: bool = True,
 |     *, flatten_parameters: bool = True,
 |     *, local_rank: Optional[int] = None,
 |     *, world_size: Optional[int] = None,
 |     *, cuda_device: Union[torch.device, int] = -1
 | ) -> None

A DdpAccelerator for FairScale's FullyShardedDataParallel.

To save memory while initializing a model, you should call .wrap_module() on submodules as they're created.

See the T5 class for an example of how to use this.

wrap_model

class FairScaleFsdpAccelerator(DdpAccelerator):
 | ...
 | def wrap_model(
 |     self,
 |     model: "Model"
 | ) -> Tuple["Model", DdpWrappedModel]

wrap_module

class FairScaleFsdpAccelerator(DdpAccelerator):
 | ...
 | def wrap_module(self, module: torch.nn.Module) -> torch.nn.Module