Logo
0.9.0

Package Reference

  • allennlp.commands
    • allennlp.commands.subcommand
    • allennlp.commands.configure
    • allennlp.commands.evaluate
    • allennlp.commands.make_vocab
    • allennlp.commands.predict
    • allennlp.commands.train
    • allennlp.commands.fine_tune
    • allennlp.commands.elmo
    • allennlp.commands.dry_run
    • allennlp.commands.find_learning_rate
    • allennlp.commands.test_install
    • allennlp.commands.print_results
  • allennlp.common
    • allennlp.common.checks
    • allennlp.common.configuration
    • allennlp.common.file_utils
    • allennlp.common.from_params
    • allennlp.common.params
    • allennlp.common.registrable
    • allennlp.common.tee_logger
    • allennlp.common.testing
    • allennlp.common.checks
    • allennlp.common.util
  • allennlp.data
    • allennlp.data.dataset
    • allennlp.data.dataset_readers
      • allennlp.data.dataset_readers.dataset_reader
      • allennlp.data.dataset_readers.dataset_utils
      • allennlp.data.dataset_readers.babi
      • allennlp.data.dataset_readers.ccgbank
      • allennlp.data.dataset_readers.conll2000
      • allennlp.data.dataset_readers.conll2003
      • allennlp.data.dataset_readers.coreference_resolution
      • allennlp.data.dataset_readers.event2mind
      • allennlp.data.dataset_readers.interleaving_dataset_reader
      • allennlp.data.dataset_readers.language_modeling
      • allennlp.data.dataset_readers.masked_language_modeling
      • allennlp.data.dataset_readers.multiprocess_dataset_reader
      • allennlp.data.dataset_readers.next_token_lm
      • allennlp.data.dataset_readers.ontonotes_ner
      • allennlp.data.dataset_readers.penn_tree_bank
      • allennlp.data.dataset_readers.quora_paraphrase
      • allennlp.data.dataset_readers.reading_comprehension
      • allennlp.data.dataset_readers.semantic_dependency_parsing
      • allennlp.data.dataset_readers.semantic_parsing
        • allennlp.data.dataset_readers.semantic_parsing.wikitables
      • allennlp.data.dataset_readers.semantic_role_labeling
      • allennlp.data.dataset_readers.seq2seq
      • allennlp.data.dataset_readers.sequence_tagging
      • allennlp.data.dataset_readers.simple_language_modeling
      • allennlp.data.dataset_readers.snli
      • allennlp.data.dataset_readers.stanford_sentiment_tree_bank
      • allennlp.data.dataset_readers.universal_dependencies
      • allennlp.data.dataset_readers.universal_dependencies_multilang
      • allennlp.data.dataset_readers.quora_paraphrase
      • allennlp.data.dataset_readers.copynet_seq2seq
      • allennlp.data.dataset_readers.text_classification_json
    • allennlp.data.fields
    • allennlp.data.instance
    • allennlp.data.iterators
    • allennlp.data.token_indexers
    • allennlp.data.tokenizers
    • allennlp.data.vocabulary
  • allennlp.interpret
    • allennlp.interpret.attackers
    • allennlp.interpret.saliency_interpreters
  • allennlp.models
    • allennlp.models.model
    • allennlp.models.archival
    • allennlp.models.basic_classifier
    • allennlp.models.bert_for_classification
    • allennlp.models.biaffine_dependency_parser
    • allennlp.models.biaffine_dependency_parser_multilang
    • allennlp.models.biattentive_classification_network
    • allennlp.models.bimpm
    • allennlp.models.constituency_parser
    • allennlp.models.coreference_resolution
    • allennlp.models.crf_tagger
    • allennlp.models.decomposable_attention
    • allennlp.models.encoder_decoders
    • allennlp.models.ensemble
    • allennlp.models.esim
    • allennlp.models.event2mind
    • allennlp.models.graph_parser
    • allennlp.models.language_model
    • allennlp.models.masked_language_model
    • allennlp.models.next_token_lm
    • allennlp.models.reading_comprehension
    • allennlp.models.semantic_parsing
      • allennlp.models.semantic_parsing.nlvr
      • allennlp.models.semantic_parsing.wikitables
      • allennlp.models.semantic_parsing.atis
      • allennlp.models.semantic_parsing.quarel
    • allennlp.models.semantic_role_labeler
    • allennlp.models.simple_tagger
    • allennlp.models.srl_bert
    • allennlp.models.srl_util
  • allennlp.predictors
  • allennlp.modules
    • allennlp.modules.attention
    • allennlp.modules.matrix_attention
    • allennlp.modules.augmented_lstm
    • allennlp.modules.lstm_cell_with_projection
    • allennlp.modules.elmo
    • allennlp.modules.elmo_lstm
    • allennlp.modules.language_model_heads
    • allennlp.modules.conditional_random_field
    • allennlp.modules.feedforward
    • allennlp.modules.highway
    • allennlp.modules.matrix_attention
    • allennlp.modules.openai_transformer
    • allennlp.modules.seq2seq_encoders
    • allennlp.modules.seq2seq_decoders
    • allennlp.modules.seq2vec_encoders
    • allennlp.modules.span_extractors
    • allennlp.modules.similarity_functions
    • allennlp.modules.stacked_alternating_lstm
    • allennlp.modules.stacked_bidirectional_lstm
    • allennlp.modules.text_field_embedders
    • allennlp.modules.time_distributed
    • allennlp.modules.token_embedders
    • allennlp.modules.scalar_mix
    • allennlp.modules.layer_norm
    • allennlp.modules.pruner
    • allennlp.modules.maxout
    • allennlp.modules.input_variational_dropout
    • allennlp.modules.bimpm_matching
    • allennlp.modules.masked_layer_norm
    • allennlp.modules.sampled_softmax_loss
    • allennlp.modules.residual_with_layer_dropout
  • allennlp.nn
    • allennlp.nn.activations
    • allennlp.nn.chu_liu_edmonds
    • allennlp.nn.initializers
    • allennlp.nn.regularizers
    • allennlp.nn.util
    • allennlp.nn.beam_search
  • allennlp.semparse
    • allennlp.semparse.common
    • allennlp.semparse.contexts
    • allennlp.semparse.executors
    • allennlp.semparse.type_declarations
    • allennlp.semparse.worlds
    • allennlp.semparse.executors
    • allennlp.semparse.domain_languages
    • allennlp.semparse.util
  • allennlp.service
    • allennlp.service.server_simple
    • allennlp.service.config_explorer
  • allennlp.state_machines
    • allennlp.state_machines.states
    • allennlp.state_machines.trainers
    • allennlp.state_machines.transition_functions
  • allennlp.tools
  • allennlp.training
    • allennlp.training.callbacks
    • allennlp.training.callback_trainer
    • allennlp.training.checkpointer
    • allennlp.training.scheduler
    • allennlp.training.learning_rate_schedulers
    • allennlp.training.momentum_schedulers
    • allennlp.training.metric_tracker
    • allennlp.training.metrics
    • allennlp.training.moving_average
    • allennlp.training.no_op_trainer
    • allennlp.training.optimizers
    • allennlp.training.tensorboard_writer
    • allennlp.training.trainer
    • allennlp.training.trainer_base
    • allennlp.training.trainer_pieces
    • allennlp.training.util
  • allennlp.pretrained
