diff --git a/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect b/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect index 882f58f3c37..a404e15a977 100644 --- a/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect +++ b/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect @@ -55,7 +55,7 @@ torch.fx.node.Node.append(self, x: 'Node') -> None torch.fx.node.Node.format_node(self, placeholder_names: Optional[List[str]] = None, maybe_return_typename: Optional[List[str]] = None, include_tensor_metadata: bool = False) -> Optional[str] torch.fx.node.Node.insert_arg(self, idx: int, arg: torch.fx.node.Argument) -> None torch.fx.node.Node.prepend(self, x: 'Node') -> None -torch.fx.node.Node.replace_all_uses_with(self, replace_with: 'Node', delete_user_cb: Callable[[Node], bool] = >, propagate_meta: bool = False) -> List[Node] +torch.fx.node.Node.replace_all_uses_with(self, replace_with: 'Node', delete_user_cb: Optional[Callable[[Node], bool]] = None, propagate_meta: bool = False) -> List[Node] torch.fx.node.Node.replace_input_with(self, old_input: 'Node', new_input: 'Node') -> None torch.fx.node.Node.update_arg(self, idx: int, arg: torch.fx.node.Argument) -> None torch.fx.node.Node.update_kwarg(self, key: str, arg: torch.fx.node.Argument) -> None diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 66d73feec00..8309b22db30 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -2759,6 +2759,7 @@ class _NodeBase: ) -> None: ... def _update_args_kwargs(self, args: tuple[Any, ...], kwargs: dict[str, Any]): ... def _prepend(self, n: FxNode) -> None: ... + def _replace_input_with(self, old_input: FxNode, new_input: FxNode) -> None: ... def _remove_from_list(self) -> None: ... def __lt__(self, n: Self) -> _bool: ... def __gt__(self, n: Self) -> _bool: ... diff --git a/torch/_functorch/_aot_autograd/graph_compile.py b/torch/_functorch/_aot_autograd/graph_compile.py index a7f7dee5355..91df5d42823 100644 --- a/torch/_functorch/_aot_autograd/graph_compile.py +++ b/torch/_functorch/_aot_autograd/graph_compile.py @@ -1274,17 +1274,8 @@ def maybe_inline_graph_saved_tensors_hooks( else: # Keep usages of bw_g_input in inserted unpacked hook graph. # Replace other usages of bw_g_input with unpack_saved_tensor_n. - from torch._C import _fx_map_arg - - def maybe_replace_node(n): - return unpack_saved_tensor_n if n == bw_g_input else n - for use_node in original_bw_g_input_users: - new_args = _fx_map_arg(use_node.args, maybe_replace_node) - new_kwargs = _fx_map_arg(use_node.kwargs, maybe_replace_node) - assert isinstance(new_args, tuple) - assert isinstance(new_kwargs, dict) - use_node._update_args_kwargs(new_args, new_kwargs) + use_node._replace_input_with(bw_g_input, unpack_saved_tensor_n) bw_g.erase_node(bw_unpack_out_n) # Changing forward graph outputs, diff --git a/torch/csrc/fx/node.cpp b/torch/csrc/fx/node.cpp index d60f43c9556..1669f79af72 100644 --- a/torch/csrc/fx/node.cpp +++ b/torch/csrc/fx/node.cpp @@ -365,6 +365,43 @@ static PyObject* NodeBase__remove_from_list( Py_RETURN_NONE; } +static PyObject* NodeBase__replace_input_with( + PyObject* self, + PyObject* const* args, + Py_ssize_t nargs) { + if (nargs != 2) { + PyErr_SetString( + PyExc_TypeError, + "_replace_input_with() requires exactly 2 arguments (old_input, new_input)"); + return nullptr; + } + PyObject* old_input = args[0]; + PyObject* new_input = args[1]; + auto replace_fn = [old_input, new_input](PyObject* maybe_node) { + if (maybe_node == old_input) { + return Py_NewRef(new_input); + } + return Py_NewRef(maybe_node); + }; + + auto node = reinterpret_cast(self); + try { + THPObjectPtr new_args(map_aggregate(node->_args, replace_fn)); + if (!new_args) { + return nullptr; + } + THPObjectPtr new_kwargs(map_aggregate(node->_kwargs, replace_fn)); + if (!new_kwargs) { + return nullptr; + } + + PyObject* update_args[2] = {new_args.get(), new_kwargs.get()}; + return NodeBase__update_args_kwargs(self, update_args, 2); + } catch (const PythonError& e) { + return nullptr; + } +} + static PyObject* NodeBase__prepend(PyObject* self_, PyObject* arg) { if (self_ == arg) { Py_RETURN_NONE; @@ -514,6 +551,10 @@ static PyMethodDef NodeBase_methods[] = { (PyCFunction)(void*)(NodeBase__remove_from_list), METH_NOARGS, "Internal method: do not call directly."}, + {"_replace_input_with", + (PyCFunction)(void*)(NodeBase__replace_input_with), + METH_FASTCALL, + "Internal method: replace occurrences of one input Node with another."}, {"_prepend", (PyCFunction)(void*)(NodeBase__prepend), METH_O, diff --git a/torch/fx/node.py b/torch/fx/node.py index 6f8bc730860..466c704fb92 100644 --- a/torch/fx/node.py +++ b/torch/fx/node.py @@ -658,7 +658,7 @@ class Node(_NodeBase): def replace_all_uses_with( self, replace_with: "Node", - delete_user_cb: Callable[["Node"], bool] = lambda user: True, + delete_user_cb: Optional[Callable[["Node"], bool]] = None, *, propagate_meta: bool = False, ) -> list["Node"]: @@ -686,32 +686,18 @@ class Node(_NodeBase): ) for k, v in self.meta.items(): replace_with.meta[k] = v - to_process = list(self.users) - skipped = [] - m = self.graph.owning_module + to_process = [*self.users] + replace_hooks = getattr(self.graph.owning_module, "_replace_hooks", None) + result = [] for use_node in to_process: - if not delete_user_cb(use_node): - skipped.append(use_node) + if delete_user_cb is not None and not delete_user_cb(use_node): continue - - def maybe_replace_node(n: Node) -> Node: - if n == self: - return replace_with - else: - return n - - if getattr(m, "_replace_hooks", None): - for replace_hook in m._replace_hooks: + result.append(use_node) + if replace_hooks: + for replace_hook in replace_hooks: replace_hook(old=self, new=replace_with.name, user=use_node) - - new_args = _fx_map_arg(use_node.args, maybe_replace_node) - new_kwargs = _fx_map_arg(use_node.kwargs, maybe_replace_node) - assert isinstance(new_args, tuple) - assert isinstance(new_kwargs, dict) - use_node._update_args_kwargs(new_args, new_kwargs) - - assert len(self.users) - len(skipped) == 0 - return [n for n in to_process if n not in skipped] + use_node._replace_input_with(self, replace_with) + return result @compatibility(is_backward_compatible=False) def is_impure(self, impure_random: bool = True) -> bool: @@ -842,19 +828,12 @@ class Node(_NodeBase): new_input (Node): The new input node to replace ``old_input``. """ - def maybe_replace_node(n: Node) -> Node: - return new_input if n == old_input else n - m = self.graph.owning_module if getattr(m, "_replace_hooks", None): for replace_hook in m._replace_hooks: replace_hook(old=old_input, new=new_input.name, user=self) - new_args = _fx_map_arg(self.args, maybe_replace_node) - new_kwargs = _fx_map_arg(self.kwargs, maybe_replace_node) - assert isinstance(new_args, tuple) - assert isinstance(new_kwargs, dict) - self._update_args_kwargs(new_args, new_kwargs) + self._replace_input_with(old_input, new_input) def _rename(self, candidate: str) -> None: if candidate == self.name: