Skip to content

util

allennlp.modules.transformer.util

[SOURCE]


apply_mask

def apply_mask(
    values: torch.FloatTensor,
    mask: Union[torch.BoolTensor, torch.IntTensor, torch.FloatTensor]
) -> torch.FloatTensor

Parameters

  • values : torch.FloatTensor
    Shape batch_size x num_attention_heads x source_seq_len x target_seq_len
  • mask : torch.BoolTensor
    Shape batch_size x target_seq_len OR batch_size x 1 x 1 x target_seq_len