Skip to content





@Seq2SeqEncoder.register("multi_head_self_attention", exist_ok=True)
class MultiHeadSelfAttention(Seq2SeqEncoder):
 | def __init__(
 |     self,
 |     num_heads: int,
 |     input_dim: int,
 |     attention_dim: int,
 |     values_dim: int,
 |     output_projection_dim: int = None,
 |     attention_dropout_prob: float = 0.1
 | ) -> None

This class implements the key-value scaled dot product attention mechanism detailed in the paper Attention is all you Need.

The attention mechanism is a weighted sum of a projection V of the inputs, with respect to the scaled, normalised dot product of Q and K, which are also both linear projections of the input. This procedure is repeated for each attention head, using different parameters.


  • num_heads : int
    The number of attention heads to use.
  • input_dim : int
    The size of the last dimension of the input tensor. attention_dim int, required. The total dimension of the query and key projections which comprise the dot product attention function. Must be divisible by num_heads.
  • values_dim : int
    The total dimension which the input is projected to for representing the values, which are combined using the attention. Must be divisible by num_heads.
  • output_projection_dim : int, optional (default = None)
    The dimensionality of the final output projection. If this is not passed explicitly, the projection has size input_size.
  • attention_dropout_prob : float, optional (default = 0.1)
    The dropout probability applied to the normalised attention distributions.


class MultiHeadSelfAttention(Seq2SeqEncoder):
 | ...
 | def get_input_dim(self)


class MultiHeadSelfAttention(Seq2SeqEncoder):
 | ...
 | def get_output_dim(self)


class MultiHeadSelfAttention(Seq2SeqEncoder):
 | ...
 | def is_bidirectional(self)


class MultiHeadSelfAttention(Seq2SeqEncoder):
 | ...
 | def forward(
 |     self,
 |     inputs: torch.Tensor,
 |     mask: torch.BoolTensor = None
 | ) -> torch.FloatTensor


  • inputs : torch.FloatTensor
    A tensor of shape (batch_size, timesteps, input_dim)
  • mask : torch.BoolTensor, optional (default = None)
    A tensor of shape (batch_size, timesteps).


  • A tensor of shape (batch_size, timesteps, output_projection_dim),
  • where output_projection_dim = input_dim by default.