Skip to content

naqanet

allennlp_models.rc.models.naqanet

[SOURCE]


NumericallyAugmentedQaNet#

@Model.register("naqanet")
class NumericallyAugmentedQaNet(Model):
 | def __init__(
 |     self,
 |     vocab: Vocabulary,
 |     text_field_embedder: TextFieldEmbedder,
 |     num_highway_layers: int,
 |     phrase_layer: Seq2SeqEncoder,
 |     matrix_attention_layer: MatrixAttention,
 |     modeling_layer: Seq2SeqEncoder,
 |     dropout_prob: float = 0.1,
 |     initializer: InitializerApplicator = InitializerApplicator(),
 |     regularizer: Optional[RegularizerApplicator] = None,
 |     answering_abilities: List[str] = None
 | ) -> None

This class augments the QANet model with some rudimentary numerical reasoning abilities, as published in the original DROP paper.

The main idea here is that instead of just predicting a passage span after doing all of the QANet modeling stuff, we add several different "answer abilities": predicting a span from the question, predicting a count, or predicting an arithmetic expression. Near the end of the QANet model, we have a variable that predicts what kind of answer type we need, and each branch has separate modeling logic to predict that answer type. We then marginalize over all possible ways of getting to the right answer through each of these answer types.

forward#

class NumericallyAugmentedQaNet(Model):
 | ...
 | def forward(
 |     self,
 |     question: Dict[str, torch.LongTensor],
 |     passage: Dict[str, torch.LongTensor],
 |     number_indices: torch.LongTensor,
 |     answer_as_passage_spans: torch.LongTensor = None,
 |     answer_as_question_spans: torch.LongTensor = None,
 |     answer_as_add_sub_expressions: torch.LongTensor = None,
 |     answer_as_counts: torch.LongTensor = None,
 |     metadata: List[Dict[str, Any]] = None
 | ) -> Dict[str, torch.Tensor]

get_metrics#

class NumericallyAugmentedQaNet(Model):
 | ...
 | def get_metrics(self, reset: bool = False) -> Dict[str, float]

default_predictor#

class NumericallyAugmentedQaNet(Model):
 | ...
 | default_predictor = "reading_comprehension"