mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Make sure all SDPA tests are ran with tensor cores enabled (#135592)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135592 Approved by: https://github.com/eqy
This commit is contained in:
parent
c81d4fd0a8
commit
80c7c7178e
|
|
@ -45,12 +45,15 @@ from torch._dynamo.testing import CompileCounterWithBackend
|
|||
|
||||
from torch.testing._internal.common_methods_invocations import wrapper_set_seed
|
||||
from torch.testing._internal.common_cuda import (
|
||||
IS_JETSON, SM80OrLater, PLATFORM_SUPPORTS_FLASH_ATTENTION,
|
||||
IS_JETSON,
|
||||
SM80OrLater,
|
||||
PLATFORM_SUPPORTS_FLASH_ATTENTION,
|
||||
PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
|
||||
PLATFORM_SUPPORTS_FUSED_ATTENTION,
|
||||
PLATFORM_SUPPORTS_CUDNN_ATTENTION,
|
||||
SM90OrLater,
|
||||
tf32_on_and_off
|
||||
tf32_on_and_off,
|
||||
tf32_enabled,
|
||||
)
|
||||
|
||||
if not IS_FBCODE:
|
||||
|
|
@ -64,6 +67,7 @@ if TEST_FAIRSEQ:
|
|||
SdpaShape = namedtuple('Sdpa_Shape', ['batch', 'num_heads', 'seq_len', 'head_dim'])
|
||||
Tolerances = namedtuple('Tolerances', ['atol', 'rtol'])
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def use_deterministic_algorithims(mode: bool, warn_only: bool):
|
||||
r"""
|
||||
|
|
@ -2998,17 +3002,30 @@ class TestSDPACudaOnly(NNTestCase):
|
|||
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support SDPA")
|
||||
@unittest.skipIf(IS_JETSON, "causing sigkill on Jetson")
|
||||
@parametrize("batch_size", [1, 8])
|
||||
@parametrize("seq_len_q", [8, 103, 1024, 2048] if MEM_EFF_CAPABILITY_MATCHES_SM80
|
||||
else [4, 8, 256, 512])
|
||||
@parametrize("seq_len_k", [8, 103, 1024, 2048] if MEM_EFF_CAPABILITY_MATCHES_SM80
|
||||
else [4, 8, 256, 512])
|
||||
@parametrize("head_dim", [8, 16, 96, 128] if MEM_EFF_CAPABILITY_MATCHES_SM80
|
||||
else [8, 16, 32, 64])
|
||||
@parametrize(
|
||||
"seq_len_q",
|
||||
[8, 103, 1024, 2048] if MEM_EFF_CAPABILITY_MATCHES_SM80 else [4, 8, 256, 512],
|
||||
)
|
||||
@parametrize(
|
||||
"seq_len_k",
|
||||
[8, 103, 1024, 2048] if MEM_EFF_CAPABILITY_MATCHES_SM80 else [4, 8, 256, 512],
|
||||
)
|
||||
@parametrize(
|
||||
"head_dim",
|
||||
[8, 16, 96, 128] if MEM_EFF_CAPABILITY_MATCHES_SM80 else [8, 16, 32, 64],
|
||||
)
|
||||
@parametrize("is_causal", [False, True])
|
||||
@parametrize("dropout_p", [0.0, 0.22])
|
||||
@parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32] if MEM_EFF_CAPABILITY_MATCHES_SM80
|
||||
else [torch.float16, torch.float32])
|
||||
@parametrize(
|
||||
"dtype",
|
||||
(
|
||||
[torch.float16, torch.bfloat16, torch.float32]
|
||||
if MEM_EFF_CAPABILITY_MATCHES_SM80
|
||||
else [torch.float16, torch.float32]
|
||||
),
|
||||
)
|
||||
@parametrize("scale", [None, "l1"])
|
||||
@tf32_enabled()
|
||||
def test_mem_efficient_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int,
|
||||
head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype,
|
||||
scale: str):
|
||||
|
|
@ -3097,17 +3114,30 @@ class TestSDPACudaOnly(NNTestCase):
|
|||
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support SDPA")
|
||||
@unittest.skipIf(IS_JETSON, "causing sigkill on Jetson")
|
||||
@parametrize("batch_size", [1, 8])
|
||||
@parametrize("seq_len_q", [8, 312, 1024, 2048] if MEM_EFF_CAPABILITY_MATCHES_SM80
|
||||
else [8, 152, 512])
|
||||
@parametrize("seq_len_k", [8, 408, 1024, 2048] if MEM_EFF_CAPABILITY_MATCHES_SM80
|
||||
else [8, 37, 512])
|
||||
@parametrize("head_dim", [8, 16, 96, 128] if MEM_EFF_CAPABILITY_MATCHES_SM80
|
||||
else [8, 16, 32, 64])
|
||||
@parametrize(
|
||||
"seq_len_q",
|
||||
[8, 312, 1024, 2048] if MEM_EFF_CAPABILITY_MATCHES_SM80 else [8, 152, 512],
|
||||
)
|
||||
@parametrize(
|
||||
"seq_len_k",
|
||||
[8, 408, 1024, 2048] if MEM_EFF_CAPABILITY_MATCHES_SM80 else [8, 37, 512],
|
||||
)
|
||||
@parametrize(
|
||||
"head_dim",
|
||||
[8, 16, 96, 128] if MEM_EFF_CAPABILITY_MATCHES_SM80 else [8, 16, 32, 64],
|
||||
)
|
||||
@parametrize("is_causal", [False])
|
||||
@parametrize("dropout_p", [0.0, 0.22])
|
||||
@parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32] if MEM_EFF_CAPABILITY_MATCHES_SM80
|
||||
else [torch.float16, torch.float32])
|
||||
@parametrize(
|
||||
"dtype",
|
||||
(
|
||||
[torch.float16, torch.bfloat16, torch.float32]
|
||||
if MEM_EFF_CAPABILITY_MATCHES_SM80
|
||||
else [torch.float16, torch.float32]
|
||||
),
|
||||
)
|
||||
@parametrize("scale", [None, "l1"])
|
||||
@tf32_enabled()
|
||||
def test_mem_efficient_attention_attn_mask_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int,
|
||||
seq_len_k: int, head_dim: int, is_causal: bool,
|
||||
dropout_p: float, dtype: torch.dtype,
|
||||
|
|
@ -3137,7 +3167,6 @@ class TestSDPACudaOnly(NNTestCase):
|
|||
|
||||
attn_mask = torch.rand(seq_len_q, seq_len_k, device=device, dtype=dtype, requires_grad=True)
|
||||
|
||||
|
||||
higher_precision_dtype = torch.float64 if dtype == torch.float32 else torch.float32
|
||||
query_ref, key_ref, value_ref = query_key_value_clones(query, key, value, dtype=higher_precision_dtype)
|
||||
attn_mask_ref = attn_mask.detach().to(higher_precision_dtype).requires_grad_(True)
|
||||
|
|
@ -3204,7 +3233,10 @@ class TestSDPACudaOnly(NNTestCase):
|
|||
fudge_factors=fudge_factors,
|
||||
)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware")
|
||||
@unittest.skipIf(
|
||||
not PLATFORM_SUPPORTS_FLASH_ATTENTION,
|
||||
"Does not support SDPA or pre-SM80 hardware",
|
||||
)
|
||||
@unittest.skipIf(IS_JETSON, "causing sigkill on Jetson")
|
||||
@parametrize("batch_size", [1, 8])
|
||||
@parametrize("seq_len_q", [4, 143, 2048])
|
||||
|
|
@ -3216,6 +3248,7 @@ class TestSDPACudaOnly(NNTestCase):
|
|||
@parametrize("scale", [None, "l1"])
|
||||
@parametrize("enable_gqa", [True, False] if not TEST_WITH_ROCM else [False])
|
||||
@parametrize("n_heads", [[16, 8], [10, 2]])
|
||||
@tf32_enabled()
|
||||
def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int,
|
||||
head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype,
|
||||
scale: str, enable_gqa: bool, n_heads: List[int]):
|
||||
|
|
@ -3327,7 +3360,10 @@ class TestSDPACudaOnly(NNTestCase):
|
|||
fudge_factors=fudge_factors,
|
||||
)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware")
|
||||
@unittest.skipIf(
|
||||
not PLATFORM_SUPPORTS_FLASH_ATTENTION,
|
||||
"Does not support SDPA or pre-SM80 hardware",
|
||||
)
|
||||
@parametrize("batch_size", [1, 8])
|
||||
@parametrize("seq_len_q", [256, 1024])
|
||||
@parametrize("seq_len_k", [256, 1024])
|
||||
|
|
@ -3337,6 +3373,7 @@ class TestSDPACudaOnly(NNTestCase):
|
|||
@parametrize("dtype", [torch.float16])
|
||||
@parametrize("scale", [None, "l1"])
|
||||
@parametrize("fused_kernel", PLATFORM_SPECIFIC_SDPA)
|
||||
@tf32_enabled()
|
||||
def test_fused_attention_vs_math_ref_grads_cudagraph(self, device, batch_size: int,
|
||||
seq_len_q: int, seq_len_k: int,
|
||||
head_dim: int,
|
||||
|
|
@ -3479,7 +3516,6 @@ class TestSDPACudaOnly(NNTestCase):
|
|||
}
|
||||
)
|
||||
|
||||
|
||||
@skipIfRocm # Nested Tensor
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system")
|
||||
@parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] if
|
||||
|
|
|
|||
|
|
@ -154,6 +154,23 @@ def tf32_on(self, tf32_precision=1e-5):
|
|||
self.precision = old_precision
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def tf32_enabled():
|
||||
"""
|
||||
Context manager to temporarily enable TF32 for CUDA operations.
|
||||
Restores the previous TF32 state after exiting the context.
|
||||
"""
|
||||
old_allow_tf32_matmul = torch.backends.cuda.matmul.allow_tf32
|
||||
try:
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
with torch.backends.cudnn.flags(
|
||||
enabled=None, benchmark=None, deterministic=None, allow_tf32=True
|
||||
):
|
||||
yield
|
||||
finally:
|
||||
torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32_matmul
|
||||
|
||||
|
||||
# This is a wrapper that wraps a test to run this test twice, one with
|
||||
# allow_tf32=True, another with allow_tf32=False. When running with
|
||||
# allow_tf32=True, it will use reduced precision as specified by the
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user