Skip to content

seq_decoder

allennlp_models.generation.modules.seq_decoders.seq_decoder

[SOURCE]


SeqDecoder#

class SeqDecoder(Module,  Registrable):
 | def __init__(self, target_embedder: Embedding) -> None

A SeqDecoder abstract class representing the entire decoder (embedding and neural network) of a Seq2Seq architecture. This is meant to be used with allennlp.models.encoder_decoder.composed_seq2seq.ComposedSeq2Seq.

The implementation of this abstract class ideally uses a decoder neural net allennlp.modules.seq2seq_decoders.decoder_net.DecoderNet for decoding.

The default_implementation allennlp.modules.seq2seq_decoders.seq_decoder.auto_regressive_seq_decoder.AutoRegressiveSeqDecoder covers most use cases. More likely that we will use the default implementation instead of creating a new implementation.

Parameters

  • target_embedder : Embedding
    Embedder for target tokens. Needed in the base class to enable weight tying.

default_implementation#

class SeqDecoder(Module,  Registrable):
 | ...
 | default_implementation = "auto_regressive_seq_decoder"

get_output_dim#

class SeqDecoder(Module,  Registrable):
 | ...
 | def get_output_dim(self) -> int

The dimension of each timestep of the hidden state in the layer before final softmax. Needed to check whether the model is compatible for embedding-final layer weight tying.

get_metrics#

class SeqDecoder(Module,  Registrable):
 | ...
 | def get_metrics(self, reset: bool = False) -> Dict[str, float]

The decoder is responsible for computing metrics using the target tokens.

forward#

class SeqDecoder(Module,  Registrable):
 | ...
 | def forward(
 |     self,
 |     encoder_out: Dict[str, torch.LongTensor],
 |     target_tokens: Optional[Dict[str, torch.LongTensor]] = None
 | ) -> Dict[str, torch.Tensor]

Decoding from encoded states to sequence of outputs also computes loss if target_tokens are given.

Parameters

  • encoder_out : Dict[str, torch.LongTensor]
    Dictionary with encoded state, ideally containing the encoded vectors and the source mask.
  • target_tokens : Dict[str, torch.LongTensor], optional
    The output of TextField.as_array() applied on the target TextField.

post_process#

class SeqDecoder(Module,  Registrable):
 | ...
 | def post_process(
 |     self,
 |     output_dict: Dict[str, torch.Tensor]
 | ) -> Dict[str, torch.Tensor]

Post processing for converting raw outputs to prediction during inference. The composing models such allennlp.models.encoder_decoders.composed_seq2seq.ComposedSeq2Seq can call this method when decode is called.