pytorch_transformer_wrapper
allennlp.modules.seq2seq_encoders.pytorch_transformer_wrapper
PytorchTransformer#
@Seq2SeqEncoder.register("pytorch_transformer")
class PytorchTransformer(Seq2SeqEncoder):
| def __init__(
| self,
| input_dim: int,
| num_layers: int,
| feedforward_hidden_dim: int = 2048,
| num_attention_heads: int = 8,
| positional_encoding: Optional[str] = None,
| positional_embedding_size: int = 512,
| dropout_prob: float = 0.1,
| activation: str = "relu"
| ) -> None
Implements a stacked self-attention encoder similar to the Transformer architecture in Attention is all you Need.
This class adapts the Transformer from torch.nn for use in AllenNLP. Optionally, it adds positional encodings.
Registered as a Seq2SeqEncoder
with name "pytorch_transformer".
Parameters
- input_dim :
int
The input dimension of the encoder. - feedforward_hidden_dim :
int
The middle dimension of the FeedForward network. The input and output dimensions are fixed to ensure sizes match up for the self attention layers. - num_layers :
int
The number of stacked self attention -> feedforward -> layer normalisation blocks. - num_attention_heads :
int
The number of attention heads to use per layer. - use_positional_encoding :
bool
, optional (default =True
)
Whether to add sinusoidal frequencies to the input tensor. This is strongly recommended, as without this feature, the self attention layers have no idea of absolute or relative position (as they are just computing pairwise similarity between vectors of elements), which can be important features for many tasks. - dropout_prob :
float
, optional (default =0.1
)
The dropout probability for the feedforward network.
get_input_dim#
class PytorchTransformer(Seq2SeqEncoder):
| ...
| @overrides
| def get_input_dim(self) -> int
get_output_dim#
class PytorchTransformer(Seq2SeqEncoder):
| ...
| @overrides
| def get_output_dim(self) -> int
is_bidirectional#
class PytorchTransformer(Seq2SeqEncoder):
| ...
| @overrides
| def is_bidirectional(self)
forward#
class PytorchTransformer(Seq2SeqEncoder):
| ...
| @overrides
| def forward(self, inputs: torch.Tensor, mask: torch.BoolTensor)