mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Enable Intel GPU on 4 unit test cases (#165405)
For https://github.com/pytorch/pytorch/issues/114850, we will port some aten unit tests to Intel GPU. We could enable Intel GPU with following methods and try the best to keep the original code styles: 1. Replaced onlyCUDA with onlyOn(['cuda', 'xpu']) for supported tests 2. Added allow_xpu=True for supported test class in test parameterization. 3. Use torch.accelerator to extend cude specific test to XPU if needed. 4. Enabled 'xpu' for some test pathes Pull Request resolved: https://github.com/pytorch/pytorch/pull/165405 Approved by: https://github.com/guangyey, https://github.com/ezyang
This commit is contained in:
parent
4e6afa8c07
commit
81fa4a204c
|
|
@ -10,13 +10,14 @@ from torch.nn import MultiheadAttention
|
||||||
from torch.testing._internal.common_device_type import (
|
from torch.testing._internal.common_device_type import (
|
||||||
dtypes,
|
dtypes,
|
||||||
instantiate_device_type_tests,
|
instantiate_device_type_tests,
|
||||||
onlyCUDAAndPRIVATEUSE1,
|
onlyOn,
|
||||||
)
|
)
|
||||||
from torch.testing._internal.common_nn import NNTestCase
|
from torch.testing._internal.common_nn import NNTestCase
|
||||||
from torch.testing._internal.common_utils import (
|
from torch.testing._internal.common_utils import (
|
||||||
instantiate_parametrized_tests,
|
instantiate_parametrized_tests,
|
||||||
parametrize as parametrize_test,
|
parametrize as parametrize_test,
|
||||||
run_tests,
|
run_tests,
|
||||||
|
TEST_CUDA,
|
||||||
TEST_NUMPY,
|
TEST_NUMPY,
|
||||||
TEST_WITH_CROSSREF,
|
TEST_WITH_CROSSREF,
|
||||||
)
|
)
|
||||||
|
|
@ -32,8 +33,9 @@ if TEST_NUMPY:
|
||||||
|
|
||||||
|
|
||||||
class TestMultiheadAttentionNN(NNTestCase):
|
class TestMultiheadAttentionNN(NNTestCase):
|
||||||
_do_cuda_memory_leak_check = True
|
if TEST_CUDA:
|
||||||
_do_cuda_non_default_stream = True
|
_do_cuda_memory_leak_check = True
|
||||||
|
_do_cuda_non_default_stream = True
|
||||||
|
|
||||||
@unittest.skipIf(not TEST_NUMPY, "numpy not found")
|
@unittest.skipIf(not TEST_NUMPY, "numpy not found")
|
||||||
@parametrize_test("average_attn_weights", [True, False])
|
@parametrize_test("average_attn_weights", [True, False])
|
||||||
|
|
@ -834,8 +836,13 @@ class TestMultiheadAttentionNNDeviceType(NNTestCase):
|
||||||
and key padding mask (mask type 1) are provided at the same time on CPU and CUDA and PrivateUse1
|
and key padding mask (mask type 1) are provided at the same time on CPU and CUDA and PrivateUse1
|
||||||
"""
|
"""
|
||||||
device = device.rstrip(":0123456789")
|
device = device.rstrip(":0123456789")
|
||||||
if device not in ["cpu", "cuda", torch._C._get_privateuse1_backend_name()]:
|
if device not in [
|
||||||
self.skipTest("Fastpath only runs on CPU and CUDA and PrivateUse1.")
|
"cpu",
|
||||||
|
"cuda",
|
||||||
|
"xpu",
|
||||||
|
torch._C._get_privateuse1_backend_name(),
|
||||||
|
]:
|
||||||
|
self.skipTest("Fastpath only runs on CPU and CUDA and XPU and PrivateUse1.")
|
||||||
|
|
||||||
with torch.autocast(device_type=device, enabled=False):
|
with torch.autocast(device_type=device, enabled=False):
|
||||||
embed_dim = 16
|
embed_dim = 16
|
||||||
|
|
@ -869,7 +876,7 @@ class TestMultiheadAttentionNNDeviceType(NNTestCase):
|
||||||
# If mock was called, fastpath was taken
|
# If mock was called, fastpath was taken
|
||||||
self.assertTrue(fastpath_mock.called)
|
self.assertTrue(fastpath_mock.called)
|
||||||
|
|
||||||
@onlyCUDAAndPRIVATEUSE1
|
@onlyOn(["cuda", "xpu", torch._C._get_privateuse1_backend_name()])
|
||||||
@dtypes(torch.half, torch.float, torch.double)
|
@dtypes(torch.half, torch.float, torch.double)
|
||||||
def test_multihead_attention_dtype(self, device, dtype):
|
def test_multihead_attention_dtype(self, device, dtype):
|
||||||
embed_dim = 128
|
embed_dim = 128
|
||||||
|
|
@ -884,7 +891,7 @@ class TestMultiheadAttentionNNDeviceType(NNTestCase):
|
||||||
self.assertEqual(q.size(), out[0].size())
|
self.assertEqual(q.size(), out[0].size())
|
||||||
self.assertEqual(dtype, out[0].dtype)
|
self.assertEqual(dtype, out[0].dtype)
|
||||||
|
|
||||||
@onlyCUDAAndPRIVATEUSE1
|
@onlyOn(["cuda", "xpu", torch._C._get_privateuse1_backend_name()])
|
||||||
@dtypes(torch.half, torch.float, torch.double)
|
@dtypes(torch.half, torch.float, torch.double)
|
||||||
def test_multihead_attention_dtype_batch_first(self, device, dtype):
|
def test_multihead_attention_dtype_batch_first(self, device, dtype):
|
||||||
embed_dim = 128
|
embed_dim = 128
|
||||||
|
|
|
||||||
|
|
@ -44,6 +44,7 @@ from torch.testing._internal.common_utils import (
|
||||||
parametrize,
|
parametrize,
|
||||||
run_tests,
|
run_tests,
|
||||||
skipIfTorchDynamo,
|
skipIfTorchDynamo,
|
||||||
|
TEST_XPU,
|
||||||
TestCase,
|
TestCase,
|
||||||
)
|
)
|
||||||
from torch.testing._internal.logging_utils import logs_to_string
|
from torch.testing._internal.logging_utils import logs_to_string
|
||||||
|
|
@ -3204,6 +3205,9 @@ class TestGuardsExpressions(TestCase):
|
||||||
self.assertTrue(shape_env.evaluate_guards_expression(guards, [hint_int(s0)]))
|
self.assertTrue(shape_env.evaluate_guards_expression(guards, [hint_int(s0)]))
|
||||||
self.assertFalse(shape_env.evaluate_guards_expression(guards, [hint_int(s1)]))
|
self.assertFalse(shape_env.evaluate_guards_expression(guards, [hint_int(s1)]))
|
||||||
|
|
||||||
|
@unittest.skipIf(
|
||||||
|
TEST_XPU, "Skipped on XPU"
|
||||||
|
) # https://github.com/intel/torch-xpu-ops/issues/2169"
|
||||||
@skipIfTorchDynamo("Attempt to trace generator")
|
@skipIfTorchDynamo("Attempt to trace generator")
|
||||||
@torch.fx.experimental._config.patch("use_duck_shape", False)
|
@torch.fx.experimental._config.patch("use_duck_shape", False)
|
||||||
def test_size_comparison_no_recompile(self):
|
def test_size_comparison_no_recompile(self):
|
||||||
|
|
|
||||||
|
|
@ -14,11 +14,12 @@ from torch.testing import make_tensor
|
||||||
from torch.testing._internal.common_device_type import (
|
from torch.testing._internal.common_device_type import (
|
||||||
dtypes,
|
dtypes,
|
||||||
dtypesIfCUDA,
|
dtypesIfCUDA,
|
||||||
|
dtypesIfXPU,
|
||||||
instantiate_device_type_tests,
|
instantiate_device_type_tests,
|
||||||
largeTensorTest,
|
largeTensorTest,
|
||||||
onlyCPU,
|
onlyCPU,
|
||||||
onlyCUDA,
|
|
||||||
onlyNativeDeviceTypes,
|
onlyNativeDeviceTypes,
|
||||||
|
onlyOn,
|
||||||
)
|
)
|
||||||
from torch.testing._internal.common_dtype import (
|
from torch.testing._internal.common_dtype import (
|
||||||
all_types,
|
all_types,
|
||||||
|
|
@ -271,6 +272,7 @@ class TestShapeOps(TestCase):
|
||||||
@onlyNativeDeviceTypes
|
@onlyNativeDeviceTypes
|
||||||
@dtypes(*all_types())
|
@dtypes(*all_types())
|
||||||
@dtypesIfCUDA(*all_types_and(torch.half))
|
@dtypesIfCUDA(*all_types_and(torch.half))
|
||||||
|
@dtypesIfXPU(*all_types_and(torch.half))
|
||||||
def test_trace(self, device, dtype):
|
def test_trace(self, device, dtype):
|
||||||
def test(shape):
|
def test(shape):
|
||||||
tensor = make_tensor(shape, dtype=dtype, device=device, low=-9, high=9)
|
tensor = make_tensor(shape, dtype=dtype, device=device, low=-9, high=9)
|
||||||
|
|
@ -568,7 +570,7 @@ class TestShapeOps(TestCase):
|
||||||
np_fn = partial(np.flip, axis=flip_dim)
|
np_fn = partial(np.flip, axis=flip_dim)
|
||||||
self.compare_with_numpy(torch_fn, np_fn, data)
|
self.compare_with_numpy(torch_fn, np_fn, data)
|
||||||
|
|
||||||
@onlyCUDA # CPU is too slow
|
@onlyOn(["cuda", "xpu"]) # CPU is too slow
|
||||||
@largeTensorTest("17GB") # 4 tensors of 4GB (in, out) x (torch, numpy) + 1GB
|
@largeTensorTest("17GB") # 4 tensors of 4GB (in, out) x (torch, numpy) + 1GB
|
||||||
@largeTensorTest(
|
@largeTensorTest(
|
||||||
"81GB", "cpu"
|
"81GB", "cpu"
|
||||||
|
|
@ -715,6 +717,7 @@ class TestShapeOps(TestCase):
|
||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
self.device_type == "cuda"
|
self.device_type == "cuda"
|
||||||
|
or self.device_type == "xpu"
|
||||||
or self.device_type == TEST_PRIVATEUSE1_DEVICE_TYPE
|
or self.device_type == TEST_PRIVATEUSE1_DEVICE_TYPE
|
||||||
):
|
):
|
||||||
self.assertRaisesRegex(
|
self.assertRaisesRegex(
|
||||||
|
|
|
||||||
|
|
@ -37,6 +37,7 @@ from torch.testing._internal.common_utils import (
|
||||||
NOTEST_CPU,
|
NOTEST_CPU,
|
||||||
IS_WINDOWS,
|
IS_WINDOWS,
|
||||||
TEST_WITH_TORCHDYNAMO,
|
TEST_WITH_TORCHDYNAMO,
|
||||||
|
TEST_XPU,
|
||||||
)
|
)
|
||||||
from torch._dynamo.testing import CompileCounterWithBackend
|
from torch._dynamo.testing import CompileCounterWithBackend
|
||||||
|
|
||||||
|
|
@ -4630,12 +4631,15 @@ if NOTEST_CPU:
|
||||||
else:
|
else:
|
||||||
device_types = ("cpu", "cuda", "mps")
|
device_types = ("cpu", "cuda", "mps")
|
||||||
|
|
||||||
|
if TEST_XPU:
|
||||||
|
device_types += ("xpu", )
|
||||||
|
|
||||||
instantiate_device_type_tests(TestTransformers, globals(), only_for=device_types)
|
instantiate_device_type_tests(TestTransformers, globals(), only_for=device_types)
|
||||||
instantiate_device_type_tests(TestSDPAFailureModes, globals(), only_for=device_types, allow_mps=True)
|
instantiate_device_type_tests(TestSDPAFailureModes, globals(), only_for=device_types, allow_mps=True)
|
||||||
instantiate_device_type_tests(TestSDPA, globals(), only_for=device_types, allow_mps=True)
|
instantiate_device_type_tests(TestSDPA, globals(), only_for=device_types, allow_mps=True, allow_xpu=True)
|
||||||
instantiate_device_type_tests(TestSDPACudaOnly, globals(), only_for=("cuda"))
|
instantiate_device_type_tests(TestSDPACudaOnly, globals(), only_for=("cuda"))
|
||||||
instantiate_device_type_tests(TestSDPACpuOnly, globals(), only_for=("cpu"))
|
instantiate_device_type_tests(TestSDPACpuOnly, globals(), only_for=("cpu"))
|
||||||
instantiate_device_type_tests(TestAttnBias, globals(), only_for=device_types)
|
instantiate_device_type_tests(TestAttnBias, globals(), only_for=device_types, allow_xpu=True)
|
||||||
instantiate_device_type_tests(TestSDPAXpuOnly, globals(), only_for="xpu", allow_xpu=True)
|
instantiate_device_type_tests(TestSDPAXpuOnly, globals(), only_for="xpu", allow_xpu=True)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user