allennlp.modules.language_model_heads

class allennlp.modules.language_model_heads.language_model_head.LanguageModelHead[source]

Bases: torch.nn.modules.module.Module, allennlp.common.registrable.Registrable

A LanguageModelHead encapsulates a function that goes from some hidden state to logits over a vocabulary.

forward(self, hidden_states: torch.Tensor) → torch.Tensor[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

get_input_dim(self) → int[source]
get_output_dim(self) → int[source]
class allennlp.modules.language_model_heads.bert.BertLanguageModelHead(model_name: str)[source]

Bases: allennlp.modules.language_model_heads.language_model_head.LanguageModelHead

Loads just the LM head from pytorch_transformers.BertForMaskedLM. It was easiest to load the entire model before only pulling out the head, so this is a bit slower than it could be, but for practical use in a model, the few seconds of extra loading time is probably not a big deal.

forward(self, hidden_states: torch.Tensor) → torch.Tensor[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

get_input_dim(self) → int[source]
get_output_dim(self) → int[source]
class allennlp.modules.language_model_heads.gpt2.Gpt2LanguageModelHead(model_name: str)[source]

Bases: allennlp.modules.language_model_heads.language_model_head.LanguageModelHead

Loads just the LM head from pytorch_transformers.GPT2LMHeadModel. It was easiest to load the entire model before only pulling out the head, so this is a bit slower than it could be, but for practical use in a model, the few seconds of extra loading time is probably not a big deal.

forward(self, hidden_states: torch.Tensor) → torch.Tensor[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

get_input_dim(self) → int[source]
get_output_dim(self) → int[source]