span_extractor_with_span_width_embedding
allennlp.modules.span_extractors.span_extractor_with_span_width_embedding
SpanExtractorWithSpanWidthEmbedding¶
class SpanExtractorWithSpanWidthEmbedding(SpanExtractor):
| def __init__(
| self,
| input_dim: int,
| num_width_embeddings: int = None,
| span_width_embedding_dim: int = None,
| bucket_widths: bool = False
| ) -> None
SpanExtractorWithSpanWidthEmbedding
implements some common code for span
extractors which will need to embed span width.
Specifically, we initiate the span width embedding matrix and other
attributes in __init__
, leave an _embed_spans
method that can be
implemented to compute span embeddings in different ways, and in forward
we concatenate span embeddings returned by _embed_spans
with span width
embeddings to form the final span representations.
We keep SpanExtractor as a purely abstract base class, just in case someone wants to build a totally different span extractor.
Parameters¶
- input_dim :
int
The final dimension of thesequence_tensor
. - num_width_embeddings :
int
, optional (default =None
)
Specifies the number of buckets to use when representing span width features. - span_width_embedding_dim :
int
, optional (default =None
)
The embedding size for the span_width features. - bucket_widths :
bool
, optional (default =False
)
Whether to bucket the span widths into log-space buckets. IfFalse
, the raw span widths are used.
Returns¶
- span_embeddings :
torch.FloatTensor
.
A tensor of shape(batch_size, num_spans, embedded_span_size)
, whereembedded_span_size
depends on the way spans are represented.
forward¶
class SpanExtractorWithSpanWidthEmbedding(SpanExtractor):
| ...
| def forward(
| self,
| sequence_tensor: torch.FloatTensor,
| span_indices: torch.LongTensor,
| sequence_mask: torch.BoolTensor = None,
| span_indices_mask: torch.BoolTensor = None
| )
Given a sequence tensor, extract spans, concatenate width embeddings when need and return representations of them.
Parameters¶
- sequence_tensor :
torch.FloatTensor
A tensor of shape (batch_size, sequence_length, embedding_size) representing an embedded sequence of words. - span_indices :
torch.LongTensor
A tensor of shape(batch_size, num_spans, 2)
, where the last dimension represents the inclusive start and end indices of the span to be extracted from thesequence_tensor
. - sequence_mask :
torch.BoolTensor
, optional (default =None
)
A tensor of shape (batch_size, sequence_length) representing padded elements of the sequence. - span_indices_mask :
torch.BoolTensor
, optional (default =None
)
A tensor of shape (batch_size, num_spans) representing the valid spans in theindices
tensor. This mask is optional because sometimes it's easier to worry about masking after calling this function, rather than passing a mask directly.
Returns¶
- A tensor of shape
(batch_size, num_spans, embedded_span_size)
, - where
embedded_span_size
depends on the way spans are represented.
get_input_dim¶
class SpanExtractorWithSpanWidthEmbedding(SpanExtractor):
| ...
| def get_input_dim(self) -> int