auto_regressive
allennlp_models.generation.modules.seq_decoders.auto_regressive
AutoRegressiveSeqDecoder#
@SeqDecoder.register("auto_regressive_seq_decoder")
class AutoRegressiveSeqDecoder(SeqDecoder):
| def __init__(
| self,
| vocab: Vocabulary,
| decoder_net: DecoderNet,
| target_embedder: Embedding,
| target_namespace: str = "tokens",
| beam_search: Lazy[BeamSearch] = Lazy(BeamSearch),
| tie_output_embedding: bool = False,
| scheduled_sampling_ratio: float = 0,
| label_smoothing_ratio: Optional[float] = None,
| tensor_based_metric: Metric = None,
| token_based_metric: Metric = None,
| **kwargs
| ) -> None
An autoregressive decoder that can be used for most seq2seq tasks.
Parameters¶
- vocab :
Vocabulary
Vocabulary containing source and target vocabularies. They may be under the same namespace (tokens
) or the target tokens can have a different namespace, in which case it needs to be specified astarget_namespace
. - decoder_net :
DecoderNet
Module that contains implementation of neural network for decoding output elements - target_embedder :
Embedding
Embedder for target tokens. - target_namespace :
str
, optional (default ='tokens'
)
If the target side vocabulary is different from the source side's, you need to specify the target's namespace here. If not, we'll assume it is "tokens", which is also the default choice for the source side, and this might cause them to share vocabularies. - beam_search :
BeamSearch
, optional (default =Lazy(BeamSearch)
)
This is used to during inference to select the tokens of the decoded output sequence. - tensor_based_metric :
Metric
, optional (default =None
)
A metric to track on validation data that takes raw tensors when its called. This metric must accept two arguments when called: a batched tensor of predicted token indices, and a batched tensor of gold token indices. - token_based_metric :
Metric
, optional (default =None
)
A metric to track on validation data that takes lists of lists of tokens as input. This metric must accept two arguments when called, both of typeList[List[str]]
. The first is a predicted sequence for each item in the batch and the second is a gold sequence for each item in the batch. - scheduled_sampling_ratio :
float
, optional (default =0.0
)
Defines ratio between teacher forced training and real output usage. If its zero (teacher forcing only) anddecoder_net
supports parallel decoding, we get the output predictions in a single forward pass of thedecoder_net
.
get_output_dim#
class AutoRegressiveSeqDecoder(SeqDecoder):
| ...
| def get_output_dim(self)
take_step#
class AutoRegressiveSeqDecoder(SeqDecoder):
| ...
| def take_step(
| self,
| last_predictions: torch.Tensor,
| state: Dict[str, torch.Tensor],
| step: int
| ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]
Take a decoding step. This is called by the beam search class.
Parameters¶
- last_predictions :
torch.Tensor
A tensor of shape(group_size,)
, which gives the indices of the predictions during the last time step. - state :
Dict[str, torch.Tensor]
A dictionary of tensors that contain the current state information needed to predict the next step, which includes the encoder outputs, the source mask, and the decoder hidden state and context. Each of these tensors has shape(group_size, *)
, where*
can be any other number of dimensions. - step :
int
The time step in beam search decoding.
Returns¶
- Tuple[torch.Tensor, Dict[str, torch.Tensor]]
A tuple of(log_probabilities, updated_state)
, wherelog_probabilities
is a tensor of shape(group_size, num_classes)
containing the predicted log probability of each class for the next step, for each item in the group, whileupdated_state
is a dictionary of tensors containing the encoder outputs, source mask, and updated decoder hidden state and context.
Notes We treat the inputs as a batch, even though `group_size` is not necessarily¶
equal to `batch_size`, since the group may contain multiple states
for each source sentence in the batch.
get_metrics#
class AutoRegressiveSeqDecoder(SeqDecoder):
| ...
| @overrides
| def get_metrics(self, reset: bool = False) -> Dict[str, float]
forward#
class AutoRegressiveSeqDecoder(SeqDecoder):
| ...
| @overrides
| def forward(
| self,
| encoder_out: Dict[str, torch.LongTensor],
| target_tokens: TextFieldTensors = None
| ) -> Dict[str, torch.Tensor]
post_process#
class AutoRegressiveSeqDecoder(SeqDecoder):
| ...
| @overrides
| def post_process(
| self,
| output_dict: Dict[str, torch.Tensor]
| ) -> Dict[str, torch.Tensor]
This method trims the output predictions to the first end symbol, replaces indices with
corresponding tokens, and adds a field called predicted_tokens
to the output_dict
.
indices_to_tokens#
class AutoRegressiveSeqDecoder(SeqDecoder):
| ...
| def indices_to_tokens(
| self,
| batch_indeces: numpy.ndarray
| ) -> List[List[str]]