Skip to content

transformer_beam_search_generator

allennlp_models.lm.util.beam_search_generators.transformer_beam_search_generator

[SOURCE]


TransformerBeamSearchGenerator#

@BeamSearchGenerator.register("transformer")
class TransformerBeamSearchGenerator(BeamSearchGenerator):
 | def __init__(self, *args, namespace: str = None, **kwargs) -> None

A BeamSearchGenerator for transformer-based NextTokenLM models.

This can be used with any NextTokenLM that utilizes a single pretrained_transformer TokenEmbedder for it's text_field_embedder.

validate_text_field_embedder#

class TransformerBeamSearchGenerator(BeamSearchGenerator):
 | ...
 | def validate_text_field_embedder(
 |     self,
 |     text_field_embedder: TextFieldEmbedder
 | )

prepare_step_input#

class TransformerBeamSearchGenerator(BeamSearchGenerator):
 | ...
 | def prepare_step_input(
 |     self,
 |     predictions: torch.Tensor,
 |     state: Dict[str, torch.Tensor]
 | ) -> TextFieldTensors

Add predicted_tokens to state["token_ids"] and expand state["mask"].