attention_module
allennlp.modules.transformer.attention_module
AttentionOutput¶
@dataclass
class AttentionOutput
Encapsulates the outputs of the Attention
module.
hidden_states¶
class AttentionOutput:
| ...
| hidden_states: FloatT = None
key_value_state¶
class AttentionOutput:
| ...
| key_value_state: Optional[Tuple[FloatT, FloatT]] = None
position_bias¶
class AttentionOutput:
| ...
| position_bias: Optional[FloatT] = None
attention_probs¶
class AttentionOutput:
| ...
| attention_probs: Optional[FloatT] = None
AttentionModule¶
class AttentionModule(TransformerModule, FromParams):
| def __init__(
| self,
| hidden_size: int = 512,
| attention_head_size: int = 64,
| num_attention_heads: int = 8,
| scoring_func: str = "scaled_dot_product",
| output_linear: bool = False,
| dropout: float = 0.0,
| bias: bool = True,
| normalize_weights: bool = False,
| is_decoder: bool = False,
| is_cross_attention: bool = False,
| relative_attention_num_buckets: Optional[int] = None
| )
This module computes self-attention (or cross-attention), similar to the architecture in BERT. Details in the paper: BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding, Devlin et al, 2019
Additionally, it has the following functionality:
- the attention scoring function can be specified.
- it can be used in encoders as well as decoders.
position_bias
can be used, which makes it suitable for T5-style attention as well.
Parameters¶
- hidden_size :
int
, optional (default =512
)
The size of the expected input tensor. - attention_head_size :
int
, optional (default =64
)
The size of a single attention head. - num_attention_heads :
int
, optional (default =8
)
The number of attention heads. - scoring_func :
str
, optional (default =scaled_dot_product
)
The name of the attention-calculating function to be used. Eg.additive
,linear
, etc. For a complete list, please checkmatrix_attention
. - output_linear :
bool
, optional (default =False
)
Whether to add an additional output linear layer at the end. - dropout :
float
, optional (default =0.0
)
The dropout probability. - bias :
bool
, optional (default =True
)
Whether to include bias weights in query, key, value (and output) linear layers. - normalize_weights :
bool
, optional (default =False
)
Whether to normalize the initial weights. - is_decoder :
bool
, optional (default =False
)
Whether this module is being used in a decoder stack or not. - is_cross_attention :
bool
, optional (default =False
)
Whether this module is being used for cross-attention in a decoder stack or not. Ifis_cross_attention
isTrue
, thenis_decoder
must also beTrue
. - relative_attention_num_buckets :
int
, optional (default =None
)
The number of buckets to use in relative attention; ifNone
, relative attention will not be applied.
forward¶
class AttentionModule(TransformerModule, FromParams):
| ...
| def forward(
| self,
| query_states: torch.Tensor,
| past_key_states: Optional[torch.Tensor] = None,
| past_value_states: Optional[torch.Tensor] = None,
| attention_mask: Optional[torch.BoolTensor] = None,
| source_states: Optional[torch.Tensor] = None,
| source_attention_mask: Optional[torch.BoolTensor] = None,
| head_mask: Optional[torch.Tensor] = None,
| position_bias: Optional[torch.Tensor] = None,
| output_attentions: bool = False,
| use_cache: bool = False,
| query_length: Optional[int] = None
| )
Parameters¶
- query_states :
torch.Tensor
Shapebatch_size x seq_len x hidden_dim
- past_key_states :
torch.Tensor
, optional
Shapebatch_size x seq_len x hidden_dim
These are the key_states from the previous step of the decoder. - past_value_states :
torch.Tensor
, optional
Shapebatch_size x seq_len x hidden_dim
These are the value_states from the previous step of the decoder. - attention_mask :
torch.BoolTensor
, optional
Shapebatch_size x seq_len
- source_states :
torch.Tensor
, optional
Shapebatch_size x source_seq_len x hidden_dim
This is from the final state of attention over the source (encoder); it is passed when this module is being used for cross-attention. - source_attention_mask :
torch.BoolTensor
, optional
Shapebatch_size x source_seq_len
- head_mask :
torch.BoolTensor
, optional - position_bias :
torch.Tensor
, optional - output_attentions :
bool
Whether to also return the attention probabilities, default =False
Note
source_states
needs to be passed in case of cross-attention.
compute_bias¶
class AttentionModule(TransformerModule, FromParams):
| ...
| def compute_bias(self, query_length: int, key_length: int) -> FloatT
Compute binned relative position bias
T5Attention¶
class T5Attention(AttentionModule):
| def __init__(
| self,
| is_decoder: bool = False,
| hidden_size: int = 512,
| key_value_proj_dim: int = 64,
| num_heads: int = 8,
| has_relative_attention_bias: bool = False,
| relative_attention_num_buckets: int = 32,
| dropout: float = 0.1,
| normalize: bool = True,
| is_cross_attention: bool = False
| )
forward¶
class T5Attention(AttentionModule):
| ...
| def forward(
| self,
| hidden_states: torch.Tensor,
| mask: Optional[torch.BoolTensor] = None,
| key_value_states: Optional[FloatT] = None,
| position_bias: Optional[FloatT] = None,
| past_key_value: Optional[
| Tuple[FloatT, FloatT]
| ] = None,
| layer_head_mask: Optional[BoolT] = None,
| query_length: Optional[int] = None,
| use_cache: bool = False,
| output_attentions: bool = False
| ) -> AttentionOutput
Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
SelfAttention¶
class SelfAttention(AttentionModule):
| def __init__(
| self,
| hidden_size: int,
| num_attention_heads: int,
| dropout: float = 0.0,
| scoring_func: str = "scaled_dot_product",
| output_linear: bool = False,
| is_decoder: bool = False,
| is_cross_attention: bool = False
| )
This module computes the self-attention, similar to the architecture in BERT. Additionally, the attention scoring function can be specified. Details in the paper: BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding, Devlin et al, 2019
Parameters¶
- hidden_size :
int
- num_attention_heads :
int
- dropout :
float
, optional (default =0.0
) - scoring_func :
str
, optional (default =scaled_dot_product
)
The name of the attention-calculating function to be used. Eg.additive
,linear
, etc. For a complete list, please checkattention
.