allennlp.models.bert_for_classification¶
-
class
allennlp.models.bert_for_classification.
BertForClassification
(vocab: allennlp.data.vocabulary.Vocabulary, bert_model: Union[str, pytorch_pretrained_bert.modeling.BertModel], dropout: float = 0.0, num_labels: int = None, index: str = 'bert', label_namespace: str = 'labels', trainable: bool = True, initializer: allennlp.nn.initializers.InitializerApplicator = <allennlp.nn.initializers.InitializerApplicator object>, regularizer: Optional[allennlp.nn.regularizers.regularizer_applicator.RegularizerApplicator] = None)[source]¶ Bases:
allennlp.models.model.Model
An AllenNLP Model that runs pretrained BERT, takes the pooled output, and adds a Linear layer on top. If you want an easy way to use BERT for classification, this is it. Note that this is a somewhat non-AllenNLP-ish model architecture, in that it essentially requires you to use the “bert-pretrained” token indexer, rather than configuring whatever indexing scheme you like.
See allennlp/tests/fixtures/bert/bert_for_classification.jsonnet for an example of what your config might look like.
- Parameters
- vocab
Vocabulary
- bert_model
Union[str, BertModel]
The BERT model to be wrapped. If a string is provided, we will call
BertModel.from_pretrained(bert_model)
and use the result.- num_labels
int
, optional (default: None) How many output classes to predict. If not provided, we’ll use the vocab_size for the
label_namespace
.- index
str
, optional (default: “bert”) The index of the token indexer that generates the BERT indices.
- label_namespace
str
, optional (default Used to determine the number of classes if
num_labels
is not supplied.- trainable
bool
, optional (default If True, the weights of the pretrained BERT model will be updated during training. Otherwise, they will be frozen and only the final linear layer will be trained.
- initializer
InitializerApplicator
, optional If provided, will be used to initialize the final linear layer only.
- regularizer
RegularizerApplicator
, optional (default=``None``) If provided, will be used to calculate the regularization penalty during training.
- vocab
-
decode
(self, output_dict: Dict[str, torch.Tensor]) → Dict[str, torch.Tensor][source]¶ Does a simple argmax over the probabilities, converts index to string label, and add
"label"
key to the dictionary with the result.
-
forward
(self, tokens: Dict[str, torch.LongTensor], label: torch.IntTensor = None) → Dict[str, torch.Tensor][source]¶ - Parameters
- tokensDict[str, torch.LongTensor]
From a
TextField
(that has a bert-pretrained token indexer)- labeltorch.IntTensor, optional (default = None)
From a
LabelField
- Returns
- An output dictionary consisting of:
- logitstorch.FloatTensor
A tensor of shape
(batch_size, num_labels)
representing unnormalized log probabilities of the label.- probstorch.FloatTensor
A tensor of shape
(batch_size, num_labels)
representing probabilities of the label.- losstorch.FloatTensor, optional
A scalar loss to be optimised.
-
get_metrics
(self, reset: bool = False) → Dict[str, float][source]¶ Returns a dictionary of metrics. This method will be called by
allennlp.training.Trainer
in order to compute and use model metrics for early stopping and model serialization. We return an empty dictionary here rather than raising as it is not required to implement metrics for a new model. A boolean reset parameter is passed, as frequently a metric accumulator will have some state which should be reset between epochs. This is also compatible withMetrics should be populated during the call to ``forward`
, with theMetric
handling the accumulation of the metric until this method is called.