class allennlp.models.bert_for_classification.BertForClassification(vocab:, bert_model: Union[str, pytorch_pretrained_bert.modeling.BertModel], dropout: float = 0.0, num_labels: int = None, index: str = 'bert', label_namespace: str = 'labels', trainable: bool = True, initializer: allennlp.nn.initializers.InitializerApplicator = <allennlp.nn.initializers.InitializerApplicator object>, regularizer: Optional[allennlp.nn.regularizers.regularizer_applicator.RegularizerApplicator] = None)[source]

Bases: allennlp.models.model.Model

An AllenNLP Model that runs pretrained BERT, takes the pooled output, and adds a Linear layer on top. If you want an easy way to use BERT for classification, this is it. Note that this is a somewhat non-AllenNLP-ish model architecture, in that it essentially requires you to use the “bert-pretrained” token indexer, rather than configuring whatever indexing scheme you like.

See allennlp/tests/fixtures/bert/bert_for_classification.jsonnet for an example of what your config might look like.

bert_modelUnion[str, BertModel]

The BERT model to be wrapped. If a string is provided, we will call BertModel.from_pretrained(bert_model) and use the result.

num_labelsint, optional (default: None)

How many output classes to predict. If not provided, we’ll use the vocab_size for the label_namespace.

indexstr, optional (default: “bert”)

The index of the token indexer that generates the BERT indices.

label_namespacestr, optional (default

Used to determine the number of classes if num_labels is not supplied.

trainablebool, optional (default

If True, the weights of the pretrained BERT model will be updated during training. Otherwise, they will be frozen and only the final linear layer will be trained.

initializerInitializerApplicator, optional

If provided, will be used to initialize the final linear layer only.

regularizerRegularizerApplicator, optional (default=``None``)

If provided, will be used to calculate the regularization penalty during training.

decode(self, output_dict: Dict[str, torch.Tensor]) → Dict[str, torch.Tensor][source]

Does a simple argmax over the probabilities, converts index to string label, and add "label" key to the dictionary with the result.

forward(self, tokens: Dict[str, torch.LongTensor], label: torch.IntTensor = None) → Dict[str, torch.Tensor][source]
tokensDict[str, torch.LongTensor]

From a TextField (that has a bert-pretrained token indexer)

labeltorch.IntTensor, optional (default = None)

From a LabelField

An output dictionary consisting of:

A tensor of shape (batch_size, num_labels) representing unnormalized log probabilities of the label.


A tensor of shape (batch_size, num_labels) representing probabilities of the label.

losstorch.FloatTensor, optional

A scalar loss to be optimised.

get_metrics(self, reset: bool = False) → Dict[str, float][source]

Returns a dictionary of metrics. This method will be called by in order to compute and use model metrics for early stopping and model serialization. We return an empty dictionary here rather than raising as it is not required to implement metrics for a new model. A boolean reset parameter is passed, as frequently a metric accumulator will have some state which should be reset between epochs. This is also compatible with Metrics should be populated during the call to ``forward`, with the Metric handling the accumulation of the metric until this method is called.