bert_pooler
allennlp.modules.seq2vec_encoders.bert_pooler
BertPooler¶
@Seq2VecEncoder.register("bert_pooler")
class BertPooler(Seq2VecEncoder):
| def __init__(
| self,
| pretrained_model: str,
| *, override_weights_file: Optional[str] = None,
| *, override_weights_strip_prefix: Optional[str] = None,
| *, load_weights: bool = True,
| *, requires_grad: bool = True,
| *, dropout: float = 0.0,
| *, transformer_kwargs: Optional[Dict[str, Any]] = None
| ) -> None
The pooling layer at the end of the BERT model. This returns an embedding for the [CLS] token, after passing it through a non-linear tanh activation; the non-linear layer is also part of the BERT model. If you want to use the pretrained BERT model to build a classifier and you want to use the AllenNLP token-indexer -> token-embedder -> seq2vec encoder setup, this is the Seq2VecEncoder to use. (For example, if you want to experiment with other embedding / encoding combinations.)
Registered as a Seq2VecEncoder
with name "bert_pooler".
Parameters¶
- pretrained_model :
Union[str, BertModel]
The pretrained BERT model to use. If this is a string, we will calltransformers.AutoModel.from_pretrained(pretrained_model)
and use that. - override_weights_file :
Optional[str]
, optional (default =None
)
If set, this specifies a file from which to load alternate weights that override the weights from huggingface. The file is expected to contain a PyTorchstate_dict
, created withtorch.save()
. - override_weights_strip_prefix :
Optional[str]
, optional (default =None
)
If set, strip the given prefix from the state dict when loading it. - load_weights :
bool
, optional (default =True
)
Whether to load the pretraiend weights. - requires_grad :
bool
, optional (default =True
)
If True, the weights of the pooler will be updated during training. Otherwise they will not. - dropout :
float
, optional (default =0.0
)
Amount of dropout to apply after pooling - transformer_kwargs :
Dict[str, Any]
, optional (default =None
)
Dictionary with additional arguments forAutoModel.from_pretrained
.
get_input_dim¶
class BertPooler(Seq2VecEncoder):
| ...
| def get_input_dim(self) -> int
get_output_dim¶
class BertPooler(Seq2VecEncoder):
| ...
| def get_output_dim(self) -> int
forward¶
class BertPooler(Seq2VecEncoder):
| ...
| def forward(
| self,
| tokens: torch.Tensor,
| mask: torch.BoolTensor = None,
| num_wrapping_dims: int = 0
| )