checkpointer
allennlp.training.checkpointer
Checkpointer¶
class Checkpointer(Registrable):
| def __init__(
| self,
| serialization_dir: Union[str, os.PathLike],
| save_completed_epochs: bool = True,
| save_every_num_seconds: Optional[float] = None,
| save_every_num_batches: Optional[int] = None,
| keep_most_recent_by_count: Optional[int] = 2,
| keep_most_recent_by_age: Optional[int] = None
| ) -> None
This class implements the functionality for checkpointing your model and trainer state
during training. It is agnostic as to what those states look like (they are typed as
Dict[str, Any]
), but they will be fed to torch.save
so they should be serializable
in that sense. They will also be restored as Dict[str, Any]
, which means the calling
code is responsible for knowing what to do with them.
Parameters¶
- save_completed_epochs :
bool
, optional (default =True
)
Saves model and trainer state at the end of each completed epoch. - save_every_num_seconds :
int
, optional (default =None
)
If set, makes sure we never go longer than this number of seconds between saving a model. - save_every_num_batches :
int
, optional (default =None
)
If set, makes sure we never go longer than this number of batches between saving a model. - keep_most_recent_by_count :
int
, optional (default =2
)
Sets the number of model checkpoints to keep on disk. If bothkeep_most_recent_by_count
andkeep_most_recent_by_age
are set, we'll keep checkpoints that satisfy either criterion. If both areNone
, we keep all checkpoints. - keep_most_recent_by_age :
int
, optional (default =None
)
Sets the number of seconds we'll keep a checkpoint before deleting it. If bothkeep_most_recent_by_count
andkeep_most_recent_by_age
are set, we'll keep checkpoints that satisfy either criterion. If both areNone
, we keep all checkpoints.
default_implementation¶
class Checkpointer(Registrable):
| ...
| default_implementation = "default"
maybe_save_checkpoint¶
class Checkpointer(Registrable):
| ...
| def maybe_save_checkpoint(
| self,
| trainer: Trainer,
| num_epochs_completed: int,
| num_batches_in_epoch_completed: int
| ) -> bool
Figures out whether we need to save a checkpoint, and does so if necessary.
save_checkpoint¶
class Checkpointer(Registrable):
| ...
| def save_checkpoint(self, trainer: Trainer) -> None
find_latest_checkpoint¶
class Checkpointer(Registrable):
| ...
| def find_latest_checkpoint(self) -> Optional[Tuple[str, str]]
Return the location of the latest model and training state files. If there isn't a valid checkpoint then return None.
load_checkpoint¶
class Checkpointer(Registrable):
| ...
| def load_checkpoint(self) -> Optional[TrainerCheckpoint]
Loads model state from a serialization_dir
corresponding to the last saved checkpoint.
This includes a training state, which is serialized separately from model parameters. This function
should only be used to continue training - if you wish to load a model for inference/load parts
of a model into a new computation graph, you should use the native Pytorch functions:
model.load_state_dict(torch.load("/path/to/model/weights.th"))
If self._serialization_dir
does not exist or does not contain any checkpointed weights,
this function will do nothing and return empty dicts.
Returns¶
- states :
Tuple[Dict[str, Any], Dict[str, Any]]
The model state and the training state.