bert
BertLanguageModelHead#
class BertLanguageModelHead(LanguageModelHead):
| def __init__(self, model_name: str) -> None
Loads just the LM head from transformers.BertForMaskedLM
. It was easiest to load
the entire model before only pulling out the head, so this is a bit slower than it could be,
but for practical use in a model, the few seconds of extra loading time is probably not a big
deal.
get_input_dim#
class BertLanguageModelHead(LanguageModelHead):
| ...
| @overrides
| def get_input_dim(self) -> int
get_output_dim#
class BertLanguageModelHead(LanguageModelHead):
| ...
| @overrides
| def get_output_dim(self) -> int
forward#
class BertLanguageModelHead(LanguageModelHead):
| ...
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor