pretrained_transformer_mismatched_embedder
allennlp.modules.token_embedders.pretrained_transformer_mismatched_embedder
PretrainedTransformerMismatchedEmbedder#
@TokenEmbedder.register("pretrained_transformer_mismatched")
class PretrainedTransformerMismatchedEmbedder(TokenEmbedder):
| def __init__(
| self,
| model_name: str,
| max_length: int = None,
| train_parameters: bool = True,
| last_layer_only: bool = True,
| gradient_checkpointing: Optional[bool] = None,
| tokenizer_kwargs: Optional[Dict[str, Any]] = None,
| transformer_kwargs: Optional[Dict[str, Any]] = None
| ) -> None
Use this embedder to embed wordpieces given by PretrainedTransformerMismatchedIndexer
and to pool the resulting vectors to get word-level representations.
Registered as a TokenEmbedder
with name "pretrained_transformer_mismatched".
Parameters
- model_name :
str
The name of thetransformers
model to use. Should be the same as the correspondingPretrainedTransformerMismatchedIndexer
. - max_length :
int
, optional (default =None
)
If positive, folds input token IDs into multiple segments of this length, pass them through the transformer model independently, and concatenate the final representations. Should be set to the same value as themax_length
option on thePretrainedTransformerMismatchedIndexer
. - train_parameters :
bool
, optional (default =True
)
If this isTrue
, the transformer weights get updated during training. - last_layer_only :
bool
, optional (default =True
)
WhenTrue
(the default), only the final layer of the pretrained transformer is taken for the embeddings. But if set toFalse
, a scalar mix of all of the layers is used. - gradient_checkpointing :
bool
, optional (default =None
)
Enable or disable gradient checkpointing. - tokenizer_kwargs :
Dict[str, Any]
, optional (default =None
)
Dictionary with additional arguments forAutoTokenizer.from_pretrained
. - transformer_kwargs :
Dict[str, Any]
, optional (default =None
)
Dictionary with additional arguments forAutoModel.from_pretrained
.
get_output_dim#
class PretrainedTransformerMismatchedEmbedder(TokenEmbedder):
| ...
| @overrides
| def get_output_dim(self)
forward#
class PretrainedTransformerMismatchedEmbedder(TokenEmbedder):
| ...
| @overrides
| def forward(
| self,
| token_ids: torch.LongTensor,
| mask: torch.BoolTensor,
| offsets: torch.LongTensor,
| wordpiece_mask: torch.BoolTensor,
| type_ids: Optional[torch.LongTensor] = None,
| segment_concat_mask: Optional[torch.BoolTensor] = None
| ) -> torch.Tensor
Parameters
- token_ids :
torch.LongTensor
Shape: [batch_size, num_wordpieces] (for exception seePretrainedTransformerEmbedder
). - mask :
torch.BoolTensor
Shape: [batch_size, num_orig_tokens]. - offsets :
torch.LongTensor
Shape: [batch_size, num_orig_tokens, 2]. Maps indices for the original tokens, i.e. those given as input to the indexer, to a span in token_ids.token_ids[i][offsets[i][j][0]:offsets[i][j][1] + 1]
corresponds to the original j-th token from the i-th batch. - wordpiece_mask :
torch.BoolTensor
Shape: [batch_size, num_wordpieces]. - type_ids :
Optional[torch.LongTensor]
Shape: [batch_size, num_wordpieces]. - segment_concat_mask :
Optional[torch.BoolTensor]
SeePretrainedTransformerEmbedder
.
Returns
torch.Tensor
Shape: [batch_size, num_orig_tokens, embedding_size].