Skip to content

transformer_mc_tt

allennlp_models.mc.models.transformer_mc_tt

[SOURCE]


TransformerMCTransformerToolkit#

@Model.register("transformer_mc_tt")
class TransformerMCTransformerToolkit(Model):
 | def __init__(
 |     self,
 |     vocab: Vocabulary,
 |     transformer_model: str = "roberta-large",
 |     override_weights_file: Optional[str] = None,
 |     **kwargs
 | ) -> None

This class implements a multiple choice model patterned after the proposed model in RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al).

It is exactly like the TransformerMC model, except it uses the TransformerTextField for its input.

It calculates a score for each sequence on top of the CLS token, and then chooses the alternative with the highest score.

Parametersvocab : ``Vocabulary``

transformer_model : str, optional (default="roberta-large") This model chooses the embedder according to this setting. You probably want to make sure this matches the setting in the reader.

forward#

class TransformerMCTransformerToolkit(Model):
 | ...
 | def forward(
 |     self,
 |     alternatives: Dict[str, torch.Tensor],
 |     correct_alternative: Optional[torch.IntTensor] = None,
 |     qid: Optional[List[str]] = None
 | ) -> Dict[str, torch.Tensor]

Parametersalternatives : ``Dict[str, torch.LongTensor]``

From a ``ListField[TensorTextField]``. Contains a list of alternatives to evaluate for every instance.

correct_alternative : Optional[torch.IntTensor] From an IndexField. Contains the index of the correct answer for every instance. qid : Optional[List[str]] A list of question IDs for the questions being processed now.

ReturnsAn output dictionary consisting of:

loss : torch.FloatTensor, optional A scalar loss to be optimised. This is only returned when correct_alternative is not None. logits : torch.FloatTensor The logits for every possible answer choice best_alternative : List[int] The index of the highest scoring alternative for every instance in the batch

get_metrics#

class TransformerMCTransformerToolkit(Model):
 | ...
 | def get_metrics(self, reset: bool = False) -> Dict[str, float]