Skip to content

next_token_lm

NextTokenLM#

class NextTokenLM(Model):
 | def __init__(
 |     self,
 |     vocab: Vocabulary,
 |     text_field_embedder: TextFieldEmbedder,
 |     language_model_head: LanguageModelHead,
 |     contextualizer: Seq2SeqEncoder = None,
 |     target_namespace: str = "bert",
 |     dropout: float = 0.0,
 |     initializer: InitializerApplicator = None,
 |     **kwargs
 | ) -> None

The NextTokenLM embeds some input tokens, contextualizes them, then predicts the next word, computing a loss against known target.

NOTE: This was developed for use in a demo, not for training. You definitely don't want to train a language model using this code; it would be incredibly inefficient. This does compute correct gradients of the loss, however, so you can use it for interesting visualization of the gradients of a pretrained model, and it appears to be fast enough to sample from, at least for one word at a time. If you want to sample many tokens at a time, you'd want to re-use some intermediate computation, so you would either need to modify this code or use something else.

Parameters

  • vocab : Vocabulary
  • text_field_embedder : TextFieldEmbedder
    Used to embed the indexed tokens we get in forward.
  • language_model_head : LanguageModelHead
    The torch.nn.Module that goes from the hidden states output by the contextualizer to logits over some output vocabulary.
  • contextualizer : Seq2SeqEncoder, optional (default = None)
    Used to "contextualize" the embeddings. This is optional because the contextualization might actually be done in the text field embedder.
  • target_namespace : str, optional (default = 'bert')
    Namespace to use to convert predicted token ids to strings in Model.make_output_human_readable.
  • dropout : float, optional (default = 0.0)
    If specified, dropout is applied to the contextualized embeddings before computation of the softmax. The contextualized embeddings themselves are returned without dropout.

forward#

class NextTokenLM(Model):
 | ...
 | def forward(
 |     self,
 |     tokens: TextFieldTensors,
 |     target_ids: TextFieldTensors = None
 | ) -> Dict[str, torch.Tensor]

Shape: (batch_size, num_tokens, embedding_dim)

get_metrics#

class NextTokenLM(Model):
 | ...
 | def get_metrics(self, reset: bool = False)

make_output_human_readable#

class NextTokenLM(Model):
 | ...
 | @overrides
 | def make_output_human_readable(
 |     self,
 |     output_dict: Dict[str, torch.Tensor]
 | ) -> Dict[str, torch.Tensor]

default_predictor#

class NextTokenLM(Model):
 | ...
 | default_predictor = "next_token_lm"