Skip to content

util

allennlp.modules.transformer.util

[SOURCE]


FloatT

FloatT = Union[torch.FloatTensor]

IntT

IntT = Union[torch.IntTensor]

BoolT

BoolT = Union[torch.BoolTensor]

apply_mask

def apply_mask(
    values: torch.FloatTensor,
    mask: Union[torch.BoolTensor, torch.IntTensor, torch.FloatTensor]
) -> torch.FloatTensor

Parameters

  • values : torch.FloatTensor
    Shape batch_size x num_attention_heads x source_seq_len x target_seq_len
  • mask : torch.BoolTensor
    Shape batch_size x target_seq_len OR batch_size x 1 x 1 x target_seq_len

get_extended_attention_mask

def get_extended_attention_mask(
    attention_mask: torch.Tensor,
    input_shape: Tuple[int, ...],
    dtype: torch.dtype,
    is_decoder: bool = False
) -> torch.Tensor

Makes broadcastable attention and causal masks so that future and masked tokens are ignored.

Parameters

  • attention_mask : torch.Tensor
    Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
  • input_shape : Tuple[int, ...]
    The shape of the input to the model.
  • dtype : torch.dtype
    The datatype of the resulting mask.
  • is_decoder : bool, optional (default = False)
    If this is for a decoder stack.

Returns

  • torch.Tensor
    The extended attention mask, with a the same dtype as attention_mask.dtype.