feedforward
allennlp.modules.feedforward
A feed-forward neural network.
FeedForward¶
class FeedForward(torch.nn.Module, FromParams):
| def __init__(
| self,
| input_dim: int,
| num_layers: int,
| hidden_dims: Union[int, List[int]],
| activations: Union[Activation, List[Activation]],
| dropout: Union[float, List[float]] = 0.0
| ) -> None
This Module
is a feed-forward neural network, just a sequence of Linear
layers with
activation functions in between.
Parameters¶
- input_dim :
int
The dimensionality of the input. We assume the input has shape(batch_size, input_dim)
. - num_layers :
int
The number ofLinear
layers to apply to the input. - hidden_dims :
Union[int, List[int]]
The output dimension of each of theLinear
layers. If this is a singleint
, we use it for allLinear
layers. If it is aList[int]
,len(hidden_dims)
must benum_layers
. - activations :
Union[Activation, List[Activation]]
The activation function to use after eachLinear
layer. If this is a single function, we use it after allLinear
layers. If it is aList[Activation]
,len(activations)
must benum_layers
. Activation must have torch.nn.Module type. - dropout :
Union[float, List[float]]
, optional (default =0.0
)
If given, we will apply this amount of dropout after each layer. Semantics offloat
versusList[float]
is the same as with other parameters.
Examples¶
FeedForward(124, 2, [64, 32], torch.nn.ReLU(), 0.2)
#> FeedForward(
#> (_activations): ModuleList(
#> (0): ReLU()
#> (1): ReLU()
#> )
#> (_linear_layers): ModuleList(
#> (0): Linear(in_features=124, out_features=64, bias=True)
#> (1): Linear(in_features=64, out_features=32, bias=True)
#> )
#> (_dropout): ModuleList(
#> (0): Dropout(p=0.2, inplace=False)
#> (1): Dropout(p=0.2, inplace=False)
#> )
#> )
get_output_dim¶
class FeedForward(torch.nn.Module, FromParams):
| ...
| def get_output_dim(self)
get_input_dim¶
class FeedForward(torch.nn.Module, FromParams):
| ...
| def get_input_dim(self)
forward¶
class FeedForward(torch.nn.Module, FromParams):
| ...
| def forward(self, inputs: torch.Tensor) -> torch.Tensor