composed_seq2seq
allennlp_models.generation.models.composed_seq2seq
ComposedSeq2Seq#
@Model.register("composed_seq2seq")
class ComposedSeq2Seq(Model):
| def __init__(
| self,
| vocab: Vocabulary,
| source_text_embedder: TextFieldEmbedder,
| encoder: Seq2SeqEncoder,
| decoder: SeqDecoder,
| tied_source_embedder_key: Optional[str] = None,
| initializer: InitializerApplicator = InitializerApplicator(),
| **kwargs
| ) -> None
This ComposedSeq2Seq
class is a Model
which takes a sequence, encodes it, and then
uses the encoded representations to decode another sequence. You can use this as the basis for
a neural machine translation system, an abstractive summarization system, or any other common
seq2seq problem. The model here is simple, but should be a decent starting place for
implementing recent models for these tasks.
The ComposedSeq2Seq
class composes separate Seq2SeqEncoder
and SeqDecoder
classes.
These parts are customizable and are independent from each other.
Parameters¶
- vocab :
Vocabulary
Vocabulary containing source and target vocabularies. They may be under the same namespace (tokens
) or the target tokens can have a different namespace, in which case it needs to be specified astarget_namespace
. - source_text_embedders :
TextFieldEmbedder
Embedders for source side sequences - encoder :
Seq2SeqEncoder
The encoder of the "encoder/decoder" model - decoder :
SeqDecoder
The decoder of the "encoder/decoder" model - tied_source_embedder_key :
str
, optional (default =None
)
If specified, this key is used to obtain token_embedder insource_text_embedder
and the weights are shared/tied with the decoder's target embedding weights. - initializer :
InitializerApplicator
, optional (default =InitializerApplicator()
)
Used to initialize the model parameters.
forward#
class ComposedSeq2Seq(Model):
| ...
| def forward(
| self,
| source_tokens: TextFieldTensors,
| target_tokens: TextFieldTensors = None
| ) -> Dict[str, torch.Tensor]
Make forward pass on the encoder and decoder for producing the entire target sequence.
Parameters¶
- source_tokens :
TextFieldTensors
The output ofTextField.as_array()
applied on the sourceTextField
. This will be passed through aTextFieldEmbedder
and then through an encoder. - target_tokens :
TextFieldTensors
, optional (default =None
)
Output ofTextfield.as_array()
applied on targetTextField
. We assume that the target tokens are also represented as aTextField
.
Returns¶
Dict[str, torch.Tensor]
The output tensors from the decoder.
make_output_human_readable#
class ComposedSeq2Seq(Model):
| ...
| def make_output_human_readable(
| self,
| output_dict: Dict[str, torch.Tensor]
| ) -> Dict[str, torch.Tensor]
Finalize predictions.
get_metrics#
class ComposedSeq2Seq(Model):
| ...
| def get_metrics(self, reset: bool = False) -> Dict[str, float]
default_predictor#
class ComposedSeq2Seq(Model):
| ...
| default_predictor = "seq2seq"