mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[fx] Optimize torch.fx.Node.replace_all_uses_with (#165889)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165889 Approved by: https://github.com/aorenste
This commit is contained in:
parent
1e2e7cb18b
commit
78bcfcf870
|
|
@ -55,7 +55,7 @@ torch.fx.node.Node.append(self, x: 'Node') -> None
|
||||||
torch.fx.node.Node.format_node(self, placeholder_names: Optional[List[str]] = None, maybe_return_typename: Optional[List[str]] = None, include_tensor_metadata: bool = False) -> Optional[str]
|
torch.fx.node.Node.format_node(self, placeholder_names: Optional[List[str]] = None, maybe_return_typename: Optional[List[str]] = None, include_tensor_metadata: bool = False) -> Optional[str]
|
||||||
torch.fx.node.Node.insert_arg(self, idx: int, arg: torch.fx.node.Argument) -> None
|
torch.fx.node.Node.insert_arg(self, idx: int, arg: torch.fx.node.Argument) -> None
|
||||||
torch.fx.node.Node.prepend(self, x: 'Node') -> None
|
torch.fx.node.Node.prepend(self, x: 'Node') -> None
|
||||||
torch.fx.node.Node.replace_all_uses_with(self, replace_with: 'Node', delete_user_cb: Callable[[Node], bool] = <function <lambda>>, propagate_meta: bool = False) -> List[Node]
|
torch.fx.node.Node.replace_all_uses_with(self, replace_with: 'Node', delete_user_cb: Optional[Callable[[Node], bool]] = None, propagate_meta: bool = False) -> List[Node]
|
||||||
torch.fx.node.Node.replace_input_with(self, old_input: 'Node', new_input: 'Node') -> None
|
torch.fx.node.Node.replace_input_with(self, old_input: 'Node', new_input: 'Node') -> None
|
||||||
torch.fx.node.Node.update_arg(self, idx: int, arg: torch.fx.node.Argument) -> None
|
torch.fx.node.Node.update_arg(self, idx: int, arg: torch.fx.node.Argument) -> None
|
||||||
torch.fx.node.Node.update_kwarg(self, key: str, arg: torch.fx.node.Argument) -> None
|
torch.fx.node.Node.update_kwarg(self, key: str, arg: torch.fx.node.Argument) -> None
|
||||||
|
|
|
||||||
|
|
@ -2759,6 +2759,7 @@ class _NodeBase:
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
def _update_args_kwargs(self, args: tuple[Any, ...], kwargs: dict[str, Any]): ...
|
def _update_args_kwargs(self, args: tuple[Any, ...], kwargs: dict[str, Any]): ...
|
||||||
def _prepend(self, n: FxNode) -> None: ...
|
def _prepend(self, n: FxNode) -> None: ...
|
||||||
|
def _replace_input_with(self, old_input: FxNode, new_input: FxNode) -> None: ...
|
||||||
def _remove_from_list(self) -> None: ...
|
def _remove_from_list(self) -> None: ...
|
||||||
def __lt__(self, n: Self) -> _bool: ...
|
def __lt__(self, n: Self) -> _bool: ...
|
||||||
def __gt__(self, n: Self) -> _bool: ...
|
def __gt__(self, n: Self) -> _bool: ...
|
||||||
|
|
|
||||||
|
|
@ -1274,17 +1274,8 @@ def maybe_inline_graph_saved_tensors_hooks(
|
||||||
else:
|
else:
|
||||||
# Keep usages of bw_g_input in inserted unpacked hook graph.
|
# Keep usages of bw_g_input in inserted unpacked hook graph.
|
||||||
# Replace other usages of bw_g_input with unpack_saved_tensor_n.
|
# Replace other usages of bw_g_input with unpack_saved_tensor_n.
|
||||||
from torch._C import _fx_map_arg
|
|
||||||
|
|
||||||
def maybe_replace_node(n):
|
|
||||||
return unpack_saved_tensor_n if n == bw_g_input else n
|
|
||||||
|
|
||||||
for use_node in original_bw_g_input_users:
|
for use_node in original_bw_g_input_users:
|
||||||
new_args = _fx_map_arg(use_node.args, maybe_replace_node)
|
use_node._replace_input_with(bw_g_input, unpack_saved_tensor_n)
|
||||||
new_kwargs = _fx_map_arg(use_node.kwargs, maybe_replace_node)
|
|
||||||
assert isinstance(new_args, tuple)
|
|
||||||
assert isinstance(new_kwargs, dict)
|
|
||||||
use_node._update_args_kwargs(new_args, new_kwargs)
|
|
||||||
bw_g.erase_node(bw_unpack_out_n)
|
bw_g.erase_node(bw_unpack_out_n)
|
||||||
|
|
||||||
# Changing forward graph outputs,
|
# Changing forward graph outputs,
|
||||||
|
|
|
||||||
|
|
@ -365,6 +365,43 @@ static PyObject* NodeBase__remove_from_list(
|
||||||
Py_RETURN_NONE;
|
Py_RETURN_NONE;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static PyObject* NodeBase__replace_input_with(
|
||||||
|
PyObject* self,
|
||||||
|
PyObject* const* args,
|
||||||
|
Py_ssize_t nargs) {
|
||||||
|
if (nargs != 2) {
|
||||||
|
PyErr_SetString(
|
||||||
|
PyExc_TypeError,
|
||||||
|
"_replace_input_with() requires exactly 2 arguments (old_input, new_input)");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
PyObject* old_input = args[0];
|
||||||
|
PyObject* new_input = args[1];
|
||||||
|
auto replace_fn = [old_input, new_input](PyObject* maybe_node) {
|
||||||
|
if (maybe_node == old_input) {
|
||||||
|
return Py_NewRef(new_input);
|
||||||
|
}
|
||||||
|
return Py_NewRef(maybe_node);
|
||||||
|
};
|
||||||
|
|
||||||
|
auto node = reinterpret_cast<NodeBase*>(self);
|
||||||
|
try {
|
||||||
|
THPObjectPtr new_args(map_aggregate(node->_args, replace_fn));
|
||||||
|
if (!new_args) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
THPObjectPtr new_kwargs(map_aggregate(node->_kwargs, replace_fn));
|
||||||
|
if (!new_kwargs) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
PyObject* update_args[2] = {new_args.get(), new_kwargs.get()};
|
||||||
|
return NodeBase__update_args_kwargs(self, update_args, 2);
|
||||||
|
} catch (const PythonError& e) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static PyObject* NodeBase__prepend(PyObject* self_, PyObject* arg) {
|
static PyObject* NodeBase__prepend(PyObject* self_, PyObject* arg) {
|
||||||
if (self_ == arg) {
|
if (self_ == arg) {
|
||||||
Py_RETURN_NONE;
|
Py_RETURN_NONE;
|
||||||
|
|
@ -514,6 +551,10 @@ static PyMethodDef NodeBase_methods[] = {
|
||||||
(PyCFunction)(void*)(NodeBase__remove_from_list),
|
(PyCFunction)(void*)(NodeBase__remove_from_list),
|
||||||
METH_NOARGS,
|
METH_NOARGS,
|
||||||
"Internal method: do not call directly."},
|
"Internal method: do not call directly."},
|
||||||
|
{"_replace_input_with",
|
||||||
|
(PyCFunction)(void*)(NodeBase__replace_input_with),
|
||||||
|
METH_FASTCALL,
|
||||||
|
"Internal method: replace occurrences of one input Node with another."},
|
||||||
{"_prepend",
|
{"_prepend",
|
||||||
(PyCFunction)(void*)(NodeBase__prepend),
|
(PyCFunction)(void*)(NodeBase__prepend),
|
||||||
METH_O,
|
METH_O,
|
||||||
|
|
|
||||||
|
|
@ -658,7 +658,7 @@ class Node(_NodeBase):
|
||||||
def replace_all_uses_with(
|
def replace_all_uses_with(
|
||||||
self,
|
self,
|
||||||
replace_with: "Node",
|
replace_with: "Node",
|
||||||
delete_user_cb: Callable[["Node"], bool] = lambda user: True,
|
delete_user_cb: Optional[Callable[["Node"], bool]] = None,
|
||||||
*,
|
*,
|
||||||
propagate_meta: bool = False,
|
propagate_meta: bool = False,
|
||||||
) -> list["Node"]:
|
) -> list["Node"]:
|
||||||
|
|
@ -686,32 +686,18 @@ class Node(_NodeBase):
|
||||||
)
|
)
|
||||||
for k, v in self.meta.items():
|
for k, v in self.meta.items():
|
||||||
replace_with.meta[k] = v
|
replace_with.meta[k] = v
|
||||||
to_process = list(self.users)
|
to_process = [*self.users]
|
||||||
skipped = []
|
replace_hooks = getattr(self.graph.owning_module, "_replace_hooks", None)
|
||||||
m = self.graph.owning_module
|
result = []
|
||||||
for use_node in to_process:
|
for use_node in to_process:
|
||||||
if not delete_user_cb(use_node):
|
if delete_user_cb is not None and not delete_user_cb(use_node):
|
||||||
skipped.append(use_node)
|
|
||||||
continue
|
continue
|
||||||
|
result.append(use_node)
|
||||||
def maybe_replace_node(n: Node) -> Node:
|
if replace_hooks:
|
||||||
if n == self:
|
for replace_hook in replace_hooks:
|
||||||
return replace_with
|
|
||||||
else:
|
|
||||||
return n
|
|
||||||
|
|
||||||
if getattr(m, "_replace_hooks", None):
|
|
||||||
for replace_hook in m._replace_hooks:
|
|
||||||
replace_hook(old=self, new=replace_with.name, user=use_node)
|
replace_hook(old=self, new=replace_with.name, user=use_node)
|
||||||
|
use_node._replace_input_with(self, replace_with)
|
||||||
new_args = _fx_map_arg(use_node.args, maybe_replace_node)
|
return result
|
||||||
new_kwargs = _fx_map_arg(use_node.kwargs, maybe_replace_node)
|
|
||||||
assert isinstance(new_args, tuple)
|
|
||||||
assert isinstance(new_kwargs, dict)
|
|
||||||
use_node._update_args_kwargs(new_args, new_kwargs)
|
|
||||||
|
|
||||||
assert len(self.users) - len(skipped) == 0
|
|
||||||
return [n for n in to_process if n not in skipped]
|
|
||||||
|
|
||||||
@compatibility(is_backward_compatible=False)
|
@compatibility(is_backward_compatible=False)
|
||||||
def is_impure(self, impure_random: bool = True) -> bool:
|
def is_impure(self, impure_random: bool = True) -> bool:
|
||||||
|
|
@ -842,19 +828,12 @@ class Node(_NodeBase):
|
||||||
new_input (Node): The new input node to replace ``old_input``.
|
new_input (Node): The new input node to replace ``old_input``.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def maybe_replace_node(n: Node) -> Node:
|
|
||||||
return new_input if n == old_input else n
|
|
||||||
|
|
||||||
m = self.graph.owning_module
|
m = self.graph.owning_module
|
||||||
if getattr(m, "_replace_hooks", None):
|
if getattr(m, "_replace_hooks", None):
|
||||||
for replace_hook in m._replace_hooks:
|
for replace_hook in m._replace_hooks:
|
||||||
replace_hook(old=old_input, new=new_input.name, user=self)
|
replace_hook(old=old_input, new=new_input.name, user=self)
|
||||||
|
|
||||||
new_args = _fx_map_arg(self.args, maybe_replace_node)
|
self._replace_input_with(old_input, new_input)
|
||||||
new_kwargs = _fx_map_arg(self.kwargs, maybe_replace_node)
|
|
||||||
assert isinstance(new_args, tuple)
|
|
||||||
assert isinstance(new_kwargs, dict)
|
|
||||||
self._update_args_kwargs(new_args, new_kwargs)
|
|
||||||
|
|
||||||
def _rename(self, candidate: str) -> None:
|
def _rename(self, candidate: str) -> None:
|
||||||
if candidate == self.name:
|
if candidate == self.name:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user