From 3c8c509a9c28a1a9da5e525269f1a62a0c5200f6 Mon Sep 17 00:00:00 2001 From: angelayi Date: Mon, 18 Aug 2025 15:42:17 +0000 Subject: [PATCH] [export] Fix custom ops in subgraphs (#160004) Fixes https://github.com/pytorch/pytorch/issues/159995 Currently there are two problems with extern kernels in subgraphs: 1. They don't get serialized to the extern kernel json file because we only look at the toplevel graph. 2. Since the scope of each extern_kernel list is within its own subgraph, the indices referencing the operator is messed up because each subgraph will start counting from 0. So, this PR moves the extern_kernels list to a global view (under virtualized) so that we can count the extern kernels across subgraphs and the toplevel graph. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160004 Approved by: https://github.com/ydwu4 --- test/inductor/test_aot_inductor.py | 43 +++++++++++++++++++++ test/inductor/test_aot_inductor_arrayref.py | 1 + torch/_inductor/codegen/cpp_wrapper_cpu.py | 2 +- torch/_inductor/compile_fx.py | 13 ++++--- torch/_inductor/graph.py | 2 - torch/_inductor/ir.py | 2 +- torch/_inductor/virtualized.py | 16 ++++++++ 7 files changed, 69 insertions(+), 10 deletions(-) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index 81a218d5c42..0889c948de0 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -6772,6 +6772,49 @@ class AOTInductorTestsTemplate: # compare against eager self.assertEqual(optimized(**model_kwargs), model(**model_kwargs)) + def test_custom_op_in_subgraph(self): + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: + torch.library.define( + "mylib::foo_add1", + "(Tensor a) -> Tensor", + tags=torch.Tag.pt2_compliant_tag, + lib=lib, + ) + + @torch.library.impl("mylib::foo_add1", "CompositeExplicitAutograd", lib=lib) + @torch.library.register_fake("mylib::foo_add1", lib=lib) + def foo_add1_impl(a: torch.Tensor) -> torch.Tensor: + return a + 1 + + torch.library.define( + "mylib::foo_add2", + "(Tensor a) -> Tensor", + tags=torch.Tag.pt2_compliant_tag, + lib=lib, + ) + + @torch.library.impl("mylib::foo_add2", "CompositeExplicitAutograd", lib=lib) + @torch.library.register_fake("mylib::foo_add2", lib=lib) + def foo_add2_impl(a: torch.Tensor) -> torch.Tensor: + return a + 2 + + class M(torch.nn.Module): + def forward(self, x): + return torch.cond( + x.shape[0] < 5, + torch.ops.mylib.foo_add1, + torch.ops.mylib.foo_add2, + (x,), + ) + + list_example_inputs = [ + (torch.ones(6, device=self.device),), + (torch.ones(3, device=self.device),), + ] + self.check_model_with_multiple_inputs( + M(), list_example_inputs, dynamic_shapes=({0: Dim.DYNAMIC},) + ) + def test_clamp_decomposition(self): class Model1(torch.nn.Module): def forward(self, x): diff --git a/test/inductor/test_aot_inductor_arrayref.py b/test/inductor/test_aot_inductor_arrayref.py index 9ba1121a539..492ad9c23c5 100644 --- a/test/inductor/test_aot_inductor_arrayref.py +++ b/test/inductor/test_aot_inductor_arrayref.py @@ -70,6 +70,7 @@ CPU_TEST_FAILURES = { "test_cond_with_multiple_outputs": fail_minimal_arrayref_interface(), "test_cond_with_parameters": fail_minimal_arrayref_interface(), "test_cond_with_reinterpret_view_inputs_outputs": fail_minimal_arrayref_interface(), + "test_custom_op_in_subgraph": fail_minimal_arrayref_interface(), "test_cond_share_predicte": fail_stack_allocation(is_skip=True), "test_cond_unbacked_symint_closure_dynamic_True": fail_minimal_arrayref_interface(), "test_while_loop_with_unbacked_symint_closure_dynamic_True": fail_minimal_arrayref_interface(), diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index 9b1b0ac075e..0869db93111 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -2604,7 +2604,7 @@ if (!custom_op_wrapper) { "AtenTensorHandle", tensor_call_args, force_mutable=True ) - extern_kernel_node_index = len(V.graph.extern_kernel_nodes) - 1 + extern_kernel_node_index = len(V.extern_kernel_nodes) - 1 self.writeline( f"aoti_torch_proxy_executor_call_function(proxy_executor, " f"{extern_kernel_node_index}, " diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index 3d614d6795b..1d194a8f404 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -1390,7 +1390,10 @@ class _InProcessFxCompile(FxCompile): is_backward=is_backward, is_const_graph=True, ) - with V.set_graph_handler(const_graph): + with ( + V.set_graph_handler(const_graph), + V.set_extern_kernel_nodes([]), + ): assert cpp_wrapper, "AOT mode only supports C++ wrapper" const_graph.run() const_wrapper_code, const_kernel_code = ( @@ -1425,7 +1428,7 @@ class _InProcessFxCompile(FxCompile): # We are going to start code generating runtime asserts, so make sure # you don't start adding new ones in the lowering process graph.freeze_runtime_asserts() - with V.set_graph_handler(graph): + with V.set_graph_handler(graph), V.set_extern_kernel_nodes([]): graph.run(*example_inputs) output_strides: list[Optional[tuple[_StrideExprStr, ...]]] = [] if graph.graph_outputs is not None: @@ -1472,11 +1475,9 @@ class _InProcessFxCompile(FxCompile): ) serialized_extern_kernel_nodes = None - if graph.extern_kernel_nodes: + if V.extern_kernel_nodes: serialized_extern_kernel_nodes = ( - graph.extern_node_serializer( - graph.extern_kernel_nodes - ) + graph.extern_node_serializer(V.extern_kernel_nodes) ) output_code_log.debug( "Serialized Extern Kernel Nodes: \n%s", diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index 31be050ab28..f42ff44a312 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -392,8 +392,6 @@ class GraphLowering(torch.fx.Interpreter): self.inplaced_to_remove: OrderedSet[str] = OrderedSet() self.device_ops: DeviceOpOverrides = None # type: ignore[assignment] self.wrapper_code: PythonWrapperCodegen = None # type: ignore[assignment] - # See `ProxyExecutor Design Note` in ir.py for more details - self.extern_kernel_nodes: list[ir.ExternKernelNode] = [] from torch._inductor.extern_node_serializer import extern_node_json_serializer diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 2601ed32499..44521a23dfd 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -7656,7 +7656,7 @@ class FallbackKernel(ExternKernelAlloc): ), ) - V.graph.extern_kernel_nodes.append(node) + V.extern_kernel_nodes.append(node) return [*args, *ordered_kwargs] diff --git a/torch/_inductor/virtualized.py b/torch/_inductor/virtualized.py index 6144f7c6f18..ea1073f88b7 100644 --- a/torch/_inductor/virtualized.py +++ b/torch/_inductor/virtualized.py @@ -80,6 +80,7 @@ if TYPE_CHECKING: from torch._inductor.codegen.cpp_utils import LocalBufferContext from torch._inductor.debug import DebugContext from torch._inductor.graph import GraphLowering + from torch._inductor.ir import ExternKernelNode from torch._inductor.loop_body import InterpreterShim from torch._subclasses import FakeTensorMode @@ -183,6 +184,9 @@ _ops: Virtualized[OpsHandler[Any]] = Virtualized( "ops", cast(type[OpsHandler[Any]], MockHandler) ) _graph: Virtualized[GraphLowering] = Virtualized("graph", NullHandler) +_extern_kernel_nodes: Virtualized[list[ExternKernelNode]] = Virtualized( + "extern_kernel_nodes", NullHandler +) _real_inputs: Virtualized[list[torch.Tensor]] = Virtualized("real_inputs", NullHandler) _fake_mode: Virtualized[FakeTensorMode] = Virtualized("fake_mode", NullHandler) _kernel: Virtualized[NullKernelHandler] = Virtualized( @@ -343,6 +347,9 @@ class _V: ) get_ops_handler: Callable[[], OpsHandler[Any]] = _ops._get_handler set_graph_handler: Callable[[GraphLowering], Any] = _graph._set_handler + set_extern_kernel_nodes: Callable[[list[ExternKernelNode]], Any] = ( + _extern_kernel_nodes._set_handler + ) set_real_inputs: Callable[[Any], Any] = _real_inputs._set_handler get_real_inputs: Callable[[], Any] = _real_inputs._get_handler set_fake_mode: Callable[[Any], Any] = _fake_mode._set_handler @@ -368,6 +375,15 @@ class _V: """The graph currently being generated""" return _graph._get_handler() + @property + def extern_kernel_nodes(self) -> list[ExternKernelNode]: + """ + The extern_kernel_nodes needed for the entire graph, including the + subgraphs. + See `ProxyExecutor Design Note` in ir.py for more details + """ + return _extern_kernel_nodes._get_handler() + @property def real_inputs(self): """non-fake example inputs"""