transformer_classification_tt
allennlp_models.classification.models.transformer_classification_tt
TransformerClassificationTT#
@Model.register("transformer_classification_tt")
class TransformerClassificationTT(Model):
| def __init__(
| self,
| vocab: Vocabulary,
| transformer_model: str = "roberta-large",
| num_labels: Optional[int] = None,
| label_namespace: str = "labels",
| override_weights_file: Optional[str] = None,
| **kwargs
| ) -> None
This class implements a classification patterned after the proposed model in RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al).
Parametersvocab : ``Vocabulary``¶
transformer_model : str
, optional (default="roberta-large"
)
This model chooses the embedder according to this setting. You probably want to make sure this matches the
setting in the reader.
forward#
class TransformerClassificationTT(Model):
| ...
| def forward(
| self,
| text: Dict[str, torch.Tensor],
| label: Optional[torch.Tensor] = None
| ) -> Dict[str, torch.Tensor]
Parameterstext : ``Dict[str, torch.LongTensor]``¶
From a ``TensorTextField``. Contains the text to be classified.
label : Optional[torch.LongTensor]
From a LabelField
, specifies the true class of the instance
ReturnsAn output dictionary consisting of:¶
loss : torch.FloatTensor
, optional
A scalar loss to be optimised. This is only returned when correct_alternative
is not None
.
logits : torch.FloatTensor
The logits for every possible answer choice
get_metrics#
class TransformerClassificationTT(Model):
| ...
| def get_metrics(self, reset: bool = False) -> Dict[str, float]