mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
cpp_wrapper: Fix even more tests (#147225)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147225 Approved by: https://github.com/desertfire ghstack dependencies: #150671, #150672
This commit is contained in:
parent
f0abbabac1
commit
f813d64f54
|
|
@ -10,7 +10,7 @@ from torch._inductor.test_operators import realize
|
||||||
from torch._inductor.utils import fresh_inductor_cache, is_big_gpu, run_and_get_code
|
from torch._inductor.utils import fresh_inductor_cache, is_big_gpu, run_and_get_code
|
||||||
from torch.testing import FileCheck
|
from torch.testing import FileCheck
|
||||||
from torch.testing._internal.common_utils import slowTest
|
from torch.testing._internal.common_utils import slowTest
|
||||||
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
|
from torch.testing._internal.inductor_utils import get_func_call, HAS_CPU, HAS_CUDA
|
||||||
|
|
||||||
|
|
||||||
# Make the helper files in test/ importable
|
# Make the helper files in test/ importable
|
||||||
|
|
@ -24,6 +24,7 @@ from inductor.test_torchinductor import ( # @manual=fbcode//caffe2/test/inducto
|
||||||
check_model,
|
check_model,
|
||||||
check_model_cuda,
|
check_model_cuda,
|
||||||
copy_tests,
|
copy_tests,
|
||||||
|
skip_if_cpp_wrapper,
|
||||||
)
|
)
|
||||||
from torch._inductor import config
|
from torch._inductor import config
|
||||||
from torch._inductor.scheduler import Scheduler
|
from torch._inductor.scheduler import Scheduler
|
||||||
|
|
@ -126,7 +127,7 @@ class BenchmarkFusionTestTemplate:
|
||||||
|
|
||||||
self.common(f, (a, b))
|
self.common(f, (a, b))
|
||||||
|
|
||||||
@torch._inductor.config.patch(max_autotune_gemm_backends="TRITON")
|
@config.patch(max_autotune_gemm_backends="TRITON")
|
||||||
def test_avoid_register_spilling(self):
|
def test_avoid_register_spilling(self):
|
||||||
if self.device != "cuda":
|
if self.device != "cuda":
|
||||||
raise unittest.SkipTest("CUDA only")
|
raise unittest.SkipTest("CUDA only")
|
||||||
|
|
@ -196,6 +197,7 @@ if HAS_CUDA:
|
||||||
@unittest.skipIf(
|
@unittest.skipIf(
|
||||||
torch.cuda.device_count() < 2, "The test need at least 2 devices"
|
torch.cuda.device_count() < 2, "The test need at least 2 devices"
|
||||||
)
|
)
|
||||||
|
@skip_if_cpp_wrapper("This tests triton scheduling directly")
|
||||||
def test_benchmark_on_non_zero_device(self):
|
def test_benchmark_on_non_zero_device(self):
|
||||||
hit_count = 0
|
hit_count = 0
|
||||||
with torch.cuda.device("cuda:0"):
|
with torch.cuda.device("cuda:0"):
|
||||||
|
|
@ -265,9 +267,7 @@ if HAS_CUDA:
|
||||||
res, code = run_and_get_code(foo_c, m, inp)
|
res, code = run_and_get_code(foo_c, m, inp)
|
||||||
|
|
||||||
torch._dynamo.reset()
|
torch._dynamo.reset()
|
||||||
with unittest.mock.patch.object(
|
with config.patch(benchmark_epilogue_fusion=False):
|
||||||
torch._inductor.config, "benchmark_epilogue_fusion", False
|
|
||||||
):
|
|
||||||
foo_c = torch.compile(mode="max-autotune-no-cudagraphs")(foo)
|
foo_c = torch.compile(mode="max-autotune-no-cudagraphs")(foo)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
res2, code2 = run_and_get_code(foo_c, m, inp)
|
res2, code2 = run_and_get_code(foo_c, m, inp)
|
||||||
|
|
@ -276,32 +276,34 @@ if HAS_CUDA:
|
||||||
return code, code2
|
return code, code2
|
||||||
|
|
||||||
@fresh_inductor_cache()
|
@fresh_inductor_cache()
|
||||||
@torch._inductor.config.patch(max_autotune_gemm_backends="TRITON")
|
@config.patch(max_autotune_gemm_backends="TRITON")
|
||||||
def test_equivalent_template_code(self):
|
def test_equivalent_template_code(self):
|
||||||
code, code2 = self._equivalent_output_code_impl(256)
|
code, code2 = self._equivalent_output_code_impl(256)
|
||||||
for out_code in [code, code2]:
|
for out_code in [code, code2]:
|
||||||
FileCheck().check("def call").check_count(
|
FileCheck().check(get_func_call()).check_count(
|
||||||
"empty_strided_cuda", 1, exactly=True
|
"empty_strided", 1, exactly=True
|
||||||
).check("triton_tem_fused_addmm_relu_0.run").check_count(
|
).check("triton_tem_fused_addmm_relu_0").check_count(
|
||||||
"del", 3, exactly=True
|
".reset()" if config.cpp_wrapper else "del", 3, exactly=True
|
||||||
).check(
|
).check(
|
||||||
"return"
|
"" if config.cpp_wrapper else "return"
|
||||||
).run(
|
).run(
|
||||||
out_code[0]
|
out_code[0]
|
||||||
)
|
)
|
||||||
|
|
||||||
@fresh_inductor_cache()
|
@fresh_inductor_cache()
|
||||||
@torch._inductor.config.patch(max_autotune_gemm_backends="ATEN")
|
@config.patch(max_autotune_gemm_backends="ATEN")
|
||||||
def test_equivalent_extern_code(self):
|
def test_equivalent_extern_code(self):
|
||||||
torch._dynamo.reset()
|
torch._dynamo.reset()
|
||||||
|
|
||||||
code, code2 = self._equivalent_output_code_impl(512, 1, False)
|
code, code2 = self._equivalent_output_code_impl(512, 1, False)
|
||||||
|
|
||||||
for out_code in [code, code2]:
|
for out_code in [code, code2]:
|
||||||
FileCheck().check("def call").check_count(
|
FileCheck().check(get_func_call()).check_count(
|
||||||
"empty_strided_cuda", 1, exactly=True
|
"empty_strided", 1, exactly=True
|
||||||
).check("extern_kernels.").check_count("del", 3, exactly=True).check(
|
).check("" if config.cpp_wrapper else "extern_kernels.").check_count(
|
||||||
"return"
|
".reset()" if config.cpp_wrapper else "del", 3, exactly=True
|
||||||
|
).check(
|
||||||
|
"" if config.cpp_wrapper else "return"
|
||||||
).run(
|
).run(
|
||||||
out_code[0]
|
out_code[0]
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -2801,7 +2801,12 @@ main()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
torch._inductor.config.triton.cudagraphs = False
|
torch._inductor.config.triton.cudagraphs = False
|
||||||
|
|
||||||
self.assertFalse("skipping cudagraphs" in stderr_msgs.getvalue())
|
if inductor_config.cpp_wrapper:
|
||||||
|
self.assertIn("skipping cudagraphs", stderr_msgs.getvalue())
|
||||||
|
self.assertEqual(counters["inductor"]["cudagraph_skips"], 1)
|
||||||
|
else:
|
||||||
|
self.assertNotIn("skipping cudagraphs", stderr_msgs.getvalue())
|
||||||
|
self.assertEqual(counters["inductor"]["cudagraph_skips"], 0)
|
||||||
|
|
||||||
def test_cudagraphs_cpu_graph(self):
|
def test_cudagraphs_cpu_graph(self):
|
||||||
from torch._dynamo.testing import reduce_to_scalar_loss
|
from torch._dynamo.testing import reduce_to_scalar_loss
|
||||||
|
|
@ -2834,7 +2839,10 @@ main()
|
||||||
opt_bwd()
|
opt_bwd()
|
||||||
|
|
||||||
self.assertEqual(counters["compiled_autograd"]["captures"], 1)
|
self.assertEqual(counters["compiled_autograd"]["captures"], 1)
|
||||||
self.assertEqual(counters["inductor"]["cudagraph_skips"], 0)
|
self.assertEqual(
|
||||||
|
counters["inductor"]["cudagraph_skips"],
|
||||||
|
2 if inductor_config.cpp_wrapper else 0,
|
||||||
|
)
|
||||||
|
|
||||||
@unittest.skipIf(not HAS_CUDA, "requires cuda")
|
@unittest.skipIf(not HAS_CUDA, "requires cuda")
|
||||||
def test_cudagraphs_cpu_scalar_used_in_python_custom_op(self):
|
def test_cudagraphs_cpu_scalar_used_in_python_custom_op(self):
|
||||||
|
|
@ -2927,7 +2935,10 @@ TORCH_LIBRARY(test_cudagraphs_cpu_scalar_used_in_cpp_custom_op, m) {
|
||||||
# into it. We must skip since we do not know if the cpu scalar will be used only in ATen/prim ops.
|
# into it. We must skip since we do not know if the cpu scalar will be used only in ATen/prim ops.
|
||||||
# In the future, we can consider having a cpu scalar movement pass sometime after we trace
|
# In the future, we can consider having a cpu scalar movement pass sometime after we trace
|
||||||
# into the custom C++ autograd::Function (like in AOTDispatcher)
|
# into the custom C++ autograd::Function (like in AOTDispatcher)
|
||||||
self.assertEqual(counters["inductor"]["cudagraph_skips"], 1)
|
self.assertEqual(
|
||||||
|
counters["inductor"]["cudagraph_skips"],
|
||||||
|
2 if inductor_config.cpp_wrapper else 1,
|
||||||
|
)
|
||||||
|
|
||||||
def test_logs(self):
|
def test_logs(self):
|
||||||
logs, ctx = logs_to_string(
|
logs, ctx = logs_to_string(
|
||||||
|
|
|
||||||
|
|
@ -46,7 +46,14 @@ from torch._inductor.virtualized import V
|
||||||
from torch.fx.experimental.proxy_tensor import make_fx
|
from torch.fx.experimental.proxy_tensor import make_fx
|
||||||
from torch.testing import FileCheck
|
from torch.testing import FileCheck
|
||||||
from torch.testing._internal.common_utils import MI300_ARCH, runOnRocmArch, skipIfXpu
|
from torch.testing._internal.common_utils import MI300_ARCH, runOnRocmArch, skipIfXpu
|
||||||
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_CUDA, HAS_GPU
|
from torch.testing._internal.inductor_utils import (
|
||||||
|
get_func_call,
|
||||||
|
get_kernel_launch,
|
||||||
|
GPU_TYPE,
|
||||||
|
HAS_CPU,
|
||||||
|
HAS_CUDA,
|
||||||
|
HAS_GPU,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
torch.set_float32_matmul_precision("high")
|
torch.set_float32_matmul_precision("high")
|
||||||
|
|
@ -54,14 +61,6 @@ if HAS_CUDA:
|
||||||
torch.cuda.memory._set_allocator_settings("expandable_segments:False")
|
torch.cuda.memory._set_allocator_settings("expandable_segments:False")
|
||||||
|
|
||||||
|
|
||||||
def _get_func_call() -> str:
|
|
||||||
return "void inductor_entry_impl(" if config.cpp_wrapper else "def call("
|
|
||||||
|
|
||||||
|
|
||||||
def _get_kernel_launch() -> str:
|
|
||||||
return "call_triton_" if config.cpp_wrapper else ".run("
|
|
||||||
|
|
||||||
|
|
||||||
def benchmark_choice(choice, args, out, expected_out, timings):
|
def benchmark_choice(choice, args, out, expected_out, timings):
|
||||||
result = choice.benchmark(*args, out=out)
|
result = choice.benchmark(*args, out=out)
|
||||||
if expected_out is not None:
|
if expected_out is not None:
|
||||||
|
|
@ -899,8 +898,8 @@ class TestMaxAutotune(TestCase):
|
||||||
|
|
||||||
# mm kernel, and cos kernel
|
# mm kernel, and cos kernel
|
||||||
count = 2 if using_triton_mm else 1
|
count = 2 if using_triton_mm else 1
|
||||||
FileCheck().check(_get_func_call()).check_count(
|
FileCheck().check(get_func_call()).check_count(
|
||||||
_get_kernel_launch(), count, exactly=True
|
get_kernel_launch(), count, exactly=True
|
||||||
).run(code[0])
|
).run(code[0])
|
||||||
|
|
||||||
def f(x, y):
|
def f(x, y):
|
||||||
|
|
@ -912,8 +911,8 @@ class TestMaxAutotune(TestCase):
|
||||||
f_c = torch.compile(mode="max-autotune-no-cudagraphs")(f)
|
f_c = torch.compile(mode="max-autotune-no-cudagraphs")(f)
|
||||||
_, code = run_and_get_code(f_c, inps[0], inps[1])
|
_, code = run_and_get_code(f_c, inps[0], inps[1])
|
||||||
self.assertEqual(f_c(*inps), f(*inps), atol=0.03, rtol=0.25)
|
self.assertEqual(f_c(*inps), f(*inps), atol=0.03, rtol=0.25)
|
||||||
FileCheck().check(_get_func_call()).check_count(
|
FileCheck().check(get_func_call()).check_count(
|
||||||
_get_kernel_launch(), 2, exactly=True
|
get_kernel_launch(), 2, exactly=True
|
||||||
).run(code[0])
|
).run(code[0])
|
||||||
|
|
||||||
def f(x, y):
|
def f(x, y):
|
||||||
|
|
@ -1362,21 +1361,21 @@ class TestPrologueFusion(TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
def check_code(self, code_str, num_kernels, num_allocs, num_deallocs):
|
def check_code(self, code_str, num_kernels, num_allocs, num_deallocs):
|
||||||
FileCheck().check(_get_func_call()).check_count(
|
FileCheck().check(get_func_call()).check_count(
|
||||||
_get_kernel_launch(),
|
get_kernel_launch(),
|
||||||
num_kernels,
|
num_kernels,
|
||||||
exactly=True,
|
exactly=True,
|
||||||
).run(code_str)
|
).run(code_str)
|
||||||
|
|
||||||
if num_allocs is not None:
|
if num_allocs is not None:
|
||||||
FileCheck().check(_get_func_call()).check_count(
|
FileCheck().check(get_func_call()).check_count(
|
||||||
"empty_strided", num_allocs, exactly=True
|
"empty_strided", num_allocs, exactly=True
|
||||||
).run(code_str)
|
).run(code_str)
|
||||||
|
|
||||||
# skip the deallocation check when using cpp_wrapper; most deallocations happen
|
# skip the deallocation check when using cpp_wrapper; most deallocations happen
|
||||||
# outside of our control via RAIIAtenTensorHandle
|
# outside of our control via RAIIAtenTensorHandle
|
||||||
if num_deallocs is not None and not config.cpp_wrapper:
|
if num_deallocs is not None and not config.cpp_wrapper:
|
||||||
FileCheck().check(_get_func_call()).check_count(
|
FileCheck().check(get_func_call()).check_count(
|
||||||
"del", num_deallocs, exactly=True
|
"del", num_deallocs, exactly=True
|
||||||
).run(code_str)
|
).run(code_str)
|
||||||
|
|
||||||
|
|
@ -1557,8 +1556,8 @@ class TestPrologueFusion(TestCase):
|
||||||
|
|
||||||
out, code = run_and_get_code(torch.compile(multi_use), x, y)
|
out, code = run_and_get_code(torch.compile(multi_use), x, y)
|
||||||
|
|
||||||
FileCheck().check(_get_func_call()).check_count(
|
FileCheck().check(get_func_call()).check_count(
|
||||||
_get_kernel_launch(), 2, exactly=True
|
get_kernel_launch(), 2, exactly=True
|
||||||
).run(code[0])
|
).run(code[0])
|
||||||
self.assertEqual(out, multi_use(x, y), atol=0.05, rtol=0.05)
|
self.assertEqual(out, multi_use(x, y), atol=0.05, rtol=0.05)
|
||||||
|
|
||||||
|
|
@ -1567,8 +1566,8 @@ class TestPrologueFusion(TestCase):
|
||||||
|
|
||||||
x = torch.rand([128, 128], device=GPU_TYPE)
|
x = torch.rand([128, 128], device=GPU_TYPE)
|
||||||
out, code = run_and_get_code(torch.compile(resolve_pending), x)
|
out, code = run_and_get_code(torch.compile(resolve_pending), x)
|
||||||
FileCheck().check(_get_func_call()).check_count(
|
FileCheck().check(get_func_call()).check_count(
|
||||||
_get_kernel_launch(), 1, exactly=True
|
get_kernel_launch(), 1, exactly=True
|
||||||
).run(code[0])
|
).run(code[0])
|
||||||
self.assertEqual(out, resolve_pending(x), atol=0.05, rtol=0.05)
|
self.assertEqual(out, resolve_pending(x), atol=0.05, rtol=0.05)
|
||||||
|
|
||||||
|
|
@ -1591,8 +1590,8 @@ class TestPrologueFusion(TestCase):
|
||||||
|
|
||||||
x = torch.rand([128, 128], dtype=torch.float16, device=GPU_TYPE)
|
x = torch.rand([128, 128], dtype=torch.float16, device=GPU_TYPE)
|
||||||
out, code = run_and_get_code(torch.compile(test_multiple_fusions), x)
|
out, code = run_and_get_code(torch.compile(test_multiple_fusions), x)
|
||||||
FileCheck().check(_get_func_call()).check_count(
|
FileCheck().check(get_func_call()).check_count(
|
||||||
_get_kernel_launch(), 1, exactly=True
|
get_kernel_launch(), 1, exactly=True
|
||||||
).run(code[0])
|
).run(code[0])
|
||||||
self.assertEqual(out, test_multiple_fusions(x), atol=0.05, rtol=0.05)
|
self.assertEqual(out, test_multiple_fusions(x), atol=0.05, rtol=0.05)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10124,9 +10124,6 @@ class CommonTemplate:
|
||||||
for x in (torch.randn(2, 3), torch.randn(2, 2), torch.randn(3, 2)):
|
for x in (torch.randn(2, 3), torch.randn(2, 2), torch.randn(3, 2)):
|
||||||
self.common(fn, (x,))
|
self.common(fn, (x,))
|
||||||
|
|
||||||
@skip_if_cpp_wrapper(
|
|
||||||
"cannot currently handle fallback ops with return types containing list[Tensor]"
|
|
||||||
)
|
|
||||||
def test_kwargs(self):
|
def test_kwargs(self):
|
||||||
if self.device == GPU_TYPE:
|
if self.device == GPU_TYPE:
|
||||||
raise unittest.SkipTest("histogramdd only supports cpu")
|
raise unittest.SkipTest("histogramdd only supports cpu")
|
||||||
|
|
|
||||||
|
|
@ -210,6 +210,12 @@ def maybe_skip_size_asserts(op):
|
||||||
else:
|
else:
|
||||||
return contextlib.nullcontext()
|
return contextlib.nullcontext()
|
||||||
|
|
||||||
|
def get_func_call() -> str:
|
||||||
|
return "void inductor_entry_impl(" if torch._inductor.config.cpp_wrapper else "def call("
|
||||||
|
|
||||||
|
def get_kernel_launch() -> str:
|
||||||
|
return "call_triton_" if torch._inductor.config.cpp_wrapper else ".run("
|
||||||
|
|
||||||
def clone_preserve_strides_offset(x, device=None):
|
def clone_preserve_strides_offset(x, device=None):
|
||||||
if not isinstance(x, torch.Tensor):
|
if not isinstance(x, torch.Tensor):
|
||||||
return x
|
return x
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user