beam_search
allennlp.nn.beam_search
StateType#
StateType = Dict[str, torch.Tensor]
StepFunctionTypeWithTimestep#
StepFunctionTypeWithTimestep = Callable[
[torch.Tensor, StateType, int], Tuple[torch.Tensor, StateType]
]
StepFunctionTypeNoTimestep#
StepFunctionTypeNoTimestep = Callable[[torch.Tensor, StateType], Tuple[torch.Tensor, StateType]]
StepFunctionType#
StepFunctionType = TypeVar(
"StepFunctionType", StepFunctionTypeWithTimestep, StepFunctionTypeNoTimestep
)
The type of step function that can be passed to BeamSearch.search
.
This can either be StepFunctionTypeWithTimestep
or StepFunctionTypeNoTimestep
.
Sampler#
class Sampler(Registrable)
An abstract class that can be used to sample candidates (either nodes or beams)
within BeamSearch
.
A Sampler
just has three methods, init_state()
, sample_nodes()
and sample_beams()
.
init_state()
takes three arguments:
- a tensor of starting log probs with shape
(batch_size,, num_classes)
, - the batch size, an int,
- and the number of classes, also an int.
It returns a state dictionary with any state tensors needed for subsequent
calls to sample_nodes()
and sample_beams()
.
By default this method just returns an empty dictionary.
Both sample_nodes()
and sample_beams()
should take three arguments:
- tensor of normalized log probabilities with shape
(batch_size, num_examples)
, - an integer representing the number of samples to take for each example in the batch,
- and a state dictionary which could contain any tensors needed for the
Sampler
to keep track of state.
For sample_nodes()
, num_examples = num_classes
, but for sample_beams
,
num_examples = beam_size * per_node_beam_size
.
The return value should be a tuple containing:
- a tensor of log probabilities of the sampled examples with shape
(batch_size, num_samples)
, - a tensor of indices of the sampled examples with shape
(batch_size, num_samples)
, - and the updated state dictionary.
A default implementation of sample_beams
is provided, which just deterministically
picks the k
examples with highest log probability.
default_implementation#
class Sampler(Registrable):
| ...
| default_implementation = "deterministic"
init_state#
class Sampler(Registrable):
| ...
| def init_state(
| self,
| start_class_log_probabilities: torch.Tensor,
| batch_size: int,
| num_classes: int
| ) -> StateType
sample_nodes#
class Sampler(Registrable):
| ...
| def sample_nodes(
| self,
| log_probs: torch.Tensor,
| per_node_beam_size: int,
| state: StateType
| ) -> Tuple[torch.Tensor, torch.Tensor, StateType]
sample_beams#
class Sampler(Registrable):
| ...
| def sample_beams(
| self,
| log_probs: torch.Tensor,
| beam_size: int,
| state: StateType
| ) -> Tuple[torch.Tensor, torch.Tensor, StateType]
DeterministicSampler#
@Sampler.register("deterministic")
class DeterministicSampler(Sampler)
A Sampler
that just deterministically returns the k
nodes or beams with highest
log probability.
sample_nodes#
class DeterministicSampler(Sampler):
| ...
| @overrides
| def sample_nodes(
| self,
| log_probs: torch.Tensor,
| per_node_beam_size: int,
| state: StateType
| ) -> Tuple[torch.Tensor, torch.Tensor, StateType]
MultinomialSampler#
@Sampler.register("multinomial")
class MultinomialSampler(Sampler):
| def __init__(
| self,
| temperature: float = 1.0,
| with_replacement: bool = False
| ) -> None
A Sampler
which samples nodes from the given multinomial distribution. Beams are sampled
in the default, non-deterministic way.
Parameters
- temperature :
float
, optional (default =1.0
)
Atemperature
below 1.0 produces a sharper probability distribution and atemperature
above 1.0 produces a flatter probability distribution. - with_replacement :
bool
, optional (default =False
)
Whether to sample with replacement.
sample_nodes#
class MultinomialSampler(Sampler):
| ...
| @overrides
| def sample_nodes(
| self,
| log_probs: torch.Tensor,
| per_node_beam_size: int,
| state: StateType
| ) -> Tuple[torch.Tensor, torch.Tensor, StateType]
TopKSampler#
@Sampler.register("top-k")
class TopKSampler(Sampler):
| def __init__(
| self,
| k: int = 1,
| temperature: float = 1.0,
| with_replacement: bool = False
| )
A Sampler
which redistributes the probability mass function for nodes among the
top k
choices, then samples from that subset after re-normalizing the probabilities.
Beams are sampled in the default, deterministic way.
Parameters
- k :
int
, optional (default =1
)
The number of top choices to be selected from. - temperature :
float
, optional (default =1.0
)
Atemperature
below 1.0 produces a sharper probability distribution and atemperature
above 1.0 produces a flatter probability distribution. - with_replacement :
bool
, optional (default =False
)
If set toTrue
, samples will be selected with replacement from the top k choices.
sample_nodes#
class TopKSampler(Sampler):
| ...
| @overrides
| def sample_nodes(
| self,
| log_probs: torch.Tensor,
| per_node_beam_size: int,
| state: StateType
| ) -> Tuple[torch.Tensor, torch.Tensor, StateType]
TopPSampler#
@Sampler.register("top-p")
class TopPSampler(Sampler):
| def __init__(
| self,
| p: float = 0.9,
| temperature: float = 1.0,
| with_replacement: bool = False
| )
A Sampler
which redistributes the probability mass function for nodes among
the top choices with a cumulative probability of at least p
, then samples from that subset
after re-normalizing the probabilities.
Beams are sampled in the default, deterministic way.
Parameters
- p :
float
, optional (default =0.9
)
The cumulative probability cutoff threshold. A higher value ofp
will result in more possible examples to sample from. Ifwith_replacement
isFalse
and the number of possible samples is insufficient to sample without replacement from when callingsample_nodes
, then the topper_node_beam_size
examples will be chosen. - temperature :
float
, optional (default =1.0
)
Atemperature
below 1.0 produces a sharper probability distribution and atemperature
above 1.0 produces a flatter probability distribution. - with_replacement :
bool
, optional (default =False
)
If set toTrue
, samples will be selected with replacement from the top choices.
sample_nodes#
class TopPSampler(Sampler):
| ...
| @overrides
| def sample_nodes(
| self,
| log_probs: torch.Tensor,
| per_node_beam_size: int,
| state: StateType
| ) -> Tuple[torch.Tensor, torch.Tensor, StateType]
GumbelSampler#
@Sampler.register("gumbel")
class GumbelSampler(Sampler):
| def __init__(self, temperature: float = 1.0)
A Sampler
which uses the Gumbel-Top-K trick to sample without replacement. See
Stochastic Beams and Where to Find Them: The Gumbel-Top-k Trick for Sampling
Sequences Without Replacement, W Kool, H Van Hoof and M Welling, 2010.
Parameters
- temperature :
float
, optional (default =1.0
)
Atemperature
below 1.0 produces a sharper probability distribution and atemperature
above 1.0 produces a flatter probability distribution.
init_state#
class GumbelSampler(Sampler):
| ...
| @overrides
| def init_state(
| self,
| start_class_log_probabilities: torch.Tensor,
| batch_size: int,
| num_classes: int
| ) -> StateType
shape: (batch_size, num_classes)
sample_nodes#
class GumbelSampler(Sampler):
| ...
| @overrides
| def sample_nodes(
| self,
| log_probs: torch.Tensor,
| per_node_beam_size: int,
| state: StateType
| ) -> Tuple[torch.Tensor, torch.Tensor, StateType]
First apply temperature coefficient: shape: (batch_size * beam_size, num_classes)
sample_beams#
class GumbelSampler(Sampler):
| ...
| @overrides
| def sample_beams(
| self,
| log_probs: torch.Tensor,
| beam_size: int,
| state: StateType
| ) -> Tuple[torch.Tensor, torch.Tensor, StateType]
Returns the beams with the highest perturbed log probabilities.
gumbel#
class GumbelSampler(Sampler):
| ...
| def gumbel(self, phi) -> torch.Tensor
Sample Gumbel(phi)
.
phi
should have shape (batch_size, num_classes)
.
gumbel_with_max#
class GumbelSampler(Sampler):
| ...
| def gumbel_with_max(self, phi, T) -> torch.Tensor
Sample Gumbel(phi)
conditioned on the maximum value being equal to T
.
phi
should have shape (batch_size, num_classes)
and T
should have
shape (batch_size, 1)
.
BeamSearch#
class BeamSearch(FromParams):
| def __init__(
| self,
| end_index: int,
| max_steps: int = 50,
| beam_size: int = 10,
| per_node_beam_size: int = None,
| sampler: Sampler = None
| ) -> None
Implements the beam search algorithm for decoding the most likely sequences.
Parameters
-
end_index :
int
The index of the "stop" or "end" token in the target vocabulary. -
max_steps :
int
, optional (default =50
)
The maximum number of decoding steps to take, i.e. the maximum length of the predicted sequences. -
beam_size :
int
, optional (default =10
)
The width of the beam used. -
per_node_beam_size :
int
, optional (default =beam_size
)
The maximum number of candidates to consider per node, at each step in the search. If not given, this just defaults tobeam_size
. Setting this parameter to a number smaller thanbeam_size
may give better results, as it can introduce more diversity into the search. See Beam Search Strategies for Neural Machine Translation, Freitag and Al-Onaizan, 2017. -
sampler :
Sampler
, optional (default =None
)
An optionalSampler
which is used to pick next candidate nodes and beams. If not specified,DeterministicSampler
will be used, which just takes theper_node_beam_size
most likely nodes and thebeam_size
most likely beams.Using the
GumbelSampler
, on the other hand, will give you Stochastic Beam Search.
search#
class BeamSearch(FromParams):
| ...
| @torch.no_grad()
| def search(
| self,
| start_predictions: torch.Tensor,
| start_state: StateType,
| step: StepFunctionType
| ) -> Tuple[torch.Tensor, torch.Tensor]
Given a starting state and a step function, apply beam search to find the most likely target sequences.
Notes
If your step function returns -inf
for some log probabilities
(like if you're using a masked log-softmax) then some of the "best"
sequences returned may also have -inf
log probability. Specifically
this happens when the beam size is smaller than the number of actions
with finite log probability (non-zero probability) returned by the step function.
Therefore if you're using a mask you may want to check the results from search
and potentially discard sequences with non-finite log probability.
Parameters
-
start_predictions :
torch.Tensor
A tensor containing the initial predictions with shape(batch_size,)
. Usually the initial predictions are just the index of the "start" token in the target vocabulary. -
start_state :
StateType
The initial state passed to thestep
function. Each value of the state dict should be a tensor of shape(batch_size, *)
, where*
means any other number of dimensions. -
step :
StepFunctionType
A function that is responsible for computing the next most likely tokens, given the current state and the predictions from the last time step. The function should accept two or three arguments:- a tensor of shape
(group_size,)
representing the index of the predicted tokens from the last time step, - the current state, a
StateType
, and - optionally, the timestep, an
int
.
The
group_size
will bebatch_size * beam_size
, except in the initial step, for which it will just bebatch_size
.The function is expected to return a tuple, where the first element is a tensor of shape
(group_size, target_vocab_size)
containing the log probabilities of the tokens for the next step, and the second element is the updated state. The tensor in the state should have shape(group_size, *)
, where*
means any other number of dimensions. - a tensor of shape
Returns
Tuple[torch.Tensor, torch.Tensor]
Tuple of(predictions, log_probabilities)
, wherepredictions
has shape(batch_size, beam_size, max_steps)
andlog_probabilities
has shape(batch_size, beam_size)
.