mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
**MOTIVATION** We recently integrated support for Intel Gaudi devices (identified as 'hpu') into the common_device_type framework via the pull request at https://github.com/pytorch/pytorch/pull/126970. This integration allows tests to be automatically instantiated for Gaudi devices upon loading the relevant library. Building on this development, the current pull request extends the utility of these hooks by adapting selected CUDA tests to operate on Gaudi devices. Additionally, we have confirmed that these modifications do not interfere with the existing tests on CUDA devices. Other accelerators can also extend the functionality by adding the device in the devices list. ( For eg: xpu ) **CHANGES** Create a separate class for test functions running on CUDA devices Extend the functionality of these tests to include HPUs Use instantiate_device_type_tests with targeted attributes to generate device-specific test instances within the new classes Apply skipIfHPU decorator to bypass tests that are not yet compatible with HPU devices Previously we had submitted some changes in https://github.com/pytorch/pytorch/pull/140131 . However, deleted that PR due to merge conflicts and other issues. Pull Request resolved: https://github.com/pytorch/pytorch/pull/144387 Approved by: https://github.com/ankurneog, https://github.com/EikanWang, https://github.com/yanboliang, https://github.com/guangyey
205 lines
7.3 KiB
Python
205 lines
7.3 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
|
|
import os
|
|
import unittest
|
|
from unittest.mock import patch
|
|
|
|
import torch
|
|
from functorch import make_fx
|
|
from torch._dynamo import debug_utils
|
|
from torch._dynamo.debug_utils import aot_graph_input_parser, generate_env_vars_string
|
|
from torch._dynamo.test_case import TestCase
|
|
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
|
from torch.testing._internal.inductor_utils import HAS_CUDA
|
|
|
|
|
|
requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda")
|
|
|
|
f32 = torch.float32
|
|
i64 = torch.int64
|
|
i32 = torch.int32
|
|
|
|
|
|
class TestDebugUtils(TestCase):
|
|
def test_cast_model_to_fp64_dtype_args(self):
|
|
# Test that dtype arguments are converted to fp64
|
|
|
|
def fn(x):
|
|
return (
|
|
torch.ops.prims.convert_element_type(x, torch.float16),
|
|
x.to(torch.float16),
|
|
torch.full(x.shape, 2, dtype=torch.float32, device=x.device),
|
|
x.new_empty(x.shape),
|
|
)
|
|
|
|
x = torch.randn(32, device="cpu")
|
|
decomps = torch._decomp.core_aten_decompositions()
|
|
fx = make_fx(fn, decomposition_table=decomps)(x)
|
|
|
|
self.assertExpectedInline(
|
|
fx.code.lstrip(),
|
|
"""\
|
|
def forward(self, x_1):
|
|
convert_element_type = torch.ops.prims.convert_element_type.default(x_1, torch.float16)
|
|
_to_copy = torch.ops.aten._to_copy.default(x_1, dtype = torch.float16); x_1 = None
|
|
full = torch.ops.aten.full.default([32], 2, dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
|
|
empty = torch.ops.aten.empty.memory_format([32], dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
|
|
return (convert_element_type, _to_copy, full, empty)
|
|
""", # NOQA: B950
|
|
)
|
|
|
|
_, fp64_examples = debug_utils.cast_to_fp64(fx, (x,))
|
|
self.assertEqual(fp64_examples, (x.to(torch.float64),))
|
|
|
|
self.assertExpectedInline(
|
|
fx.code.lstrip(),
|
|
"""\
|
|
def forward(self, x_1):
|
|
convert_element_type = torch.ops.prims.convert_element_type.default(x_1, torch.float64)
|
|
_to_copy = torch.ops.aten._to_copy.default(x_1, dtype = torch.float64); x_1 = None
|
|
full = torch.ops.aten.full.default([32], 2, dtype = torch.float64, device = device(type='cpu'), pin_memory = False)
|
|
empty = torch.ops.aten.empty.memory_format([32], dtype = torch.float64, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
|
|
return (convert_element_type, _to_copy, full, empty)
|
|
""", # NOQA: B950
|
|
)
|
|
|
|
@patch.dict(os.environ, {"TORCHINDUCTOR_MAX_AUTOTUNE": "1", "TEST_ENV": "1"})
|
|
def test_generate_env_vars_string(self):
|
|
env_strings = generate_env_vars_string()
|
|
self.assertIn(
|
|
"""os.environ['TORCHINDUCTOR_MAX_AUTOTUNE'] = '1'
|
|
""",
|
|
env_strings,
|
|
)
|
|
self.assertIn(
|
|
"""import os
|
|
""",
|
|
env_strings,
|
|
)
|
|
self.assertNotIn(
|
|
"""TEST_ENV
|
|
""",
|
|
env_strings,
|
|
)
|
|
|
|
|
|
class TestDebugUtilsDevice(TestCase):
|
|
def test_aot_graph_parser(self, device):
|
|
def forward(
|
|
self,
|
|
primals_1: "f32[1001, 6]",
|
|
primals_2: "f32[1001]",
|
|
primals_3: "f32[1001, 64]",
|
|
primals_4: "f32[4190]",
|
|
primals_5: "f32[4190]",
|
|
primals_6: "f32[1739, 4190]",
|
|
primals_48: "f32[6144, 4191]",
|
|
):
|
|
_tensor_constant0: "i64[4190]" = self._tensor_constant0
|
|
lift_fresh_copy: "i64[4190]" = torch.ops.aten.lift_fresh_copy.default(
|
|
_tensor_constant0
|
|
)
|
|
_tensor_constant0 = None
|
|
index: "f32[6144, 4190]" = torch.ops.aten.index.Tensor( # noqa: F841
|
|
primals_48, [None, lift_fresh_copy]
|
|
)
|
|
lift_fresh_copy = None
|
|
|
|
_tensor_constant1: "i64[6]" = self._tensor_constant1
|
|
lift_fresh_copy_1: "i64[6]" = torch.ops.aten.lift_fresh_copy.default(
|
|
_tensor_constant1
|
|
)
|
|
_tensor_constant1 = None
|
|
index_1: "f32[6144, 6]" = torch.ops.aten.index.Tensor(
|
|
primals_48, [None, lift_fresh_copy_1]
|
|
)
|
|
primals_48 = lift_fresh_copy_1 = None
|
|
permute: "f32[6, 1001]" = torch.ops.aten.permute.default(primals_1, [1, 0])
|
|
primals_1 = None
|
|
addmm: "f32[6144, 1001]" = torch.ops.aten.addmm.default(
|
|
primals_2, index_1, permute
|
|
)
|
|
primals_2 = permute = None
|
|
amax: "f32[6144, 1]" = torch.ops.aten.amax.default(addmm, [-1], True)
|
|
sub: "f32[6144, 1001]" = torch.ops.aten.sub.Tensor(addmm, amax)
|
|
exp: "f32[6144, 1001]" = torch.ops.aten.exp.default(sub)
|
|
sub = None
|
|
sum_1: "f32[6144, 1]" = torch.ops.aten.sum.dim_IntList(exp, [-1], True)
|
|
div: "f32[6144, 1001]" = torch.ops.aten.div.Tensor(exp, sum_1)
|
|
exp = None
|
|
|
|
full_default: "i32[6144, 1001]" = torch.ops.aten.full.default(
|
|
[6144, 1001],
|
|
1,
|
|
dtype=torch.int32,
|
|
layout=torch.strided,
|
|
device=device,
|
|
pin_memory=False,
|
|
)
|
|
|
|
iota: "i32[1001]" = torch.ops.prims.iota.default(
|
|
1001,
|
|
start=0,
|
|
step=1,
|
|
dtype=torch.int32,
|
|
device=device,
|
|
requires_grad=False,
|
|
)
|
|
|
|
mul: "i32[6144, 1001]" = torch.ops.aten.mul.Tensor(full_default, iota)
|
|
full_default = iota = None
|
|
|
|
iota_1: "i32[6144]" = torch.ops.prims.iota.default(
|
|
6144,
|
|
start=0,
|
|
step=1001,
|
|
dtype=torch.int32,
|
|
device=device,
|
|
requires_grad=False,
|
|
)
|
|
view: "i32[6150144]" = torch.ops.aten.reshape.default(mul, [-1])
|
|
mul = None
|
|
view_1: "f32[6150144]" = torch.ops.aten.reshape.default(div, [-1])
|
|
div = None
|
|
_embedding_bag = torch.ops.aten._embedding_bag.default(
|
|
primals_3, view, iota_1, False, 0, False, view_1
|
|
)
|
|
|
|
return _embedding_bag
|
|
|
|
kwargs = aot_graph_input_parser(forward, device=device)
|
|
# runs successfully
|
|
forward(**kwargs)
|
|
|
|
def test_sym_aot_graph_parser(self, device):
|
|
def forward(
|
|
self,
|
|
primals_1: "f32[1001, 6]", # noqa: F821
|
|
primals_2: "f32[s0]", # noqa: F821
|
|
primals_3: "Sym(s0)", # noqa: F821,
|
|
primals_4: "f32[s1]", # noqa: F821,
|
|
primals_5: "Sym(s1)", # noqa: F821,
|
|
):
|
|
_tensor_constant0: "i64[4190]" = self._tensor_constant0
|
|
|
|
kwargs = aot_graph_input_parser(
|
|
forward, device=device, sym_shapes={"s0": 10}, default_sym_shape=5
|
|
)
|
|
|
|
self.assertEqual(list(kwargs["primals_2"].shape), [10])
|
|
self.assertEqual(kwargs["primals_3"], 10)
|
|
|
|
self.assertEqual(list(kwargs["primals_4"].shape), [5])
|
|
self.assertEqual(kwargs["primals_5"], 5)
|
|
|
|
|
|
instantiate_device_type_tests(TestDebugUtils, globals())
|
|
|
|
devices = ["cuda", "hpu"]
|
|
instantiate_device_type_tests(TestDebugUtilsDevice, globals(), only_for=devices)
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
run_tests()
|