Skip to content

constituency_parser

allennlp_models.structured_prediction.models.constituency_parser

[SOURCE]


SpanInformation#

class SpanInformation(NamedTuple)

A helper namedtuple for handling decoding information.

Parameters

  • start : int
    The start index of the span.
  • end : int
    The exclusive end index of the span.
  • no_label_prob : float
    The probability of this span being assigned the NO-LABEL label.
  • label_prob : float
    The probability of the most likely label.

start#

class SpanInformation(NamedTuple):
 | ...
 | start: int = None

end#

class SpanInformation(NamedTuple):
 | ...
 | end: int = None

label_prob#

class SpanInformation(NamedTuple):
 | ...
 | label_prob: float = None

no_label_prob#

class SpanInformation(NamedTuple):
 | ...
 | no_label_prob: float = None

label_index#

class SpanInformation(NamedTuple):
 | ...
 | label_index: int = None

SpanConstituencyParser#

@Model.register("constituency_parser")
class SpanConstituencyParser(Model):
 | def __init__(
 |     self,
 |     vocab: Vocabulary,
 |     text_field_embedder: TextFieldEmbedder,
 |     span_extractor: SpanExtractor,
 |     encoder: Seq2SeqEncoder,
 |     feedforward: FeedForward = None,
 |     pos_tag_embedding: Embedding = None,
 |     initializer: InitializerApplicator = InitializerApplicator(),
 |     evalb_directory_path: str = DEFAULT_EVALB_DIR,
 |     **kwargs
 | ) -> None

This SpanConstituencyParser simply encodes a sequence of text with a stacked Seq2SeqEncoder, extracts span representations using a SpanExtractor, and then predicts a label for each span in the sequence. These labels are non-terminal nodes in a constituency parse tree, which we then greedily reconstruct.

Parameters

  • vocab : Vocabulary
    A Vocabulary, required in order to compute sizes for input/output projections.
  • text_field_embedder : TextFieldEmbedder
    Used to embed the tokens TextField we get as input to the model.
  • span_extractor : SpanExtractor
    The method used to extract the spans from the encoded sequence.
  • encoder : Seq2SeqEncoder
    The encoder that we will use in between embedding tokens and generating span representations.
  • feedforward : FeedForward
    The FeedForward layer that we will use in between the encoder and the linear projection to a distribution over span labels.
  • pos_tag_embedding : Embedding, optional
    Used to embed the pos_tags SequenceLabelField we get as input to the model.
  • initializer : InitializerApplicator, optional (default = InitializerApplicator())
    Used to initialize the model parameters.
  • evalb_directory_path : str, optional (default = DEFAULT_EVALB_DIR)
    The path to the directory containing the EVALB executable used to score bracketed parses. By default, will use the EVALB included with allennlp, which is located at allennlp/tools/EVALB . If None, EVALB scoring is not used.

forward#

class SpanConstituencyParser(Model):
 | ...
 | def forward(
 |     self,
 |     tokens: TextFieldTensors,
 |     spans: torch.LongTensor,
 |     metadata: List[Dict[str, Any]],
 |     pos_tags: TextFieldTensors = None,
 |     span_labels: torch.LongTensor = None
 | ) -> Dict[str, torch.Tensor]

Parameters

  • tokens : TextFieldTensors
    The output of TextField.as_array(), which should typically be passed directly to a TextFieldEmbedder. This output is a dictionary mapping keys to TokenIndexer tensors. At its most basic, using a SingleIdTokenIndexer this is : {"tokens": Tensor(batch_size, num_tokens)}. This dictionary will have the same keys as were used for the TokenIndexers when you created the TextField representing your sequence. The dictionary is designed to be passed directly to a TextFieldEmbedder, which knows how to combine different word representations into a single vector per token in your input.
  • spans : torch.LongTensor
    A tensor of shape (batch_size, num_spans, 2) representing the inclusive start and end indices of all possible spans in the sentence.
  • metadata : List[Dict[str, Any]]
    A dictionary of metadata for each batch element which has keys: tokens : List[str], required. The original string tokens in the sentence. gold_tree : nltk.Tree, optional (default = None) Gold NLTK trees for use in evaluation. pos_tags : List[str], optional. The POS tags for the sentence. These can be used in the model as embedded features, but they are passed here in addition for use in constructing the tree.
  • pos_tags : torch.LongTensor, optional (default = None)
    The output of a SequenceLabelField containing POS tags.
  • span_labels : torch.LongTensor, optional (default = None)
    A torch tensor representing the integer gold class labels for all possible spans, of shape (batch_size, num_spans).

