A wrapper that unrolls the second (time) dimension of a tensor
into the first (batch) dimension, applies some other
and then rolls the time dimension back up.
class TimeDistributed(torch.nn.Module): | def __init__(self, module)
Given an input shaped like
(batch_size, time_steps, [rest]) and a
Module that takes
TimeDistributed reshapes the input to be
(batch_size * time_steps, [rest]), applies the contained
Module, then reshapes it back.
Note that while the above gives shapes with
batch_size first, this
Module also works if
batch_size is second - we always just combine the first two dimensions, then split them.
It also reshapes keyword arguments unless they are not tensors or their name is specified in
class TimeDistributed(torch.nn.Module): | ... | def forward( | self, | *inputs, | *, pass_through: List[str] = None, | **kwargs | )