Skip to content

beam_search

allennlp.nn.beam_search

[SOURCE]


StateType

StateType = Dict[str, torch.Tensor]

StepFunctionTypeWithTimestep

StepFunctionTypeWithTimestep = Callable[
    [torch.Tensor, StateType, int], Tuple[torch.Tensor, StateType]
]

StepFunctionTypeNoTimestep

StepFunctionTypeNoTimestep = Callable[[torch.Tensor, StateType], Tuple[torch.Tensor, StateType]]

StepFunctionType

StepFunctionType = TypeVar(
    "StepFunctionType", StepFunctionTypeWithTimestep, StepFunctionTypeNoTimestep
)

The type of step function that can be passed to BeamSearch.search.

This can either be StepFunctionTypeWithTimestep or StepFunctionTypeNoTimestep.

ConstraintStateType

ConstraintStateType = List[List[Dict[str, Any]]]

Sampler

class Sampler(Registrable)

An abstract class that can be used to sample candidates (either nodes or beams) within BeamSearch.

A Sampler just has three methods, init_state(), sample_nodes() and sample_beams().

init_state() takes three arguments:

  • a tensor of starting log probs with shape (batch_size,, num_classes),
  • the batch size, an int,
  • and the number of classes, also an int.

It returns a state dictionary with any state tensors needed for subsequent calls to sample_nodes() and sample_beams().

By default this method just returns an empty dictionary.

Both sample_nodes() and sample_beams() should take three arguments:

  • tensor of normalized log probabilities with shape (batch_size, num_examples),
  • an integer representing the number of samples to take for each example in the batch,
  • and a state dictionary which could contain any tensors needed for the Sampler to keep track of state.

For sample_nodes(), num_examples = num_classes, but for sample_beams, num_examples = beam_size * per_node_beam_size.

The return value should be a tuple containing:

  • a tensor of log probabilities of the sampled examples with shape (batch_size, num_samples),
  • a tensor of indices of the sampled examples with shape (batch_size, num_samples),
  • and the updated state dictionary.

A default implementation of sample_beams is provided, which just deterministically picks the k examples with highest log probability.

default_implementation

class Sampler(Registrable):
 | ...
 | default_implementation = "deterministic"

init_state

class Sampler(Registrable):
 | ...
 | def init_state(
 |     self,
 |     start_class_log_probabilities: torch.Tensor,
 |     batch_size: int,
 |     num_classes: int
 | ) -> StateType

sample_nodes

class Sampler(Registrable):
 | ...
 | def sample_nodes(
 |     self,
 |     log_probs: torch.Tensor,
 |     per_node_beam_size: int,
 |     state: StateType
 | ) -> Tuple[torch.Tensor, torch.Tensor, StateType]

sample_beams

class Sampler(Registrable):
 | ...
 | def sample_beams(
 |     self,
 |     log_probs: torch.Tensor,
 |     beam_size: int,
 |     state: StateType
 | ) -> Tuple[torch.Tensor, torch.Tensor, StateType]

DeterministicSampler

@Sampler.register("deterministic")
class DeterministicSampler(Sampler)

A Sampler that just deterministically returns the k nodes or beams with highest log probability.

sample_nodes

class DeterministicSampler(Sampler):
 | ...
 | def sample_nodes(
 |     self,
 |     log_probs: torch.Tensor,
 |     per_node_beam_size: int,
 |     state: StateType
 | ) -> Tuple[torch.Tensor, torch.Tensor, StateType]

MultinomialSampler

@Sampler.register("multinomial")
class MultinomialSampler(Sampler):
 | def __init__(
 |     self,
 |     temperature: float = 1.0,
 |     with_replacement: bool = False
 | ) -> None

A Sampler which samples nodes from the given multinomial distribution. Beams are sampled in the default, non-deterministic way.

Parameters

  • temperature : float, optional (default = 1.0)
    A temperature below 1.0 produces a sharper probability distribution and a temperature above 1.0 produces a flatter probability distribution.
  • with_replacement : bool, optional (default = False)
    Whether to sample with replacement.

sample_nodes

class MultinomialSampler(Sampler):
 | ...
 | def sample_nodes(
 |     self,
 |     log_probs: torch.Tensor,
 |     per_node_beam_size: int,
 |     state: StateType
 | ) -> Tuple[torch.Tensor, torch.Tensor, StateType]

TopKSampler

@Sampler.register("top-k")
class TopKSampler(Sampler):
 | def __init__(
 |     self,
 |     k: int = 1,
 |     temperature: float = 1.0,
 |     with_replacement: bool = False
 | )

A Sampler which redistributes the probability mass function for nodes among the top k choices, then samples from that subset after re-normalizing the probabilities.

Beams are sampled in the default, deterministic way.

