Skip to content

pytorch_seq2vec_wrapper

allennlp.modules.seq2vec_encoders.pytorch_seq2vec_wrapper

[SOURCE]


PytorchSeq2VecWrapper

class PytorchSeq2VecWrapper(Seq2VecEncoder):
 | def __init__(self, module: torch.nn.modules.RNNBase) -> None

Pytorch's RNNs have two outputs: the final hidden state for every time step, and the hidden state at the last time step for every layer. We just want the final hidden state of the last time step. This wrapper pulls out that output, and adds a get_output_dim method, which is useful if you want to, e.g., define a linear + softmax layer on top of this to get some distribution over a set of labels. The linear layer needs to know its input dimension before it is called, and you can get that from get_output_dim.

Also, there are lots of ways you could imagine going from an RNN hidden state at every timestep to a single vector - you could take the last vector at all layers in the stack, do some kind of pooling, take the last vector of the top layer in a stack, or many other options. We just take the final hidden state vector, or in the case of a bidirectional RNN cell, we concatenate the forward and backward final states together. TODO(mattg): allow for other ways of wrapping RNNs.

In order to be wrapped with this wrapper, a class must have the following members:

- `self.input_size: int`
- `self.hidden_size: int`
- `def forward(inputs: PackedSequence, hidden_state: torch.tensor) ->
  Tuple[PackedSequence, torch.Tensor]`.
- `self.bidirectional: bool` (optional)

This is what pytorch's RNN's look like - just make sure your class looks like those, and it should work.

Note that we require you to pass a binary mask of shape (batch_size, sequence_length) when you call this module, to avoid subtle bugs around masking. If you already have a PackedSequence you can pass None as the second parameter.

get_input_dim

class PytorchSeq2VecWrapper(Seq2VecEncoder):
 | ...
 | def get_input_dim(self) -> int

get_output_dim

class PytorchSeq2VecWrapper(Seq2VecEncoder):
 | ...
 | def get_output_dim(self) -> int

forward

class PytorchSeq2VecWrapper(Seq2VecEncoder):
 | ...
 | def forward(
 |     self,
 |     inputs: torch.Tensor,
 |     mask: torch.BoolTensor,
 |     hidden_state: torch.Tensor = None
 | ) -> torch.Tensor

GruSeq2VecEncoder

@Seq2VecEncoder.register("gru")
class GruSeq2VecEncoder(PytorchSeq2VecWrapper):
 | def __init__(
 |     self,
 |     input_size: int,
 |     hidden_size: int,
 |     num_layers: int = 1,
 |     bias: bool = True,
 |     dropout: float = 0.0,
 |     bidirectional: bool = False
 | )

Registered as a Seq2VecEncoder with name "gru".

LstmSeq2VecEncoder

@Seq2VecEncoder.register("lstm")
class LstmSeq2VecEncoder(PytorchSeq2VecWrapper):
 | def __init__(
 |     self,
 |     input_size: int,
 |     hidden_size: int,
 |     num_layers: int = 1,
 |     bias: bool = True,
 |     dropout: float = 0.0,
 |     bidirectional: bool = False
 | )

Registered as a Seq2VecEncoder with name "lstm".

RnnSeq2VecEncoder

@Seq2VecEncoder.register("rnn")
class RnnSeq2VecEncoder(PytorchSeq2VecWrapper):
 | def __init__(
 |     self,
 |     input_size: int,
 |     hidden_size: int,
 |     num_layers: int = 1,
 |     nonlinearity: str = "tanh",
 |     bias: bool = True,
 |     dropout: float = 0.0,
 |     bidirectional: bool = False
 | )

Registered as a Seq2VecEncoder with name "rnn".

AugmentedLstmSeq2VecEncoder

@Seq2VecEncoder.register("augmented_lstm")
class AugmentedLstmSeq2VecEncoder(PytorchSeq2VecWrapper):
 | def __init__(
 |     self,
 |     input_size: int,
 |     hidden_size: int,
 |     go_forward: bool = True,
 |     recurrent_dropout_probability: float = 0.0,
 |     use_highway: bool = True,
 |     use_input_projection_bias: bool = True
 | ) -> None

Registered as a Seq2VecEncoder with name "augmented_lstm".

StackedAlternatingLstmSeq2VecEncoder

@Seq2VecEncoder.register("alternating_lstm")
class StackedAlternatingLstmSeq2VecEncoder(PytorchSeq2VecWrapper):
 | def __init__(
 |     self,
 |     input_size: int,
 |     hidden_size: int,
 |     num_layers: int,
 |     recurrent_dropout_probability: float = 0.0,
 |     use_highway: bool = True,
 |     use_input_projection_bias: bool = True
 | ) -> None

Registered as a Seq2VecEncoder with name "alternating_lstm".

StackedBidirectionalLstmSeq2VecEncoder

@Seq2VecEncoder.register("stacked_bidirectional_lstm")
class StackedBidirectionalLstmSeq2VecEncoder(PytorchSeq2VecWrapper):
 | def __init__(
 |     self,
 |     input_size: int,
 |     hidden_size: int,
 |     num_layers: int,
 |     recurrent_dropout_probability: float = 0.0,
 |     layer_dropout_probability: float = 0.0,
 |     use_highway: bool = True
 | ) -> None

Registered as a Seq2VecEncoder with name "stacked_bidirectional_lstm".