Skip to content

model

[ allennlp.models.model ]


Model is an abstract class representing an AllenNLP model.

Model Objects#

class Model(torch.nn.Module,  Registrable):
 | def __init__(
 |     self,
 |     vocab: Vocabulary,
 |     regularizer: RegularizerApplicator = None
 | ) -> None

This abstract class represents a model to be trained. Rather than relying completely on the Pytorch Module, we modify the output spec of forward to be a dictionary.

Models built using this API are still compatible with other pytorch models and can be used naturally as modules within other models - outputs are dictionaries, which can be unpacked and passed into other layers. One caveat to this is that if you wish to use an AllenNLP model inside a Container (such as nn.Sequential), you must interleave the models with a wrapper module which unpacks the dictionary into a list of tensors.

In order for your model to be trained using the Trainer api, the output dictionary of your Model must include a "loss" key, which will be optimised during the training process.

Finally, you can optionally implement Model.get_metrics in order to make use of early stopping and best-model serialization based on a validation metric in Trainer. Metrics that begin with "_" will not be logged to the progress bar by Trainer.

The from_archive method on this class is registered as a Model with name "from_archive". So, if you are using a configuration file, you can specify a model as {"type": "from_archive", "archive_file": "/path/to/archive.tar.gz"}, which will pull out the model from the given location and return it.

Parameters

  • vocab : Vocabulary
    There are two typical use-cases for the Vocabulary in a Model: getting vocabulary sizes when constructing embedding matrices or output classifiers (as the vocabulary holds the number of classes in your output, also), and translating model output into human-readable form.

    In a typical AllenNLP configuration file, this parameter does not get an entry under the "model", it gets specified as a top-level parameter, then is passed in to the model separately. - regularizer : RegularizerApplicator, optional
    If given, the Trainer will use this to regularize model parameters.

default_predictor#

default_predictor = None

get_regularization_penalty#

 | def get_regularization_penalty(self) -> Union[float, torch.Tensor]

Computes the regularization penalty for the model. Returns 0 if the model was not configured to use regularization.

get_parameters_for_histogram_tensorboard_logging#

 | def get_parameters_for_histogram_tensorboard_logging(
 |     self
 | ) -> List[str]

Returns the name of model parameters used for logging histograms to tensorboard.

forward#

 | def forward(self, *inputs) -> Dict[str, torch.Tensor]

Defines the forward pass of the model. In addition, to facilitate easy training, this method is designed to compute a loss function defined by a user.

The input is comprised of everything required to perform a training update, including labels - you define the signature here! It is down to the user to ensure that inference can be performed without the presence of these labels. Hence, any inputs not available at inference time should only be used inside a conditional block.

The intended sketch of this method is as follows::

def forward(self, input1, input2, targets=None):
    ....
    ....
    output1 = self.layer1(input1)
    output2 = self.layer2(input2)
    output_dict = {"output1": output1, "output2": output2}
    if targets is not None:
        # Function returning a scalar torch.Tensor, defined by the user.
        loss = self._compute_loss(output1, output2, targets)
        output_dict["loss"] = loss
    return output_dict

Parameters

  • *inputs : Any
    Tensors comprising everything needed to perform a training update, including labels, which should be optional (i.e have a default value of None). At inference time, simply pass the relevant inputs, not including the labels.

Returns

  • output_dict : Dict[str, torch.Tensor]
    The outputs from the model. In order to train a model using the Trainer api, you must provide a "loss" key pointing to a scalar torch.Tensor representing the loss to be optimized.

forward_on_instance#

 | def forward_on_instance(
 |     self,
 |     instance: Instance
 | ) -> Dict[str, numpy.ndarray]

Takes an Instance, which typically has raw text in it, converts that text into arrays using this model's Vocabulary, passes those arrays through self.forward() and self.make_output_human_readable() (which by default does nothing) and returns the result. Before returning the result, we convert any torch.Tensors into numpy arrays and remove the batch dimension.

forward_on_instances#

 | def forward_on_instances(
 |     self,
 |     instances: List[Instance]
 | ) -> List[Dict[str, numpy.ndarray]]

