fairscale_fsdp_accelerator
allennlp.nn.parallel.fairscale_fsdp_accelerator
FairScaleFsdpWrappedModel¶
class FairScaleFsdpWrappedModel(DdpWrappedModel)
The wrapped model type returned from FairScaleFsdpWrappedModel.wrap_model
.
consolidate_sharded_state¶
class FairScaleFsdpWrappedModel(DdpWrappedModel):
| ...
| @staticmethod
| def consolidate_sharded_state(
| sharded_state_files: Sequence[Union[str, os.PathLike]]
| ) -> StateDictType
load_state_dict¶
class FairScaleFsdpWrappedModel(DdpWrappedModel):
| ...
| def load_state_dict(
| self,
| state_dict: StateDictType,
| strict: bool = True
| ) -> LoadStateDictReturnType
state_dict¶
class FairScaleFsdpWrappedModel(DdpWrappedModel):
| ...
| def state_dict(self, *args, **kwargs) -> StateDictType
clip_grad_norm_¶
class FairScaleFsdpWrappedModel(DdpWrappedModel):
| ...
| def clip_grad_norm_(self, max_norm: Union[float, int]) -> torch.Tensor
init_grad_scaler¶
class FairScaleFsdpWrappedModel(DdpWrappedModel):
| ...
| def init_grad_scaler(self) -> amp.GradScaler
FairScaleFsdpAccelerator¶
@DdpAccelerator.register("fairscale_fsdp")
class FairScaleFsdpAccelerator(DdpAccelerator):
| def __init__(
| self,
| *, mixed_precision: bool = False,
| *, reshard_after_forward: bool = True,
| *, flatten_parameters: bool = True,
| *, local_rank: Optional[int] = None,
| *, world_size: Optional[int] = None,
| *, cuda_device: Union[torch.device, int] = -1
| ) -> None
A DdpAccelerator
for FairScale's FullyShardedDataParallel
.
To save memory while initializing a model, you should call .wrap_module()
on submodules
as they're created.
See the T5
class for an example of how to use this.
wrap_model¶
class FairScaleFsdpAccelerator(DdpAccelerator):
| ...
| def wrap_model(
| self,
| model: "Model"
| ) -> Tuple["Model", DdpWrappedModel]
wrap_module¶
class FairScaleFsdpAccelerator(DdpAccelerator):
| ...
| def wrap_module(self, module: torch.nn.Module) -> torch.nn.Module