Skip to content

transformer_qa

TransformerQA#

class TransformerQA(Model):
 | def __init__(
 |     self,
 |     vocab: Vocabulary,
 |     transformer_model_name: str = "bert-base-cased",
 |     **kwargs
 | ) -> None

This class implements a reading comprehension model patterned after the proposed model in https://arxiv.org/abs/1810.04805 (Devlin et al), with improvements borrowed from the SQuAD model in the transformers project.

It predicts start tokens and end tokens with a linear layer on top of word piece embeddings.

Note that the metrics that the model produces are calculated on a per-instance basis only. Since there could be more than one instance per question, these metrics are not the official numbers on the SQuAD task. To get official numbers, run the script in scripts/transformer_qa_eval.py.

Parameters

  • vocab : Vocabulary

  • transformer_model_name : str, optional (default = 'bert-base-cased')
    This model chooses the embedder according to this setting. You probably want to make sure this is set to the same thing as the reader.

forward#

class TransformerQA(Model):
 | ...
 | def forward(
 |     self,
 |     question_with_context: Dict[str, Dict[str, torch.LongTensor]],
 |     context_span: torch.IntTensor,
 |     answer_span: Optional[torch.IntTensor] = None,
 |     metadata: List[Dict[str, Any]] = None
 | ) -> Dict[str, torch.Tensor]

Parameters

  • question_with_context : Dict[str, torch.LongTensor]
    From a TextField. The model assumes that this text field contains the context followed by the question. It further assumes that the tokens have type ids set such that any token that can be part of the answer (i.e., tokens from the context) has type id 0, and any other token (including [CLS] and [SEP]) has type id 1.
  • context_span : torch.IntTensor
    From a SpanField. This marks the span of word pieces in question from which answers can come.
  • answer_span : torch.IntTensor, optional
    From a SpanField. This is the thing we are trying to predict - the span of text that marks the answer. If given, we compute a loss that gets included in the output directory.
  • metadata : List[Dict[str, Any]], optional
    If present, this should contain the question id, and the original texts of context, question, tokenized version of both, and a list of possible answers. The length of the metadata list should be the batch size, and each dictionary should have the keys id, question, context, question_tokens, context_tokens, and answers.

Returns

  • An output dictionary consisting of:

  • span_start_logits : torch.FloatTensor
    A tensor of shape (batch_size, passage_length) representing unnormalized log probabilities of the span start position.

  • span_start_probs : torch.FloatTensor
    The result of softmax(span_start_logits).
  • span_end_logits : torch.FloatTensor
    A tensor of shape (batch_size, passage_length) representing unnormalized log probabilities of the span end position (inclusive).
  • span_end_probs : torch.FloatTensor
    The result of softmax(span_end_logits).
  • best_span : torch.IntTensor
    The result of a constrained inference over span_start_logits and span_end_logits to find the most probable span. Shape is (batch_size, 2) and each offset is a token index.
  • best_span_scores : torch.FloatTensor
    The score for each of the best spans.
  • loss : torch.FloatTensor, optional
    A scalar loss to be optimised.
  • best_span_str : List[str]
    If sufficient metadata was provided for the instances in the batch, we also return the string from the original passage that the model thinks is the best answer to the question.

get_metrics#

class TransformerQA(Model):
 | ...
 | def get_metrics(self, reset: bool = False) -> Dict[str, float]

default_predictor#

class TransformerQA(Model):
 | ...
 | default_predictor = "transformer_qa"