Skip to content

transformer_classification_tt

allennlp_models.classification.models.transformer_classification_tt

[SOURCE]


TransformerClassificationTT#

@Model.register("transformer_classification_tt")
class TransformerClassificationTT(Model):
 | def __init__(
 |     self,
 |     vocab: Vocabulary,
 |     transformer_model: str = "roberta-large",
 |     num_labels: Optional[int] = None,
 |     label_namespace: str = "labels",
 |     override_weights_file: Optional[str] = None,
 |     **kwargs
 | ) -> None

This class implements a classification patterned after the proposed model in RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al).

Parametersvocab : ``Vocabulary``

transformer_model : str, optional (default="roberta-large") This model chooses the embedder according to this setting. You probably want to make sure this matches the setting in the reader.

forward#

class TransformerClassificationTT(Model):
 | ...
 | def forward(
 |     self,
 |     text: Dict[str, torch.Tensor],
 |     label: Optional[torch.Tensor] = None
 | ) -> Dict[str, torch.Tensor]

Parameterstext : ``Dict[str, torch.LongTensor]``

From a ``TensorTextField``. Contains the text to be classified.

label : Optional[torch.LongTensor] From a LabelField, specifies the true class of the instance

ReturnsAn output dictionary consisting of:

loss : torch.FloatTensor, optional A scalar loss to be optimised. This is only returned when correct_alternative is not None. logits : torch.FloatTensor The logits for every possible answer choice

get_metrics#

class TransformerClassificationTT(Model):
 | ...
 | def get_metrics(self, reset: bool = False) -> Dict[str, float]