[Inductor XPU] Refine test_mkldnn_pattern_matcher.py to be reusable for XPU. (#150286)

This PR extracts some test cases from TestPatternMatcher into a newly created TestPatternMatcherGeneric, and uses instantiate_device_type_tests to make them reusable across multiple devices.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150286
Approved by: https://github.com/jansel
This commit is contained in:
xinan.lin 2025-04-06 23:50:14 -07:00 committed by PyTorch MergeBot
parent f8aa6404ac
commit 58ede0cca3
2 changed files with 133 additions and 68 deletions

View File

@ -189,7 +189,7 @@ if RUN_CPU:
BaseTest(
"test_conv2d_unary",
"cpu",
test_mkldnn_pattern_matcher.TestPatternMatcher(),
test_mkldnn_pattern_matcher.TestPatternMatcherGenericCPU(),
condition=torch.backends.mkldnn.is_available(),
slow=True,
),
@ -220,9 +220,9 @@ if RUN_CPU:
],
BaseTest("test_polar"),
BaseTest(
"test_linear_binary",
"test_linear_binary_cpu",
"",
test_mkldnn_pattern_matcher.TestPatternMatcher(),
test_mkldnn_pattern_matcher.TestPatternMatcherGenericCPU(),
torch.backends.mkldnn.is_available()
and torch.ops.mkldnn._is_mkldnn_bf16_supported(),
),
@ -359,7 +359,9 @@ if RUN_CPU:
BaseTest("test_view_as_complex"),
BaseTest("test_view_as_real"),
BaseTest(
"test_woq_int4", "cpu", test_mkldnn_pattern_matcher.TestPatternMatcher()
"test_woq_int4",
"cpu",
test_mkldnn_pattern_matcher.TestPatternMatcher(),
),
]:
make_test_case(

View File

@ -13,6 +13,7 @@ from torch._inductor.test_case import run_tests, TestCase
from torch._inductor.utils import run_and_get_code
from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer
from torch.nn import functional as F
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_quantization import (
_generate_qdq_quantized_model,
skipIfNoDynamoSupport,
@ -33,7 +34,11 @@ from torch.testing._internal.common_utils import (
TEST_MKL,
xfailIfACL,
)
from torch.testing._internal.inductor_utils import _check_has_dynamic_shape, HAS_CPU
from torch.testing._internal.inductor_utils import (
_check_has_dynamic_shape,
clone_preserve_strides_offset,
HAS_CPU,
)
# The dict value is match_nodes(computation_op+unary_op)
@ -91,7 +96,7 @@ def get_default_quantizer(is_qat, is_dynamic):
return quantizer
def cal_conv_generated_kernel_number(mod, input, dtype, dim=4):
def cal_conv_generated_kernel_number(mod, input, dtype, dim=4, device="cpu"):
# this function is to decide how many kernels are generated
# while testing conv2d/3d/deconv2d
# the assumption is:
@ -103,11 +108,14 @@ def cal_conv_generated_kernel_number(mod, input, dtype, dim=4):
# and force the output to have same stride with eager.
# So there will be a to_contiguous for output if eager output is contiguouse
mod = copy.deepcopy(mod)
mod = mod.to(device=device)
input = input.clone()
input = input.to(device)
if dtype == torch.float32:
maybe_autocast = contextlib.nullcontext()
else:
maybe_autocast = torch.amp.autocast("cpu", dtype=dtype)
maybe_autocast = torch.amp.autocast(device_type=device, dtype=dtype)
with torch.no_grad(), maybe_autocast:
output = mod(input)
input_kernel, output_kernel = 0, 0
@ -155,26 +163,33 @@ class TestPatternMatcherBase(TestCase):
quantizer=None,
compile_options={}, # noqa: B006
):
if not hasattr(self, "device"):
has_xpu = any(
isinstance(input, torch.Tensor) and input.device.type == "xpu"
for input in inputs
)
device = "xpu" if has_xpu else "cpu"
else:
device = self.device
mod = mod.to(device=device)
if device != "cpu":
inputs = tuple(
clone_preserve_strides_offset(x, device=device) for x in inputs
)
counters.clear()
torch._dynamo.reset()
has_xpu = any(
isinstance(input, torch.Tensor) and input.device.type == "xpu"
for input in inputs
)
device_type = "xpu" if has_xpu else "cpu"
if check_autocast == torch.bfloat16 and (
torch.ops.mkldnn._is_mkldnn_bf16_supported() or has_xpu
torch.ops.mkldnn._is_mkldnn_bf16_supported() or device == "xpu"
):
maybe_autocast = torch.amp.autocast(
device_type=device_type, dtype=torch.bfloat16
device_type=device, dtype=torch.bfloat16
)
atol, rtol = 1e-2, 1e-2
elif check_autocast == torch.float16 and (
torch.ops.mkldnn._is_mkldnn_fp16_supported() or has_xpu
torch.ops.mkldnn._is_mkldnn_fp16_supported() or device == "xpu"
):
maybe_autocast = torch.amp.autocast(
device_type=device_type, dtype=torch.float16
)
maybe_autocast = torch.amp.autocast(device_type=device, dtype=torch.float16)
atol, rtol = 1e-2, 1e-2
else:
assert check_autocast == torch.float32
@ -233,8 +248,8 @@ class TestPatternMatcherBase(TestCase):
torch.testing.assert_close(actual, expected, atol=atol, rtol=rtol)
class TestPatternMatcher(TestPatternMatcherBase):
def _test_conv_unary_cpu_base(self, dim=4):
class TestPatternMatcherGeneric(TestPatternMatcherBase):
def _test_conv_unary_base(self, dim=4):
assert dim == 4 or dim == 5
class M(torch.nn.Module):
@ -304,23 +319,27 @@ class TestPatternMatcher(TestPatternMatcherBase):
self._test_common(mod, (v,), matcher_check_fn, check_autocast=dtype)
generated_kernel_count = cal_conv_generated_kernel_number(
mod, v, dtype, dim
mod, v, dtype, dim, self.device
)
self.assertEqual(metrics.generated_kernel_count, generated_kernel_count)
@skipIfNoDynamoSupport
@skipIfNoONEDNN
@skipIfRocm
def test_conv2d_unary_cpu(self):
self._test_conv_unary_cpu_base(dim=4)
def test_conv2d_unary(self, device):
self.device = device
self._test_conv_unary_base(dim=4)
@skipIfNoDynamoSupport
@skipIfNoONEDNN
@skipIfRocm
def test_conv3d_unary_cpu(self):
self._test_conv_unary_cpu_base(dim=5)
def test_conv3d_unary(self, device):
self.device = device
self._test_conv_unary_base(dim=5)
def test_linear_unary(self, device):
self.device = device
def test_linear_unary(self):
class M(torch.nn.Module):
def __init__(
self,
@ -374,7 +393,9 @@ class TestPatternMatcher(TestPatternMatcherBase):
self.assertEqual(metrics.generated_kernel_count, 2 if TEST_ACL else 1)
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
def test_linear_fp32(self):
def test_linear_fp32(self, device):
self.device = device
class M(torch.nn.Module):
def __init__(self, bias):
super().__init__()
@ -396,7 +417,9 @@ class TestPatternMatcher(TestPatternMatcherBase):
self._test_common(mod, (v,), matcher_check_fn)
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
def test_linear_input_non_contiguous_3D_wo_bias(self):
def test_linear_input_non_contiguous_3D_wo_bias(self, device):
self.device = device
# Activation is 3D, non-contiguous and without Bias
class M(torch.nn.Module):
def __init__(self):
@ -438,17 +461,19 @@ class TestPatternMatcher(TestPatternMatcherBase):
)
torch.testing.assert_close(actual, expected, atol=1e-2, rtol=1e-2)
def test_linear_add_bias(self):
def test_linear_add_bias(self, device):
self.device = device
class M(torch.nn.Module):
def __init__(self, dtype, unary_fn, cast_bias):
def __init__(self, device, dtype, unary_fn, cast_bias):
super().__init__()
self.linear1 = torch.nn.Linear(10, 64, bias=False)
self.bias1 = torch.randn(64)
self.bias1 = torch.randn(64, device=device)
self.linear2 = torch.nn.Linear(10, 64, bias=False)
self.bias2 = torch.randn(64)
self.bias2 = torch.randn(64, device=device)
if cast_bias:
self.bias1 = self.bias1.to(dtype=dtype)
self.bias2 = self.bias2.to(dtype=dtype)
self.bias1 = self.bias1.to(dtype=dtype, device=device)
self.bias2 = self.bias2.to(dtype=dtype, device=device)
self.unary_fn = unary_fn
def forward(self, x):
@ -464,7 +489,7 @@ class TestPatternMatcher(TestPatternMatcherBase):
options = itertools.product(unary_list, dtypes)
for unary_fn, dtype in options:
metrics.reset()
fold_mod = M(dtype, unary_fn, cast_bias=True).eval()
fold_mod = M(self.device, dtype, unary_fn, cast_bias=True).eval()
v = torch.randn(2, 10)
def folder_matcher_check_fn():
@ -495,7 +520,7 @@ class TestPatternMatcher(TestPatternMatcherBase):
# we won't fold the bias if bias is not same dtype with weight
# https://github.com/pytorch/pytorch/pull/129138
metrics.reset()
mod = M(dtype, unary_fn, cast_bias=False).eval()
mod = M(self.device, dtype, unary_fn, cast_bias=False).eval()
def matcher_check_fn():
self.assertEqual(
@ -575,20 +600,22 @@ class TestPatternMatcher(TestPatternMatcherBase):
self._test_common(mod, (v,), matcher_check_fn, check_autocast=dtype)
generated_kernel_count = cal_conv_generated_kernel_number(
mod, v, dtype, dim
mod, v, dtype, dim, self.device
)
self.assertEqual(metrics.generated_kernel_count, generated_kernel_count)
@skipIfNoDynamoSupport
@skipIfNoONEDNN
@skipIfRocm
def test_conv_transpose2d_unary_cpu(self):
def test_conv_transpose2d_unary(self, device):
self.device = device
self._test_conv_transpose_unary_base(dim=4)
@skipIfNoDynamoSupport
@skipIfNoONEDNN
@skipIfRocm
def test_conv_transpose3d_unary_cpu(self):
def test_conv_transpose3d_unary(self, device):
self.device = device
self._test_conv_transpose_unary_base(dim=5)
def _test_conv_binary_base(self, dim=4):
@ -669,20 +696,22 @@ class TestPatternMatcher(TestPatternMatcherBase):
self._test_common(mod, (v,), matcher_check_fn, check_autocast=dtype)
generated_kernel_count = cal_conv_generated_kernel_number(
mod, v, dtype, dim
mod, v, dtype, dim, self.device
)
self.assertEqual(metrics.generated_kernel_count, generated_kernel_count)
@skipIfNoDynamoSupport
@skipIfNoONEDNN
@skipIfRocm
def test_conv2d_binary(self):
def test_conv2d_binary(self, device):
self.device = device
self._test_conv_binary_base(dim=4)
@skipIfNoDynamoSupport
@skipIfNoONEDNN
@skipIfRocm
def test_conv3d_binary(self):
def test_conv3d_binary(self, device):
self.device = device
self._test_conv_binary_base(dim=5)
def _test_conv_binary_broadcast_shapes_base(self, dim=4):
@ -788,7 +817,9 @@ class TestPatternMatcher(TestPatternMatcherBase):
def test_conv3d_binary_broadcast_shapes_cpu(self):
self._test_conv_binary_broadcast_shapes_base(dim=5)
def test_linear_binary(self):
def test_linear_binary(self, device):
self.device = device
class M(torch.nn.Module):
def __init__(self, binary_fn, in_channels, out_channels, bias, **kwargs):
super().__init__()
@ -939,7 +970,9 @@ class TestPatternMatcher(TestPatternMatcherBase):
self._test_common(mod, (x1, x2), matcher_check_fn)
def test_multi_linear_share_same_input(self):
def test_multi_linear_share_same_input(self, device):
self.device = device
# llama pattern.
class M(torch.nn.Module):
def __init__(
@ -979,6 +1012,8 @@ class TestPatternMatcher(TestPatternMatcherBase):
v = torch.randn(2, 4, 16).to(dtype)
self._test_common(mod, (v,), matcher_check_fn, rtol=1e-2, atol=1e-2)
class TestPatternMatcher(TestPatternMatcherBase):
def _qconv2d_test_helper(self, device="cpu", int8_mixed_bf16=False):
class M(torch.nn.Module):
def __init__(
@ -4119,30 +4154,42 @@ class TestPatternMatcher(TestPatternMatcherBase):
self.assertEqual(counters["inductor"]["qlinear_binary_matcher_count"], 1)
# When testing kernel counts, unspecializing float causes wobbling of our tests because
# we end up reusing the same compiled region across tests. Thus we purposely specialize floats
# here since we primarily care about number of kernels generated in the absence of compile
# caching.
@dynamo_config.patch(
{
"dynamic_shapes": True,
"assume_static_by_default": False,
"specialize_float": True,
}
)
class TestDynamicPatternMatcher(TestPatternMatcherBase):
_test_conv_unary_cpu_base = TestPatternMatcher._test_conv_unary_cpu_base
test_conv2d_unary_dynamic_shapes = TestPatternMatcher.test_conv2d_unary_cpu
test_conv3d_unary_dynamic_shapes = TestPatternMatcher.test_conv3d_unary_cpu
_test_conv_binary_base = TestPatternMatcher._test_conv_binary_base
test_conv2d_binary_dynamic_shapes = TestPatternMatcher.test_conv2d_binary
test_conv3d_binary_dynamic_shapes = TestPatternMatcher.test_conv3d_binary
test_linear_unary_dynamic_shapes = TestPatternMatcher.test_linear_unary
class TestDynamicPatternMatcherGeneric(TestPatternMatcherBase):
def setUp(self):
TestCase.setUp(self)
self.ctx_stack = contextlib.ExitStack()
self.ctx_stack.enter_context(
# When testing kernel counts, unspecializing float causes wobbling of our tests because
# we end up reusing the same compiled region across tests. Thus we purposely specialize floats
# here since we primarily care about number of kernels generated in the absence of compile
# caching.
dynamo_config.patch(
{
"dynamic_shapes": True,
"assume_static_by_default": False,
"specialize_float": True,
}
)
)
def tearDown(self):
TestCase.tearDown(self)
self.ctx_stack.close()
_test_conv_unary_base = TestPatternMatcherGeneric._test_conv_unary_base
test_conv2d_unary_dynamic_shapes = TestPatternMatcherGeneric.test_conv2d_unary
test_conv3d_unary_dynamic_shapes = TestPatternMatcherGeneric.test_conv3d_unary
_test_conv_binary_base = TestPatternMatcherGeneric._test_conv_binary_base
test_conv2d_binary_dynamic_shapes = TestPatternMatcherGeneric.test_conv2d_binary
test_conv3d_binary_dynamic_shapes = TestPatternMatcherGeneric.test_conv3d_binary
test_linear_unary_dynamic_shapes = TestPatternMatcherGeneric.test_linear_unary
test_linear_input_non_contiguous_3D_wo_bias_dynamic_shapes = (
TestPatternMatcher.test_linear_input_non_contiguous_3D_wo_bias
TestPatternMatcherGeneric.test_linear_input_non_contiguous_3D_wo_bias
)
def test_conv_transpose2d_dynamic_shapes(self):
def test_conv_transpose2d_dynamic_shapes(self, device):
self.device = device
# We don't support conv_transpose2d for now.
class M(torch.nn.Module):
def __init__(self) -> None:
@ -4163,7 +4210,9 @@ class TestDynamicPatternMatcher(TestPatternMatcherBase):
self._test_common(mod, (v,), matcher_check_fn)
def test_multi_linear_share_same_input_dynamic(self):
def test_multi_linear_share_same_input_dynamic(self, device):
self.device = device
# llama pattern.
class M(torch.nn.Module):
def __init__(
@ -4206,6 +4255,15 @@ class TestDynamicPatternMatcher(TestPatternMatcherBase):
v = torch.randn(2, 4, 16).to(dtype)
self._test_common(mod, (v,), matcher_check_fn, rtol=1e-2, atol=1e-2)
@dynamo_config.patch(
{
"dynamic_shapes": True,
"assume_static_by_default": False,
"specialize_float": True,
}
)
class TestDynamicPatternMatcher(TestPatternMatcherBase):
@xfailIfACL
def test_qconv2d_maxpool2d_linear_dynamic_cpu(self, include_ops=None):
r"""
@ -4367,8 +4425,13 @@ class TestDynamicPatternMatcher(TestPatternMatcherBase):
)
instantiate_device_type_tests(
TestPatternMatcherGeneric, globals(), allow_xpu=True, only_for=("cpu")
)
instantiate_device_type_tests(
TestDynamicPatternMatcherGeneric, globals(), allow_xpu=True, only_for=("cpu")
)
instantiate_parametrized_tests(TestPatternMatcher)
if __name__ == "__main__":
if IS_LINUX and HAS_CPU and torch.backends.mkldnn.is_available():
if IS_LINUX and (HAS_CPU) and torch.backends.mkldnn.is_available():
run_tests()