Skip to content




A wrapper that unrolls the second (time) dimension of a tensor into the first (batch) dimension, applies some other Module, and then rolls the time dimension back up.


class TimeDistributed(torch.nn.Module):
 | def __init__(self, module)

Given an input shaped like (batch_size, time_steps, [rest]) and a Module that takes inputs like (batch_size, [rest]), TimeDistributed reshapes the input to be (batch_size * time_steps, [rest]), applies the contained Module, then reshapes it back.

Note that while the above gives shapes with batch_size first, this Module also works if batch_size is second - we always just combine the first two dimensions, then split them.

It also reshapes keyword arguments unless they are not tensors or their name is specified in the optional pass_through iterable.


class TimeDistributed(torch.nn.Module):
 | ...
 | def forward(
 |     self,
 |     *inputs,
 |     *, pass_through: List[str] = None,
 |     **kwargs
 | )