[Inductor UT] Reuse test_fused_attention.py for Intel GPU. (#154110)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154110
Approved by: https://github.com/eellison, https://github.com/jansel, https://github.com/EikanWang
This commit is contained in:
xinan.lin 2025-05-23 19:49:30 -07:00 committed by PyTorch MergeBot
parent 8fe7ec6721
commit 2dfc0e3327

View File

@ -15,7 +15,7 @@ from torch.testing._internal.common_cuda import (
SM80OrLater,
)
from torch.testing._internal.common_utils import IS_LINUX, skipIfRocm
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_CUDA, HAS_XPU
def checkpoint_wrapper(fn):
@ -61,6 +61,10 @@ class TestSDPAPatternRewriterTemplate(TestCase):
args2 = self._clone_inputs(args1)
for training in [False, True] if check_train else [False]:
if training and self.device == "xpu":
# Intel GPU have not implemented sdpa backward yet mode.
# TODO: remove this when sdpa backward is implemented for XPU.
continue
for x in itertools.chain(args1[:], args2[:]):
if isinstance(x, torch.Tensor) and x.is_floating_point():
x.requires_grad = training
@ -120,7 +124,7 @@ class TestSDPAPatternRewriterTemplate(TestCase):
for dtype in [torch.float, torch.half]:
atol = 0.001
rtol = 1.3e-6 if dtype == torch.float else 0.7
if self.device == "cpu" and dtype == torch.half:
if self.device in ["cpu", "xpu"] and dtype == torch.half:
atol = 2e-3
rtol = 1e-2
self._check_common(dot_prod_attention, dtype=dtype, atol=atol, rtol=rtol)
@ -144,10 +148,10 @@ class TestSDPAPatternRewriterTemplate(TestCase):
.matmul(value)
)
for dtype in [torch.float, torch.half]:
for dtype in [torch.half]:
atol = 0.001
rtol = 1.3e-6 if dtype == torch.float else 0.7
if self.device == "cpu" and dtype == torch.half:
if self.device in ["cpu", "xpu"] and dtype == torch.half:
atol = 2e-3
rtol = 1e-2
with torch.no_grad():
@ -160,6 +164,11 @@ class TestSDPAPatternRewriterTemplate(TestCase):
)
def _test_insignificant_strides(self):
if self.device == "xpu":
self.skipTest(
"The operator 'aten::_scaled_dot_product_efficient_attention'"
" is not currently implemented for the XPU device. "
)
f32 = torch.float32
# repro taken from https://github.com/pytorch/pytorch/issues/124289
@ -229,7 +238,7 @@ class TestSDPAPatternRewriterTemplate(TestCase):
)
return _scaled_dot_product_efficient_attention
kwargs = aot_graph_input_parser(forward, device="cuda")
kwargs = aot_graph_input_parser(forward, device=GPU_TYPE)
# runs successfully
out_eager = forward(**kwargs)
out_c = torch.compile(forward)(**kwargs)
@ -389,9 +398,9 @@ class TestSDPAPatternRewriterTemplate(TestCase):
)
args = (
torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half),
torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half),
torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half),
torch.randn((2, 8, 4, 16), device=GPU_TYPE, dtype=torch.half),
torch.randn((2, 8, 4, 16), device=GPU_TYPE, dtype=torch.half),
torch.randn((2, 8, 4, 16), device=GPU_TYPE, dtype=torch.half),
)
self._check_common(
checkpoint_wrapper(sfdp_pattern_7),
@ -421,9 +430,9 @@ class TestSDPAPatternRewriterTemplate(TestCase):
self._check_common(sfdp_pattern_8, args, atol=2e-3)
args = (
torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half),
torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half),
torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half),
torch.randn((2, 8, 4, 16), device=GPU_TYPE, dtype=torch.half),
torch.randn((2, 8, 4, 16), device=GPU_TYPE, dtype=torch.half),
torch.randn((2, 8, 4, 16), device=GPU_TYPE, dtype=torch.half),
)
self._check_common(checkpoint_wrapper(sfdp_pattern_8), args, atol=2e-3)
@ -455,9 +464,9 @@ class TestSDPAPatternRewriterTemplate(TestCase):
atol=2e-3,
)
args = (
torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half),
torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half),
torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half),
torch.randn((2, 8, 4, 16), device=GPU_TYPE, dtype=torch.half),
torch.randn((2, 8, 4, 16), device=GPU_TYPE, dtype=torch.half),
torch.randn((2, 8, 4, 16), device=GPU_TYPE, dtype=torch.half),
)
self._check_common(
checkpoint_wrapper(sfdp_pattern_9),
@ -488,9 +497,9 @@ class TestSDPAPatternRewriterTemplate(TestCase):
self._check_common(sfdp_pattern_10, args, atol=2e-3)
args = (
torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half),
torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half),
torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half),
torch.randn((2, 8, 4, 16), device=GPU_TYPE, dtype=torch.half),
torch.randn((2, 8, 4, 16), device=GPU_TYPE, dtype=torch.half),
torch.randn((2, 8, 4, 16), device=GPU_TYPE, dtype=torch.half),
)
self._check_common(checkpoint_wrapper(sfdp_pattern_10), args, atol=2e-3)
@ -969,84 +978,66 @@ class TestSDPAPatternRewriterTemplate(TestCase):
self._check_common(dot_prod_attention, check_train=False, has_dropout=True)
if HAS_CUDA and PLATFORM_SUPPORTS_FUSED_ATTENTION:
if HAS_XPU or (HAS_CUDA and PLATFORM_SUPPORTS_FUSED_ATTENTION):
class SDPAPatternRewriterCudaTests(TestSDPAPatternRewriterTemplate):
device = "cuda"
test_sdpa_rewriter_1_cuda = (
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_1
)
class SDPAPatternRewriterGpuTests(TestSDPAPatternRewriterTemplate):
device = GPU_TYPE
test_sdpa_rewriter_1_gpu = TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_1
test_sdpa_rewriter_1_freezing = (
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_1_freezing
)
test_insignificant_strides = (
TestSDPAPatternRewriterTemplate._test_insignificant_strides
)
test_pattern_fails_with_reuse_cuda = (
test_pattern_fails_with_reuse_gpu = (
TestSDPAPatternRewriterTemplate._test_pattern_fails_with_reuse
)
test_sdpa_rewriter_2_cuda = (
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_2
)
test_sdpa_rewriter_3_cuda = (
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_3
)
test_sdpa_rewriter_4_cuda = (
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_4
)
test_sdpa_rewriter_5_cuda = (
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_5
)
test_sdpa_rewriter_6_cuda = (
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_6
)
test_sdpa_rewriter_7_cuda = (
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_7
)
test_sdpa_rewriter_8_cuda = (
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_8
)
test_sdpa_rewriter_9_cuda = (
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_9
)
test_sdpa_rewriter_10_cuda = (
test_sdpa_rewriter_2_gpu = TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_2
test_sdpa_rewriter_3_gpu = TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_3
test_sdpa_rewriter_4_gpu = TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_4
test_sdpa_rewriter_5_gpu = TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_5
test_sdpa_rewriter_6_gpu = TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_6
test_sdpa_rewriter_7_gpu = TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_7
test_sdpa_rewriter_8_gpu = TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_8
test_sdpa_rewriter_9_gpu = TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_9
test_sdpa_rewriter_10_gpu = (
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_10
)
test_pattern_fails_with_tensor_factor_cuda = (
test_pattern_fails_with_tensor_factor_gpu = (
TestSDPAPatternRewriterTemplate._test_pattern_fails_with_tensor_factor
)
test_pattern_fails_with_unsupported_mask_cuda = (
test_pattern_fails_with_unsupported_mask_gpu = (
TestSDPAPatternRewriterTemplate._test_pattern_fails_with_unsupported_mask
)
test_sdpa_rewriter_11_cuda = (
test_sdpa_rewriter_11_gpu = (
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_11
)
test_sdpa_rewriter_12_cuda = (
test_sdpa_rewriter_12_gpu = (
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_12
)
test_sdpa_prev_13_cuda = TestSDPAPatternRewriterTemplate._test_sdpa_prev_13
test_sdpa_prev_14_cuda = TestSDPAPatternRewriterTemplate._test_sdpa_prev_14
test_sdpa_prev_15_cuda = TestSDPAPatternRewriterTemplate._test_sdpa_prev_15
test_sdpa_rewriter_13_cuda = functools.partialmethod(
test_sdpa_prev_13_gpu = TestSDPAPatternRewriterTemplate._test_sdpa_prev_13
test_sdpa_prev_14_gpu = TestSDPAPatternRewriterTemplate._test_sdpa_prev_14
test_sdpa_prev_15_gpu = TestSDPAPatternRewriterTemplate._test_sdpa_prev_15
test_sdpa_rewriter_13_gpu = functools.partialmethod(
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_13, dtype=torch.half
)
test_sdpa_rewriter_14_cuda = functools.partialmethod(
test_sdpa_rewriter_14_gpu = functools.partialmethod(
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_14
)
test_sdpa_rewriter_15_cuda = functools.partialmethod(
test_sdpa_rewriter_15_gpu = functools.partialmethod(
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_15
)
test_sdpa_rewriter_17_cuda = functools.partialmethod(
test_sdpa_rewriter_17_gpu = functools.partialmethod(
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_17
)
test_sdpa_rewriter_19_cuda = functools.partialmethod(
test_sdpa_rewriter_19_gpu = functools.partialmethod(
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_19
)
test_sdpa_rewriter_20_cuda = functools.partialmethod(
test_sdpa_rewriter_20_gpu = functools.partialmethod(
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_20
)
class SDPAPatternRewriterCudaDynamicTests(SDPAPatternRewriterCudaTests):
class SDPAPatternRewriterGpuDynamicTests(SDPAPatternRewriterGpuTests):
use_static_shapes = False