Revert "cpp_wrapper: Fix even more tests (#147225)"

This reverts commit d25acac357.

Reverted https://github.com/pytorch/pytorch/pull/147225 on behalf of https://github.com/yangw-dev due to broke test internally test/inductor/test_benchmark_fusion ([comment](https://github.com/pytorch/pytorch/pull/147225#issuecomment-2761944564))
This commit is contained in:
PyTorch MergeBot 2025-03-28 17:07:52 +00:00
parent e691fcae0e
commit cf7447ae99
5 changed files with 50 additions and 78 deletions

View File

@ -10,12 +10,7 @@ from torch._inductor.test_operators import realize
from torch._inductor.utils import fresh_inductor_cache, is_big_gpu, run_and_get_code
from torch.testing import FileCheck
from torch.testing._internal.common_utils import slowTest
from torch.testing._internal.inductor_utils import (
get_func_call,
get_kernel_launch,
HAS_CPU,
HAS_CUDA,
)
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
# Make the helper files in test/ importable
@ -29,7 +24,6 @@ from inductor.test_torchinductor import ( # @manual=fbcode//caffe2/test/inducto
check_model,
check_model_cuda,
copy_tests,
skip_if_cpp_wrapper,
)
from torch._inductor import config
from torch._inductor.scheduler import Scheduler
@ -132,7 +126,7 @@ class BenchmarkFusionTestTemplate:
self.common(f, (a, b))
@config.patch(max_autotune_gemm_backends="TRITON")
@torch._inductor.config.patch(max_autotune_gemm_backends="TRITON")
def test_avoid_register_spilling(self):
if self.device != "cuda":
raise unittest.SkipTest("CUDA only")
@ -163,8 +157,8 @@ class BenchmarkFusionTestTemplate:
return
# should be multiple triton invocations
FileCheck().check(get_func_call()).check_count(
get_kernel_launch(), 2, exactly=True
FileCheck().check("async_compile.wait").check_count(
".run", 2, exactly=True
).run(out_code[0])
with config.patch(
@ -177,17 +171,9 @@ class BenchmarkFusionTestTemplate:
_, out_code2 = run_and_get_code(foo_c, m, inp)
for c in out_code[0], out_code2[0]:
FileCheck().check(get_func_call()).check(
"device_guard" if config.cpp_wrapper else "DeviceGuard"
).check_count("empty_strided", 1, exactly=True).check_regex(
r"output_handles\[[0-9]+\] = buf[0-9]+\.release\(\)"
if config.cpp_wrapper
else r"buf[0-9]+ = buf[0-9]+; del buf[0-9]+"
).check(
"" if config.cpp_wrapper else "return"
).run(
c
)
FileCheck().check("async_compile.wait").check("DeviceGuard").check_count(
"empty_strided_cuda", 1, exactly=True
).check_regex("buf[0-9]* = buf[0-9]*; del buf[0-9]*").check("return").run(c)
def test_tield_kernel_fusion(self):
def f(x):
@ -210,7 +196,6 @@ if HAS_CUDA:
@unittest.skipIf(
torch.cuda.device_count() < 2, "The test need at least 2 devices"
)
@skip_if_cpp_wrapper("This tests triton scheduling directly")
def test_benchmark_on_non_zero_device(self):
hit_count = 0
with torch.cuda.device("cuda:0"):
@ -280,7 +265,9 @@ if HAS_CUDA:
res, code = run_and_get_code(foo_c, m, inp)
torch._dynamo.reset()
with config.patch(benchmark_epilogue_fusion=False):
with unittest.mock.patch.object(
torch._inductor.config, "benchmark_epilogue_fusion", False
):
foo_c = torch.compile(mode="max-autotune-no-cudagraphs")(foo)
with torch.no_grad():
res2, code2 = run_and_get_code(foo_c, m, inp)
@ -289,34 +276,32 @@ if HAS_CUDA:
return code, code2
@fresh_inductor_cache()
@config.patch(max_autotune_gemm_backends="TRITON")
@torch._inductor.config.patch(max_autotune_gemm_backends="TRITON")
def test_equivalent_template_code(self):
code, code2 = self._equivalent_output_code_impl(256)
for out_code in [code, code2]:
FileCheck().check(get_func_call()).check_count(
"empty_strided", 1, exactly=True
).check("triton_tem_fused_addmm_relu_0").check_count(
".reset()" if config.cpp_wrapper else "del", 3, exactly=True
FileCheck().check("def call").check_count(
"empty_strided_cuda", 1, exactly=True
).check("triton_tem_fused_addmm_relu_0.run").check_count(
"del", 3, exactly=True
).check(
"" if config.cpp_wrapper else "return"
"return"
).run(
out_code[0]
)
@fresh_inductor_cache()
@config.patch(max_autotune_gemm_backends="ATEN")
@torch._inductor.config.patch(max_autotune_gemm_backends="ATEN")
def test_equivalent_extern_code(self):
torch._dynamo.reset()
code, code2 = self._equivalent_output_code_impl(512, 1, False)
for out_code in [code, code2]:
FileCheck().check(get_func_call()).check_count(
"empty_strided", 1, exactly=True
).check("" if config.cpp_wrapper else "extern_kernels.").check_count(
".reset()" if config.cpp_wrapper else "del", 3, exactly=True
).check(
"" if config.cpp_wrapper else "return"
FileCheck().check("def call").check_count(
"empty_strided_cuda", 1, exactly=True
).check("extern_kernels.").check_count("del", 3, exactly=True).check(
"return"
).run(
out_code[0]
)

View File

@ -2801,12 +2801,7 @@ main()
loss.backward()
torch._inductor.config.triton.cudagraphs = False
if inductor_config.cpp_wrapper:
self.assertIn("skipping cudagraphs", stderr_msgs.getvalue())
self.assertEqual(counters["inductor"]["cudagraph_skips"], 1)
else:
self.assertNotIn("skipping cudagraphs", stderr_msgs.getvalue())
self.assertEqual(counters["inductor"]["cudagraph_skips"], 0)
self.assertFalse("skipping cudagraphs" in stderr_msgs.getvalue())
def test_cudagraphs_cpu_graph(self):
from torch._dynamo.testing import reduce_to_scalar_loss
@ -2839,10 +2834,7 @@ main()
opt_bwd()
self.assertEqual(counters["compiled_autograd"]["captures"], 1)
self.assertEqual(
counters["inductor"]["cudagraph_skips"],
2 if inductor_config.cpp_wrapper else 0,
)
self.assertEqual(counters["inductor"]["cudagraph_skips"], 0)
@unittest.skipIf(not HAS_CUDA, "requires cuda")
def test_cudagraphs_cpu_scalar_used_in_python_custom_op(self):
@ -2935,10 +2927,7 @@ TORCH_LIBRARY(test_cudagraphs_cpu_scalar_used_in_cpp_custom_op, m) {
# into it. We must skip since we do not know if the cpu scalar will be used only in ATen/prim ops.
# In the future, we can consider having a cpu scalar movement pass sometime after we trace
# into the custom C++ autograd::Function (like in AOTDispatcher)
self.assertEqual(
counters["inductor"]["cudagraph_skips"],
2 if inductor_config.cpp_wrapper else 1,
)
self.assertEqual(counters["inductor"]["cudagraph_skips"], 1)
def test_logs(self):
logs, ctx = logs_to_string(

View File

@ -46,14 +46,7 @@ from torch._inductor.virtualized import V
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing import FileCheck
from torch.testing._internal.common_utils import skipIfRocm, skipIfXpu
from torch.testing._internal.inductor_utils import (
get_func_call,
get_kernel_launch,
GPU_TYPE,
HAS_CPU,
HAS_CUDA,
HAS_GPU,
)
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_CUDA, HAS_GPU
torch.set_float32_matmul_precision("high")
@ -61,6 +54,14 @@ if HAS_CUDA:
torch.cuda.memory._set_allocator_settings("expandable_segments:False")
def _get_func_call() -> str:
return "void inductor_entry_impl(" if config.cpp_wrapper else "def call("
def _get_kernel_launch() -> str:
return "call_triton_" if config.cpp_wrapper else ".run("
def benchmark_choice(choice, args, out, expected_out, timings):
result = choice.benchmark(*args, out=out)
if expected_out is not None:
@ -898,8 +899,8 @@ class TestMaxAutotune(TestCase):
# mm kernel, and cos kernel
count = 2 if using_triton_mm else 1
FileCheck().check(get_func_call()).check_count(
get_kernel_launch(), count, exactly=True
FileCheck().check(_get_func_call()).check_count(
_get_kernel_launch(), count, exactly=True
).run(code[0])
def f(x, y):
@ -911,8 +912,8 @@ class TestMaxAutotune(TestCase):
f_c = torch.compile(mode="max-autotune-no-cudagraphs")(f)
_, code = run_and_get_code(f_c, inps[0], inps[1])
self.assertEqual(f_c(*inps), f(*inps), atol=0.03, rtol=0.25)
FileCheck().check(get_func_call()).check_count(
get_kernel_launch(), 2, exactly=True
FileCheck().check(_get_func_call()).check_count(
_get_kernel_launch(), 2, exactly=True
).run(code[0])
def f(x, y):
@ -1361,21 +1362,21 @@ class TestPrologueFusion(TestCase):
)
def check_code(self, code_str, num_kernels, num_allocs, num_deallocs):
FileCheck().check(get_func_call()).check_count(
get_kernel_launch(),
FileCheck().check(_get_func_call()).check_count(
_get_kernel_launch(),
num_kernels,
exactly=True,
).run(code_str)
if num_allocs is not None:
FileCheck().check(get_func_call()).check_count(
FileCheck().check(_get_func_call()).check_count(
"empty_strided", num_allocs, exactly=True
).run(code_str)
# skip the deallocation check when using cpp_wrapper; most deallocations happen
# outside of our control via RAIIAtenTensorHandle
if num_deallocs is not None and not config.cpp_wrapper:
FileCheck().check(get_func_call()).check_count(
FileCheck().check(_get_func_call()).check_count(
"del", num_deallocs, exactly=True
).run(code_str)
@ -1515,8 +1516,8 @@ class TestPrologueFusion(TestCase):
out, code = run_and_get_code(torch.compile(multi_use), x, y)
FileCheck().check(get_func_call()).check_count(
get_kernel_launch(), 2, exactly=True
FileCheck().check(_get_func_call()).check_count(
_get_kernel_launch(), 2, exactly=True
).run(code[0])
self.assertEqual(out, multi_use(x, y), atol=0.05, rtol=0.05)
@ -1525,8 +1526,8 @@ class TestPrologueFusion(TestCase):
x = torch.rand([128, 128], device=GPU_TYPE)
out, code = run_and_get_code(torch.compile(resolve_pending), x)
FileCheck().check(get_func_call()).check_count(
get_kernel_launch(), 1, exactly=True
FileCheck().check(_get_func_call()).check_count(
_get_kernel_launch(), 1, exactly=True
).run(code[0])
self.assertEqual(out, resolve_pending(x), atol=0.05, rtol=0.05)
@ -1549,8 +1550,8 @@ class TestPrologueFusion(TestCase):
x = torch.rand([128, 128], dtype=torch.float16, device=GPU_TYPE)
out, code = run_and_get_code(torch.compile(test_multiple_fusions), x)
FileCheck().check(get_func_call()).check_count(
get_kernel_launch(), 1, exactly=True
FileCheck().check(_get_func_call()).check_count(
_get_kernel_launch(), 1, exactly=True
).run(code[0])
self.assertEqual(out, test_multiple_fusions(x), atol=0.05, rtol=0.05)

View File

@ -10100,6 +10100,9 @@ class CommonTemplate:
for x in (torch.randn(2, 3), torch.randn(2, 2), torch.randn(3, 2)):
self.common(fn, (x,))
@skip_if_cpp_wrapper(
"cannot currently handle fallback ops with return types containing list[Tensor]"
)
def test_kwargs(self):
if self.device == GPU_TYPE:
raise unittest.SkipTest("histogramdd only supports cpu")

View File

@ -210,12 +210,6 @@ def maybe_skip_size_asserts(op):
else:
return contextlib.nullcontext()
def get_func_call() -> str:
return "void inductor_entry_impl(" if torch._inductor.config.cpp_wrapper else "def call("
def get_kernel_launch() -> str:
return "call_triton_" if torch._inductor.config.cpp_wrapper else ".run("
def clone_preserve_strides_offset(x, device=None):
if not isinstance(x, torch.Tensor):
return x