mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[inductor] pre grad graph bisecting (#166344)
A few things to note: 1. Customers like vllm use a custom backend (e.g. VllmBackend), split the graph, and call standalone_compile for each split. If we let the bisector override the backend, we won't bisect thru the custom backend. `test_configs.bisect_keep_custom_backend_for_inductor` is used to keep the custom backend if we are bisecting for inductor. 2. pre_grad_graph bisecting and lowering bisecting so far does not compose well with each other since an issue may be just captured by the first one we try. `test_configs.bisect_pre_grad_graph` is used to enable the 'pre_grad_graph' bisecting. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166344 Approved by: https://github.com/eellison
This commit is contained in:
parent
1aef88c72d
commit
4cc64d6234
|
|
@ -275,6 +275,59 @@ class TestCompilerBisector(TestCase):
|
||||||
self.assertEqual(out.backend, "eager")
|
self.assertEqual(out.backend, "eager")
|
||||||
self.assertEqual(out.subsystem, None)
|
self.assertEqual(out.subsystem, None)
|
||||||
|
|
||||||
|
@config.patch(
|
||||||
|
{
|
||||||
|
"test_configs.bisect_pre_grad_graph": True,
|
||||||
|
"test_configs.bisect_keep_custom_backend_for_inductor": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
def test_bisect_pre_grad_graph(self):
|
||||||
|
def f(x):
|
||||||
|
for i in range(5):
|
||||||
|
x = x + 1
|
||||||
|
return x.relu()
|
||||||
|
|
||||||
|
class MyBackend:
|
||||||
|
def __call__(self, gm, example_inputs):
|
||||||
|
node_idx = 0
|
||||||
|
|
||||||
|
def node_to_graph_id(node):
|
||||||
|
nonlocal node_idx
|
||||||
|
out = 0 if node_idx < 3 else 1
|
||||||
|
node_idx += 1
|
||||||
|
return out
|
||||||
|
|
||||||
|
split_gm = torch.fx.passes.split_module.split_module(
|
||||||
|
gm, None, node_to_graph_id, keep_original_order=True
|
||||||
|
)
|
||||||
|
|
||||||
|
for name, submod in split_gm.named_modules():
|
||||||
|
if "submod_" in name:
|
||||||
|
# the test case is simple enough that using
|
||||||
|
# the original example_inputs works for sub
|
||||||
|
# moule
|
||||||
|
submod.forward = torch._inductor.standalone_compile(
|
||||||
|
submod,
|
||||||
|
example_inputs,
|
||||||
|
dynamic_shapes="from_example_inputs",
|
||||||
|
options={},
|
||||||
|
)
|
||||||
|
|
||||||
|
return split_gm
|
||||||
|
|
||||||
|
def test_fn():
|
||||||
|
torch._dynamo.reset()
|
||||||
|
|
||||||
|
x = torch.randn(1024, device="cuda")
|
||||||
|
with config.patch("triton.inject_relu_bug_TESTING_ONLY", "accuracy"):
|
||||||
|
opt_f = torch.compile(f, backend=MyBackend())
|
||||||
|
return torch.allclose(opt_f(x), f(x))
|
||||||
|
|
||||||
|
out = CompilerBisector.do_bisect(test_fn)
|
||||||
|
self.assertEqual(out.backend, "inductor")
|
||||||
|
self.assertEqual(out.subsystem, "pre_grad_graph")
|
||||||
|
self.assertEqual(out.bisect_number, 1)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from torch._dynamo.test_case import run_tests
|
from torch._dynamo.test_case import run_tests
|
||||||
|
|
|
||||||
|
|
@ -2644,7 +2644,16 @@ def compile(
|
||||||
from torch._inductor.compiler_bisector import CompilerBisector
|
from torch._inductor.compiler_bisector import CompilerBisector
|
||||||
|
|
||||||
if bisect_backend := CompilerBisector.get_backend():
|
if bisect_backend := CompilerBisector.get_backend():
|
||||||
backend = bisect_backend
|
import torch._inductor.config as inductor_config
|
||||||
|
|
||||||
|
# don't override the backend for use cases like vllm
|
||||||
|
# which leverages their custom backend.
|
||||||
|
if not (
|
||||||
|
inductor_config.test_configs.bisect_keep_custom_backend_for_inductor
|
||||||
|
and bisect_backend == "inductor"
|
||||||
|
and not isinstance(backend, str)
|
||||||
|
):
|
||||||
|
backend = bisect_backend
|
||||||
|
|
||||||
guard_filter_fn = None
|
guard_filter_fn = None
|
||||||
if options and isinstance(options, dict):
|
if options and isinstance(options, dict):
|
||||||
|
|
|
||||||
|
|
@ -2448,6 +2448,11 @@ def compile_fx(
|
||||||
# Some arguments trigger a recursive call to compile_fx. Handle these
|
# Some arguments trigger a recursive call to compile_fx. Handle these
|
||||||
# short circuits first, before anything else
|
# short circuits first, before anything else
|
||||||
|
|
||||||
|
from torch._inductor.compiler_bisector import CompilerBisector
|
||||||
|
|
||||||
|
if CompilerBisector.disable_subsystem("inductor", "pre_grad_graph"):
|
||||||
|
return model_
|
||||||
|
|
||||||
if config_patches:
|
if config_patches:
|
||||||
with config.patch(config_patches):
|
with config.patch(config_patches):
|
||||||
return compile_fx(
|
return compile_fx(
|
||||||
|
|
|
||||||
|
|
@ -491,6 +491,13 @@ class CompilerBisector:
|
||||||
Run fn repeatedly attempting to bisect torch.compile. fn should return True on success and False on failure.
|
Run fn repeatedly attempting to bisect torch.compile. fn should return True on success and False on failure.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# TODO graph bisecting is not well composed with lowering
|
||||||
|
# bisector so far. Use a config to opt-in
|
||||||
|
import torch._inductor.config as inductor_config
|
||||||
|
|
||||||
|
if inductor_config.test_configs.bisect_pre_grad_graph:
|
||||||
|
BACKENDS["inductor"].insert(0, BisectSubsystem("pre_grad_graph"))
|
||||||
|
|
||||||
if not cli_interface:
|
if not cli_interface:
|
||||||
bisection_enabled_orig = cls.bisection_enabled
|
bisection_enabled_orig = cls.bisection_enabled
|
||||||
cls.delete_bisect_status()
|
cls.delete_bisect_status()
|
||||||
|
|
@ -502,6 +509,9 @@ class CompilerBisector:
|
||||||
cls.delete_bisect_status()
|
cls.delete_bisect_status()
|
||||||
cls.in_process_cache = None
|
cls.in_process_cache = None
|
||||||
|
|
||||||
|
if BACKENDS["inductor"][0].name == "pre_grad_graph":
|
||||||
|
del BACKENDS["inductor"][0]
|
||||||
|
|
||||||
cleanup_handler = atexit.register(cleanup)
|
cleanup_handler = atexit.register(cleanup)
|
||||||
|
|
||||||
class DisableBisect:
|
class DisableBisect:
|
||||||
|
|
|
||||||
|
|
@ -2150,6 +2150,9 @@ class test_configs:
|
||||||
"TORCHINDUCTOR_DISTORT_BENCHMARKING_RESULT", ""
|
"TORCHINDUCTOR_DISTORT_BENCHMARKING_RESULT", ""
|
||||||
)
|
)
|
||||||
|
|
||||||
|
bisect_pre_grad_graph = False
|
||||||
|
bisect_keep_custom_backend_for_inductor = False
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from torch.utils._config_typing import * # noqa: F401, F403
|
from torch.utils._config_typing import * # noqa: F401, F403
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user