Make sure all SDPA tests are ran with tensor cores enabled (#135592)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135592
Approved by: https://github.com/eqy
This commit is contained in:
drisspg 2024-10-29 10:24:15 -07:00 committed by PyTorch MergeBot
parent c81d4fd0a8
commit 80c7c7178e
2 changed files with 75 additions and 22 deletions

View File

@ -45,12 +45,15 @@ from torch._dynamo.testing import CompileCounterWithBackend
from torch.testing._internal.common_methods_invocations import wrapper_set_seed
from torch.testing._internal.common_cuda import (
IS_JETSON, SM80OrLater, PLATFORM_SUPPORTS_FLASH_ATTENTION,
IS_JETSON,
SM80OrLater,
PLATFORM_SUPPORTS_FLASH_ATTENTION,
PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
PLATFORM_SUPPORTS_FUSED_ATTENTION,
PLATFORM_SUPPORTS_CUDNN_ATTENTION,
SM90OrLater,
tf32_on_and_off
tf32_on_and_off,
tf32_enabled,
)
if not IS_FBCODE:
@ -64,6 +67,7 @@ if TEST_FAIRSEQ:
SdpaShape = namedtuple('Sdpa_Shape', ['batch', 'num_heads', 'seq_len', 'head_dim'])
Tolerances = namedtuple('Tolerances', ['atol', 'rtol'])
@contextlib.contextmanager
def use_deterministic_algorithims(mode: bool, warn_only: bool):
r"""
@ -2998,17 +3002,30 @@ class TestSDPACudaOnly(NNTestCase):
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support SDPA")
@unittest.skipIf(IS_JETSON, "causing sigkill on Jetson")
@parametrize("batch_size", [1, 8])
@parametrize("seq_len_q", [8, 103, 1024, 2048] if MEM_EFF_CAPABILITY_MATCHES_SM80
else [4, 8, 256, 512])
@parametrize("seq_len_k", [8, 103, 1024, 2048] if MEM_EFF_CAPABILITY_MATCHES_SM80
else [4, 8, 256, 512])
@parametrize("head_dim", [8, 16, 96, 128] if MEM_EFF_CAPABILITY_MATCHES_SM80
else [8, 16, 32, 64])
@parametrize(
"seq_len_q",
[8, 103, 1024, 2048] if MEM_EFF_CAPABILITY_MATCHES_SM80 else [4, 8, 256, 512],
)
@parametrize(
"seq_len_k",
[8, 103, 1024, 2048] if MEM_EFF_CAPABILITY_MATCHES_SM80 else [4, 8, 256, 512],
)
@parametrize(
"head_dim",
[8, 16, 96, 128] if MEM_EFF_CAPABILITY_MATCHES_SM80 else [8, 16, 32, 64],
)
@parametrize("is_causal", [False, True])
@parametrize("dropout_p", [0.0, 0.22])
@parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32] if MEM_EFF_CAPABILITY_MATCHES_SM80
else [torch.float16, torch.float32])
@parametrize(
"dtype",
(
[torch.float16, torch.bfloat16, torch.float32]
if MEM_EFF_CAPABILITY_MATCHES_SM80
else [torch.float16, torch.float32]
),
)
@parametrize("scale", [None, "l1"])
@tf32_enabled()
def test_mem_efficient_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int,
head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype,
scale: str):
@ -3097,17 +3114,30 @@ class TestSDPACudaOnly(NNTestCase):
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support SDPA")
@unittest.skipIf(IS_JETSON, "causing sigkill on Jetson")
@parametrize("batch_size", [1, 8])
@parametrize("seq_len_q", [8, 312, 1024, 2048] if MEM_EFF_CAPABILITY_MATCHES_SM80
else [8, 152, 512])
@parametrize("seq_len_k", [8, 408, 1024, 2048] if MEM_EFF_CAPABILITY_MATCHES_SM80
else [8, 37, 512])
@parametrize("head_dim", [8, 16, 96, 128] if MEM_EFF_CAPABILITY_MATCHES_SM80
else [8, 16, 32, 64])
@parametrize(
"seq_len_q",
[8, 312, 1024, 2048] if MEM_EFF_CAPABILITY_MATCHES_SM80 else [8, 152, 512],
)
@parametrize(
"seq_len_k",
[8, 408, 1024, 2048] if MEM_EFF_CAPABILITY_MATCHES_SM80 else [8, 37, 512],
)
@parametrize(
"head_dim",
[8, 16, 96, 128] if MEM_EFF_CAPABILITY_MATCHES_SM80 else [8, 16, 32, 64],
)
@parametrize("is_causal", [False])
@parametrize("dropout_p", [0.0, 0.22])
@parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32] if MEM_EFF_CAPABILITY_MATCHES_SM80
else [torch.float16, torch.float32])
@parametrize(
"dtype",
(
[torch.float16, torch.bfloat16, torch.float32]
if MEM_EFF_CAPABILITY_MATCHES_SM80
else [torch.float16, torch.float32]
),
)
@parametrize("scale", [None, "l1"])
@tf32_enabled()
def test_mem_efficient_attention_attn_mask_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int,
seq_len_k: int, head_dim: int, is_causal: bool,
dropout_p: float, dtype: torch.dtype,
@ -3137,7 +3167,6 @@ class TestSDPACudaOnly(NNTestCase):
attn_mask = torch.rand(seq_len_q, seq_len_k, device=device, dtype=dtype, requires_grad=True)
higher_precision_dtype = torch.float64 if dtype == torch.float32 else torch.float32
query_ref, key_ref, value_ref = query_key_value_clones(query, key, value, dtype=higher_precision_dtype)
attn_mask_ref = attn_mask.detach().to(higher_precision_dtype).requires_grad_(True)
@ -3204,7 +3233,10 @@ class TestSDPACudaOnly(NNTestCase):
fudge_factors=fudge_factors,
)
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware")
@unittest.skipIf(
not PLATFORM_SUPPORTS_FLASH_ATTENTION,
"Does not support SDPA or pre-SM80 hardware",
)
@unittest.skipIf(IS_JETSON, "causing sigkill on Jetson")
@parametrize("batch_size", [1, 8])
@parametrize("seq_len_q", [4, 143, 2048])
@ -3216,6 +3248,7 @@ class TestSDPACudaOnly(NNTestCase):
@parametrize("scale", [None, "l1"])
@parametrize("enable_gqa", [True, False] if not TEST_WITH_ROCM else [False])
@parametrize("n_heads", [[16, 8], [10, 2]])
@tf32_enabled()
def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int,
head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype,
scale: str, enable_gqa: bool, n_heads: List[int]):
@ -3327,7 +3360,10 @@ class TestSDPACudaOnly(NNTestCase):
fudge_factors=fudge_factors,
)
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware")
@unittest.skipIf(
not PLATFORM_SUPPORTS_FLASH_ATTENTION,
"Does not support SDPA or pre-SM80 hardware",
)
@parametrize("batch_size", [1, 8])
@parametrize("seq_len_q", [256, 1024])
@parametrize("seq_len_k", [256, 1024])
@ -3337,6 +3373,7 @@ class TestSDPACudaOnly(NNTestCase):
@parametrize("dtype", [torch.float16])
@parametrize("scale", [None, "l1"])
@parametrize("fused_kernel", PLATFORM_SPECIFIC_SDPA)
@tf32_enabled()
def test_fused_attention_vs_math_ref_grads_cudagraph(self, device, batch_size: int,
seq_len_q: int, seq_len_k: int,
head_dim: int,
@ -3479,7 +3516,6 @@ class TestSDPACudaOnly(NNTestCase):
}
)
@skipIfRocm # Nested Tensor
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system")
@parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] if

View File

@ -154,6 +154,23 @@ def tf32_on(self, tf32_precision=1e-5):
self.precision = old_precision
@contextlib.contextmanager
def tf32_enabled():
"""
Context manager to temporarily enable TF32 for CUDA operations.
Restores the previous TF32 state after exiting the context.
"""
old_allow_tf32_matmul = torch.backends.cuda.matmul.allow_tf32
try:
torch.backends.cuda.matmul.allow_tf32 = True
with torch.backends.cudnn.flags(
enabled=None, benchmark=None, deterministic=None, allow_tf32=True
):
yield
finally:
torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32_matmul
# This is a wrapper that wraps a test to run this test twice, one with
# allow_tf32=True, another with allow_tf32=False. When running with
# allow_tf32=True, it will use reduced precision as specified by the