mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[reland][fx] Move Node._prepend/Node._remove_from_list to C++ (#165882)
Relands #148261 that was reverted by #150542 Pull Request resolved: https://github.com/pytorch/pytorch/pull/165882 Approved by: https://github.com/ezyang
This commit is contained in:
parent
0bd12c1168
commit
3c3b278872
|
|
@ -1,8 +1,8 @@
|
||||||
add_loop_eager,compile_time_instruction_count,3070000000,0.1
|
add_loop_eager,compile_time_instruction_count,3184000000,0.1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
add_loop_eager_dynamic,compile_time_instruction_count,4432000000,0.1
|
add_loop_eager_dynamic,compile_time_instruction_count,4595000000,0.1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -18,7 +18,7 @@ add_loop_inductor_gpu,compile_time_instruction_count,26800000000,0.1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
basic_modules_ListOfLinears_eager,compile_time_instruction_count,1048000000,0.1
|
basic_modules_ListOfLinears_eager,compile_time_instruction_count,1096000000,0.1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -26,7 +26,7 @@ basic_modules_ListOfLinears_inductor,compile_time_instruction_count,15240000000,
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17020000000,0.1
|
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17720000000,0.1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -34,11 +34,11 @@ basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,11090000
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
update_hint_regression,compile_time_instruction_count,1719000000,0.1
|
update_hint_regression,compile_time_instruction_count,1645000000,0.1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
sum_floordiv_regression,compile_time_instruction_count,3686995725,0.1
|
sum_floordiv_regression,compile_time_instruction_count,3813000000,0.1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -50,31 +50,31 @@ symint_sum_loop,compile_time_instruction_count,4299000000,0.1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,1869000000,0.1
|
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,1793000000,0.1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5281000000,0.1
|
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5120000000,0.1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
aotdispatcher_partitioner_cpu,compile_time_instruction_count,8333000000,0.1
|
aotdispatcher_partitioner_cpu,compile_time_instruction_count,7936000000,0.1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1909000000,0.1
|
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1848000000,0.1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3442000000,0.1
|
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3152000000,0.1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,9239000000,0.1
|
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,8301000000,0.1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
mm_loop_inductor_gpu,compile_time_instruction_count,4820968837,0.1
|
mm_loop_inductor_gpu,compile_time_instruction_count,4958000000,0.1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -82,8 +82,8 @@ mm_loop_inductor_dynamic_gpu,compile_time_instruction_count,9051000000,0.1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
basic_NestedModule_eager,compile_time_instruction_count,9554000000,0.1
|
basic_NestedModule_eager,compile_time_instruction_count,9990000000,0.1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
basic_InlineMod_eager,compile_time_instruction_count,7618000000,0.1
|
basic_InlineMod_eager,compile_time_instruction_count,8126000000,0.1
|
||||||
|
|
|
||||||
|
|
|
@ -6,6 +6,7 @@ import builtins
|
||||||
import collections
|
import collections
|
||||||
import contextlib
|
import contextlib
|
||||||
import copy
|
import copy
|
||||||
|
import gc
|
||||||
import functools
|
import functools
|
||||||
import inspect
|
import inspect
|
||||||
import io
|
import io
|
||||||
|
|
@ -19,6 +20,7 @@ import traceback
|
||||||
import types
|
import types
|
||||||
import typing
|
import typing
|
||||||
import unittest
|
import unittest
|
||||||
|
import weakref
|
||||||
import warnings
|
import warnings
|
||||||
from math import sqrt
|
from math import sqrt
|
||||||
from torch.multiprocessing import Process
|
from torch.multiprocessing import Process
|
||||||
|
|
@ -1624,6 +1626,25 @@ class TestFX(JitTestCase):
|
||||||
|
|
||||||
self.assertTrue(neg not in relu.users)
|
self.assertTrue(neg not in relu.users)
|
||||||
|
|
||||||
|
@skipIfTorchDynamo("Dynamo does not free right away")
|
||||||
|
def test_prepend_does_not_leak(self):
|
||||||
|
g = Graph()
|
||||||
|
x = g.placeholder("x")
|
||||||
|
relu = g.call_function(torch.relu, (x,))
|
||||||
|
neg = g.call_function(torch.neg, (x,))
|
||||||
|
|
||||||
|
relu.prepend(neg)
|
||||||
|
|
||||||
|
ref = weakref.ref(neg)
|
||||||
|
g.erase_node(neg)
|
||||||
|
del g
|
||||||
|
del x
|
||||||
|
del relu
|
||||||
|
del neg
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
self.assertIsNone(ref())
|
||||||
|
|
||||||
def test_remove_uses_with_custom_filter(self):
|
def test_remove_uses_with_custom_filter(self):
|
||||||
g: torch.fx.Graph = Graph()
|
g: torch.fx.Graph = Graph()
|
||||||
x: torch.fx.Node = g.placeholder("x")
|
x: torch.fx.Node = g.placeholder("x")
|
||||||
|
|
|
||||||
|
|
@ -2758,6 +2758,12 @@ class _NodeBase:
|
||||||
return_type: Any,
|
return_type: Any,
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
def _update_args_kwargs(self, args: tuple[Any, ...], kwargs: dict[str, Any]): ...
|
def _update_args_kwargs(self, args: tuple[Any, ...], kwargs: dict[str, Any]): ...
|
||||||
|
def _prepend(self, n: FxNode) -> None: ...
|
||||||
|
def _remove_from_list(self) -> None: ...
|
||||||
|
def __lt__(self, n: Self) -> _bool: ...
|
||||||
|
def __gt__(self, n: Self) -> _bool: ...
|
||||||
|
def __le__(self, n: Self) -> _bool: ...
|
||||||
|
def __ge__(self, n: Self) -> _bool: ...
|
||||||
|
|
||||||
class _NodeIter(Iterator[FxNode]):
|
class _NodeIter(Iterator[FxNode]):
|
||||||
def __init__(self, root: FxNode, reversed: _bool) -> None: ...
|
def __init__(self, root: FxNode, reversed: _bool) -> None: ...
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,15 @@
|
||||||
#include <torch/csrc/fx/node.h>
|
#include <torch/csrc/fx/node.h>
|
||||||
|
|
||||||
|
#include <c10/util/Exception.h>
|
||||||
|
#include <c10/util/SmallVector.h>
|
||||||
#include <structmember.h>
|
#include <structmember.h>
|
||||||
#include <torch/csrc/utils/object_ptr.h>
|
#include <torch/csrc/utils/object_ptr.h>
|
||||||
#include <torch/csrc/utils/pythoncapi_compat.h>
|
#include <torch/csrc/utils/pythoncapi_compat.h>
|
||||||
|
#include <algorithm>
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
using NodeSortKey = c10::SmallVector<int64_t, 4>;
|
||||||
struct NodeBase;
|
struct NodeBase;
|
||||||
|
|
||||||
// Thrown to exit out of a C++ function and return an error to Python.
|
// Thrown to exit out of a C++ function and return an error to Python.
|
||||||
|
|
@ -163,7 +167,41 @@ struct NodeBase {
|
||||||
PyObject* users;
|
PyObject* users;
|
||||||
PyObject* _repr_fn;
|
PyObject* _repr_fn;
|
||||||
PyObject* meta;
|
PyObject* meta;
|
||||||
PyObject* _sort_key;
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
|
||||||
|
alignas(NodeSortKey) char sort_key_buf[sizeof(NodeSortKey)];
|
||||||
|
|
||||||
|
inline NodeSortKey& sort_key() {
|
||||||
|
return *reinterpret_cast<NodeSortKey*>(sort_key_buf);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void set_prev(NodeBase* value) {
|
||||||
|
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(value);
|
||||||
|
Py_INCREF(reinterpret_cast<PyObject*>(value));
|
||||||
|
NodeBase* old = _prev;
|
||||||
|
_prev = value;
|
||||||
|
Py_DECREF(reinterpret_cast<PyObject*>(old));
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void set_next(NodeBase* value) {
|
||||||
|
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(value);
|
||||||
|
Py_INCREF(reinterpret_cast<PyObject*>(value));
|
||||||
|
NodeBase* old = _next;
|
||||||
|
_next = value;
|
||||||
|
Py_DECREF(reinterpret_cast<PyObject*>(old));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Equivalent to:
|
||||||
|
// p, n = self._prev, self._next
|
||||||
|
// p._next, n._prev = n, p
|
||||||
|
inline void remove_from_list() {
|
||||||
|
if (this->_prev == this && this->_next == this) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
NodeBase* p = this->_prev;
|
||||||
|
NodeBase* n = this->_next;
|
||||||
|
p->set_next(n);
|
||||||
|
n->set_prev(p);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
static PyObject* NodeBase_new(
|
static PyObject* NodeBase_new(
|
||||||
|
|
@ -173,6 +211,8 @@ static PyObject* NodeBase_new(
|
||||||
PyObject* self = type->tp_alloc(type, 0);
|
PyObject* self = type->tp_alloc(type, 0);
|
||||||
if (!self)
|
if (!self)
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
new (reinterpret_cast<NodeBase*>(self)->sort_key_buf)
|
||||||
|
NodeSortKey(); // placement new does not allocate
|
||||||
return self;
|
return self;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -201,7 +241,6 @@ static int NodeBase_init_fn(NodeBase* self, PyObject* args, PyObject* kwds) {
|
||||||
self->users = PyDict_New();
|
self->users = PyDict_New();
|
||||||
self->_repr_fn = Py_NewRef(Py_None);
|
self->_repr_fn = Py_NewRef(Py_None);
|
||||||
self->meta = PyDict_New();
|
self->meta = PyDict_New();
|
||||||
self->_sort_key = PyTuple_New(0);
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -221,7 +260,6 @@ static struct PyMemberDef NodeBase_members[] = {
|
||||||
{"users", T_OBJECT_EX, offsetof(NodeBase, users), 0, nullptr},
|
{"users", T_OBJECT_EX, offsetof(NodeBase, users), 0, nullptr},
|
||||||
{"_repr_fn", T_OBJECT_EX, offsetof(NodeBase, _repr_fn), 0, nullptr},
|
{"_repr_fn", T_OBJECT_EX, offsetof(NodeBase, _repr_fn), 0, nullptr},
|
||||||
{"meta", T_OBJECT_EX, offsetof(NodeBase, meta), 0, nullptr},
|
{"meta", T_OBJECT_EX, offsetof(NodeBase, meta), 0, nullptr},
|
||||||
{"_sort_key", T_OBJECT_EX, offsetof(NodeBase, _sort_key), 0, nullptr},
|
|
||||||
{nullptr} /* Sentinel */
|
{nullptr} /* Sentinel */
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -239,7 +277,6 @@ static int NodeBase_traverse(NodeBase* self, visitproc visit, void* arg) {
|
||||||
Py_VISIT(self->users);
|
Py_VISIT(self->users);
|
||||||
Py_VISIT(self->_repr_fn);
|
Py_VISIT(self->_repr_fn);
|
||||||
Py_VISIT(self->meta);
|
Py_VISIT(self->meta);
|
||||||
Py_VISIT(self->_sort_key);
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -257,12 +294,12 @@ static int NodeBase_clear(NodeBase* self) {
|
||||||
Py_CLEAR(self->users);
|
Py_CLEAR(self->users);
|
||||||
Py_CLEAR(self->_repr_fn);
|
Py_CLEAR(self->_repr_fn);
|
||||||
Py_CLEAR(self->meta);
|
Py_CLEAR(self->meta);
|
||||||
Py_CLEAR(self->_sort_key);
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void NodeBase_dealloc(PyObject* self) {
|
static void NodeBase_dealloc(PyObject* self) {
|
||||||
PyObject_GC_UnTrack(self);
|
PyObject_GC_UnTrack(self);
|
||||||
|
reinterpret_cast<NodeBase*>(self)->sort_key().~NodeSortKey();
|
||||||
(void)NodeBase_clear((NodeBase*)self);
|
(void)NodeBase_clear((NodeBase*)self);
|
||||||
Py_TYPE(self)->tp_free(self);
|
Py_TYPE(self)->tp_free(self);
|
||||||
}
|
}
|
||||||
|
|
@ -321,15 +358,195 @@ static PyObject* NodeBase__update_args_kwargs(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static PyObject* NodeBase__remove_from_list(
|
||||||
|
PyObject* self,
|
||||||
|
PyObject* _ignored) {
|
||||||
|
reinterpret_cast<NodeBase*>(self)->remove_from_list();
|
||||||
|
Py_RETURN_NONE;
|
||||||
|
}
|
||||||
|
|
||||||
|
static PyObject* NodeBase__prepend(PyObject* self_, PyObject* arg) {
|
||||||
|
if (self_ == arg) {
|
||||||
|
Py_RETURN_NONE;
|
||||||
|
}
|
||||||
|
if (!is_node(arg)) {
|
||||||
|
PyErr_SetString(PyExc_TypeError, "_prepend() argument must be a Node");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
NodeBase* self = reinterpret_cast<NodeBase*>(self_);
|
||||||
|
NodeBase* x = reinterpret_cast<NodeBase*>(arg);
|
||||||
|
if (self->graph != x->graph) {
|
||||||
|
PyErr_SetString(
|
||||||
|
PyExc_AssertionError,
|
||||||
|
"Attempting to move a Node into a different Graph");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
x->remove_from_list();
|
||||||
|
NodeBase* p = self->_prev;
|
||||||
|
p->set_next(x);
|
||||||
|
x->set_prev(p);
|
||||||
|
x->set_next(self);
|
||||||
|
self->set_prev(x);
|
||||||
|
|
||||||
|
// Now compute x.sort_key()
|
||||||
|
const NodeSortKey& psk = x->_prev->sort_key();
|
||||||
|
const NodeSortKey& nsk = x->_next->sort_key();
|
||||||
|
if (psk.size() > nsk.size()) {
|
||||||
|
// prefix = psk[: len(nsk)+1]
|
||||||
|
size_t slice_len = nsk.size() + 1;
|
||||||
|
NodeSortKey prefix(psk.begin(), psk.begin() + slice_len);
|
||||||
|
// last element is idx => increment by 1
|
||||||
|
prefix.back()++;
|
||||||
|
x->sort_key() = std::move(prefix);
|
||||||
|
} else if (psk.size() < nsk.size()) {
|
||||||
|
// prefix = nsk[: len(psk)+1]
|
||||||
|
size_t slice_len = psk.size() + 1;
|
||||||
|
NodeSortKey prefix(nsk.begin(), nsk.begin() + slice_len);
|
||||||
|
// last element is idx => decrement by 1
|
||||||
|
prefix.back()--;
|
||||||
|
x->sort_key() = std::move(prefix);
|
||||||
|
} else {
|
||||||
|
// same length => add a 0
|
||||||
|
x->sort_key() = psk;
|
||||||
|
x->sort_key().emplace_back(0);
|
||||||
|
}
|
||||||
|
Py_RETURN_NONE;
|
||||||
|
}
|
||||||
|
|
||||||
|
// __lt__(self, other): Return self.sort_key < other.sort_key
|
||||||
|
static PyObject* NodeBase___lt__(PyObject* self, PyObject* other) {
|
||||||
|
// METH_O => one argument: 'other'
|
||||||
|
if (!is_node(other)) {
|
||||||
|
Py_RETURN_NOTIMPLEMENTED;
|
||||||
|
}
|
||||||
|
const NodeSortKey& lhs = reinterpret_cast<NodeBase*>(self)->sort_key();
|
||||||
|
const NodeSortKey& rhs = reinterpret_cast<NodeBase*>(other)->sort_key();
|
||||||
|
bool less = std::lexicographical_compare(
|
||||||
|
lhs.begin(), lhs.end(), rhs.begin(), rhs.end());
|
||||||
|
if (less)
|
||||||
|
Py_RETURN_TRUE;
|
||||||
|
Py_RETURN_FALSE;
|
||||||
|
}
|
||||||
|
|
||||||
|
// __gt__(self, other): Return self.sort_key() > other.sort_key
|
||||||
|
static PyObject* NodeBase___gt__(PyObject* self, PyObject* other) {
|
||||||
|
if (!is_node(other)) {
|
||||||
|
Py_RETURN_NOTIMPLEMENTED;
|
||||||
|
}
|
||||||
|
const NodeSortKey& lhs = reinterpret_cast<NodeBase*>(self)->sort_key();
|
||||||
|
const NodeSortKey& rhs = reinterpret_cast<NodeBase*>(other)->sort_key();
|
||||||
|
// "a > b" is equivalent to "b < a"
|
||||||
|
bool greater = std::lexicographical_compare(
|
||||||
|
rhs.begin(), rhs.end(), lhs.begin(), lhs.end());
|
||||||
|
if (greater)
|
||||||
|
Py_RETURN_TRUE;
|
||||||
|
Py_RETURN_FALSE;
|
||||||
|
}
|
||||||
|
|
||||||
|
static PyObject* NodeBase___ge__(PyObject* self, PyObject* other) {
|
||||||
|
if (self == other) {
|
||||||
|
Py_RETURN_TRUE;
|
||||||
|
}
|
||||||
|
return NodeBase___gt__(self, other);
|
||||||
|
}
|
||||||
|
|
||||||
|
// __le__(self, other): Return not (self > other)
|
||||||
|
static PyObject* NodeBase___le__(PyObject* self, PyObject* other) {
|
||||||
|
if (self == other) {
|
||||||
|
Py_RETURN_TRUE;
|
||||||
|
}
|
||||||
|
return NodeBase___lt__(self, other);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert the NodeBase::sort_key vector<long> into a Python tuple of ints
|
||||||
|
// Only used by pickle/__getstate__
|
||||||
|
static PyObject* NodeBase_get_sort_key(PyObject* self, void* /*closure*/) {
|
||||||
|
NodeBase* node = reinterpret_cast<NodeBase*>(self);
|
||||||
|
const NodeSortKey& vec = node->sort_key();
|
||||||
|
Py_ssize_t n = static_cast<Py_ssize_t>(vec.size());
|
||||||
|
THPObjectPtr tuple(PyTuple_New(n));
|
||||||
|
if (!tuple) {
|
||||||
|
return nullptr; // Out of memory
|
||||||
|
}
|
||||||
|
for (Py_ssize_t i = 0; i < n; i++) {
|
||||||
|
PyObject* value = PyLong_FromSsize_t(vec[i]);
|
||||||
|
if (!value) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
PyTuple_SET_ITEM(tuple.get(), i, value);
|
||||||
|
}
|
||||||
|
return tuple.release();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setter for NodeBase::sort_key: expects a Python tuple of ints, e.g.
|
||||||
|
// node._sort_key = (1,2,3) Only used by pickle/__setstate__
|
||||||
|
static int NodeBase_set_sort_key(
|
||||||
|
PyObject* self,
|
||||||
|
PyObject* value,
|
||||||
|
void* /*closure*/) {
|
||||||
|
NodeBase* node = reinterpret_cast<NodeBase*>(self);
|
||||||
|
if (!PyTuple_Check(value)) {
|
||||||
|
PyErr_SetString(PyExc_TypeError, "_sort_key must be an tuple of ints");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
Py_ssize_t size = PyTuple_GET_SIZE(value);
|
||||||
|
NodeSortKey new_vec;
|
||||||
|
new_vec.reserve(size);
|
||||||
|
for (Py_ssize_t i = 0; i < size; i++) {
|
||||||
|
int64_t val = PyLong_AsSsize_t(PyTuple_GET_ITEM(value, i));
|
||||||
|
if (val == -1 && PyErr_Occurred()) {
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
new_vec.emplace_back(val);
|
||||||
|
}
|
||||||
|
node->sort_key() = std::move(new_vec);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
|
||||||
static PyMethodDef NodeBase_methods[] = {
|
static PyMethodDef NodeBase_methods[] = {
|
||||||
{"_update_args_kwargs",
|
{"_update_args_kwargs",
|
||||||
(PyCFunction)(void*)(NodeBase__update_args_kwargs),
|
(PyCFunction)(void*)(NodeBase__update_args_kwargs),
|
||||||
METH_FASTCALL,
|
METH_FASTCALL,
|
||||||
"Internal method: do not call directly."},
|
"Internal method: do not call directly."},
|
||||||
|
{"_remove_from_list",
|
||||||
|
(PyCFunction)(void*)(NodeBase__remove_from_list),
|
||||||
|
METH_NOARGS,
|
||||||
|
"Internal method: do not call directly."},
|
||||||
|
{"_prepend",
|
||||||
|
(PyCFunction)(void*)(NodeBase__prepend),
|
||||||
|
METH_O,
|
||||||
|
"Internal method: do not call directly."},
|
||||||
|
{"__lt__",
|
||||||
|
(PyCFunction)(void*)NodeBase___lt__,
|
||||||
|
METH_O,
|
||||||
|
"Return True if self.sort_key < other.sort_key"},
|
||||||
|
{"__gt__",
|
||||||
|
(PyCFunction)(void*)NodeBase___gt__,
|
||||||
|
METH_O,
|
||||||
|
"Return True if self.sort_key > other.sort_key"},
|
||||||
|
{"__ge__",
|
||||||
|
(PyCFunction)(void*)NodeBase___ge__,
|
||||||
|
METH_O,
|
||||||
|
"Return True if self.sort_key >= other.sort_key"},
|
||||||
|
{"__le__",
|
||||||
|
(PyCFunction)(void*)NodeBase___le__,
|
||||||
|
METH_O,
|
||||||
|
"Return True if self.sort_key <= other.sort_key"},
|
||||||
{nullptr, nullptr, 0, nullptr} // Sentinel
|
{nullptr, nullptr, 0, nullptr} // Sentinel
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
|
||||||
|
static PyGetSetDef NodeBase_getset[] = {
|
||||||
|
{"_sort_key", // attribute name in Python
|
||||||
|
(getter)NodeBase_get_sort_key, // C getter function
|
||||||
|
(setter)NodeBase_set_sort_key, // C setter function
|
||||||
|
(char*)"The sort key as a tuple of ints", // docstring
|
||||||
|
nullptr},
|
||||||
|
{nullptr, nullptr, nullptr, nullptr, nullptr} // Sentinel
|
||||||
|
};
|
||||||
|
|
||||||
PyTypeObject NodeBaseType = {
|
PyTypeObject NodeBaseType = {
|
||||||
PyVarObject_HEAD_INIT(nullptr, 0)
|
PyVarObject_HEAD_INIT(nullptr, 0)
|
||||||
"torch._C._NodeBase", /* tp_name */
|
"torch._C._NodeBase", /* tp_name */
|
||||||
|
|
@ -361,7 +578,7 @@ PyTypeObject NodeBaseType = {
|
||||||
nullptr, /* tp_iternext */
|
nullptr, /* tp_iternext */
|
||||||
NodeBase_methods, /* tp_methods */
|
NodeBase_methods, /* tp_methods */
|
||||||
NodeBase_members, /* tp_members */
|
NodeBase_members, /* tp_members */
|
||||||
nullptr, /* tp_getset */
|
NodeBase_getset, /* tp_getset */
|
||||||
nullptr, /* tp_base */
|
nullptr, /* tp_base */
|
||||||
nullptr, /* tp_dict */
|
nullptr, /* tp_dict */
|
||||||
nullptr, /* tp_descr_get */
|
nullptr, /* tp_descr_get */
|
||||||
|
|
|
||||||
|
|
@ -385,41 +385,7 @@ class Node(_NodeBase):
|
||||||
Args:
|
Args:
|
||||||
x (Node): The node to put before this node. Must be a member of the same graph.
|
x (Node): The node to put before this node. Must be a member of the same graph.
|
||||||
"""
|
"""
|
||||||
assert self.graph == x.graph, "Attempting to move a Node into a different Graph"
|
self._prepend(x)
|
||||||
if self == x:
|
|
||||||
log.debug(
|
|
||||||
"Trying to prepend a node to itself. This behavior has no effect on the graph."
|
|
||||||
)
|
|
||||||
return
|
|
||||||
x._remove_from_list()
|
|
||||||
p = self._prev
|
|
||||||
p._next, x._prev = x, p
|
|
||||||
x._next, self._prev = self, x
|
|
||||||
|
|
||||||
# compute x._sort_key
|
|
||||||
psk = x._prev._sort_key
|
|
||||||
nsk = x._next._sort_key
|
|
||||||
if len(psk) > len(nsk):
|
|
||||||
idx: int
|
|
||||||
*prefix, idx = psk[: len(nsk) + 1]
|
|
||||||
x._sort_key = (*prefix, idx + 1)
|
|
||||||
elif len(psk) < len(nsk):
|
|
||||||
*prefix, idx = nsk[: len(psk) + 1]
|
|
||||||
x._sort_key = (*prefix, idx - 1)
|
|
||||||
else: # same length, increase length by 1
|
|
||||||
x._sort_key = (*psk, 0)
|
|
||||||
|
|
||||||
def __gt__(self, other: "Node") -> bool:
|
|
||||||
return self._sort_key > other._sort_key
|
|
||||||
|
|
||||||
def __lt__(self, other: "Node") -> bool:
|
|
||||||
return self._sort_key < other._sort_key
|
|
||||||
|
|
||||||
def __ge__(self, other: "Node") -> bool:
|
|
||||||
return self > other or self == other
|
|
||||||
|
|
||||||
def __le__(self, other: "Node") -> bool:
|
|
||||||
return self < other or self == other
|
|
||||||
|
|
||||||
@compatibility(is_backward_compatible=True)
|
@compatibility(is_backward_compatible=True)
|
||||||
def append(self, x: "Node") -> None:
|
def append(self, x: "Node") -> None:
|
||||||
|
|
@ -430,11 +396,7 @@ class Node(_NodeBase):
|
||||||
Args:
|
Args:
|
||||||
x (Node): The node to put after this node. Must be a member of the same graph.
|
x (Node): The node to put after this node. Must be a member of the same graph.
|
||||||
"""
|
"""
|
||||||
self._next.prepend(x)
|
self._next._prepend(x)
|
||||||
|
|
||||||
def _remove_from_list(self) -> None:
|
|
||||||
p, n = self._prev, self._next
|
|
||||||
p._next, n._prev = n, p
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def args(self) -> tuple[Argument, ...]:
|
def args(self) -> tuple[Argument, ...]:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user