diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv index 1ec72cc22b8..f6a0bff369a 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv @@ -318,7 +318,7 @@ timm_vovnet,pass,0 -torch_multimodal_clip,pass,3 +torch_multimodal_clip,pass,0 diff --git a/test/export/test_export.py b/test/export/test_export.py index 72be959d9e7..35ca85b9161 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -1087,6 +1087,93 @@ graph(): args = (torch.randn(15, 3, 256, 256), torch.ones(15, 32, 256, 256)) self.assertEqual(gm(*args), m(*args)) + # stride() is called for an undefined tensor + @testing.expectedFailureCppRuntimeNonStrict + def test_native_multi_attention_head(self): + embed_dim = 64 + num_heads = 4 + bs = 16 + sl = 8 + device = "cpu" + + q = 6 * torch.rand(bs, sl, embed_dim, device=device, dtype=torch.float32) - 3 + k = q + v = q + + qkv = torch.nn.Linear( + embed_dim, 3 * embed_dim, device=device, dtype=torch.float32 + ) + proj = torch.nn.Linear(embed_dim, embed_dim, device=device, dtype=torch.float32) + + class NativeMHA(torch.nn.Module): + def __init__( + self, + embed_dim, + num_heads, + qkv, + proj, + need_weights, + average_attn_weights, + mask_type, + ): + super().__init__() + self.qkv = qkv + self.proj = proj + self.embed_dim = embed_dim + self.num_heads = num_heads + self.need_weights = need_weights + self.average_attn_weights = average_attn_weights + self.mask_type = mask_type + + def forward(self, q, k, v, key_padding_mask): + return torch._native_multi_head_attention( + q, + k, + v, + self.embed_dim, + self.num_heads, + self.qkv.weight, + self.qkv.bias, + self.proj.weight, + self.proj.bias, + key_padding_mask, + need_weights=False, + average_attn_weights=False, + mask_type=1, # mask_type = 1 => src_key_padding_mask, mask_type = 0 => src_mask + ) + + for mask_type in (0, 1): + for need_weights in (True, False): + for average_attn_weights in (True, False): + npt = NativeMHA( + embed_dim=embed_dim, + num_heads=num_heads, + qkv=qkv, + proj=proj, + need_weights=need_weights, + average_attn_weights=average_attn_weights, + mask_type=mask_type, + ) + sample_input = (q, k, v, None) + + ep = export( + npt, + args=sample_input, + dynamic_shapes={ + "q": { + 0: Dim("dim0_q", max=1024), + }, + "k": { + 0: Dim("dim0_k", max=1024), + }, + "v": { + 0: Dim("dim0_v", max=1024), + }, + "key_padding_mask": None, + }, + ) + self.assertEqual(ep.module()(*sample_input), npt(*sample_input)) + def test_unused_constant(self): class M(torch.nn.Module): def forward(self, x): diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 8435b80f34c..1b067df9f4d 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -7790,6 +7790,56 @@ def _create_unary_float_meta_func(func): return _f +# Implementation follows cuda implementation native_multi_head_attention_cuda +@register_meta(aten._native_multi_head_attention.default) +def native_multi_head_attention_fake( + query, + key, + value, + embed_dim, + num_head, + qkv_weight, + qkv_bias, + proj_weight, + proj_bias, + mask=None, + need_weights=True, + average_attn_weights=True, + mask_type=None, +): + if query.is_nested or key.is_nested or value.is_nested: + raise NotImplementedError( + "_native_multi_head_attention fake implementation does not support nested tensors" + ) + + if query.numel() == 0: + return (query.new_empty(query.shape), query.new_empty(0)) + + B = query.size(0) # B: batch size + T = query.size(1) # T: target sequence length + + # In native_multi_head_attention_cuda, + # we have proj = transform0213_gemm_nt_bias(attn_ctx, proj_weight, proj_bias, query) + # , which does attn_ctx @ proj_weight.T + proj_bias + # so the last dim of output shape is proj_weight.size(0) + output_dim = proj_weight.size(0) + output = query.new_empty(B, T, output_dim) + + if need_weights: + if average_attn_weights: + # When averaging attention weights, shape is [B, T, T] (averaged over heads) + # T = query seq len, S = key/value seq len + attn_weights = query.new_empty(B, T, T) + else: + # When not averaging, shape is [B, num_head, T, T] + # T = query seq len, S = key/value seq len + attn_weights = query.new_empty(B, num_head, T, T) + else: + attn_weights = query.new_empty(0) + + return (output, attn_weights) + + def _create_binary_float_meta_func(func): @register_meta(func) @out_wrapper()