[fx] Move Node._update_args_kwargs to C++ (#148260)

Microbenchmarking `fx.symbolic_trace(lambda x: functools.reduce(operator.add, [x, *range(100000)]))`, before:
```
25203549 function calls (24403352 primitive calls) in 12.090 seconds
```
after:
```
24303536 function calls (23503339 primitive calls) in 10.726 seconds
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148260
Approved by: https://github.com/oulgen
ghstack dependencies: #148243
This commit is contained in:
Jason Ansel 2025-03-02 10:46:18 -08:00 committed by PyTorch MergeBot
parent edaff88f69
commit 0135f57f4a
4 changed files with 212 additions and 113 deletions

View File

@ -1,32 +1,32 @@
add_loop_eager,compile_time_instruction_count,2993000000,0.015 add_loop_eager,compile_time_instruction_count,2958000000,0.015
add_loop_eager_dynamic,compile_time_instruction_count,6349000000,0.025 add_loop_eager_dynamic,compile_time_instruction_count,6250000000,0.025
add_loop_inductor,compile_time_instruction_count,28630000000,0.015 add_loop_inductor,compile_time_instruction_count,28450000000,0.015
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,45240000000,0.025 add_loop_inductor_dynamic_gpu,compile_time_instruction_count,44690000000,0.025
add_loop_inductor_gpu,compile_time_instruction_count,24960000000,0.015 add_loop_inductor_gpu,compile_time_instruction_count,24770000000,0.015
basic_modules_ListOfLinears_eager,compile_time_instruction_count,960700000,0.015 basic_modules_ListOfLinears_eager,compile_time_instruction_count,959200000,0.015
basic_modules_ListOfLinears_inductor,compile_time_instruction_count,18120000000,0.015 basic_modules_ListOfLinears_inductor,compile_time_instruction_count,17950000000,0.015
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,16340000000,0.015 basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,16030000000,0.015
@ -34,32 +34,32 @@ basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,98740000
update_hint_regression,compile_time_instruction_count,1699000000,0.02 update_hint_regression,compile_time_instruction_count,1683000000,0.02
sum_floordiv_regression,compile_time_instruction_count,1061000000,0.015 sum_floordiv_regression,compile_time_instruction_count,1054000000,0.015
symint_sum,compile_time_instruction_count,3194000000,0.015 symint_sum,compile_time_instruction_count,3167000000,0.015
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2018000000,0.015 aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2010000000,0.015
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5792000000,0.015 aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5776000000,0.015
aotdispatcher_partitioner_cpu,compile_time_instruction_count,8703000000,0.015 aotdispatcher_partitioner_cpu,compile_time_instruction_count,8521000000,0.015
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3783000000,0.015 aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3735000000,0.015
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10170000000,0.015 aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10070000000,0.015

1 add_loop_eager compile_time_instruction_count 2993000000 2958000000 0.015
2 add_loop_eager_dynamic compile_time_instruction_count 6349000000 6250000000 0.025
3 add_loop_inductor compile_time_instruction_count 28630000000 28450000000 0.015
4 add_loop_inductor_dynamic_gpu compile_time_instruction_count 45240000000 44690000000 0.025
5 add_loop_inductor_gpu compile_time_instruction_count 24960000000 24770000000 0.015
6 basic_modules_ListOfLinears_eager compile_time_instruction_count 960700000 959200000 0.015
7 basic_modules_ListOfLinears_inductor compile_time_instruction_count 18120000000 17950000000 0.015
8 basic_modules_ListOfLinears_inductor_gpu_force_shape_pad compile_time_instruction_count 16340000000 16030000000 0.015
9 basic_modules_ListOfLinears_inductor_gpu compile_time_instruction_count 9874000000 9874000000 0.2
10 update_hint_regression compile_time_instruction_count 1699000000 1683000000 0.02
11 sum_floordiv_regression compile_time_instruction_count 1061000000 1054000000 0.015
12 symint_sum compile_time_instruction_count 3194000000 3167000000 0.015
13 aotdispatcher_inference_nosubclass_cpu compile_time_instruction_count 2018000000 2010000000 0.015
14 aotdispatcher_inference_subclass_cpu compile_time_instruction_count 5792000000 5776000000 0.015
15 aotdispatcher_partitioner_cpu compile_time_instruction_count 8703000000 8521000000 0.015
16 aotdispatcher_training_nosubclass_cpu compile_time_instruction_count 3783000000 3735000000 0.015
17 aotdispatcher_training_subclass_cpu compile_time_instruction_count 10170000000 10070000000 0.015
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65

View File

@ -2505,6 +2505,15 @@ class _NodeBase:
_erased: _bool _erased: _bool
_prev: FxNode _prev: FxNode
_next: FxNode _next: FxNode
def __init__(
self,
graph: Any,
name: str,
op: str,
target: Any,
return_type: Any,
) -> None: ...
def _update_args_kwargs(self, args: tuple[Any, ...], kwargs: dict[str, Any]): ...
class _NodeIter(Iterator): class _NodeIter(Iterator):
def __init__(self, root: FxNode, reversed: _bool) -> None: ... def __init__(self, root: FxNode, reversed: _bool) -> None: ...

View File

@ -6,6 +6,8 @@
namespace { namespace {
struct NodeBase;
// Thrown to exit out of a C++ function and return an error to Python. // Thrown to exit out of a C++ function and return an error to Python.
class PythonError : public std::exception {}; class PythonError : public std::exception {};
@ -153,6 +155,18 @@ struct NodeBase {
bool _erased; bool _erased;
NodeBase* _prev; NodeBase* _prev;
NodeBase* _next; NodeBase* _next;
PyObject* graph;
PyObject* name;
PyObject* op;
PyObject* target;
PyObject* type;
PyObject* _input_nodes;
PyObject* _args;
PyObject* _kwargs;
PyObject* users;
PyObject* _repr_fn;
PyObject* meta;
PyObject* _sort_key;
}; };
static PyObject* NodeBase_new( static PyObject* NodeBase_new(
@ -166,11 +180,31 @@ static PyObject* NodeBase_new(
} }
static int NodeBase_init_fn(NodeBase* self, PyObject* args, PyObject* kwds) { static int NodeBase_init_fn(NodeBase* self, PyObject* args, PyObject* kwds) {
PyObject* graph = nullptr;
PyObject* name = nullptr;
PyObject* op = nullptr;
PyObject* target = nullptr;
PyObject* type = nullptr;
if (!PyArg_ParseTuple(args, "OOOOO", &graph, &name, &op, &target, &type)) {
return -1;
}
self->_erased = false; self->_erased = false;
Py_INCREF(self); Py_INCREF(self);
self->_prev = self; self->_prev = self;
Py_INCREF(self); Py_INCREF(self);
self->_next = self; self->_next = self;
self->graph = Py_NewRef(graph);
self->name = Py_NewRef(name);
self->op = Py_NewRef(op);
self->target = Py_NewRef(target);
self->type = Py_NewRef(type);
self->_input_nodes = PyDict_New();
self->_args = nullptr; // set with _update_args_kwargs
self->_kwargs = nullptr; // set with _update_args_kwargs
self->users = PyDict_New();
self->_repr_fn = Py_NewRef(Py_None);
self->meta = PyDict_New();
self->_sort_key = PyTuple_New(0);
return 0; return 0;
} }
@ -179,18 +213,54 @@ static struct PyMemberDef NodeBase_members[] = {
{"_erased", T_BOOL, offsetof(NodeBase, _erased), 0, nullptr}, {"_erased", T_BOOL, offsetof(NodeBase, _erased), 0, nullptr},
{"_prev", T_OBJECT_EX, offsetof(NodeBase, _prev), 0, nullptr}, {"_prev", T_OBJECT_EX, offsetof(NodeBase, _prev), 0, nullptr},
{"_next", T_OBJECT_EX, offsetof(NodeBase, _next), 0, nullptr}, {"_next", T_OBJECT_EX, offsetof(NodeBase, _next), 0, nullptr},
{"graph", T_OBJECT_EX, offsetof(NodeBase, graph), 0, nullptr},
{"name", T_OBJECT_EX, offsetof(NodeBase, name), 0, nullptr},
{"op", T_OBJECT_EX, offsetof(NodeBase, op), 0, nullptr},
{"target", T_OBJECT_EX, offsetof(NodeBase, target), 0, nullptr},
{"type", T_OBJECT_EX, offsetof(NodeBase, type), 0, nullptr},
{"_input_nodes", T_OBJECT_EX, offsetof(NodeBase, _input_nodes), 0, nullptr},
{"_args", T_OBJECT_EX, offsetof(NodeBase, _args), 0, nullptr},
{"_kwargs", T_OBJECT_EX, offsetof(NodeBase, _kwargs), 0, nullptr},
{"users", T_OBJECT_EX, offsetof(NodeBase, users), 0, nullptr},
{"_repr_fn", T_OBJECT_EX, offsetof(NodeBase, _repr_fn), 0, nullptr},
{"meta", T_OBJECT_EX, offsetof(NodeBase, meta), 0, nullptr},
{"_sort_key", T_OBJECT_EX, offsetof(NodeBase, _sort_key), 0, nullptr},
{nullptr} /* Sentinel */ {nullptr} /* Sentinel */
}; };
static int NodeBase_traverse(NodeBase* self, visitproc visit, void* arg) { static int NodeBase_traverse(NodeBase* self, visitproc visit, void* arg) {
Py_VISIT(self->_prev); Py_VISIT(self->_prev);
Py_VISIT(self->_next); Py_VISIT(self->_next);
Py_VISIT(self->graph);
Py_VISIT(self->name);
Py_VISIT(self->op);
Py_VISIT(self->target);
Py_VISIT(self->type);
Py_VISIT(self->_input_nodes);
Py_VISIT(self->_args);
Py_VISIT(self->_kwargs);
Py_VISIT(self->users);
Py_VISIT(self->_repr_fn);
Py_VISIT(self->meta);
Py_VISIT(self->_sort_key);
return 0; return 0;
} }
static int NodeBase_clear(NodeBase* self) { static int NodeBase_clear(NodeBase* self) {
Py_CLEAR(self->_prev); Py_CLEAR(self->_prev);
Py_CLEAR(self->_next); Py_CLEAR(self->_next);
Py_CLEAR(self->graph);
Py_CLEAR(self->name);
Py_CLEAR(self->op);
Py_CLEAR(self->target);
Py_CLEAR(self->type);
Py_CLEAR(self->_input_nodes);
Py_CLEAR(self->_args);
Py_CLEAR(self->_kwargs);
Py_CLEAR(self->users);
Py_CLEAR(self->_repr_fn);
Py_CLEAR(self->meta);
Py_CLEAR(self->_sort_key);
return 0; return 0;
} }
@ -200,6 +270,69 @@ static void NodeBase_dealloc(PyObject* self) {
Py_TYPE(self)->tp_free(self); Py_TYPE(self)->tp_free(self);
} }
static PyObject* NodeBase__update_args_kwargs(
PyObject* self,
PyObject* const* args,
Py_ssize_t nargs) {
// Verify argument count
if (nargs != 2) {
PyErr_SetString(
PyExc_TypeError,
"_update_args_kwargs() requires exactly 2 arguments (new_args, new_kwargs)");
return nullptr;
}
auto node = reinterpret_cast<NodeBase*>(self);
auto input_nodes = node->_input_nodes;
if (PyDict_GET_SIZE(input_nodes) > 0) {
// Clear other.users containing us and input_nodes
PyObject *key = nullptr, *value = nullptr; // borrowed
Py_ssize_t pos = 0;
while (PyDict_Next(input_nodes, &pos, &key, &value)) {
// key.users.pop(self), intentionally ignore KeyError
PyDict_DelItem(reinterpret_cast<NodeBase*>(key)->users, self);
}
PyDict_Clear(input_nodes);
}
auto visit_fn = [self, input_nodes](PyObject* x) {
if (is_node(x)) {
// self._input_nodes.setdefault(x)
if (!PyDict_SetDefault(input_nodes, x, Py_None)) {
throw PythonError();
}
// x.users.setdefault(self)
if (!PyDict_SetDefault(
reinterpret_cast<NodeBase*>(x)->users, self, Py_None)) {
throw PythonError();
}
}
return Py_NewRef(x);
};
// We do three things in a single pass of the args
// - Normalize list->immutable_list, dict->immutable_dict, etc
// - Populate self._input_nodes
// - Populate arg.users[self] for each arg
try {
Py_CLEAR(node->_args);
node->_args = map_aggregate(args[0], visit_fn);
Py_CLEAR(node->_kwargs);
node->_kwargs = map_aggregate(args[1], visit_fn);
Py_RETURN_NONE;
} catch (const PythonError& e) {
return nullptr;
}
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
static PyMethodDef NodeBase_methods[] = {
{"_update_args_kwargs",
(PyCFunction)(void*)(NodeBase__update_args_kwargs),
METH_FASTCALL,
"Internal method: do not call directly."},
{nullptr, nullptr, 0, nullptr} // Sentinel
};
PyTypeObject NodeBaseType = { PyTypeObject NodeBaseType = {
PyVarObject_HEAD_INIT(nullptr, 0) PyVarObject_HEAD_INIT(nullptr, 0)
"torch._C._NodeBase", /* tp_name */ "torch._C._NodeBase", /* tp_name */
@ -229,7 +362,7 @@ PyTypeObject NodeBaseType = {
0, /* tp_weaklistoffset */ 0, /* tp_weaklistoffset */
nullptr, /* tp_iter */ nullptr, /* tp_iter */
nullptr, /* tp_iternext */ nullptr, /* tp_iternext */
nullptr, /* tp_methods */ NodeBase_methods, /* tp_methods */
NodeBase_members, /* tp_members */ NodeBase_members, /* tp_members */
nullptr, /* tp_getset */ nullptr, /* tp_getset */
nullptr, /* tp_base */ nullptr, /* tp_base */

View File

@ -229,14 +229,38 @@ class Node(_NodeBase):
_args: tuple["Argument", ...] _args: tuple["Argument", ...]
_kwargs: dict[str, "Argument"] _kwargs: dict[str, "Argument"]
graph: "Graph" graph: "Graph"
# unique name of value being created
name: str name: str
# the kind of operation = placeholder|call_method|call_module|call_function|get_attr
op: str op: str
# for method/module/function, the name of the method/module/function/attr
# being invoked, e.g add, layer1, or torch.add
target: "Target" target: "Target"
# All `Node`-valued inputs. Key is the Node, value is don't-care.
# The public API for this is `all_input_nodes`, this private attribute
# should not be accessed directly.
_input_nodes: dict["Node", None] _input_nodes: dict["Node", None]
# All of the nodes that use the value produced by this Node
# Note one user may correspond to several uses, e.g. the node fo ``x + x``
# would appear once here, but represents two uses.
# Is a dict to act as an "ordered set". Keys are significant, value dont-care
users: dict["Node", None] users: dict["Node", None]
# Type expression representing the output value of this node.
# This should contain the same class of Type objects that would appear
# as type annotations for function inputs/outputs.
#
# For placeholder nodes, this value will be used to type-annotate the
# generated function parameters.
# For the return node, this value will be used to type-annotate the
# generated function return type. (Note this is a special case. ``return``
# does not produce a value, it's more of a notation. Thus, this value
# describes the type of args[0] in the ``return`` node.
type: Optional[Any] type: Optional[Any]
_sort_key: Any _sort_key: Any
# If set, use this fn to print this node
_repr_fn: Optional[Callable[["Node"], str]] _repr_fn: Optional[Callable[["Node"], str]]
# Dictionary to store metadata passes need to do their
# transformations. This metadata is preserved across node copies
meta: dict[str, Any] meta: dict[str, Any]
@compatibility(is_backward_compatible=True) @compatibility(is_backward_compatible=True)
@ -276,7 +300,6 @@ class Node(_NodeBase):
annotation of values in the generated code or for other types annotation of values in the generated code or for other types
of analyses. of analyses.
""" """
assert op in _legal_ops
if op == "call_function": if op == "call_function":
if not callable(target): if not callable(target):
raise ValueError( raise ValueError(
@ -284,75 +307,38 @@ class Node(_NodeBase):
"but a Callable is expected" "but a Callable is expected"
) )
else: else:
assert op in _legal_ops
if not isinstance(target, str): if not isinstance(target, str):
raise ValueError( raise ValueError(
f"Node [graph = {graph}, name = '{name}'] target {target} has type {torch.typename(target)} " f"Node [graph = {graph}, name = '{name}'] target {target} has type {torch.typename(target)} "
"but a str is expected" "but a str is expected"
) )
super().__init__() super().__init__(graph, name, op, target, return_type)
self._update_args_kwargs(args, kwargs)
# bypass Node.__setattr__ for perf and so that it doesn't need to handle half-built objects
assign = object.__setattr__
assign(self, "graph", graph)
assign(self, "name", name) # unique name of value being created
assign(
self, "op", op
) # the kind of operation = placeholder|call_method|call_module|call_function|get_attr
assign(
self, "target", target
) # for method/module/function, the name of the method/module/function/attr
# being invoked, e.g add, layer1, or torch.add
# All `Node`-valued inputs. Key is the Node, value is don't-care.
# The public API for this is `all_input_nodes`, this private attribute
# should not be accessed directly.
assign(self, "_input_nodes", {})
self.__update_args_kwargs(args, kwargs)
# All of the nodes that use the value produced by this Node
# Note one user may correspond to several uses, e.g. the node fo ``x + x``
# would appear once here, but represents two uses.
#
# Is a dict to act as an "ordered set". Keys are significant, value dont-care
assign(self, "users", {})
# Type expression representing the output value of this node.
# This should contain the same class of Type objects that would appear
# as type annotations for function inputs/outputs.
#
# For placeholder nodes, this value will be used to type-annotate the
# generated function parameters.
# For the return node, this value will be used to type-annotate the
# generated function return type. (Note this is a special case. ``return``
# does not produce a value, it's more of a notation. Thus, this value
# describes the type of args[0] in the ``return`` node.
assign(self, "type", return_type)
assign(self, "_sort_key", ())
# If set, use this fn to print this node
assign(self, "_repr_fn", None)
# Dictionary to store metadata passes need to do their
# transformations. This metadata is preserved across node copies
assign(self, "meta", {})
def __getstate__(self) -> dict[str, Any]: def __getstate__(self) -> dict[str, Any]:
state = self.__dict__.copy() return {
state["_erased"] = self._erased **self.__dict__,
state["_prev"] = self._prev "graph": self.graph,
state["_next"] = self._next "name": self.name,
return state "op": self.op,
"target": self.target,
"type": self.target,
"_sort_key": self._sort_key,
"_args": self._args,
"_kwargs": self._kwargs,
"_erased": self._erased,
"_prev": self._prev,
"_next": self._next,
"_input_nodes": self._input_nodes,
"users": self.users,
"_repr_fn": self._repr_fn,
"meta": self.meta,
}
def __setstate__(self, state: dict[str, Any]) -> None: def __setstate__(self, state: dict[str, Any]) -> None:
_erased = state.pop("_erased") for k, v in state.items():
_prev = state.pop("_prev") setattr(self, k, v)
_next = state.pop("_next")
self.__dict__.update(state)
self._erased = _erased
self._prev = _prev
self._next = _next
@property @property
def next(self) -> "Node": def next(self) -> "Node":
@ -459,9 +445,9 @@ class Node(_NodeBase):
depends on the node's opcode. See the ``fx.Graph`` docstring for more depends on the node's opcode. See the ``fx.Graph`` docstring for more
information. information.
""" """
# DO NOT CALL `__update_args_kwargs` directly. The correct way to # DO NOT CALL `_update_args_kwargs` directly. The correct way to
# set `args` is via direct assignment, i.e. `node.args = new_args` # set `args` is via direct assignment, i.e. `node.args = new_args`
self.__update_args_kwargs(a, self._kwargs) self._update_args_kwargs(a, self._kwargs)
@property @property
def kwargs(self) -> dict[str, Argument]: def kwargs(self) -> dict[str, Argument]:
@ -482,9 +468,9 @@ class Node(_NodeBase):
depends on the node's opcode. See the ``fx.Graph`` docstring for more depends on the node's opcode. See the ``fx.Graph`` docstring for more
information. information.
""" """
# DO NOT CALL `__update_args_kwargs` directly. The correct way to # DO NOT CALL `_update_args_kwargs` directly. The correct way to
# set `args` is via direct assignment, i.e. `node.kwargs = new_kwargs` # set `args` is via direct assignment, i.e. `node.kwargs = new_kwargs`
self.__update_args_kwargs(self._args, k) self._update_args_kwargs(self._args, k)
@property @property
def all_input_nodes(self) -> list["Node"]: def all_input_nodes(self) -> list["Node"]:
@ -572,35 +558,6 @@ class Node(_NodeBase):
def stack_trace(self, trace: Optional[str]) -> None: def stack_trace(self, trace: Optional[str]) -> None:
self.meta["stack_trace"] = trace self.meta["stack_trace"] = trace
def __update_args_kwargs(
self, new_args: tuple["Argument", ...], new_kwargs: dict[str, "Argument"]
) -> None:
"""
This API is internal. Do *not* call it directly.
"""
def update_users_and_input_nodes(n: Any) -> Any:
if isinstance(n, Node):
self._input_nodes.setdefault(n)
n.users.setdefault(self)
return n
# Clear prior users and input_nodes
for old_use in self._input_nodes.keys():
old_use.users.pop(self)
object.__setattr__(self, "_input_nodes", {}) # bypass Node.__setattr__
# We do three things in a single pass of the args
# - Normalize list->immutable_list, dict->immutable_dict, etc
# - Populate self._input_nodes
# - Populate arg.users[self] for each arg
object.__setattr__(
self, "_args", _fx_map_aggregate(new_args, update_users_and_input_nodes)
)
object.__setattr__(
self, "_kwargs", _fx_map_aggregate(new_kwargs, update_users_and_input_nodes)
)
def __repr__(self) -> str: def __repr__(self) -> str:
if self._repr_fn: if self._repr_fn:
return self._repr_fn(self) return self._repr_fn(self)
@ -751,7 +708,7 @@ class Node(_NodeBase):
new_kwargs = _fx_map_arg(use_node.kwargs, maybe_replace_node) new_kwargs = _fx_map_arg(use_node.kwargs, maybe_replace_node)
assert isinstance(new_args, tuple) assert isinstance(new_args, tuple)
assert isinstance(new_kwargs, dict) assert isinstance(new_kwargs, dict)
use_node.__update_args_kwargs(new_args, new_kwargs) use_node._update_args_kwargs(new_args, new_kwargs)
assert len(self.users) - len(skipped) == 0 assert len(self.users) - len(skipped) == 0
return [n for n in to_process if n not in skipped] return [n for n in to_process if n not in skipped]
@ -863,7 +820,7 @@ class Node(_NodeBase):
new_kwargs = _fx_map_arg(self.kwargs, maybe_replace_node) new_kwargs = _fx_map_arg(self.kwargs, maybe_replace_node)
assert isinstance(new_args, tuple) assert isinstance(new_args, tuple)
assert isinstance(new_kwargs, dict) assert isinstance(new_kwargs, dict)
self.__update_args_kwargs(new_args, new_kwargs) self._update_args_kwargs(new_args, new_kwargs)
def _rename(self, candidate: str) -> None: def _rename(self, candidate: str) -> None:
if candidate == self.name: if candidate == self.name: