mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Inductor UT] Reuse test_fused_attention.py for Intel GPU. (#154110)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154110 Approved by: https://github.com/eellison, https://github.com/jansel, https://github.com/EikanWang
This commit is contained in:
parent
8fe7ec6721
commit
2dfc0e3327
|
|
@ -15,7 +15,7 @@ from torch.testing._internal.common_cuda import (
|
|||
SM80OrLater,
|
||||
)
|
||||
from torch.testing._internal.common_utils import IS_LINUX, skipIfRocm
|
||||
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
|
||||
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_CUDA, HAS_XPU
|
||||
|
||||
|
||||
def checkpoint_wrapper(fn):
|
||||
|
|
@ -61,6 +61,10 @@ class TestSDPAPatternRewriterTemplate(TestCase):
|
|||
args2 = self._clone_inputs(args1)
|
||||
|
||||
for training in [False, True] if check_train else [False]:
|
||||
if training and self.device == "xpu":
|
||||
# Intel GPU have not implemented sdpa backward yet mode.
|
||||
# TODO: remove this when sdpa backward is implemented for XPU.
|
||||
continue
|
||||
for x in itertools.chain(args1[:], args2[:]):
|
||||
if isinstance(x, torch.Tensor) and x.is_floating_point():
|
||||
x.requires_grad = training
|
||||
|
|
@ -120,7 +124,7 @@ class TestSDPAPatternRewriterTemplate(TestCase):
|
|||
for dtype in [torch.float, torch.half]:
|
||||
atol = 0.001
|
||||
rtol = 1.3e-6 if dtype == torch.float else 0.7
|
||||
if self.device == "cpu" and dtype == torch.half:
|
||||
if self.device in ["cpu", "xpu"] and dtype == torch.half:
|
||||
atol = 2e-3
|
||||
rtol = 1e-2
|
||||
self._check_common(dot_prod_attention, dtype=dtype, atol=atol, rtol=rtol)
|
||||
|
|
@ -144,10 +148,10 @@ class TestSDPAPatternRewriterTemplate(TestCase):
|
|||
.matmul(value)
|
||||
)
|
||||
|
||||
for dtype in [torch.float, torch.half]:
|
||||
for dtype in [torch.half]:
|
||||
atol = 0.001
|
||||
rtol = 1.3e-6 if dtype == torch.float else 0.7
|
||||
if self.device == "cpu" and dtype == torch.half:
|
||||
if self.device in ["cpu", "xpu"] and dtype == torch.half:
|
||||
atol = 2e-3
|
||||
rtol = 1e-2
|
||||
with torch.no_grad():
|
||||
|
|
@ -160,6 +164,11 @@ class TestSDPAPatternRewriterTemplate(TestCase):
|
|||
)
|
||||
|
||||
def _test_insignificant_strides(self):
|
||||
if self.device == "xpu":
|
||||
self.skipTest(
|
||||
"The operator 'aten::_scaled_dot_product_efficient_attention'"
|
||||
" is not currently implemented for the XPU device. "
|
||||
)
|
||||
f32 = torch.float32
|
||||
|
||||
# repro taken from https://github.com/pytorch/pytorch/issues/124289
|
||||
|
|
@ -229,7 +238,7 @@ class TestSDPAPatternRewriterTemplate(TestCase):
|
|||
)
|
||||
return _scaled_dot_product_efficient_attention
|
||||
|
||||
kwargs = aot_graph_input_parser(forward, device="cuda")
|
||||
kwargs = aot_graph_input_parser(forward, device=GPU_TYPE)
|
||||
# runs successfully
|
||||
out_eager = forward(**kwargs)
|
||||
out_c = torch.compile(forward)(**kwargs)
|
||||
|
|
@ -389,9 +398,9 @@ class TestSDPAPatternRewriterTemplate(TestCase):
|
|||
)
|
||||
|
||||
args = (
|
||||
torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half),
|
||||
torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half),
|
||||
torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half),
|
||||
torch.randn((2, 8, 4, 16), device=GPU_TYPE, dtype=torch.half),
|
||||
torch.randn((2, 8, 4, 16), device=GPU_TYPE, dtype=torch.half),
|
||||
torch.randn((2, 8, 4, 16), device=GPU_TYPE, dtype=torch.half),
|
||||
)
|
||||
self._check_common(
|
||||
checkpoint_wrapper(sfdp_pattern_7),
|
||||
|
|
@ -421,9 +430,9 @@ class TestSDPAPatternRewriterTemplate(TestCase):
|
|||
self._check_common(sfdp_pattern_8, args, atol=2e-3)
|
||||
|
||||
args = (
|
||||
torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half),
|
||||
torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half),
|
||||
torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half),
|
||||
torch.randn((2, 8, 4, 16), device=GPU_TYPE, dtype=torch.half),
|
||||
torch.randn((2, 8, 4, 16), device=GPU_TYPE, dtype=torch.half),
|
||||
torch.randn((2, 8, 4, 16), device=GPU_TYPE, dtype=torch.half),
|
||||
)
|
||||
self._check_common(checkpoint_wrapper(sfdp_pattern_8), args, atol=2e-3)
|
||||
|
||||
|
|
@ -455,9 +464,9 @@ class TestSDPAPatternRewriterTemplate(TestCase):
|
|||
atol=2e-3,
|
||||
)
|
||||
args = (
|
||||
torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half),
|
||||
torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half),
|
||||
torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half),
|
||||
torch.randn((2, 8, 4, 16), device=GPU_TYPE, dtype=torch.half),
|
||||
torch.randn((2, 8, 4, 16), device=GPU_TYPE, dtype=torch.half),
|
||||
torch.randn((2, 8, 4, 16), device=GPU_TYPE, dtype=torch.half),
|
||||
)
|
||||
self._check_common(
|
||||
checkpoint_wrapper(sfdp_pattern_9),
|
||||
|
|
@ -488,9 +497,9 @@ class TestSDPAPatternRewriterTemplate(TestCase):
|
|||
self._check_common(sfdp_pattern_10, args, atol=2e-3)
|
||||
|
||||
args = (
|
||||
torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half),
|
||||
torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half),
|
||||
torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half),
|
||||
torch.randn((2, 8, 4, 16), device=GPU_TYPE, dtype=torch.half),
|
||||
torch.randn((2, 8, 4, 16), device=GPU_TYPE, dtype=torch.half),
|
||||
torch.randn((2, 8, 4, 16), device=GPU_TYPE, dtype=torch.half),
|
||||
)
|
||||
self._check_common(checkpoint_wrapper(sfdp_pattern_10), args, atol=2e-3)
|
||||
|
||||
|
|
@ -969,84 +978,66 @@ class TestSDPAPatternRewriterTemplate(TestCase):
|
|||
self._check_common(dot_prod_attention, check_train=False, has_dropout=True)
|
||||
|
||||
|
||||
if HAS_CUDA and PLATFORM_SUPPORTS_FUSED_ATTENTION:
|
||||
if HAS_XPU or (HAS_CUDA and PLATFORM_SUPPORTS_FUSED_ATTENTION):
|
||||
|
||||
class SDPAPatternRewriterCudaTests(TestSDPAPatternRewriterTemplate):
|
||||
device = "cuda"
|
||||
test_sdpa_rewriter_1_cuda = (
|
||||
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_1
|
||||
)
|
||||
class SDPAPatternRewriterGpuTests(TestSDPAPatternRewriterTemplate):
|
||||
device = GPU_TYPE
|
||||
test_sdpa_rewriter_1_gpu = TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_1
|
||||
test_sdpa_rewriter_1_freezing = (
|
||||
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_1_freezing
|
||||
)
|
||||
test_insignificant_strides = (
|
||||
TestSDPAPatternRewriterTemplate._test_insignificant_strides
|
||||
)
|
||||
test_pattern_fails_with_reuse_cuda = (
|
||||
test_pattern_fails_with_reuse_gpu = (
|
||||
TestSDPAPatternRewriterTemplate._test_pattern_fails_with_reuse
|
||||
)
|
||||
test_sdpa_rewriter_2_cuda = (
|
||||
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_2
|
||||
)
|
||||
test_sdpa_rewriter_3_cuda = (
|
||||
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_3
|
||||
)
|
||||
test_sdpa_rewriter_4_cuda = (
|
||||
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_4
|
||||
)
|
||||
test_sdpa_rewriter_5_cuda = (
|
||||
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_5
|
||||
)
|
||||
test_sdpa_rewriter_6_cuda = (
|
||||
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_6
|
||||
)
|
||||
test_sdpa_rewriter_7_cuda = (
|
||||
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_7
|
||||
)
|
||||
test_sdpa_rewriter_8_cuda = (
|
||||
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_8
|
||||
)
|
||||
test_sdpa_rewriter_9_cuda = (
|
||||
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_9
|
||||
)
|
||||
test_sdpa_rewriter_10_cuda = (
|
||||
test_sdpa_rewriter_2_gpu = TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_2
|
||||
test_sdpa_rewriter_3_gpu = TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_3
|
||||
test_sdpa_rewriter_4_gpu = TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_4
|
||||
test_sdpa_rewriter_5_gpu = TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_5
|
||||
test_sdpa_rewriter_6_gpu = TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_6
|
||||
test_sdpa_rewriter_7_gpu = TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_7
|
||||
test_sdpa_rewriter_8_gpu = TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_8
|
||||
test_sdpa_rewriter_9_gpu = TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_9
|
||||
test_sdpa_rewriter_10_gpu = (
|
||||
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_10
|
||||
)
|
||||
test_pattern_fails_with_tensor_factor_cuda = (
|
||||
test_pattern_fails_with_tensor_factor_gpu = (
|
||||
TestSDPAPatternRewriterTemplate._test_pattern_fails_with_tensor_factor
|
||||
)
|
||||
test_pattern_fails_with_unsupported_mask_cuda = (
|
||||
test_pattern_fails_with_unsupported_mask_gpu = (
|
||||
TestSDPAPatternRewriterTemplate._test_pattern_fails_with_unsupported_mask
|
||||
)
|
||||
test_sdpa_rewriter_11_cuda = (
|
||||
test_sdpa_rewriter_11_gpu = (
|
||||
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_11
|
||||
)
|
||||
test_sdpa_rewriter_12_cuda = (
|
||||
test_sdpa_rewriter_12_gpu = (
|
||||
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_12
|
||||
)
|
||||
test_sdpa_prev_13_cuda = TestSDPAPatternRewriterTemplate._test_sdpa_prev_13
|
||||
test_sdpa_prev_14_cuda = TestSDPAPatternRewriterTemplate._test_sdpa_prev_14
|
||||
test_sdpa_prev_15_cuda = TestSDPAPatternRewriterTemplate._test_sdpa_prev_15
|
||||
test_sdpa_rewriter_13_cuda = functools.partialmethod(
|
||||
test_sdpa_prev_13_gpu = TestSDPAPatternRewriterTemplate._test_sdpa_prev_13
|
||||
test_sdpa_prev_14_gpu = TestSDPAPatternRewriterTemplate._test_sdpa_prev_14
|
||||
test_sdpa_prev_15_gpu = TestSDPAPatternRewriterTemplate._test_sdpa_prev_15
|
||||
test_sdpa_rewriter_13_gpu = functools.partialmethod(
|
||||
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_13, dtype=torch.half
|
||||
)
|
||||
test_sdpa_rewriter_14_cuda = functools.partialmethod(
|
||||
test_sdpa_rewriter_14_gpu = functools.partialmethod(
|
||||
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_14
|
||||
)
|
||||
test_sdpa_rewriter_15_cuda = functools.partialmethod(
|
||||
test_sdpa_rewriter_15_gpu = functools.partialmethod(
|
||||
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_15
|
||||
)
|
||||
test_sdpa_rewriter_17_cuda = functools.partialmethod(
|
||||
test_sdpa_rewriter_17_gpu = functools.partialmethod(
|
||||
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_17
|
||||
)
|
||||
test_sdpa_rewriter_19_cuda = functools.partialmethod(
|
||||
test_sdpa_rewriter_19_gpu = functools.partialmethod(
|
||||
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_19
|
||||
)
|
||||
test_sdpa_rewriter_20_cuda = functools.partialmethod(
|
||||
test_sdpa_rewriter_20_gpu = functools.partialmethod(
|
||||
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_20
|
||||
)
|
||||
|
||||
class SDPAPatternRewriterCudaDynamicTests(SDPAPatternRewriterCudaTests):
|
||||
class SDPAPatternRewriterGpuDynamicTests(SDPAPatternRewriterGpuTests):
|
||||
use_static_shapes = False
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user