Skip to content

beam_search

allennlp.nn.beam_search

[SOURCE]


StateType

StateType = Dict[str, torch.Tensor]

StepFunctionTypeWithTimestep

StepFunctionTypeWithTimestep = Callable[
    [torch.Tensor, StateType, int], Tuple[torch.Tensor, StateType]
]

StepFunctionTypeNoTimestep

StepFunctionTypeNoTimestep = Callable[[torch.Tensor, StateType], Tuple[torch.Tensor, StateType]]

StepFunctionType

StepFunctionType = TypeVar(
    "StepFunctionType", StepFunctionTypeWithTimestep, StepFunctionTypeNoTimestep
)

The type of step function that can be passed to BeamSearch.search.

This can either be StepFunctionTypeWithTimestep or StepFunctionTypeNoTimestep.

Sampler

class Sampler(Registrable)

An abstract class that can be used to sample candidates (either nodes or beams) within BeamSearch.

A Sampler just has three methods, init_state(), sample_nodes() and sample_beams().

init_state() takes three arguments:

  • a tensor of starting log probs with shape (batch_size,, num_classes),
  • the batch size, an int,
  • and the number of classes, also an int.

It returns a state dictionary with any state tensors needed for subsequent calls to sample_nodes() and sample_beams().

By default this method just returns an empty dictionary.

Both sample_nodes() and sample_beams() should take three arguments:

  • tensor of normalized log probabilities with shape (batch_size, num_examples),
  • an integer representing the number of samples to take for each example in the batch,
  • and a state dictionary which could contain any tensors needed for the Sampler to keep track of state.

For sample_nodes(), num_examples = num_classes, but for sample_beams, num_examples = beam_size * per_node_beam_size.

The return value should be a tuple containing:

  • a tensor of log probabilities of the sampled examples with shape (batch_size, num_samples),
  • a tensor of indices of the sampled examples with shape (batch_size, num_samples),
  • and the updated state dictionary.

A default implementation of sample_beams is provided, which just deterministically picks the k examples with highest log probability.

default_implementation

class Sampler(Registrable):
 | ...
 | default_implementation = "deterministic"

init_state

class Sampler(Registrable):
 | ...
 | def init_state(
 |     self,
 |     start_class_log_probabilities: torch.Tensor,
 |     batch_size: int,
 |     num_classes: int
 | ) -> StateType

sample_nodes

class Sampler(Registrable):
 | ...
 | def sample_nodes(
 |     self,
 |     log_probs: torch.Tensor,
 |     per_node_beam_size: int,
 |     state: StateType
 | ) -> Tuple[torch.Tensor, torch.Tensor, StateType]

sample_beams

class Sampler(Registrable):
 | ...
 | def sample_beams(
 |     self,
 |     log_probs: torch.Tensor,
 |     beam_size: int,
 |     state: StateType
 | ) -> Tuple[torch.Tensor, torch.Tensor, StateType]

DeterministicSampler

@Sampler.register("deterministic")
class DeterministicSampler(Sampler)

A Sampler that just deterministically returns the k nodes or beams with highest log probability.

sample_nodes

class DeterministicSampler(Sampler):
 | ...
 | @overrides
 | def sample_nodes(
 |     self,
 |     log_probs: torch.Tensor,
 |     per_node_beam_size: int,
 |     state: StateType
 | ) -> Tuple[torch.Tensor, torch.Tensor, StateType]

MultinomialSampler

@Sampler.register("multinomial")
class MultinomialSampler(Sampler):
 | def __init__(
 |     self,
 |     temperature: float = 1.0,
 |     with_replacement: bool = False
 | ) -> None

A Sampler which samples nodes from the given multinomial distribution. Beams are sampled in the default, non-deterministic way.

Parameters

  • temperature : float, optional (default = 1.0)
    A temperature below 1.0 produces a sharper probability distribution and a temperature above 1.0 produces a flatter probability distribution.
  • with_replacement : bool, optional (default = False)
    Whether to sample with replacement.

sample_nodes

class MultinomialSampler(Sampler):
 | ...
 | @overrides
 | def sample_nodes(
 |     self,
 |     log_probs: torch.Tensor,
 |     per_node_beam_size: int,
 |     state: StateType
 | ) -> Tuple[torch.Tensor, torch.Tensor, StateType]

TopKSampler

@Sampler.register("top-k")
class TopKSampler(Sampler):
 | def __init__(
 |     self,
 |     k: int = 1,
 |     temperature: float = 1.0,
 |     with_replacement: bool = False
 | )

A Sampler which redistributes the probability mass function for nodes among the top k choices, then samples from that subset after re-normalizing the probabilities.

Beams are sampled in the default, deterministic way.

Parameters

  • k : int, optional (default = 1)
    The number of top choices to be selected from.
  • temperature : float, optional (default = 1.0)
    A temperature below 1.0 produces a sharper probability distribution and a temperature above 1.0 produces a flatter probability distribution.
  • with_replacement : bool, optional (default = False)
    If set to True, samples will be selected with replacement from the top k choices.

sample_nodes

class TopKSampler(Sampler):
 | ...
 | @overrides
 | def sample_nodes(
 |     self,
 |     log_probs: torch.Tensor,
 |     per_node_beam_size: int,
 |     state: StateType
 | ) -> Tuple[torch.Tensor, torch.Tensor, StateType]

TopPSampler

@Sampler.register("top-p")
class TopPSampler(Sampler):
 | def __init__(
 |     self,
 |     p: float = 0.9,
 |     temperature: float = 1.0,
 |     with_replacement: bool = False
 | )

A Sampler which redistributes the probability mass function for nodes among the top choices with a cumulative probability of at least p, then samples from that subset after re-normalizing the probabilities.

Beams are sampled in the default, deterministic way.

Parameters

  • p : float, optional (default = 0.9)
    The cumulative probability cutoff threshold. A higher value of p will result in more possible examples to sample from. If with_replacement is False and the number of possible samples is insufficient to sample without replacement from when calling sample_nodes, then the top per_node_beam_size examples will be chosen.
  • temperature : float, optional (default = 1.0)
    A temperature below 1.0 produces a sharper probability distribution and a temperature above 1.0 produces a flatter probability distribution.
  • with_replacement : bool, optional (default = False)
    If set to True, samples will be selected with replacement from the top choices.

sample_nodes

class TopPSampler(Sampler):
 | ...
 | @overrides
 | def sample_nodes(
 |     self,
 |     log_probs: torch.Tensor,
 |     per_node_beam_size: int,
 |     state: StateType
 | ) -> Tuple[torch.Tensor, torch.Tensor, StateType]

GumbelSampler

@Sampler.register("gumbel")
class GumbelSampler(Sampler):
 | def __init__(self, temperature: float = 1.0)

A Sampler which uses the Gumbel-Top-K trick to sample without replacement. See Stochastic Beams and Where to Find Them: The Gumbel-Top-k Trick for Sampling Sequences Without Replacement, W Kool, H Van Hoof and M Welling, 2010.

Parameters

  • temperature : float, optional (default = 1.0)
    A temperature below 1.0 produces a sharper probability distribution and a temperature above 1.0 produces a flatter probability distribution.

init_state

class GumbelSampler(Sampler):
 | ...
 | @overrides
 | def init_state(
 |     self,
 |     start_class_log_probabilities: torch.Tensor,
 |     batch_size: int,
 |     num_classes: int
 | ) -> StateType

shape: (batch_size, num_classes)

sample_nodes

class GumbelSampler(Sampler):
 | ...
 | @overrides
 | def sample_nodes(
 |     self,
 |     log_probs: torch.Tensor,
 |     per_node_beam_size: int,
 |     state: StateType
 | ) -> Tuple[torch.Tensor, torch.Tensor, StateType]

First apply temperature coefficient: shape: (batch_size * beam_size, num_classes)

sample_beams

class GumbelSampler(Sampler):
 | ...
 | @overrides
 | def sample_beams(
 |     self,
 |     log_probs: torch.Tensor,
 |     beam_size: int,
 |     state: StateType
 | ) -> Tuple[torch.Tensor, torch.Tensor, StateType]

Returns the beams with the highest perturbed log probabilities.

gumbel

class GumbelSampler(Sampler):
 | ...
 | def gumbel(self, phi) -> torch.Tensor

Sample Gumbel(phi).

phi should have shape (batch_size, num_classes).

gumbel_with_max

class GumbelSampler(Sampler):
 | ...
 | def gumbel_with_max(self, phi, T) -> torch.Tensor

Sample Gumbel(phi) conditioned on the maximum value being equal to T.

phi should have shape (batch_size, num_classes) and T should have shape (batch_size, 1).

BeamSearch

class BeamSearch(FromParams):
 | def __init__(
 |     self,
 |     end_index: int,
 |     max_steps: int = 50,
 |     beam_size: int = 10,
 |     per_node_beam_size: int = None,
 |     sampler: Sampler = None
 | ) -> None

Implements the beam search algorithm for decoding the most likely sequences.

Parameters

  • end_index : int
    The index of the "stop" or "end" token in the target vocabulary.

  • max_steps : int, optional (default = 50)
    The maximum number of decoding steps to take, i.e. the maximum length of the predicted sequences.

  • beam_size : int, optional (default = 10)
    The width of the beam used.

  • per_node_beam_size : int, optional (default = beam_size)
    The maximum number of candidates to consider per node, at each step in the search. If not given, this just defaults to beam_size. Setting this parameter to a number smaller than beam_size may give better results, as it can introduce more diversity into the search. See Beam Search Strategies for Neural Machine Translation, Freitag and Al-Onaizan, 2017.

  • sampler : Sampler, optional (default = None)
    An optional Sampler which is used to pick next candidate nodes and beams. If not specified, DeterministicSampler will be used, which just takes the per_node_beam_size most likely nodes and the beam_size most likely beams.

    Using the GumbelSampler, on the other hand, will give you Stochastic Beam Search.

class BeamSearch(FromParams):
 | ...
 | @torch.no_grad()
 | def search(
 |     self,
 |     start_predictions: torch.Tensor,
 |     start_state: StateType,
 |     step: StepFunctionType
 | ) -> Tuple[torch.Tensor, torch.Tensor]

Given a starting state and a step function, apply beam search to find the most likely target sequences.

Notes

If your step function returns -inf for some log probabilities (like if you're using a masked log-softmax) then some of the "best" sequences returned may also have -inf log probability. Specifically this happens when the beam size is smaller than the number of actions with finite log probability (non-zero probability) returned by the step function. Therefore if you're using a mask you may want to check the results from search and potentially discard sequences with non-finite log probability.

Parameters

  • start_predictions : torch.Tensor
    A tensor containing the initial predictions with shape (batch_size,). Usually the initial predictions are just the index of the "start" token in the target vocabulary.

  • start_state : StateType
    The initial state passed to the step function. Each value of the state dict should be a tensor of shape (batch_size, *), where * means any other number of dimensions.

  • step : StepFunctionType
    A function that is responsible for computing the next most likely tokens, given the current state and the predictions from the last time step. The function should accept two or three arguments:

    • a tensor of shape (group_size,) representing the index of the predicted tokens from the last time step,
    • the current state, a StateType, and
    • optionally, the timestep, an int.

    The group_size will be batch_size * beam_size, except in the initial step, for which it will just be batch_size.

    The function is expected to return a tuple, where the first element is a tensor of shape (group_size, target_vocab_size) containing the log probabilities of the tokens for the next step, and the second element is the updated state. The tensor in the state should have shape (group_size, *), where * means any other number of dimensions.

Returns

  • Tuple[torch.Tensor, torch.Tensor]
    Tuple of (predictions, log_probabilities), where predictions has shape (batch_size, beam_size, max_steps) and log_probabilities has shape (batch_size, beam_size).