From 31e59766e7e7b51e8dddd4a6967891ac01f4d37b Mon Sep 17 00:00:00 2001 From: Joel Schlosser Date: Tue, 13 Feb 2024 18:29:52 -0500 Subject: [PATCH] Fix meta registration for _flash_attention_forward() (#119812) Meta registration wrongly assumes 4D inputs, while the underlying op allows 3D inputs for the `mha_varlen_fwd()` case. Testing: I added `detach()`es so the NJT test `test_sdpa_compile()` won't fail for a view-related reason. It should pass now with this fix. Pull Request resolved: https://github.com/pytorch/pytorch/pull/119812 Approved by: https://github.com/drisspg --- test/test_nestedtensor.py | 16 ++++++++-------- torch/_meta_registrations.py | 15 +++++++++------ torch/nested/_internal/sdpa.py | 3 +-- 3 files changed, 18 insertions(+), 16 deletions(-) diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index 355ae719eb1..8cf64b59f9a 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -11,13 +11,14 @@ import numpy as np import torch import torch.nn import torch.nn.functional as F -from torch.testing._internal.common_cuda import SM80OrLater +from torch.testing._internal.common_cuda import SM70OrLater, SM80OrLater from torch.testing._internal.common_device_type import ( dtypes, dtypesIfCUDA, instantiate_device_type_tests, onlyCPU, onlyCUDA, + skipCUDAIf, skipMeta, PYTORCH_CUDA_MEMCHECK, ) @@ -3827,11 +3828,10 @@ class TestNestedTensorSubclass(TestCase): if not (str(device).startswith("cuda") and dtype == torch.bfloat16): check_forward_backward() - # This requires NT -> NT views to work in inductor, which is a TODO - @unittest.expectedFailure # noqa: E301 + @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") @onlyCUDA - @parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32] if - SM80OrLater else [torch.float16, torch.float32]) + @dtypes(*([torch.float16, torch.bfloat16, torch.float32] if SM80OrLater + else [torch.float16, torch.float32])) def test_sdpa_compile(self, device, dtype): batch_size = 1 emb_dims = 1024 @@ -3857,9 +3857,9 @@ class TestNestedTensorSubclass(TestCase): k_d2 = key(x_d2).view(batch_size, -1, n_heads, head_dims).transpose(1, 2) v_d2 = value(x_d2).view(batch_size, -1, n_heads, head_dims).transpose(1, 2) - q_nt = query(x_nt).view(*x_nt.size()[0:2], n_heads, head_dims).transpose(1, 2) - k_nt = key(x_nt).view(*x_nt.size()[0:2], n_heads, head_dims).transpose(1, 2) - v_nt = value(x_nt).view(*x_nt.size()[0:2], n_heads, head_dims).transpose(1, 2) + q_nt = query(x_nt).view(*x_nt.size()[0:2], n_heads, head_dims).detach().transpose(1, 2) + k_nt = key(x_nt).view(*x_nt.size()[0:2], n_heads, head_dims).detach().transpose(1, 2) + v_nt = value(x_nt).view(*x_nt.size()[0:2], n_heads, head_dims).detach().transpose(1, 2) # High Precision Math Reference q_d1_f32 = q_d1.to(torch.float32) diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 86685968ac9..cf800f22743 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -5354,12 +5354,15 @@ def meta__flash_attention_forward( return_debug_mask: bool, scale: Optional[float] = None, ): - batch_size = query.size(0) - max_seqlen_batch_q = query.size(1) - num_heads = query.size(2) - head_dim = query.size(3) - - max_seqlen_batch_k = key.size(1) + # NB: there are two underlying paths: + # 1. normal dense path; expect 4D inputs of shape (batch_size, seqlen, num_heads, head_dim) + # 2. varseqlen path; expect 3D inputs of shape (total, num_heads, head_dim) where total + # includes all batch item sequences. cum_seq_q / cum_seq_k contain offsets into total + batch_size = query.size(0) if cum_seq_q is None else cum_seq_q.numel() - 1 + max_seqlen_batch_q = query.size(1) if cum_seq_q is None else max_q + max_seqlen_batch_k = key.size(1) if cum_seq_k is None else max_k + num_heads = query.size(-2) + head_dim = query.size(-1) # Cuda Path attention = torch.empty_like(query) diff --git a/torch/nested/_internal/sdpa.py b/torch/nested/_internal/sdpa.py index 5a2ce0ada12..204d4aa5a3f 100644 --- a/torch/nested/_internal/sdpa.py +++ b/torch/nested/_internal/sdpa.py @@ -1,5 +1,4 @@ import logging -import math from typing import Optional, Tuple import torch @@ -604,7 +603,7 @@ def _pad_last_dim( # TODO: coalesce with torch/nn/utils/attention.py def _calculate_scale(query, scale): - softmax_scale = scale if scale is not None else math.sqrt(1.0 / query.size(-1)) + softmax_scale = scale if scale is not None else torch.sym_sqrt(1.0 / query.size(-1)) return softmax_scale