AllenNLP
  • Docs »
  • allennlp.nn »
  • allennlp.nn.beam_search
  • View page source

allennlp.nn.beam_search¶

class allennlp.nn.beam_search.BeamSearch(end_index: int, max_steps: int = 50, beam_size: int = 10, per_node_beam_size: int = None)[source]¶

Bases: object

Implements the beam search algorithm for decoding the most likely sequences.

Parameters
end_indexint

The index of the “stop” or “end” token in the target vocabulary.

max_stepsint, optional (default = 50)

The maximum number of decoding steps to take, i.e. the maximum length of the predicted sequences.

beam_sizeint, optional (default = 10)

The width of the beam used.

per_node_beam_sizeint, optional (default = beam_size)

The maximum number of candidates to consider per node, at each step in the search. If not given, this just defaults to beam_size. Setting this parameter to a number smaller than beam_size may give better results, as it can introduce more diversity into the search. See Beam Search Strategies for Neural Machine Translation. Freitag and Al-Onaizan, 2017.

search(self, start_predictions: torch.Tensor, start_state: Dict[str, torch.Tensor], step: Callable[[torch.Tensor, Dict[str, torch.Tensor]], Tuple[torch.Tensor, Dict[str, torch.Tensor]]]) → Tuple[torch.Tensor, torch.Tensor][source]¶

Given a starting state and a step function, apply beam search to find the most likely target sequences.

Parameters
start_predictionstorch.Tensor

A tensor containing the initial predictions with shape (batch_size,). Usually the initial predictions are just the index of the “start” token in the target vocabulary.

start_stateStateType

The initial state passed to the step function. Each value of the state dict should be a tensor of shape (batch_size, *), where * means any other number of dimensions.

stepStepFunctionType

A function that is responsible for computing the next most likely tokens, given the current state and the predictions from the last time step. The function should accept two arguments. The first being a tensor of shape (group_size,), representing the index of the predicted tokens from the last time step, and the second being the current state. The group_size will be batch_size * beam_size, except in the initial step, for which it will just be batch_size. The function is expected to return a tuple, where the first element is a tensor of shape (group_size, target_vocab_size) containing the log probabilities of the tokens for the next step, and the second element is the updated state. The tensor in the state should have shape (group_size, *), where * means any other number of dimensions.

Returns
Tuple[torch.Tensor, torch.Tensor]

Tuple of (predictions, log_probabilities), where predictions has shape (batch_size, beam_size, max_steps) and log_probabilities has shape (batch_size, beam_size).

Notes

If your step function returns -inf for some log probabilities (like if you’re using a masked log-softmax) then some of the “best” sequences returned may also have -inf log probability. Specifically this happens when the beam size is smaller than the number of actions with finite log probability (non-zero probability) returned by the step function. Therefore if you’re using a mask you may want to check the results from search and potentially discard sequences with non-finite log probability.

Next Previous

© Copyright 2018, Allen Institute for Artificial Intelligence

Built with Sphinx using a theme provided by Read the Docs.