mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[inductor] pre grad graph bisecting (#166344)
A few things to note: 1. Customers like vllm use a custom backend (e.g. VllmBackend), split the graph, and call standalone_compile for each split. If we let the bisector override the backend, we won't bisect thru the custom backend. `test_configs.bisect_keep_custom_backend_for_inductor` is used to keep the custom backend if we are bisecting for inductor. 2. pre_grad_graph bisecting and lowering bisecting so far does not compose well with each other since an issue may be just captured by the first one we try. `test_configs.bisect_pre_grad_graph` is used to enable the 'pre_grad_graph' bisecting. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166344 Approved by: https://github.com/eellison
This commit is contained in:
parent
1aef88c72d
commit
4cc64d6234
|
|
@ -275,6 +275,59 @@ class TestCompilerBisector(TestCase):
|
|||
self.assertEqual(out.backend, "eager")
|
||||
self.assertEqual(out.subsystem, None)
|
||||
|
||||
@config.patch(
|
||||
{
|
||||
"test_configs.bisect_pre_grad_graph": True,
|
||||
"test_configs.bisect_keep_custom_backend_for_inductor": True,
|
||||
}
|
||||
)
|
||||
def test_bisect_pre_grad_graph(self):
|
||||
def f(x):
|
||||
for i in range(5):
|
||||
x = x + 1
|
||||
return x.relu()
|
||||
|
||||
class MyBackend:
|
||||
def __call__(self, gm, example_inputs):
|
||||
node_idx = 0
|
||||
|
||||
def node_to_graph_id(node):
|
||||
nonlocal node_idx
|
||||
out = 0 if node_idx < 3 else 1
|
||||
node_idx += 1
|
||||
return out
|
||||
|
||||
split_gm = torch.fx.passes.split_module.split_module(
|
||||
gm, None, node_to_graph_id, keep_original_order=True
|
||||
)
|
||||
|
||||
for name, submod in split_gm.named_modules():
|
||||
if "submod_" in name:
|
||||
# the test case is simple enough that using
|
||||
# the original example_inputs works for sub
|
||||
# moule
|
||||
submod.forward = torch._inductor.standalone_compile(
|
||||
submod,
|
||||
example_inputs,
|
||||
dynamic_shapes="from_example_inputs",
|
||||
options={},
|
||||
)
|
||||
|
||||
return split_gm
|
||||
|
||||
def test_fn():
|
||||
torch._dynamo.reset()
|
||||
|
||||
x = torch.randn(1024, device="cuda")
|
||||
with config.patch("triton.inject_relu_bug_TESTING_ONLY", "accuracy"):
|
||||
opt_f = torch.compile(f, backend=MyBackend())
|
||||
return torch.allclose(opt_f(x), f(x))
|
||||
|
||||
out = CompilerBisector.do_bisect(test_fn)
|
||||
self.assertEqual(out.backend, "inductor")
|
||||
self.assertEqual(out.subsystem, "pre_grad_graph")
|
||||
self.assertEqual(out.bisect_number, 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
|
|
|||
|
|
@ -2644,7 +2644,16 @@ def compile(
|
|||
from torch._inductor.compiler_bisector import CompilerBisector
|
||||
|
||||
if bisect_backend := CompilerBisector.get_backend():
|
||||
backend = bisect_backend
|
||||
import torch._inductor.config as inductor_config
|
||||
|
||||
# don't override the backend for use cases like vllm
|
||||
# which leverages their custom backend.
|
||||
if not (
|
||||
inductor_config.test_configs.bisect_keep_custom_backend_for_inductor
|
||||
and bisect_backend == "inductor"
|
||||
and not isinstance(backend, str)
|
||||
):
|
||||
backend = bisect_backend
|
||||
|
||||
guard_filter_fn = None
|
||||
if options and isinstance(options, dict):
|
||||
|
|
|
|||
|
|
@ -2448,6 +2448,11 @@ def compile_fx(
|
|||
# Some arguments trigger a recursive call to compile_fx. Handle these
|
||||
# short circuits first, before anything else
|
||||
|
||||
from torch._inductor.compiler_bisector import CompilerBisector
|
||||
|
||||
if CompilerBisector.disable_subsystem("inductor", "pre_grad_graph"):
|
||||
return model_
|
||||
|
||||
if config_patches:
|
||||
with config.patch(config_patches):
|
||||
return compile_fx(
|
||||
|
|
|
|||
|
|
@ -491,6 +491,13 @@ class CompilerBisector:
|
|||
Run fn repeatedly attempting to bisect torch.compile. fn should return True on success and False on failure.
|
||||
"""
|
||||
|
||||
# TODO graph bisecting is not well composed with lowering
|
||||
# bisector so far. Use a config to opt-in
|
||||
import torch._inductor.config as inductor_config
|
||||
|
||||
if inductor_config.test_configs.bisect_pre_grad_graph:
|
||||
BACKENDS["inductor"].insert(0, BisectSubsystem("pre_grad_graph"))
|
||||
|
||||
if not cli_interface:
|
||||
bisection_enabled_orig = cls.bisection_enabled
|
||||
cls.delete_bisect_status()
|
||||
|
|
@ -502,6 +509,9 @@ class CompilerBisector:
|
|||
cls.delete_bisect_status()
|
||||
cls.in_process_cache = None
|
||||
|
||||
if BACKENDS["inductor"][0].name == "pre_grad_graph":
|
||||
del BACKENDS["inductor"][0]
|
||||
|
||||
cleanup_handler = atexit.register(cleanup)
|
||||
|
||||
class DisableBisect:
|
||||
|
|
|
|||
|
|
@ -2150,6 +2150,9 @@ class test_configs:
|
|||
"TORCHINDUCTOR_DISTORT_BENCHMARKING_RESULT", ""
|
||||
)
|
||||
|
||||
bisect_pre_grad_graph = False
|
||||
bisect_keep_custom_backend_for_inductor = False
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.utils._config_typing import * # noqa: F401, F403
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user