@SpanExtractor.register("self_attentive") class SelfAttentiveSpanExtractor(SpanExtractor): | def __init__(self, input_dim: int) -> None
Computes span representations by generating an unnormalized attention score for each word in the document. Spans representations are computed with respect to these scores by normalising the attention scores for words inside the span.
Given these attention distributions over every span, this module weights the corresponding vector representations of the words in the span by this distribution, returning a weighted representation of each span.
Registered as a
SpanExtractor with name "self_attentive".
- input_dim :
The final dimension of the
- attended_text_embeddings :
A tensor of shape (batch_size, num_spans, input_dim), which each span representation is formed by locally normalising a global attention over the sequence. The only way in which the attention distribution differs over different spans is in the set of words over which they are normalized.
class SelfAttentiveSpanExtractor(SpanExtractor): | ... | def get_input_dim(self) -> int
class SelfAttentiveSpanExtractor(SpanExtractor): | ... | def get_output_dim(self) -> int
class SelfAttentiveSpanExtractor(SpanExtractor): | ... | @overrides | def forward( | self, | sequence_tensor: torch.FloatTensor, | span_indices: torch.LongTensor, | span_indices_mask: torch.BoolTensor = None | ) -> torch.FloatTensor
shape (batch_size, sequence_length, 1)