Skip to content

transformer_qa

allennlp_models.rc.models.transformer_qa

[SOURCE]


TransformerQA#

@Model.register("transformer_qa")
class TransformerQA(Model):
 | def __init__(
 |     self,
 |     vocab: Vocabulary,
 |     transformer_model_name: str = "bert-base-cased",
 |     **kwargs
 | ) -> None

Registered as "transformer_qa", this class implements a reading comprehension model patterned after the proposed model in Devlin et al, with improvements borrowed from the SQuAD model in the transformers project.

It predicts start tokens and end tokens with a linear layer on top of word piece embeddings.

If you want to use this model on SQuAD datasets, you can use it with the TransformerSquadReader dataset reader, registered as "transformer_squad".

Note that the metrics that the model produces are calculated on a per-instance basis only. Since there could be more than one instance per question, these metrics are not the official numbers on either SQuAD task.

To get official numbers for SQuAD v1.1, for example, you can run

python -m allennlp_models.rc.tools.transformer_qa_eval

Parameters

  • vocab : Vocabulary

  • transformer_model_name : str, optional (default = 'bert-base-cased')
    This model chooses the embedder according to this setting. You probably want to make sure this is set to the same thing as the reader.

forward#

class TransformerQA(Model):
 | ...
 | def forward(
 |     self,
 |     question_with_context: Dict[str, Dict[str, torch.LongTensor]],
 |     context_span: torch.IntTensor,
 |     cls_index: torch.LongTensor = None,
 |     answer_span: torch.IntTensor = None,
 |     metadata: List[Dict[str, Any]] = None
 | ) -> Dict[str, torch.Tensor]

Parameters

  • question_with_context : Dict[str, torch.LongTensor]
    From a TextField. The model assumes that this text field contains the context followed by the question. It further assumes that the tokens have type ids set such that any token that can be part of the answer (i.e., tokens from the context) has type id 0, and any other token (including [CLS] and [SEP]) has type id 1.

  • context_span : torch.IntTensor
    From a SpanField. This marks the span of word pieces in question from which answers can come.

  • cls_index : torch.LongTensor, optional
    A tensor of shape (batch_size,) that provides the index of the [CLS] token in the question_with_context for each instance.

    This is needed because the [CLS] token is used to indicate that the question is impossible.

    If this is None, it's assumed that the [CLS] token is at index 0 for each instance in the batch.

  • answer_span : torch.IntTensor, optional
    From a SpanField. This is the thing we are trying to predict - the span of text that marks the answer. If given, we compute a loss that gets included in the output directory.

  • metadata : List[Dict[str, Any]], optional
    If present, this should contain the question id, and the original texts of context, question, tokenized version of both, and a list of possible answers. The length of the metadata list should be the batch size, and each dictionary should have the keys id, question, context, question_tokens, context_tokens, and answers.

Returns

  • Dict[str, torch.Tensor] :
    An output dictionary with the following fields:

    • span_start_logits (torch.FloatTensor) : A tensor of shape (batch_size, passage_length) representing unnormalized log probabilities of the span start position.
    • span_end_logits (torch.FloatTensor) : A tensor of shape (batch_size, passage_length) representing unnormalized log probabilities of the span end position (inclusive).
    • best_span_scores (torch.FloatTensor) : The score for each of the best spans.
    • loss (torch.FloatTensor, optional) : A scalar loss to be optimised, evaluated against answer_span.
    • best_span (torch.IntTensor, optional) : Provided when not in train mode and sufficient metadata given for the instance. The result of a constrained inference over span_start_logits and span_end_logits to find the most probable span. Shape is (batch_size, 2) and each offset is a token index, unless the best span for an instance was predicted to be the [CLS] token, in which case the span will be (-1, -1).
    • best_span_str (List[str], optional) : Provided when not in train mode and sufficient metadata given for the instance. This is the string from the original passage that the model thinks is the best answer to the question.

get_metrics#

class TransformerQA(Model):
 | ...
 | def get_metrics(self, reset: bool = False) -> Dict[str, float]

default_predictor#

class TransformerQA(Model):
 | ...
 | default_predictor = "transformer_qa"