Takes a list of Instances, converts that text into arrays using this model's Vocabulary, passes those arrays through self.forward() and self.make_output_human_readable() (which by default does nothing) and returns the result. Before returning the result, we convert any torch.Tensors into numpy arrays and separate the batched output into a list of individual dicts per instance. Note that typically this will be faster on a GPU (and conditionally, on a CPU) than repeated calls to forward_on_instance.

Parameters

  • instances : List[Instance]
    The instances to run the model on.

Returns

  • A list of the models output for each instance.

make_output_human_readable#

 | def make_output_human_readable(
 |     self,
 |     output_dict: Dict[str, torch.Tensor]
 | ) -> Dict[str, torch.Tensor]

Takes the result of forward and makes it human readable. Most of the time, the only thing this method does is convert tokens / predicted labels from tensors to strings that humans might actually understand. Somtimes you'll also do an argmax or something in here, too, but that most often happens in Model.forward, before you compute your metrics.

This method modifies the input dictionary, and also returns the same dictionary.

By default in the base class we do nothing.

get_metrics#

 | def get_metrics(self, reset: bool = False) -> Dict[str, float]

Returns a dictionary of metrics. This method will be called by allennlp.training.Trainer in order to compute and use model metrics for early stopping and model serialization. We return an empty dictionary here rather than raising as it is not required to implement metrics for a new model. A boolean reset parameter is passed, as frequently a metric accumulator will have some state which should be reset between epochs. This is also compatible with Metrics. Metrics should be populated during the call to forward, with the Metric handling the accumulation of the metric until this method is called.

load#

 | @classmethod
 | def load(
 |     cls,
 |     config: Params,
 |     serialization_dir: str,
 |     weights_file: Optional[str] = None,
 |     cuda_device: int = -1,
 |     opt_level: Optional[str] = None
 | ) -> "Model"

Instantiates an already-trained model, based on the experiment configuration and some optional overrides.

Parameters

  • config : Params
    The configuration that was used to train the model. It should definitely have a model section, and should probably have a trainer section as well.
  • serialization_dir : str = None
    The directory containing the serialized weights, parameters, and vocabulary of the model.
  • weights_file : str = None
    By default we load the weights from best.th in the serialization directory, but you can override that value here.
  • cuda_device : int = -1
    By default we load the model on the CPU, but if you want to load it for GPU usage you can specify the id of your GPU here
  • opt_level : str, optional (default = None)
    Each opt_level establishes a set of properties that govern Amp’s implementation of pure or mixed precision training. Must be a choice of "O0", "O1", "O2", or "O3". See the Apex documentation for more details. If None, defaults to the opt_level found in the model params. If cuda_device==-1, Amp is not used and this argument is ignored.

Returns

  • model : Model
    The model specified in the configuration, loaded with the serialized vocabulary and the trained weights.

extend_embedder_vocab#

 | def extend_embedder_vocab(
 |     self,
 |     embedding_sources_mapping: Dict[str, str] = None
 | ) -> None

Iterates through all embedding modules in the model and assures it can embed with the extended vocab. This is required in fine-tuning or transfer learning scenarios where model was trained with original vocabulary but during fine-tuning/transfer-learning, it will have it work with extended vocabulary (original + new-data vocabulary).

Parameters

  • embedding_sources_mapping : Dict[str, str], optional (default = None)
    Mapping from model_path to pretrained-file path of the embedding modules. If pretrained-file used at time of embedding initialization isn't available now, user should pass this mapping. Model path is path traversing the model attributes upto this embedding module. Eg. "_text_field_embedder.token_embedder_tokens".

from_archive#

 | @classmethod
 | def from_archive(
 |     cls,
 |     archive_file: str,
 |     vocab: Vocabulary = None
 | ) -> "Model"

Loads a model from an archive file. This basically just calls return archival.load_archive(archive_file).model. It exists as a method here for convenience, and so that we can register it for easy use for fine tuning an existing model from a config file.

If vocab is given, we will extend the loaded model's vocabulary using the passed vocab object (including calling extend_embedder_vocab, which extends embedding layers).

remove_pretrained_embedding_params#

def remove_pretrained_embedding_params(params: Params)