[reland][fx] Move Node._prepend/Node._remove_from_list to C++ (#165882)

Relands #148261 that was reverted by #150542

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165882
Approved by: https://github.com/ezyang
This commit is contained in:
Jason Ansel 2025-10-20 07:57:33 -07:00 committed by PyTorch MergeBot
parent 0bd12c1168
commit 3c3b278872
5 changed files with 267 additions and 61 deletions

View File

@ -1,8 +1,8 @@
add_loop_eager,compile_time_instruction_count,3070000000,0.1 add_loop_eager,compile_time_instruction_count,3184000000,0.1
add_loop_eager_dynamic,compile_time_instruction_count,4432000000,0.1 add_loop_eager_dynamic,compile_time_instruction_count,4595000000,0.1
@ -18,7 +18,7 @@ add_loop_inductor_gpu,compile_time_instruction_count,26800000000,0.1
basic_modules_ListOfLinears_eager,compile_time_instruction_count,1048000000,0.1 basic_modules_ListOfLinears_eager,compile_time_instruction_count,1096000000,0.1
@ -26,7 +26,7 @@ basic_modules_ListOfLinears_inductor,compile_time_instruction_count,15240000000,
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17020000000,0.1 basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17720000000,0.1
@ -34,11 +34,11 @@ basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,11090000
update_hint_regression,compile_time_instruction_count,1719000000,0.1 update_hint_regression,compile_time_instruction_count,1645000000,0.1
sum_floordiv_regression,compile_time_instruction_count,3686995725,0.1 sum_floordiv_regression,compile_time_instruction_count,3813000000,0.1
@ -50,31 +50,31 @@ symint_sum_loop,compile_time_instruction_count,4299000000,0.1
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,1869000000,0.1 aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,1793000000,0.1
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5281000000,0.1 aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5120000000,0.1
aotdispatcher_partitioner_cpu,compile_time_instruction_count,8333000000,0.1 aotdispatcher_partitioner_cpu,compile_time_instruction_count,7936000000,0.1
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1909000000,0.1 aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1848000000,0.1
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3442000000,0.1 aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3152000000,0.1
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,9239000000,0.1 aotdispatcher_training_subclass_cpu,compile_time_instruction_count,8301000000,0.1
mm_loop_inductor_gpu,compile_time_instruction_count,4820968837,0.1 mm_loop_inductor_gpu,compile_time_instruction_count,4958000000,0.1
@ -82,8 +82,8 @@ mm_loop_inductor_dynamic_gpu,compile_time_instruction_count,9051000000,0.1
basic_NestedModule_eager,compile_time_instruction_count,9554000000,0.1 basic_NestedModule_eager,compile_time_instruction_count,9990000000,0.1
basic_InlineMod_eager,compile_time_instruction_count,7618000000,0.1 basic_InlineMod_eager,compile_time_instruction_count,8126000000,0.1

1 add_loop_eager compile_time_instruction_count 3070000000 3184000000 0.1
2 add_loop_eager_dynamic compile_time_instruction_count 4432000000 4595000000 0.1
3 add_loop_inductor compile_time_instruction_count 29660000000 29660000000 0.1
4 add_loop_inductor_dynamic_gpu compile_time_instruction_count 39910000000 39910000000 0.1
5 add_loop_inductor_gpu compile_time_instruction_count 26800000000 26800000000 0.1
6 basic_modules_ListOfLinears_eager compile_time_instruction_count 1048000000 1096000000 0.1
7 basic_modules_ListOfLinears_inductor compile_time_instruction_count 15240000000 15240000000 0.1
8 basic_modules_ListOfLinears_inductor_gpu_force_shape_pad compile_time_instruction_count 17020000000 17720000000 0.1
18 aotdispatcher_training_nosubclass_cpu compile_time_instruction_count 3442000000 3152000000 0.1
19 aotdispatcher_training_subclass_cpu compile_time_instruction_count 9239000000 8301000000 0.1
20 mm_loop_inductor_gpu compile_time_instruction_count 4820968837 4958000000 0.1
21 mm_loop_inductor_dynamic_gpu compile_time_instruction_count 9051000000 9051000000 0.1
22 basic_NestedModule_eager compile_time_instruction_count 9554000000 9990000000 0.1
23 basic_InlineMod_eager compile_time_instruction_count 7618000000 8126000000 0.1
24
26
27
28
29
30
31
32
34
35
36
37
38
39
40
41
42
43
44
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
82
83
84
85
86
87
88
89

View File

@ -6,6 +6,7 @@ import builtins
import collections import collections
import contextlib import contextlib
import copy import copy
import gc
import functools import functools
import inspect import inspect
import io import io
@ -19,6 +20,7 @@ import traceback
import types import types
import typing import typing
import unittest import unittest
import weakref
import warnings import warnings
from math import sqrt from math import sqrt
from torch.multiprocessing import Process from torch.multiprocessing import Process
@ -1624,6 +1626,25 @@ class TestFX(JitTestCase):
self.assertTrue(neg not in relu.users) self.assertTrue(neg not in relu.users)
@skipIfTorchDynamo("Dynamo does not free right away")
def test_prepend_does_not_leak(self):
g = Graph()
x = g.placeholder("x")
relu = g.call_function(torch.relu, (x,))
neg = g.call_function(torch.neg, (x,))
relu.prepend(neg)
ref = weakref.ref(neg)
g.erase_node(neg)
del g
del x
del relu
del neg
gc.collect()
self.assertIsNone(ref())
def test_remove_uses_with_custom_filter(self): def test_remove_uses_with_custom_filter(self):
g: torch.fx.Graph = Graph() g: torch.fx.Graph = Graph()
x: torch.fx.Node = g.placeholder("x") x: torch.fx.Node = g.placeholder("x")

View File

@ -2758,6 +2758,12 @@ class _NodeBase:
return_type: Any, return_type: Any,
) -> None: ... ) -> None: ...
def _update_args_kwargs(self, args: tuple[Any, ...], kwargs: dict[str, Any]): ... def _update_args_kwargs(self, args: tuple[Any, ...], kwargs: dict[str, Any]): ...
def _prepend(self, n: FxNode) -> None: ...
def _remove_from_list(self) -> None: ...
def __lt__(self, n: Self) -> _bool: ...
def __gt__(self, n: Self) -> _bool: ...
def __le__(self, n: Self) -> _bool: ...
def __ge__(self, n: Self) -> _bool: ...
class _NodeIter(Iterator[FxNode]): class _NodeIter(Iterator[FxNode]):
def __init__(self, root: FxNode, reversed: _bool) -> None: ... def __init__(self, root: FxNode, reversed: _bool) -> None: ...