Parameters

  • k : int, optional (default = 1)
    The number of top choices to be selected from.
  • temperature : float, optional (default = 1.0)
    A temperature below 1.0 produces a sharper probability distribution and a temperature above 1.0 produces a flatter probability distribution.
  • with_replacement : bool, optional (default = False)
    If set to True, samples will be selected with replacement from the top k choices.

sample_nodes

class TopKSampler(Sampler):
 | ...
 | def sample_nodes(
 |     self,
 |     log_probs: torch.Tensor,
 |     per_node_beam_size: int,
 |     state: StateType
 | ) -> Tuple[torch.Tensor, torch.Tensor, StateType]

TopPSampler

@Sampler.register("top-p")
class TopPSampler(Sampler):
 | def __init__(
 |     self,
 |     p: float = 0.9,
 |     temperature: float = 1.0,
 |     with_replacement: bool = False
 | )

A Sampler which redistributes the probability mass function for nodes among the top choices with a cumulative probability of at least p, then samples from that subset after re-normalizing the probabilities.

Beams are sampled in the default, deterministic way.

Parameters

  • p : float, optional (default = 0.9)
    The cumulative probability cutoff threshold. A higher value of p will result in more possible examples to sample from. If with_replacement is False and the number of possible samples is insufficient to sample without replacement from when calling sample_nodes, then the top per_node_beam_size examples will be chosen.
  • temperature : float, optional (default = 1.0)
    A temperature below 1.0 produces a sharper probability distribution and a temperature above 1.0 produces a flatter probability distribution.
  • with_replacement : bool, optional (default = False)
    If set to True, samples will be selected with replacement from the top choices.

sample_nodes

class TopPSampler(Sampler):
 | ...
 | def sample_nodes(
 |     self,
 |     log_probs: torch.Tensor,
 |     per_node_beam_size: int,
 |     state: StateType
 | ) -> Tuple[torch.Tensor, torch.Tensor, StateType]

GumbelSampler

@Sampler.register("gumbel")
class GumbelSampler(Sampler):
 | def __init__(self, temperature: float = 1.0)

A Sampler which uses the Gumbel-Top-K trick to sample without replacement. See Stochastic Beams and Where to Find Them: The Gumbel-Top-k Trick for Sampling Sequences Without Replacement, W Kool, H Van Hoof and M Welling, 2010.

Parameters

  • temperature : float, optional (default = 1.0)
    A temperature below 1.0 produces a sharper probability distribution and a temperature above 1.0 produces a flatter probability distribution.

init_state

class GumbelSampler(Sampler):
 | ...
 | def init_state(
 |     self,
 |     start_class_log_probabilities: torch.Tensor,
 |     batch_size: int,
 |     num_classes: int
 | ) -> StateType

sample_nodes

class GumbelSampler(Sampler):
 | ...
 | def sample_nodes(
 |     self,
 |     log_probs: torch.Tensor,
 |     per_node_beam_size: int,
 |     state: StateType
 | ) -> Tuple[torch.Tensor, torch.Tensor, StateType]

sample_beams

class GumbelSampler(Sampler):
 | ...
 | def sample_beams(
 |     self,
 |     log_probs: torch.Tensor,
 |     beam_size: int,
 |     state: StateType
 | ) -> Tuple[torch.Tensor, torch.Tensor, StateType]

Returns the beams with the highest perturbed log probabilities.

gumbel

class GumbelSampler(Sampler):
 | ...
 | def gumbel(self, phi) -> torch.Tensor

Sample Gumbel(phi).

phi should have shape (batch_size, num_classes).

gumbel_with_max

class GumbelSampler(Sampler):
 | ...
 | def gumbel_with_max(self, phi, T) -> torch.Tensor

Sample Gumbel(phi) conditioned on the maximum value being equal to T.

phi should have shape (batch_size, num_classes) and T should have shape (batch_size, 1).

FinalSequenceScorer

class FinalSequenceScorer(Registrable)

An abstract class that can be used to score the final generated sequences found by beam search. Given the predicted sequences and the corresponding log probabilities of those sequences, the class calculates and returns the final score of the sequences.

The default implementation scores the sequences using the sum of the log probabilities of the sequence, which is passed as input.

default_implementation

class FinalSequenceScorer(Registrable):
 | ...
 | default_implementation = "sequence-log-prob"

score

class FinalSequenceScorer(Registrable):
 | ...
 | def score(
 |     self,
 |     predictions: torch.Tensor,
 |     log_probabilities: torch.Tensor,
 |     end_index: int
 | ) -> torch.Tensor

Score the final predictions found by beam search.

Parameters

  • predictions : torch.Tensor
    A tensor containing the initial predictions with shape (batch_size, beam_size, max_steps).

  • log_probabilities : torch.Tensor
    A tensor containing the log probabilities of the sequence, defined as the sum of the log probabilities per token, with shape (batch_size, beam_size).

  • end_index : int
    The index of the end symbol.

Returns

  • torch.Tensor
    A tensor of the final sequence scores of shape (batch_size, beam_size).

SequenceLogProbabilityScorer

@FinalSequenceScorer.register("sequence-log-prob")
class SequenceLogProbabilityScorer(FinalSequenceScorer)

A FinalSequenceScorer which scores the sequences by the sum of the log probabilities across the sequence's tokens.

score

class SequenceLogProbabilityScorer(FinalSequenceScorer):
 | ...
 | def score(
 |     self,
 |     predictions: torch.Tensor,
 |     log_probabilities: torch.Tensor,
 |     end_index: int
 | ) -> torch.Tensor

LengthNormalizedSequenceLogProbabilityScorer

@FinalSequenceScorer.register("length-normalized-sequence-log-prob")
class LengthNormalizedSequenceLogProbabilityScorer(FinalSequenceScorer):
 | def __init__(self, length_penalty: float = 1.0)

A FinalSequenceScorer which scores the sequences by the average log probability of the tokens in the sequence. It optionally includes a length penalty which promotes or demotes sequences based on their lengths. The final score for a sequence will be (sequence_log_probability) / (sequence_length ** length_penalty). The sequence length here includes the end token.

Parameters

  • length_penalty : float, optional (default = 1.0)
    The length penalty to use. A value of 1.0 means no length penalty is used. A value > 1.0 favors longer sequences, and < 1.0 favors shorter sequences.

score

class LengthNormalizedSequenceLogProbabilityScorer(FinalSequenceScorer):
 | ...
 | def score(
 |     self,
 |     predictions: torch.Tensor,
 |     log_probabilities: torch.Tensor,
 |     end_index: int
 | ) -> torch.Tensor

Constraint

class Constraint(Registrable):
 | def __init__(self, vocab: Optional[Vocabulary] = None) -> None

An abstract class that can be used to enforce constraints on the output predictions by manipulating the class log probabilities during beam search.

A Constraint just has three methods that need to be implemented by subclasses: init_state(), apply() and _update_state().

init_state() takes one argument:

  • the batch size, an int

It returns a constraint state, which is a nested list of dictionaries, with any state needed for subsequent calls to apply() and update_state(). The length of the outer list should be equal to batch_size. Each inner list should be of length 1.

apply() takes two arguments:

  • the constraint state, which is a nested list of dictionaries. The length of the outer list is batch_size and the length of each inner list is beam_size except on the first time apply() is called when it is 1.
  • class_log_probabilities, a tensor of shape (batch_size, beam_size, num_classes) that contains the log probabilities for the classes during search. The first time apply() is called, beam_size = 1.

The apply() method should return new class_log_probabilities that enforce the constraint for this step of beam search. For instance, it may prevent a specific class from being selected by setting the corresponding log probability to a negligible value such as float("-inf") or min_value_of_dtype(class_log_probabilities.dtype).

_update_state() takes two arguments:

  • the copied parent constraint state, which is a nested list of dictionaries. state[i][j] contains the copied state for the parent of last_prediction[i, j]. It is unique to that batch and beam, so it can be directly edited in-place without affecting the others.
  • last_prediction, a tensor of shape (batch_size, beam_size) containing the predictions from the last step of beam search.

The _update_state() function should return a new constraint state, a nested list of dictionaries of length batch_size and inner list of length beam_size, one for each of the predictions in last_prediction.

init_state

class Constraint(Registrable):
 | ...
 | def init_state(self, batch_size: int) -> ConstraintStateType

apply

class Constraint(Registrable):
 | ...
 | def apply(
 |     self,
 |     state: ConstraintStateType,
 |     class_log_probabilities: torch.Tensor
 | ) -> torch.Tensor

update_state

class Constraint(Registrable):
 | ...
 | def update_state(
 |     self,
 |     state: ConstraintStateType,
 |     last_prediction: torch.Tensor,
 |     last_backpointer: Optional[torch.Tensor] = None
 | ) -> ConstraintStateType

_update_state

class Constraint(Registrable):
 | ...
 | def _update_state(
 |     self,
 |     state: ConstraintStateType,
 |     last_prediction: torch.Tensor
 | ) -> ConstraintStateType

RepeatedNGramBlockingConstraint

@Constraint.register("repeated-ngram-blocking")
class RepeatedNGramBlockingConstraint(Constraint):
 | def __init__(self, ngram_size: int, **kwargs) -> None

init_state

class RepeatedNGramBlockingConstraint(Constraint):
 | ...
 | def init_state(self, batch_size: int) -> ConstraintStateType

apply

