Skip to content





class BeamSearchGenerator(Registrable):
 | def __init__(self, beam_search: BeamSearch)

A beam search generator for next token language models.

This is just a wrapper around allennlp.nn.beam_search.BeamSearch with custom logic for handling the state dict.

The reason we need this is because the step function that BeamSearch uses needs to know how to handle different TextFieldTensors, the form of which depends on the exact embedder class that the NextTokenLm uses.

So essentially we need a different BeamSearchGenerator implementation for each different text_field_embedder.


class BeamSearchGenerator(Registrable):
 | ...
 | def validate_text_field_embedder(
 |     self,
 |     text_field_embedder: TextFieldEmbedder
 | )

This should be called after initialization to verify that the model's text_field_embedder is compatable.


class BeamSearchGenerator(Registrable):
 | ...
 | def get_step_state(
 |     self,
 |     inputs: TextFieldTensors
 | ) -> Dict[str, torch.Tensor]

Create a state dictionary for BeamSearch from the TextFieldTensors inputs to the NextTokenLm model.

By default this assumes the TextFieldTensors has a single TokenEmbedder, and just "flattens" the TextFieldTensors by returning the TokenEmbedder sub-dictionary.

If you have TextFieldTensors with more than one TokenEmbedder sub-dictionary, you'll need to override this class.


class BeamSearchGenerator(Registrable):
 | ...
 | def prepare_step_input(
 |     self,
 |     predictions: torch.Tensor,
 |     state: Dict[str, torch.Tensor]
 | ) -> TextFieldTensors

This is like the reverse of get_step_state().

It takes predictions and state from the current step and returns a TextFieldTensors dictionary that can be fed through the embedder of the NextTokenLm model.

This usually involves adding the predicted tokens to the proper field of the state dict, and expanding any mask tensors or other context tensors by 1 in the right dimension, and then unflattening the state so that it looks like a TextFieldTensors dict.

class BeamSearchGenerator(Registrable):
 | ...
 | def search(
 |     self,
 |     start_predictions: torch.Tensor,
 |     state: Dict[str, torch.Tensor],
 |     step_function: StepFunctionType
 | ) -> Tuple[torch.Tensor, torch.Tensor]

Calls, return the top predicted indices and corresponding log probabilities.