View File

@ -1,11 +1,15 @@
#include <torch/csrc/fx/node.h> #include <torch/csrc/fx/node.h>
#include <c10/util/Exception.h>
#include <c10/util/SmallVector.h>
#include <structmember.h> #include <structmember.h>
#include <torch/csrc/utils/object_ptr.h> #include <torch/csrc/utils/object_ptr.h>
#include <torch/csrc/utils/pythoncapi_compat.h> #include <torch/csrc/utils/pythoncapi_compat.h>
#include <algorithm>
namespace { namespace {
using NodeSortKey = c10::SmallVector<int64_t, 4>;
struct NodeBase; struct NodeBase;
// Thrown to exit out of a C++ function and return an error to Python. // Thrown to exit out of a C++ function and return an error to Python.
@ -163,7 +167,41 @@ struct NodeBase {
PyObject* users; PyObject* users;
PyObject* _repr_fn; PyObject* _repr_fn;
PyObject* meta; PyObject* meta;
PyObject* _sort_key; // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
alignas(NodeSortKey) char sort_key_buf[sizeof(NodeSortKey)];
inline NodeSortKey& sort_key() {
return *reinterpret_cast<NodeSortKey*>(sort_key_buf);
}
inline void set_prev(NodeBase* value) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(value);
Py_INCREF(reinterpret_cast<PyObject*>(value));
NodeBase* old = _prev;
_prev = value;
Py_DECREF(reinterpret_cast<PyObject*>(old));
}
inline void set_next(NodeBase* value) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(value);
Py_INCREF(reinterpret_cast<PyObject*>(value));
NodeBase* old = _next;
_next = value;
Py_DECREF(reinterpret_cast<PyObject*>(old));
}
// Equivalent to:
// p, n = self._prev, self._next
// p._next, n._prev = n, p
inline void remove_from_list() {
if (this->_prev == this && this->_next == this) {
return;
}
NodeBase* p = this->_prev;
NodeBase* n = this->_next;
p->set_next(n);
n->set_prev(p);
}
}; };
static PyObject* NodeBase_new( static PyObject* NodeBase_new(
@ -173,6 +211,8 @@ static PyObject* NodeBase_new(
PyObject* self = type->tp_alloc(type, 0); PyObject* self = type->tp_alloc(type, 0);
if (!self) if (!self)
return nullptr; return nullptr;
new (reinterpret_cast<NodeBase*>(self)->sort_key_buf)
NodeSortKey(); // placement new does not allocate
return self; return self;
} }
@ -201,7 +241,6 @@ static int NodeBase_init_fn(NodeBase* self, PyObject* args, PyObject* kwds) {
self->users = PyDict_New(); self->users = PyDict_New();
self->_repr_fn = Py_NewRef(Py_None); self->_repr_fn = Py_NewRef(Py_None);
self->meta = PyDict_New(); self->meta = PyDict_New();
self->_sort_key = PyTuple_New(0);
return 0; return 0;
} }
@ -221,7 +260,6 @@ static struct PyMemberDef NodeBase_members[] = {
{"users", T_OBJECT_EX, offsetof(NodeBase, users), 0, nullptr}, {"users", T_OBJECT_EX, offsetof(NodeBase, users), 0, nullptr},
{"_repr_fn", T_OBJECT_EX, offsetof(NodeBase, _repr_fn), 0, nullptr}, {"_repr_fn", T_OBJECT_EX, offsetof(NodeBase, _repr_fn), 0, nullptr},
{"meta", T_OBJECT_EX, offsetof(NodeBase, meta), 0, nullptr}, {"meta", T_OBJECT_EX, offsetof(NodeBase, meta), 0, nullptr},
{"_sort_key", T_OBJECT_EX, offsetof(NodeBase, _sort_key), 0, nullptr},
{nullptr} /* Sentinel */ {nullptr} /* Sentinel */
}; };
@ -239,7 +277,6 @@ static int NodeBase_traverse(NodeBase* self, visitproc visit, void* arg) {
Py_VISIT(self->users); Py_VISIT(self->users);
Py_VISIT(self->_repr_fn); Py_VISIT(self->_repr_fn);
Py_VISIT(self->meta); Py_VISIT(self->meta);
Py_VISIT(self->_sort_key);
return 0; return 0;
} }
@ -257,12 +294,12 @@ static int NodeBase_clear(NodeBase* self) {
Py_CLEAR(self->users); Py_CLEAR(self->users);
Py_CLEAR(self->_repr_fn); Py_CLEAR(self->_repr_fn);
Py_CLEAR(self->meta); Py_CLEAR(self->meta);
Py_CLEAR(self->_sort_key);
return 0; return 0;
} }
static void NodeBase_dealloc(PyObject* self) { static void NodeBase_dealloc(PyObject* self) {
PyObject_GC_UnTrack(self); PyObject_GC_UnTrack(self);
reinterpret_cast<NodeBase*>(self)->sort_key().~NodeSortKey();
(void)NodeBase_clear((NodeBase*)self); (void)NodeBase_clear((NodeBase*)self);
Py_TYPE(self)->tp_free(self); Py_TYPE(self)->tp_free(self);
} }
@ -321,15 +358,195 @@ static PyObject* NodeBase__update_args_kwargs(
} }
} }
static PyObject* NodeBase__remove_from_list(
PyObject* self,
PyObject* _ignored) {
reinterpret_cast<NodeBase*>(self)->remove_from_list();
Py_RETURN_NONE;
}
static PyObject* NodeBase__prepend(PyObject* self_, PyObject* arg) {
if (self_ == arg) {
Py_RETURN_NONE;
}
if (!is_node(arg)) {
PyErr_SetString(PyExc_TypeError, "_prepend() argument must be a Node");
return nullptr;
}
NodeBase* self = reinterpret_cast<NodeBase*>(self_);
NodeBase* x = reinterpret_cast<NodeBase*>(arg);
if (self->graph != x->graph) {
PyErr_SetString(
PyExc_AssertionError,
"Attempting to move a Node into a different Graph");
return nullptr;
}
x->remove_from_list();
NodeBase* p = self->_prev;
p->set_next(x);
x->set_prev(p);
x->set_next(self);
self->set_prev(x);
// Now compute x.sort_key()
const NodeSortKey& psk = x->_prev->sort_key();
const NodeSortKey& nsk = x->_next->sort_key();
if (psk.size() > nsk.size()) {
// prefix = psk[: len(nsk)+1]
size_t slice_len = nsk.size() + 1;
NodeSortKey prefix(psk.begin(), psk.begin() + slice_len);
// last element is idx => increment by 1
prefix.back()++;
x->sort_key() = std::move(prefix);
} else if (psk.size() < nsk.size()) {
// prefix = nsk[: len(psk)+1]
size_t slice_len = psk.size() + 1;
NodeSortKey prefix(nsk.begin(), nsk.begin() + slice_len);
// last element is idx => decrement by 1
prefix.back()--;
x->sort_key() = std::move(prefix);
} else {
// same length => add a 0
x->sort_key() = psk;
x->sort_key().emplace_back(0);
}
Py_RETURN_NONE;
}
// __lt__(self, other): Return self.sort_key < other.sort_key
static PyObject* NodeBase___lt__(PyObject* self, PyObject* other) {
// METH_O => one argument: 'other'
if (!is_node(other)) {
Py_RETURN_NOTIMPLEMENTED;
}
const NodeSortKey& lhs = reinterpret_cast<NodeBase*>(self)->sort_key();
const NodeSortKey& rhs = reinterpret_cast<NodeBase*>(other)->sort_key();
bool less = std::lexicographical_compare(
lhs.begin(), lhs.end(), rhs.begin(), rhs.end());
if (less)
Py_RETURN_TRUE;
Py_RETURN_FALSE;
}
// __gt__(self, other): Return self.sort_key() > other.sort_key
static PyObject* NodeBase___gt__(PyObject* self, PyObject* other) {
if (!is_node(other)) {
Py_RETURN_NOTIMPLEMENTED;
}
const NodeSortKey& lhs = reinterpret_cast<NodeBase*>(self)->sort_key();
const NodeSortKey& rhs = reinterpret_cast<NodeBase*>(other)->sort_key();
// "a > b" is equivalent to "b < a"
bool greater = std::lexicographical_compare(
rhs.begin(), rhs.end(), lhs.begin(), lhs.end());
if (greater)
Py_RETURN_TRUE;
Py_RETURN_FALSE;
}
static PyObject* NodeBase___ge__(PyObject* self, PyObject* other) {
if (self == other) {
Py_RETURN_TRUE;
}
return NodeBase___gt__(self, other);
}
// __le__(self, other): Return not (self > other)
static PyObject* NodeBase___le__(PyObject* self, PyObject* other) {
if (self == other) {
Py_RETURN_TRUE;
}
return NodeBase___lt__(self, other);
}
// Convert the NodeBase::sort_key vector<long> into a Python tuple of ints
// Only used by pickle/__getstate__
static PyObject* NodeBase_get_sort_key(PyObject* self, void* /*closure*/) {
NodeBase* node = reinterpret_cast<NodeBase*>(self);
const NodeSortKey& vec = node->sort_key();
Py_ssize_t n = static_cast<Py_ssize_t>(vec.size());
THPObjectPtr tuple(PyTuple_New(n));
if (!tuple) {
return nullptr; // Out of memory
}
for (Py_ssize_t i = 0; i < n; i++) {
PyObject* value = PyLong_FromSsize_t(vec[i]);
if (!value) {
return nullptr;
}
PyTuple_SET_ITEM(tuple.get(), i, value);
}
return tuple.release();
}
// Setter for NodeBase::sort_key: expects a Python tuple of ints, e.g.
// node._sort_key = (1,2,3) Only used by pickle/__setstate__
static int NodeBase_set_sort_key(
PyObject* self,
PyObject* value,
void* /*closure*/) {
NodeBase* node = reinterpret_cast<NodeBase*>(self);
if (!PyTuple_Check(value)) {
PyErr_SetString(PyExc_TypeError, "_sort_key must be an tuple of ints");
return -1;
}
Py_ssize_t size = PyTuple_GET_SIZE(value);
NodeSortKey new_vec;
new_vec.reserve(size);
for (Py_ssize_t i = 0; i < size; i++) {
int64_t val = PyLong_AsSsize_t(PyTuple_GET_ITEM(value, i));
if (val == -1 && PyErr_Occurred()) {
return -1;
}
new_vec.emplace_back(val);
}
node->sort_key() = std::move(new_vec);
return 0;
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
static PyMethodDef NodeBase_methods[] = { static PyMethodDef NodeBase_methods[] = {
{"_update_args_kwargs", {"_update_args_kwargs",
(PyCFunction)(void*)(NodeBase__update_args_kwargs), (PyCFunction)(void*)(NodeBase__update_args_kwargs),
METH_FASTCALL, METH_FASTCALL,
"Internal method: do not call directly."}, "Internal method: do not call directly."},
{"_remove_from_list",
(PyCFunction)(void*)(NodeBase__remove_from_list),
METH_NOARGS,
"Internal method: do not call directly."},
{"_prepend",
(PyCFunction)(void*)(NodeBase__prepend),
METH_O,
"Internal method: do not call directly."},
{"__lt__",
(PyCFunction)(void*)NodeBase___lt__,
METH_O,
"Return True if self.sort_key < other.sort_key"},
{"__gt__",
(PyCFunction)(void*)NodeBase___gt__,
METH_O,
"Return True if self.sort_key > other.sort_key"},
{"__ge__",
(PyCFunction)(void*)NodeBase___ge__,
METH_O,
"Return True if self.sort_key >= other.sort_key"},
{"__le__",
(PyCFunction)(void*)NodeBase___le__,
METH_O,
"Return True if self.sort_key <= other.sort_key"},
{nullptr, nullptr, 0, nullptr} // Sentinel {nullptr, nullptr, 0, nullptr} // Sentinel
}; };
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
static PyGetSetDef NodeBase_getset[] = {
{"_sort_key", // attribute name in Python
(getter)NodeBase_get_sort_key, // C getter function
(setter)NodeBase_set_sort_key, // C setter function
(char*)"The sort key as a tuple of ints", // docstring
nullptr},
{nullptr, nullptr, nullptr, nullptr, nullptr} // Sentinel
};
PyTypeObject NodeBaseType = { PyTypeObject NodeBaseType = {
PyVarObject_HEAD_INIT(nullptr, 0) PyVarObject_HEAD_INIT(nullptr, 0)
"torch._C._NodeBase", /* tp_name */ "torch._C._NodeBase", /* tp_name */
@ -361,7 +578,7 @@ PyTypeObject NodeBaseType = {
nullptr, /* tp_iternext */ nullptr, /* tp_iternext */
NodeBase_methods, /* tp_methods */ NodeBase_methods, /* tp_methods */
NodeBase_members, /* tp_members */ NodeBase_members, /* tp_members */
nullptr, /* tp_getset */ NodeBase_getset, /* tp_getset */
nullptr, /* tp_base */ nullptr, /* tp_base */
nullptr, /* tp_dict */ nullptr, /* tp_dict */
nullptr, /* tp_descr_get */ nullptr, /* tp_descr_get */

View File

@ -385,41 +385,7 @@ class Node(_NodeBase):
Args: Args:
x (Node): The node to put before this node. Must be a member of the same graph. x (Node): The node to put before this node. Must be a member of the same graph.
""" """
assert self.graph == x.graph, "Attempting to move a Node into a different Graph" self._prepend(x)
if self == x:
log.debug(
"Trying to prepend a node to itself. This behavior has no effect on the graph."
)
return
x._remove_from_list()
p = self._prev
p._next, x._prev = x, p
x._next, self._prev = self, x
# compute x._sort_key
psk = x._prev._sort_key
nsk = x._next._sort_key
if len(psk) > len(nsk):
idx: int
*prefix, idx = psk[: len(nsk) + 1]
x._sort_key = (*prefix, idx + 1)
elif len(psk) < len(nsk):
*prefix, idx = nsk[: len(psk) + 1]
x._sort_key = (*prefix, idx - 1)
else: # same length, increase length by 1
x._sort_key = (*psk, 0)
def __gt__(self, other: "Node") -> bool:
return self._sort_key > other._sort_key
def __lt__(self, other: "Node") -> bool:
return self._sort_key < other._sort_key
def __ge__(self, other: "Node") -> bool:
return self > other or self == other
def __le__(self, other: "Node") -> bool:
return self < other or self == other
@compatibility(is_backward_compatible=True) @compatibility(is_backward_compatible=True)
def append(self, x: "Node") -> None: def append(self, x: "Node") -> None:
@ -430,11 +396,7 @@ class Node(_NodeBase):
Args: Args:
x (Node): The node to put after this node. Must be a member of the same graph. x (Node): The node to put after this node. Must be a member of the same graph.
""" """
self._next.prepend(x) self._next._prepend(x)
def _remove_from_list(self) -> None:
p, n = self._prev, self._next
p._next, n._prev = n, p
@property @property
def args(self) -> tuple[Argument, ...]: def args(self) -> tuple[Argument, ...]: