mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[fx] Move Node._update_args_kwargs to C++ (#148260)
Microbenchmarking `fx.symbolic_trace(lambda x: functools.reduce(operator.add, [x, *range(100000)]))`, before: ``` 25203549 function calls (24403352 primitive calls) in 12.090 seconds ``` after: ``` 24303536 function calls (23503339 primitive calls) in 10.726 seconds ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/148260 Approved by: https://github.com/oulgen ghstack dependencies: #148243
This commit is contained in:
parent
edaff88f69
commit
0135f57f4a
|
|
@ -1,32 +1,32 @@
|
|||
add_loop_eager,compile_time_instruction_count,2993000000,0.015
|
||||
add_loop_eager,compile_time_instruction_count,2958000000,0.015
|
||||
|
||||
|
||||
|
||||
add_loop_eager_dynamic,compile_time_instruction_count,6349000000,0.025
|
||||
add_loop_eager_dynamic,compile_time_instruction_count,6250000000,0.025
|
||||
|
||||
|
||||
|
||||
add_loop_inductor,compile_time_instruction_count,28630000000,0.015
|
||||
add_loop_inductor,compile_time_instruction_count,28450000000,0.015
|
||||
|
||||
|
||||
|
||||
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,45240000000,0.025
|
||||
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,44690000000,0.025
|
||||
|
||||
|
||||
|
||||
add_loop_inductor_gpu,compile_time_instruction_count,24960000000,0.015
|
||||
add_loop_inductor_gpu,compile_time_instruction_count,24770000000,0.015
|
||||
|
||||
|
||||
|
||||
basic_modules_ListOfLinears_eager,compile_time_instruction_count,960700000,0.015
|
||||
basic_modules_ListOfLinears_eager,compile_time_instruction_count,959200000,0.015
|
||||
|
||||
|
||||
|
||||
basic_modules_ListOfLinears_inductor,compile_time_instruction_count,18120000000,0.015
|
||||
basic_modules_ListOfLinears_inductor,compile_time_instruction_count,17950000000,0.015
|
||||
|
||||
|
||||
|
||||
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,16340000000,0.015
|
||||
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,16030000000,0.015
|
||||
|
||||
|
||||
|
||||
|
|
@ -34,32 +34,32 @@ basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,98740000
|
|||
|
||||
|
||||
|
||||
update_hint_regression,compile_time_instruction_count,1699000000,0.02
|
||||
update_hint_regression,compile_time_instruction_count,1683000000,0.02
|
||||
|
||||
|
||||
|
||||
sum_floordiv_regression,compile_time_instruction_count,1061000000,0.015
|
||||
sum_floordiv_regression,compile_time_instruction_count,1054000000,0.015
|
||||
|
||||
|
||||
|
||||
symint_sum,compile_time_instruction_count,3194000000,0.015
|
||||
symint_sum,compile_time_instruction_count,3167000000,0.015
|
||||
|
||||
|
||||
|
||||
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2018000000,0.015
|
||||
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2010000000,0.015
|
||||
|
||||
|
||||
|
||||
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5792000000,0.015
|
||||
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5776000000,0.015
|
||||
|
||||
|
||||
|
||||
aotdispatcher_partitioner_cpu,compile_time_instruction_count,8703000000,0.015
|
||||
aotdispatcher_partitioner_cpu,compile_time_instruction_count,8521000000,0.015
|
||||
|
||||
|
||||
|
||||
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3783000000,0.015
|
||||
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3735000000,0.015
|
||||
|
||||
|
||||
|
||||
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10170000000,0.015
|
||||
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10070000000,0.015
|
||||
|
|
|
|||
|
|
|
@ -2505,6 +2505,15 @@ class _NodeBase:
|
|||
_erased: _bool
|
||||
_prev: FxNode
|
||||
_next: FxNode
|
||||
def __init__(
|
||||
self,
|
||||
graph: Any,
|
||||
name: str,
|
||||
op: str,
|
||||
target: Any,
|
||||
return_type: Any,
|
||||
) -> None: ...
|
||||
def _update_args_kwargs(self, args: tuple[Any, ...], kwargs: dict[str, Any]): ...
|
||||
|
||||
class _NodeIter(Iterator):
|
||||
def __init__(self, root: FxNode, reversed: _bool) -> None: ...
|
||||
|
|
|
|||
|
|
@ -6,6 +6,8 @@
|
|||
|
||||
namespace {
|
||||
|
||||
struct NodeBase;
|
||||
|
||||
// Thrown to exit out of a C++ function and return an error to Python.
|
||||
class PythonError : public std::exception {};
|
||||
|
||||
|
|
@ -153,6 +155,18 @@ struct NodeBase {
|
|||
bool _erased;
|
||||
NodeBase* _prev;
|
||||
NodeBase* _next;
|
||||
PyObject* graph;
|
||||
PyObject* name;
|
||||
PyObject* op;
|
||||
PyObject* target;
|
||||
PyObject* type;
|
||||
PyObject* _input_nodes;
|
||||
PyObject* _args;
|
||||
PyObject* _kwargs;
|
||||
PyObject* users;
|
||||
PyObject* _repr_fn;
|
||||
PyObject* meta;
|
||||
PyObject* _sort_key;
|
||||
};
|
||||
|
||||
static PyObject* NodeBase_new(
|
||||
|
|
@ -166,11 +180,31 @@ static PyObject* NodeBase_new(
|
|||
}
|
||||
|
||||
static int NodeBase_init_fn(NodeBase* self, PyObject* args, PyObject* kwds) {
|
||||
PyObject* graph = nullptr;
|
||||
PyObject* name = nullptr;
|
||||
PyObject* op = nullptr;
|
||||
PyObject* target = nullptr;
|
||||
PyObject* type = nullptr;
|
||||
if (!PyArg_ParseTuple(args, "OOOOO", &graph, &name, &op, &target, &type)) {
|
||||
return -1;
|
||||
}
|
||||
self->_erased = false;
|
||||
Py_INCREF(self);
|
||||
self->_prev = self;
|
||||
Py_INCREF(self);
|
||||
self->_next = self;
|
||||
self->graph = Py_NewRef(graph);
|
||||
self->name = Py_NewRef(name);
|
||||
self->op = Py_NewRef(op);
|
||||
self->target = Py_NewRef(target);
|
||||
self->type = Py_NewRef(type);
|
||||
self->_input_nodes = PyDict_New();
|
||||
self->_args = nullptr; // set with _update_args_kwargs
|
||||
self->_kwargs = nullptr; // set with _update_args_kwargs
|
||||
self->users = PyDict_New();
|
||||
self->_repr_fn = Py_NewRef(Py_None);
|
||||
self->meta = PyDict_New();
|
||||
self->_sort_key = PyTuple_New(0);
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
|
@ -179,18 +213,54 @@ static struct PyMemberDef NodeBase_members[] = {
|
|||
{"_erased", T_BOOL, offsetof(NodeBase, _erased), 0, nullptr},
|
||||
{"_prev", T_OBJECT_EX, offsetof(NodeBase, _prev), 0, nullptr},
|
||||
{"_next", T_OBJECT_EX, offsetof(NodeBase, _next), 0, nullptr},
|
||||
{"graph", T_OBJECT_EX, offsetof(NodeBase, graph), 0, nullptr},
|
||||
{"name", T_OBJECT_EX, offsetof(NodeBase, name), 0, nullptr},
|
||||
{"op", T_OBJECT_EX, offsetof(NodeBase, op), 0, nullptr},
|
||||
{"target", T_OBJECT_EX, offsetof(NodeBase, target), 0, nullptr},
|
||||
{"type", T_OBJECT_EX, offsetof(NodeBase, type), 0, nullptr},
|
||||
{"_input_nodes", T_OBJECT_EX, offsetof(NodeBase, _input_nodes), 0, nullptr},
|
||||
{"_args", T_OBJECT_EX, offsetof(NodeBase, _args), 0, nullptr},
|
||||
{"_kwargs", T_OBJECT_EX, offsetof(NodeBase, _kwargs), 0, nullptr},
|
||||
{"users", T_OBJECT_EX, offsetof(NodeBase, users), 0, nullptr},
|
||||
{"_repr_fn", T_OBJECT_EX, offsetof(NodeBase, _repr_fn), 0, nullptr},
|
||||
{"meta", T_OBJECT_EX, offsetof(NodeBase, meta), 0, nullptr},
|
||||
{"_sort_key", T_OBJECT_EX, offsetof(NodeBase, _sort_key), 0, nullptr},
|
||||
{nullptr} /* Sentinel */
|
||||
};
|
||||
|
||||
static int NodeBase_traverse(NodeBase* self, visitproc visit, void* arg) {
|
||||
Py_VISIT(self->_prev);
|
||||
Py_VISIT(self->_next);
|
||||
Py_VISIT(self->graph);
|
||||
Py_VISIT(self->name);
|
||||
Py_VISIT(self->op);
|
||||
Py_VISIT(self->target);
|
||||
Py_VISIT(self->type);
|
||||
Py_VISIT(self->_input_nodes);
|
||||
Py_VISIT(self->_args);
|
||||
Py_VISIT(self->_kwargs);
|
||||
Py_VISIT(self->users);
|
||||
Py_VISIT(self->_repr_fn);
|
||||
Py_VISIT(self->meta);
|
||||
Py_VISIT(self->_sort_key);
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int NodeBase_clear(NodeBase* self) {
|
||||
Py_CLEAR(self->_prev);
|
||||
Py_CLEAR(self->_next);
|
||||
Py_CLEAR(self->graph);
|
||||
Py_CLEAR(self->name);
|
||||
Py_CLEAR(self->op);
|
||||
Py_CLEAR(self->target);
|
||||
Py_CLEAR(self->type);
|
||||
Py_CLEAR(self->_input_nodes);
|
||||
Py_CLEAR(self->_args);
|
||||
Py_CLEAR(self->_kwargs);
|
||||
Py_CLEAR(self->users);
|
||||
Py_CLEAR(self->_repr_fn);
|
||||
Py_CLEAR(self->meta);
|
||||
Py_CLEAR(self->_sort_key);
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
|
@ -200,6 +270,69 @@ static void NodeBase_dealloc(PyObject* self) {
|
|||
Py_TYPE(self)->tp_free(self);
|
||||
}
|
||||
|
||||
static PyObject* NodeBase__update_args_kwargs(
|
||||
PyObject* self,
|
||||
PyObject* const* args,
|
||||
Py_ssize_t nargs) {
|
||||
// Verify argument count
|
||||
if (nargs != 2) {
|
||||
PyErr_SetString(
|
||||
PyExc_TypeError,
|
||||
"_update_args_kwargs() requires exactly 2 arguments (new_args, new_kwargs)");
|
||||
return nullptr;
|
||||
}
|
||||
auto node = reinterpret_cast<NodeBase*>(self);
|
||||
auto input_nodes = node->_input_nodes;
|
||||
if (PyDict_GET_SIZE(input_nodes) > 0) {
|
||||
// Clear other.users containing us and input_nodes
|
||||
PyObject *key = nullptr, *value = nullptr; // borrowed
|
||||
Py_ssize_t pos = 0;
|
||||
while (PyDict_Next(input_nodes, &pos, &key, &value)) {
|
||||
// key.users.pop(self), intentionally ignore KeyError
|
||||
PyDict_DelItem(reinterpret_cast<NodeBase*>(key)->users, self);
|
||||
}
|
||||
PyDict_Clear(input_nodes);
|
||||
}
|
||||
|
||||
auto visit_fn = [self, input_nodes](PyObject* x) {
|
||||
if (is_node(x)) {
|
||||
// self._input_nodes.setdefault(x)
|
||||
if (!PyDict_SetDefault(input_nodes, x, Py_None)) {
|
||||
throw PythonError();
|
||||
}
|
||||
// x.users.setdefault(self)
|
||||
if (!PyDict_SetDefault(
|
||||
reinterpret_cast<NodeBase*>(x)->users, self, Py_None)) {
|
||||
throw PythonError();
|
||||
}
|
||||
}
|
||||
return Py_NewRef(x);
|
||||
};
|
||||
|
||||
// We do three things in a single pass of the args
|
||||
// - Normalize list->immutable_list, dict->immutable_dict, etc
|
||||
// - Populate self._input_nodes
|
||||
// - Populate arg.users[self] for each arg
|
||||
try {
|
||||
Py_CLEAR(node->_args);
|
||||
node->_args = map_aggregate(args[0], visit_fn);
|
||||
Py_CLEAR(node->_kwargs);
|
||||
node->_kwargs = map_aggregate(args[1], visit_fn);
|
||||
Py_RETURN_NONE;
|
||||
} catch (const PythonError& e) {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
|
||||
static PyMethodDef NodeBase_methods[] = {
|
||||
{"_update_args_kwargs",
|
||||
(PyCFunction)(void*)(NodeBase__update_args_kwargs),
|
||||
METH_FASTCALL,
|
||||
"Internal method: do not call directly."},
|
||||
{nullptr, nullptr, 0, nullptr} // Sentinel
|
||||
};
|
||||
|
||||
PyTypeObject NodeBaseType = {
|
||||
PyVarObject_HEAD_INIT(nullptr, 0)
|
||||
"torch._C._NodeBase", /* tp_name */
|
||||
|
|
@ -229,7 +362,7 @@ PyTypeObject NodeBaseType = {
|
|||
0, /* tp_weaklistoffset */
|
||||
nullptr, /* tp_iter */
|
||||
nullptr, /* tp_iternext */
|
||||
nullptr, /* tp_methods */
|
||||
NodeBase_methods, /* tp_methods */
|
||||
NodeBase_members, /* tp_members */
|
||||
nullptr, /* tp_getset */
|
||||
nullptr, /* tp_base */
|
||||
|
|
|
|||
149
torch/fx/node.py
149
torch/fx/node.py
|
|
@ -229,14 +229,38 @@ class Node(_NodeBase):
|
|||
_args: tuple["Argument", ...]
|
||||
_kwargs: dict[str, "Argument"]
|
||||
graph: "Graph"
|
||||
# unique name of value being created
|
||||
name: str
|
||||
# the kind of operation = placeholder|call_method|call_module|call_function|get_attr
|
||||
op: str
|
||||
# for method/module/function, the name of the method/module/function/attr
|
||||
# being invoked, e.g add, layer1, or torch.add
|
||||
target: "Target"
|
||||
# All `Node`-valued inputs. Key is the Node, value is don't-care.
|
||||
# The public API for this is `all_input_nodes`, this private attribute
|
||||
# should not be accessed directly.
|
||||
_input_nodes: dict["Node", None]
|
||||
# All of the nodes that use the value produced by this Node
|
||||
# Note one user may correspond to several uses, e.g. the node fo ``x + x``
|
||||
# would appear once here, but represents two uses.
|
||||
# Is a dict to act as an "ordered set". Keys are significant, value dont-care
|
||||
users: dict["Node", None]
|
||||
# Type expression representing the output value of this node.
|
||||
# This should contain the same class of Type objects that would appear
|
||||
# as type annotations for function inputs/outputs.
|
||||
#
|
||||
# For placeholder nodes, this value will be used to type-annotate the
|
||||
# generated function parameters.
|
||||
# For the return node, this value will be used to type-annotate the
|
||||
# generated function return type. (Note this is a special case. ``return``
|
||||
# does not produce a value, it's more of a notation. Thus, this value
|
||||
# describes the type of args[0] in the ``return`` node.
|
||||
type: Optional[Any]
|
||||
_sort_key: Any
|
||||
# If set, use this fn to print this node
|
||||
_repr_fn: Optional[Callable[["Node"], str]]
|
||||
# Dictionary to store metadata passes need to do their
|
||||
# transformations. This metadata is preserved across node copies
|
||||
meta: dict[str, Any]
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
|
|
@ -276,7 +300,6 @@ class Node(_NodeBase):
|
|||
annotation of values in the generated code or for other types
|
||||
of analyses.
|
||||
"""
|
||||
assert op in _legal_ops
|
||||
if op == "call_function":
|
||||
if not callable(target):
|
||||
raise ValueError(
|
||||
|
|
@ -284,75 +307,38 @@ class Node(_NodeBase):
|
|||
"but a Callable is expected"
|
||||
)
|
||||
else:
|
||||
assert op in _legal_ops
|
||||
if not isinstance(target, str):
|
||||
raise ValueError(
|
||||
f"Node [graph = {graph}, name = '{name}'] target {target} has type {torch.typename(target)} "
|
||||
"but a str is expected"
|
||||
)
|
||||
super().__init__()
|
||||
|
||||
# bypass Node.__setattr__ for perf and so that it doesn't need to handle half-built objects
|
||||
assign = object.__setattr__
|
||||
|
||||
assign(self, "graph", graph)
|
||||
assign(self, "name", name) # unique name of value being created
|
||||
assign(
|
||||
self, "op", op
|
||||
) # the kind of operation = placeholder|call_method|call_module|call_function|get_attr
|
||||
|
||||
assign(
|
||||
self, "target", target
|
||||
) # for method/module/function, the name of the method/module/function/attr
|
||||
# being invoked, e.g add, layer1, or torch.add
|
||||
|
||||
# All `Node`-valued inputs. Key is the Node, value is don't-care.
|
||||
# The public API for this is `all_input_nodes`, this private attribute
|
||||
# should not be accessed directly.
|
||||
assign(self, "_input_nodes", {})
|
||||
self.__update_args_kwargs(args, kwargs)
|
||||
|
||||
# All of the nodes that use the value produced by this Node
|
||||
# Note one user may correspond to several uses, e.g. the node fo ``x + x``
|
||||
# would appear once here, but represents two uses.
|
||||
#
|
||||
# Is a dict to act as an "ordered set". Keys are significant, value dont-care
|
||||
assign(self, "users", {})
|
||||
|
||||
# Type expression representing the output value of this node.
|
||||
# This should contain the same class of Type objects that would appear
|
||||
# as type annotations for function inputs/outputs.
|
||||
#
|
||||
# For placeholder nodes, this value will be used to type-annotate the
|
||||
# generated function parameters.
|
||||
# For the return node, this value will be used to type-annotate the
|
||||
# generated function return type. (Note this is a special case. ``return``
|
||||
# does not produce a value, it's more of a notation. Thus, this value
|
||||
# describes the type of args[0] in the ``return`` node.
|
||||
assign(self, "type", return_type)
|
||||
assign(self, "_sort_key", ())
|
||||
|
||||
# If set, use this fn to print this node
|
||||
assign(self, "_repr_fn", None)
|
||||
|
||||
# Dictionary to store metadata passes need to do their
|
||||
# transformations. This metadata is preserved across node copies
|
||||
assign(self, "meta", {})
|
||||
super().__init__(graph, name, op, target, return_type)
|
||||
self._update_args_kwargs(args, kwargs)
|
||||
|
||||
def __getstate__(self) -> dict[str, Any]:
|
||||
state = self.__dict__.copy()
|
||||
state["_erased"] = self._erased
|
||||
state["_prev"] = self._prev
|
||||
state["_next"] = self._next
|
||||
return state
|
||||
return {
|
||||
**self.__dict__,
|
||||
"graph": self.graph,
|
||||
"name": self.name,
|
||||
"op": self.op,
|
||||
"target": self.target,
|
||||
"type": self.target,
|
||||
"_sort_key": self._sort_key,
|
||||
"_args": self._args,
|
||||
"_kwargs": self._kwargs,
|
||||
"_erased": self._erased,
|
||||
"_prev": self._prev,
|
||||
"_next": self._next,
|
||||
"_input_nodes": self._input_nodes,
|
||||
"users": self.users,
|
||||
"_repr_fn": self._repr_fn,
|
||||
"meta": self.meta,
|
||||
}
|
||||
|
||||
def __setstate__(self, state: dict[str, Any]) -> None:
|
||||
_erased = state.pop("_erased")
|
||||
_prev = state.pop("_prev")
|
||||
_next = state.pop("_next")
|
||||
self.__dict__.update(state)
|
||||
self._erased = _erased
|
||||
self._prev = _prev
|
||||
self._next = _next
|
||||
for k, v in state.items():
|
||||
setattr(self, k, v)
|
||||
|
||||
@property
|
||||
def next(self) -> "Node":
|
||||
|
|
@ -459,9 +445,9 @@ class Node(_NodeBase):
|
|||
depends on the node's opcode. See the ``fx.Graph`` docstring for more
|
||||
information.
|
||||
"""
|
||||
# DO NOT CALL `__update_args_kwargs` directly. The correct way to
|
||||
# DO NOT CALL `_update_args_kwargs` directly. The correct way to
|
||||
# set `args` is via direct assignment, i.e. `node.args = new_args`
|
||||
self.__update_args_kwargs(a, self._kwargs)
|
||||
self._update_args_kwargs(a, self._kwargs)
|
||||
|
||||
@property
|
||||
def kwargs(self) -> dict[str, Argument]:
|
||||
|
|
@ -482,9 +468,9 @@ class Node(_NodeBase):
|
|||
depends on the node's opcode. See the ``fx.Graph`` docstring for more
|
||||
information.
|
||||
"""
|
||||
# DO NOT CALL `__update_args_kwargs` directly. The correct way to
|
||||
# DO NOT CALL `_update_args_kwargs` directly. The correct way to
|
||||
# set `args` is via direct assignment, i.e. `node.kwargs = new_kwargs`
|
||||
self.__update_args_kwargs(self._args, k)
|
||||
self._update_args_kwargs(self._args, k)
|
||||
|
||||
@property
|
||||
def all_input_nodes(self) -> list["Node"]:
|
||||
|
|
@ -572,35 +558,6 @@ class Node(_NodeBase):
|
|||
def stack_trace(self, trace: Optional[str]) -> None:
|
||||
self.meta["stack_trace"] = trace
|
||||
|
||||
def __update_args_kwargs(
|
||||
self, new_args: tuple["Argument", ...], new_kwargs: dict[str, "Argument"]
|
||||
) -> None:
|
||||
"""
|
||||
This API is internal. Do *not* call it directly.
|
||||
"""
|
||||
|
||||
def update_users_and_input_nodes(n: Any) -> Any:
|
||||
if isinstance(n, Node):
|
||||
self._input_nodes.setdefault(n)
|
||||
n.users.setdefault(self)
|
||||
return n
|
||||
|
||||
# Clear prior users and input_nodes
|
||||
for old_use in self._input_nodes.keys():
|
||||
old_use.users.pop(self)
|
||||
object.__setattr__(self, "_input_nodes", {}) # bypass Node.__setattr__
|
||||
|
||||
# We do three things in a single pass of the args
|
||||
# - Normalize list->immutable_list, dict->immutable_dict, etc
|
||||
# - Populate self._input_nodes
|
||||
# - Populate arg.users[self] for each arg
|
||||
object.__setattr__(
|
||||
self, "_args", _fx_map_aggregate(new_args, update_users_and_input_nodes)
|
||||
)
|
||||
object.__setattr__(
|
||||
self, "_kwargs", _fx_map_aggregate(new_kwargs, update_users_and_input_nodes)
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
if self._repr_fn:
|
||||
return self._repr_fn(self)
|
||||
|
|
@ -751,7 +708,7 @@ class Node(_NodeBase):
|
|||
new_kwargs = _fx_map_arg(use_node.kwargs, maybe_replace_node)
|
||||
assert isinstance(new_args, tuple)
|
||||
assert isinstance(new_kwargs, dict)
|
||||
use_node.__update_args_kwargs(new_args, new_kwargs)
|
||||
use_node._update_args_kwargs(new_args, new_kwargs)
|
||||
|
||||
assert len(self.users) - len(skipped) == 0
|
||||
return [n for n in to_process if n not in skipped]
|
||||
|
|
@ -863,7 +820,7 @@ class Node(_NodeBase):
|
|||
new_kwargs = _fx_map_arg(self.kwargs, maybe_replace_node)
|
||||
assert isinstance(new_args, tuple)
|
||||
assert isinstance(new_kwargs, dict)
|
||||
self.__update_args_kwargs(new_args, new_kwargs)
|
||||
self._update_args_kwargs(new_args, new_kwargs)
|
||||
|
||||
def _rename(self, candidate: str) -> None:
|
||||
if candidate == self.name:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user