[inductor] pre grad graph bisecting (#166344)

A few things to note:
1. Customers like vllm use a custom backend (e.g. VllmBackend), split the graph, and call standalone_compile for each split. If we let the bisector override the backend, we won't bisect thru the custom backend. `test_configs.bisect_keep_custom_backend_for_inductor` is used to keep the custom backend if we are bisecting for inductor.
2. pre_grad_graph bisecting and lowering bisecting so far does not compose well with each other since an issue may be just captured by the first one we try. `test_configs.bisect_pre_grad_graph` is used to enable the 'pre_grad_graph' bisecting.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166344
Approved by: https://github.com/eellison
This commit is contained in:
Shunting Zhang 2025-10-31 17:43:55 -07:00 committed by PyTorch MergeBot
parent 1aef88c72d
commit 4cc64d6234
5 changed files with 81 additions and 1 deletions

View File

@ -275,6 +275,59 @@ class TestCompilerBisector(TestCase):
self.assertEqual(out.backend, "eager") self.assertEqual(out.backend, "eager")
self.assertEqual(out.subsystem, None) self.assertEqual(out.subsystem, None)
@config.patch(
{
"test_configs.bisect_pre_grad_graph": True,
"test_configs.bisect_keep_custom_backend_for_inductor": True,
}
)
def test_bisect_pre_grad_graph(self):
def f(x):
for i in range(5):
x = x + 1
return x.relu()
class MyBackend:
def __call__(self, gm, example_inputs):
node_idx = 0
def node_to_graph_id(node):
nonlocal node_idx
out = 0 if node_idx < 3 else 1
node_idx += 1
return out
split_gm = torch.fx.passes.split_module.split_module(
gm, None, node_to_graph_id, keep_original_order=True
)
for name, submod in split_gm.named_modules():
if "submod_" in name:
# the test case is simple enough that using
# the original example_inputs works for sub
# moule
submod.forward = torch._inductor.standalone_compile(
submod,
example_inputs,
dynamic_shapes="from_example_inputs",
options={},
)
return split_gm
def test_fn():
torch._dynamo.reset()
x = torch.randn(1024, device="cuda")
with config.patch("triton.inject_relu_bug_TESTING_ONLY", "accuracy"):
opt_f = torch.compile(f, backend=MyBackend())
return torch.allclose(opt_f(x), f(x))
out = CompilerBisector.do_bisect(test_fn)
self.assertEqual(out.backend, "inductor")
self.assertEqual(out.subsystem, "pre_grad_graph")
self.assertEqual(out.bisect_number, 1)
if __name__ == "__main__": if __name__ == "__main__":
from torch._dynamo.test_case import run_tests from torch._dynamo.test_case import run_tests

View File

@ -2644,7 +2644,16 @@ def compile(
from torch._inductor.compiler_bisector import CompilerBisector from torch._inductor.compiler_bisector import CompilerBisector
if bisect_backend := CompilerBisector.get_backend(): if bisect_backend := CompilerBisector.get_backend():
backend = bisect_backend import torch._inductor.config as inductor_config
# don't override the backend for use cases like vllm
# which leverages their custom backend.
if not (
inductor_config.test_configs.bisect_keep_custom_backend_for_inductor
and bisect_backend == "inductor"
and not isinstance(backend, str)
):
backend = bisect_backend
guard_filter_fn = None guard_filter_fn = None
if options and isinstance(options, dict): if options and isinstance(options, dict):

View File

@ -2448,6 +2448,11 @@ def compile_fx(
# Some arguments trigger a recursive call to compile_fx. Handle these # Some arguments trigger a recursive call to compile_fx. Handle these
# short circuits first, before anything else # short circuits first, before anything else
from torch._inductor.compiler_bisector import CompilerBisector
if CompilerBisector.disable_subsystem("inductor", "pre_grad_graph"):
return model_
if config_patches: if config_patches:
with config.patch(config_patches): with config.patch(config_patches):
return compile_fx( return compile_fx(

View File

@ -491,6 +491,13 @@ class CompilerBisector:
Run fn repeatedly attempting to bisect torch.compile. fn should return True on success and False on failure. Run fn repeatedly attempting to bisect torch.compile. fn should return True on success and False on failure.
""" """
# TODO graph bisecting is not well composed with lowering
# bisector so far. Use a config to opt-in
import torch._inductor.config as inductor_config
if inductor_config.test_configs.bisect_pre_grad_graph:
BACKENDS["inductor"].insert(0, BisectSubsystem("pre_grad_graph"))
if not cli_interface: if not cli_interface:
bisection_enabled_orig = cls.bisection_enabled bisection_enabled_orig = cls.bisection_enabled
cls.delete_bisect_status() cls.delete_bisect_status()
@ -502,6 +509,9 @@ class CompilerBisector:
cls.delete_bisect_status() cls.delete_bisect_status()
cls.in_process_cache = None cls.in_process_cache = None
if BACKENDS["inductor"][0].name == "pre_grad_graph":
del BACKENDS["inductor"][0]
cleanup_handler = atexit.register(cleanup) cleanup_handler = atexit.register(cleanup)
class DisableBisect: class DisableBisect:

View File

@ -2150,6 +2150,9 @@ class test_configs:
"TORCHINDUCTOR_DISTORT_BENCHMARKING_RESULT", "" "TORCHINDUCTOR_DISTORT_BENCHMARKING_RESULT", ""
) )
bisect_pre_grad_graph = False
bisect_keep_custom_backend_for_inductor = False
if TYPE_CHECKING: if TYPE_CHECKING:
from torch.utils._config_typing import * # noqa: F401, F403 from torch.utils._config_typing import * # noqa: F401, F403