mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Add support for bool/byte attn_mask tensor in MultiheadAttention/Transformer modules (#33763)
Summary: Add the support to accept both float, byte, and bool tensors for `attn_mask`. No breakage is expected. - If a bool tensor is provided, positions with `True` are not allowed to attend while `False` values will be unchanged. - if a byte tensor is provided, it will be converted to bool tensor. Positions with non-zero are not allowed to attend while zero values will be unchanged. - If a float tensor is provided, it will be added to the attention weight. Note: the behavior of the float mask tensor is slightly different from the first two options because it is added to the attention weight, rather than calling `masked_fill_` function. Also, converting a byte tensor to bool tensor within `multi_head_attention_forward` causes extra overhead. Therefore, a bool mask is recommended here. For `key_padding_mask`: - if a bool tensor is provided, it will be converted to bool tensor. The positions with the value of `True` will be ignored while the position with the value of `False` will be unchanged. - If a byte tensor is provided, the positions with the value of non-zero will be ignored while the position with the value of zero will be unchanged. Pull Request resolved: https://github.com/pytorch/pytorch/pull/33763 Differential Revision: D20925358 Pulled By: zhangguanheng66 fbshipit-source-id: de174056be183cdad0f3de8024ee0a3c5eb364c9
This commit is contained in:
parent
9854df673c
commit
b607c83a26
|
|
@ -3057,7 +3057,7 @@ class TestNN(NNTestCase):
|
|||
return (src_indices < src_lengths).int().detach()
|
||||
|
||||
def _multihead_attn_test_helper(add_key_padding_mask=False, add_bias_kv=False, add_zero_attn=False,
|
||||
saved_kv=False, same_embed_dim=False):
|
||||
saved_kv=False, same_embed_dim=False, byte_mask=False):
|
||||
for _ in range(100):
|
||||
batch_sz, seq_len = [random.randint(2, 10) for r in range(2)]
|
||||
d_head = random.randint(3, 10)
|
||||
|
|
@ -3085,16 +3085,20 @@ class TestNN(NNTestCase):
|
|||
seq_mask = np.random.randint(0, 2, (1, seq_len))
|
||||
key_padding_mask = (np.repeat(seq_mask, batch_sz, axis=0) == 1)
|
||||
key_padding_mask_tensor = torch.from_numpy(key_padding_mask)
|
||||
|
||||
if byte_mask:
|
||||
key_padding_mask_tensor = key_padding_mask_tensor.byte()
|
||||
decoder_state = np.random.rand(batch_sz, d_model)
|
||||
K = np.random.rand(*dims)
|
||||
V = K
|
||||
Q = np.expand_dims(decoder_state, 1)
|
||||
attn_mask = np.random.randint(0 , 2, size=(1, seq_len))
|
||||
attn_mask_tensor = torch.from_numpy(attn_mask).float()
|
||||
attn_mask_tensor.masked_fill_(attn_mask_tensor == 0, float('-inf'))
|
||||
attn_mask_tensor.masked_fill_(attn_mask_tensor > 0, float('0.0'))
|
||||
attn_mask_tensor = attn_mask_tensor.double()
|
||||
if byte_mask:
|
||||
attn_mask_tensor = (attn_mask_tensor == 0).byte()
|
||||
else:
|
||||
attn_mask_tensor.masked_fill_(attn_mask_tensor == 0, float('-inf'))
|
||||
attn_mask_tensor.masked_fill_(attn_mask_tensor > 0, float('0.0'))
|
||||
attn_mask_tensor = attn_mask_tensor.double()
|
||||
|
||||
decoder_state_tensor = torch.from_numpy(decoder_state).to(torch.get_default_dtype())
|
||||
source_hid_tensor = torch.from_numpy(K).to(torch.get_default_dtype()).transpose(0, 1)
|
||||
|
|
@ -3239,6 +3243,10 @@ class TestNN(NNTestCase):
|
|||
_multihead_attn_test_helper(add_key_padding_mask=True, add_zero_attn=True,
|
||||
saved_kv=True, same_embed_dim=True)
|
||||
|
||||
def test_multihead_attn_all_arguments4():
|
||||
_multihead_attn_test_helper(add_key_padding_mask=True, add_zero_attn=True,
|
||||
saved_kv=True, same_embed_dim=True, byte_mask=True)
|
||||
|
||||
test_multihead_attn_add_zero_attn() # Test MultiheadAttention with add_zero_attn
|
||||
test_multihead_attn_add_bias_kv() # Test MultiheadAttention with add_bias_kv
|
||||
test_multihead_attn_no_masking() # Test MultiheadAttention without masking
|
||||
|
|
@ -3249,6 +3257,7 @@ class TestNN(NNTestCase):
|
|||
with self.assertRaisesRegex(AssertionError, "bias cannot be added to static key."):
|
||||
test_multihead_attn_all_arguments2() # Test MultiheadAttention with all the argument.
|
||||
test_multihead_attn_all_arguments3() # Test MultiheadAttention with all the argument.
|
||||
test_multihead_attn_all_arguments4() # Test MultiheadAttention with all the argument.
|
||||
|
||||
def test_multihead_attn_3d_attn_mask(self):
|
||||
embed_dim = 8
|
||||
|
|
|
|||
|
|
@ -3770,8 +3770,7 @@ def multi_head_attention_forward(query, # type: Tensor
|
|||
be ignored by the attention. This is an binary mask. When the value is True,
|
||||
the corresponding value on the attention layer will be filled with -inf.
|
||||
need_weights: output attn_output_weights.
|
||||
attn_mask: 2D or 3D mask that prevents attention to certain positions. This is an additive mask
|
||||
(i.e. the values will be added to the attention layer). A 2D mask will be broadcasted for all
|
||||
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
||||
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
||||
use_separate_proj_weight: the function accept the proj. weights for query, key,
|
||||
and value in different forms. If false, in_proj_weight will be used, which is
|
||||
|
|
@ -3788,10 +3787,17 @@ def multi_head_attention_forward(query, # type: Tensor
|
|||
the embedding dimension.
|
||||
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
|
||||
the embedding dimension.
|
||||
- key_padding_mask: :math:`(N, S)`, ByteTensor, where N is the batch size, S is the source sequence length.
|
||||
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
|
||||
If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
|
||||
will be unchanged. If a BoolTensor is provided, the positions with the
|
||||
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
|
||||
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
|
||||
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
|
||||
S is the source sequence length.
|
||||
S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
|
||||
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
|
||||
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
|
||||
are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
|
||||
is provided, it will be added to the attention weight.
|
||||
- static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
|
||||
N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
|
||||
- static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
|
||||
|
|
@ -3906,6 +3912,13 @@ def multi_head_attention_forward(query, # type: Tensor
|
|||
q = q * scaling
|
||||
|
||||
if attn_mask is not None:
|
||||
assert attn_mask.dtype == torch.float32 or attn_mask.dtype == torch.float64 or \
|
||||
attn_mask.dtype == torch.uint8 or attn_mask.dtype == torch.bool, \
|
||||
'Only float, byte, and bool types are supported for attn_mask, not {}'.format(attn_mask.dtype)
|
||||
if attn_mask.dtype == torch.uint8:
|
||||
warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
|
||||
attn_mask = attn_mask.to(torch.bool)
|
||||
|
||||
if attn_mask.dim() == 2:
|
||||
attn_mask = attn_mask.unsqueeze(0)
|
||||
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
|
||||
|
|
@ -3917,6 +3930,11 @@ def multi_head_attention_forward(query, # type: Tensor
|
|||
raise RuntimeError("attn_mask's dimension {} is not supported".format(attn_mask.dim()))
|
||||
# attn_mask's dim is 3 now.
|
||||
|
||||
# convert ByteTensor key_padding_mask to bool
|
||||
if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
|
||||
warnings.warn("Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
|
||||
key_padding_mask = key_padding_mask.to(torch.bool)
|
||||
|
||||
if bias_k is not None and bias_v is not None:
|
||||
if static_k is None and static_v is None:
|
||||
k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
|
||||
|
|
@ -3967,7 +3985,11 @@ def multi_head_attention_forward(query, # type: Tensor
|
|||
assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len]
|
||||
|
||||
if attn_mask is not None:
|
||||
attn_output_weights += attn_mask
|
||||
if attn_mask.dtype == torch.bool:
|
||||
attn_output_weights.masked_fill_(attn_mask, float('-inf'))
|
||||
else:
|
||||
attn_output_weights += attn_mask
|
||||
|
||||
|
||||
if key_padding_mask is not None:
|
||||
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
|
||||
|
|
|
|||
|
|
@ -835,8 +835,7 @@ class MultiheadAttention(Module):
|
|||
be ignored by the attention. This is an binary mask. When the value is True,
|
||||
the corresponding value on the attention layer will be filled with -inf.
|
||||
need_weights: output attn_output_weights.
|
||||
attn_mask: 2D or 3D mask that prevents attention to certain positions. This is an additive mask
|
||||
(i.e. the values will be added to the attention layer). A 2D mask will be broadcasted for all
|
||||
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
||||
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
||||
|
||||
Shape:
|
||||
|
|
@ -847,10 +846,17 @@ class MultiheadAttention(Module):
|
|||
the embedding dimension.
|
||||
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
|
||||
the embedding dimension.
|
||||
- key_padding_mask: :math:`(N, S)`, ByteTensor, where N is the batch size, S is the source sequence length.
|
||||
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
|
||||
If a ByteTensor is provided, the non-zero positions will be ignored while the position
|
||||
with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
|
||||
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
|
||||
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
|
||||
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
|
||||
S is the source sequence length.
|
||||
S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
|
||||
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
|
||||
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
|
||||
is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
|
||||
is provided, it will be added to the attention weight.
|
||||
|
||||
- Outputs:
|
||||
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
||||
|
|
|
|||
|
|
@ -89,14 +89,15 @@ class Transformer(Module):
|
|||
- tgt_key_padding_mask: :math:`(N, T)`.
|
||||
- memory_key_padding_mask: :math:`(N, S)`.
|
||||
|
||||
Note: [src/tgt/memory]_mask should be filled with
|
||||
float('-inf') for the masked positions and float(0.0) else. These masks
|
||||
ensure that predictions for position i depend only on the unmasked positions
|
||||
j and are applied identically for each sequence in a batch.
|
||||
[src/tgt/memory]_key_padding_mask should be a ByteTensor where True values are positions
|
||||
that should be masked with float('-inf') and False values will be unchanged.
|
||||
This mask ensures that no information will be taken from position i if
|
||||
it is masked, and has a separate mask for each sequence in a batch.
|
||||
Note: [src/tgt/memory]_mask ensures that position i is allowed to attend the unmasked
|
||||
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
|
||||
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
|
||||
are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
|
||||
is provided, it will be added to the attention weight.
|
||||
[src/tgt/memory]_key_padding_mask provides specified elements in the key to be ignored by
|
||||
the attention. If a ByteTensor is provided, the non-zero positions will be ignored while the zero
|
||||
positions will be unchanged. If a BoolTensor is provided, the positions with the
|
||||
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
|
||||
|
||||
- output: :math:`(T, N, E)`.
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user