util
allennlp.modules.transformer.util
FloatT¶
FloatT = Union[torch.FloatTensor]
IntT¶
IntT = Union[torch.IntTensor]
BoolT¶
BoolT = Union[torch.BoolTensor]
apply_mask¶
def apply_mask(
values: torch.FloatTensor,
mask: Union[torch.BoolTensor, torch.IntTensor, torch.FloatTensor]
) -> torch.FloatTensor
Parameters¶
- values :
torch.FloatTensor
Shapebatch_size x num_attention_heads x source_seq_len x target_seq_len
- mask :
torch.BoolTensor
Shapebatch_size x target_seq_len
ORbatch_size x 1 x 1 x target_seq_len
get_extended_attention_mask¶
def get_extended_attention_mask(
attention_mask: torch.Tensor,
input_shape: Tuple[int, ...],
dtype: torch.dtype,
is_decoder: bool = False
) -> torch.Tensor
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
Parameters¶
- attention_mask :
torch.Tensor
Mask with ones indicating tokens to attend to, zeros for tokens to ignore. - input_shape :
Tuple[int, ...]
The shape of the input to the model. - dtype :
torch.dtype
The datatype of the resulting mask. - is_decoder :
bool
, optional (default =False
)
If this is for a decoder stack.
Returns¶
torch.Tensor
The extended attention mask, with a the same dtype asattention_mask.dtype
.