Skip to content

pytorch_transformer_wrapper

[ allennlp.modules.seq2seq_encoders.pytorch_transformer_wrapper ]


PytorchTransformer#

class PytorchTransformer(Seq2SeqEncoder):
 | def __init__(
 |     self,
 |     input_dim: int,
 |     num_layers: int,
 |     feedforward_hidden_dim: int = 2048,
 |     num_attention_heads: int = 8,
 |     positional_encoding: Optional[str] = None,
 |     positional_embedding_size: int = 512,
 |     dropout_prob: float = 0.1,
 |     activation: str = "relu"
 | ) -> None

Implements a stacked self-attention encoder similar to the Transformer architecture in Attention is all you Need.

This class adapts the Transformer from torch.nn for use in AllenNLP. Optionally, it adds positional encodings.

Registered as a Seq2SeqEncoder with name "pytorch_transformer".

Parameters

  • input_dim : int
    The input dimension of the encoder.
  • feedforward_hidden_dim : int
    The middle dimension of the FeedForward network. The input and output dimensions are fixed to ensure sizes match up for the self attention layers.
  • num_layers : int
    The number of stacked self attention -> feedforward -> layer normalisation blocks.
  • num_attention_heads : int
    The number of attention heads to use per layer.
  • use_positional_encoding : bool, optional (default = True)
    Whether to add sinusoidal frequencies to the input tensor. This is strongly recommended, as without this feature, the self attention layers have no idea of absolute or relative position (as they are just computing pairwise similarity between vectors of elements), which can be important features for many tasks.
  • dropout_prob : float, optional (default = 0.1)
    The dropout probability for the feedforward network.

get_input_dim#

class PytorchTransformer(Seq2SeqEncoder):
 | ...
 | @overrides
 | def get_input_dim(self) -> int

get_output_dim#

class PytorchTransformer(Seq2SeqEncoder):
 | ...
 | @overrides
 | def get_output_dim(self) -> int

is_bidirectional#

class PytorchTransformer(Seq2SeqEncoder):
 | ...
 | @overrides
 | def is_bidirectional(self)

forward#

class PytorchTransformer(Seq2SeqEncoder):
 | ...
 | @overrides
 | def forward(self, inputs: torch.Tensor, mask: torch.BoolTensor)