mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[Inductor XPU] Refine test_mkldnn_pattern_matcher.py to be reusable for XPU. (#150286)
This PR extracts some test cases from TestPatternMatcher into a newly created TestPatternMatcherGeneric, and uses instantiate_device_type_tests to make them reusable across multiple devices. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150286 Approved by: https://github.com/jansel
This commit is contained in:
parent
f8aa6404ac
commit
58ede0cca3
|
|
@ -189,7 +189,7 @@ if RUN_CPU:
|
||||||
BaseTest(
|
BaseTest(
|
||||||
"test_conv2d_unary",
|
"test_conv2d_unary",
|
||||||
"cpu",
|
"cpu",
|
||||||
test_mkldnn_pattern_matcher.TestPatternMatcher(),
|
test_mkldnn_pattern_matcher.TestPatternMatcherGenericCPU(),
|
||||||
condition=torch.backends.mkldnn.is_available(),
|
condition=torch.backends.mkldnn.is_available(),
|
||||||
slow=True,
|
slow=True,
|
||||||
),
|
),
|
||||||
|
|
@ -220,9 +220,9 @@ if RUN_CPU:
|
||||||
],
|
],
|
||||||
BaseTest("test_polar"),
|
BaseTest("test_polar"),
|
||||||
BaseTest(
|
BaseTest(
|
||||||
"test_linear_binary",
|
"test_linear_binary_cpu",
|
||||||
"",
|
"",
|
||||||
test_mkldnn_pattern_matcher.TestPatternMatcher(),
|
test_mkldnn_pattern_matcher.TestPatternMatcherGenericCPU(),
|
||||||
torch.backends.mkldnn.is_available()
|
torch.backends.mkldnn.is_available()
|
||||||
and torch.ops.mkldnn._is_mkldnn_bf16_supported(),
|
and torch.ops.mkldnn._is_mkldnn_bf16_supported(),
|
||||||
),
|
),
|
||||||
|
|
@ -359,7 +359,9 @@ if RUN_CPU:
|
||||||
BaseTest("test_view_as_complex"),
|
BaseTest("test_view_as_complex"),
|
||||||
BaseTest("test_view_as_real"),
|
BaseTest("test_view_as_real"),
|
||||||
BaseTest(
|
BaseTest(
|
||||||
"test_woq_int4", "cpu", test_mkldnn_pattern_matcher.TestPatternMatcher()
|
"test_woq_int4",
|
||||||
|
"cpu",
|
||||||
|
test_mkldnn_pattern_matcher.TestPatternMatcher(),
|
||||||
),
|
),
|
||||||
]:
|
]:
|
||||||
make_test_case(
|
make_test_case(
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,7 @@ from torch._inductor.test_case import run_tests, TestCase
|
||||||
from torch._inductor.utils import run_and_get_code
|
from torch._inductor.utils import run_and_get_code
|
||||||
from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer
|
from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
||||||
from torch.testing._internal.common_quantization import (
|
from torch.testing._internal.common_quantization import (
|
||||||
_generate_qdq_quantized_model,
|
_generate_qdq_quantized_model,
|
||||||
skipIfNoDynamoSupport,
|
skipIfNoDynamoSupport,
|
||||||
|
|
@ -33,7 +34,11 @@ from torch.testing._internal.common_utils import (
|
||||||
TEST_MKL,
|
TEST_MKL,
|
||||||
xfailIfACL,
|
xfailIfACL,
|
||||||
)
|
)
|
||||||
from torch.testing._internal.inductor_utils import _check_has_dynamic_shape, HAS_CPU
|
from torch.testing._internal.inductor_utils import (
|
||||||
|
_check_has_dynamic_shape,
|
||||||
|
clone_preserve_strides_offset,
|
||||||
|
HAS_CPU,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# The dict value is match_nodes(computation_op+unary_op)
|
# The dict value is match_nodes(computation_op+unary_op)
|
||||||
|
|
@ -91,7 +96,7 @@ def get_default_quantizer(is_qat, is_dynamic):
|
||||||
return quantizer
|
return quantizer
|
||||||
|
|
||||||
|
|
||||||
def cal_conv_generated_kernel_number(mod, input, dtype, dim=4):
|
def cal_conv_generated_kernel_number(mod, input, dtype, dim=4, device="cpu"):
|
||||||
# this function is to decide how many kernels are generated
|
# this function is to decide how many kernels are generated
|
||||||
# while testing conv2d/3d/deconv2d
|
# while testing conv2d/3d/deconv2d
|
||||||
# the assumption is:
|
# the assumption is:
|
||||||
|
|
@ -103,11 +108,14 @@ def cal_conv_generated_kernel_number(mod, input, dtype, dim=4):
|
||||||
# and force the output to have same stride with eager.
|
# and force the output to have same stride with eager.
|
||||||
# So there will be a to_contiguous for output if eager output is contiguouse
|
# So there will be a to_contiguous for output if eager output is contiguouse
|
||||||
mod = copy.deepcopy(mod)
|
mod = copy.deepcopy(mod)
|
||||||
|
mod = mod.to(device=device)
|
||||||
input = input.clone()
|
input = input.clone()
|
||||||
|
input = input.to(device)
|
||||||
|
|
||||||
if dtype == torch.float32:
|
if dtype == torch.float32:
|
||||||
maybe_autocast = contextlib.nullcontext()
|
maybe_autocast = contextlib.nullcontext()
|
||||||
else:
|
else:
|
||||||
maybe_autocast = torch.amp.autocast("cpu", dtype=dtype)
|
maybe_autocast = torch.amp.autocast(device_type=device, dtype=dtype)
|
||||||
with torch.no_grad(), maybe_autocast:
|
with torch.no_grad(), maybe_autocast:
|
||||||
output = mod(input)
|
output = mod(input)
|
||||||
input_kernel, output_kernel = 0, 0
|
input_kernel, output_kernel = 0, 0
|
||||||
|
|
@ -155,26 +163,33 @@ class TestPatternMatcherBase(TestCase):
|
||||||
quantizer=None,
|
quantizer=None,
|
||||||
compile_options={}, # noqa: B006
|
compile_options={}, # noqa: B006
|
||||||
):
|
):
|
||||||
|
if not hasattr(self, "device"):
|
||||||
|
has_xpu = any(
|
||||||
|
isinstance(input, torch.Tensor) and input.device.type == "xpu"
|
||||||
|
for input in inputs
|
||||||
|
)
|
||||||
|
device = "xpu" if has_xpu else "cpu"
|
||||||
|
else:
|
||||||
|
device = self.device
|
||||||
|
|
||||||
|
mod = mod.to(device=device)
|
||||||
|
if device != "cpu":
|
||||||
|
inputs = tuple(
|
||||||
|
clone_preserve_strides_offset(x, device=device) for x in inputs
|
||||||
|
)
|
||||||
counters.clear()
|
counters.clear()
|
||||||
torch._dynamo.reset()
|
torch._dynamo.reset()
|
||||||
has_xpu = any(
|
|
||||||
isinstance(input, torch.Tensor) and input.device.type == "xpu"
|
|
||||||
for input in inputs
|
|
||||||
)
|
|
||||||
device_type = "xpu" if has_xpu else "cpu"
|
|
||||||
if check_autocast == torch.bfloat16 and (
|
if check_autocast == torch.bfloat16 and (
|
||||||
torch.ops.mkldnn._is_mkldnn_bf16_supported() or has_xpu
|
torch.ops.mkldnn._is_mkldnn_bf16_supported() or device == "xpu"
|
||||||
):
|
):
|
||||||
maybe_autocast = torch.amp.autocast(
|
maybe_autocast = torch.amp.autocast(
|
||||||
device_type=device_type, dtype=torch.bfloat16
|
device_type=device, dtype=torch.bfloat16
|
||||||
)
|
)
|
||||||
atol, rtol = 1e-2, 1e-2
|
atol, rtol = 1e-2, 1e-2
|
||||||
elif check_autocast == torch.float16 and (
|
elif check_autocast == torch.float16 and (
|
||||||
torch.ops.mkldnn._is_mkldnn_fp16_supported() or has_xpu
|
torch.ops.mkldnn._is_mkldnn_fp16_supported() or device == "xpu"
|
||||||
):
|
):
|
||||||
maybe_autocast = torch.amp.autocast(
|
maybe_autocast = torch.amp.autocast(device_type=device, dtype=torch.float16)
|
||||||
device_type=device_type, dtype=torch.float16
|
|
||||||
)
|
|
||||||
atol, rtol = 1e-2, 1e-2
|
atol, rtol = 1e-2, 1e-2
|
||||||
else:
|
else:
|
||||||
assert check_autocast == torch.float32
|
assert check_autocast == torch.float32
|
||||||
|
|
@ -233,8 +248,8 @@ class TestPatternMatcherBase(TestCase):
|
||||||
torch.testing.assert_close(actual, expected, atol=atol, rtol=rtol)
|
torch.testing.assert_close(actual, expected, atol=atol, rtol=rtol)
|
||||||
|
|
||||||
|
|
||||||
class TestPatternMatcher(TestPatternMatcherBase):
|
class TestPatternMatcherGeneric(TestPatternMatcherBase):
|
||||||
def _test_conv_unary_cpu_base(self, dim=4):
|
def _test_conv_unary_base(self, dim=4):
|
||||||
assert dim == 4 or dim == 5
|
assert dim == 4 or dim == 5
|
||||||
|
|
||||||
class M(torch.nn.Module):
|
class M(torch.nn.Module):
|
||||||
|
|
@ -304,23 +319,27 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
||||||
|
|
||||||
self._test_common(mod, (v,), matcher_check_fn, check_autocast=dtype)
|
self._test_common(mod, (v,), matcher_check_fn, check_autocast=dtype)
|
||||||
generated_kernel_count = cal_conv_generated_kernel_number(
|
generated_kernel_count = cal_conv_generated_kernel_number(
|
||||||
mod, v, dtype, dim
|
mod, v, dtype, dim, self.device
|
||||||
)
|
)
|
||||||
self.assertEqual(metrics.generated_kernel_count, generated_kernel_count)
|
self.assertEqual(metrics.generated_kernel_count, generated_kernel_count)
|
||||||
|
|
||||||
@skipIfNoDynamoSupport
|
@skipIfNoDynamoSupport
|
||||||
@skipIfNoONEDNN
|
@skipIfNoONEDNN
|
||||||
@skipIfRocm
|
@skipIfRocm
|
||||||
def test_conv2d_unary_cpu(self):
|
def test_conv2d_unary(self, device):
|
||||||
self._test_conv_unary_cpu_base(dim=4)
|
self.device = device
|
||||||
|
self._test_conv_unary_base(dim=4)
|
||||||
|
|
||||||
@skipIfNoDynamoSupport
|
@skipIfNoDynamoSupport
|
||||||
@skipIfNoONEDNN
|
@skipIfNoONEDNN
|
||||||
@skipIfRocm
|
@skipIfRocm
|
||||||
def test_conv3d_unary_cpu(self):
|
def test_conv3d_unary(self, device):
|
||||||
self._test_conv_unary_cpu_base(dim=5)
|
self.device = device
|
||||||
|
self._test_conv_unary_base(dim=5)
|
||||||
|
|
||||||
|
def test_linear_unary(self, device):
|
||||||
|
self.device = device
|
||||||
|
|
||||||
def test_linear_unary(self):
|
|
||||||
class M(torch.nn.Module):
|
class M(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -374,7 +393,9 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
||||||
self.assertEqual(metrics.generated_kernel_count, 2 if TEST_ACL else 1)
|
self.assertEqual(metrics.generated_kernel_count, 2 if TEST_ACL else 1)
|
||||||
|
|
||||||
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
|
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
|
||||||
def test_linear_fp32(self):
|
def test_linear_fp32(self, device):
|
||||||
|
self.device = device
|
||||||
|
|
||||||
class M(torch.nn.Module):
|
class M(torch.nn.Module):
|
||||||
def __init__(self, bias):
|
def __init__(self, bias):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
@ -396,7 +417,9 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
||||||
self._test_common(mod, (v,), matcher_check_fn)
|
self._test_common(mod, (v,), matcher_check_fn)
|
||||||
|
|
||||||
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
|
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
|
||||||
def test_linear_input_non_contiguous_3D_wo_bias(self):
|
def test_linear_input_non_contiguous_3D_wo_bias(self, device):
|
||||||
|
self.device = device
|
||||||
|
|
||||||
# Activation is 3D, non-contiguous and without Bias
|
# Activation is 3D, non-contiguous and without Bias
|
||||||
class M(torch.nn.Module):
|
class M(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
@ -438,17 +461,19 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
||||||
)
|
)
|
||||||
torch.testing.assert_close(actual, expected, atol=1e-2, rtol=1e-2)
|
torch.testing.assert_close(actual, expected, atol=1e-2, rtol=1e-2)
|
||||||
|
|
||||||
def test_linear_add_bias(self):
|
def test_linear_add_bias(self, device):
|
||||||
|
self.device = device
|
||||||
|
|
||||||
class M(torch.nn.Module):
|
class M(torch.nn.Module):
|
||||||
def __init__(self, dtype, unary_fn, cast_bias):
|
def __init__(self, device, dtype, unary_fn, cast_bias):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.linear1 = torch.nn.Linear(10, 64, bias=False)
|
self.linear1 = torch.nn.Linear(10, 64, bias=False)
|
||||||
self.bias1 = torch.randn(64)
|
self.bias1 = torch.randn(64, device=device)
|
||||||
self.linear2 = torch.nn.Linear(10, 64, bias=False)
|
self.linear2 = torch.nn.Linear(10, 64, bias=False)
|
||||||
self.bias2 = torch.randn(64)
|
self.bias2 = torch.randn(64, device=device)
|
||||||
if cast_bias:
|
if cast_bias:
|
||||||
self.bias1 = self.bias1.to(dtype=dtype)
|
self.bias1 = self.bias1.to(dtype=dtype, device=device)
|
||||||
self.bias2 = self.bias2.to(dtype=dtype)
|
self.bias2 = self.bias2.to(dtype=dtype, device=device)
|
||||||
self.unary_fn = unary_fn
|
self.unary_fn = unary_fn
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
|
@ -464,7 +489,7 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
||||||
options = itertools.product(unary_list, dtypes)
|
options = itertools.product(unary_list, dtypes)
|
||||||
for unary_fn, dtype in options:
|
for unary_fn, dtype in options:
|
||||||
metrics.reset()
|
metrics.reset()
|
||||||
fold_mod = M(dtype, unary_fn, cast_bias=True).eval()
|
fold_mod = M(self.device, dtype, unary_fn, cast_bias=True).eval()
|
||||||
v = torch.randn(2, 10)
|
v = torch.randn(2, 10)
|
||||||
|
|
||||||
def folder_matcher_check_fn():
|
def folder_matcher_check_fn():
|
||||||
|
|
@ -495,7 +520,7 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
||||||
# we won't fold the bias if bias is not same dtype with weight
|
# we won't fold the bias if bias is not same dtype with weight
|
||||||
# https://github.com/pytorch/pytorch/pull/129138
|
# https://github.com/pytorch/pytorch/pull/129138
|
||||||
metrics.reset()
|
metrics.reset()
|
||||||
mod = M(dtype, unary_fn, cast_bias=False).eval()
|
mod = M(self.device, dtype, unary_fn, cast_bias=False).eval()
|
||||||
|
|
||||||
def matcher_check_fn():
|
def matcher_check_fn():
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
|
|
@ -575,20 +600,22 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
||||||
|
|
||||||
self._test_common(mod, (v,), matcher_check_fn, check_autocast=dtype)
|
self._test_common(mod, (v,), matcher_check_fn, check_autocast=dtype)
|
||||||
generated_kernel_count = cal_conv_generated_kernel_number(
|
generated_kernel_count = cal_conv_generated_kernel_number(
|
||||||
mod, v, dtype, dim
|
mod, v, dtype, dim, self.device
|
||||||
)
|
)
|
||||||
self.assertEqual(metrics.generated_kernel_count, generated_kernel_count)
|
self.assertEqual(metrics.generated_kernel_count, generated_kernel_count)
|
||||||
|
|
||||||
@skipIfNoDynamoSupport
|
@skipIfNoDynamoSupport
|
||||||
@skipIfNoONEDNN
|
@skipIfNoONEDNN
|
||||||
@skipIfRocm
|
@skipIfRocm
|
||||||
def test_conv_transpose2d_unary_cpu(self):
|
def test_conv_transpose2d_unary(self, device):
|
||||||
|
self.device = device
|
||||||
self._test_conv_transpose_unary_base(dim=4)
|
self._test_conv_transpose_unary_base(dim=4)
|
||||||
|
|
||||||
@skipIfNoDynamoSupport
|
@skipIfNoDynamoSupport
|
||||||
@skipIfNoONEDNN
|
@skipIfNoONEDNN
|
||||||
@skipIfRocm
|
@skipIfRocm
|
||||||
def test_conv_transpose3d_unary_cpu(self):
|
def test_conv_transpose3d_unary(self, device):
|
||||||
|
self.device = device
|
||||||
self._test_conv_transpose_unary_base(dim=5)
|
self._test_conv_transpose_unary_base(dim=5)
|
||||||
|
|
||||||
def _test_conv_binary_base(self, dim=4):
|
def _test_conv_binary_base(self, dim=4):
|
||||||
|
|
@ -669,20 +696,22 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
||||||
|
|
||||||
self._test_common(mod, (v,), matcher_check_fn, check_autocast=dtype)
|
self._test_common(mod, (v,), matcher_check_fn, check_autocast=dtype)
|
||||||
generated_kernel_count = cal_conv_generated_kernel_number(
|
generated_kernel_count = cal_conv_generated_kernel_number(
|
||||||
mod, v, dtype, dim
|
mod, v, dtype, dim, self.device
|
||||||
)
|
)
|
||||||
self.assertEqual(metrics.generated_kernel_count, generated_kernel_count)
|
self.assertEqual(metrics.generated_kernel_count, generated_kernel_count)
|
||||||
|
|
||||||
@skipIfNoDynamoSupport
|
@skipIfNoDynamoSupport
|
||||||
@skipIfNoONEDNN
|
@skipIfNoONEDNN
|
||||||
@skipIfRocm
|
@skipIfRocm
|
||||||
def test_conv2d_binary(self):
|
def test_conv2d_binary(self, device):
|
||||||
|
self.device = device
|
||||||
self._test_conv_binary_base(dim=4)
|
self._test_conv_binary_base(dim=4)
|
||||||
|
|
||||||
@skipIfNoDynamoSupport
|
@skipIfNoDynamoSupport
|
||||||
@skipIfNoONEDNN
|
@skipIfNoONEDNN
|
||||||
@skipIfRocm
|
@skipIfRocm
|
||||||
def test_conv3d_binary(self):
|
def test_conv3d_binary(self, device):
|
||||||
|
self.device = device
|
||||||
self._test_conv_binary_base(dim=5)
|
self._test_conv_binary_base(dim=5)
|
||||||
|
|
||||||
def _test_conv_binary_broadcast_shapes_base(self, dim=4):
|
def _test_conv_binary_broadcast_shapes_base(self, dim=4):
|
||||||
|
|
@ -788,7 +817,9 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
||||||
def test_conv3d_binary_broadcast_shapes_cpu(self):
|
def test_conv3d_binary_broadcast_shapes_cpu(self):
|
||||||
self._test_conv_binary_broadcast_shapes_base(dim=5)
|
self._test_conv_binary_broadcast_shapes_base(dim=5)
|
||||||
|
|
||||||
def test_linear_binary(self):
|
def test_linear_binary(self, device):
|
||||||
|
self.device = device
|
||||||
|
|
||||||
class M(torch.nn.Module):
|
class M(torch.nn.Module):
|
||||||
def __init__(self, binary_fn, in_channels, out_channels, bias, **kwargs):
|
def __init__(self, binary_fn, in_channels, out_channels, bias, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
@ -939,7 +970,9 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
||||||
|
|
||||||
self._test_common(mod, (x1, x2), matcher_check_fn)
|
self._test_common(mod, (x1, x2), matcher_check_fn)
|
||||||
|
|
||||||
def test_multi_linear_share_same_input(self):
|
def test_multi_linear_share_same_input(self, device):
|
||||||
|
self.device = device
|
||||||
|
|
||||||
# llama pattern.
|
# llama pattern.
|
||||||
class M(torch.nn.Module):
|
class M(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
@ -979,6 +1012,8 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
||||||
v = torch.randn(2, 4, 16).to(dtype)
|
v = torch.randn(2, 4, 16).to(dtype)
|
||||||
self._test_common(mod, (v,), matcher_check_fn, rtol=1e-2, atol=1e-2)
|
self._test_common(mod, (v,), matcher_check_fn, rtol=1e-2, atol=1e-2)
|
||||||
|
|
||||||
|
|
||||||
|
class TestPatternMatcher(TestPatternMatcherBase):
|
||||||
def _qconv2d_test_helper(self, device="cpu", int8_mixed_bf16=False):
|
def _qconv2d_test_helper(self, device="cpu", int8_mixed_bf16=False):
|
||||||
class M(torch.nn.Module):
|
class M(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
@ -4119,30 +4154,42 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
||||||
self.assertEqual(counters["inductor"]["qlinear_binary_matcher_count"], 1)
|
self.assertEqual(counters["inductor"]["qlinear_binary_matcher_count"], 1)
|
||||||
|
|
||||||
|
|
||||||
# When testing kernel counts, unspecializing float causes wobbling of our tests because
|
class TestDynamicPatternMatcherGeneric(TestPatternMatcherBase):
|
||||||
# we end up reusing the same compiled region across tests. Thus we purposely specialize floats
|
def setUp(self):
|
||||||
# here since we primarily care about number of kernels generated in the absence of compile
|
TestCase.setUp(self)
|
||||||
# caching.
|
self.ctx_stack = contextlib.ExitStack()
|
||||||
@dynamo_config.patch(
|
self.ctx_stack.enter_context(
|
||||||
{
|
# When testing kernel counts, unspecializing float causes wobbling of our tests because
|
||||||
"dynamic_shapes": True,
|
# we end up reusing the same compiled region across tests. Thus we purposely specialize floats
|
||||||
"assume_static_by_default": False,
|
# here since we primarily care about number of kernels generated in the absence of compile
|
||||||
"specialize_float": True,
|
# caching.
|
||||||
}
|
dynamo_config.patch(
|
||||||
)
|
{
|
||||||
class TestDynamicPatternMatcher(TestPatternMatcherBase):
|
"dynamic_shapes": True,
|
||||||
_test_conv_unary_cpu_base = TestPatternMatcher._test_conv_unary_cpu_base
|
"assume_static_by_default": False,
|
||||||
test_conv2d_unary_dynamic_shapes = TestPatternMatcher.test_conv2d_unary_cpu
|
"specialize_float": True,
|
||||||
test_conv3d_unary_dynamic_shapes = TestPatternMatcher.test_conv3d_unary_cpu
|
}
|
||||||
_test_conv_binary_base = TestPatternMatcher._test_conv_binary_base
|
)
|
||||||
test_conv2d_binary_dynamic_shapes = TestPatternMatcher.test_conv2d_binary
|
)
|
||||||
test_conv3d_binary_dynamic_shapes = TestPatternMatcher.test_conv3d_binary
|
|
||||||
test_linear_unary_dynamic_shapes = TestPatternMatcher.test_linear_unary
|
def tearDown(self):
|
||||||
|
TestCase.tearDown(self)
|
||||||
|
self.ctx_stack.close()
|
||||||
|
|
||||||
|
_test_conv_unary_base = TestPatternMatcherGeneric._test_conv_unary_base
|
||||||
|
test_conv2d_unary_dynamic_shapes = TestPatternMatcherGeneric.test_conv2d_unary
|
||||||
|
test_conv3d_unary_dynamic_shapes = TestPatternMatcherGeneric.test_conv3d_unary
|
||||||
|
_test_conv_binary_base = TestPatternMatcherGeneric._test_conv_binary_base
|
||||||
|
test_conv2d_binary_dynamic_shapes = TestPatternMatcherGeneric.test_conv2d_binary
|
||||||
|
test_conv3d_binary_dynamic_shapes = TestPatternMatcherGeneric.test_conv3d_binary
|
||||||
|
test_linear_unary_dynamic_shapes = TestPatternMatcherGeneric.test_linear_unary
|
||||||
test_linear_input_non_contiguous_3D_wo_bias_dynamic_shapes = (
|
test_linear_input_non_contiguous_3D_wo_bias_dynamic_shapes = (
|
||||||
TestPatternMatcher.test_linear_input_non_contiguous_3D_wo_bias
|
TestPatternMatcherGeneric.test_linear_input_non_contiguous_3D_wo_bias
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_conv_transpose2d_dynamic_shapes(self):
|
def test_conv_transpose2d_dynamic_shapes(self, device):
|
||||||
|
self.device = device
|
||||||
|
|
||||||
# We don't support conv_transpose2d for now.
|
# We don't support conv_transpose2d for now.
|
||||||
class M(torch.nn.Module):
|
class M(torch.nn.Module):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
|
|
@ -4163,7 +4210,9 @@ class TestDynamicPatternMatcher(TestPatternMatcherBase):
|
||||||
|
|
||||||
self._test_common(mod, (v,), matcher_check_fn)
|
self._test_common(mod, (v,), matcher_check_fn)
|
||||||
|
|
||||||
def test_multi_linear_share_same_input_dynamic(self):
|
def test_multi_linear_share_same_input_dynamic(self, device):
|
||||||
|
self.device = device
|
||||||
|
|
||||||
# llama pattern.
|
# llama pattern.
|
||||||
class M(torch.nn.Module):
|
class M(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
@ -4206,6 +4255,15 @@ class TestDynamicPatternMatcher(TestPatternMatcherBase):
|
||||||
v = torch.randn(2, 4, 16).to(dtype)
|
v = torch.randn(2, 4, 16).to(dtype)
|
||||||
self._test_common(mod, (v,), matcher_check_fn, rtol=1e-2, atol=1e-2)
|
self._test_common(mod, (v,), matcher_check_fn, rtol=1e-2, atol=1e-2)
|
||||||
|
|
||||||
|
|
||||||
|
@dynamo_config.patch(
|
||||||
|
{
|
||||||
|
"dynamic_shapes": True,
|
||||||
|
"assume_static_by_default": False,
|
||||||
|
"specialize_float": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
class TestDynamicPatternMatcher(TestPatternMatcherBase):
|
||||||
@xfailIfACL
|
@xfailIfACL
|
||||||
def test_qconv2d_maxpool2d_linear_dynamic_cpu(self, include_ops=None):
|
def test_qconv2d_maxpool2d_linear_dynamic_cpu(self, include_ops=None):
|
||||||
r"""
|
r"""
|
||||||
|
|
@ -4367,8 +4425,13 @@ class TestDynamicPatternMatcher(TestPatternMatcherBase):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
instantiate_device_type_tests(
|
||||||
|
TestPatternMatcherGeneric, globals(), allow_xpu=True, only_for=("cpu")
|
||||||
|
)
|
||||||
|
instantiate_device_type_tests(
|
||||||
|
TestDynamicPatternMatcherGeneric, globals(), allow_xpu=True, only_for=("cpu")
|
||||||
|
)
|
||||||
instantiate_parametrized_tests(TestPatternMatcher)
|
instantiate_parametrized_tests(TestPatternMatcher)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
if IS_LINUX and HAS_CPU and torch.backends.mkldnn.is_available():
|
if IS_LINUX and (HAS_CPU) and torch.backends.mkldnn.is_available():
|
||||||
run_tests()
|
run_tests()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user