Skip to content

checkpointer

allennlp.training.checkpointer

[SOURCE]


Checkpointer#

class Checkpointer(Registrable):
 | def __init__(
 |     self,
 |     serialization_dir: str,
 |     keep_serialized_model_every_num_seconds: int = None,
 |     num_serialized_models_to_keep: int = 2,
 |     model_save_interval: float = None
 | ) -> None

This class implements the functionality for checkpointing your model and trainer state during training. It is agnostic as to what those states look like (they are typed as Dict[str, Any]), but they will be fed to torch.save so they should be serializable in that sense. They will also be restored as Dict[str, Any], which means the calling code is responsible for knowing what to do with them.

Parameters

  • num_serialized_models_to_keep : int, optional (default = 2)
    Number of previous model checkpoints to retain. Default is to keep 2 checkpoints. A value of None or -1 means all checkpoints will be kept.

    In a typical AllenNLP configuration file, this argument does not get an entry under the "checkpointer", it gets passed in separately. - keep_serialized_model_every_num_seconds : int, optional (default = None)
    If num_serialized_models_to_keep is not None, then occasionally it's useful to save models at a given interval in addition to the last num_serialized_models_to_keep. To do so, specify keep_serialized_model_every_num_seconds as the number of seconds between permanently saved checkpoints. Note that this option is only used if num_serialized_models_to_keep is not None, otherwise all checkpoints are kept. - model_save_interval : float, optional (default = None)
    If provided, then serialize models every model_save_interval seconds within single epochs. In all cases, models are also saved at the end of every epoch if serialization_dir is provided.

default_implementation#

class Checkpointer(Registrable):
 | ...
 | default_implementation = "default"

maybe_save_checkpoint#

class Checkpointer(Registrable):
 | ...
 | def maybe_save_checkpoint(
 |     self,
 |     trainer: "allennlp.training.trainer.Trainer",
 |     epoch: int,
 |     batches_this_epoch: int
 | ) -> None

Given amount of time lapsed between the last save and now (tracked internally), the current epoch, and the number of batches seen so far this epoch, this method decides whether to save a checkpoint or not. If we decide to save a checkpoint, we grab whatever state we need out of the Trainer and save it.

This function is intended to be called at the end of each batch in an epoch (perhaps because your data is large enough that you don't really have "epochs"). The default implementation only looks at time, not batch or epoch number, though those parameters are available to you if you want to customize the behavior of this function.

save_checkpoint#

class Checkpointer(Registrable):
 | ...
 | def save_checkpoint(
 |     self,
 |     epoch: Union[int, str],
 |     trainer: "allennlp.training.trainer.Trainer",
 |     is_best_so_far: bool = False,
 |     save_model_only=False
 | ) -> None

find_latest_checkpoint#

class Checkpointer(Registrable):
 | ...
 | def find_latest_checkpoint(self) -> Optional[Tuple[str, str]]

Return the location of the latest model and training state files. If there isn't a valid checkpoint then return None.

restore_checkpoint#

class Checkpointer(Registrable):
 | ...
 | def restore_checkpoint(self) -> Tuple[Dict[str, Any], Dict[str, Any]]

Restores a model from a serialization_dir to the last saved checkpoint. This includes a training state (typically consisting of an epoch count and optimizer state), which is serialized separately from model parameters. This function should only be used to continue training - if you wish to load a model for inference/load parts of a model into a new computation graph, you should use the native Pytorch functions: model.load_state_dict(torch.load("/path/to/model/weights.th"))

If self._serialization_dir does not exist or does not contain any checkpointed weights, this function will do nothing and return empty dicts.

Returns

  • states : Tuple[Dict[str, Any], Dict[str, Any]]
    The model state and the training state.

best_model_state#

class Checkpointer(Registrable):
 | ...
 | def best_model_state(self) -> Dict[str, Any]