diff --git a/test/test_transformers.py b/test/test_transformers.py index 86adfc26ece..5183e77931c 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -45,12 +45,15 @@ from torch._dynamo.testing import CompileCounterWithBackend from torch.testing._internal.common_methods_invocations import wrapper_set_seed from torch.testing._internal.common_cuda import ( - IS_JETSON, SM80OrLater, PLATFORM_SUPPORTS_FLASH_ATTENTION, + IS_JETSON, + SM80OrLater, + PLATFORM_SUPPORTS_FLASH_ATTENTION, PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, PLATFORM_SUPPORTS_FUSED_ATTENTION, PLATFORM_SUPPORTS_CUDNN_ATTENTION, SM90OrLater, - tf32_on_and_off + tf32_on_and_off, + tf32_enabled, ) if not IS_FBCODE: @@ -64,6 +67,7 @@ if TEST_FAIRSEQ: SdpaShape = namedtuple('Sdpa_Shape', ['batch', 'num_heads', 'seq_len', 'head_dim']) Tolerances = namedtuple('Tolerances', ['atol', 'rtol']) + @contextlib.contextmanager def use_deterministic_algorithims(mode: bool, warn_only: bool): r""" @@ -2998,17 +3002,30 @@ class TestSDPACudaOnly(NNTestCase): @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support SDPA") @unittest.skipIf(IS_JETSON, "causing sigkill on Jetson") @parametrize("batch_size", [1, 8]) - @parametrize("seq_len_q", [8, 103, 1024, 2048] if MEM_EFF_CAPABILITY_MATCHES_SM80 - else [4, 8, 256, 512]) - @parametrize("seq_len_k", [8, 103, 1024, 2048] if MEM_EFF_CAPABILITY_MATCHES_SM80 - else [4, 8, 256, 512]) - @parametrize("head_dim", [8, 16, 96, 128] if MEM_EFF_CAPABILITY_MATCHES_SM80 - else [8, 16, 32, 64]) + @parametrize( + "seq_len_q", + [8, 103, 1024, 2048] if MEM_EFF_CAPABILITY_MATCHES_SM80 else [4, 8, 256, 512], + ) + @parametrize( + "seq_len_k", + [8, 103, 1024, 2048] if MEM_EFF_CAPABILITY_MATCHES_SM80 else [4, 8, 256, 512], + ) + @parametrize( + "head_dim", + [8, 16, 96, 128] if MEM_EFF_CAPABILITY_MATCHES_SM80 else [8, 16, 32, 64], + ) @parametrize("is_causal", [False, True]) @parametrize("dropout_p", [0.0, 0.22]) - @parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32] if MEM_EFF_CAPABILITY_MATCHES_SM80 - else [torch.float16, torch.float32]) + @parametrize( + "dtype", + ( + [torch.float16, torch.bfloat16, torch.float32] + if MEM_EFF_CAPABILITY_MATCHES_SM80 + else [torch.float16, torch.float32] + ), + ) @parametrize("scale", [None, "l1"]) + @tf32_enabled() def test_mem_efficient_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int, head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype, scale: str): @@ -3097,17 +3114,30 @@ class TestSDPACudaOnly(NNTestCase): @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support SDPA") @unittest.skipIf(IS_JETSON, "causing sigkill on Jetson") @parametrize("batch_size", [1, 8]) - @parametrize("seq_len_q", [8, 312, 1024, 2048] if MEM_EFF_CAPABILITY_MATCHES_SM80 - else [8, 152, 512]) - @parametrize("seq_len_k", [8, 408, 1024, 2048] if MEM_EFF_CAPABILITY_MATCHES_SM80 - else [8, 37, 512]) - @parametrize("head_dim", [8, 16, 96, 128] if MEM_EFF_CAPABILITY_MATCHES_SM80 - else [8, 16, 32, 64]) + @parametrize( + "seq_len_q", + [8, 312, 1024, 2048] if MEM_EFF_CAPABILITY_MATCHES_SM80 else [8, 152, 512], + ) + @parametrize( + "seq_len_k", + [8, 408, 1024, 2048] if MEM_EFF_CAPABILITY_MATCHES_SM80 else [8, 37, 512], + ) + @parametrize( + "head_dim", + [8, 16, 96, 128] if MEM_EFF_CAPABILITY_MATCHES_SM80 else [8, 16, 32, 64], + ) @parametrize("is_causal", [False]) @parametrize("dropout_p", [0.0, 0.22]) - @parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32] if MEM_EFF_CAPABILITY_MATCHES_SM80 - else [torch.float16, torch.float32]) + @parametrize( + "dtype", + ( + [torch.float16, torch.bfloat16, torch.float32] + if MEM_EFF_CAPABILITY_MATCHES_SM80 + else [torch.float16, torch.float32] + ), + ) @parametrize("scale", [None, "l1"]) + @tf32_enabled() def test_mem_efficient_attention_attn_mask_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int, head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype, @@ -3137,7 +3167,6 @@ class TestSDPACudaOnly(NNTestCase): attn_mask = torch.rand(seq_len_q, seq_len_k, device=device, dtype=dtype, requires_grad=True) - higher_precision_dtype = torch.float64 if dtype == torch.float32 else torch.float32 query_ref, key_ref, value_ref = query_key_value_clones(query, key, value, dtype=higher_precision_dtype) attn_mask_ref = attn_mask.detach().to(higher_precision_dtype).requires_grad_(True) @@ -3204,7 +3233,10 @@ class TestSDPACudaOnly(NNTestCase): fudge_factors=fudge_factors, ) - @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware") + @unittest.skipIf( + not PLATFORM_SUPPORTS_FLASH_ATTENTION, + "Does not support SDPA or pre-SM80 hardware", + ) @unittest.skipIf(IS_JETSON, "causing sigkill on Jetson") @parametrize("batch_size", [1, 8]) @parametrize("seq_len_q", [4, 143, 2048]) @@ -3216,6 +3248,7 @@ class TestSDPACudaOnly(NNTestCase): @parametrize("scale", [None, "l1"]) @parametrize("enable_gqa", [True, False] if not TEST_WITH_ROCM else [False]) @parametrize("n_heads", [[16, 8], [10, 2]]) + @tf32_enabled() def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int, head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype, scale: str, enable_gqa: bool, n_heads: List[int]): @@ -3327,7 +3360,10 @@ class TestSDPACudaOnly(NNTestCase): fudge_factors=fudge_factors, ) - @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware") + @unittest.skipIf( + not PLATFORM_SUPPORTS_FLASH_ATTENTION, + "Does not support SDPA or pre-SM80 hardware", + ) @parametrize("batch_size", [1, 8]) @parametrize("seq_len_q", [256, 1024]) @parametrize("seq_len_k", [256, 1024]) @@ -3337,6 +3373,7 @@ class TestSDPACudaOnly(NNTestCase): @parametrize("dtype", [torch.float16]) @parametrize("scale", [None, "l1"]) @parametrize("fused_kernel", PLATFORM_SPECIFIC_SDPA) + @tf32_enabled() def test_fused_attention_vs_math_ref_grads_cudagraph(self, device, batch_size: int, seq_len_q: int, seq_len_k: int, head_dim: int, @@ -3479,7 +3516,6 @@ class TestSDPACudaOnly(NNTestCase): } ) - @skipIfRocm # Nested Tensor @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system") @parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] if diff --git a/torch/testing/_internal/common_cuda.py b/torch/testing/_internal/common_cuda.py index ae47eae9c1f..2fa3801661e 100644 --- a/torch/testing/_internal/common_cuda.py +++ b/torch/testing/_internal/common_cuda.py @@ -154,6 +154,23 @@ def tf32_on(self, tf32_precision=1e-5): self.precision = old_precision +@contextlib.contextmanager +def tf32_enabled(): + """ + Context manager to temporarily enable TF32 for CUDA operations. + Restores the previous TF32 state after exiting the context. + """ + old_allow_tf32_matmul = torch.backends.cuda.matmul.allow_tf32 + try: + torch.backends.cuda.matmul.allow_tf32 = True + with torch.backends.cudnn.flags( + enabled=None, benchmark=None, deterministic=None, allow_tf32=True + ): + yield + finally: + torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32_matmul + + # This is a wrapper that wraps a test to run this test twice, one with # allow_tf32=True, another with allow_tf32=False. When running with # allow_tf32=True, it will use reduced precision as specified by the