pytorch_seq2seq_wrapper
allennlp.modules.seq2seq_encoders.pytorch_seq2seq_wrapper
PytorchSeq2SeqWrapper#
class PytorchSeq2SeqWrapper(Seq2SeqEncoder):
| def __init__(
| self,
| module: torch.nn.Module,
| stateful: bool = False
| ) -> None
Pytorch's RNNs have two outputs: the hidden state for every time step, and the hidden state at
the last time step for every layer. We just want the first one as a single output. This
wrapper pulls out that output, and adds a get_output_dim
method, which is useful if you
want to, e.g., define a linear + softmax layer on top of this to get some distribution over a
set of labels. The linear layer needs to know its input dimension before it is called, and you
can get that from get_output_dim
.
In order to be wrapped with this wrapper, a class must have the following members:
- `self.input_size: int`
- `self.hidden_size: int`
- `def forward(inputs: PackedSequence, hidden_state: torch.Tensor) ->
Tuple[PackedSequence, torch.Tensor]`.
- `self.bidirectional: bool` (optional)
This is what pytorch's RNN's look like - just make sure your class looks like those, and it should work.
Note that we require you to pass a binary mask of shape (batch_size, sequence_length)
when you call this module, to avoid subtle bugs around masking. If you already have a
PackedSequence
you can pass None
as the second parameter.
We support stateful RNNs where the final state from each batch is used as the initial
state for the subsequent batch by passing stateful=True
to the constructor.
get_input_dim#
class PytorchSeq2SeqWrapper(Seq2SeqEncoder):
| ...
| @overrides
| def get_input_dim(self) -> int
get_output_dim#
class PytorchSeq2SeqWrapper(Seq2SeqEncoder):
| ...
| @overrides
| def get_output_dim(self) -> int
is_bidirectional#
class PytorchSeq2SeqWrapper(Seq2SeqEncoder):
| ...
| @overrides
| def is_bidirectional(self) -> bool
forward#
class PytorchSeq2SeqWrapper(Seq2SeqEncoder):
| ...
| @overrides
| def forward(
| self,
| inputs: torch.Tensor,
| mask: torch.BoolTensor,
| hidden_state: torch.Tensor = None
| ) -> torch.Tensor
GruSeq2SeqEncoder#
@Seq2SeqEncoder.register("gru")
class GruSeq2SeqEncoder(PytorchSeq2SeqWrapper):
| def __init__(
| self,
| input_size: int,
| hidden_size: int,
| num_layers: int = 1,
| bias: bool = True,
| dropout: float = 0.0,
| bidirectional: bool = False,
| stateful: bool = False
| )
Registered as a Seq2SeqEncoder
with name "gru".
LstmSeq2SeqEncoder#
@Seq2SeqEncoder.register("lstm")
class LstmSeq2SeqEncoder(PytorchSeq2SeqWrapper):
| def __init__(
| self,
| input_size: int,
| hidden_size: int,
| num_layers: int = 1,
| bias: bool = True,
| dropout: float = 0.0,
| bidirectional: bool = False,
| stateful: bool = False
| )
Registered as a Seq2SeqEncoder
with name "lstm".
RnnSeq2SeqEncoder#
@Seq2SeqEncoder.register("rnn")
class RnnSeq2SeqEncoder(PytorchSeq2SeqWrapper):
| def __init__(
| self,
| input_size: int,
| hidden_size: int,
| num_layers: int = 1,
| nonlinearity: str = "tanh",
| bias: bool = True,
| dropout: float = 0.0,
| bidirectional: bool = False,
| stateful: bool = False
| )
Registered as a Seq2SeqEncoder
with name "rnn".
AugmentedLstmSeq2SeqEncoder#
@Seq2SeqEncoder.register("augmented_lstm")
class AugmentedLstmSeq2SeqEncoder(PytorchSeq2SeqWrapper):
| def __init__(
| self,
| input_size: int,
| hidden_size: int,
| go_forward: bool = True,
| recurrent_dropout_probability: float = 0.0,
| use_highway: bool = True,
| use_input_projection_bias: bool = True,
| stateful: bool = False
| ) -> None
Registered as a Seq2SeqEncoder
with name "augmented_lstm".
StackedAlternatingLstmSeq2SeqEncoder#
@Seq2SeqEncoder.register("alternating_lstm")
class StackedAlternatingLstmSeq2SeqEncoder(PytorchSeq2SeqWrapper):
| def __init__(
| self,
| input_size: int,
| hidden_size: int,
| num_layers: int,
| recurrent_dropout_probability: float = 0.0,
| use_highway: bool = True,
| use_input_projection_bias: bool = True,
| stateful: bool = False
| ) -> None
Registered as a Seq2SeqEncoder
with name "alternating_lstm".
StackedBidirectionalLstmSeq2SeqEncoder#
@Seq2SeqEncoder.register("stacked_bidirectional_lstm")
class StackedBidirectionalLstmSeq2SeqEncoder(PytorchSeq2SeqWrapper):
| def __init__(
| self,
| input_size: int,
| hidden_size: int,
| num_layers: int,
| recurrent_dropout_probability: float = 0.0,
| layer_dropout_probability: float = 0.0,
| use_highway: bool = True,
| stateful: bool = False
| ) -> None
Registered as a Seq2SeqEncoder
with name "stacked_bidirectional_lstm".