Skip to content

coref

allennlp_models.coref.models.coref

[SOURCE]


CoreferenceResolver#

@Model.register("coref")
class CoreferenceResolver(Model):
 | def __init__(
 |     self,
 |     vocab: Vocabulary,
 |     text_field_embedder: TextFieldEmbedder,
 |     context_layer: Seq2SeqEncoder,
 |     mention_feedforward: FeedForward,
 |     antecedent_feedforward: FeedForward,
 |     feature_size: int,
 |     max_span_width: int,
 |     spans_per_word: float,
 |     max_antecedents: int,
 |     coarse_to_fine: bool = False,
 |     inference_order: int = 1,
 |     lexical_dropout: float = 0.2,
 |     initializer: InitializerApplicator = InitializerApplicator(),
 |     **kwargs
 | ) -> None

This Model implements the coreference resolution model described in Higher-order Coreference Resolution with Coarse-to-fine Inference by Lee et al., 2018. The basic outline of this model is to get an embedded representation of each span in the document. These span representations are scored and used to prune away spans that are unlikely to occur in a coreference cluster. For the remaining spans, the model decides which antecedent span (if any) they are coreferent with. The resulting coreference links, after applying transitivity, imply a clustering of the spans in the document.

Parameters

  • vocab : Vocabulary
  • text_field_embedder : TextFieldEmbedder
    Used to embed the text TextField we get as input to the model.
  • context_layer : Seq2SeqEncoder
    This layer incorporates contextual information for each word in the document.
  • mention_feedforward : FeedForward
    This feedforward network is applied to the span representations which is then scored by a linear layer.
  • antecedent_feedforward : FeedForward
    This feedforward network is applied to pairs of span representation, along with any pairwise features, which is then scored by a linear layer.
  • feature_size : int
    The embedding size for all the embedded features, such as distances or span widths.
  • max_span_width : int
    The maximum width of candidate spans.
  • spans_per_word : float
    A multiplier between zero and one which controls what percentage of candidate mention spans we retain with respect to the number of words in the document.
  • max_antecedents : int
    For each mention which survives the pruning stage, we consider this many antecedents.
  • coarse_to_fine : bool, optional (default = False)
    Whether or not to apply the coarse-to-fine filtering.
  • inference_order : int, optional (default = 1)
    The number of inference orders. When greater than 1, the span representations are updated and coreference scores re-computed.
  • lexical_dropout : int
    The probability of dropping out dimensions of the embedded text.
  • initializer : InitializerApplicator, optional (default = InitializerApplicator())
    Used to initialize the model parameters.

forward#

class CoreferenceResolver(Model):
 | ...
 | def forward(
 |     self,
 |     text: TextFieldTensors,
 |     spans: torch.IntTensor,
 |     span_labels: torch.IntTensor = None,
 |     metadata: List[Dict[str, Any]] = None
 | ) -> Dict[str, torch.Tensor]

Parameters

  • text : TextFieldTensors
    The output of a TextField representing the text of the document.
  • spans : torch.IntTensor
    A tensor of shape (batch_size, num_spans, 2), representing the inclusive start and end indices of candidate spans for mentions. Comes from a ListField[SpanField] of indices into the text of the document.
  • span_labels : torch.IntTensor, optional (default = None)
    A tensor of shape (batch_size, num_spans), representing the cluster ids of each span, or -1 for those which do not appear in any clusters.
  • metadata : List[Dict[str, Any]], optional (default = None)
    A metadata dictionary for each instance in the batch. We use the "original_text" and "clusters" keys from this dictionary, which respectively have the original text and the annotated gold coreference clusters for that instance.

Returns

  • An output dictionary consisting of:

  • top_spans : torch.IntTensor
    A tensor of shape (batch_size, num_spans_to_keep, 2) representing the start and end word indices of the top spans that survived the pruning stage.

  • antecedent_indices : torch.IntTensor
    A tensor of shape (num_spans_to_keep, max_antecedents) representing for each top span the index (with respect to top_spans) of the possible antecedents the model considered.
  • predicted_antecedents : torch.IntTensor
    A tensor of shape (batch_size, num_spans_to_keep) representing, for each top span, the index (with respect to antecedent_indices) of the most likely antecedent. -1 means there was no predicted link.
  • loss : torch.FloatTensor, optional
    A scalar loss to be optimised.

make_output_human_readable#

class CoreferenceResolver(Model):
 | ...
 | def make_output_human_readable(
 |     self,
 |     output_dict: Dict[str, torch.Tensor]
 | )

Converts the list of spans and predicted antecedent indices into clusters of spans for each element in the batch.

Parameters

  • output_dict : Dict[str, torch.Tensor]
    The result of calling forward on an instance or batch of instances.

Returns

  • The same output dictionary, but with an additional clusters key:

  • clusters : List[List[List[Tuple[int, int]]]]
    A nested list, representing, for each instance in the batch, the list of clusters, which are in turn comprised of a list of (start, end) inclusive spans into the original document.

get_metrics#

class CoreferenceResolver(Model):
 | ...
 | def get_metrics(self, reset: bool = False) -> Dict[str, float]

default_predictor#

class CoreferenceResolver(Model):
 | ...
 | default_predictor = "coreference_resolution"