[inductor] pre grad graph bisecting (#166344)

A few things to note:
1. Customers like vllm use a custom backend (e.g. VllmBackend), split the graph, and call standalone_compile for each split. If we let the bisector override the backend, we won't bisect thru the custom backend. `test_configs.bisect_keep_custom_backend_for_inductor` is used to keep the custom backend if we are bisecting for inductor.
2. pre_grad_graph bisecting and lowering bisecting so far does not compose well with each other since an issue may be just captured by the first one we try. `test_configs.bisect_pre_grad_graph` is used to enable the 'pre_grad_graph' bisecting.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166344
Approved by: https://github.com/eellison
This commit is contained in:
Shunting Zhang 2025-10-31 17:43:55 -07:00 committed by PyTorch MergeBot
parent 1aef88c72d
commit 4cc64d6234
5 changed files with 81 additions and 1 deletions

View File

@ -275,6 +275,59 @@ class TestCompilerBisector(TestCase):
self.assertEqual(out.backend, "eager")
self.assertEqual(out.subsystem, None)
@config.patch(
{
"test_configs.bisect_pre_grad_graph": True,
"test_configs.bisect_keep_custom_backend_for_inductor": True,
}
)
def test_bisect_pre_grad_graph(self):
def f(x):
for i in range(5):
x = x + 1
return x.relu()
class MyBackend:
def __call__(self, gm, example_inputs):
node_idx = 0
def node_to_graph_id(node):
nonlocal node_idx
out = 0 if node_idx < 3 else 1
node_idx += 1
return out
split_gm = torch.fx.passes.split_module.split_module(
gm, None, node_to_graph_id, keep_original_order=True
)
for name, submod in split_gm.named_modules():
if "submod_" in name:
# the test case is simple enough that using
# the original example_inputs works for sub
# moule
submod.forward = torch._inductor.standalone_compile(
submod,
example_inputs,
dynamic_shapes="from_example_inputs",
options={},
)
return split_gm
def test_fn():
torch._dynamo.reset()
x = torch.randn(1024, device="cuda")
with config.patch("triton.inject_relu_bug_TESTING_ONLY", "accuracy"):
opt_f = torch.compile(f, backend=MyBackend())
return torch.allclose(opt_f(x), f(x))
out = CompilerBisector.do_bisect(test_fn)
self.assertEqual(out.backend, "inductor")
self.assertEqual(out.subsystem, "pre_grad_graph")
self.assertEqual(out.bisect_number, 1)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -2644,7 +2644,16 @@ def compile(
from torch._inductor.compiler_bisector import CompilerBisector
if bisect_backend := CompilerBisector.get_backend():
backend = bisect_backend
import torch._inductor.config as inductor_config
# don't override the backend for use cases like vllm
# which leverages their custom backend.
if not (
inductor_config.test_configs.bisect_keep_custom_backend_for_inductor
and bisect_backend == "inductor"
and not isinstance(backend, str)
):
backend = bisect_backend
guard_filter_fn = None
if options and isinstance(options, dict):

View File

@ -2448,6 +2448,11 @@ def compile_fx(
# Some arguments trigger a recursive call to compile_fx. Handle these
# short circuits first, before anything else
from torch._inductor.compiler_bisector import CompilerBisector
if CompilerBisector.disable_subsystem("inductor", "pre_grad_graph"):
return model_
if config_patches:
with config.patch(config_patches):
return compile_fx(

View File

@ -491,6 +491,13 @@ class CompilerBisector:
Run fn repeatedly attempting to bisect torch.compile. fn should return True on success and False on failure.
"""
# TODO graph bisecting is not well composed with lowering
# bisector so far. Use a config to opt-in
import torch._inductor.config as inductor_config
if inductor_config.test_configs.bisect_pre_grad_graph:
BACKENDS["inductor"].insert(0, BisectSubsystem("pre_grad_graph"))
if not cli_interface:
bisection_enabled_orig = cls.bisection_enabled
cls.delete_bisect_status()
@ -502,6 +509,9 @@ class CompilerBisector:
cls.delete_bisect_status()
cls.in_process_cache = None
if BACKENDS["inductor"][0].name == "pre_grad_graph":
del BACKENDS["inductor"][0]
cleanup_handler = atexit.register(cleanup)
class DisableBisect:

View File

@ -2150,6 +2150,9 @@ class test_configs:
"TORCHINDUCTOR_DISTORT_BENCHMARKING_RESULT", ""
)
bisect_pre_grad_graph = False
bisect_keep_custom_backend_for_inductor = False
if TYPE_CHECKING:
from torch.utils._config_typing import * # noqa: F401, F403