beam_search_generator
allennlp_models.lm.util.beam_search_generators.beam_search_generator
BeamSearchGenerator#
class BeamSearchGenerator(Registrable):
| def __init__(self, beam_search: BeamSearch)
A beam search generator for next token language models.
This is just a wrapper around allennlp.nn.beam_search.BeamSearch with custom
logic for handling the state dict.
The reason we need this is because the step function that BeamSearch uses
needs to know how to handle different TextFieldTensors, the form of which
depends on the exact embedder class that the NextTokenLm uses.
So essentially we need a different BeamSearchGenerator implementation
for each different text_field_embedder.
validate_text_field_embedder#
class BeamSearchGenerator(Registrable):
| ...
| def validate_text_field_embedder(
| self,
| text_field_embedder: TextFieldEmbedder
| )
This should be called after initialization to verify that the model's
text_field_embedder is compatable.
get_step_state#
class BeamSearchGenerator(Registrable):
| ...
| def get_step_state(
| self,
| inputs: TextFieldTensors
| ) -> Dict[str, torch.Tensor]
Create a state dictionary for BeamSearch from the TextFieldTensors inputs
to the NextTokenLm model.
By default this assumes the TextFieldTensors has a single TokenEmbedder,
and just "flattens" the TextFieldTensors by returning the TokenEmbedder
sub-dictionary.
If you have TextFieldTensors with more than one TokenEmbedder sub-dictionary,
you'll need to override this class.
prepare_step_input#
class BeamSearchGenerator(Registrable):
| ...
| def prepare_step_input(
| self,
| predictions: torch.Tensor,
| state: Dict[str, torch.Tensor]
| ) -> TextFieldTensors
This is like the reverse of get_step_state().
It takes predictions and state from the current step and returns
a TextFieldTensors dictionary that can be fed through the embedder of the NextTokenLm
model.
This usually involves adding the predicted tokens to the proper field of the state dict,
and expanding any mask tensors or other context tensors by 1 in the right dimension,
and then unflattening the state so that it looks like a TextFieldTensors dict.
search#
class BeamSearchGenerator(Registrable):
| ...
| def search(
| self,
| start_predictions: torch.Tensor,
| state: Dict[str, torch.Tensor],
| step_function: StepFunctionType
| ) -> Tuple[torch.Tensor, torch.Tensor]
Calls BeamSearch.search, return the top predicted indices and corresponding
log probabilities.