Use Span Representations

Using Span Representations in AllenNLP#

Note that this tutorial goes through some quite advanced usage of AllenNLP - you may want to familiarize yourself with the repository before you go through this Span Representation Tutorial.

Many state of the art Deep NLP models use representations of spans, rather than representations of words, as the basic building block for models. In AllenNLP (starting from version 0.4), Span Representations are extremely easy to use in your model.

Examples of papers which contain span representations include:

In order to use span representations in your model, there are three things you probably need to think about: (1) enumerating all possible spans in a DatasetReader as input to your model; (2) extracting embedded span representations for the span indices and (3) pruning the spans in your model to only keep the most promising ones; We'll describe how to do each of these steps.

Generating SpanFields from text in a DatasetReader#

SpanFields are a type of Field in AllenNLP which take a start index, an end index and a SequenceField which the indices refer to. Once a batch of SpanFields has been converted to a tensor, we will have a matrix of shape (batch_size, 2), where the last dimension contains the start and end indices passed in to the SpanField constructor. However, for many models, you'll want to represent many spans for a single batch element - the way to do this is to use a ListField[SpanFields], which will create a tensor of shape (batch_size, num_spans, 2) once indexed.

Extracting Span Representations from a text sequence#

In many cases, you will want to extract spans from vector representations of sentences. In order to do this in AllenNLP, you will need to use a [SpanExtractor]. Broadly, a SpanExtractor takes a sequence tensor of shape (batch_size, sentence_length, embedding_size) and some indices of shape (batch_size, num_spans, 2) and returns an encoded representation of each span as a tensor of shape (batch_size, num_spans, encoded_size).

The simplest SpanExtractor is the EndpointSpanExtractor, which represents spans as a combination of the embeddings of their endpoints.

import torch
from torch.autograd import Variable
from allennlp.modules.span_extractors import EndpointSpanExtractor
sequence_tensor = Variable(torch.randn([2, 5, 7]))
# Concatentate start and end points together to form our representation.
extractor = EndpointSpanExtractor(input_dim=7, combination="x,y")

# Typically these would come from a SpanField,
# rather than being created directly.
indices = Variable(torch.LongTensor([[[1, 3],
                                      [2, 4]],
                                     [[0, 2],
                                      [3, 4]]]))

# We concatenated the representations for the start and end of
# the span, so the embedded span size is 2 * embedding_size.
# Shape (batch_size, num_spans, 2 * embedding_size).
span_representations = extractor(sequence_tensor, indices)
assert list(span_representations.size()) == [2, 2, 14]

There are other types of Span Extractors - for instance, the SelfAttentiveSpanExtractor, which computes span representations by generating an unnormalized attention score for each word in the sentence. Spans representations are then computed with respect to these scores by normalising the attention scores for words inside the span.

Scoring and Pruning Spans#

Span-based representations have been effective for modeling/approximating structured prediction problems - however, many models which leverage this type of representation also involve some kind of span enumeration (i.e considering all possible spans in a sentence/document). For a given sentence of length n, there are n2 spans. In itself, this is not too problematic, but for instance, the co-reference model in AllenNLP compares pairs of spans - meaning that naively we consider n4 spans, with potential document lengths of upwards of 3000 tokens.

In order to solve this problem, we need to be able to prune spans as we go inside our model. There are several ways to do this:

Heuristically prune spans in your DatasetReader.#

We have added a utility method for enumerating all spans in a sentence, but excluding those which fulfil some condition based on the input text or any Spacy Token attribute. For instance, for co-reference, all spans which are mentions (spans which are co-referent with something) never start or end with punctuation, or occur across sentence boundaries because of the way the Onotonotes 5.0 dataset was created. This means that we can exclude any span which starts or ends with punctuation using a very simple python function:

from typing import List
from allennlp.data.dataset_readers.dataset_utils import span_utils
from allennlp.data.tokenizers.spacy_tokenizer import SpacyTokenizer
from allennlp.data.tokenizers.token import Token

tokenizer = SpacyTokenizer(pos_tags=True)
sentence = tokenizer.tokenize("This is a sentence.")

def no_prefixed_punctuation(tokens: List[Token]) -> bool:
    # Only include spans which don't start or end with punctuation.
    return tokens[0].pos_ != "PUNCT" and tokens[-1].pos_ != "PUNCT"

spans = span_utils.enumerate_spans(sentence,
                                   max_span_width=3,
                                   min_span_width=2,
                                   filter_function=no_prefixed_punctuation)

# 'spans' won't include (2, 4) or (3, 4) as these have
# punctuation as their last element. Note that these spans
# have inclusive start and end indices!
assert spans == [(0, 1), (0, 2), (1, 2), (1, 3), (2, 3)]

There are other helpful functions in allennlp.data.dataset_readers.dataset_utils.span_utils, such as a function to convert between BIO labelings and span-based representations.

Existing AllenNLP examples for generating SpanFields#

We've already started using SpanFields in AllenNLP - you can see some examples in the Coreference DatasetReader, where we enumerate all possible spans in sentences of a document, or in the PennTreeBankConstituencySpanDatasetReader in order to classify whether or not they are constituents in a constitutency parse of the sentence.

Existing AllenNLP models which use SpanExtractors#

Currently, both the Coreference Model and the Span Based Constituency Parser use span representations from the output of bi-directional LSTMs. Take a look and see how they're used in a model context!