allennlp.state_machines.trainers¶
-
class
allennlp.state_machines.trainers.decoder_trainer.
DecoderTrainer
[source]¶ Bases:
typing.Generic
DecoderTrainers
define a training regime for transition-based decoders. ADecoderTrainer
assumes an initialState
, aTransitionFunction
function that can traverse the state space, and some supervision signal. Given these things, theDecoderTrainer
trains theTransitionFunction
function to traverse the state space to end up at good end states.Concrete implementations of this abstract base class could do things like maximum marginal likelihood, SEARN, LaSO, or other structured learning algorithms. If you’re just trying to maximize the probability of a single target sequence where the possible outputs are the same for each timestep (as in, e.g., typical machine translation training regimes), there are way more efficient ways to do that than using this API.
-
decode
(self, initial_state: allennlp.state_machines.states.state.State, transition_function: allennlp.state_machines.transition_functions.transition_function.TransitionFunction, supervision: ~SupervisionType) → Dict[str, torch.Tensor][source]¶ Takes an initial state object, a means of transitioning from state to state, and a supervision signal, and uses the supervision to train the transition function to pick “good” states.
This function should typically return a
loss
key during training, which theModel
will use as its loss.- Parameters
- initial_state
State
This is the initial state for decoding, typically initialized after running some kind of encoder on some inputs.
- transition_function
TransitionFunction
This is the transition function that scores all possible actions that can be taken in a given state, and returns a ranked list of next states at each step of decoding.
- supervision
SupervisionType
This is the supervision that is used to train the
transition_function
function to pick “good” states. You can use whatever kind of supervision you want (e.g., a single “gold” action sequence, a set of possible “gold” action sequences, a reward function, etc.). We usetyping.Generics
to make sure that our static type checker is happy with how you’ve matched the supervision that you provide in the model to theDecoderTrainer
that you want to use.
- initial_state
-
-
class
allennlp.state_machines.trainers.maximum_marginal_likelihood.
MaximumMarginalLikelihood
(beam_size: int = None)[source]¶ Bases:
allennlp.state_machines.trainers.decoder_trainer.DecoderTrainer
This class trains a decoder by maximizing the marginal likelihood of the targets. That is, during training, we are given a set of acceptable or possible target sequences, and we optimize the sum of the probability the model assigns to each item in the set. This allows the model to distribute its probability mass over the set however it chooses, without forcing all of the given target sequences to have high probability. This is helpful, for example, if you have good reason to expect that the correct target sequence is in the set, but aren’t sure which of the sequences is actually correct.
This implementation of maximum marginal likelihood requires the model you use to be locally normalized; that is, at each decoding timestep, we assume that the model creates a normalized probability distribution over actions. This assumption is necessary, because we do no explicit normalization in our loss function, we just sum the probabilities assigned to all correct target sequences, relying on the local normalization at each time step to push probability mass from bad actions to good ones.
- Parameters
- beam_size
int
, optional (default=None) We can optionally run a constrained beam search over the provided targets during decoding. This narrows the set of transition sequences that are marginalized over in the loss function, keeping only the top
beam_size
sequences according to the model. If this isNone
, we will keep all of the provided sequences in the loss computation.
- beam_size
-
decode
(self, initial_state: allennlp.state_machines.states.state.State, transition_function: allennlp.state_machines.transition_functions.transition_function.TransitionFunction, supervision: Tuple[torch.Tensor, torch.Tensor]) → Dict[str, torch.Tensor][source]¶ Takes an initial state object, a means of transitioning from state to state, and a supervision signal, and uses the supervision to train the transition function to pick “good” states.
This function should typically return a
loss
key during training, which theModel
will use as its loss.- Parameters
- initial_state
State
This is the initial state for decoding, typically initialized after running some kind of encoder on some inputs.
- transition_function
TransitionFunction
This is the transition function that scores all possible actions that can be taken in a given state, and returns a ranked list of next states at each step of decoding.
- supervision
SupervisionType
This is the supervision that is used to train the
transition_function
function to pick “good” states. You can use whatever kind of supervision you want (e.g., a single “gold” action sequence, a set of possible “gold” action sequences, a reward function, etc.). We usetyping.Generics
to make sure that our static type checker is happy with how you’ve matched the supervision that you provide in the model to theDecoderTrainer
that you want to use.
- initial_state
-
class
allennlp.state_machines.trainers.expected_risk_minimization.
ExpectedRiskMinimization
(beam_size: int, normalize_by_length: bool, max_decoding_steps: int, max_num_decoded_sequences: int = 1, max_num_finished_states: int = None)[source]¶ Bases:
allennlp.state_machines.trainers.decoder_trainer.DecoderTrainer
This class implements a trainer that minimizes the expected value of a cost function over the space of some candidate sequences produced by a decoder. We generate the candidate sequences by performing beam search (which is one of the two popular ways of getting these sequences, the other one being sampling; see “Classical Structured Prediction Losses for Sequence to Sequence Learning” by Edunov et al., 2017 for more details).
- Parameters
- beam_size
int
- noramlize_by_length
bool
Should the log probabilities be normalized by length before renormalizing them? Edunov et al. do this in their work.
- max_decoding_steps
int
The maximum number of steps we should take during decoding.
- max_num_decoded_sequences
int
, optional (default=1) Maximum number of sorted decoded sequences to return. Defaults to 1.
- max_num_finished_states
int
, optional (default = None) Maximum number of finished states to keep after search. This is to finished states as
beam_size
is to unfinished ones. Costs are computed for only these number of states per instance. If not set, we will keep all the finished states.
- beam_size
-
decode
(self, initial_state: allennlp.state_machines.states.state.State, transition_function: allennlp.state_machines.transition_functions.transition_function.TransitionFunction, supervision: Callable[[~StateType], torch.Tensor]) → Dict[str, torch.Tensor][source]¶ Takes an initial state object, a means of transitioning from state to state, and a supervision signal, and uses the supervision to train the transition function to pick “good” states.
This function should typically return a
loss
key during training, which theModel
will use as its loss.- Parameters
- initial_state
State
This is the initial state for decoding, typically initialized after running some kind of encoder on some inputs.
- transition_function
TransitionFunction
This is the transition function that scores all possible actions that can be taken in a given state, and returns a ranked list of next states at each step of decoding.
- supervision
SupervisionType
This is the supervision that is used to train the
transition_function
function to pick “good” states. You can use whatever kind of supervision you want (e.g., a single “gold” action sequence, a set of possible “gold” action sequences, a reward function, etc.). We usetyping.Generics
to make sure that our static type checker is happy with how you’ve matched the supervision that you provide in the model to theDecoderTrainer
that you want to use.
- initial_state