Skip to content





class MaxPoolingSpanExtractor(SpanExtractorWithSpanWidthEmbedding):
 | def __init__(
 |     self,
 |     input_dim: int,
 |     num_width_embeddings: int = None,
 |     span_width_embedding_dim: int = None,
 |     bucket_widths: bool = False
 | ) -> None

Represents spans through the application of a dimension-wise max-pooling operation. Given a span x_i, ..., x_j with i,j as span_start and span_end, each dimension d of the resulting span s is computed via s_d = max(x_id, ..., x_jd).

Elements masked-out by sequence_mask are ignored when max-pooling is computed. Span representations of masked out span_indices by span_mask are set to '0.'

Registered as a SpanExtractor with name "max_pooling".


  • input_dim : int
    The final dimension of the sequence_tensor.
  • num_width_embeddings : int, optional (default = None)
    Specifies the number of buckets to use when representing span width features.
  • span_width_embedding_dim : int, optional (default = None)
    The embedding size for the span_width features.
  • bucket_widths : bool, optional (default = False)
    Whether to bucket the span widths into log-space buckets. If False, the raw span widths are used.


  • max_pooling_text_embeddings : torch.FloatTensor.
    A tensor of shape (batch_size, num_spans, input_dim), which each span representation is the result of a max-pooling operation.


class MaxPoolingSpanExtractor(SpanExtractorWithSpanWidthEmbedding):
 | ...
 | def get_output_dim(self) -> int