Skip to content

bidaf

allennlp_models.rc.models.bidaf

[SOURCE]


BidirectionalAttentionFlow#

@Model.register("bidaf")
class BidirectionalAttentionFlow(Model):
 | def __init__(
 |     self,
 |     vocab: Vocabulary,
 |     text_field_embedder: TextFieldEmbedder,
 |     num_highway_layers: int,
 |     phrase_layer: Seq2SeqEncoder,
 |     matrix_attention: MatrixAttention,
 |     modeling_layer: Seq2SeqEncoder,
 |     span_end_encoder: Seq2SeqEncoder,
 |     dropout: float = 0.2,
 |     mask_lstms: bool = True,
 |     initializer: InitializerApplicator = InitializerApplicator(),
 |     regularizer: Optional[RegularizerApplicator] = None
 | ) -> None

This class implements Minjoon Seo's Bidirectional Attention Flow model for answering reading comprehension questions (ICLR 2017).

The basic layout is pretty simple: encode words as a combination of word embeddings and a character-level encoder, pass the word representations through a bi-LSTM/GRU, use a matrix of attentions to put question information into the passage word representations (this is the only part that is at all non-standard), pass this through another few layers of bi-LSTMs/GRUs, and do a softmax over span start and span end.

Parameters

  • vocab : Vocabulary
  • text_field_embedder : TextFieldEmbedder
    Used to embed the question and passage TextFields we get as input to the model.
  • num_highway_layers : int
    The number of highway layers to use in between embedding the input and passing it through the phrase layer.
  • phrase_layer : Seq2SeqEncoder
    The encoder (with its own internal stacking) that we will use in between embedding tokens and doing the bidirectional attention.
  • matrix_attention : MatrixAttention
    The attention function that we will use when comparing encoded passage and question representations.
  • modeling_layer : Seq2SeqEncoder
    The encoder (with its own internal stacking) that we will use in between the bidirectional attention and predicting span start and end.
  • span_end_encoder : Seq2SeqEncoder
    The encoder that we will use to incorporate span start predictions into the passage state before predicting span end.
  • dropout : float, optional (default = 0.2)
    If greater than 0, we will apply dropout with this probability after all encoders (pytorch LSTMs do not apply dropout to their last layer).
  • mask_lstms : bool, optional (default = True)
    If False, we will skip passing the mask to the LSTM layers. This gives a ~2x speedup, with only a slight performance decrease, if any. We haven't experimented much with this yet, but have confirmed that we still get very similar performance with much faster training times. We still use the mask for all softmaxes, but avoid the shuffling that's required when using masking with pytorch LSTMs.
  • initializer : InitializerApplicator, optional (default = InitializerApplicator())
    Used to initialize the model parameters.
  • regularizer : RegularizerApplicator, optional (default = None)
    If provided, will be used to calculate the regularization penalty during training.

forward#

class BidirectionalAttentionFlow(Model):
 | ...
 | def forward(
 |     self,
 |     question: Dict[str, torch.LongTensor],
 |     passage: Dict[str, torch.LongTensor],
 |     span_start: torch.IntTensor = None,
 |     span_end: torch.IntTensor = None,
 |     metadata: List[Dict[str, Any]] = None
 | ) -> Dict[str, torch.Tensor]

Parameters

  • question : Dict[str, torch.LongTensor]
    From a TextField.
  • passage : Dict[str, torch.LongTensor]
    From a TextField. The model assumes that this passage contains the answer to the question, and predicts the beginning and ending positions of the answer within the passage.
  • span_start : torch.IntTensor, optional
    From an IndexField. This is one of the things we are trying to predict - the beginning position of the answer with the passage. This is an inclusive token index. If this is given, we will compute a loss that gets included in the output dictionary.
  • span_end : torch.IntTensor, optional
    From an IndexField. This is one of the things we are trying to predict - the ending position of the answer with the passage. This is an inclusive token index. If this is given, we will compute a loss that gets included in the output dictionary.
  • metadata : List[Dict[str, Any]], optional
    metadata : List[Dict[str, Any]], optional If present, this should contain the question tokens, passage tokens, original passage text, and token offsets into the passage for each instance in the batch. The length of this list should be the batch size, and each dictionary should have the keys question_tokens, passage_tokens, original_passage, and token_offsets.

ReturnsAn output dictionary consisting of:

span_start_logits : torch.FloatTensor A tensor of shape (batch_size, passage_length) representing unnormalized log probabilities of the span start position. span_start_probs : torch.FloatTensor The result of softmax(span_start_logits). span_end_logits : torch.FloatTensor A tensor of shape (batch_size, passage_length) representing unnormalized log probabilities of the span end position (inclusive). span_end_probs : torch.FloatTensor The result of softmax(span_end_logits). best_span : torch.IntTensor The result of a constrained inference over span_start_logits and span_end_logits to find the most probable span. Shape is (batch_size, 2) and each offset is a token index. loss : torch.FloatTensor, optional A scalar loss to be optimised. best_span_str : List[str] If sufficient metadata was provided for the instances in the batch, we also return the string from the original passage that the model thinks is the best answer to the question.

get_metrics#

class BidirectionalAttentionFlow(Model):
 | ...
 | def get_metrics(self, reset: bool = False) -> Dict[str, float]

get_best_span#

class BidirectionalAttentionFlow(Model):
 | ...
 | @staticmethod
 | def get_best_span(
 |     span_start_logits: torch.Tensor,
 |     span_end_logits: torch.Tensor
 | ) -> torch.Tensor

default_predictor#

class BidirectionalAttentionFlow(Model):
 | ...
 | default_predictor = "reading_comprehension"