mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Inductor XPU] Refine test_mkldnn_pattern_matcher.py to be reusable for XPU. (#150286)
This PR extracts some test cases from TestPatternMatcher into a newly created TestPatternMatcherGeneric, and uses instantiate_device_type_tests to make them reusable across multiple devices. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150286 Approved by: https://github.com/jansel
This commit is contained in:
parent
f8aa6404ac
commit
58ede0cca3
|
|
@ -189,7 +189,7 @@ if RUN_CPU:
|
|||
BaseTest(
|
||||
"test_conv2d_unary",
|
||||
"cpu",
|
||||
test_mkldnn_pattern_matcher.TestPatternMatcher(),
|
||||
test_mkldnn_pattern_matcher.TestPatternMatcherGenericCPU(),
|
||||
condition=torch.backends.mkldnn.is_available(),
|
||||
slow=True,
|
||||
),
|
||||
|
|
@ -220,9 +220,9 @@ if RUN_CPU:
|
|||
],
|
||||
BaseTest("test_polar"),
|
||||
BaseTest(
|
||||
"test_linear_binary",
|
||||
"test_linear_binary_cpu",
|
||||
"",
|
||||
test_mkldnn_pattern_matcher.TestPatternMatcher(),
|
||||
test_mkldnn_pattern_matcher.TestPatternMatcherGenericCPU(),
|
||||
torch.backends.mkldnn.is_available()
|
||||
and torch.ops.mkldnn._is_mkldnn_bf16_supported(),
|
||||
),
|
||||
|
|
@ -359,7 +359,9 @@ if RUN_CPU:
|
|||
BaseTest("test_view_as_complex"),
|
||||
BaseTest("test_view_as_real"),
|
||||
BaseTest(
|
||||
"test_woq_int4", "cpu", test_mkldnn_pattern_matcher.TestPatternMatcher()
|
||||
"test_woq_int4",
|
||||
"cpu",
|
||||
test_mkldnn_pattern_matcher.TestPatternMatcher(),
|
||||
),
|
||||
]:
|
||||
make_test_case(
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from torch._inductor.test_case import run_tests, TestCase
|
|||
from torch._inductor.utils import run_and_get_code
|
||||
from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer
|
||||
from torch.nn import functional as F
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
||||
from torch.testing._internal.common_quantization import (
|
||||
_generate_qdq_quantized_model,
|
||||
skipIfNoDynamoSupport,
|
||||
|
|
@ -33,7 +34,11 @@ from torch.testing._internal.common_utils import (
|
|||
TEST_MKL,
|
||||
xfailIfACL,
|
||||
)
|
||||
from torch.testing._internal.inductor_utils import _check_has_dynamic_shape, HAS_CPU
|
||||
from torch.testing._internal.inductor_utils import (
|
||||
_check_has_dynamic_shape,
|
||||
clone_preserve_strides_offset,
|
||||
HAS_CPU,
|
||||
)
|
||||
|
||||
|
||||
# The dict value is match_nodes(computation_op+unary_op)
|
||||
|
|
@ -91,7 +96,7 @@ def get_default_quantizer(is_qat, is_dynamic):
|
|||
return quantizer
|
||||
|
||||
|
||||
def cal_conv_generated_kernel_number(mod, input, dtype, dim=4):
|
||||
def cal_conv_generated_kernel_number(mod, input, dtype, dim=4, device="cpu"):
|
||||
# this function is to decide how many kernels are generated
|
||||
# while testing conv2d/3d/deconv2d
|
||||
# the assumption is:
|
||||
|
|
@ -103,11 +108,14 @@ def cal_conv_generated_kernel_number(mod, input, dtype, dim=4):
|
|||
# and force the output to have same stride with eager.
|
||||
# So there will be a to_contiguous for output if eager output is contiguouse
|
||||
mod = copy.deepcopy(mod)
|
||||
mod = mod.to(device=device)
|
||||
input = input.clone()
|
||||
input = input.to(device)
|
||||
|
||||
if dtype == torch.float32:
|
||||
maybe_autocast = contextlib.nullcontext()
|
||||
else:
|
||||
maybe_autocast = torch.amp.autocast("cpu", dtype=dtype)
|
||||
maybe_autocast = torch.amp.autocast(device_type=device, dtype=dtype)
|
||||
with torch.no_grad(), maybe_autocast:
|
||||
output = mod(input)
|
||||
input_kernel, output_kernel = 0, 0
|
||||
|
|
@ -155,26 +163,33 @@ class TestPatternMatcherBase(TestCase):
|
|||
quantizer=None,
|
||||
compile_options={}, # noqa: B006
|
||||
):
|
||||
if not hasattr(self, "device"):
|
||||
has_xpu = any(
|
||||
isinstance(input, torch.Tensor) and input.device.type == "xpu"
|
||||
for input in inputs
|
||||
)
|
||||
device = "xpu" if has_xpu else "cpu"
|
||||
else:
|
||||
device = self.device
|
||||
|
||||
mod = mod.to(device=device)
|
||||
if device != "cpu":
|
||||
inputs = tuple(
|
||||
clone_preserve_strides_offset(x, device=device) for x in inputs
|
||||
)
|
||||
counters.clear()
|
||||
torch._dynamo.reset()
|
||||
has_xpu = any(
|
||||
isinstance(input, torch.Tensor) and input.device.type == "xpu"
|
||||
for input in inputs
|
||||
)
|
||||
device_type = "xpu" if has_xpu else "cpu"
|
||||
if check_autocast == torch.bfloat16 and (
|
||||
torch.ops.mkldnn._is_mkldnn_bf16_supported() or has_xpu
|
||||
torch.ops.mkldnn._is_mkldnn_bf16_supported() or device == "xpu"
|
||||
):
|
||||
maybe_autocast = torch.amp.autocast(
|
||||
device_type=device_type, dtype=torch.bfloat16
|
||||
device_type=device, dtype=torch.bfloat16
|
||||
)
|
||||
atol, rtol = 1e-2, 1e-2
|
||||
elif check_autocast == torch.float16 and (
|
||||
torch.ops.mkldnn._is_mkldnn_fp16_supported() or has_xpu
|
||||
torch.ops.mkldnn._is_mkldnn_fp16_supported() or device == "xpu"
|
||||
):
|
||||
maybe_autocast = torch.amp.autocast(
|
||||
device_type=device_type, dtype=torch.float16
|
||||
)
|
||||
maybe_autocast = torch.amp.autocast(device_type=device, dtype=torch.float16)
|
||||
atol, rtol = 1e-2, 1e-2
|
||||
else:
|
||||
assert check_autocast == torch.float32
|
||||
|
|
@ -233,8 +248,8 @@ class TestPatternMatcherBase(TestCase):
|
|||
torch.testing.assert_close(actual, expected, atol=atol, rtol=rtol)
|
||||
|
||||
|
||||
class TestPatternMatcher(TestPatternMatcherBase):
|
||||
def _test_conv_unary_cpu_base(self, dim=4):
|
||||
class TestPatternMatcherGeneric(TestPatternMatcherBase):
|
||||
def _test_conv_unary_base(self, dim=4):
|
||||
assert dim == 4 or dim == 5
|
||||
|
||||
class M(torch.nn.Module):
|
||||
|
|
@ -304,23 +319,27 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
|||
|
||||
self._test_common(mod, (v,), matcher_check_fn, check_autocast=dtype)
|
||||
generated_kernel_count = cal_conv_generated_kernel_number(
|
||||
mod, v, dtype, dim
|
||||
mod, v, dtype, dim, self.device
|
||||
)
|
||||
self.assertEqual(metrics.generated_kernel_count, generated_kernel_count)
|
||||
|
||||
@skipIfNoDynamoSupport
|
||||
@skipIfNoONEDNN
|
||||
@skipIfRocm
|
||||
def test_conv2d_unary_cpu(self):
|
||||
self._test_conv_unary_cpu_base(dim=4)
|
||||
def test_conv2d_unary(self, device):
|
||||
self.device = device
|
||||
self._test_conv_unary_base(dim=4)
|
||||
|
||||
@skipIfNoDynamoSupport
|
||||
@skipIfNoONEDNN
|
||||
@skipIfRocm
|
||||
def test_conv3d_unary_cpu(self):
|
||||
self._test_conv_unary_cpu_base(dim=5)
|
||||
def test_conv3d_unary(self, device):
|
||||
self.device = device
|
||||
self._test_conv_unary_base(dim=5)
|
||||
|
||||
def test_linear_unary(self, device):
|
||||
self.device = device
|
||||
|
||||
def test_linear_unary(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -374,7 +393,9 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
|||
self.assertEqual(metrics.generated_kernel_count, 2 if TEST_ACL else 1)
|
||||
|
||||
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
|
||||
def test_linear_fp32(self):
|
||||
def test_linear_fp32(self, device):
|
||||
self.device = device
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self, bias):
|
||||
super().__init__()
|
||||
|
|
@ -396,7 +417,9 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
|||
self._test_common(mod, (v,), matcher_check_fn)
|
||||
|
||||
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
|
||||
def test_linear_input_non_contiguous_3D_wo_bias(self):
|
||||
def test_linear_input_non_contiguous_3D_wo_bias(self, device):
|
||||
self.device = device
|
||||
|
||||
# Activation is 3D, non-contiguous and without Bias
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
|
@ -438,17 +461,19 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
|||
)
|
||||
torch.testing.assert_close(actual, expected, atol=1e-2, rtol=1e-2)
|
||||
|
||||
def test_linear_add_bias(self):
|
||||
def test_linear_add_bias(self, device):
|
||||
self.device = device
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self, dtype, unary_fn, cast_bias):
|
||||
def __init__(self, device, dtype, unary_fn, cast_bias):
|
||||
super().__init__()
|
||||
self.linear1 = torch.nn.Linear(10, 64, bias=False)
|
||||
self.bias1 = torch.randn(64)
|
||||
self.bias1 = torch.randn(64, device=device)
|
||||
self.linear2 = torch.nn.Linear(10, 64, bias=False)
|
||||
self.bias2 = torch.randn(64)
|
||||
self.bias2 = torch.randn(64, device=device)
|
||||
if cast_bias:
|
||||
self.bias1 = self.bias1.to(dtype=dtype)
|
||||
self.bias2 = self.bias2.to(dtype=dtype)
|
||||
self.bias1 = self.bias1.to(dtype=dtype, device=device)
|
||||
self.bias2 = self.bias2.to(dtype=dtype, device=device)
|
||||
self.unary_fn = unary_fn
|
||||
|
||||
def forward(self, x):
|
||||
|
|
@ -464,7 +489,7 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
|||
options = itertools.product(unary_list, dtypes)
|
||||
for unary_fn, dtype in options:
|
||||
metrics.reset()
|
||||
fold_mod = M(dtype, unary_fn, cast_bias=True).eval()
|
||||
fold_mod = M(self.device, dtype, unary_fn, cast_bias=True).eval()
|
||||
v = torch.randn(2, 10)
|
||||
|
||||
def folder_matcher_check_fn():
|
||||
|
|
@ -495,7 +520,7 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
|||
# we won't fold the bias if bias is not same dtype with weight
|
||||
# https://github.com/pytorch/pytorch/pull/129138
|
||||
metrics.reset()
|
||||
mod = M(dtype, unary_fn, cast_bias=False).eval()
|
||||
mod = M(self.device, dtype, unary_fn, cast_bias=False).eval()
|
||||
|
||||
def matcher_check_fn():
|
||||
self.assertEqual(
|
||||
|
|
@ -575,20 +600,22 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
|||
|
||||
self._test_common(mod, (v,), matcher_check_fn, check_autocast=dtype)
|
||||
generated_kernel_count = cal_conv_generated_kernel_number(
|
||||
mod, v, dtype, dim
|
||||
mod, v, dtype, dim, self.device
|
||||
)
|
||||
self.assertEqual(metrics.generated_kernel_count, generated_kernel_count)
|
||||
|
||||
@skipIfNoDynamoSupport
|
||||
@skipIfNoONEDNN
|
||||
@skipIfRocm
|
||||
def test_conv_transpose2d_unary_cpu(self):
|
||||
def test_conv_transpose2d_unary(self, device):
|
||||
self.device = device
|
||||
self._test_conv_transpose_unary_base(dim=4)
|
||||
|
||||
@skipIfNoDynamoSupport
|
||||
@skipIfNoONEDNN
|
||||
@skipIfRocm
|
||||
def test_conv_transpose3d_unary_cpu(self):
|
||||
def test_conv_transpose3d_unary(self, device):
|
||||
self.device = device
|
||||
self._test_conv_transpose_unary_base(dim=5)
|
||||
|
||||
def _test_conv_binary_base(self, dim=4):
|
||||
|
|
@ -669,20 +696,22 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
|||
|
||||
self._test_common(mod, (v,), matcher_check_fn, check_autocast=dtype)
|
||||
generated_kernel_count = cal_conv_generated_kernel_number(
|
||||
mod, v, dtype, dim
|
||||
mod, v, dtype, dim, self.device
|
||||
)
|
||||
self.assertEqual(metrics.generated_kernel_count, generated_kernel_count)
|
||||
|
||||
@skipIfNoDynamoSupport
|
||||
@skipIfNoONEDNN
|
||||
@skipIfRocm
|
||||
def test_conv2d_binary(self):
|
||||
def test_conv2d_binary(self, device):
|
||||
self.device = device
|
||||
self._test_conv_binary_base(dim=4)
|
||||
|
||||
@skipIfNoDynamoSupport
|
||||
@skipIfNoONEDNN
|
||||
@skipIfRocm
|
||||
def test_conv3d_binary(self):
|
||||
def test_conv3d_binary(self, device):
|
||||
self.device = device
|
||||
self._test_conv_binary_base(dim=5)
|
||||
|
||||
def _test_conv_binary_broadcast_shapes_base(self, dim=4):
|
||||
|
|
@ -788,7 +817,9 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
|||
def test_conv3d_binary_broadcast_shapes_cpu(self):
|
||||
self._test_conv_binary_broadcast_shapes_base(dim=5)
|
||||
|
||||
def test_linear_binary(self):
|
||||
def test_linear_binary(self, device):
|
||||
self.device = device
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self, binary_fn, in_channels, out_channels, bias, **kwargs):
|
||||
super().__init__()
|
||||
|
|
@ -939,7 +970,9 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
|||
|
||||
self._test_common(mod, (x1, x2), matcher_check_fn)
|
||||
|
||||
def test_multi_linear_share_same_input(self):
|
||||
def test_multi_linear_share_same_input(self, device):
|
||||
self.device = device
|
||||
|
||||
# llama pattern.
|
||||
class M(torch.nn.Module):
|
||||
def __init__(
|
||||
|
|
@ -979,6 +1012,8 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
|||
v = torch.randn(2, 4, 16).to(dtype)
|
||||
self._test_common(mod, (v,), matcher_check_fn, rtol=1e-2, atol=1e-2)
|
||||
|
||||
|
||||
class TestPatternMatcher(TestPatternMatcherBase):
|
||||
def _qconv2d_test_helper(self, device="cpu", int8_mixed_bf16=False):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(
|
||||
|
|
@ -4119,30 +4154,42 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
|||
self.assertEqual(counters["inductor"]["qlinear_binary_matcher_count"], 1)
|
||||
|
||||
|
||||
# When testing kernel counts, unspecializing float causes wobbling of our tests because
|
||||
# we end up reusing the same compiled region across tests. Thus we purposely specialize floats
|
||||
# here since we primarily care about number of kernels generated in the absence of compile
|
||||
# caching.
|
||||
@dynamo_config.patch(
|
||||
{
|
||||
"dynamic_shapes": True,
|
||||
"assume_static_by_default": False,
|
||||
"specialize_float": True,
|
||||
}
|
||||
)
|
||||
class TestDynamicPatternMatcher(TestPatternMatcherBase):
|
||||
_test_conv_unary_cpu_base = TestPatternMatcher._test_conv_unary_cpu_base
|
||||
test_conv2d_unary_dynamic_shapes = TestPatternMatcher.test_conv2d_unary_cpu
|
||||
test_conv3d_unary_dynamic_shapes = TestPatternMatcher.test_conv3d_unary_cpu
|
||||
_test_conv_binary_base = TestPatternMatcher._test_conv_binary_base
|
||||
test_conv2d_binary_dynamic_shapes = TestPatternMatcher.test_conv2d_binary
|
||||
test_conv3d_binary_dynamic_shapes = TestPatternMatcher.test_conv3d_binary
|
||||
test_linear_unary_dynamic_shapes = TestPatternMatcher.test_linear_unary
|
||||
class TestDynamicPatternMatcherGeneric(TestPatternMatcherBase):
|
||||
def setUp(self):
|
||||
TestCase.setUp(self)
|
||||
self.ctx_stack = contextlib.ExitStack()
|
||||
self.ctx_stack.enter_context(
|
||||
# When testing kernel counts, unspecializing float causes wobbling of our tests because
|
||||
# we end up reusing the same compiled region across tests. Thus we purposely specialize floats
|
||||
# here since we primarily care about number of kernels generated in the absence of compile
|
||||
# caching.
|
||||
dynamo_config.patch(
|
||||
{
|
||||
"dynamic_shapes": True,
|
||||
"assume_static_by_default": False,
|
||||
"specialize_float": True,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
TestCase.tearDown(self)
|
||||
self.ctx_stack.close()
|
||||
|
||||
_test_conv_unary_base = TestPatternMatcherGeneric._test_conv_unary_base
|
||||
test_conv2d_unary_dynamic_shapes = TestPatternMatcherGeneric.test_conv2d_unary
|
||||
test_conv3d_unary_dynamic_shapes = TestPatternMatcherGeneric.test_conv3d_unary
|
||||
_test_conv_binary_base = TestPatternMatcherGeneric._test_conv_binary_base
|
||||
test_conv2d_binary_dynamic_shapes = TestPatternMatcherGeneric.test_conv2d_binary
|
||||
test_conv3d_binary_dynamic_shapes = TestPatternMatcherGeneric.test_conv3d_binary
|
||||
test_linear_unary_dynamic_shapes = TestPatternMatcherGeneric.test_linear_unary
|
||||
test_linear_input_non_contiguous_3D_wo_bias_dynamic_shapes = (
|
||||
TestPatternMatcher.test_linear_input_non_contiguous_3D_wo_bias
|
||||
TestPatternMatcherGeneric.test_linear_input_non_contiguous_3D_wo_bias
|
||||
)
|
||||
|
||||
def test_conv_transpose2d_dynamic_shapes(self):
|
||||
def test_conv_transpose2d_dynamic_shapes(self, device):
|
||||
self.device = device
|
||||
|
||||
# We don't support conv_transpose2d for now.
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
|
|
@ -4163,7 +4210,9 @@ class TestDynamicPatternMatcher(TestPatternMatcherBase):
|
|||
|
||||
self._test_common(mod, (v,), matcher_check_fn)
|
||||
|
||||
def test_multi_linear_share_same_input_dynamic(self):
|
||||
def test_multi_linear_share_same_input_dynamic(self, device):
|
||||
self.device = device
|
||||
|
||||
# llama pattern.
|
||||
class M(torch.nn.Module):
|
||||
def __init__(
|
||||
|
|
@ -4206,6 +4255,15 @@ class TestDynamicPatternMatcher(TestPatternMatcherBase):
|
|||
v = torch.randn(2, 4, 16).to(dtype)
|
||||
self._test_common(mod, (v,), matcher_check_fn, rtol=1e-2, atol=1e-2)
|
||||
|
||||
|
||||
@dynamo_config.patch(
|
||||
{
|
||||
"dynamic_shapes": True,
|
||||
"assume_static_by_default": False,
|
||||
"specialize_float": True,
|
||||
}
|
||||
)
|
||||
class TestDynamicPatternMatcher(TestPatternMatcherBase):
|
||||
@xfailIfACL
|
||||
def test_qconv2d_maxpool2d_linear_dynamic_cpu(self, include_ops=None):
|
||||
r"""
|
||||
|
|
@ -4367,8 +4425,13 @@ class TestDynamicPatternMatcher(TestPatternMatcherBase):
|
|||
)
|
||||
|
||||
|
||||
instantiate_device_type_tests(
|
||||
TestPatternMatcherGeneric, globals(), allow_xpu=True, only_for=("cpu")
|
||||
)
|
||||
instantiate_device_type_tests(
|
||||
TestDynamicPatternMatcherGeneric, globals(), allow_xpu=True, only_for=("cpu")
|
||||
)
|
||||
instantiate_parametrized_tests(TestPatternMatcher)
|
||||
|
||||
if __name__ == "__main__":
|
||||
if IS_LINUX and HAS_CPU and torch.backends.mkldnn.is_available():
|
||||
if IS_LINUX and (HAS_CPU) and torch.backends.mkldnn.is_available():
|
||||
run_tests()
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user