pytorch/torch/csrc/fx/node.cpp
cyy 8fa81a6066 Enable misc-use-internal-linkage check and apply fixes (#148948)
Enables clang-tidy rule [`misc-use-internal-linkage`](https://clang.llvm.org/extra/clang-tidy/checks/misc/use-internal-linkage.html). This new check was introduced in Clang-Tidy 18 and is available due to recent update of Clang-Tidy 19.

The check marks functions and variables used only in the translation unit as static. Therefore undesired symbols are not leaked into other units, more link time optimisations are possible and the resulting binaries may be smaller.

The detected violations were mostly fixed by using static. In other cases, the symbols were indeed consumed by others files, then their declaring headers were included. Still some declarations were wrong and have been fixed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148948
Approved by: https://github.com/Skylion007
2025-03-12 14:22:56 +00:00

789 lines
23 KiB
C++

#include <torch/csrc/fx/node.h>
#include <c10/util/SmallVector.h>
#include <structmember.h>
#include <torch/csrc/utils/object_ptr.h>
#include <torch/csrc/utils/pythoncapi_compat.h>
#include <algorithm>
namespace {
using NodeSortKey = c10::SmallVector<int64_t, 4>;
struct NodeBase;
// Thrown to exit out of a C++ function and return an error to Python.
class PythonError : public std::exception {};
inline static PyObject* import_from(const char* module_name, const char* name) {
THPObjectPtr module(PyImport_ImportModule(module_name));
if (!module) {
throw PythonError();
}
PyObject* result = PyObject_GetAttrString(module, name);
if (!result) {
throw PythonError();
}
return result;
}
inline static PyObject* immutable_list_cls() {
static PyObject* immutable_list_cls = nullptr;
if (!immutable_list_cls) {
immutable_list_cls =
import_from("torch.fx.immutable_collections", "immutable_list");
}
return immutable_list_cls;
}
inline static PyObject* immutable_dict_cls() {
static PyObject* immutable_dict_cls = nullptr;
if (!immutable_dict_cls) {
immutable_dict_cls =
import_from("torch.fx.immutable_collections", "immutable_dict");
}
return immutable_dict_cls;
}
inline static bool is_node(PyObject* obj) {
static PyObject* node_cls = nullptr;
if (!node_cls) {
node_cls = import_from("torch.fx.node", "Node");
}
return PyObject_TypeCheck(obj, reinterpret_cast<PyTypeObject*>(node_cls));
}
inline static bool exact_type(PyObject* obj, PyObject* typ) {
return Py_TYPE(obj) == reinterpret_cast<PyTypeObject*>(typ);
}
template <typename F>
inline static PyObject* map_aggregate(PyObject* a, F fn) {
// Invariant: this function will throw an exception and never return nullptr.
// Case 1: a is a tuple.
if (PyTuple_Check(a)) {
Py_ssize_t n = PyTuple_GET_SIZE(a);
if (n == 0 && PyTuple_CheckExact(a)) {
return Py_NewRef(a);
}
THPObjectPtr new_tuple(PyTuple_New(n));
if (!new_tuple) {
throw PythonError();
}
for (Py_ssize_t i = 0; i < n; i++) {
PyObject* elem = PyTuple_GET_ITEM(a, i); // Borrowed reference.
// PyTuple_SET_ITEM steals reference to result of map_aggregate
PyTuple_SET_ITEM(new_tuple.get(), i, map_aggregate(elem, fn));
}
// If the tuple has a "_fields" attribute, assume it is a NamedTuple.
if (!PyTuple_CheckExact(a) && PyObject_HasAttrString(a, "_fields")) {
// Call type_obj with new_tuple as arguments (i.e. type(a)(*new_tuple))
return PyObject_CallObject(
reinterpret_cast<PyObject*>(Py_TYPE(a)), new_tuple);
} else {
return new_tuple.release();
}
}
// Case 2: a is a list.
else if (PyList_Check(a)) {
Py_ssize_t n = PyList_GET_SIZE(a);
if (n == 0 && exact_type(a, immutable_list_cls())) {
return Py_NewRef(a);
}
THPObjectPtr result(PyObject_CallNoArgs(immutable_list_cls()));
if (!result) {
throw PythonError();
}
for (Py_ssize_t i = 0; i < n; i++) {
PyObject* elem = PyList_GET_ITEM(a, i); // borrowed ref
THPObjectPtr mapped(map_aggregate(elem, fn));
if (PyList_Append(result.get(), mapped.get()) < 0) {
throw PythonError();
}
}
return result.release();
}
// Case 3: a is a dict.
else if (PyDict_Check(a)) {
if (PyDict_GET_SIZE(a) == 0 && exact_type(a, immutable_dict_cls())) {
return Py_NewRef(a);
}
THPObjectPtr result(PyObject_CallNoArgs(immutable_dict_cls()));
if (!result) {
throw PythonError();
}
PyObject *key = nullptr, *value = nullptr; // borrowed
Py_ssize_t pos = 0;
while (PyDict_Next(a, &pos, &key, &value)) {
THPObjectPtr mapped(map_aggregate(value, fn));
if (PyDict_SetItem(result.get(), key, mapped.get()) < 0) {
throw PythonError();
}
}
return result.release();
}
// Case 4: a is a slice.
else if (PySlice_Check(a)) {
// Get start, stop, and step attributes.
THPObjectPtr start(PyObject_GetAttrString(a, "start"));
THPObjectPtr stop(PyObject_GetAttrString(a, "stop"));
THPObjectPtr step(PyObject_GetAttrString(a, "step"));
if (!start || !stop || !step) {
throw PythonError();
}
THPObjectPtr mapped_start(map_aggregate(start, fn));
THPObjectPtr mapped_stop(map_aggregate(stop, fn));
THPObjectPtr mapped_step(map_aggregate(step, fn));
return PySlice_New(
mapped_start.get(), mapped_stop.get(), mapped_step.get());
}
// Default case: call fn(a).
else {
PyObject* result = fn(a);
if (!result) {
throw PythonError();
}
return result;
}
}
////////////////////////////////
// NodeBase
///////////////////////////////
struct NodeBase {
PyObject_HEAD
bool _erased;
NodeBase* _prev;
NodeBase* _next;
PyObject* graph;
PyObject* name;
PyObject* op;
PyObject* target;
PyObject* type;
PyObject* _input_nodes;
PyObject* _args;
PyObject* _kwargs;
PyObject* users;
PyObject* _repr_fn;
PyObject* meta;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
alignas(NodeSortKey) char sort_key_buf[sizeof(NodeSortKey)];
inline NodeSortKey& sort_key() {
return *reinterpret_cast<NodeSortKey*>(sort_key_buf);
}
// Equivalent to:
// p, n = self._prev, self._next
// p._next, n._prev = n, p
inline void remove_from_list() {
NodeBase* p = this->_prev;
NodeBase* n = this->_next;
p->_next = n;
n->_prev = p;
}
};
static PyObject* NodeBase_new(
PyTypeObject* type,
PyObject* args,
PyObject* kwds) {
PyObject* self = type->tp_alloc(type, 0);
if (!self)
return nullptr;
new (reinterpret_cast<NodeBase*>(self)->sort_key_buf)
NodeSortKey(); // placement new does not allocate
return self;
}
static int NodeBase_init_fn(NodeBase* self, PyObject* args, PyObject* kwds) {
PyObject* graph = nullptr;
PyObject* name = nullptr;
PyObject* op = nullptr;
PyObject* target = nullptr;
PyObject* type = nullptr;
if (!PyArg_ParseTuple(args, "OOOOO", &graph, &name, &op, &target, &type)) {
return -1;
}
self->_erased = false;
Py_INCREF(self);
self->_prev = self;
Py_INCREF(self);
self->_next = self;
self->graph = Py_NewRef(graph);
self->name = Py_NewRef(name);
self->op = Py_NewRef(op);
self->target = Py_NewRef(target);
self->type = Py_NewRef(type);
self->_input_nodes = PyDict_New();
self->_args = nullptr; // set with _update_args_kwargs
self->_kwargs = nullptr; // set with _update_args_kwargs
self->users = PyDict_New();
self->_repr_fn = Py_NewRef(Py_None);
self->meta = PyDict_New();
return 0;
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
static struct PyMemberDef NodeBase_members[] = {
{"_erased", T_BOOL, offsetof(NodeBase, _erased), 0, nullptr},
{"_prev", T_OBJECT_EX, offsetof(NodeBase, _prev), 0, nullptr},
{"_next", T_OBJECT_EX, offsetof(NodeBase, _next), 0, nullptr},
{"graph", T_OBJECT_EX, offsetof(NodeBase, graph), 0, nullptr},
{"name", T_OBJECT_EX, offsetof(NodeBase, name), 0, nullptr},
{"op", T_OBJECT_EX, offsetof(NodeBase, op), 0, nullptr},
{"target", T_OBJECT_EX, offsetof(NodeBase, target), 0, nullptr},
{"type", T_OBJECT_EX, offsetof(NodeBase, type), 0, nullptr},
{"_input_nodes", T_OBJECT_EX, offsetof(NodeBase, _input_nodes), 0, nullptr},
{"_args", T_OBJECT_EX, offsetof(NodeBase, _args), 0, nullptr},
{"_kwargs", T_OBJECT_EX, offsetof(NodeBase, _kwargs), 0, nullptr},
{"users", T_OBJECT_EX, offsetof(NodeBase, users), 0, nullptr},
{"_repr_fn", T_OBJECT_EX, offsetof(NodeBase, _repr_fn), 0, nullptr},
{"meta", T_OBJECT_EX, offsetof(NodeBase, meta), 0, nullptr},
{nullptr} /* Sentinel */
};
static int NodeBase_traverse(NodeBase* self, visitproc visit, void* arg) {
Py_VISIT(self->_prev);
Py_VISIT(self->_next);
Py_VISIT(self->graph);
Py_VISIT(self->name);
Py_VISIT(self->op);
Py_VISIT(self->target);
Py_VISIT(self->type);
Py_VISIT(self->_input_nodes);
Py_VISIT(self->_args);
Py_VISIT(self->_kwargs);
Py_VISIT(self->users);
Py_VISIT(self->_repr_fn);
Py_VISIT(self->meta);
return 0;
}
static int NodeBase_clear(NodeBase* self) {
Py_CLEAR(self->_prev);
Py_CLEAR(self->_next);
Py_CLEAR(self->graph);
Py_CLEAR(self->name);
Py_CLEAR(self->op);
Py_CLEAR(self->target);
Py_CLEAR(self->type);
Py_CLEAR(self->_input_nodes);
Py_CLEAR(self->_args);
Py_CLEAR(self->_kwargs);
Py_CLEAR(self->users);
Py_CLEAR(self->_repr_fn);
Py_CLEAR(self->meta);
return 0;
}
static void NodeBase_dealloc(PyObject* self) {
PyObject_GC_UnTrack(self);
reinterpret_cast<NodeBase*>(self)->sort_key().~NodeSortKey();
(void)NodeBase_clear((NodeBase*)self);
Py_TYPE(self)->tp_free(self);
}
static PyObject* NodeBase__update_args_kwargs(
PyObject* self,
PyObject* const* args,
Py_ssize_t nargs) {
// Verify argument count
if (nargs != 2) {
PyErr_SetString(
PyExc_TypeError,
"_update_args_kwargs() requires exactly 2 arguments (new_args, new_kwargs)");
return nullptr;
}
auto node = reinterpret_cast<NodeBase*>(self);
auto input_nodes = node->_input_nodes;
if (PyDict_GET_SIZE(input_nodes) > 0) {
// Clear other.users containing us and input_nodes
PyObject *key = nullptr, *value = nullptr; // borrowed
Py_ssize_t pos = 0;
while (PyDict_Next(input_nodes, &pos, &key, &value)) {
// key.users.pop(self), intentionally ignore KeyError
PyDict_DelItem(reinterpret_cast<NodeBase*>(key)->users, self);
}
PyDict_Clear(input_nodes);
}
auto visit_fn = [self, input_nodes](PyObject* x) {
if (is_node(x)) {
// self._input_nodes.setdefault(x)
if (!PyDict_SetDefault(input_nodes, x, Py_None)) {
throw PythonError();
}
// x.users.setdefault(self)
if (!PyDict_SetDefault(
reinterpret_cast<NodeBase*>(x)->users, self, Py_None)) {
throw PythonError();
}
}
return Py_NewRef(x);
};
// We do three things in a single pass of the args
// - Normalize list->immutable_list, dict->immutable_dict, etc
// - Populate self._input_nodes
// - Populate arg.users[self] for each arg
try {
Py_CLEAR(node->_args);
node->_args = map_aggregate(args[0], visit_fn);
Py_CLEAR(node->_kwargs);
node->_kwargs = map_aggregate(args[1], visit_fn);
Py_RETURN_NONE;
} catch (const PythonError& e) {
return nullptr;
}
}
static PyObject* NodeBase__remove_from_list(
PyObject* self,
PyObject* _ignored) {
reinterpret_cast<NodeBase*>(self)->remove_from_list();
Py_RETURN_NONE;
}
static PyObject* NodeBase__prepend(PyObject* self_, PyObject* arg) {
if (self_ == arg) {
Py_RETURN_NONE;
}
if (!is_node(arg)) {
PyErr_SetString(PyExc_TypeError, "_prepend() argument must be a Node");
return nullptr;
}
NodeBase* self = reinterpret_cast<NodeBase*>(self_);
NodeBase* x = reinterpret_cast<NodeBase*>(arg);
if (self->graph != x->graph) {
PyErr_SetString(
PyExc_AssertionError,
"Attempting to move a Node into a different Graph");
return nullptr;
}
x->remove_from_list();
NodeBase* p = self->_prev;
p->_next = x;
x->_prev = p;
x->_next = self;
self->_prev = x;
// Now compute x.sort_key()
const NodeSortKey& psk = x->_prev->sort_key();
const NodeSortKey& nsk = x->_next->sort_key();
if (psk.size() > nsk.size()) {
// prefix = psk[: len(nsk)+1]
size_t slice_len = nsk.size() + 1;
NodeSortKey prefix(psk.begin(), psk.begin() + slice_len);
// last element is idx => increment by 1
prefix.back()++;
x->sort_key() = std::move(prefix);
} else if (psk.size() < nsk.size()) {
// prefix = nsk[: len(psk)+1]
size_t slice_len = psk.size() + 1;
NodeSortKey prefix(nsk.begin(), nsk.begin() + slice_len);
// last element is idx => decrement by 1
prefix.back()--;
x->sort_key() = std::move(prefix);
} else {
// same length => add a 0
x->sort_key() = psk;
x->sort_key().emplace_back(0);
}
Py_RETURN_NONE;
}
// __lt__(self, other): Return self.sort_key < other.sort_key
static PyObject* NodeBase___lt__(PyObject* self, PyObject* other) {
// METH_O => one argument: 'other'
if (!is_node(other)) {
Py_RETURN_NOTIMPLEMENTED;
}
const NodeSortKey& lhs = reinterpret_cast<NodeBase*>(self)->sort_key();
const NodeSortKey& rhs = reinterpret_cast<NodeBase*>(other)->sort_key();
bool less = std::lexicographical_compare(
lhs.begin(), lhs.end(), rhs.begin(), rhs.end());
if (less)
Py_RETURN_TRUE;
Py_RETURN_FALSE;
}
// __gt__(self, other): Return self.sort_key() > other.sort_key
static PyObject* NodeBase___gt__(PyObject* self, PyObject* other) {
if (!is_node(other)) {
Py_RETURN_NOTIMPLEMENTED;
}
const NodeSortKey& lhs = reinterpret_cast<NodeBase*>(self)->sort_key();
const NodeSortKey& rhs = reinterpret_cast<NodeBase*>(other)->sort_key();
// "a > b" is equivalent to "b < a"
bool greater = std::lexicographical_compare(
rhs.begin(), rhs.end(), lhs.begin(), lhs.end());
if (greater)
Py_RETURN_TRUE;
Py_RETURN_FALSE;
}
static PyObject* NodeBase___ge__(PyObject* self, PyObject* other) {
if (self == other) {
Py_RETURN_TRUE;
}
return NodeBase___gt__(self, other);
}
// __le__(self, other): Return not (self > other)
static PyObject* NodeBase___le__(PyObject* self, PyObject* other) {
if (self == other) {
Py_RETURN_TRUE;
}
return NodeBase___lt__(self, other);
}
// Convert the NodeBase::sort_key vector<long> into a Python tuple of ints
// Only used by pickle/__getstate__
static PyObject* NodeBase_get_sort_key(PyObject* self, void* /*closure*/) {
NodeBase* node = reinterpret_cast<NodeBase*>(self);
const NodeSortKey& vec = node->sort_key();
Py_ssize_t n = static_cast<Py_ssize_t>(vec.size());
THPObjectPtr tuple(PyTuple_New(n));
if (!tuple) {
return nullptr; // Out of memory
}
for (Py_ssize_t i = 0; i < n; i++) {
PyTuple_SET_ITEM(tuple.get(), i, PyLong_FromSsize_t(vec[i]));
}
return tuple.release();
}
// Setter for NodeBase::sort_key: expects a Python tuple of ints, e.g.
// node._sort_key = (1,2,3) Only used by pickle/__setstate__
static int NodeBase_set_sort_key(
PyObject* self,
PyObject* value,
void* /*closure*/) {
NodeBase* node = reinterpret_cast<NodeBase*>(self);
if (!PyTuple_Check(value)) {
PyErr_SetString(PyExc_TypeError, "_sort_key must be an tuple of ints");
return -1;
}
Py_ssize_t size = PyTuple_GET_SIZE(value);
NodeSortKey new_vec;
new_vec.reserve(size);
for (Py_ssize_t i = 0; i < size; i++) {
int64_t val = PyLong_AsSsize_t(PyTuple_GET_ITEM(value, i));
if (val == -1 && PyErr_Occurred()) {
return -1;
}
new_vec.emplace_back(val);
}
node->sort_key() = std::move(new_vec);
return 0;
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
static PyMethodDef NodeBase_methods[] = {
{"_update_args_kwargs",
(PyCFunction)(void*)(NodeBase__update_args_kwargs),
METH_FASTCALL,
"Internal method: do not call directly."},
{"_remove_from_list",
(PyCFunction)(void*)(NodeBase__remove_from_list),
METH_NOARGS,
"Internal method: do not call directly."},
{"_prepend",
(PyCFunction)(void*)(NodeBase__prepend),
METH_O,
"Internal method: do not call directly."},
{"__lt__",
(PyCFunction)(void*)NodeBase___lt__,
METH_O,
"Return True if self.sort_key < other.sort_key"},
{"__gt__",
(PyCFunction)(void*)NodeBase___gt__,
METH_O,
"Return True if self.sort_key > other.sort_key"},
{"__ge__",
(PyCFunction)(void*)NodeBase___ge__,
METH_O,
"Return True if self.sort_key >= other.sort_key"},
{"__le__",
(PyCFunction)(void*)NodeBase___le__,
METH_O,
"Return True if self.sort_key <= other.sort_key"},
{nullptr, nullptr, 0, nullptr} // Sentinel
};
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
static PyGetSetDef NodeBase_getset[] = {
{"_sort_key", // attribute name in Python
(getter)NodeBase_get_sort_key, // C getter function
(setter)NodeBase_set_sort_key, // C setter function
(char*)"The sort key as a tuple of ints", // docstring
nullptr},
{nullptr, nullptr, nullptr, nullptr, nullptr} // Sentinel
};
PyTypeObject NodeBaseType = {
PyVarObject_HEAD_INIT(nullptr, 0)
"torch._C._NodeBase", /* tp_name */
sizeof(NodeBase), /* tp_basicsize */
0, /* tp_itemsize */
(destructor)NodeBase_dealloc, /* tp_dealloc */
0, /* tp_vectorcall_offset */
nullptr, /* tp_getattr */
nullptr, /* tp_setattr */
nullptr, /* tp_reserved */
nullptr, /* tp_repr */
nullptr, /* tp_as_number */
nullptr, /* tp_as_sequence */
nullptr, /* tp_as_mapping */
nullptr, /* tp_hash */
nullptr, /* tp_call */
nullptr, /* tp_str */
nullptr, /* tp_getattro */
nullptr, /* tp_setattro */
nullptr, /* tp_as_buffer */
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE |
Py_TPFLAGS_HAVE_GC, /* tp_flags */
nullptr, /* tp_doc */
(traverseproc)NodeBase_traverse, /* tp_traverse */
(inquiry)NodeBase_clear, /* tp_clear */
nullptr, /* tp_richcompare */
0, /* tp_weaklistoffset */
nullptr, /* tp_iter */
nullptr, /* tp_iternext */
NodeBase_methods, /* tp_methods */
NodeBase_members, /* tp_members */
NodeBase_getset, /* tp_getset */
nullptr, /* tp_base */
nullptr, /* tp_dict */
nullptr, /* tp_descr_get */
nullptr, /* tp_descr_set */
0, /* tp_dictoffset */
(initproc)NodeBase_init_fn, /* tp_init */
nullptr, /* tp_alloc */
NodeBase_new, /* tp_new */
};
} // namespace
////////////////////////////////
// NodeIter
////////////////////////////////
struct NodeIter {
PyObject_HEAD
bool _reversed;
NodeBase* _root;
NodeBase* _cur;
};
static PyObject* NodeIter_new(
PyTypeObject* type,
PyObject* args,
PyObject* kwds) {
PyObject* self = type->tp_alloc(type, 0);
if (!self)
return nullptr;
return self;
}
static int NodeIter_init_fn(NodeIter* self, PyObject* args, PyObject* kwargs) {
NodeBase* root = nullptr;
bool reversed = false;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
constexpr const char* keywords[] = {"root", "reversed", nullptr};
if (!PyArg_ParseTupleAndKeywords(
args,
kwargs,
"Ob|",
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
const_cast<char**>(keywords),
&root,
&reversed)) {
return -1;
}
self->_reversed = reversed;
Py_INCREF(root);
self->_root = root;
Py_INCREF(root);
self->_cur = root;
return 0;
}
template <bool reversed>
static PyObject* NodeIter_iternext_helper(NodeIter* self) {
// It should be possible to relax the ref counting here
// but in practice, we do not have that many _erased Nodes,
// so probably not worth it.
if constexpr (reversed) {
NodeBase* prev = (NodeBase*)Py_NewRef(self->_cur->_prev);
Py_CLEAR(self->_cur);
self->_cur = prev;
} else {
NodeBase* next = (NodeBase*)Py_NewRef(self->_cur->_next);
Py_CLEAR(self->_cur);
self->_cur = next;
}
while (self->_cur != self->_root) {
if (!self->_cur->_erased) {
Py_INCREF(self->_cur);
return (PyObject*)self->_cur;
}
if constexpr (reversed) {
NodeBase* prev = (NodeBase*)Py_NewRef(self->_cur->_prev);
Py_CLEAR(self->_cur);
self->_cur = prev;
} else {
NodeBase* next = (NodeBase*)Py_NewRef(self->_cur->_next);
Py_CLEAR(self->_cur);
self->_cur = next;
}
}
PyErr_SetNone(PyExc_StopIteration);
return nullptr;
}
static PyObject* NodeIter_iternext(PyObject* _self) {
NodeIter* self = (NodeIter*)_self;
if (self->_reversed) {
return NodeIter_iternext_helper<true>(self);
} else {
return NodeIter_iternext_helper<false>(self);
}
}
static int NodeIter_traverse(NodeIter* self, visitproc visit, void* arg) {
Py_VISIT(self->_root);
Py_VISIT(self->_cur);
return 0;
}
static int NodeIter_clear(NodeIter* self) {
Py_CLEAR(self->_root);
Py_CLEAR(self->_cur);
return 0;
}
static void NodeIter_dealloc(PyObject* self) {
PyObject_GC_UnTrack(self);
(void)NodeIter_clear((NodeIter*)self);
Py_TYPE(self)->tp_free(self);
}
static PyTypeObject NodeIterType = {
PyVarObject_HEAD_INIT(nullptr, 0)
"torch._C._NodeIter", /* tp_name */
sizeof(NodeIter), /* tp_basicsize */
0, /* tp_itemsize */
(destructor)NodeIter_dealloc, /* tp_dealloc */
0, /* tp_vectorcall_offset */
nullptr, /* tp_getattr */
nullptr, /* tp_setattr */
nullptr, /* tp_reserved */
nullptr, /* tp_repr */
nullptr, /* tp_as_number */
nullptr, /* tp_as_sequence */
nullptr, /* tp_as_mapping */
nullptr, /* tp_hash */
nullptr, /* tp_call */
nullptr, /* tp_str */
nullptr, /* tp_getattro */
nullptr, /* tp_setattro */
nullptr, /* tp_as_buffer */
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC, /* tp_flags */
nullptr, /* tp_doc */
(traverseproc)NodeIter_traverse, /* tp_traverse */
(inquiry)NodeIter_clear, /* tp_clear */
nullptr, /* tp_richcompare */
0, /* tp_weaklistoffset */
PyObject_SelfIter, /* tp_iter */
NodeIter_iternext, /* tp_iternext */
nullptr, /* tp_methods */
nullptr, /* tp_members */
nullptr, /* tp_getset */
nullptr, /* tp_base */
nullptr, /* tp_dict */
nullptr, /* tp_descr_get */
nullptr, /* tp_descr_set */
0, /* tp_dictoffset */
(initproc)NodeIter_init_fn, /* tp_init */
nullptr, /* tp_alloc */
NodeIter_new, /* tp_new */
};
bool NodeIter_init(PyObject* module) {
if (PyModule_AddType(module, &NodeIterType) < 0) {
return false;
}
return true;
}
////////////////////////////////
// Global methods
////////////////////////////////
static PyObject* py_map_aggregate(
PyObject* self,
PyObject* const* args,
Py_ssize_t nargs) {
if (nargs != 2) {
PyErr_SetString(
PyExc_TypeError, "map_aggregate() takes exactly two arguments");
return nullptr;
}
try {
PyObject* fn = args[1];
// args[0]: aggregate, args[1]: callable fn
return map_aggregate(
args[0], [fn](PyObject* a) { return PyObject_CallOneArg(fn, a); });
} catch (const PythonError& e) {
return nullptr; // error should already be set
}
}
static PyObject* py_map_arg(
PyObject* self,
PyObject* const* args,
Py_ssize_t nargs) {
if (nargs != 2) {
PyErr_SetString(PyExc_TypeError, "map_arg() takes exactly two arguments");
return nullptr;
}
try {
PyObject* fn = args[1];
// args[0]: aggregate, args[1]: callable fn
return map_aggregate(args[0], [fn](PyObject* a) {
if (is_node(a)) {
return PyObject_CallOneArg(fn, a);
}
return Py_NewRef(a);
});
} catch (const PythonError& e) {
return nullptr; // error should already be set
}
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
static PyMethodDef extra_methods[] = {
{"_fx_map_aggregate",
(PyCFunction)(void*)(py_map_aggregate),
METH_FASTCALL,
"Recursively apply a function to every element in an aggregate object."},
{"_fx_map_arg",
(PyCFunction)(void*)(py_map_arg),
METH_FASTCALL,
"Recursively apply a function to every Node in an aggregate object."},
{nullptr, nullptr, 0, nullptr} // Sentinel
};
bool NodeBase_init(PyObject* module) {
if (PyModule_AddType(module, &NodeBaseType) < 0) {
return false;
}
if (PyModule_AddFunctions(module, extra_methods) < 0) {
return false;
}
return true;
}