next_token_lm
NextTokenLM#
class NextTokenLM(Model):
| def __init__(
| self,
| vocab: Vocabulary,
| text_field_embedder: TextFieldEmbedder,
| language_model_head: LanguageModelHead,
| contextualizer: Seq2SeqEncoder = None,
| target_namespace: str = "bert",
| dropout: float = 0.0,
| initializer: InitializerApplicator = None,
| **kwargs
| ) -> None
The NextTokenLM
embeds some input tokens, contextualizes them, then predicts the next word,
computing a loss against known target.
NOTE: This was developed for use in a demo, not for training. You definitely
don't want to
train a language model using this code; it would be incredibly inefficient. This does
compute correct gradients of the loss, however, so you can use it for interesting visualization
of the gradients of a pretrained model, and it appears to be fast enough to sample from, at
least for one word at a time. If you want to sample many tokens at a time, you'd want to
re-use some intermediate computation, so you would either need to modify this code or use
something else.
Parameters
- vocab :
Vocabulary
- text_field_embedder :
TextFieldEmbedder
Used to embed the indexed tokens we get inforward
. - language_model_head :
LanguageModelHead
Thetorch.nn.Module
that goes from the hidden states output by the contextualizer to logits over some output vocabulary. - contextualizer :
Seq2SeqEncoder
, optional (default =None
)
Used to "contextualize" the embeddings. This is optional because the contextualization might actually be done in the text field embedder. - target_namespace :
str
, optional (default ='bert'
)
Namespace to use to convert predicted token ids to strings inModel.make_output_human_readable
. - dropout :
float
, optional (default =0.0
)
If specified, dropout is applied to the contextualized embeddings before computation of the softmax. The contextualized embeddings themselves are returned without dropout.
forward#
class NextTokenLM(Model):
| ...
| def forward(
| self,
| tokens: TextFieldTensors,
| target_ids: TextFieldTensors = None
| ) -> Dict[str, torch.Tensor]
Shape: (batch_size, num_tokens, embedding_dim)
get_metrics#
class NextTokenLM(Model):
| ...
| def get_metrics(self, reset: bool = False)
make_output_human_readable#
class NextTokenLM(Model):
| ...
| @overrides
| def make_output_human_readable(
| self,
| output_dict: Dict[str, torch.Tensor]
| ) -> Dict[str, torch.Tensor]
default_predictor#
class NextTokenLM(Model):
| ...
| default_predictor = "next_token_lm"