utils
get_best_span#
def get_best_span(
span_start_logits: torch.Tensor,
span_end_logits: torch.Tensor
) -> torch.Tensor
This acts the same as the static method BidirectionalAttentionFlow.get_best_span()
in allennlp/models/reading_comprehension/bidaf.py
. We keep it here so that users can
directly import this function without the class.
We call the inputs "logits" - they could either be unnormalized logits or normalized log probabilities. A log_softmax operation is a constant shifting of the entire logit vector, so taking an argmax over either one gives the same result.
replace_masked_values_with_big_negative_number#
def replace_masked_values_with_big_negative_number(
x: torch.Tensor,
mask: torch.Tensor
)
Replace the masked values in a tensor something really negative so that they won't affect a max operation.