mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Make sure all SDPA tests are ran with tensor cores enabled (#135592)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135592 Approved by: https://github.com/eqy
This commit is contained in:
parent
c81d4fd0a8
commit
80c7c7178e
|
|
@ -45,12 +45,15 @@ from torch._dynamo.testing import CompileCounterWithBackend
|
||||||
|
|
||||||
from torch.testing._internal.common_methods_invocations import wrapper_set_seed
|
from torch.testing._internal.common_methods_invocations import wrapper_set_seed
|
||||||
from torch.testing._internal.common_cuda import (
|
from torch.testing._internal.common_cuda import (
|
||||||
IS_JETSON, SM80OrLater, PLATFORM_SUPPORTS_FLASH_ATTENTION,
|
IS_JETSON,
|
||||||
|
SM80OrLater,
|
||||||
|
PLATFORM_SUPPORTS_FLASH_ATTENTION,
|
||||||
PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
|
PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
|
||||||
PLATFORM_SUPPORTS_FUSED_ATTENTION,
|
PLATFORM_SUPPORTS_FUSED_ATTENTION,
|
||||||
PLATFORM_SUPPORTS_CUDNN_ATTENTION,
|
PLATFORM_SUPPORTS_CUDNN_ATTENTION,
|
||||||
SM90OrLater,
|
SM90OrLater,
|
||||||
tf32_on_and_off
|
tf32_on_and_off,
|
||||||
|
tf32_enabled,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not IS_FBCODE:
|
if not IS_FBCODE:
|
||||||
|
|
@ -64,6 +67,7 @@ if TEST_FAIRSEQ:
|
||||||
SdpaShape = namedtuple('Sdpa_Shape', ['batch', 'num_heads', 'seq_len', 'head_dim'])
|
SdpaShape = namedtuple('Sdpa_Shape', ['batch', 'num_heads', 'seq_len', 'head_dim'])
|
||||||
Tolerances = namedtuple('Tolerances', ['atol', 'rtol'])
|
Tolerances = namedtuple('Tolerances', ['atol', 'rtol'])
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def use_deterministic_algorithims(mode: bool, warn_only: bool):
|
def use_deterministic_algorithims(mode: bool, warn_only: bool):
|
||||||
r"""
|
r"""
|
||||||
|
|
@ -2998,17 +3002,30 @@ class TestSDPACudaOnly(NNTestCase):
|
||||||
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support SDPA")
|
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support SDPA")
|
||||||
@unittest.skipIf(IS_JETSON, "causing sigkill on Jetson")
|
@unittest.skipIf(IS_JETSON, "causing sigkill on Jetson")
|
||||||
@parametrize("batch_size", [1, 8])
|
@parametrize("batch_size", [1, 8])
|
||||||
@parametrize("seq_len_q", [8, 103, 1024, 2048] if MEM_EFF_CAPABILITY_MATCHES_SM80
|
@parametrize(
|
||||||
else [4, 8, 256, 512])
|
"seq_len_q",
|
||||||
@parametrize("seq_len_k", [8, 103, 1024, 2048] if MEM_EFF_CAPABILITY_MATCHES_SM80
|
[8, 103, 1024, 2048] if MEM_EFF_CAPABILITY_MATCHES_SM80 else [4, 8, 256, 512],
|
||||||
else [4, 8, 256, 512])
|
)
|
||||||
@parametrize("head_dim", [8, 16, 96, 128] if MEM_EFF_CAPABILITY_MATCHES_SM80
|
@parametrize(
|
||||||
else [8, 16, 32, 64])
|
"seq_len_k",
|
||||||
|
[8, 103, 1024, 2048] if MEM_EFF_CAPABILITY_MATCHES_SM80 else [4, 8, 256, 512],
|
||||||
|
)
|
||||||
|
@parametrize(
|
||||||
|
"head_dim",
|
||||||
|
[8, 16, 96, 128] if MEM_EFF_CAPABILITY_MATCHES_SM80 else [8, 16, 32, 64],
|
||||||
|
)
|
||||||
@parametrize("is_causal", [False, True])
|
@parametrize("is_causal", [False, True])
|
||||||
@parametrize("dropout_p", [0.0, 0.22])
|
@parametrize("dropout_p", [0.0, 0.22])
|
||||||
@parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32] if MEM_EFF_CAPABILITY_MATCHES_SM80
|
@parametrize(
|
||||||
else [torch.float16, torch.float32])
|
"dtype",
|
||||||
|
(
|
||||||
|
[torch.float16, torch.bfloat16, torch.float32]
|
||||||
|
if MEM_EFF_CAPABILITY_MATCHES_SM80
|
||||||
|
else [torch.float16, torch.float32]
|
||||||
|
),
|
||||||
|
)
|
||||||
@parametrize("scale", [None, "l1"])
|
@parametrize("scale", [None, "l1"])
|
||||||
|
@tf32_enabled()
|
||||||
def test_mem_efficient_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int,
|
def test_mem_efficient_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int,
|
||||||
head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype,
|
head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype,
|
||||||
scale: str):
|
scale: str):
|
||||||
|
|
@ -3097,17 +3114,30 @@ class TestSDPACudaOnly(NNTestCase):
|
||||||
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support SDPA")
|
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support SDPA")
|
||||||
@unittest.skipIf(IS_JETSON, "causing sigkill on Jetson")
|
@unittest.skipIf(IS_JETSON, "causing sigkill on Jetson")
|
||||||
@parametrize("batch_size", [1, 8])
|
@parametrize("batch_size", [1, 8])
|
||||||
@parametrize("seq_len_q", [8, 312, 1024, 2048] if MEM_EFF_CAPABILITY_MATCHES_SM80
|
@parametrize(
|
||||||
else [8, 152, 512])
|
"seq_len_q",
|
||||||
@parametrize("seq_len_k", [8, 408, 1024, 2048] if MEM_EFF_CAPABILITY_MATCHES_SM80
|
[8, 312, 1024, 2048] if MEM_EFF_CAPABILITY_MATCHES_SM80 else [8, 152, 512],
|
||||||
else [8, 37, 512])
|
)
|
||||||
@parametrize("head_dim", [8, 16, 96, 128] if MEM_EFF_CAPABILITY_MATCHES_SM80
|
@parametrize(
|
||||||
else [8, 16, 32, 64])
|
"seq_len_k",
|
||||||
|
[8, 408, 1024, 2048] if MEM_EFF_CAPABILITY_MATCHES_SM80 else [8, 37, 512],
|
||||||
|
)
|
||||||
|
@parametrize(
|
||||||
|
"head_dim",
|
||||||
|
[8, 16, 96, 128] if MEM_EFF_CAPABILITY_MATCHES_SM80 else [8, 16, 32, 64],
|
||||||
|
)
|
||||||
@parametrize("is_causal", [False])
|
@parametrize("is_causal", [False])
|
||||||
@parametrize("dropout_p", [0.0, 0.22])
|
@parametrize("dropout_p", [0.0, 0.22])
|
||||||
@parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32] if MEM_EFF_CAPABILITY_MATCHES_SM80
|
@parametrize(
|
||||||
else [torch.float16, torch.float32])
|
"dtype",
|
||||||
|
(
|
||||||
|
[torch.float16, torch.bfloat16, torch.float32]
|
||||||
|
if MEM_EFF_CAPABILITY_MATCHES_SM80
|
||||||
|
else [torch.float16, torch.float32]
|
||||||
|
),
|
||||||
|
)
|
||||||
@parametrize("scale", [None, "l1"])
|
@parametrize("scale", [None, "l1"])
|
||||||
|
@tf32_enabled()
|
||||||
def test_mem_efficient_attention_attn_mask_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int,
|
def test_mem_efficient_attention_attn_mask_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int,
|
||||||
seq_len_k: int, head_dim: int, is_causal: bool,
|
seq_len_k: int, head_dim: int, is_causal: bool,
|
||||||
dropout_p: float, dtype: torch.dtype,
|
dropout_p: float, dtype: torch.dtype,
|
||||||
|
|
@ -3137,7 +3167,6 @@ class TestSDPACudaOnly(NNTestCase):
|
||||||
|
|
||||||
attn_mask = torch.rand(seq_len_q, seq_len_k, device=device, dtype=dtype, requires_grad=True)
|
attn_mask = torch.rand(seq_len_q, seq_len_k, device=device, dtype=dtype, requires_grad=True)
|
||||||
|
|
||||||
|
|
||||||
higher_precision_dtype = torch.float64 if dtype == torch.float32 else torch.float32
|
higher_precision_dtype = torch.float64 if dtype == torch.float32 else torch.float32
|
||||||
query_ref, key_ref, value_ref = query_key_value_clones(query, key, value, dtype=higher_precision_dtype)
|
query_ref, key_ref, value_ref = query_key_value_clones(query, key, value, dtype=higher_precision_dtype)
|
||||||
attn_mask_ref = attn_mask.detach().to(higher_precision_dtype).requires_grad_(True)
|
attn_mask_ref = attn_mask.detach().to(higher_precision_dtype).requires_grad_(True)
|
||||||
|
|
@ -3204,7 +3233,10 @@ class TestSDPACudaOnly(NNTestCase):
|
||||||
fudge_factors=fudge_factors,
|
fudge_factors=fudge_factors,
|
||||||
)
|
)
|
||||||
|
|
||||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware")
|
@unittest.skipIf(
|
||||||
|
not PLATFORM_SUPPORTS_FLASH_ATTENTION,
|
||||||
|
"Does not support SDPA or pre-SM80 hardware",
|
||||||
|
)
|
||||||
@unittest.skipIf(IS_JETSON, "causing sigkill on Jetson")
|
@unittest.skipIf(IS_JETSON, "causing sigkill on Jetson")
|
||||||
@parametrize("batch_size", [1, 8])
|
@parametrize("batch_size", [1, 8])
|
||||||
@parametrize("seq_len_q", [4, 143, 2048])
|
@parametrize("seq_len_q", [4, 143, 2048])
|
||||||
|
|
@ -3216,6 +3248,7 @@ class TestSDPACudaOnly(NNTestCase):
|
||||||
@parametrize("scale", [None, "l1"])
|
@parametrize("scale", [None, "l1"])
|
||||||
@parametrize("enable_gqa", [True, False] if not TEST_WITH_ROCM else [False])
|
@parametrize("enable_gqa", [True, False] if not TEST_WITH_ROCM else [False])
|
||||||
@parametrize("n_heads", [[16, 8], [10, 2]])
|
@parametrize("n_heads", [[16, 8], [10, 2]])
|
||||||
|
@tf32_enabled()
|
||||||
def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int,
|
def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int,
|
||||||
head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype,
|
head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype,
|
||||||
scale: str, enable_gqa: bool, n_heads: List[int]):
|
scale: str, enable_gqa: bool, n_heads: List[int]):
|
||||||
|
|
@ -3327,7 +3360,10 @@ class TestSDPACudaOnly(NNTestCase):
|
||||||
fudge_factors=fudge_factors,
|
fudge_factors=fudge_factors,
|
||||||
)
|
)
|
||||||
|
|
||||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware")
|
@unittest.skipIf(
|
||||||
|
not PLATFORM_SUPPORTS_FLASH_ATTENTION,
|
||||||
|
"Does not support SDPA or pre-SM80 hardware",
|
||||||
|
)
|
||||||
@parametrize("batch_size", [1, 8])
|
@parametrize("batch_size", [1, 8])
|
||||||
@parametrize("seq_len_q", [256, 1024])
|
@parametrize("seq_len_q", [256, 1024])
|
||||||
@parametrize("seq_len_k", [256, 1024])
|
@parametrize("seq_len_k", [256, 1024])
|
||||||
|
|
@ -3337,6 +3373,7 @@ class TestSDPACudaOnly(NNTestCase):
|
||||||
@parametrize("dtype", [torch.float16])
|
@parametrize("dtype", [torch.float16])
|
||||||
@parametrize("scale", [None, "l1"])
|
@parametrize("scale", [None, "l1"])
|
||||||
@parametrize("fused_kernel", PLATFORM_SPECIFIC_SDPA)
|
@parametrize("fused_kernel", PLATFORM_SPECIFIC_SDPA)
|
||||||
|
@tf32_enabled()
|
||||||
def test_fused_attention_vs_math_ref_grads_cudagraph(self, device, batch_size: int,
|
def test_fused_attention_vs_math_ref_grads_cudagraph(self, device, batch_size: int,
|
||||||
seq_len_q: int, seq_len_k: int,
|
seq_len_q: int, seq_len_k: int,
|
||||||
head_dim: int,
|
head_dim: int,
|
||||||
|
|
@ -3479,7 +3516,6 @@ class TestSDPACudaOnly(NNTestCase):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@skipIfRocm # Nested Tensor
|
@skipIfRocm # Nested Tensor
|
||||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system")
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system")
|
||||||
@parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] if
|
@parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] if
|
||||||
|
|
|
||||||
|
|
@ -154,6 +154,23 @@ def tf32_on(self, tf32_precision=1e-5):
|
||||||
self.precision = old_precision
|
self.precision = old_precision
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def tf32_enabled():
|
||||||
|
"""
|
||||||
|
Context manager to temporarily enable TF32 for CUDA operations.
|
||||||
|
Restores the previous TF32 state after exiting the context.
|
||||||
|
"""
|
||||||
|
old_allow_tf32_matmul = torch.backends.cuda.matmul.allow_tf32
|
||||||
|
try:
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
with torch.backends.cudnn.flags(
|
||||||
|
enabled=None, benchmark=None, deterministic=None, allow_tf32=True
|
||||||
|
):
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32_matmul
|
||||||
|
|
||||||
|
|
||||||
# This is a wrapper that wraps a test to run this test twice, one with
|
# This is a wrapper that wraps a test to run this test twice, one with
|
||||||
# allow_tf32=True, another with allow_tf32=False. When running with
|
# allow_tf32=True, another with allow_tf32=False. When running with
|
||||||
# allow_tf32=True, it will use reduced precision as specified by the
|
# allow_tf32=True, it will use reduced precision as specified by the
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user