Skip to content



Helper functions for Trainers


class HasBeenWarned


class HasBeenWarned:
 | ...
 | tqdm_ignores_underscores = False


def move_optimizer_to_cuda(optimizer)

Move the optimizer state to GPU, if necessary. After calling, any parameter specific state in the optimizer will be located on the same device as the parameter.


def get_batch_size(batch: Union[Dict, torch.Tensor]) -> int

Returns the size of the batch dimension. Assumes a well-formed batch, returns 0 otherwise.


def time_to_str(timestamp: int) -> str

Convert seconds past Epoch to human readable string.


def str_to_time(time_str: str) -> datetime.datetime

Convert human readable string to datetime.datetime.


def data_loaders_from_params(
    params: Params,
    train: bool = True,
    validation: bool = True,
    test: bool = True,
    serialization_dir: Optional[Union[str, PathLike]] = None
) -> Dict[str, DataLoader]

Instantiate data loaders specified by the config.


def create_serialization_dir(
    params: Params,
    serialization_dir: Union[str, PathLike],
    recover: bool,
    force: bool
) -> None

This function creates the serialization directory if it doesn't exist. If it already exists and is non-empty, then it verifies that we're recovering from a training with an identical configuration.


  • params : Params
    A parameter object specifying an AllenNLP Experiment.
  • serialization_dir : str
    The directory in which to save results and logs.
  • recover : bool
    If True, we will try to recover from an existing serialization directory, and crash if the directory doesn't exist, or doesn't match the configuration we're given.
  • force : bool
    If True, we will overwrite the serialization directory if it already exists.


def enable_gradient_clipping(
    model: Model,
    grad_clipping: Optional[float]
) -> None


def rescale_gradients(
    model: Model,
    grad_norm: Optional[float] = None
) -> Optional[float]

Performs gradient rescaling. Is a no-op if gradient rescaling is not enabled.


def get_metrics(
    model: Model,
    total_loss: float,
    total_reg_loss: Optional[float],
    batch_loss: Optional[float],
    batch_reg_loss: Optional[float],
    num_batches: int,
    reset: bool = False
) -> Dict[str, float]

Gets the metrics but sets "loss" to the total loss divided by the num_batches so that the "loss" metric is "average loss per batch". Returns the "batch_loss" separately.


def get_train_and_validation_metrics(
    metrics: Dict
) -> Tuple[Dict[str, Any], Dict[str, Any]]

Utility function to separate out train_metrics and val_metrics.


def evaluate(
    model: Model,
    data_loader: DataLoader,
    cuda_device: Union[int, torch.device] = -1,
    batch_weight_key: str = None,
    output_file: str = None,
    predictions_output_file: str = None
) -> Dict[str, Any]


  • model : Model
    The model to evaluate
  • data_loader : DataLoader
    The DataLoader that will iterate over the evaluation data (data loaders already contain their data).
  • cuda_device : Union[int, torch.device], optional (default = -1)
    The cuda device to use for this evaluation. The model is assumed to already be using this device; this parameter is only used for moving the input data to the correct device.
  • batch_weight_key : str, optional (default = None)
    If given, this is a key in the output dictionary for each batch that specifies how to weight the loss for that batch. If this is not given, we use a weight of 1 for every batch.
  • metrics_output_file : str, optional (default = None)
    Optional path to write the final metrics to.
  • predictions_output_file : str, optional (default = None)
    Optional path to write the predictions to.


  • Dict[str, Any]
    The final metrics.


def description_from_metrics(metrics: Dict[str, float]) -> str


def make_vocab_from_params(
    params: Params,
    serialization_dir: Union[str, PathLike],
    print_statistics: bool = False
) -> Vocabulary


def ngrams(
    tensor: torch.LongTensor,
    ngram_size: int,
    exclude_indices: Set[int]
) -> Dict[Tuple[int, ...], int]


def get_valid_tokens_mask(
    tensor: torch.LongTensor,
    exclude_indices: Set[int]
) -> torch.ByteTensor