coref
allennlp_models.coref.models.coref
CoreferenceResolver#
@Model.register("coref")
class CoreferenceResolver(Model):
| def __init__(
| self,
| vocab: Vocabulary,
| text_field_embedder: TextFieldEmbedder,
| context_layer: Seq2SeqEncoder,
| mention_feedforward: FeedForward,
| antecedent_feedforward: FeedForward,
| feature_size: int,
| max_span_width: int,
| spans_per_word: float,
| max_antecedents: int,
| coarse_to_fine: bool = False,
| inference_order: int = 1,
| lexical_dropout: float = 0.2,
| initializer: InitializerApplicator = InitializerApplicator(),
| **kwargs
| ) -> None
This Model
implements the coreference resolution model described in
Higher-order Coreference Resolution with Coarse-to-fine Inference
by Lee et al., 2018.
The basic outline of this model is to get an embedded representation of each span in the
document. These span representations are scored and used to prune away spans that are unlikely
to occur in a coreference cluster. For the remaining spans, the model decides which antecedent
span (if any) they are coreferent with. The resulting coreference links, after applying
transitivity, imply a clustering of the spans in the document.
Parameters¶
- vocab :
Vocabulary
- text_field_embedder :
TextFieldEmbedder
Used to embed thetext
TextField
we get as input to the model. - context_layer :
Seq2SeqEncoder
This layer incorporates contextual information for each word in the document. - mention_feedforward :
FeedForward
This feedforward network is applied to the span representations which is then scored by a linear layer. - antecedent_feedforward :
FeedForward
This feedforward network is applied to pairs of span representation, along with any pairwise features, which is then scored by a linear layer. - feature_size :
int
The embedding size for all the embedded features, such as distances or span widths. - max_span_width :
int
The maximum width of candidate spans. - spans_per_word :
float
A multiplier between zero and one which controls what percentage of candidate mention spans we retain with respect to the number of words in the document. - max_antecedents :
int
For each mention which survives the pruning stage, we consider this many antecedents. - coarse_to_fine :
bool
, optional (default =False
)
Whether or not to apply the coarse-to-fine filtering. - inference_order :
int
, optional (default =1
)
The number of inference orders. When greater than 1, the span representations are updated and coreference scores re-computed. - lexical_dropout :
int
The probability of dropping out dimensions of the embedded text. - initializer :
InitializerApplicator
, optional (default =InitializerApplicator()
)
Used to initialize the model parameters.
forward#
class CoreferenceResolver(Model):
| ...
| def forward(
| self,
| text: TextFieldTensors,
| spans: torch.IntTensor,
| span_labels: torch.IntTensor = None,
| metadata: List[Dict[str, Any]] = None
| ) -> Dict[str, torch.Tensor]
Parameters¶
- text :
TextFieldTensors
The output of aTextField
representing the text of the document. - spans :
torch.IntTensor
A tensor of shape (batch_size, num_spans, 2), representing the inclusive start and end indices of candidate spans for mentions. Comes from aListField[SpanField]
of indices into the text of the document. - span_labels :
torch.IntTensor
, optional (default =None
)
A tensor of shape (batch_size, num_spans), representing the cluster ids of each span, or -1 for those which do not appear in any clusters. - metadata :
List[Dict[str, Any]]
, optional (default =None
)
A metadata dictionary for each instance in the batch. We use the "original_text" and "clusters" keys from this dictionary, which respectively have the original text and the annotated gold coreference clusters for that instance.
Returns¶
-
An output dictionary consisting of:
-
top_spans :
torch.IntTensor
A tensor of shape(batch_size, num_spans_to_keep, 2)
representing the start and end word indices of the top spans that survived the pruning stage. - antecedent_indices :
torch.IntTensor
A tensor of shape(num_spans_to_keep, max_antecedents)
representing for each top span the index (with respect to top_spans) of the possible antecedents the model considered. - predicted_antecedents :
torch.IntTensor
A tensor of shape(batch_size, num_spans_to_keep)
representing, for each top span, the index (with respect to antecedent_indices) of the most likely antecedent. -1 means there was no predicted link. - loss :
torch.FloatTensor
, optional
A scalar loss to be optimised.
make_output_human_readable#
class CoreferenceResolver(Model):
| ...
| def make_output_human_readable(
| self,
| output_dict: Dict[str, torch.Tensor]
| )
Converts the list of spans and predicted antecedent indices into clusters of spans for each element in the batch.
Parameters¶
- output_dict :
Dict[str, torch.Tensor]
The result of callingforward
on an instance or batch of instances.
Returns¶
-
The same output dictionary, but with an additional
clusters
key: -
clusters :
List[List[List[Tuple[int, int]]]]
A nested list, representing, for each instance in the batch, the list of clusters, which are in turn comprised of a list of (start, end) inclusive spans into the original document.
get_metrics#
class CoreferenceResolver(Model):
| ...
| def get_metrics(self, reset: bool = False) -> Dict[str, float]
default_predictor#
class CoreferenceResolver(Model):
| ...
| default_predictor = "coreference_resolution"