[fx] Move Node._prepend/Node._remove_from_list to C++ (#148261)

Microbenchmarking `fx.symbolic_trace(lambda x: functools.reduce(operator.add, [x, *range(100000)]))`, before:
```
24303536 function calls (23503339 primitive calls) in 10.726 seconds
```
after:
```
20003454 function calls (19203257 primitive calls) in 8.936 seconds
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148261
Approved by: https://github.com/oulgen
ghstack dependencies: #148243, #148260
This commit is contained in:
Jason Ansel 2025-03-02 10:46:19 -08:00 committed by PyTorch MergeBot
parent 0135f57f4a
commit 29c2de9ae1
4 changed files with 225 additions and 67 deletions

View File

@ -1,4 +1,4 @@
add_loop_eager,compile_time_instruction_count,2958000000,0.015
add_loop_eager,compile_time_instruction_count,2892000000,0.015
@ -6,27 +6,27 @@ add_loop_eager_dynamic,compile_time_instruction_count,6250000000,0.025
add_loop_inductor,compile_time_instruction_count,28450000000,0.015
add_loop_inductor,compile_time_instruction_count,27960000000,0.015
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,44690000000,0.025
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,43820000000,0.025
add_loop_inductor_gpu,compile_time_instruction_count,24770000000,0.015
add_loop_inductor_gpu,compile_time_instruction_count,24280000000,0.015
basic_modules_ListOfLinears_eager,compile_time_instruction_count,959200000,0.015
basic_modules_ListOfLinears_eager,compile_time_instruction_count,954100000,0.015
basic_modules_ListOfLinears_inductor,compile_time_instruction_count,17950000000,0.015
basic_modules_ListOfLinears_inductor,compile_time_instruction_count,17510000000,0.015
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,16030000000,0.015
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,15690000000,0.015
@ -34,32 +34,32 @@ basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,98740000
update_hint_regression,compile_time_instruction_count,1683000000,0.02
update_hint_regression,compile_time_instruction_count,1641000000,0.02
sum_floordiv_regression,compile_time_instruction_count,1054000000,0.015
sum_floordiv_regression,compile_time_instruction_count,1036000000,0.015
symint_sum,compile_time_instruction_count,3167000000,0.015
symint_sum,compile_time_instruction_count,3094000000,0.015
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2010000000,0.015
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,1992000000,0.015
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5776000000,0.015
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5734000000,0.015
aotdispatcher_partitioner_cpu,compile_time_instruction_count,8521000000,0.015
aotdispatcher_partitioner_cpu,compile_time_instruction_count,8066000000,0.015
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3735000000,0.015
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3631000000,0.015
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10070000000,0.015
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,9862000000,0.015

1 add_loop_eager compile_time_instruction_count 2958000000 2892000000 0.015
2 add_loop_eager_dynamic compile_time_instruction_count 6250000000 6250000000 0.025
3 add_loop_inductor compile_time_instruction_count 28450000000 27960000000 0.015
4 add_loop_inductor_dynamic_gpu compile_time_instruction_count 44690000000 43820000000 0.025
6 basic_modules_ListOfLinears_eager compile_time_instruction_count 959200000 954100000 0.015
7 basic_modules_ListOfLinears_inductor compile_time_instruction_count 17950000000 17510000000 0.015
8 basic_modules_ListOfLinears_inductor_gpu_force_shape_pad compile_time_instruction_count 16030000000 15690000000 0.015
9 basic_modules_ListOfLinears_inductor_gpu compile_time_instruction_count 9874000000 9874000000 0.2
10 update_hint_regression compile_time_instruction_count 1683000000 1641000000 0.02
11 sum_floordiv_regression compile_time_instruction_count 1054000000 1036000000 0.015
12 symint_sum compile_time_instruction_count 3167000000 3094000000 0.015
13 aotdispatcher_inference_nosubclass_cpu compile_time_instruction_count 2010000000 1992000000 0.015
14 aotdispatcher_inference_subclass_cpu compile_time_instruction_count 5776000000 5734000000 0.015
15 aotdispatcher_partitioner_cpu compile_time_instruction_count 8521000000 8066000000 0.015
16 aotdispatcher_training_nosubclass_cpu compile_time_instruction_count 3735000000 3631000000 0.015
17 aotdispatcher_training_subclass_cpu compile_time_instruction_count 10070000000 9862000000 0.015
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65

View File

@ -2514,6 +2514,12 @@ class _NodeBase:
return_type: Any,
) -> None: ...
def _update_args_kwargs(self, args: tuple[Any, ...], kwargs: dict[str, Any]): ...
def _prepend(self, n: FxNode) -> None: ...
def _remove_from_list(self) -> None: ...
def __lt__(self, n: Self) -> _bool: ...
def __gt__(self, n: Self) -> _bool: ...
def __le__(self, n: Self) -> _bool: ...
def __ge__(self, n: Self) -> _bool: ...
class _NodeIter(Iterator):
def __init__(self, root: FxNode, reversed: _bool) -> None: ...

View File

@ -1,11 +1,14 @@
#include <torch/csrc/fx/node.h>
#include <c10/util/SmallVector.h>
#include <structmember.h>
#include <torch/csrc/utils/object_ptr.h>
#include <torch/csrc/utils/pythoncapi_compat.h>
#include <algorithm>
namespace {
using NodeSortKey = c10::SmallVector<int64_t, 4>;
struct NodeBase;
// Thrown to exit out of a C++ function and return an error to Python.
@ -60,8 +63,7 @@ inline static PyObject* map_aggregate(PyObject* a, F fn) {
if (PyTuple_Check(a)) {
Py_ssize_t n = PyTuple_GET_SIZE(a);
if (n == 0) {
Py_INCREF(a);
return a;
return Py_NewRef(a);
}
THPObjectPtr new_tuple(PyTuple_New(n));
if (!new_tuple) {
@ -85,8 +87,7 @@ inline static PyObject* map_aggregate(PyObject* a, F fn) {
else if (PyList_Check(a)) {
Py_ssize_t n = PyList_GET_SIZE(a);
if (n == 0 && exact_type(a, immutable_list_cls())) {
Py_INCREF(a);
return a;
return Py_NewRef(a);
}
THPObjectPtr result(PyObject_CallNoArgs(immutable_list_cls()));
if (!result) {
@ -104,8 +105,7 @@ inline static PyObject* map_aggregate(PyObject* a, F fn) {
// Case 3: a is a dict.
else if (PyDict_Check(a)) {
if (PyDict_GET_SIZE(a) == 0 && exact_type(a, immutable_dict_cls())) {
Py_INCREF(a);
return a;
return Py_NewRef(a);
}
THPObjectPtr result(PyObject_CallNoArgs(immutable_dict_cls()));
if (!result) {
@ -166,7 +166,22 @@ struct NodeBase {
PyObject* users;
PyObject* _repr_fn;
PyObject* meta;
PyObject* _sort_key;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
alignas(NodeSortKey) char sort_key_buf[sizeof(NodeSortKey)];
inline NodeSortKey& sort_key() {
return *reinterpret_cast<NodeSortKey*>(sort_key_buf);
}
// Equivalent to:
// p, n = self._prev, self._next
// p._next, n._prev = n, p
inline void remove_from_list() {
NodeBase* p = this->_prev;
NodeBase* n = this->_next;
p->_next = n;
n->_prev = p;
}
};
static PyObject* NodeBase_new(
@ -176,6 +191,8 @@ static PyObject* NodeBase_new(
PyObject* self = type->tp_alloc(type, 0);
if (!self)
return nullptr;
new (reinterpret_cast<NodeBase*>(self)->sort_key_buf)
NodeSortKey(); // placement new does not allocate
return self;
}
@ -204,7 +221,6 @@ static int NodeBase_init_fn(NodeBase* self, PyObject* args, PyObject* kwds) {
self->users = PyDict_New();
self->_repr_fn = Py_NewRef(Py_None);
self->meta = PyDict_New();
self->_sort_key = PyTuple_New(0);
return 0;
}
@ -224,7 +240,6 @@ static struct PyMemberDef NodeBase_members[] = {
{"users", T_OBJECT_EX, offsetof(NodeBase, users), 0, nullptr},
{"_repr_fn", T_OBJECT_EX, offsetof(NodeBase, _repr_fn), 0, nullptr},
{"meta", T_OBJECT_EX, offsetof(NodeBase, meta), 0, nullptr},
{"_sort_key", T_OBJECT_EX, offsetof(NodeBase, _sort_key), 0, nullptr},
{nullptr} /* Sentinel */
};
@ -242,7 +257,6 @@ static int NodeBase_traverse(NodeBase* self, visitproc visit, void* arg) {
Py_VISIT(self->users);
Py_VISIT(self->_repr_fn);
Py_VISIT(self->meta);
Py_VISIT(self->_sort_key);
return 0;
}
@ -260,12 +274,12 @@ static int NodeBase_clear(NodeBase* self) {
Py_CLEAR(self->users);
Py_CLEAR(self->_repr_fn);
Py_CLEAR(self->meta);
Py_CLEAR(self->_sort_key);
return 0;
}
static void NodeBase_dealloc(PyObject* self) {
PyObject_GC_UnTrack(self);
reinterpret_cast<NodeBase*>(self)->sort_key().~NodeSortKey();
(void)NodeBase_clear((NodeBase*)self);
Py_TYPE(self)->tp_free(self);
}
@ -324,15 +338,191 @@ static PyObject* NodeBase__update_args_kwargs(
}
}
static PyObject* NodeBase__remove_from_list(
PyObject* self,
PyObject* _ignored) {
reinterpret_cast<NodeBase*>(self)->remove_from_list();
Py_RETURN_NONE;
}
static PyObject* NodeBase__prepend(PyObject* self_, PyObject* arg) {
if (self_ == arg) {
Py_RETURN_NONE;
}
if (!is_node(arg)) {
PyErr_SetString(PyExc_TypeError, "_prepend() argument must be a Node");
return nullptr;
}
NodeBase* self = reinterpret_cast<NodeBase*>(self_);
NodeBase* x = reinterpret_cast<NodeBase*>(arg);
if (self->graph != x->graph) {
PyErr_SetString(
PyExc_AssertionError,
"Attempting to move a Node into a different Graph");
return nullptr;
}
x->remove_from_list();
NodeBase* p = self->_prev;
p->_next = x;
x->_prev = p;
x->_next = self;
self->_prev = x;
// Now compute x.sort_key()
const NodeSortKey& psk = x->_prev->sort_key();
const NodeSortKey& nsk = x->_next->sort_key();
if (psk.size() > nsk.size()) {
// prefix = psk[: len(nsk)+1]
size_t slice_len = nsk.size() + 1;
NodeSortKey prefix(psk.begin(), psk.begin() + slice_len);
// last element is idx => increment by 1
prefix.back()++;
x->sort_key() = std::move(prefix);
} else if (psk.size() < nsk.size()) {
// prefix = nsk[: len(psk)+1]
size_t slice_len = psk.size() + 1;
NodeSortKey prefix(nsk.begin(), nsk.begin() + slice_len);
// last element is idx => decrement by 1
prefix.back()--;
x->sort_key() = std::move(prefix);
} else {
// same length => add a 0
x->sort_key() = psk;
x->sort_key().emplace_back(0);
}
Py_RETURN_NONE;
}
// __lt__(self, other): Return self.sort_key < other.sort_key
static PyObject* NodeBase___lt__(PyObject* self, PyObject* other) {
// METH_O => one argument: 'other'
if (!is_node(other)) {
Py_RETURN_NOTIMPLEMENTED;
}
const NodeSortKey& lhs = reinterpret_cast<NodeBase*>(self)->sort_key();
const NodeSortKey& rhs = reinterpret_cast<NodeBase*>(other)->sort_key();
bool less = std::lexicographical_compare(
lhs.begin(), lhs.end(), rhs.begin(), rhs.end());
if (less)
Py_RETURN_TRUE;
Py_RETURN_FALSE;
}
// __gt__(self, other): Return self.sort_key() > other.sort_key
static PyObject* NodeBase___gt__(PyObject* self, PyObject* other) {
if (!is_node(other)) {
Py_RETURN_NOTIMPLEMENTED;
}
const NodeSortKey& lhs = reinterpret_cast<NodeBase*>(self)->sort_key();
const NodeSortKey& rhs = reinterpret_cast<NodeBase*>(other)->sort_key();
// "a > b" is equivalent to "b < a"
bool greater = std::lexicographical_compare(
rhs.begin(), rhs.end(), lhs.begin(), lhs.end());
if (greater)
Py_RETURN_TRUE;
Py_RETURN_FALSE;
}
static PyObject* NodeBase___ge__(PyObject* self, PyObject* other) {
if (self == other) {
Py_RETURN_TRUE;
}
return NodeBase___gt__(self, other);
}
// __le__(self, other): Return not (self > other)
static PyObject* NodeBase___le__(PyObject* self, PyObject* other) {
if (self == other) {
Py_RETURN_TRUE;
}
return NodeBase___lt__(self, other);
}
// Convert the NodeBase::sort_key vector<long> into a Python tuple of ints
// Only used by pickle/__getstate__
static PyObject* NodeBase_get_sort_key(PyObject* self, void* /*closure*/) {
NodeBase* node = reinterpret_cast<NodeBase*>(self);
const NodeSortKey& vec = node->sort_key();
Py_ssize_t n = static_cast<Py_ssize_t>(vec.size());
THPObjectPtr tuple(PyTuple_New(n));
if (!tuple) {
return nullptr; // Out of memory
}
for (Py_ssize_t i = 0; i < n; i++) {
PyTuple_SET_ITEM(tuple.get(), i, PyLong_FromSsize_t(vec[i]));
}
return tuple.release();
}
// Setter for NodeBase::sort_key: expects a Python tuple of ints, e.g.
// node._sort_key = (1,2,3) Only used by pickle/__setstate__
static int NodeBase_set_sort_key(
PyObject* self,
PyObject* value,
void* /*closure*/) {
NodeBase* node = reinterpret_cast<NodeBase*>(self);
if (!PyTuple_Check(value)) {
PyErr_SetString(PyExc_TypeError, "_sort_key must be an tuple of ints");
return -1;
}
Py_ssize_t size = PyTuple_GET_SIZE(value);
NodeSortKey new_vec;
new_vec.reserve(size);
for (Py_ssize_t i = 0; i < size; i++) {
int64_t val = PyLong_AsSsize_t(PyTuple_GET_ITEM(value, i));
if (val == -1 && PyErr_Occurred()) {
return -1;
}
new_vec.emplace_back(val);
}
node->sort_key() = std::move(new_vec);
return 0;
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
static PyMethodDef NodeBase_methods[] = {
{"_update_args_kwargs",
(PyCFunction)(void*)(NodeBase__update_args_kwargs),
METH_FASTCALL,
"Internal method: do not call directly."},
{"_remove_from_list",
(PyCFunction)(void*)(NodeBase__remove_from_list),
METH_NOARGS,
"Internal method: do not call directly."},
{"_prepend",
(PyCFunction)(void*)(NodeBase__prepend),
METH_O,
"Internal method: do not call directly."},
{"__lt__",
(PyCFunction)(void*)NodeBase___lt__,
METH_O,
"Return True if self.sort_key < other.sort_key"},
{"__gt__",
(PyCFunction)(void*)NodeBase___gt__,
METH_O,
"Return True if self.sort_key > other.sort_key"},
{"__ge__",
(PyCFunction)(void*)NodeBase___ge__,
METH_O,
"Return True if self.sort_key >= other.sort_key"},
{"__le__",
(PyCFunction)(void*)NodeBase___le__,
METH_O,
"Return True if self.sort_key <= other.sort_key"},
{nullptr, nullptr, 0, nullptr} // Sentinel
};
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
static PyGetSetDef NodeBase_getset[] = {
{"_sort_key", // attribute name in Python
(getter)NodeBase_get_sort_key, // C getter function
(setter)NodeBase_set_sort_key, // C setter function
(char*)"The sort key as a tuple of ints", // docstring
nullptr},
{nullptr, nullptr, nullptr, nullptr, nullptr} // Sentinel
};
PyTypeObject NodeBaseType = {
PyVarObject_HEAD_INIT(nullptr, 0)
"torch._C._NodeBase", /* tp_name */
@ -364,7 +554,7 @@ PyTypeObject NodeBaseType = {
nullptr, /* tp_iternext */
NodeBase_methods, /* tp_methods */
NodeBase_members, /* tp_members */
nullptr, /* tp_getset */
NodeBase_getset, /* tp_getset */
nullptr, /* tp_base */
nullptr, /* tp_dict */
nullptr, /* tp_descr_get */

View File

@ -375,41 +375,7 @@ class Node(_NodeBase):
Args:
x (Node): The node to put before this node. Must be a member of the same graph.
"""
assert self.graph == x.graph, "Attempting to move a Node into a different Graph"
if self == x:
log.debug(
"Trying to prepend a node to itself. This behavior has no effect on the graph."
)
return
x._remove_from_list()
p = self._prev
p._next, x._prev = x, p
x._next, self._prev = self, x
# compute x._sort_key
psk = x._prev._sort_key
nsk = x._next._sort_key
if len(psk) > len(nsk):
idx: int
*prefix, idx = psk[: len(nsk) + 1]
x._sort_key = (*prefix, idx + 1)
elif len(psk) < len(nsk):
*prefix, idx = nsk[: len(psk) + 1]
x._sort_key = (*prefix, idx - 1)
else: # same length, increase length by 1
x._sort_key = (*psk, 0)
def __gt__(self, other: "Node") -> bool:
return self._sort_key > other._sort_key
def __lt__(self, other: "Node") -> bool:
return self._sort_key < other._sort_key
def __ge__(self, other: "Node") -> bool:
return self > other or self == other
def __le__(self, other: "Node") -> bool:
return self < other or self == other
self._prepend(x)
@compatibility(is_backward_compatible=True)
def append(self, x: "Node") -> None:
@ -420,11 +386,7 @@ class Node(_NodeBase):
Args:
x (Node): The node to put after this node. Must be a member of the same graph.
"""
self._next.prepend(x)
def _remove_from_list(self) -> None:
p, n = self._prev, self._next
p._next, n._prev = n, p
self._next._prepend(x)
@property
def args(self) -> tuple[Argument, ...]: