mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[fx] Move Node._prepend/Node._remove_from_list to C++ (#148261)
Microbenchmarking `fx.symbolic_trace(lambda x: functools.reduce(operator.add, [x, *range(100000)]))`, before: ``` 24303536 function calls (23503339 primitive calls) in 10.726 seconds ``` after: ``` 20003454 function calls (19203257 primitive calls) in 8.936 seconds ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/148261 Approved by: https://github.com/oulgen ghstack dependencies: #148243, #148260
This commit is contained in:
parent
bf752c36da
commit
5d4e7d58b4
|
|
@ -1,65 +1,65 @@
|
||||||
add_loop_eager,compile_time_instruction_count,2972000000,0.015
|
add_loop_eager,compile_time_instruction_count,2891000000,0.015
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
add_loop_eager_dynamic,compile_time_instruction_count,5647000000,0.025
|
add_loop_eager_dynamic,compile_time_instruction_count,5554000000,0.025
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
add_loop_inductor,compile_time_instruction_count,28480000000,0.015
|
add_loop_inductor,compile_time_instruction_count,27890000000,0.015
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,42160000000,0.025
|
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,40960000000,0.025
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
add_loop_inductor_gpu,compile_time_instruction_count,24910000000,0.015
|
add_loop_inductor_gpu,compile_time_instruction_count,24330000000,0.015
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
basic_modules_ListOfLinears_eager,compile_time_instruction_count,960300000,0.015
|
basic_modules_ListOfLinears_eager,compile_time_instruction_count,954200000,0.015
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
basic_modules_ListOfLinears_inductor,compile_time_instruction_count,17840000000,0.015
|
basic_modules_ListOfLinears_inductor,compile_time_instruction_count,17410000000,0.015
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,15940000000,0.015
|
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,15620000000,0.015
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,10040000000,0.2
|
basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,9716000000,0.2
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
update_hint_regression,compile_time_instruction_count,1593000000,0.02
|
update_hint_regression,compile_time_instruction_count,1548000000,0.02
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
sum_floordiv_regression,compile_time_instruction_count,1052000000,0.015
|
sum_floordiv_regression,compile_time_instruction_count,1039000000,0.015
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
symint_sum,compile_time_instruction_count,3135000000,0.015
|
symint_sum,compile_time_instruction_count,3075000000,0.015
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2004000000,0.015
|
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,1980000000,0.015
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5753000000,0.015
|
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5702000000,0.015
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
aotdispatcher_partitioner_cpu,compile_time_instruction_count,8484000000,0.015
|
aotdispatcher_partitioner_cpu,compile_time_instruction_count,8018000000,0.015
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3718000000,0.015
|
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3609000000,0.015
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10020000000,0.015
|
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,9809000000,0.015
|
||||||
|
|
|
||||||
|
|
|
@ -2528,6 +2528,12 @@ class _NodeBase:
|
||||||
return_type: Any,
|
return_type: Any,
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
def _update_args_kwargs(self, args: tuple[Any, ...], kwargs: dict[str, Any]): ...
|
def _update_args_kwargs(self, args: tuple[Any, ...], kwargs: dict[str, Any]): ...
|
||||||
|
def _prepend(self, n: FxNode) -> None: ...
|
||||||
|
def _remove_from_list(self) -> None: ...
|
||||||
|
def __lt__(self, n: Self) -> _bool: ...
|
||||||
|
def __gt__(self, n: Self) -> _bool: ...
|
||||||
|
def __le__(self, n: Self) -> _bool: ...
|
||||||
|
def __ge__(self, n: Self) -> _bool: ...
|
||||||
|
|
||||||
class _NodeIter(Iterator):
|
class _NodeIter(Iterator):
|
||||||
def __init__(self, root: FxNode, reversed: _bool) -> None: ...
|
def __init__(self, root: FxNode, reversed: _bool) -> None: ...
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,14 @@
|
||||||
#include <torch/csrc/fx/node.h>
|
#include <torch/csrc/fx/node.h>
|
||||||
|
|
||||||
|
#include <c10/util/SmallVector.h>
|
||||||
#include <structmember.h>
|
#include <structmember.h>
|
||||||
#include <torch/csrc/utils/object_ptr.h>
|
#include <torch/csrc/utils/object_ptr.h>
|
||||||
#include <torch/csrc/utils/pythoncapi_compat.h>
|
#include <torch/csrc/utils/pythoncapi_compat.h>
|
||||||
|
#include <algorithm>
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
using NodeSortKey = c10::SmallVector<int64_t, 4>;
|
||||||
struct NodeBase;
|
struct NodeBase;
|
||||||
|
|
||||||
// Thrown to exit out of a C++ function and return an error to Python.
|
// Thrown to exit out of a C++ function and return an error to Python.
|
||||||
|
|
@ -163,7 +166,22 @@ struct NodeBase {
|
||||||
PyObject* users;
|
PyObject* users;
|
||||||
PyObject* _repr_fn;
|
PyObject* _repr_fn;
|
||||||
PyObject* meta;
|
PyObject* meta;
|
||||||
PyObject* _sort_key;
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
|
||||||
|
alignas(NodeSortKey) char sort_key_buf[sizeof(NodeSortKey)];
|
||||||
|
|
||||||
|
inline NodeSortKey& sort_key() {
|
||||||
|
return *reinterpret_cast<NodeSortKey*>(sort_key_buf);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Equivalent to:
|
||||||
|
// p, n = self._prev, self._next
|
||||||
|
// p._next, n._prev = n, p
|
||||||
|
inline void remove_from_list() {
|
||||||
|
NodeBase* p = this->_prev;
|
||||||
|
NodeBase* n = this->_next;
|
||||||
|
p->_next = n;
|
||||||
|
n->_prev = p;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
static PyObject* NodeBase_new(
|
static PyObject* NodeBase_new(
|
||||||
|
|
@ -173,6 +191,8 @@ static PyObject* NodeBase_new(
|
||||||
PyObject* self = type->tp_alloc(type, 0);
|
PyObject* self = type->tp_alloc(type, 0);
|
||||||
if (!self)
|
if (!self)
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
new (reinterpret_cast<NodeBase*>(self)->sort_key_buf)
|
||||||
|
NodeSortKey(); // placement new does not allocate
|
||||||
return self;
|
return self;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -201,7 +221,6 @@ static int NodeBase_init_fn(NodeBase* self, PyObject* args, PyObject* kwds) {
|
||||||
self->users = PyDict_New();
|
self->users = PyDict_New();
|
||||||
self->_repr_fn = Py_NewRef(Py_None);
|
self->_repr_fn = Py_NewRef(Py_None);
|
||||||
self->meta = PyDict_New();
|
self->meta = PyDict_New();
|
||||||
self->_sort_key = PyTuple_New(0);
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -221,7 +240,6 @@ static struct PyMemberDef NodeBase_members[] = {
|
||||||
{"users", T_OBJECT_EX, offsetof(NodeBase, users), 0, nullptr},
|
{"users", T_OBJECT_EX, offsetof(NodeBase, users), 0, nullptr},
|
||||||
{"_repr_fn", T_OBJECT_EX, offsetof(NodeBase, _repr_fn), 0, nullptr},
|
{"_repr_fn", T_OBJECT_EX, offsetof(NodeBase, _repr_fn), 0, nullptr},
|
||||||
{"meta", T_OBJECT_EX, offsetof(NodeBase, meta), 0, nullptr},
|
{"meta", T_OBJECT_EX, offsetof(NodeBase, meta), 0, nullptr},
|
||||||
{"_sort_key", T_OBJECT_EX, offsetof(NodeBase, _sort_key), 0, nullptr},
|
|
||||||
{nullptr} /* Sentinel */
|
{nullptr} /* Sentinel */
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -239,7 +257,6 @@ static int NodeBase_traverse(NodeBase* self, visitproc visit, void* arg) {
|
||||||
Py_VISIT(self->users);
|
Py_VISIT(self->users);
|
||||||
Py_VISIT(self->_repr_fn);
|
Py_VISIT(self->_repr_fn);
|
||||||
Py_VISIT(self->meta);
|
Py_VISIT(self->meta);
|
||||||
Py_VISIT(self->_sort_key);
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -257,12 +274,12 @@ static int NodeBase_clear(NodeBase* self) {
|
||||||
Py_CLEAR(self->users);
|
Py_CLEAR(self->users);
|
||||||
Py_CLEAR(self->_repr_fn);
|
Py_CLEAR(self->_repr_fn);
|
||||||
Py_CLEAR(self->meta);
|
Py_CLEAR(self->meta);
|
||||||
Py_CLEAR(self->_sort_key);
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void NodeBase_dealloc(PyObject* self) {
|
static void NodeBase_dealloc(PyObject* self) {
|
||||||
PyObject_GC_UnTrack(self);
|
PyObject_GC_UnTrack(self);
|
||||||
|
reinterpret_cast<NodeBase*>(self)->sort_key().~NodeSortKey();
|
||||||
(void)NodeBase_clear((NodeBase*)self);
|
(void)NodeBase_clear((NodeBase*)self);
|
||||||
Py_TYPE(self)->tp_free(self);
|
Py_TYPE(self)->tp_free(self);
|
||||||
}
|
}
|
||||||
|
|
@ -321,15 +338,191 @@ static PyObject* NodeBase__update_args_kwargs(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static PyObject* NodeBase__remove_from_list(
|
||||||
|
PyObject* self,
|
||||||
|
PyObject* _ignored) {
|
||||||
|
reinterpret_cast<NodeBase*>(self)->remove_from_list();
|
||||||
|
Py_RETURN_NONE;
|
||||||
|
}
|
||||||
|
|
||||||
|
static PyObject* NodeBase__prepend(PyObject* self_, PyObject* arg) {
|
||||||
|
if (self_ == arg) {
|
||||||
|
Py_RETURN_NONE;
|
||||||
|
}
|
||||||
|
if (!is_node(arg)) {
|
||||||
|
PyErr_SetString(PyExc_TypeError, "_prepend() argument must be a Node");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
NodeBase* self = reinterpret_cast<NodeBase*>(self_);
|
||||||
|
NodeBase* x = reinterpret_cast<NodeBase*>(arg);
|
||||||
|
if (self->graph != x->graph) {
|
||||||
|
PyErr_SetString(
|
||||||
|
PyExc_AssertionError,
|
||||||
|
"Attempting to move a Node into a different Graph");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
x->remove_from_list();
|
||||||
|
NodeBase* p = self->_prev;
|
||||||
|
p->_next = x;
|
||||||
|
x->_prev = p;
|
||||||
|
x->_next = self;
|
||||||
|
self->_prev = x;
|
||||||
|
|
||||||
|
// Now compute x.sort_key()
|
||||||
|
const NodeSortKey& psk = x->_prev->sort_key();
|
||||||
|
const NodeSortKey& nsk = x->_next->sort_key();
|
||||||
|
if (psk.size() > nsk.size()) {
|
||||||
|
// prefix = psk[: len(nsk)+1]
|
||||||
|
size_t slice_len = nsk.size() + 1;
|
||||||
|
NodeSortKey prefix(psk.begin(), psk.begin() + slice_len);
|
||||||
|
// last element is idx => increment by 1
|
||||||
|
prefix.back()++;
|
||||||
|
x->sort_key() = std::move(prefix);
|
||||||
|
} else if (psk.size() < nsk.size()) {
|
||||||
|
// prefix = nsk[: len(psk)+1]
|
||||||
|
size_t slice_len = psk.size() + 1;
|
||||||
|
NodeSortKey prefix(nsk.begin(), nsk.begin() + slice_len);
|
||||||
|
// last element is idx => decrement by 1
|
||||||
|
prefix.back()--;
|
||||||
|
x->sort_key() = std::move(prefix);
|
||||||
|
} else {
|
||||||
|
// same length => add a 0
|
||||||
|
x->sort_key() = psk;
|
||||||
|
x->sort_key().emplace_back(0);
|
||||||
|
}
|
||||||
|
Py_RETURN_NONE;
|
||||||
|
}
|
||||||
|
|
||||||
|
// __lt__(self, other): Return self.sort_key < other.sort_key
|
||||||
|
static PyObject* NodeBase___lt__(PyObject* self, PyObject* other) {
|
||||||
|
// METH_O => one argument: 'other'
|
||||||
|
if (!is_node(other)) {
|
||||||
|
Py_RETURN_NOTIMPLEMENTED;
|
||||||
|
}
|
||||||
|
const NodeSortKey& lhs = reinterpret_cast<NodeBase*>(self)->sort_key();
|
||||||
|
const NodeSortKey& rhs = reinterpret_cast<NodeBase*>(other)->sort_key();
|
||||||
|
bool less = std::lexicographical_compare(
|
||||||
|
lhs.begin(), lhs.end(), rhs.begin(), rhs.end());
|
||||||
|
if (less)
|
||||||
|
Py_RETURN_TRUE;
|
||||||
|
Py_RETURN_FALSE;
|
||||||
|
}
|
||||||
|
|
||||||
|
// __gt__(self, other): Return self.sort_key() > other.sort_key
|
||||||
|
static PyObject* NodeBase___gt__(PyObject* self, PyObject* other) {
|
||||||
|
if (!is_node(other)) {
|
||||||
|
Py_RETURN_NOTIMPLEMENTED;
|
||||||
|
}
|
||||||
|
const NodeSortKey& lhs = reinterpret_cast<NodeBase*>(self)->sort_key();
|
||||||
|
const NodeSortKey& rhs = reinterpret_cast<NodeBase*>(other)->sort_key();
|
||||||
|
// "a > b" is equivalent to "b < a"
|
||||||
|
bool greater = std::lexicographical_compare(
|
||||||
|
rhs.begin(), rhs.end(), lhs.begin(), lhs.end());
|
||||||
|
if (greater)
|
||||||
|
Py_RETURN_TRUE;
|
||||||
|
Py_RETURN_FALSE;
|
||||||
|
}
|
||||||
|
|
||||||
|
static PyObject* NodeBase___ge__(PyObject* self, PyObject* other) {
|
||||||
|
if (self == other) {
|
||||||
|
Py_RETURN_TRUE;
|
||||||
|
}
|
||||||
|
return NodeBase___gt__(self, other);
|
||||||
|
}
|
||||||
|
|
||||||
|
// __le__(self, other): Return not (self > other)
|
||||||
|
static PyObject* NodeBase___le__(PyObject* self, PyObject* other) {
|
||||||
|
if (self == other) {
|
||||||
|
Py_RETURN_TRUE;
|
||||||
|
}
|
||||||
|
return NodeBase___lt__(self, other);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert the NodeBase::sort_key vector<long> into a Python tuple of ints
|
||||||
|
// Only used by pickle/__getstate__
|
||||||
|
static PyObject* NodeBase_get_sort_key(PyObject* self, void* /*closure*/) {
|
||||||
|
NodeBase* node = reinterpret_cast<NodeBase*>(self);
|
||||||
|
const NodeSortKey& vec = node->sort_key();
|
||||||
|
Py_ssize_t n = static_cast<Py_ssize_t>(vec.size());
|
||||||
|
THPObjectPtr tuple(PyTuple_New(n));
|
||||||
|
if (!tuple) {
|
||||||
|
return nullptr; // Out of memory
|
||||||
|
}
|
||||||
|
for (Py_ssize_t i = 0; i < n; i++) {
|
||||||
|
PyTuple_SET_ITEM(tuple.get(), i, PyLong_FromSsize_t(vec[i]));
|
||||||
|
}
|
||||||
|
return tuple.release();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setter for NodeBase::sort_key: expects a Python tuple of ints, e.g.
|
||||||
|
// node._sort_key = (1,2,3) Only used by pickle/__setstate__
|
||||||
|
static int NodeBase_set_sort_key(
|
||||||
|
PyObject* self,
|
||||||
|
PyObject* value,
|
||||||
|
void* /*closure*/) {
|
||||||
|
NodeBase* node = reinterpret_cast<NodeBase*>(self);
|
||||||
|
if (!PyTuple_Check(value)) {
|
||||||
|
PyErr_SetString(PyExc_TypeError, "_sort_key must be an tuple of ints");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
Py_ssize_t size = PyTuple_GET_SIZE(value);
|
||||||
|
NodeSortKey new_vec;
|
||||||
|
new_vec.reserve(size);
|
||||||
|
for (Py_ssize_t i = 0; i < size; i++) {
|
||||||
|
int64_t val = PyLong_AsSsize_t(PyTuple_GET_ITEM(value, i));
|
||||||
|
if (val == -1 && PyErr_Occurred()) {
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
new_vec.emplace_back(val);
|
||||||
|
}
|
||||||
|
node->sort_key() = std::move(new_vec);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
|
||||||
static PyMethodDef NodeBase_methods[] = {
|
static PyMethodDef NodeBase_methods[] = {
|
||||||
{"_update_args_kwargs",
|
{"_update_args_kwargs",
|
||||||
(PyCFunction)(void*)(NodeBase__update_args_kwargs),
|
(PyCFunction)(void*)(NodeBase__update_args_kwargs),
|
||||||
METH_FASTCALL,
|
METH_FASTCALL,
|
||||||
"Internal method: do not call directly."},
|
"Internal method: do not call directly."},
|
||||||
|
{"_remove_from_list",
|
||||||
|
(PyCFunction)(void*)(NodeBase__remove_from_list),
|
||||||
|
METH_NOARGS,
|
||||||
|
"Internal method: do not call directly."},
|
||||||
|
{"_prepend",
|
||||||
|
(PyCFunction)(void*)(NodeBase__prepend),
|
||||||
|
METH_O,
|
||||||
|
"Internal method: do not call directly."},
|
||||||
|
{"__lt__",
|
||||||
|
(PyCFunction)(void*)NodeBase___lt__,
|
||||||
|
METH_O,
|
||||||
|
"Return True if self.sort_key < other.sort_key"},
|
||||||
|
{"__gt__",
|
||||||
|
(PyCFunction)(void*)NodeBase___gt__,
|
||||||
|
METH_O,
|
||||||
|
"Return True if self.sort_key > other.sort_key"},
|
||||||
|
{"__ge__",
|
||||||
|
(PyCFunction)(void*)NodeBase___ge__,
|
||||||
|
METH_O,
|
||||||
|
"Return True if self.sort_key >= other.sort_key"},
|
||||||
|
{"__le__",
|
||||||
|
(PyCFunction)(void*)NodeBase___le__,
|
||||||
|
METH_O,
|
||||||
|
"Return True if self.sort_key <= other.sort_key"},
|
||||||
{nullptr, nullptr, 0, nullptr} // Sentinel
|
{nullptr, nullptr, 0, nullptr} // Sentinel
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
|
||||||
|
static PyGetSetDef NodeBase_getset[] = {
|
||||||
|
{"_sort_key", // attribute name in Python
|
||||||
|
(getter)NodeBase_get_sort_key, // C getter function
|
||||||
|
(setter)NodeBase_set_sort_key, // C setter function
|
||||||
|
(char*)"The sort key as a tuple of ints", // docstring
|
||||||
|
nullptr},
|
||||||
|
{nullptr, nullptr, nullptr, nullptr, nullptr} // Sentinel
|
||||||
|
};
|
||||||
|
|
||||||
PyTypeObject NodeBaseType = {
|
PyTypeObject NodeBaseType = {
|
||||||
PyVarObject_HEAD_INIT(nullptr, 0)
|
PyVarObject_HEAD_INIT(nullptr, 0)
|
||||||
"torch._C._NodeBase", /* tp_name */
|
"torch._C._NodeBase", /* tp_name */
|
||||||
|
|
@ -361,7 +554,7 @@ PyTypeObject NodeBaseType = {
|
||||||
nullptr, /* tp_iternext */
|
nullptr, /* tp_iternext */
|
||||||
NodeBase_methods, /* tp_methods */
|
NodeBase_methods, /* tp_methods */
|
||||||
NodeBase_members, /* tp_members */
|
NodeBase_members, /* tp_members */
|
||||||
nullptr, /* tp_getset */
|
NodeBase_getset, /* tp_getset */
|
||||||
nullptr, /* tp_base */
|
nullptr, /* tp_base */
|
||||||
nullptr, /* tp_dict */
|
nullptr, /* tp_dict */
|
||||||
nullptr, /* tp_descr_get */
|
nullptr, /* tp_descr_get */
|
||||||
|
|
|
||||||
|
|
@ -375,41 +375,7 @@ class Node(_NodeBase):
|
||||||
Args:
|
Args:
|
||||||
x (Node): The node to put before this node. Must be a member of the same graph.
|
x (Node): The node to put before this node. Must be a member of the same graph.
|
||||||
"""
|
"""
|
||||||
assert self.graph == x.graph, "Attempting to move a Node into a different Graph"
|
self._prepend(x)
|
||||||
if self == x:
|
|
||||||
log.debug(
|
|
||||||
"Trying to prepend a node to itself. This behavior has no effect on the graph."
|
|
||||||
)
|
|
||||||
return
|
|
||||||
x._remove_from_list()
|
|
||||||
p = self._prev
|
|
||||||
p._next, x._prev = x, p
|
|
||||||
x._next, self._prev = self, x
|
|
||||||
|
|
||||||
# compute x._sort_key
|
|
||||||
psk = x._prev._sort_key
|
|
||||||
nsk = x._next._sort_key
|
|
||||||
if len(psk) > len(nsk):
|
|
||||||
idx: int
|
|
||||||
*prefix, idx = psk[: len(nsk) + 1]
|
|
||||||
x._sort_key = (*prefix, idx + 1)
|
|
||||||
elif len(psk) < len(nsk):
|
|
||||||
*prefix, idx = nsk[: len(psk) + 1]
|
|
||||||
x._sort_key = (*prefix, idx - 1)
|
|
||||||
else: # same length, increase length by 1
|
|
||||||
x._sort_key = (*psk, 0)
|
|
||||||
|
|
||||||
def __gt__(self, other: "Node") -> bool:
|
|
||||||
return self._sort_key > other._sort_key
|
|
||||||
|
|
||||||
def __lt__(self, other: "Node") -> bool:
|
|
||||||
return self._sort_key < other._sort_key
|
|
||||||
|
|
||||||
def __ge__(self, other: "Node") -> bool:
|
|
||||||
return self > other or self == other
|
|
||||||
|
|
||||||
def __le__(self, other: "Node") -> bool:
|
|
||||||
return self < other or self == other
|
|
||||||
|
|
||||||
@compatibility(is_backward_compatible=True)
|
@compatibility(is_backward_compatible=True)
|
||||||
def append(self, x: "Node") -> None:
|
def append(self, x: "Node") -> None:
|
||||||
|
|
@ -420,11 +386,7 @@ class Node(_NodeBase):
|
||||||
Args:
|
Args:
|
||||||
x (Node): The node to put after this node. Must be a member of the same graph.
|
x (Node): The node to put after this node. Must be a member of the same graph.
|
||||||
"""
|
"""
|
||||||
self._next.prepend(x)
|
self._next._prepend(x)
|
||||||
|
|
||||||
def _remove_from_list(self) -> None:
|
|
||||||
p, n = self._prev, self._next
|
|
||||||
p._next, n._prev = n, p
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def args(self) -> tuple[Argument, ...]:
|
def args(self) -> tuple[Argument, ...]:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user