beam_search
allennlp.nn.beam_search
StateType¶
StateType = Dict[str, torch.Tensor]
StepFunctionTypeWithTimestep¶
StepFunctionTypeWithTimestep = Callable[
[torch.Tensor, StateType, int], Tuple[torch.Tensor, StateType]
]
StepFunctionTypeNoTimestep¶
StepFunctionTypeNoTimestep = Callable[[torch.Tensor, StateType], Tuple[torch.Tensor, StateType]]
StepFunctionType¶
StepFunctionType = TypeVar(
"StepFunctionType", StepFunctionTypeWithTimestep, StepFunctionTypeNoTimestep
)
The type of step function that can be passed to BeamSearch.search
.
This can either be StepFunctionTypeWithTimestep
or StepFunctionTypeNoTimestep
.
ConstraintStateType¶
ConstraintStateType = List[List[Dict[str, Any]]]
Sampler¶
class Sampler(Registrable)
An abstract class that can be used to sample candidates (either nodes or beams)
within BeamSearch
.
A Sampler
just has three methods, init_state()
, sample_nodes()
and sample_beams()
.
init_state()
takes three arguments:
- a tensor of starting log probs with shape
(batch_size,, num_classes)
, - the batch size, an int,
- and the number of classes, also an int.
It returns a state dictionary with any state tensors needed for subsequent
calls to sample_nodes()
and sample_beams()
.
By default this method just returns an empty dictionary.
Both sample_nodes()
and sample_beams()
should take three arguments:
- tensor of normalized log probabilities with shape
(batch_size, num_examples)
, - an integer representing the number of samples to take for each example in the batch,
- and a state dictionary which could contain any tensors needed for the
Sampler
to keep track of state.
For sample_nodes()
, num_examples = num_classes
, but for sample_beams
,
num_examples = beam_size * per_node_beam_size
.
The return value should be a tuple containing:
- a tensor of log probabilities of the sampled examples with shape
(batch_size, num_samples)
, - a tensor of indices of the sampled examples with shape
(batch_size, num_samples)
, - and the updated state dictionary.
A default implementation of sample_beams
is provided, which just deterministically
picks the k
examples with highest log probability.
default_implementation¶
class Sampler(Registrable):
| ...
| default_implementation = "deterministic"
init_state¶
class Sampler(Registrable):
| ...
| def init_state(
| self,
| start_class_log_probabilities: torch.Tensor,
| batch_size: int,
| num_classes: int
| ) -> StateType
sample_nodes¶
class Sampler(Registrable):
| ...
| def sample_nodes(
| self,
| log_probs: torch.Tensor,
| per_node_beam_size: int,
| state: StateType
| ) -> Tuple[torch.Tensor, torch.Tensor, StateType]
sample_beams¶
class Sampler(Registrable):
| ...
| def sample_beams(
| self,
| log_probs: torch.Tensor,
| beam_size: int,
| state: StateType
| ) -> Tuple[torch.Tensor, torch.Tensor, StateType]
DeterministicSampler¶
@Sampler.register("deterministic")
class DeterministicSampler(Sampler)
A Sampler
that just deterministically returns the k
nodes or beams with highest
log probability.
sample_nodes¶
class DeterministicSampler(Sampler):
| ...
| def sample_nodes(
| self,
| log_probs: torch.Tensor,
| per_node_beam_size: int,
| state: StateType
| ) -> Tuple[torch.Tensor, torch.Tensor, StateType]
MultinomialSampler¶
@Sampler.register("multinomial")
class MultinomialSampler(Sampler):
| def __init__(
| self,
| temperature: float = 1.0,
| with_replacement: bool = False
| ) -> None
A Sampler
which samples nodes from the given multinomial distribution. Beams are sampled
in the default, non-deterministic way.
Parameters¶
- temperature :
float
, optional (default =1.0
)
Atemperature
below 1.0 produces a sharper probability distribution and atemperature
above 1.0 produces a flatter probability distribution. - with_replacement :
bool
, optional (default =False
)
Whether to sample with replacement.
sample_nodes¶
class MultinomialSampler(Sampler):
| ...
| def sample_nodes(
| self,
| log_probs: torch.Tensor,
| per_node_beam_size: int,
| state: StateType
| ) -> Tuple[torch.Tensor, torch.Tensor, StateType]
TopKSampler¶
@Sampler.register("top-k")
class TopKSampler(Sampler):
| def __init__(
| self,
| k: int = 1,
| temperature: float = 1.0,
| with_replacement: bool = False
| )
A Sampler
which redistributes the probability mass function for nodes among the
top k
choices, then samples from that subset after re-normalizing the probabilities.
Beams are sampled in the default, deterministic way.
Parameters¶
- k :
int
, optional (default =1
)
The number of top choices to be selected from. - temperature :
float
, optional (default =1.0
)
Atemperature
below 1.0 produces a sharper probability distribution and atemperature
above 1.0 produces a flatter probability distribution. - with_replacement :
bool
, optional (default =False
)
If set toTrue
, samples will be selected with replacement from the top k choices.
sample_nodes¶
class TopKSampler(Sampler):
| ...
| def sample_nodes(
| self,
| log_probs: torch.Tensor,
| per_node_beam_size: int,
| state: StateType
| ) -> Tuple[torch.Tensor, torch.Tensor, StateType]
TopPSampler¶
@Sampler.register("top-p")
class TopPSampler(Sampler):
| def __init__(
| self,
| p: float = 0.9,
| temperature: float = 1.0,
| with_replacement: bool = False
| )
A Sampler
which redistributes the probability mass function for nodes among
the top choices with a cumulative probability of at least p
, then samples from that subset
after re-normalizing the probabilities.
Beams are sampled in the default, deterministic way.
Parameters¶
- p :
float
, optional (default =0.9
)
The cumulative probability cutoff threshold. A higher value ofp
will result in more possible examples to sample from. Ifwith_replacement
isFalse
and the number of possible samples is insufficient to sample without replacement from when callingsample_nodes
, then the topper_node_beam_size
examples will be chosen. - temperature :
float
, optional (default =1.0
)
Atemperature
below 1.0 produces a sharper probability distribution and atemperature
above 1.0 produces a flatter probability distribution. - with_replacement :
bool
, optional (default =False
)
If set toTrue
, samples will be selected with replacement from the top choices.
sample_nodes¶
class TopPSampler(Sampler):
| ...
| def sample_nodes(
| self,
| log_probs: torch.Tensor,
| per_node_beam_size: int,
| state: StateType
| ) -> Tuple[torch.Tensor, torch.Tensor, StateType]
GumbelSampler¶
@Sampler.register("gumbel")
class GumbelSampler(Sampler):
| def __init__(self, temperature: float = 1.0)
A Sampler
which uses the Gumbel-Top-K trick to sample without replacement. See
Stochastic Beams and Where to Find Them: The Gumbel-Top-k Trick for Sampling
Sequences Without Replacement, W Kool, H Van Hoof and M Welling, 2010.
Parameters¶
- temperature :
float
, optional (default =1.0
)
Atemperature
below 1.0 produces a sharper probability distribution and atemperature
above 1.0 produces a flatter probability distribution.
init_state¶
class GumbelSampler(Sampler):
| ...
| def init_state(
| self,
| start_class_log_probabilities: torch.Tensor,
| batch_size: int,
| num_classes: int
| ) -> StateType
sample_nodes¶
class GumbelSampler(Sampler):
| ...
| def sample_nodes(
| self,
| log_probs: torch.Tensor,
| per_node_beam_size: int,
| state: StateType
| ) -> Tuple[torch.Tensor, torch.Tensor, StateType]
sample_beams¶
class GumbelSampler(Sampler):
| ...
| def sample_beams(
| self,
| log_probs: torch.Tensor,
| beam_size: int,
| state: StateType
| ) -> Tuple[torch.Tensor, torch.Tensor, StateType]
Returns the beams with the highest perturbed log probabilities.
gumbel¶
class GumbelSampler(Sampler):
| ...
| def gumbel(self, phi) -> torch.Tensor
Sample Gumbel(phi)
.
phi
should have shape (batch_size, num_classes)
.
gumbel_with_max¶
class GumbelSampler(Sampler):
| ...
| def gumbel_with_max(self, phi, T) -> torch.Tensor
Sample Gumbel(phi)
conditioned on the maximum value being equal to T
.
phi
should have shape (batch_size, num_classes)
and T
should have
shape (batch_size, 1)
.
FinalSequenceScorer¶
class FinalSequenceScorer(Registrable)
An abstract class that can be used to score the final generated sequences found by beam search. Given the predicted sequences and the corresponding log probabilities of those sequences, the class calculates and returns the final score of the sequences.
The default implementation scores the sequences using the sum of the log probabilities of the sequence, which is passed as input.
default_implementation¶
class FinalSequenceScorer(Registrable):
| ...
| default_implementation = "sequence-log-prob"
score¶
class FinalSequenceScorer(Registrable):
| ...
| def score(
| self,
| predictions: torch.Tensor,
| log_probabilities: torch.Tensor,
| end_index: int
| ) -> torch.Tensor
Score the final predictions found by beam search.
Parameters¶
-
predictions :
torch.Tensor
A tensor containing the initial predictions with shape(batch_size, beam_size, max_steps)
. -
log_probabilities :
torch.Tensor
A tensor containing the log probabilities of the sequence, defined as the sum of the log probabilities per token, with shape(batch_size, beam_size)
. -
end_index :
int
The index of the end symbol.
Returns¶
torch.Tensor
A tensor of the final sequence scores of shape(batch_size, beam_size)
.
SequenceLogProbabilityScorer¶
@FinalSequenceScorer.register("sequence-log-prob")
class SequenceLogProbabilityScorer(FinalSequenceScorer)
A FinalSequenceScorer
which scores the sequences by the sum of the log probabilities
across the sequence's tokens.
score¶
class SequenceLogProbabilityScorer(FinalSequenceScorer):
| ...
| def score(
| self,
| predictions: torch.Tensor,
| log_probabilities: torch.Tensor,
| end_index: int
| ) -> torch.Tensor
LengthNormalizedSequenceLogProbabilityScorer¶
@FinalSequenceScorer.register("length-normalized-sequence-log-prob")
class LengthNormalizedSequenceLogProbabilityScorer(FinalSequenceScorer):
| def __init__(self, length_penalty: float = 1.0)
A FinalSequenceScorer
which scores the sequences by the average log probability of the
tokens in the sequence. It optionally includes a length penalty which promotes
or demotes sequences based on their lengths. The final score for a sequence will
be (sequence_log_probability) / (sequence_length ** length_penalty)
. The sequence length
here includes the end token.
Parameters¶
- length_penalty :
float
, optional (default =1.0
)
The length penalty to use. A value of 1.0 means no length penalty is used. A value > 1.0 favors longer sequences, and < 1.0 favors shorter sequences.
score¶
class LengthNormalizedSequenceLogProbabilityScorer(FinalSequenceScorer):
| ...
| def score(
| self,
| predictions: torch.Tensor,
| log_probabilities: torch.Tensor,
| end_index: int
| ) -> torch.Tensor
Constraint¶
class Constraint(Registrable):
| def __init__(self, vocab: Optional[Vocabulary] = None) -> None
An abstract class that can be used to enforce constraints on the output predictions by manipulating the class log probabilities during beam search.
A Constraint
just has three methods that need to be implemented by subclasses:
init_state()
, apply()
and _update_state()
.
init_state()
takes one argument:
- the batch size, an int
It returns a constraint state, which is a nested list of dictionaries, with any state needed for subsequent
calls to apply()
and update_state()
. The length of the outer list should be equal to batch_size
.
Each inner list should be of length 1.
apply()
takes two arguments:
- the constraint state, which is a nested list of dictionaries. The length of the outer list is
batch_size
and the length of each inner list isbeam_size
except on the first timeapply()
is called when it is 1. class_log_probabilities
, a tensor of shape(batch_size, beam_size, num_classes)
that contains the log probabilities for the classes during search. The first timeapply()
is called,beam_size = 1
.
The apply()
method should return new class_log_probabilities
that enforce the constraint
for this step of beam search. For instance, it may prevent a specific class from being selected by setting
the corresponding log probability to a negligible value such as float("-inf")
or
min_value_of_dtype(class_log_probabilities.dtype)
.
_update_state()
takes two arguments:
- the copied parent constraint state, which is a nested list of dictionaries.
state[i][j]
contains the copied state for the parent oflast_prediction[i, j]
. It is unique to that batch and beam, so it can be directly edited in-place without affecting the others. - last_prediction, a tensor of shape
(batch_size, beam_size)
containing the predictions from the last step of beam search.
The _update_state()
function should return a new constraint state, a nested list of dictionaries of
length batch_size
and inner list of length beam_size
, one for each of the predictions in last_prediction
.
init_state¶
class Constraint(Registrable):
| ...
| def init_state(self, batch_size: int) -> ConstraintStateType
apply¶
class Constraint(Registrable):
| ...
| def apply(
| self,
| state: ConstraintStateType,
| class_log_probabilities: torch.Tensor
| ) -> torch.Tensor
update_state¶
class Constraint(Registrable):
| ...
| def update_state(
| self,
| state: ConstraintStateType,
| last_prediction: torch.Tensor,
| last_backpointer: Optional[torch.Tensor] = None
| ) -> ConstraintStateType
_update_state¶
class Constraint(Registrable):
| ...
| def _update_state(
| self,
| state: ConstraintStateType,
| last_prediction: torch.Tensor
| ) -> ConstraintStateType
RepeatedNGramBlockingConstraint¶
@Constraint.register("repeated-ngram-blocking")
class RepeatedNGramBlockingConstraint(Constraint):
| def __init__(self, ngram_size: int, **kwargs) -> None
init_state¶
class RepeatedNGramBlockingConstraint(Constraint):
| ...
| def init_state(self, batch_size: int) -> ConstraintStateType
apply¶
class RepeatedNGramBlockingConstraint(Constraint):
| ...
| def apply(
| self,
| state: ConstraintStateType,
| class_log_probabilities: torch.Tensor
| ) -> torch.Tensor
BeamSearch¶
class BeamSearch(Registrable):
| def __init__(
| self,
| end_index: int,
| max_steps: int = 50,
| beam_size: int = 10,
| per_node_beam_size: int = None,
| sampler: Sampler = None,
| min_steps: Optional[int] = None,
| final_sequence_scorer: FinalSequenceScorer = None,
| constraints: Optional[List[Lazy[Constraint]]] = None,
| vocab: Optional[Vocabulary] = None
| ) -> None
Implements the beam search algorithm for decoding the most likely sequences.
Parameters¶
-
end_index :
int
The index of the "stop" or "end" token in the target vocabulary. -
max_steps :
int
, optional (default =50
)
The maximum number of decoding steps to take, i.e. the maximum length of the predicted sequences. -
beam_size :
int
, optional (default =10
)
The width of the beam used. -
per_node_beam_size :
int
, optional (default =beam_size
)
The maximum number of candidates to consider per node, at each step in the search. If not given, this just defaults tobeam_size
. Setting this parameter to a number smaller thanbeam_size
may give better results, as it can introduce more diversity into the search. See Beam Search Strategies for Neural Machine Translation, Freitag and Al-Onaizan, 2017. -
sampler :
Sampler
, optional (default =None
)
An optionalSampler
which is used to pick next candidate nodes and beams. If not specified,DeterministicSampler
will be used, which just takes theper_node_beam_size
most likely nodes and thebeam_size
most likely beams.Using the
GumbelSampler
, on the other hand, will give you Stochastic Beam Search. -
min_steps :
int
, optional (default =None
)
The minimum number of decoding steps to take, i.e. the minimum length of the predicted sequences. This does not include the start or end tokens. IfNone
, no minimum is enforced. -
final_sequence_scorer :
FinalSequenceScorer
, optional (default =None
)
An optionalFinalSequenceScorer
which is used to score the final generated sequences. The output from this module is what is returned by thesearch
method. If not specified,SequenceLogProbabilityScorer
will be used, which scores the sequences by the sum of the token log probabilities. -
constraints :
List[Constraint]
, optional (default =None
)
An optional list ofConstraint
s which should be applied during beam search. If not provided, no constraints will be enforced. -
vocab :
Vocabulary
Ifconstraints
is notNone
, thenVocabulary
will be passed to each constraint during its initialization. Having access to the vocabulary may be useful for certain contraints, e.g., to mask out invalid predictions during structured prediction.In a typical AllenNLP configuration file, this parameter does not get an entry under the "model", it gets specified as a top-level parameter, then is passed in to the model separately.
default_implementation¶
class BeamSearch(Registrable):
| ...
| default_implementation = "beam_search"
search¶
class BeamSearch(Registrable):
| ...
| @torch.no_grad()
| def search(
| self,
| start_predictions: torch.Tensor,
| start_state: StateType,
| step: StepFunctionType
| ) -> Tuple[torch.Tensor, torch.Tensor]
Given a starting state and a step function, apply beam search to find the most likely target sequences.
Note
If your step function returns -inf
for some log probabilities
(like if you're using a masked log-softmax) then some of the "best"
sequences returned may also have -inf
log probability. Specifically
this happens when the beam size is smaller than the number of actions
with finite log probability (non-zero probability) returned by the step function.
Therefore if you're using a mask you may want to check the results from search
and potentially discard sequences with non-finite log probability.
Parameters¶
-
start_predictions :
torch.Tensor
A tensor containing the initial predictions with shape(batch_size,)
. Usually the initial predictions are just the index of the "start" token in the target vocabulary. -
start_state :
StateType
The initial state passed to thestep
function. Each value of the state dict should be a tensor of shape(batch_size, *)
, where*
means any other number of dimensions. -
step :
StepFunctionType
A function that is responsible for computing the next most likely tokens, given the current state and the predictions from the last time step. The function should accept two or three arguments:- a tensor of shape
(group_size,)
representing the index of the predicted tokens from the last time step, - the current state, a
StateType
, and - optionally, the timestep, an
int
.
The
group_size
will bebatch_size * beam_size
, except in the initial step, for which it will just bebatch_size
.The function is expected to return a tuple, where the first element is a tensor of shape
(group_size, target_vocab_size)
containing the log probabilities of the tokens for the next step, and the second element is the updated state. The tensor in the state should have shape(group_size, *)
, where*
means any other number of dimensions. - a tensor of shape
Returns¶
Tuple[torch.Tensor, torch.Tensor]
Tuple of(predictions, final_scores)
, wherepredictions
has shape(batch_size, beam_size, max_steps)
andfinal_scores
has shape(batch_size, beam_size)
.