transformer_beam_search_generator
allennlp_models.lm.util.beam_search_generators.transformer_beam_search_generator
TransformerBeamSearchGenerator#
@BeamSearchGenerator.register("transformer")
class TransformerBeamSearchGenerator(BeamSearchGenerator):
| def __init__(self, *args, namespace: str = None, **kwargs) -> None
A BeamSearchGenerator
for transformer-based NextTokenLM
models.
This can be used with any NextTokenLM
that utilizes a single pretrained_transformer
TokenEmbedder
for it's text_field_embedder
.
validate_text_field_embedder#
class TransformerBeamSearchGenerator(BeamSearchGenerator):
| ...
| @overrides
| def validate_text_field_embedder(
| self,
| text_field_embedder: TextFieldEmbedder
| )
prepare_step_input#
class TransformerBeamSearchGenerator(BeamSearchGenerator):
| ...
| @overrides
| def prepare_step_input(
| self,
| predictions: torch.Tensor,
| state: Dict[str, torch.Tensor]
| ) -> TextFieldTensors
Add predicted_tokens
to state["token_ids"]
and expand state["mask"]
.