mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Add fake_impl for _native_multi_head_attention (#163700)
Test Plan: See added test in test_export.py Differential Revision: D83099187 Pull Request resolved: https://github.com/pytorch/pytorch/pull/163700 Approved by: https://github.com/angelayi
This commit is contained in:
parent
7bad9c5a64
commit
21a41edd4f
|
|
@ -318,7 +318,7 @@ timm_vovnet,pass,0
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
torch_multimodal_clip,pass,3
|
torch_multimodal_clip,pass,0
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
|
@ -1087,6 +1087,93 @@ graph():
|
||||||
args = (torch.randn(15, 3, 256, 256), torch.ones(15, 32, 256, 256))
|
args = (torch.randn(15, 3, 256, 256), torch.ones(15, 32, 256, 256))
|
||||||
self.assertEqual(gm(*args), m(*args))
|
self.assertEqual(gm(*args), m(*args))
|
||||||
|
|
||||||
|
# stride() is called for an undefined tensor
|
||||||
|
@testing.expectedFailureCppRuntimeNonStrict
|
||||||
|
def test_native_multi_attention_head(self):
|
||||||
|
embed_dim = 64
|
||||||
|
num_heads = 4
|
||||||
|
bs = 16
|
||||||
|
sl = 8
|
||||||
|
device = "cpu"
|
||||||
|
|
||||||
|
q = 6 * torch.rand(bs, sl, embed_dim, device=device, dtype=torch.float32) - 3
|
||||||
|
k = q
|
||||||
|
v = q
|
||||||
|
|
||||||
|
qkv = torch.nn.Linear(
|
||||||
|
embed_dim, 3 * embed_dim, device=device, dtype=torch.float32
|
||||||
|
)
|
||||||
|
proj = torch.nn.Linear(embed_dim, embed_dim, device=device, dtype=torch.float32)
|
||||||
|
|
||||||
|
class NativeMHA(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embed_dim,
|
||||||
|
num_heads,
|
||||||
|
qkv,
|
||||||
|
proj,
|
||||||
|
need_weights,
|
||||||
|
average_attn_weights,
|
||||||
|
mask_type,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.qkv = qkv
|
||||||
|
self.proj = proj
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.need_weights = need_weights
|
||||||
|
self.average_attn_weights = average_attn_weights
|
||||||
|
self.mask_type = mask_type
|
||||||
|
|
||||||
|
def forward(self, q, k, v, key_padding_mask):
|
||||||
|
return torch._native_multi_head_attention(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
self.embed_dim,
|
||||||
|
self.num_heads,
|
||||||
|
self.qkv.weight,
|
||||||
|
self.qkv.bias,
|
||||||
|
self.proj.weight,
|
||||||
|
self.proj.bias,
|
||||||
|
key_padding_mask,
|
||||||
|
need_weights=False,
|
||||||
|
average_attn_weights=False,
|
||||||
|
mask_type=1, # mask_type = 1 => src_key_padding_mask, mask_type = 0 => src_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
for mask_type in (0, 1):
|
||||||
|
for need_weights in (True, False):
|
||||||
|
for average_attn_weights in (True, False):
|
||||||
|
npt = NativeMHA(
|
||||||
|
embed_dim=embed_dim,
|
||||||
|
num_heads=num_heads,
|
||||||
|
qkv=qkv,
|
||||||
|
proj=proj,
|
||||||
|
need_weights=need_weights,
|
||||||
|
average_attn_weights=average_attn_weights,
|
||||||
|
mask_type=mask_type,
|
||||||
|
)
|
||||||
|
sample_input = (q, k, v, None)
|
||||||
|
|
||||||
|
ep = export(
|
||||||
|
npt,
|
||||||
|
args=sample_input,
|
||||||
|
dynamic_shapes={
|
||||||
|
"q": {
|
||||||
|
0: Dim("dim0_q", max=1024),
|
||||||
|
},
|
||||||
|
"k": {
|
||||||
|
0: Dim("dim0_k", max=1024),
|
||||||
|
},
|
||||||
|
"v": {
|
||||||
|
0: Dim("dim0_v", max=1024),
|
||||||
|
},
|
||||||
|
"key_padding_mask": None,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.assertEqual(ep.module()(*sample_input), npt(*sample_input))
|
||||||
|
|
||||||
def test_unused_constant(self):
|
def test_unused_constant(self):
|
||||||
class M(torch.nn.Module):
|
class M(torch.nn.Module):
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
|
|
||||||
|
|
@ -7790,6 +7790,56 @@ def _create_unary_float_meta_func(func):
|
||||||
return _f
|
return _f
|
||||||
|
|
||||||
|
|
||||||
|
# Implementation follows cuda implementation native_multi_head_attention_cuda
|
||||||
|
@register_meta(aten._native_multi_head_attention.default)
|
||||||
|
def native_multi_head_attention_fake(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
embed_dim,
|
||||||
|
num_head,
|
||||||
|
qkv_weight,
|
||||||
|
qkv_bias,
|
||||||
|
proj_weight,
|
||||||
|
proj_bias,
|
||||||
|
mask=None,
|
||||||
|
need_weights=True,
|
||||||
|
average_attn_weights=True,
|
||||||
|
mask_type=None,
|
||||||
|
):
|
||||||
|
if query.is_nested or key.is_nested or value.is_nested:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"_native_multi_head_attention fake implementation does not support nested tensors"
|
||||||
|
)
|
||||||
|
|
||||||
|
if query.numel() == 0:
|
||||||
|
return (query.new_empty(query.shape), query.new_empty(0))
|
||||||
|
|
||||||
|
B = query.size(0) # B: batch size
|
||||||
|
T = query.size(1) # T: target sequence length
|
||||||
|
|
||||||
|
# In native_multi_head_attention_cuda,
|
||||||
|
# we have proj = transform0213_gemm_nt_bias(attn_ctx, proj_weight, proj_bias, query)
|
||||||
|
# , which does attn_ctx @ proj_weight.T + proj_bias
|
||||||
|
# so the last dim of output shape is proj_weight.size(0)
|
||||||
|
output_dim = proj_weight.size(0)
|
||||||
|
output = query.new_empty(B, T, output_dim)
|
||||||
|
|
||||||
|
if need_weights:
|
||||||
|
if average_attn_weights:
|
||||||
|
# When averaging attention weights, shape is [B, T, T] (averaged over heads)
|
||||||
|
# T = query seq len, S = key/value seq len
|
||||||
|
attn_weights = query.new_empty(B, T, T)
|
||||||
|
else:
|
||||||
|
# When not averaging, shape is [B, num_head, T, T]
|
||||||
|
# T = query seq len, S = key/value seq len
|
||||||
|
attn_weights = query.new_empty(B, num_head, T, T)
|
||||||
|
else:
|
||||||
|
attn_weights = query.new_empty(0)
|
||||||
|
|
||||||
|
return (output, attn_weights)
|
||||||
|
|
||||||
|
|
||||||
def _create_binary_float_meta_func(func):
|
def _create_binary_float_meta_func(func):
|
||||||
@register_meta(func)
|
@register_meta(func)
|
||||||
@out_wrapper()
|
@out_wrapper()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user