class RepeatedNGramBlockingConstraint(Constraint):
 | ...
 | def apply(
 |     self,
 |     state: ConstraintStateType,
 |     class_log_probabilities: torch.Tensor
 | ) -> torch.Tensor

BeamSearch

class BeamSearch(Registrable):
 | def __init__(
 |     self,
 |     end_index: int,
 |     max_steps: int = 50,
 |     beam_size: int = 10,
 |     per_node_beam_size: int = None,
 |     sampler: Sampler = None,
 |     min_steps: Optional[int] = None,
 |     final_sequence_scorer: FinalSequenceScorer = None,
 |     constraints: Optional[List[Lazy[Constraint]]] = None,
 |     vocab: Optional[Vocabulary] = None
 | ) -> None

Implements the beam search algorithm for decoding the most likely sequences.

Parameters

  • end_index : int
    The index of the "stop" or "end" token in the target vocabulary.

  • max_steps : int, optional (default = 50)
    The maximum number of decoding steps to take, i.e. the maximum length of the predicted sequences.

  • beam_size : int, optional (default = 10)
    The width of the beam used.

  • per_node_beam_size : int, optional (default = beam_size)
    The maximum number of candidates to consider per node, at each step in the search. If not given, this just defaults to beam_size. Setting this parameter to a number smaller than beam_size may give better results, as it can introduce more diversity into the search. See Beam Search Strategies for Neural Machine Translation, Freitag and Al-Onaizan, 2017.

  • sampler : Sampler, optional (default = None)
    An optional Sampler which is used to pick next candidate nodes and beams. If not specified, DeterministicSampler will be used, which just takes the per_node_beam_size most likely nodes and the beam_size most likely beams.

    Using the GumbelSampler, on the other hand, will give you Stochastic Beam Search.

  • min_steps : int, optional (default = None)
    The minimum number of decoding steps to take, i.e. the minimum length of the predicted sequences. This does not include the start or end tokens. If None, no minimum is enforced.

  • final_sequence_scorer : FinalSequenceScorer, optional (default = None)
    An optional FinalSequenceScorer which is used to score the final generated sequences. The output from this module is what is returned by the search method. If not specified, SequenceLogProbabilityScorer will be used, which scores the sequences by the sum of the token log probabilities.

  • constraints : List[Constraint], optional (default = None)
    An optional list of Constraints which should be applied during beam search. If not provided, no constraints will be enforced.

  • vocab : Vocabulary
    If constraints is not None, then Vocabulary will be passed to each constraint during its initialization. Having access to the vocabulary may be useful for certain contraints, e.g., to mask out invalid predictions during structured prediction.

    In a typical AllenNLP configuration file, this parameter does not get an entry under the "model", it gets specified as a top-level parameter, then is passed in to the model separately.

default_implementation

class BeamSearch(Registrable):
 | ...
 | default_implementation = "beam_search"

class BeamSearch(Registrable):
 | ...
 | @torch.no_grad()
 | def search(
 |     self,
 |     start_predictions: torch.Tensor,
 |     start_state: StateType,
 |     step: StepFunctionType
 | ) -> Tuple[torch.Tensor, torch.Tensor]

Given a starting state and a step function, apply beam search to find the most likely target sequences.

Note

If your step function returns -inf for some log probabilities (like if you're using a masked log-softmax) then some of the "best" sequences returned may also have -inf log probability. Specifically this happens when the beam size is smaller than the number of actions with finite log probability (non-zero probability) returned by the step function. Therefore if you're using a mask you may want to check the results from search and potentially discard sequences with non-finite log probability.

Parameters

  • start_predictions : torch.Tensor
    A tensor containing the initial predictions with shape (batch_size,). Usually the initial predictions are just the index of the "start" token in the target vocabulary.

  • start_state : StateType
    The initial state passed to the step function. Each value of the state dict should be a tensor of shape (batch_size, *), where * means any other number of dimensions.

  • step : StepFunctionType
    A function that is responsible for computing the next most likely tokens, given the current state and the predictions from the last time step. The function should accept two or three arguments:

    • a tensor of shape (group_size,) representing the index of the predicted tokens from the last time step,
    • the current state, a StateType, and
    • optionally, the timestep, an int.

    The group_size will be batch_size * beam_size, except in the initial step, for which it will just be batch_size.

    The function is expected to return a tuple, where the first element is a tensor of shape (group_size, target_vocab_size) containing the log probabilities of the tokens for the next step, and the second element is the updated state. The tensor in the state should have shape (group_size, *), where * means any other number of dimensions.

Returns

  • Tuple[torch.Tensor, torch.Tensor]
    Tuple of (predictions, final_scores), where predictions has shape (batch_size, beam_size, max_steps) and final_scores has shape (batch_size, beam_size).