#include #include #include #include #include #include #include namespace { using NodeSortKey = c10::SmallVector; struct NodeBase; // Thrown to exit out of a C++ function and return an error to Python. class PythonError : public std::exception {}; inline static PyObject* import_from(const char* module_name, const char* name) { THPObjectPtr module(PyImport_ImportModule(module_name)); if (!module) { throw PythonError(); } PyObject* result = PyObject_GetAttrString(module, name); if (!result) { throw PythonError(); } return result; } inline static PyObject* immutable_list_cls() { static PyObject* immutable_list_cls = nullptr; if (!immutable_list_cls) { immutable_list_cls = import_from("torch.fx.immutable_collections", "immutable_list"); } return immutable_list_cls; } inline static PyObject* immutable_dict_cls() { static PyObject* immutable_dict_cls = nullptr; if (!immutable_dict_cls) { immutable_dict_cls = import_from("torch.fx.immutable_collections", "immutable_dict"); } return immutable_dict_cls; } inline static bool is_node(PyObject* obj) { static PyObject* node_cls = nullptr; if (!node_cls) { node_cls = import_from("torch.fx.node", "Node"); } return PyObject_TypeCheck(obj, reinterpret_cast(node_cls)); } inline static bool exact_type(PyObject* obj, PyObject* typ) { return Py_TYPE(obj) == reinterpret_cast(typ); } template inline static PyObject* map_aggregate(PyObject* a, F fn) { // Invariant: this function will throw an exception and never return nullptr. // Case 1: a is a tuple. if (PyTuple_Check(a)) { Py_ssize_t n = PyTuple_GET_SIZE(a); if (n == 0 && PyTuple_CheckExact(a)) { return Py_NewRef(a); } THPObjectPtr new_tuple(PyTuple_New(n)); if (!new_tuple) { throw PythonError(); } for (Py_ssize_t i = 0; i < n; i++) { PyObject* elem = PyTuple_GET_ITEM(a, i); // Borrowed reference. // PyTuple_SET_ITEM steals reference to result of map_aggregate PyTuple_SET_ITEM(new_tuple.get(), i, map_aggregate(elem, fn)); } // If the tuple has a "_fields" attribute, assume it is a NamedTuple. if (!PyTuple_CheckExact(a) && PyObject_HasAttrString(a, "_fields")) { // Call type_obj with new_tuple as arguments (i.e. type(a)(*new_tuple)) return PyObject_CallObject( reinterpret_cast(Py_TYPE(a)), new_tuple); } else { return new_tuple.release(); } } // Case 2: a is a list. else if (PyList_Check(a)) { Py_ssize_t n = PyList_GET_SIZE(a); if (n == 0 && exact_type(a, immutable_list_cls())) { return Py_NewRef(a); } THPObjectPtr result(PyObject_CallNoArgs(immutable_list_cls())); if (!result) { throw PythonError(); } for (Py_ssize_t i = 0; i < n; i++) { PyObject* elem = PyList_GET_ITEM(a, i); // borrowed ref THPObjectPtr mapped(map_aggregate(elem, fn)); if (PyList_Append(result.get(), mapped.get()) < 0) { throw PythonError(); } } return result.release(); } // Case 3: a is a dict. else if (PyDict_Check(a)) { if (PyDict_GET_SIZE(a) == 0 && exact_type(a, immutable_dict_cls())) { return Py_NewRef(a); } THPObjectPtr result(PyObject_CallNoArgs(immutable_dict_cls())); if (!result) { throw PythonError(); } PyObject *key = nullptr, *value = nullptr; // borrowed Py_ssize_t pos = 0; while (PyDict_Next(a, &pos, &key, &value)) { THPObjectPtr mapped(map_aggregate(value, fn)); if (PyDict_SetItem(result.get(), key, mapped.get()) < 0) { throw PythonError(); } } return result.release(); } // Case 4: a is a slice. else if (PySlice_Check(a)) { // Get start, stop, and step attributes. THPObjectPtr start(PyObject_GetAttrString(a, "start")); THPObjectPtr stop(PyObject_GetAttrString(a, "stop")); THPObjectPtr step(PyObject_GetAttrString(a, "step")); if (!start || !stop || !step) { throw PythonError(); } THPObjectPtr mapped_start(map_aggregate(start, fn)); THPObjectPtr mapped_stop(map_aggregate(stop, fn)); THPObjectPtr mapped_step(map_aggregate(step, fn)); return PySlice_New( mapped_start.get(), mapped_stop.get(), mapped_step.get()); } // Default case: call fn(a). else { PyObject* result = fn(a); if (!result) { throw PythonError(); } return result; } } //////////////////////////////// // NodeBase /////////////////////////////// struct NodeBase { PyObject_HEAD bool _erased; NodeBase* _prev; NodeBase* _next; PyObject* graph; PyObject* name; PyObject* op; PyObject* target; PyObject* type; PyObject* _input_nodes; PyObject* _args; PyObject* _kwargs; PyObject* users; PyObject* _repr_fn; PyObject* meta; // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) alignas(NodeSortKey) char sort_key_buf[sizeof(NodeSortKey)]; inline NodeSortKey& sort_key() { return *reinterpret_cast(sort_key_buf); } inline void set_prev(NodeBase* value) { TORCH_INTERNAL_ASSERT_DEBUG_ONLY(value); Py_INCREF(reinterpret_cast(value)); NodeBase* old = _prev; _prev = value; Py_DECREF(reinterpret_cast(old)); } inline void set_next(NodeBase* value) { TORCH_INTERNAL_ASSERT_DEBUG_ONLY(value); Py_INCREF(reinterpret_cast(value)); NodeBase* old = _next; _next = value; Py_DECREF(reinterpret_cast(old)); } // Equivalent to: // p, n = self._prev, self._next // p._next, n._prev = n, p inline void remove_from_list() { if (this->_prev == this && this->_next == this) { return; } NodeBase* p = this->_prev; NodeBase* n = this->_next; p->set_next(n); n->set_prev(p); } }; static PyObject* NodeBase_new( PyTypeObject* type, PyObject* args, PyObject* kwds) { PyObject* self = type->tp_alloc(type, 0); if (!self) return nullptr; new (reinterpret_cast(self)->sort_key_buf) NodeSortKey(); // placement new does not allocate return self; } static int NodeBase_init_fn(NodeBase* self, PyObject* args, PyObject* kwds) { PyObject* graph = nullptr; PyObject* name = nullptr; PyObject* op = nullptr; PyObject* target = nullptr; PyObject* type = nullptr; if (!PyArg_ParseTuple(args, "OOOOO", &graph, &name, &op, &target, &type)) { return -1; } self->_erased = false; Py_INCREF(self); self->_prev = self; Py_INCREF(self); self->_next = self; self->graph = Py_NewRef(graph); self->name = Py_NewRef(name); self->op = Py_NewRef(op); self->target = Py_NewRef(target); self->type = Py_NewRef(type); self->_input_nodes = PyDict_New(); self->_args = nullptr; // set with _update_args_kwargs self->_kwargs = nullptr; // set with _update_args_kwargs self->users = PyDict_New(); self->_repr_fn = Py_NewRef(Py_None); self->meta = PyDict_New(); return 0; } // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) static struct PyMemberDef NodeBase_members[] = { {"_erased", T_BOOL, offsetof(NodeBase, _erased), 0, nullptr}, {"_prev", T_OBJECT_EX, offsetof(NodeBase, _prev), 0, nullptr}, {"_next", T_OBJECT_EX, offsetof(NodeBase, _next), 0, nullptr}, {"graph", T_OBJECT_EX, offsetof(NodeBase, graph), 0, nullptr}, {"name", T_OBJECT_EX, offsetof(NodeBase, name), 0, nullptr}, {"op", T_OBJECT_EX, offsetof(NodeBase, op), 0, nullptr}, {"target", T_OBJECT_EX, offsetof(NodeBase, target), 0, nullptr}, {"type", T_OBJECT_EX, offsetof(NodeBase, type), 0, nullptr}, {"_input_nodes", T_OBJECT_EX, offsetof(NodeBase, _input_nodes), 0, nullptr}, {"_args", T_OBJECT_EX, offsetof(NodeBase, _args), 0, nullptr}, {"_kwargs", T_OBJECT_EX, offsetof(NodeBase, _kwargs), 0, nullptr}, {"users", T_OBJECT_EX, offsetof(NodeBase, users), 0, nullptr}, {"_repr_fn", T_OBJECT_EX, offsetof(NodeBase, _repr_fn), 0, nullptr}, {"meta", T_OBJECT_EX, offsetof(NodeBase, meta), 0, nullptr}, {nullptr} /* Sentinel */ }; static int NodeBase_traverse(NodeBase* self, visitproc visit, void* arg) { Py_VISIT(self->_prev); Py_VISIT(self->_next); Py_VISIT(self->graph); Py_VISIT(self->name); Py_VISIT(self->op); Py_VISIT(self->target); Py_VISIT(self->type); Py_VISIT(self->_input_nodes); Py_VISIT(self->_args); Py_VISIT(self->_kwargs); Py_VISIT(self->users); Py_VISIT(self->_repr_fn); Py_VISIT(self->meta); return 0; } static int NodeBase_clear(NodeBase* self) { Py_CLEAR(self->_prev); Py_CLEAR(self->_next); Py_CLEAR(self->graph); Py_CLEAR(self->name); Py_CLEAR(self->op); Py_CLEAR(self->target); Py_CLEAR(self->type); Py_CLEAR(self->_input_nodes); Py_CLEAR(self->_args); Py_CLEAR(self->_kwargs); Py_CLEAR(self->users); Py_CLEAR(self->_repr_fn); Py_CLEAR(self->meta); return 0; } static void NodeBase_dealloc(PyObject* self) { PyObject_GC_UnTrack(self); reinterpret_cast(self)->sort_key().~NodeSortKey(); (void)NodeBase_clear((NodeBase*)self); Py_TYPE(self)->tp_free(self); } static PyObject* NodeBase__update_args_kwargs( PyObject* self, PyObject* const* args, Py_ssize_t nargs) { // Verify argument count if (nargs != 2) { PyErr_SetString( PyExc_TypeError, "_update_args_kwargs() requires exactly 2 arguments (new_args, new_kwargs)"); return nullptr; } auto node = reinterpret_cast(self); auto input_nodes = node->_input_nodes; if (PyDict_GET_SIZE(input_nodes) > 0) { // Clear other.users containing us and input_nodes PyObject *key = nullptr, *value = nullptr; // borrowed Py_ssize_t pos = 0; while (PyDict_Next(input_nodes, &pos, &key, &value)) { // key.users.pop(self), intentionally ignore KeyError PyDict_DelItem(reinterpret_cast(key)->users, self); } PyDict_Clear(input_nodes); } auto visit_fn = [self, input_nodes](PyObject* x) { if (is_node(x)) { // self._input_nodes.setdefault(x) if (!PyDict_SetDefault(input_nodes, x, Py_None)) { throw PythonError(); } // x.users.setdefault(self) if (!PyDict_SetDefault( reinterpret_cast(x)->users, self, Py_None)) { throw PythonError(); } } return Py_NewRef(x); }; // We do three things in a single pass of the args // - Normalize list->immutable_list, dict->immutable_dict, etc // - Populate self._input_nodes // - Populate arg.users[self] for each arg try { Py_CLEAR(node->_args); node->_args = map_aggregate(args[0], visit_fn); Py_CLEAR(node->_kwargs); node->_kwargs = map_aggregate(args[1], visit_fn); Py_RETURN_NONE; } catch (const PythonError& e) { return nullptr; } } static PyObject* NodeBase__remove_from_list( PyObject* self, PyObject* _ignored) { reinterpret_cast(self)->remove_from_list(); Py_RETURN_NONE; } static PyObject* NodeBase__replace_input_with( PyObject* self, PyObject* const* args, Py_ssize_t nargs) { if (nargs != 2) { PyErr_SetString( PyExc_TypeError, "_replace_input_with() requires exactly 2 arguments (old_input, new_input)"); return nullptr; } PyObject* old_input = args[0]; PyObject* new_input = args[1]; auto replace_fn = [old_input, new_input](PyObject* maybe_node) { if (maybe_node == old_input) { return Py_NewRef(new_input); } return Py_NewRef(maybe_node); }; auto node = reinterpret_cast(self); try { THPObjectPtr new_args(map_aggregate(node->_args, replace_fn)); if (!new_args) { return nullptr; } THPObjectPtr new_kwargs(map_aggregate(node->_kwargs, replace_fn)); if (!new_kwargs) { return nullptr; } PyObject* update_args[2] = {new_args.get(), new_kwargs.get()}; return NodeBase__update_args_kwargs(self, update_args, 2); } catch (const PythonError& e) { return nullptr; } } static PyObject* NodeBase__prepend(PyObject* self_, PyObject* arg) { if (self_ == arg) { Py_RETURN_NONE; } if (!is_node(arg)) { PyErr_SetString(PyExc_TypeError, "_prepend() argument must be a Node"); return nullptr; } NodeBase* self = reinterpret_cast(self_); NodeBase* x = reinterpret_cast(arg); if (self->graph != x->graph) { PyErr_SetString( PyExc_AssertionError, "Attempting to move a Node into a different Graph"); return nullptr; } x->remove_from_list(); NodeBase* p = self->_prev; p->set_next(x); x->set_prev(p); x->set_next(self); self->set_prev(x); // Now compute x.sort_key() const NodeSortKey& psk = x->_prev->sort_key(); const NodeSortKey& nsk = x->_next->sort_key(); if (psk.size() > nsk.size()) { // prefix = psk[: len(nsk)+1] size_t slice_len = nsk.size() + 1; NodeSortKey prefix(psk.begin(), psk.begin() + slice_len); // last element is idx => increment by 1 prefix.back()++; x->sort_key() = std::move(prefix); } else if (psk.size() < nsk.size()) { // prefix = nsk[: len(psk)+1] size_t slice_len = psk.size() + 1; NodeSortKey prefix(nsk.begin(), nsk.begin() + slice_len); // last element is idx => decrement by 1 prefix.back()--; x->sort_key() = std::move(prefix); } else { // same length => add a 0 x->sort_key() = psk; x->sort_key().emplace_back(0); } Py_RETURN_NONE; } // __lt__(self, other): Return self.sort_key < other.sort_key static PyObject* NodeBase___lt__(PyObject* self, PyObject* other) { // METH_O => one argument: 'other' if (!is_node(other)) { Py_RETURN_NOTIMPLEMENTED; } const NodeSortKey& lhs = reinterpret_cast(self)->sort_key(); const NodeSortKey& rhs = reinterpret_cast(other)->sort_key(); bool less = std::lexicographical_compare( lhs.begin(), lhs.end(), rhs.begin(), rhs.end()); if (less) Py_RETURN_TRUE; Py_RETURN_FALSE; } // __gt__(self, other): Return self.sort_key() > other.sort_key static PyObject* NodeBase___gt__(PyObject* self, PyObject* other) { if (!is_node(other)) { Py_RETURN_NOTIMPLEMENTED; } const NodeSortKey& lhs = reinterpret_cast(self)->sort_key(); const NodeSortKey& rhs = reinterpret_cast(other)->sort_key(); // "a > b" is equivalent to "b < a" bool greater = std::lexicographical_compare( rhs.begin(), rhs.end(), lhs.begin(), lhs.end()); if (greater) Py_RETURN_TRUE; Py_RETURN_FALSE; } static PyObject* NodeBase___ge__(PyObject* self, PyObject* other) { if (self == other) { Py_RETURN_TRUE; } return NodeBase___gt__(self, other); } // __le__(self, other): Return not (self > other) static PyObject* NodeBase___le__(PyObject* self, PyObject* other) { if (self == other) { Py_RETURN_TRUE; } return NodeBase___lt__(self, other); } // Convert the NodeBase::sort_key vector into a Python tuple of ints // Only used by pickle/__getstate__ static PyObject* NodeBase_get_sort_key(PyObject* self, void* /*closure*/) { NodeBase* node = reinterpret_cast(self); const NodeSortKey& vec = node->sort_key(); Py_ssize_t n = static_cast(vec.size()); THPObjectPtr tuple(PyTuple_New(n)); if (!tuple) { return nullptr; // Out of memory } for (Py_ssize_t i = 0; i < n; i++) { PyObject* value = PyLong_FromSsize_t(vec[i]); if (!value) { return nullptr; } PyTuple_SET_ITEM(tuple.get(), i, value); } return tuple.release(); } // Setter for NodeBase::sort_key: expects a Python tuple of ints, e.g. // node._sort_key = (1,2,3) Only used by pickle/__setstate__ static int NodeBase_set_sort_key( PyObject* self, PyObject* value, void* /*closure*/) { NodeBase* node = reinterpret_cast(self); if (!PyTuple_Check(value)) { PyErr_SetString(PyExc_TypeError, "_sort_key must be an tuple of ints"); return -1; } Py_ssize_t size = PyTuple_GET_SIZE(value); NodeSortKey new_vec; new_vec.reserve(size); for (Py_ssize_t i = 0; i < size; i++) { int64_t val = PyLong_AsSsize_t(PyTuple_GET_ITEM(value, i)); if (val == -1 && PyErr_Occurred()) { return -1; } new_vec.emplace_back(val); } node->sort_key() = std::move(new_vec); return 0; } // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) static PyMethodDef NodeBase_methods[] = { {"_update_args_kwargs", (PyCFunction)(void*)(NodeBase__update_args_kwargs), METH_FASTCALL, "Internal method: do not call directly."}, {"_remove_from_list", (PyCFunction)(void*)(NodeBase__remove_from_list), METH_NOARGS, "Internal method: do not call directly."}, {"_replace_input_with", (PyCFunction)(void*)(NodeBase__replace_input_with), METH_FASTCALL, "Internal method: replace occurrences of one input Node with another."}, {"_prepend", (PyCFunction)(void*)(NodeBase__prepend), METH_O, "Internal method: do not call directly."}, {"__lt__", (PyCFunction)(void*)NodeBase___lt__, METH_O, "Return True if self.sort_key < other.sort_key"}, {"__gt__", (PyCFunction)(void*)NodeBase___gt__, METH_O, "Return True if self.sort_key > other.sort_key"}, {"__ge__", (PyCFunction)(void*)NodeBase___ge__, METH_O, "Return True if self.sort_key >= other.sort_key"}, {"__le__", (PyCFunction)(void*)NodeBase___le__, METH_O, "Return True if self.sort_key <= other.sort_key"}, {nullptr, nullptr, 0, nullptr} // Sentinel }; // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) static PyGetSetDef NodeBase_getset[] = { {"_sort_key", // attribute name in Python (getter)NodeBase_get_sort_key, // C getter function (setter)NodeBase_set_sort_key, // C setter function (char*)"The sort key as a tuple of ints", // docstring nullptr}, {nullptr, nullptr, nullptr, nullptr, nullptr} // Sentinel }; PyTypeObject NodeBaseType = { PyVarObject_HEAD_INIT(nullptr, 0) "torch._C._NodeBase", /* tp_name */ sizeof(NodeBase), /* tp_basicsize */ 0, /* tp_itemsize */ NodeBase_dealloc, /* tp_dealloc */ 0, /* tp_vectorcall_offset */ nullptr, /* tp_getattr */ nullptr, /* tp_setattr */ nullptr, /* tp_reserved */ nullptr, /* tp_repr */ nullptr, /* tp_as_number */ nullptr, /* tp_as_sequence */ nullptr, /* tp_as_mapping */ nullptr, /* tp_hash */ nullptr, /* tp_call */ nullptr, /* tp_str */ nullptr, /* tp_getattro */ nullptr, /* tp_setattro */ nullptr, /* tp_as_buffer */ Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC, /* tp_flags */ nullptr, /* tp_doc */ (traverseproc)NodeBase_traverse, /* tp_traverse */ (inquiry)NodeBase_clear, /* tp_clear */ nullptr, /* tp_richcompare */ 0, /* tp_weaklistoffset */ nullptr, /* tp_iter */ nullptr, /* tp_iternext */ NodeBase_methods, /* tp_methods */ NodeBase_members, /* tp_members */ NodeBase_getset, /* tp_getset */ nullptr, /* tp_base */ nullptr, /* tp_dict */ nullptr, /* tp_descr_get */ nullptr, /* tp_descr_set */ 0, /* tp_dictoffset */ (initproc)NodeBase_init_fn, /* tp_init */ nullptr, /* tp_alloc */ NodeBase_new, /* tp_new */ }; } // namespace //////////////////////////////// // NodeIter //////////////////////////////// struct NodeIter { PyObject_HEAD bool _reversed; NodeBase* _root; NodeBase* _cur; }; static PyObject* NodeIter_new( PyTypeObject* type, PyObject* args, PyObject* kwds) { PyObject* self = type->tp_alloc(type, 0); if (!self) return nullptr; return self; } static int NodeIter_init_fn(NodeIter* self, PyObject* args, PyObject* kwargs) { NodeBase* root = nullptr; bool reversed = false; // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) constexpr const char* keywords[] = {"root", "reversed", nullptr}; if (!PyArg_ParseTupleAndKeywords( args, kwargs, "Ob|", // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) const_cast(keywords), &root, &reversed)) { return -1; } self->_reversed = reversed; Py_INCREF(root); self->_root = root; Py_INCREF(root); self->_cur = root; return 0; } template static PyObject* NodeIter_iternext_helper(NodeIter* self) { // It should be possible to relax the ref counting here // but in practice, we do not have that many _erased Nodes, // so probably not worth it. if constexpr (reversed) { NodeBase* prev = (NodeBase*)Py_NewRef(self->_cur->_prev); Py_CLEAR(self->_cur); self->_cur = prev; } else { NodeBase* next = (NodeBase*)Py_NewRef(self->_cur->_next); Py_CLEAR(self->_cur); self->_cur = next; } while (self->_cur != self->_root) { if (!self->_cur->_erased) { Py_INCREF(self->_cur); return (PyObject*)self->_cur; } if constexpr (reversed) { NodeBase* prev = (NodeBase*)Py_NewRef(self->_cur->_prev); Py_CLEAR(self->_cur); self->_cur = prev; } else { NodeBase* next = (NodeBase*)Py_NewRef(self->_cur->_next); Py_CLEAR(self->_cur); self->_cur = next; } } PyErr_SetNone(PyExc_StopIteration); return nullptr; } static PyObject* NodeIter_iternext(PyObject* _self) { NodeIter* self = (NodeIter*)_self; if (self->_reversed) { return NodeIter_iternext_helper(self); } else { return NodeIter_iternext_helper(self); } } static int NodeIter_traverse(NodeIter* self, visitproc visit, void* arg) { Py_VISIT(self->_root); Py_VISIT(self->_cur); return 0; } static int NodeIter_clear(NodeIter* self) { Py_CLEAR(self->_root); Py_CLEAR(self->_cur); return 0; } static void NodeIter_dealloc(PyObject* self) { PyObject_GC_UnTrack(self); (void)NodeIter_clear((NodeIter*)self); Py_TYPE(self)->tp_free(self); } static PyTypeObject NodeIterType = { PyVarObject_HEAD_INIT(nullptr, 0) "torch._C._NodeIter", /* tp_name */ sizeof(NodeIter), /* tp_basicsize */ 0, /* tp_itemsize */ (destructor)NodeIter_dealloc, /* tp_dealloc */ 0, /* tp_vectorcall_offset */ nullptr, /* tp_getattr */ nullptr, /* tp_setattr */ nullptr, /* tp_reserved */ nullptr, /* tp_repr */ nullptr, /* tp_as_number */ nullptr, /* tp_as_sequence */ nullptr, /* tp_as_mapping */ nullptr, /* tp_hash */ nullptr, /* tp_call */ nullptr, /* tp_str */ nullptr, /* tp_getattro */ nullptr, /* tp_setattro */ nullptr, /* tp_as_buffer */ Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC, /* tp_flags */ nullptr, /* tp_doc */ (traverseproc)NodeIter_traverse, /* tp_traverse */ (inquiry)NodeIter_clear, /* tp_clear */ nullptr, /* tp_richcompare */ 0, /* tp_weaklistoffset */ PyObject_SelfIter, /* tp_iter */ NodeIter_iternext, /* tp_iternext */ nullptr, /* tp_methods */ nullptr, /* tp_members */ nullptr, /* tp_getset */ nullptr, /* tp_base */ nullptr, /* tp_dict */ nullptr, /* tp_descr_get */ nullptr, /* tp_descr_set */ 0, /* tp_dictoffset */ (initproc)NodeIter_init_fn, /* tp_init */ nullptr, /* tp_alloc */ NodeIter_new, /* tp_new */ }; bool NodeIter_init(PyObject* module) { if (PyModule_AddType(module, &NodeIterType) < 0) { return false; } return true; } //////////////////////////////// // Global methods //////////////////////////////// static PyObject* py_map_aggregate( PyObject* self, PyObject* const* args, Py_ssize_t nargs) { if (nargs != 2) { PyErr_SetString( PyExc_TypeError, "map_aggregate() takes exactly two arguments"); return nullptr; } try { PyObject* fn = args[1]; // args[0]: aggregate, args[1]: callable fn return map_aggregate( args[0], [fn](PyObject* a) { return PyObject_CallOneArg(fn, a); }); } catch (const PythonError& e) { return nullptr; // error should already be set } } static PyObject* py_map_arg( PyObject* self, PyObject* const* args, Py_ssize_t nargs) { if (nargs != 2) { PyErr_SetString(PyExc_TypeError, "map_arg() takes exactly two arguments"); return nullptr; } try { PyObject* fn = args[1]; // args[0]: aggregate, args[1]: callable fn return map_aggregate(args[0], [fn](PyObject* a) { if (is_node(a)) { return PyObject_CallOneArg(fn, a); } return Py_NewRef(a); }); } catch (const PythonError& e) { return nullptr; // error should already be set } } // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) static PyMethodDef extra_methods[] = { {"_fx_map_aggregate", (PyCFunction)(void*)(py_map_aggregate), METH_FASTCALL, "Recursively apply a function to every element in an aggregate object."}, {"_fx_map_arg", (PyCFunction)(void*)(py_map_arg), METH_FASTCALL, "Recursively apply a function to every Node in an aggregate object."}, {nullptr, nullptr, 0, nullptr} // Sentinel }; bool NodeBase_init(PyObject* module) { if (PyModule_AddType(module, &NodeBaseType) < 0) { return false; } if (PyModule_AddFunctions(module, extra_methods) < 0) { return false; } return true; }