mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[fx] Optimize torch.fx.Node.replace_all_uses_with (#165889)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165889 Approved by: https://github.com/aorenste
This commit is contained in:
parent
1e2e7cb18b
commit
78bcfcf870
|
|
@ -55,7 +55,7 @@ torch.fx.node.Node.append(self, x: 'Node') -> None
|
|||
torch.fx.node.Node.format_node(self, placeholder_names: Optional[List[str]] = None, maybe_return_typename: Optional[List[str]] = None, include_tensor_metadata: bool = False) -> Optional[str]
|
||||
torch.fx.node.Node.insert_arg(self, idx: int, arg: torch.fx.node.Argument) -> None
|
||||
torch.fx.node.Node.prepend(self, x: 'Node') -> None
|
||||
torch.fx.node.Node.replace_all_uses_with(self, replace_with: 'Node', delete_user_cb: Callable[[Node], bool] = <function <lambda>>, propagate_meta: bool = False) -> List[Node]
|
||||
torch.fx.node.Node.replace_all_uses_with(self, replace_with: 'Node', delete_user_cb: Optional[Callable[[Node], bool]] = None, propagate_meta: bool = False) -> List[Node]
|
||||
torch.fx.node.Node.replace_input_with(self, old_input: 'Node', new_input: 'Node') -> None
|
||||
torch.fx.node.Node.update_arg(self, idx: int, arg: torch.fx.node.Argument) -> None
|
||||
torch.fx.node.Node.update_kwarg(self, key: str, arg: torch.fx.node.Argument) -> None
|
||||
|
|
|
|||
|
|
@ -2759,6 +2759,7 @@ class _NodeBase:
|
|||
) -> None: ...
|
||||
def _update_args_kwargs(self, args: tuple[Any, ...], kwargs: dict[str, Any]): ...
|
||||
def _prepend(self, n: FxNode) -> None: ...
|
||||
def _replace_input_with(self, old_input: FxNode, new_input: FxNode) -> None: ...
|
||||
def _remove_from_list(self) -> None: ...
|
||||
def __lt__(self, n: Self) -> _bool: ...
|
||||
def __gt__(self, n: Self) -> _bool: ...
|
||||
|
|
|
|||
|
|
@ -1274,17 +1274,8 @@ def maybe_inline_graph_saved_tensors_hooks(
|
|||
else:
|
||||
# Keep usages of bw_g_input in inserted unpacked hook graph.
|
||||
# Replace other usages of bw_g_input with unpack_saved_tensor_n.
|
||||
from torch._C import _fx_map_arg
|
||||
|
||||
def maybe_replace_node(n):
|
||||
return unpack_saved_tensor_n if n == bw_g_input else n
|
||||
|
||||
for use_node in original_bw_g_input_users:
|
||||
new_args = _fx_map_arg(use_node.args, maybe_replace_node)
|
||||
new_kwargs = _fx_map_arg(use_node.kwargs, maybe_replace_node)
|
||||
assert isinstance(new_args, tuple)
|
||||
assert isinstance(new_kwargs, dict)
|
||||
use_node._update_args_kwargs(new_args, new_kwargs)
|
||||
use_node._replace_input_with(bw_g_input, unpack_saved_tensor_n)
|
||||
bw_g.erase_node(bw_unpack_out_n)
|
||||
|
||||
# Changing forward graph outputs,
|
||||
|
|
|
|||
|
|
@ -365,6 +365,43 @@ static PyObject* NodeBase__remove_from_list(
|
|||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
static PyObject* NodeBase__replace_input_with(
|
||||
PyObject* self,
|
||||
PyObject* const* args,
|
||||
Py_ssize_t nargs) {
|
||||
if (nargs != 2) {
|
||||
PyErr_SetString(
|
||||
PyExc_TypeError,
|
||||
"_replace_input_with() requires exactly 2 arguments (old_input, new_input)");
|
||||
return nullptr;
|
||||
}
|
||||
PyObject* old_input = args[0];
|
||||
PyObject* new_input = args[1];
|
||||
auto replace_fn = [old_input, new_input](PyObject* maybe_node) {
|
||||
if (maybe_node == old_input) {
|
||||
return Py_NewRef(new_input);
|
||||
}
|
||||
return Py_NewRef(maybe_node);
|
||||
};
|
||||
|
||||
auto node = reinterpret_cast<NodeBase*>(self);
|
||||
try {
|
||||
THPObjectPtr new_args(map_aggregate(node->_args, replace_fn));
|
||||
if (!new_args) {
|
||||
return nullptr;
|
||||
}
|
||||
THPObjectPtr new_kwargs(map_aggregate(node->_kwargs, replace_fn));
|
||||
if (!new_kwargs) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
PyObject* update_args[2] = {new_args.get(), new_kwargs.get()};
|
||||
return NodeBase__update_args_kwargs(self, update_args, 2);
|
||||
} catch (const PythonError& e) {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
static PyObject* NodeBase__prepend(PyObject* self_, PyObject* arg) {
|
||||
if (self_ == arg) {
|
||||
Py_RETURN_NONE;
|
||||
|
|
@ -514,6 +551,10 @@ static PyMethodDef NodeBase_methods[] = {
|
|||
(PyCFunction)(void*)(NodeBase__remove_from_list),
|
||||
METH_NOARGS,
|
||||
"Internal method: do not call directly."},
|
||||
{"_replace_input_with",
|
||||
(PyCFunction)(void*)(NodeBase__replace_input_with),
|
||||
METH_FASTCALL,
|
||||
"Internal method: replace occurrences of one input Node with another."},
|
||||
{"_prepend",
|
||||
(PyCFunction)(void*)(NodeBase__prepend),
|
||||
METH_O,
|
||||
|
|
|
|||
|
|
@ -658,7 +658,7 @@ class Node(_NodeBase):
|
|||
def replace_all_uses_with(
|
||||
self,
|
||||
replace_with: "Node",
|
||||
delete_user_cb: Callable[["Node"], bool] = lambda user: True,
|
||||
delete_user_cb: Optional[Callable[["Node"], bool]] = None,
|
||||
*,
|
||||
propagate_meta: bool = False,
|
||||
) -> list["Node"]:
|
||||
|
|
@ -686,32 +686,18 @@ class Node(_NodeBase):
|
|||
)
|
||||
for k, v in self.meta.items():
|
||||
replace_with.meta[k] = v
|
||||
to_process = list(self.users)
|
||||
skipped = []
|
||||
m = self.graph.owning_module
|
||||
to_process = [*self.users]
|
||||
replace_hooks = getattr(self.graph.owning_module, "_replace_hooks", None)
|
||||
result = []
|
||||
for use_node in to_process:
|
||||
if not delete_user_cb(use_node):
|
||||
skipped.append(use_node)
|
||||
if delete_user_cb is not None and not delete_user_cb(use_node):
|
||||
continue
|
||||
|
||||
def maybe_replace_node(n: Node) -> Node:
|
||||
if n == self:
|
||||
return replace_with
|
||||
else:
|
||||
return n
|
||||
|
||||
if getattr(m, "_replace_hooks", None):
|
||||
for replace_hook in m._replace_hooks:
|
||||
result.append(use_node)
|
||||
if replace_hooks:
|
||||
for replace_hook in replace_hooks:
|
||||
replace_hook(old=self, new=replace_with.name, user=use_node)
|
||||
|
||||
new_args = _fx_map_arg(use_node.args, maybe_replace_node)
|
||||
new_kwargs = _fx_map_arg(use_node.kwargs, maybe_replace_node)
|
||||
assert isinstance(new_args, tuple)
|
||||
assert isinstance(new_kwargs, dict)
|
||||
use_node._update_args_kwargs(new_args, new_kwargs)
|
||||
|
||||
assert len(self.users) - len(skipped) == 0
|
||||
return [n for n in to_process if n not in skipped]
|
||||
use_node._replace_input_with(self, replace_with)
|
||||
return result
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def is_impure(self, impure_random: bool = True) -> bool:
|
||||
|
|
@ -842,19 +828,12 @@ class Node(_NodeBase):
|
|||
new_input (Node): The new input node to replace ``old_input``.
|
||||
"""
|
||||
|
||||
def maybe_replace_node(n: Node) -> Node:
|
||||
return new_input if n == old_input else n
|
||||
|
||||
m = self.graph.owning_module
|
||||
if getattr(m, "_replace_hooks", None):
|
||||
for replace_hook in m._replace_hooks:
|
||||
replace_hook(old=old_input, new=new_input.name, user=self)
|
||||
|
||||
new_args = _fx_map_arg(self.args, maybe_replace_node)
|
||||
new_kwargs = _fx_map_arg(self.kwargs, maybe_replace_node)
|
||||
assert isinstance(new_args, tuple)
|
||||
assert isinstance(new_kwargs, dict)
|
||||
self._update_args_kwargs(new_args, new_kwargs)
|
||||
self._replace_input_with(old_input, new_input)
|
||||
|
||||
def _rename(self, candidate: str) -> None:
|
||||
if candidate == self.name:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user