Skip to content





class SpanExtractorWithSpanWidthEmbedding(SpanExtractor):
 | def __init__(
 |     self,
 |     input_dim: int,
 |     num_width_embeddings: int = None,
 |     span_width_embedding_dim: int = None,
 |     bucket_widths: bool = False
 | ) -> None

SpanExtractorWithSpanWidthEmbedding implements some common code for span extractors which will need to embed span width.

Specifically, we initiate the span width embedding matrix and other attributes in __init__, leave an _embed_spans method that can be implemented to compute span embeddings in different ways, and in forward we concatenate span embeddings returned by _embed_spans with span width embeddings to form the final span representations.

We keep SpanExtractor as a purely abstract base class, just in case someone wants to build a totally different span extractor.


  • input_dim : int
    The final dimension of the sequence_tensor.
  • num_width_embeddings : int, optional (default = None)
    Specifies the number of buckets to use when representing span width features.
  • span_width_embedding_dim : int, optional (default = None)
    The embedding size for the span_width features.
  • bucket_widths : bool, optional (default = False)
    Whether to bucket the span widths into log-space buckets. If False, the raw span widths are used.


  • span_embeddings : torch.FloatTensor.
    A tensor of shape (batch_size, num_spans, embedded_span_size), where embedded_span_size depends on the way spans are represented.


class SpanExtractorWithSpanWidthEmbedding(SpanExtractor):
 | ...
 | def forward(
 |     self,
 |     sequence_tensor: torch.FloatTensor,
 |     span_indices: torch.LongTensor,
 |     sequence_mask: torch.BoolTensor = None,
 |     span_indices_mask: torch.BoolTensor = None
 | )

Given a sequence tensor, extract spans, concatenate width embeddings when need and return representations of them.


  • sequence_tensor : torch.FloatTensor
    A tensor of shape (batch_size, sequence_length, embedding_size) representing an embedded sequence of words.
  • span_indices : torch.LongTensor
    A tensor of shape (batch_size, num_spans, 2), where the last dimension represents the inclusive start and end indices of the span to be extracted from the sequence_tensor.
  • sequence_mask : torch.BoolTensor, optional (default = None)
    A tensor of shape (batch_size, sequence_length) representing padded elements of the sequence.
  • span_indices_mask : torch.BoolTensor, optional (default = None)
    A tensor of shape (batch_size, num_spans) representing the valid spans in the indices tensor. This mask is optional because sometimes it's easier to worry about masking after calling this function, rather than passing a mask directly.


  • A tensor of shape (batch_size, num_spans, embedded_span_size),
  • where embedded_span_size depends on the way spans are represented.


class SpanExtractorWithSpanWidthEmbedding(SpanExtractor):
 | ...
 | def get_input_dim(self) -> int