transformer_qa
TransformerQA#
class TransformerQA(Model):
| def __init__(
| self,
| vocab: Vocabulary,
| transformer_model_name: str = "bert-base-cased",
| **kwargs
| ) -> None
This class implements a reading comprehension model patterned after the proposed model in https://arxiv.org/abs/1810.04805 (Devlin et al), with improvements borrowed from the SQuAD model in the transformers project.
It predicts start tokens and end tokens with a linear layer on top of word piece embeddings.
Note that the metrics that the model produces are calculated on a per-instance basis only. Since there could be more than one instance per question, these metrics are not the official numbers on the SQuAD task. To get official numbers, run the script in scripts/transformer_qa_eval.py.
Parameters
-
vocab :
Vocabulary
-
transformer_model_name :
str
, optional (default ='bert-base-cased'
)
This model chooses the embedder according to this setting. You probably want to make sure this is set to the same thing as the reader.
forward#
class TransformerQA(Model):
| ...
| def forward(
| self,
| question_with_context: Dict[str, Dict[str, torch.LongTensor]],
| context_span: torch.IntTensor,
| answer_span: Optional[torch.IntTensor] = None,
| metadata: List[Dict[str, Any]] = None
| ) -> Dict[str, torch.Tensor]
Parameters
- question_with_context :
Dict[str, torch.LongTensor]
From aTextField
. The model assumes that this text field contains the context followed by the question. It further assumes that the tokens have type ids set such that any token that can be part of the answer (i.e., tokens from the context) has type id 0, and any other token (including [CLS] and [SEP]) has type id 1. - context_span :
torch.IntTensor
From aSpanField
. This marks the span of word pieces inquestion
from which answers can come. - answer_span :
torch.IntTensor
, optional
From aSpanField
. This is the thing we are trying to predict - the span of text that marks the answer. If given, we compute a loss that gets included in the output directory. - metadata :
List[Dict[str, Any]]
, optional
If present, this should contain the question id, and the original texts of context, question, tokenized version of both, and a list of possible answers. The length of themetadata
list should be the batch size, and each dictionary should have the keysid
,question
,context
,question_tokens
,context_tokens
, andanswers
.
Returns
-
An output dictionary consisting of:
-
span_start_logits :
torch.FloatTensor
A tensor of shape(batch_size, passage_length)
representing unnormalized log probabilities of the span start position. - span_start_probs :
torch.FloatTensor
The result ofsoftmax(span_start_logits)
. - span_end_logits :
torch.FloatTensor
A tensor of shape(batch_size, passage_length)
representing unnormalized log probabilities of the span end position (inclusive). - span_end_probs :
torch.FloatTensor
The result ofsoftmax(span_end_logits)
. - best_span :
torch.IntTensor
The result of a constrained inference overspan_start_logits
andspan_end_logits
to find the most probable span. Shape is(batch_size, 2)
and each offset is a token index. - best_span_scores :
torch.FloatTensor
The score for each of the best spans. - loss :
torch.FloatTensor
, optional
A scalar loss to be optimised. - best_span_str :
List[str]
If sufficient metadata was provided for the instances in the batch, we also return the string from the original passage that the model thinks is the best answer to the question.
get_metrics#
class TransformerQA(Model):
| ...
| def get_metrics(self, reset: bool = False) -> Dict[str, float]
default_predictor#
class TransformerQA(Model):
| ...
| default_predictor = "transformer_qa"