Skip to content





class BertPooler(Seq2VecEncoder):
 | def __init__(
 |     self,
 |     pretrained_model: str,
 |     *,
 |     override_weights_file: Optional[str] = None,
 |     override_weights_strip_prefix: Optional[str] = None,
 |     load_weights: bool = True,
 |     requires_grad: bool = True,
 |     dropout: float = 0.0,
 |     transformer_kwargs: Optional[Dict[str, Any]] = None
 | ) -> None

The pooling layer at the end of the BERT model. This returns an embedding for the [CLS] token, after passing it through a non-linear tanh activation; the non-linear layer is also part of the BERT model. If you want to use the pretrained BERT model to build a classifier and you want to use the AllenNLP token-indexer -> token-embedder -> seq2vec encoder setup, this is the Seq2VecEncoder to use. (For example, if you want to experiment with other embedding / encoding combinations.)

Registered as a Seq2VecEncoder with name "bert_pooler".


  • pretrained_model : Union[str, BertModel]
    The pretrained BERT model to use. If this is a string, we will call transformers.AutoModel.from_pretrained(pretrained_model) and use that.
  • override_weights_file : Optional[str], optional (default = None)
    If set, this specifies a file from which to load alternate weights that override the weights from huggingface. The file is expected to contain a PyTorch state_dict, created with
  • override_weights_strip_prefix : Optional[str], optional (default = None)
    If set, strip the given prefix from the state dict when loading it.
  • load_weights : bool, optional (default = True)
    Whether to load the pretraiend weights.
  • requires_grad : bool, optional (default = True)
    If True, the weights of the pooler will be updated during training. Otherwise they will not.
  • dropout : float, optional (default = 0.0)
    Amount of dropout to apply after pooling
  • transformer_kwargs : Dict[str, Any], optional (default = None)
    Dictionary with additional arguments for AutoModel.from_pretrained.


class BertPooler(Seq2VecEncoder):
 | ...
 | @overrides
 | def get_input_dim(self) -> int


class BertPooler(Seq2VecEncoder):
 | ...
 | @overrides
 | def get_output_dim(self) -> int


class BertPooler(Seq2VecEncoder):
 | ...
 | def forward(
 |     self,
 |     tokens: torch.Tensor,
 |     mask: torch.BoolTensor = None,
 |     num_wrapping_dims: int = 0
 | )