util
allennlp.modules.transformer.util
apply_mask¶
def apply_mask(
values: torch.FloatTensor,
mask: Union[torch.BoolTensor, torch.IntTensor, torch.FloatTensor]
) -> torch.FloatTensor
Parameters¶
- values :
torch.FloatTensor
Shapebatch_size x num_attention_heads x source_seq_len x target_seq_len
- mask :
torch.BoolTensor
Shapebatch_size x target_seq_len
ORbatch_size x 1 x 1 x target_seq_len