Returns

  • An output dictionary consisting of:

  • class_probabilities : torch.FloatTensor
    A tensor of shape (batch_size, num_spans, span_label_vocab_size) representing a distribution over the label classes per span.

  • spans : torch.LongTensor
    The original spans tensor.
  • tokens : List[List[str]], required.
    A list of tokens in the sentence for each element in the batch.
  • pos_tags : List[List[str]], required.
    A list of POS tags in the sentence for each element in the batch.
  • num_spans : torch.LongTensor, required.
    A tensor of shape (batch_size), representing the lengths of non-padded spans in enumerated_spans.
  • loss : torch.FloatTensor, optional
    A scalar loss to be optimised.

make_output_human_readable#

class SpanConstituencyParser(Model):
 | ...
 | def make_output_human_readable(
 |     self,
 |     output_dict: Dict[str, torch.Tensor]
 | ) -> Dict[str, torch.Tensor]

Constructs an NLTK Tree given the scored spans. We also switch to exclusive span ends when constructing the tree representation, because it makes indexing into lists cleaner for ranges of text, rather than individual indices.

Finally, for batch prediction, we will have padded spans and class probabilities. In order to make this less confusing, we remove all the padded spans and distributions from spans and class_probabilities respectively.

construct_trees#

class SpanConstituencyParser(Model):
 | ...
 | def construct_trees(
 |     self,
 |     predictions: torch.FloatTensor,
 |     all_spans: torch.LongTensor,
 |     num_spans: torch.LongTensor,
 |     sentences: List[List[str]],
 |     pos_tags: List[List[str]] = None
 | ) -> List[Tree]

Construct nltk.Tree's for each batch element by greedily nesting spans. The trees use exclusive end indices, which contrasts with how spans are represented in the rest of the model.

Parameters

  • predictions : torch.FloatTensor
    A tensor of shape (batch_size, num_spans, span_label_vocab_size) representing a distribution over the label classes per span.
  • all_spans : torch.LongTensor
    A tensor of shape (batch_size, num_spans, 2), representing the span indices we scored.
  • num_spans : torch.LongTensor
    A tensor of shape (batch_size), representing the lengths of non-padded spans in enumerated_spans.
  • sentences : List[List[str]]
    A list of tokens in the sentence for each element in the batch.
  • pos_tags : List[List[str]], optional (default = None)
    A list of POS tags for each word in the sentence for each element in the batch.

Returns

  • A List[Tree] containing the decoded trees for each element in the batch.

resolve_overlap_conflicts_greedily#

class SpanConstituencyParser(Model):
 | ...
 | @staticmethod
 | def resolve_overlap_conflicts_greedily(
 |     spans: List[SpanInformation]
 | ) -> List[SpanInformation]

Given a set of spans, removes spans which overlap by evaluating the difference in probability between one being labeled and the other explicitly having no label and vice-versa. The worst case time complexity of this method is O(k * n^4) where n is the length of the sentence that the spans were enumerated from (and therefore k * m^2 complexity with respect to the number of spans m) and k is the number of conflicts. However, in practice, there are very few conflicts. Hopefully.

This function modifies spans to remove overlapping spans.

Parameters

  • spans : List[SpanInformation]
    A list of spans, where each span is a namedtuple containing the following attributes:

    start : int The start index of the span. end : int The exclusive end index of the span. no_label_prob : float The probability of this span being assigned the NO-LABEL label. label_prob : float The probability of the most likely label.

Returns

  • A modified list of spans, with the conflicts resolved by considering local
  • differences between pairs of spans and removing one of the two spans.

construct_tree_from_spans#

class SpanConstituencyParser(Model):
 | ...
 | @staticmethod
 | def construct_tree_from_spans(
 |     spans_to_labels: Dict[Tuple[int, int], str],
 |     sentence: List[str],
 |     pos_tags: List[str] = None
 | ) -> Tree

Parameters

  • spans_to_labels : Dict[Tuple[int, int], str]
    A mapping from spans to constituency labels.
  • sentence : List[str]
    A list of tokens forming the sentence to be parsed.
  • pos_tags : List[str], optional (default = None)
    A list of the pos tags for the words in the sentence, if they were either predicted or taken as input to the model.

Returns

  • An nltk.Tree constructed from the labelled spans.

get_metrics#

class SpanConstituencyParser(Model):
 | ...
 | def get_metrics(self, reset: bool = False) -> Dict[str, float]

default_predictor#

class SpanConstituencyParser(Model):
 | ...
 | default_predictor = "constituency_parser"