[fx] Optimize torch.fx.Node.replace_all_uses_with (#165889)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165889
Approved by: https://github.com/aorenste
This commit is contained in:
Jason Ansel 2025-10-24 13:11:51 -07:00 committed by PyTorch MergeBot
parent 1e2e7cb18b
commit 78bcfcf870
5 changed files with 55 additions and 43 deletions

View File

@ -55,7 +55,7 @@ torch.fx.node.Node.append(self, x: 'Node') -> None
torch.fx.node.Node.format_node(self, placeholder_names: Optional[List[str]] = None, maybe_return_typename: Optional[List[str]] = None, include_tensor_metadata: bool = False) -> Optional[str]
torch.fx.node.Node.insert_arg(self, idx: int, arg: torch.fx.node.Argument) -> None
torch.fx.node.Node.prepend(self, x: 'Node') -> None
torch.fx.node.Node.replace_all_uses_with(self, replace_with: 'Node', delete_user_cb: Callable[[Node], bool] = <function <lambda>>, propagate_meta: bool = False) -> List[Node]
torch.fx.node.Node.replace_all_uses_with(self, replace_with: 'Node', delete_user_cb: Optional[Callable[[Node], bool]] = None, propagate_meta: bool = False) -> List[Node]
torch.fx.node.Node.replace_input_with(self, old_input: 'Node', new_input: 'Node') -> None
torch.fx.node.Node.update_arg(self, idx: int, arg: torch.fx.node.Argument) -> None
torch.fx.node.Node.update_kwarg(self, key: str, arg: torch.fx.node.Argument) -> None

View File

@ -2759,6 +2759,7 @@ class _NodeBase:
) -> None: ...
def _update_args_kwargs(self, args: tuple[Any, ...], kwargs: dict[str, Any]): ...
def _prepend(self, n: FxNode) -> None: ...
def _replace_input_with(self, old_input: FxNode, new_input: FxNode) -> None: ...
def _remove_from_list(self) -> None: ...
def __lt__(self, n: Self) -> _bool: ...
def __gt__(self, n: Self) -> _bool: ...

View File

@ -1274,17 +1274,8 @@ def maybe_inline_graph_saved_tensors_hooks(
else:
# Keep usages of bw_g_input in inserted unpacked hook graph.
# Replace other usages of bw_g_input with unpack_saved_tensor_n.
from torch._C import _fx_map_arg
def maybe_replace_node(n):
return unpack_saved_tensor_n if n == bw_g_input else n
for use_node in original_bw_g_input_users:
new_args = _fx_map_arg(use_node.args, maybe_replace_node)
new_kwargs = _fx_map_arg(use_node.kwargs, maybe_replace_node)
assert isinstance(new_args, tuple)
assert isinstance(new_kwargs, dict)
use_node._update_args_kwargs(new_args, new_kwargs)
use_node._replace_input_with(bw_g_input, unpack_saved_tensor_n)
bw_g.erase_node(bw_unpack_out_n)
# Changing forward graph outputs,

View File

@ -365,6 +365,43 @@ static PyObject* NodeBase__remove_from_list(
Py_RETURN_NONE;
}
static PyObject* NodeBase__replace_input_with(
PyObject* self,
PyObject* const* args,
Py_ssize_t nargs) {
if (nargs != 2) {
PyErr_SetString(
PyExc_TypeError,
"_replace_input_with() requires exactly 2 arguments (old_input, new_input)");
return nullptr;
}
PyObject* old_input = args[0];
PyObject* new_input = args[1];
auto replace_fn = [old_input, new_input](PyObject* maybe_node) {
if (maybe_node == old_input) {
return Py_NewRef(new_input);
}
return Py_NewRef(maybe_node);
};
auto node = reinterpret_cast<NodeBase*>(self);
try {
THPObjectPtr new_args(map_aggregate(node->_args, replace_fn));
if (!new_args) {
return nullptr;
}
THPObjectPtr new_kwargs(map_aggregate(node->_kwargs, replace_fn));
if (!new_kwargs) {
return nullptr;
}
PyObject* update_args[2] = {new_args.get(), new_kwargs.get()};
return NodeBase__update_args_kwargs(self, update_args, 2);
} catch (const PythonError& e) {
return nullptr;
}
}
static PyObject* NodeBase__prepend(PyObject* self_, PyObject* arg) {
if (self_ == arg) {
Py_RETURN_NONE;
@ -514,6 +551,10 @@ static PyMethodDef NodeBase_methods[] = {
(PyCFunction)(void*)(NodeBase__remove_from_list),
METH_NOARGS,
"Internal method: do not call directly."},
{"_replace_input_with",
(PyCFunction)(void*)(NodeBase__replace_input_with),
METH_FASTCALL,
"Internal method: replace occurrences of one input Node with another."},
{"_prepend",
(PyCFunction)(void*)(NodeBase__prepend),
METH_O,

View File

@ -658,7 +658,7 @@ class Node(_NodeBase):
def replace_all_uses_with(
self,
replace_with: "Node",
delete_user_cb: Callable[["Node"], bool] = lambda user: True,
delete_user_cb: Optional[Callable[["Node"], bool]] = None,
*,
propagate_meta: bool = False,
) -> list["Node"]:
@ -686,32 +686,18 @@ class Node(_NodeBase):
)
for k, v in self.meta.items():
replace_with.meta[k] = v
to_process = list(self.users)
skipped = []
m = self.graph.owning_module
to_process = [*self.users]
replace_hooks = getattr(self.graph.owning_module, "_replace_hooks", None)
result = []
for use_node in to_process:
if not delete_user_cb(use_node):
skipped.append(use_node)
if delete_user_cb is not None and not delete_user_cb(use_node):
continue
def maybe_replace_node(n: Node) -> Node:
if n == self:
return replace_with
else:
return n
if getattr(m, "_replace_hooks", None):
for replace_hook in m._replace_hooks:
result.append(use_node)
if replace_hooks:
for replace_hook in replace_hooks:
replace_hook(old=self, new=replace_with.name, user=use_node)
new_args = _fx_map_arg(use_node.args, maybe_replace_node)
new_kwargs = _fx_map_arg(use_node.kwargs, maybe_replace_node)
assert isinstance(new_args, tuple)
assert isinstance(new_kwargs, dict)
use_node._update_args_kwargs(new_args, new_kwargs)
assert len(self.users) - len(skipped) == 0
return [n for n in to_process if n not in skipped]
use_node._replace_input_with(self, replace_with)
return result
@compatibility(is_backward_compatible=False)
def is_impure(self, impure_random: bool = True) -> bool:
@ -842,19 +828,12 @@ class Node(_NodeBase):
new_input (Node): The new input node to replace ``old_input``.
"""
def maybe_replace_node(n: Node) -> Node:
return new_input if n == old_input else n
m = self.graph.owning_module
if getattr(m, "_replace_hooks", None):
for replace_hook in m._replace_hooks:
replace_hook(old=old_input, new=new_input.name, user=self)
new_args = _fx_map_arg(self.args, maybe_replace_node)
new_kwargs = _fx_map_arg(self.kwargs, maybe_replace_node)
assert isinstance(new_args, tuple)
assert isinstance(new_kwargs, dict)
self._update_args_kwargs(new_args, new_kwargs)
self._replace_input_with(old_input, new_input)
def _rename(self, candidate: str) -> None:
if candidate == self.name: