gpt2
allennlp_models.lm.modules.language_model_heads.gpt2
Gpt2LanguageModelHead#
@LanguageModelHead.register("gpt2")
class Gpt2LanguageModelHead(LanguageModelHead):
| def __init__(self, model_name: str) -> None
Loads just the LM head from transformers.GPT2LMHeadModel
. It was easiest to load
the entire model before only pulling out the head, so this is a bit slower than it could be,
but for practical use in a model, the few seconds of extra loading time is probably not a big
deal.
get_input_dim#
class Gpt2LanguageModelHead(LanguageModelHead):
| ...
| @overrides
| def get_input_dim(self) -> int
get_output_dim#
class Gpt2LanguageModelHead(LanguageModelHead):
| ...
| @overrides
| def get_output_dim(self) -> int
forward#
class Gpt2LanguageModelHead(LanguageModelHead):
| ...
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor