esim
allennlp_models.pair_classification.models.esim
ESIM#
@Model.register("esim")
class ESIM(Model):
| def __init__(
| self,
| vocab: Vocabulary,
| text_field_embedder: TextFieldEmbedder,
| encoder: Seq2SeqEncoder,
| matrix_attention: MatrixAttention,
| projection_feedforward: FeedForward,
| inference_encoder: Seq2SeqEncoder,
| output_feedforward: FeedForward,
| output_logit: FeedForward,
| dropout: float = 0.5,
| initializer: InitializerApplicator = InitializerApplicator(),
| **kwargs
| ) -> None
This Model
implements the ESIM sequence model described in Enhanced LSTM for Natural Language Inference by Chen et al., 2017.
Registered as a Model
with name "esim".
Parameters¶
- vocab :
Vocabulary
- text_field_embedder :
TextFieldEmbedder
Used to embed thepremise
andhypothesis
TextFields
we get as input to the model. - encoder :
Seq2SeqEncoder
Used to encode the premise and hypothesis. - matrix_attention :
MatrixAttention
This is the attention function used when computing the similarity matrix between encoded words in the premise and words in the hypothesis. - projection_feedforward :
FeedForward
The feedforward network used to project down the encoded and enhanced premise and hypothesis. - inference_encoder :
Seq2SeqEncoder
Used to encode the projected premise and hypothesis for prediction. - output_feedforward :
FeedForward
Used to prepare the concatenated premise and hypothesis for prediction. - output_logit :
FeedForward
This feedforward network computes the output logits. - dropout :
float
, optional (default =0.5
)
Dropout percentage to use. - initializer :
InitializerApplicator
, optional (default =InitializerApplicator()
)
Used to initialize the model parameters.
forward#
class ESIM(Model):
| ...
| def forward(
| self,
| premise: TextFieldTensors,
| hypothesis: TextFieldTensors,
| label: torch.IntTensor = None,
| metadata: List[Dict[str, Any]] = None
| ) -> Dict[str, torch.Tensor]
Parameters¶
- premise :
TextFieldTensors
From aTextField
- hypothesis :
TextFieldTensors
From aTextField
- label :
torch.IntTensor
, optional (default =None
)
From aLabelField
- metadata :
List[Dict[str, Any]]
, optional (default =None
)
Metadata containing the original tokenization of the premise and hypothesis with 'premise_tokens' and 'hypothesis_tokens' keys respectively.
Returns¶
-
An output dictionary consisting of:
-
label_logits :
torch.FloatTensor
A tensor of shape(batch_size, num_labels)
representing unnormalised log probabilities of the entailment label. - label_probs :
torch.FloatTensor
A tensor of shape(batch_size, num_labels)
representing probabilities of the entailment label. - loss :
torch.FloatTensor
, optional
A scalar loss to be optimised.
get_metrics#
class ESIM(Model):
| ...
| def get_metrics(self, reset: bool = False) -> Dict[str, float]
make_output_human_readable#
class ESIM(Model):
| ...
| def make_output_human_readable(
| self,
| output_dict: Dict[str, torch.Tensor]
| ) -> Dict[str, torch.Tensor]
Does a simple argmax over the probabilities, converts index to string label, and
add "label"
key to the dictionary with the result.
default_predictor#
class ESIM(Model):
| ...
| default_predictor = "textual_entailment"