Skip to content

checkpointer

allennlp.training.checkpointer

[SOURCE]


Checkpointer

class Checkpointer(Registrable):
 | def __init__(
 |     self,
 |     serialization_dir: Union[str, os.PathLike],
 |     save_completed_epochs: bool = True,
 |     save_every_num_seconds: Optional[float] = None,
 |     save_every_num_batches: Optional[int] = None,
 |     keep_most_recent_by_count: Optional[int] = 2,
 |     keep_most_recent_by_age: Optional[int] = None
 | ) -> None

This class implements the functionality for checkpointing your model and trainer state during training. It is agnostic as to what those states look like (they are typed as Dict[str, Any]), but they will be fed to torch.save so they should be serializable in that sense. They will also be restored as Dict[str, Any], which means the calling code is responsible for knowing what to do with them.

Parameters

  • save_completed_epochs : bool, optional (default = True)
    Saves model and trainer state at the end of each completed epoch.
  • save_every_num_seconds : int, optional (default = None)
    If set, makes sure we never go longer than this number of seconds between saving a model.
  • save_every_num_batches : int, optional (default = None)
    If set, makes sure we never go longer than this number of batches between saving a model.
  • keep_most_recent_by_count : int, optional (default = 2)
    Sets the number of model checkpoints to keep on disk. If both keep_most_recent_by_count and keep_most_recent_by_age are set, we'll keep checkpoints that satisfy either criterion. If both are None, we keep all checkpoints.
  • keep_most_recent_by_age : int, optional (default = None)
    Sets the number of seconds we'll keep a checkpoint before deleting it. If both keep_most_recent_by_count and keep_most_recent_by_age are set, we'll keep checkpoints that satisfy either criterion. If both are None, we keep all checkpoints.

default_implementation

class Checkpointer(Registrable):
 | ...
 | default_implementation = "default"

maybe_save_checkpoint

class Checkpointer(Registrable):
 | ...
 | def maybe_save_checkpoint(
 |     self,
 |     trainer: Trainer,
 |     num_epochs_completed: int,
 |     num_batches_in_epoch_completed: int
 | ) -> bool

Figures out whether we need to save a checkpoint, and does so if necessary.

save_checkpoint

class Checkpointer(Registrable):
 | ...
 | def save_checkpoint(self, trainer: Trainer) -> None

find_latest_checkpoint

class Checkpointer(Registrable):
 | ...
 | def find_latest_checkpoint(self) -> Optional[Tuple[str, str]]

Return the location of the latest model and training state files. If there isn't a valid checkpoint then return None.

load_checkpoint

class Checkpointer(Registrable):
 | ...
 | def load_checkpoint(self) -> Optional[TrainerCheckpoint]

Loads model state from a serialization_dir corresponding to the last saved checkpoint. This includes a training state, which is serialized separately from model parameters. This function should only be used to continue training - if you wish to load a model for inference/load parts of a model into a new computation graph, you should use the native Pytorch functions: model.load_state_dict(torch.load("/path/to/model/weights.th"))

If self._serialization_dir does not exist or does not contain any checkpointed weights, this function will do nothing and return empty dicts.

Returns

  • states : Tuple[Dict[str, Any], Dict[str, Any]]
    The model state and the training state.