transformer_qa
allennlp_models.rc.models.transformer_qa
TransformerQA#
@Model.register("transformer_qa")
class TransformerQA(Model):
| def __init__(
| self,
| vocab: Vocabulary,
| transformer_model_name: str = "bert-base-cased",
| **kwargs
| ) -> None
Registered as "transformer_qa"
, this class implements a reading comprehension model patterned
after the proposed model in Devlin et al,
with improvements borrowed from the SQuAD model in the transformers project.
It predicts start tokens and end tokens with a linear layer on top of word piece embeddings.
If you want to use this model on SQuAD datasets, you can use it with the
TransformerSquadReader
dataset reader, registered as "transformer_squad"
.
Note that the metrics that the model produces are calculated on a per-instance basis only. Since there could be more than one instance per question, these metrics are not the official numbers on either SQuAD task.
To get official numbers for SQuAD v1.1, for example, you can run
python -m allennlp_models.rc.tools.transformer_qa_eval
Parameters¶
-
vocab :
Vocabulary
-
transformer_model_name :
str
, optional (default ='bert-base-cased'
)
This model chooses the embedder according to this setting. You probably want to make sure this is set to the same thing as the reader.
forward#
class TransformerQA(Model):
| ...
| def forward(
| self,
| question_with_context: Dict[str, Dict[str, torch.LongTensor]],
| context_span: torch.IntTensor,
| cls_index: torch.LongTensor = None,
| answer_span: torch.IntTensor = None,
| metadata: List[Dict[str, Any]] = None
| ) -> Dict[str, torch.Tensor]
Parameters¶
-
question_with_context :
Dict[str, torch.LongTensor]
From aTextField
. The model assumes that this text field contains the context followed by the question. It further assumes that the tokens have type ids set such that any token that can be part of the answer (i.e., tokens from the context) has type id 0, and any other token (including[CLS]
and[SEP]
) has type id 1. -
context_span :
torch.IntTensor
From aSpanField
. This marks the span of word pieces inquestion
from which answers can come. -
cls_index :
torch.LongTensor
, optional
A tensor of shape(batch_size,)
that provides the index of the[CLS]
token in thequestion_with_context
for each instance.This is needed because the
[CLS]
token is used to indicate that the question is impossible.If this is
None
, it's assumed that the[CLS]
token is at index 0 for each instance in the batch. -
answer_span :
torch.IntTensor
, optional
From aSpanField
. This is the thing we are trying to predict - the span of text that marks the answer. If given, we compute a loss that gets included in the output directory. -
metadata :
List[Dict[str, Any]]
, optional
If present, this should contain the question id, and the original texts of context, question, tokenized version of both, and a list of possible answers. The length of themetadata
list should be the batch size, and each dictionary should have the keysid
,question
,context
,question_tokens
,context_tokens
, andanswers
.
Returns¶
-
Dict[str, torch.Tensor]
:
An output dictionary with the following fields:- span_start_logits (
torch.FloatTensor
) : A tensor of shape(batch_size, passage_length)
representing unnormalized log probabilities of the span start position. - span_end_logits (
torch.FloatTensor
) : A tensor of shape(batch_size, passage_length)
representing unnormalized log probabilities of the span end position (inclusive). - best_span_scores (
torch.FloatTensor
) : The score for each of the best spans. - loss (
torch.FloatTensor
, optional) : A scalar loss to be optimised, evaluated againstanswer_span
. - best_span (
torch.IntTensor
, optional) : Provided when not in train mode and sufficient metadata given for the instance. The result of a constrained inference overspan_start_logits
andspan_end_logits
to find the most probable span. Shape is(batch_size, 2)
and each offset is a token index, unless the best span for an instance was predicted to be the[CLS]
token, in which case the span will be (-1, -1). - best_span_str (
List[str]
, optional) : Provided when not in train mode and sufficient metadata given for the instance. This is the string from the original passage that the model thinks is the best answer to the question.
- span_start_logits (
get_metrics#
class TransformerQA(Model):
| ...
| def get_metrics(self, reset: bool = False) -> Dict[str, float]
default_predictor#
class TransformerQA(Model):
| ...
| default_predictor = "transformer_qa"