Skip to content


[ allennlp.modules.token_embedders.pretrained_transformer_embedder ]

PretrainedTransformerEmbedder Objects#

class PretrainedTransformerEmbedder(TokenEmbedder):
 | def __init__(
 |     self,
 |     model_name: str,
 |     *,
 |     max_length: int = None,
 |     sub_module: str = None,
 |     train_parameters: bool = True,
 |     override_weights_file: Optional[str] = None,
 |     override_weights_strip_prefix: Optional[str] = None
 | ) -> None

Uses a pretrained model from transformers as a TokenEmbedder.

Registered as a TokenEmbedder with name "pretrained_transformer".


  • model_name : str
    The name of the transformers model to use. Should be the same as the corresponding PretrainedTransformerIndexer.
  • max_length : int, optional (default = None)
    If positive, folds input token IDs into multiple segments of this length, pass them through the transformer model independently, and concatenate the final representations. Should be set to the same value as the max_length option on the PretrainedTransformerIndexer.
  • sub_module : str, optional (default = None)
    The name of a submodule of the transformer to be used as the embedder. Some transformers naturally act as embedders such as BERT. However, other models consist of encoder and decoder, in which case we just want to use the encoder.
  • train_parameters : bool, optional (default = True)
    If this is True, the transformer weights get updated during training.


 | @overrides
 | def get_output_dim(self)


 | @overrides
 | def forward(
 |     self,
 |     token_ids: torch.LongTensor,
 |     mask: torch.BoolTensor,
 |     type_ids: Optional[torch.LongTensor] = None,
 |     segment_concat_mask: Optional[torch.BoolTensor] = None
 | ) -> torch.Tensor


  • token_ids : torch.LongTensor
    Shape: [batch_size, num_wordpieces if max_length is None else num_segment_concat_wordpieces]. num_segment_concat_wordpieces is num_wordpieces plus special tokens inserted in the middle, e.g. the length of: "[CLS] A B C [SEP] [CLS] D E F [SEP]" (see indexer logic).
  • mask : torch.BoolTensor
    Shape: [batch_size, num_wordpieces].
  • type_ids : Optional[torch.LongTensor]
    Shape: [batch_size, num_wordpieces if max_length is None else num_segment_concat_wordpieces].
  • segment_concat_mask : Optional[torch.BoolTensor]
    Shape: [batch_size, num_segment_concat_wordpieces].


  • torch.Tensor
    Shape: [batch_size, num_wordpieces, embedding_size].