[fx] Move Node._prepend/Node._remove_from_list to C++ (#148261)

Microbenchmarking `fx.symbolic_trace(lambda x: functools.reduce(operator.add, [x, *range(100000)]))`, before:
```
24303536 function calls (23503339 primitive calls) in 10.726 seconds
```
after:
```
20003454 function calls (19203257 primitive calls) in 8.936 seconds
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148261
Approved by: https://github.com/oulgen
ghstack dependencies: #148243, #148260
This commit is contained in:
Jason Ansel 2025-03-09 21:21:58 -07:00 committed by PyTorch MergeBot
parent bf752c36da
commit 5d4e7d58b4
4 changed files with 224 additions and 63 deletions

View File

@ -1,65 +1,65 @@
add_loop_eager,compile_time_instruction_count,2972000000,0.015
add_loop_eager,compile_time_instruction_count,2891000000,0.015
add_loop_eager_dynamic,compile_time_instruction_count,5647000000,0.025
add_loop_eager_dynamic,compile_time_instruction_count,5554000000,0.025
add_loop_inductor,compile_time_instruction_count,28480000000,0.015
add_loop_inductor,compile_time_instruction_count,27890000000,0.015
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,42160000000,0.025
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,40960000000,0.025
add_loop_inductor_gpu,compile_time_instruction_count,24910000000,0.015
add_loop_inductor_gpu,compile_time_instruction_count,24330000000,0.015
basic_modules_ListOfLinears_eager,compile_time_instruction_count,960300000,0.015
basic_modules_ListOfLinears_eager,compile_time_instruction_count,954200000,0.015
basic_modules_ListOfLinears_inductor,compile_time_instruction_count,17840000000,0.015
basic_modules_ListOfLinears_inductor,compile_time_instruction_count,17410000000,0.015
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,15940000000,0.015
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,15620000000,0.015
basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,10040000000,0.2
basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,9716000000,0.2
update_hint_regression,compile_time_instruction_count,1593000000,0.02
update_hint_regression,compile_time_instruction_count,1548000000,0.02
sum_floordiv_regression,compile_time_instruction_count,1052000000,0.015
sum_floordiv_regression,compile_time_instruction_count,1039000000,0.015
symint_sum,compile_time_instruction_count,3135000000,0.015
symint_sum,compile_time_instruction_count,3075000000,0.015
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2004000000,0.015
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,1980000000,0.015
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5753000000,0.015
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5702000000,0.015
aotdispatcher_partitioner_cpu,compile_time_instruction_count,8484000000,0.015
aotdispatcher_partitioner_cpu,compile_time_instruction_count,8018000000,0.015
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3718000000,0.015
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3609000000,0.015
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10020000000,0.015
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,9809000000,0.015

1 add_loop_eager compile_time_instruction_count 2972000000 2891000000 0.015
2 add_loop_eager_dynamic compile_time_instruction_count 5647000000 5554000000 0.025
3 add_loop_inductor compile_time_instruction_count 28480000000 27890000000 0.015
4 add_loop_inductor_dynamic_gpu compile_time_instruction_count 42160000000 40960000000 0.025
5 add_loop_inductor_gpu compile_time_instruction_count 24910000000 24330000000 0.015
6 basic_modules_ListOfLinears_eager compile_time_instruction_count 960300000 954200000 0.015
7 basic_modules_ListOfLinears_inductor compile_time_instruction_count 17840000000 17410000000 0.015
8 basic_modules_ListOfLinears_inductor_gpu_force_shape_pad compile_time_instruction_count 15940000000 15620000000 0.015
9 basic_modules_ListOfLinears_inductor_gpu compile_time_instruction_count 10040000000 9716000000 0.2
10 update_hint_regression compile_time_instruction_count 1593000000 1548000000 0.02
11 sum_floordiv_regression compile_time_instruction_count 1052000000 1039000000 0.015
12 symint_sum compile_time_instruction_count 3135000000 3075000000 0.015
13 aotdispatcher_inference_nosubclass_cpu compile_time_instruction_count 2004000000 1980000000 0.015
14 aotdispatcher_inference_subclass_cpu compile_time_instruction_count 5753000000 5702000000 0.015
15 aotdispatcher_partitioner_cpu compile_time_instruction_count 8484000000 8018000000 0.015
16 aotdispatcher_training_nosubclass_cpu compile_time_instruction_count 3718000000 3609000000 0.015
17 aotdispatcher_training_subclass_cpu compile_time_instruction_count 10020000000 9809000000 0.015
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65

View File

@ -2528,6 +2528,12 @@ class _NodeBase:
return_type: Any,
) -> None: ...
def _update_args_kwargs(self, args: tuple[Any, ...], kwargs: dict[str, Any]): ...
def _prepend(self, n: FxNode) -> None: ...
def _remove_from_list(self) -> None: ...
def __lt__(self, n: Self) -> _bool: ...
def __gt__(self, n: Self) -> _bool: ...
def __le__(self, n: Self) -> _bool: ...
def __ge__(self, n: Self) -> _bool: ...
class _NodeIter(Iterator):
def __init__(self, root: FxNode, reversed: _bool) -> None: ...

View File

@ -1,11 +1,14 @@
#include <torch/csrc/fx/node.h>
#include <c10/util/SmallVector.h>
#include <structmember.h>
#include <torch/csrc/utils/object_ptr.h>
#include <torch/csrc/utils/pythoncapi_compat.h>
#include <algorithm>
namespace {
using NodeSortKey = c10::SmallVector<int64_t, 4>;
struct NodeBase;
// Thrown to exit out of a C++ function and return an error to Python.
@ -163,7 +166,22 @@ struct NodeBase {
PyObject* users;
PyObject* _repr_fn;
PyObject* meta;
PyObject* _sort_key;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
alignas(NodeSortKey) char sort_key_buf[sizeof(NodeSortKey)];
inline NodeSortKey& sort_key() {
return *reinterpret_cast<NodeSortKey*>(sort_key_buf);
}
// Equivalent to:
// p, n = self._prev, self._next
// p._next, n._prev = n, p
inline void remove_from_list() {
NodeBase* p = this->_prev;
NodeBase* n = this->_next;
p->_next = n;
n->_prev = p;
}
};
static PyObject* NodeBase_new(
@ -173,6 +191,8 @@ static PyObject* NodeBase_new(
PyObject* self = type->tp_alloc(type, 0);
if (!self)
return nullptr;
new (reinterpret_cast<NodeBase*>(self)->sort_key_buf)
NodeSortKey(); // placement new does not allocate
return self;
}
@ -201,7 +221,6 @@ static int NodeBase_init_fn(NodeBase* self, PyObject* args, PyObject* kwds) {
self->users = PyDict_New();
self->_repr_fn = Py_NewRef(Py_None);
self->meta = PyDict_New();
self->_sort_key = PyTuple_New(0);
return 0;
}
@ -221,7 +240,6 @@ static struct PyMemberDef NodeBase_members[] = {
{"users", T_OBJECT_EX, offsetof(NodeBase, users), 0, nullptr},
{"_repr_fn", T_OBJECT_EX, offsetof(NodeBase, _repr_fn), 0, nullptr},
{"meta", T_OBJECT_EX, offsetof(NodeBase, meta), 0, nullptr},
{"_sort_key", T_OBJECT_EX, offsetof(NodeBase, _sort_key), 0, nullptr},
{nullptr} /* Sentinel */
};
@ -239,7 +257,6 @@ static int NodeBase_traverse(NodeBase* self, visitproc visit, void* arg) {
Py_VISIT(self->users);
Py_VISIT(self->_repr_fn);
Py_VISIT(self->meta);
Py_VISIT(self->_sort_key);
return 0;
}
@ -257,12 +274,12 @@ static int NodeBase_clear(NodeBase* self) {
Py_CLEAR(self->users);
Py_CLEAR(self->_repr_fn);
Py_CLEAR(self->meta);
Py_CLEAR(self->_sort_key);
return 0;
}
static void NodeBase_dealloc(PyObject* self) {
PyObject_GC_UnTrack(self);
reinterpret_cast<NodeBase*>(self)->sort_key().~NodeSortKey();
(void)NodeBase_clear((NodeBase*)self);
Py_TYPE(self)->tp_free(self);
}
@ -321,15 +338,191 @@ static PyObject* NodeBase__update_args_kwargs(
}
}
static PyObject* NodeBase__remove_from_list(
PyObject* self,
PyObject* _ignored) {
reinterpret_cast<NodeBase*>(self)->remove_from_list();
Py_RETURN_NONE;
}
static PyObject* NodeBase__prepend(PyObject* self_, PyObject* arg) {
if (self_ == arg) {
Py_RETURN_NONE;
}
if (!is_node(arg)) {
PyErr_SetString(PyExc_TypeError, "_prepend() argument must be a Node");
return nullptr;
}
NodeBase* self = reinterpret_cast<NodeBase*>(self_);
NodeBase* x = reinterpret_cast<NodeBase*>(arg);
if (self->graph != x->graph) {
PyErr_SetString(
PyExc_AssertionError,
"Attempting to move a Node into a different Graph");
return nullptr;
}
x->remove_from_list();
NodeBase* p = self->_prev;
p->_next = x;
x->_prev = p;
x->_next = self;
self->_prev = x;
// Now compute x.sort_key()
const NodeSortKey& psk = x->_prev->sort_key();
const NodeSortKey& nsk = x->_next->sort_key();
if (psk.size() > nsk.size()) {
// prefix = psk[: len(nsk)+1]
size_t slice_len = nsk.size() + 1;
NodeSortKey prefix(psk.begin(), psk.begin() + slice_len);
// last element is idx => increment by 1
prefix.back()++;
x->sort_key() = std::move(prefix);
} else if (psk.size() < nsk.size()) {
// prefix = nsk[: len(psk)+1]
size_t slice_len = psk.size() + 1;
NodeSortKey prefix(nsk.begin(), nsk.begin() + slice_len);
// last element is idx => decrement by 1
prefix.back()--;
x->sort_key() = std::move(prefix);
} else {
// same length => add a 0
x->sort_key() = psk;
x->sort_key().emplace_back(0);
}
Py_RETURN_NONE;
}
// __lt__(self, other): Return self.sort_key < other.sort_key
static PyObject* NodeBase___lt__(PyObject* self, PyObject* other) {
// METH_O => one argument: 'other'
if (!is_node(other)) {
Py_RETURN_NOTIMPLEMENTED;
}
const NodeSortKey& lhs = reinterpret_cast<NodeBase*>(self)->sort_key();
const NodeSortKey& rhs = reinterpret_cast<NodeBase*>(other)->sort_key();
bool less = std::lexicographical_compare(
lhs.begin(), lhs.end(), rhs.begin(), rhs.end());
if (less)
Py_RETURN_TRUE;
Py_RETURN_FALSE;
}
// __gt__(self, other): Return self.sort_key() > other.sort_key
static PyObject* NodeBase___gt__(PyObject* self, PyObject* other) {
if (!is_node(other)) {
Py_RETURN_NOTIMPLEMENTED;
}
const NodeSortKey& lhs = reinterpret_cast<NodeBase*>(self)->sort_key();
const NodeSortKey& rhs = reinterpret_cast<NodeBase*>(other)->sort_key();
// "a > b" is equivalent to "b < a"
bool greater = std::lexicographical_compare(
rhs.begin(), rhs.end(), lhs.begin(), lhs.end());
if (greater)
Py_RETURN_TRUE;
Py_RETURN_FALSE;
}
static PyObject* NodeBase___ge__(PyObject* self, PyObject* other) {
if (self == other) {
Py_RETURN_TRUE;
}
return NodeBase___gt__(self, other);
}
// __le__(self, other): Return not (self > other)
static PyObject* NodeBase___le__(PyObject* self, PyObject* other) {
if (self == other) {
Py_RETURN_TRUE;
}
return NodeBase___lt__(self, other);
}
// Convert the NodeBase::sort_key vector<long> into a Python tuple of ints
// Only used by pickle/__getstate__
static PyObject* NodeBase_get_sort_key(PyObject* self, void* /*closure*/) {
NodeBase* node = reinterpret_cast<NodeBase*>(self);
const NodeSortKey& vec = node->sort_key();
Py_ssize_t n = static_cast<Py_ssize_t>(vec.size());
THPObjectPtr tuple(PyTuple_New(n));
if (!tuple) {
return nullptr; // Out of memory
}
for (Py_ssize_t i = 0; i < n; i++) {
PyTuple_SET_ITEM(tuple.get(), i, PyLong_FromSsize_t(vec[i]));
}
return tuple.release();
}
// Setter for NodeBase::sort_key: expects a Python tuple of ints, e.g.
// node._sort_key = (1,2,3) Only used by pickle/__setstate__
static int NodeBase_set_sort_key(
PyObject* self,
PyObject* value,
void* /*closure*/) {
NodeBase* node = reinterpret_cast<NodeBase*>(self);
if (!PyTuple_Check(value)) {
PyErr_SetString(PyExc_TypeError, "_sort_key must be an tuple of ints");
return -1;
}
Py_ssize_t size = PyTuple_GET_SIZE(value);
NodeSortKey new_vec;
new_vec.reserve(size);
for (Py_ssize_t i = 0; i < size; i++) {
int64_t val = PyLong_AsSsize_t(PyTuple_GET_ITEM(value, i));
if (val == -1 && PyErr_Occurred()) {
return -1;
}
new_vec.emplace_back(val);
}
node->sort_key() = std::move(new_vec);
return 0;
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
static PyMethodDef NodeBase_methods[] = {
{"_update_args_kwargs",
(PyCFunction)(void*)(NodeBase__update_args_kwargs),
METH_FASTCALL,
"Internal method: do not call directly."},
{"_remove_from_list",
(PyCFunction)(void*)(NodeBase__remove_from_list),
METH_NOARGS,
"Internal method: do not call directly."},
{"_prepend",
(PyCFunction)(void*)(NodeBase__prepend),
METH_O,
"Internal method: do not call directly."},
{"__lt__",
(PyCFunction)(void*)NodeBase___lt__,
METH_O,
"Return True if self.sort_key < other.sort_key"},
{"__gt__",
(PyCFunction)(void*)NodeBase___gt__,
METH_O,
"Return True if self.sort_key > other.sort_key"},
{"__ge__",
(PyCFunction)(void*)NodeBase___ge__,
METH_O,
"Return True if self.sort_key >= other.sort_key"},
{"__le__",
(PyCFunction)(void*)NodeBase___le__,
METH_O,
"Return True if self.sort_key <= other.sort_key"},
{nullptr, nullptr, 0, nullptr} // Sentinel
};
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
static PyGetSetDef NodeBase_getset[] = {
{"_sort_key", // attribute name in Python
(getter)NodeBase_get_sort_key, // C getter function
(setter)NodeBase_set_sort_key, // C setter function
(char*)"The sort key as a tuple of ints", // docstring
nullptr},
{nullptr, nullptr, nullptr, nullptr, nullptr} // Sentinel
};
PyTypeObject NodeBaseType = {
PyVarObject_HEAD_INIT(nullptr, 0)
"torch._C._NodeBase", /* tp_name */
@ -361,7 +554,7 @@ PyTypeObject NodeBaseType = {
nullptr, /* tp_iternext */
NodeBase_methods, /* tp_methods */
NodeBase_members, /* tp_members */
nullptr, /* tp_getset */
NodeBase_getset, /* tp_getset */
nullptr, /* tp_base */
nullptr, /* tp_dict */
nullptr, /* tp_descr_get */

View File

@ -375,41 +375,7 @@ class Node(_NodeBase):
Args:
x (Node): The node to put before this node. Must be a member of the same graph.
"""
assert self.graph == x.graph, "Attempting to move a Node into a different Graph"
if self == x:
log.debug(
"Trying to prepend a node to itself. This behavior has no effect on the graph."
)
return
x._remove_from_list()
p = self._prev
p._next, x._prev = x, p
x._next, self._prev = self, x
# compute x._sort_key
psk = x._prev._sort_key
nsk = x._next._sort_key
if len(psk) > len(nsk):
idx: int
*prefix, idx = psk[: len(nsk) + 1]
x._sort_key = (*prefix, idx + 1)
elif len(psk) < len(nsk):
*prefix, idx = nsk[: len(psk) + 1]
x._sort_key = (*prefix, idx - 1)
else: # same length, increase length by 1
x._sort_key = (*psk, 0)
def __gt__(self, other: "Node") -> bool:
return self._sort_key > other._sort_key
def __lt__(self, other: "Node") -> bool:
return self._sort_key < other._sort_key
def __ge__(self, other: "Node") -> bool:
return self > other or self == other
def __le__(self, other: "Node") -> bool:
return self < other or self == other
self._prepend(x)
@compatibility(is_backward_compatible=True)
def append(self, x: "Node") -> None:
@ -420,11 +386,7 @@ class Node(_NodeBase):
Args:
x (Node): The node to put after this node. Must be a member of the same graph.
"""
self._next.prepend(x)
def _remove_from_list(self) -> None:
p, n = self._prev, self._next
p._next, n._prev = n, p
self._next._prepend(x)
@property
def args(self) -> tuple[Argument, ...]: