allennlp.modules.language_model_heads¶
-
class
allennlp.modules.language_model_heads.language_model_head.
LanguageModelHead
[source]¶ Bases:
torch.nn.modules.module.Module
,allennlp.common.registrable.Registrable
A
LanguageModelHead
encapsulates a function that goes from some hidden state to logits over a vocabulary.-
forward
(self, hidden_states: torch.Tensor) → torch.Tensor[source]¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
class
allennlp.modules.language_model_heads.bert.
BertLanguageModelHead
(model_name: str)[source]¶ Bases:
allennlp.modules.language_model_heads.language_model_head.LanguageModelHead
Loads just the LM head from
pytorch_transformers.BertForMaskedLM
. It was easiest to load the entire model before only pulling out the head, so this is a bit slower than it could be, but for practical use in a model, the few seconds of extra loading time is probably not a big deal.-
forward
(self, hidden_states: torch.Tensor) → torch.Tensor[source]¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
class
allennlp.modules.language_model_heads.gpt2.
Gpt2LanguageModelHead
(model_name: str)[source]¶ Bases:
allennlp.modules.language_model_heads.language_model_head.LanguageModelHead
Loads just the LM head from
pytorch_transformers.GPT2LMHeadModel
. It was easiest to load the entire model before only pulling out the head, so this is a bit slower than it could be, but for practical use in a model, the few seconds of extra loading time is probably not a big deal.-
forward
(self, hidden_states: torch.Tensor) → torch.Tensor[source]¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-