mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Inductor UT] Reuse test_fused_attention.py for Intel GPU. (#154110)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154110 Approved by: https://github.com/eellison, https://github.com/jansel, https://github.com/EikanWang
This commit is contained in:
parent
8fe7ec6721
commit
2dfc0e3327
|
|
@ -15,7 +15,7 @@ from torch.testing._internal.common_cuda import (
|
||||||
SM80OrLater,
|
SM80OrLater,
|
||||||
)
|
)
|
||||||
from torch.testing._internal.common_utils import IS_LINUX, skipIfRocm
|
from torch.testing._internal.common_utils import IS_LINUX, skipIfRocm
|
||||||
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
|
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_CUDA, HAS_XPU
|
||||||
|
|
||||||
|
|
||||||
def checkpoint_wrapper(fn):
|
def checkpoint_wrapper(fn):
|
||||||
|
|
@ -61,6 +61,10 @@ class TestSDPAPatternRewriterTemplate(TestCase):
|
||||||
args2 = self._clone_inputs(args1)
|
args2 = self._clone_inputs(args1)
|
||||||
|
|
||||||
for training in [False, True] if check_train else [False]:
|
for training in [False, True] if check_train else [False]:
|
||||||
|
if training and self.device == "xpu":
|
||||||
|
# Intel GPU have not implemented sdpa backward yet mode.
|
||||||
|
# TODO: remove this when sdpa backward is implemented for XPU.
|
||||||
|
continue
|
||||||
for x in itertools.chain(args1[:], args2[:]):
|
for x in itertools.chain(args1[:], args2[:]):
|
||||||
if isinstance(x, torch.Tensor) and x.is_floating_point():
|
if isinstance(x, torch.Tensor) and x.is_floating_point():
|
||||||
x.requires_grad = training
|
x.requires_grad = training
|
||||||
|
|
@ -120,7 +124,7 @@ class TestSDPAPatternRewriterTemplate(TestCase):
|
||||||
for dtype in [torch.float, torch.half]:
|
for dtype in [torch.float, torch.half]:
|
||||||
atol = 0.001
|
atol = 0.001
|
||||||
rtol = 1.3e-6 if dtype == torch.float else 0.7
|
rtol = 1.3e-6 if dtype == torch.float else 0.7
|
||||||
if self.device == "cpu" and dtype == torch.half:
|
if self.device in ["cpu", "xpu"] and dtype == torch.half:
|
||||||
atol = 2e-3
|
atol = 2e-3
|
||||||
rtol = 1e-2
|
rtol = 1e-2
|
||||||
self._check_common(dot_prod_attention, dtype=dtype, atol=atol, rtol=rtol)
|
self._check_common(dot_prod_attention, dtype=dtype, atol=atol, rtol=rtol)
|
||||||
|
|
@ -144,10 +148,10 @@ class TestSDPAPatternRewriterTemplate(TestCase):
|
||||||
.matmul(value)
|
.matmul(value)
|
||||||
)
|
)
|
||||||
|
|
||||||
for dtype in [torch.float, torch.half]:
|
for dtype in [torch.half]:
|
||||||
atol = 0.001
|
atol = 0.001
|
||||||
rtol = 1.3e-6 if dtype == torch.float else 0.7
|
rtol = 1.3e-6 if dtype == torch.float else 0.7
|
||||||
if self.device == "cpu" and dtype == torch.half:
|
if self.device in ["cpu", "xpu"] and dtype == torch.half:
|
||||||
atol = 2e-3
|
atol = 2e-3
|
||||||
rtol = 1e-2
|
rtol = 1e-2
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
|
@ -160,6 +164,11 @@ class TestSDPAPatternRewriterTemplate(TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
def _test_insignificant_strides(self):
|
def _test_insignificant_strides(self):
|
||||||
|
if self.device == "xpu":
|
||||||
|
self.skipTest(
|
||||||
|
"The operator 'aten::_scaled_dot_product_efficient_attention'"
|
||||||
|
" is not currently implemented for the XPU device. "
|
||||||
|
)
|
||||||
f32 = torch.float32
|
f32 = torch.float32
|
||||||
|
|
||||||
# repro taken from https://github.com/pytorch/pytorch/issues/124289
|
# repro taken from https://github.com/pytorch/pytorch/issues/124289
|
||||||
|
|
@ -229,7 +238,7 @@ class TestSDPAPatternRewriterTemplate(TestCase):
|
||||||
)
|
)
|
||||||
return _scaled_dot_product_efficient_attention
|
return _scaled_dot_product_efficient_attention
|
||||||
|
|
||||||
kwargs = aot_graph_input_parser(forward, device="cuda")
|
kwargs = aot_graph_input_parser(forward, device=GPU_TYPE)
|
||||||
# runs successfully
|
# runs successfully
|
||||||
out_eager = forward(**kwargs)
|
out_eager = forward(**kwargs)
|
||||||
out_c = torch.compile(forward)(**kwargs)
|
out_c = torch.compile(forward)(**kwargs)
|
||||||
|
|
@ -389,9 +398,9 @@ class TestSDPAPatternRewriterTemplate(TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
args = (
|
args = (
|
||||||
torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half),
|
torch.randn((2, 8, 4, 16), device=GPU_TYPE, dtype=torch.half),
|
||||||
torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half),
|
torch.randn((2, 8, 4, 16), device=GPU_TYPE, dtype=torch.half),
|
||||||
torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half),
|
torch.randn((2, 8, 4, 16), device=GPU_TYPE, dtype=torch.half),
|
||||||
)
|
)
|
||||||
self._check_common(
|
self._check_common(
|
||||||
checkpoint_wrapper(sfdp_pattern_7),
|
checkpoint_wrapper(sfdp_pattern_7),
|
||||||
|
|
@ -421,9 +430,9 @@ class TestSDPAPatternRewriterTemplate(TestCase):
|
||||||
self._check_common(sfdp_pattern_8, args, atol=2e-3)
|
self._check_common(sfdp_pattern_8, args, atol=2e-3)
|
||||||
|
|
||||||
args = (
|
args = (
|
||||||
torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half),
|
torch.randn((2, 8, 4, 16), device=GPU_TYPE, dtype=torch.half),
|
||||||
torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half),
|
torch.randn((2, 8, 4, 16), device=GPU_TYPE, dtype=torch.half),
|
||||||
torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half),
|
torch.randn((2, 8, 4, 16), device=GPU_TYPE, dtype=torch.half),
|
||||||
)
|
)
|
||||||
self._check_common(checkpoint_wrapper(sfdp_pattern_8), args, atol=2e-3)
|
self._check_common(checkpoint_wrapper(sfdp_pattern_8), args, atol=2e-3)
|
||||||
|
|
||||||
|
|
@ -455,9 +464,9 @@ class TestSDPAPatternRewriterTemplate(TestCase):
|
||||||
atol=2e-3,
|
atol=2e-3,
|
||||||
)
|
)
|
||||||
args = (
|
args = (
|
||||||
torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half),
|
torch.randn((2, 8, 4, 16), device=GPU_TYPE, dtype=torch.half),
|
||||||
torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half),
|
torch.randn((2, 8, 4, 16), device=GPU_TYPE, dtype=torch.half),
|
||||||
torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half),
|
torch.randn((2, 8, 4, 16), device=GPU_TYPE, dtype=torch.half),
|
||||||
)
|
)
|
||||||
self._check_common(
|
self._check_common(
|
||||||
checkpoint_wrapper(sfdp_pattern_9),
|
checkpoint_wrapper(sfdp_pattern_9),
|
||||||
|
|
@ -488,9 +497,9 @@ class TestSDPAPatternRewriterTemplate(TestCase):
|
||||||
self._check_common(sfdp_pattern_10, args, atol=2e-3)
|
self._check_common(sfdp_pattern_10, args, atol=2e-3)
|
||||||
|
|
||||||
args = (
|
args = (
|
||||||
torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half),
|
torch.randn((2, 8, 4, 16), device=GPU_TYPE, dtype=torch.half),
|
||||||
torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half),
|
torch.randn((2, 8, 4, 16), device=GPU_TYPE, dtype=torch.half),
|
||||||
torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half),
|
torch.randn((2, 8, 4, 16), device=GPU_TYPE, dtype=torch.half),
|
||||||
)
|
)
|
||||||
self._check_common(checkpoint_wrapper(sfdp_pattern_10), args, atol=2e-3)
|
self._check_common(checkpoint_wrapper(sfdp_pattern_10), args, atol=2e-3)
|
||||||
|
|
||||||
|
|
@ -969,84 +978,66 @@ class TestSDPAPatternRewriterTemplate(TestCase):
|
||||||
self._check_common(dot_prod_attention, check_train=False, has_dropout=True)
|
self._check_common(dot_prod_attention, check_train=False, has_dropout=True)
|
||||||
|
|
||||||
|
|
||||||
if HAS_CUDA and PLATFORM_SUPPORTS_FUSED_ATTENTION:
|
if HAS_XPU or (HAS_CUDA and PLATFORM_SUPPORTS_FUSED_ATTENTION):
|
||||||
|
|
||||||
class SDPAPatternRewriterCudaTests(TestSDPAPatternRewriterTemplate):
|
class SDPAPatternRewriterGpuTests(TestSDPAPatternRewriterTemplate):
|
||||||
device = "cuda"
|
device = GPU_TYPE
|
||||||
test_sdpa_rewriter_1_cuda = (
|
test_sdpa_rewriter_1_gpu = TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_1
|
||||||
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_1
|
|
||||||
)
|
|
||||||
test_sdpa_rewriter_1_freezing = (
|
test_sdpa_rewriter_1_freezing = (
|
||||||
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_1_freezing
|
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_1_freezing
|
||||||
)
|
)
|
||||||
test_insignificant_strides = (
|
test_insignificant_strides = (
|
||||||
TestSDPAPatternRewriterTemplate._test_insignificant_strides
|
TestSDPAPatternRewriterTemplate._test_insignificant_strides
|
||||||
)
|
)
|
||||||
test_pattern_fails_with_reuse_cuda = (
|
test_pattern_fails_with_reuse_gpu = (
|
||||||
TestSDPAPatternRewriterTemplate._test_pattern_fails_with_reuse
|
TestSDPAPatternRewriterTemplate._test_pattern_fails_with_reuse
|
||||||
)
|
)
|
||||||
test_sdpa_rewriter_2_cuda = (
|
test_sdpa_rewriter_2_gpu = TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_2
|
||||||
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_2
|
test_sdpa_rewriter_3_gpu = TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_3
|
||||||
)
|
test_sdpa_rewriter_4_gpu = TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_4
|
||||||
test_sdpa_rewriter_3_cuda = (
|
test_sdpa_rewriter_5_gpu = TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_5
|
||||||
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_3
|
test_sdpa_rewriter_6_gpu = TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_6
|
||||||
)
|
test_sdpa_rewriter_7_gpu = TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_7
|
||||||
test_sdpa_rewriter_4_cuda = (
|
test_sdpa_rewriter_8_gpu = TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_8
|
||||||
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_4
|
test_sdpa_rewriter_9_gpu = TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_9
|
||||||
)
|
test_sdpa_rewriter_10_gpu = (
|
||||||
test_sdpa_rewriter_5_cuda = (
|
|
||||||
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_5
|
|
||||||
)
|
|
||||||
test_sdpa_rewriter_6_cuda = (
|
|
||||||
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_6
|
|
||||||
)
|
|
||||||
test_sdpa_rewriter_7_cuda = (
|
|
||||||
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_7
|
|
||||||
)
|
|
||||||
test_sdpa_rewriter_8_cuda = (
|
|
||||||
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_8
|
|
||||||
)
|
|
||||||
test_sdpa_rewriter_9_cuda = (
|
|
||||||
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_9
|
|
||||||
)
|
|
||||||
test_sdpa_rewriter_10_cuda = (
|
|
||||||
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_10
|
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_10
|
||||||
)
|
)
|
||||||
test_pattern_fails_with_tensor_factor_cuda = (
|
test_pattern_fails_with_tensor_factor_gpu = (
|
||||||
TestSDPAPatternRewriterTemplate._test_pattern_fails_with_tensor_factor
|
TestSDPAPatternRewriterTemplate._test_pattern_fails_with_tensor_factor
|
||||||
)
|
)
|
||||||
test_pattern_fails_with_unsupported_mask_cuda = (
|
test_pattern_fails_with_unsupported_mask_gpu = (
|
||||||
TestSDPAPatternRewriterTemplate._test_pattern_fails_with_unsupported_mask
|
TestSDPAPatternRewriterTemplate._test_pattern_fails_with_unsupported_mask
|
||||||
)
|
)
|
||||||
test_sdpa_rewriter_11_cuda = (
|
test_sdpa_rewriter_11_gpu = (
|
||||||
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_11
|
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_11
|
||||||
)
|
)
|
||||||
test_sdpa_rewriter_12_cuda = (
|
test_sdpa_rewriter_12_gpu = (
|
||||||
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_12
|
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_12
|
||||||
)
|
)
|
||||||
test_sdpa_prev_13_cuda = TestSDPAPatternRewriterTemplate._test_sdpa_prev_13
|
test_sdpa_prev_13_gpu = TestSDPAPatternRewriterTemplate._test_sdpa_prev_13
|
||||||
test_sdpa_prev_14_cuda = TestSDPAPatternRewriterTemplate._test_sdpa_prev_14
|
test_sdpa_prev_14_gpu = TestSDPAPatternRewriterTemplate._test_sdpa_prev_14
|
||||||
test_sdpa_prev_15_cuda = TestSDPAPatternRewriterTemplate._test_sdpa_prev_15
|
test_sdpa_prev_15_gpu = TestSDPAPatternRewriterTemplate._test_sdpa_prev_15
|
||||||
test_sdpa_rewriter_13_cuda = functools.partialmethod(
|
test_sdpa_rewriter_13_gpu = functools.partialmethod(
|
||||||
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_13, dtype=torch.half
|
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_13, dtype=torch.half
|
||||||
)
|
)
|
||||||
test_sdpa_rewriter_14_cuda = functools.partialmethod(
|
test_sdpa_rewriter_14_gpu = functools.partialmethod(
|
||||||
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_14
|
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_14
|
||||||
)
|
)
|
||||||
test_sdpa_rewriter_15_cuda = functools.partialmethod(
|
test_sdpa_rewriter_15_gpu = functools.partialmethod(
|
||||||
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_15
|
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_15
|
||||||
)
|
)
|
||||||
test_sdpa_rewriter_17_cuda = functools.partialmethod(
|
test_sdpa_rewriter_17_gpu = functools.partialmethod(
|
||||||
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_17
|
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_17
|
||||||
)
|
)
|
||||||
test_sdpa_rewriter_19_cuda = functools.partialmethod(
|
test_sdpa_rewriter_19_gpu = functools.partialmethod(
|
||||||
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_19
|
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_19
|
||||||
)
|
)
|
||||||
test_sdpa_rewriter_20_cuda = functools.partialmethod(
|
test_sdpa_rewriter_20_gpu = functools.partialmethod(
|
||||||
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_20
|
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_20
|
||||||
)
|
)
|
||||||
|
|
||||||
class SDPAPatternRewriterCudaDynamicTests(SDPAPatternRewriterCudaTests):
|
class SDPAPatternRewriterGpuDynamicTests(SDPAPatternRewriterGpuTests):
|
||||||
use_static_shapes = False
|
use_static_shapes = False
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user