beam_search_generator
allennlp_models.lm.util.beam_search_generators.beam_search_generator
BeamSearchGenerator#
class BeamSearchGenerator(Registrable):
| def __init__(self, beam_search: BeamSearch)
A beam search generator for next token language models.
This is just a wrapper around allennlp.nn.beam_search.BeamSearch
with custom
logic for handling the state
dict.
The reason we need this is because the step function that BeamSearch
uses
needs to know how to handle different TextFieldTensors
, the form of which
depends on the exact embedder class that the NextTokenLm
uses.
So essentially we need a different BeamSearchGenerator
implementation
for each different text_field_embedder
.
validate_text_field_embedder#
class BeamSearchGenerator(Registrable):
| ...
| def validate_text_field_embedder(
| self,
| text_field_embedder: TextFieldEmbedder
| )
This should be called after initialization to verify that the model's
text_field_embedder
is compatable.
get_step_state#
class BeamSearchGenerator(Registrable):
| ...
| def get_step_state(
| self,
| inputs: TextFieldTensors
| ) -> Dict[str, torch.Tensor]
Create a state
dictionary for BeamSearch
from the TextFieldTensors
inputs
to the NextTokenLm
model.
By default this assumes the TextFieldTensors
has a single TokenEmbedder
,
and just "flattens" the TextFieldTensors
by returning the TokenEmbedder
sub-dictionary.
If you have TextFieldTensors
with more than one TokenEmbedder
sub-dictionary,
you'll need to override this class.
prepare_step_input#
class BeamSearchGenerator(Registrable):
| ...
| def prepare_step_input(
| self,
| predictions: torch.Tensor,
| state: Dict[str, torch.Tensor]
| ) -> TextFieldTensors
This is like the reverse of get_step_state()
.
It takes predictions
and state
from the current step and returns
a TextFieldTensors
dictionary that can be fed through the embedder of the NextTokenLm
model.
This usually involves adding the predicted tokens to the proper field of the state
dict,
and expanding any mask tensors or other context tensors by 1 in the right dimension,
and then unflattening the state
so that it looks like a TextFieldTensors
dict.
search#
class BeamSearchGenerator(Registrable):
| ...
| def search(
| self,
| start_predictions: torch.Tensor,
| state: Dict[str, torch.Tensor],
| step_function: StepFunctionType
| ) -> Tuple[torch.Tensor, torch.Tensor]
Calls BeamSearch.search
, return the top predicted indices and corresponding
log probabilities.