allennlp.state_machines¶
This module contains code for using state machines in a model to do transition-based decoding. “Transition-based decoding” is where you start in some state, iteratively transition between states, and have some kind of supervision signal that tells you which end states, or which transition sequences, are “good”.
Typical seq2seq decoding, where you have a fixed vocabulary and no constraints on your output, can be done much more efficiently than we do in this code. This is intended for structured models that have constraints on their outputs.
The key abstractions in this code are the following:
State
represents the current state of decoding, containing a list of all of the actions taken so far, and a current score for the state. It also has methods around determining whether the state is “finished” and for combining states for batched computation.
TransitionFunction
is atorch.nn.Module
that models the transition function between states. Its main method istake_step
, which generates a ranked list of next states given a current state.
DecoderTrainer
is an algorithm for training the transition function with some kind of supervision signal. There are many options for training algorithms and supervision signals; this is an abstract class that is generic over the type of the supervision signal.
There is also a generic BeamSearch
class for finding the k
highest-scoring transition
sequences given a trained TransitionFunction
and an initial State
.
-
class
allennlp.state_machines.beam_search.
BeamSearch
(beam_size: int, per_node_beam_size: int = None, initial_sequence: torch.Tensor = None, keep_beam_details: bool = False)[source]¶ Bases:
allennlp.common.from_params.FromParams
,typing.Generic
This class implements beam search over transition sequences given an initial
State
and aTransitionFunction
, returning the highest scoring final states found by the beam (the states will keep track of the transition sequence themselves).The initial
State
is assumed to be batched. The value we return from the search is a dictionary from batch indices to ranked finished states.IMPORTANT: We assume that the
TransitionFunction
that you are using returns possible next states in sorted order, so we do not do an additional sort inside ofBeamSearch.search()
. If you’re implementing your ownTransitionFunction
, you must ensure that you’ve sorted the states that you return.- Parameters
- beam_size
int
The beam size to use.
- per_node_beam_size
int
, optional (default = beam_size) The maximum number of candidates to consider per node, at each step in the search. If not given, this just defaults to beam_size. Setting this parameter to a number smaller than beam_size may give better results, as it can introduce more diversity into the search. See Freitag and Al-Onaizan 2017, “Beam Search Strategies for Neural Machine Translation”.
- initial_sequence
torch.Tensor
, optional (default = None) If you provide a (sequence_length,) tensor here, the beam search will be constrained to only sequences that begin with the provided initial_sequence.
- keep_beam_details
bool
, optional (default = False) If True, we store snapshots of each beam in an instance variable
beam_snapshots
, which is a dict: { batch_index -> [timestep0_histories, …, timestepk_histories] }, where a “timestep history” is just a pair (score, action_history) that was considered at that timestep.
- beam_size
-
constrained_to
(self, initial_sequence: torch.Tensor, keep_beam_details: bool = True) → 'BeamSearch'[source]¶ Return a new BeamSearch instance that’s like this one but with the specified constraint.
-
search
(self, num_steps: int, initial_state: ~StateType, transition_function: allennlp.state_machines.transition_functions.transition_function.TransitionFunction, keep_final_unfinished_states: bool = True) → Dict[int, List[~StateType]][source]¶ - Parameters
- num_steps
int
How many steps should we take in our search? This is an upper bound, as it’s possible for the search to run out of valid actions before hitting this number, or for all states on the beam to finish.
- initial_state
StateType
The starting state of our search. This is assumed to be batched, and our beam search is batch-aware - we’ll keep
beam_size
states around for each instance in the batch.- transition_function
TransitionFunction
The
TransitionFunction
object that defines and scores transitions from one state to the next.- keep_final_unfinished_states
bool
, optional (default=True) If we run out of steps before a state is “finished”, should we return that state in our search results?
- num_steps
- Returns
- best_states
Dict[int, List[StateType]]
This is a mapping from batch index to the top states for that instance.
- best_states
-
class
allennlp.state_machines.constrained_beam_search.
ConstrainedBeamSearch
(beam_size: Optional[int], allowed_sequences: torch.Tensor, allowed_sequence_mask: torch.Tensor, per_node_beam_size: int = None)[source]¶ Bases:
object
This class implements beam search over transition sequences given an initial
State
, aTransitionFunction
, and a list of allowed transition sequences. We will do a beam search over the list of allowed sequences and return the highest scoring states found by the beam. This is only actually a beam search if your beam size is smaller than the list of allowed transition sequences; otherwise, we are just scoring and sorting the sequences using a prefix tree.The initial
State
is assumed to be batched. The value we return from the search is a dictionary from batch indices to ranked finished states.IMPORTANT: We assume that the
TransitionFunction
that you are using returns possible next states in sorted order, so we do not do an additional sort inside ofConstrainedBeamSearch.search()
. If you’re implementing your ownTransitionFunction
, you must ensure that you’ve sorted the states that you return.- Parameters
- beam_size
Optional[int]
The beam size to use. Because this is a constrained beam search, we allow for the case where you just want to evaluate all options in the constrained set. In that case, you don’t need a beam, and you can pass a beam size of
None
, and we will just evaluate everything. This lets us be more efficient inTransitionFunction.take_step()
and skip the sorting that is typically done there.- allowed_sequences
torch.Tensor
A
(batch_size, num_sequences, sequence_length)
tensor containing the transition sequences that we will search in. The values in this tensor must match whatever theState
keeps in itsaction_history
variable (typically this is action indices).- allowed_sequence_mask
torch.Tensor
A
(batch_size, num_sequences, sequence_length)
tensor indicating whether each entry in theallowed_sequences
tensor is padding. The allowed sequences could be padded both on thenum_sequences
dimension and thesequence_length
dimension.- per_node_beam_size
int
, optional (default = beam_size) The maximum number of candidates to consider per node, at each step in the search. If not given, this just defaults to beam_size. Setting this parameter to a number smaller than beam_size may give better results, as it can introduce more diversity into the search. See Freitag and Al-Onaizan 2017, “Beam Search Strategies for Neural Machine Translation”.
- beam_size
-
search
(self, initial_state: allennlp.state_machines.states.state.State, transition_function: allennlp.state_machines.transition_functions.transition_function.TransitionFunction) → Dict[int, List[allennlp.state_machines.states.state.State]][source]¶ - Parameters
- initial_state
State
The starting state of our search. This is assumed to be batched, and our beam search is batch-aware - we’ll keep
beam_size
states around for each instance in the batch.- transition_function
TransitionFunction
The
TransitionFunction
object that defines and scores transitions from one state to the next.
- initial_state
- Returns
- best_states
Dict[int, List[State]]
This is a mapping from batch index to the top states for that instance.
- best_states
-
allennlp.state_machines.util.
construct_prefix_tree
(targets: Union[torch.Tensor, List[List[List[int]]]], target_mask: Union[torch.Tensor, NoneType] = None) → List[Dict[Tuple[int, ...], Set[int]]][source]¶ Takes a list of valid target action sequences and creates a mapping from all possible (valid) action prefixes to allowed actions given that prefix. While the method is called
construct_prefix_tree
, we’re actually returning a map that has as keys the paths to all internal nodes of the trie, and as values all of the outgoing edges from that node.targets
is assumed to be a tensor of shape(batch_size, num_valid_sequences, sequence_length)
. If the mask is notNone
, it is assumed to have the same shape, and we will ignore any value intargets
that has a value of0
in the corresponding position in the mask. We assume that the mask has the format 1*0* for each item intargets
- that is, once we see our first zero, we stop processing that target.For example, if
targets
is the following tensor:[[1, 2, 3], [1, 4, 5]]
, the return value will be:{(): set([1]), (1,): set([2, 4]), (1, 2): set([3]), (1, 4): set([5])}
.This could be used, e.g., to do an efficient constrained beam search, or to efficiently evaluate the probability of all of the target sequences.