mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[fx] Move Node._update_args_kwargs to C++ (#148260)
Microbenchmarking `fx.symbolic_trace(lambda x: functools.reduce(operator.add, [x, *range(100000)]))`, before: ``` 25203549 function calls (24403352 primitive calls) in 12.090 seconds ``` after: ``` 24303536 function calls (23503339 primitive calls) in 10.726 seconds ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/148260 Approved by: https://github.com/oulgen ghstack dependencies: #148243
This commit is contained in:
parent
edaff88f69
commit
0135f57f4a
|
|
@ -1,32 +1,32 @@
|
||||||
add_loop_eager,compile_time_instruction_count,2993000000,0.015
|
add_loop_eager,compile_time_instruction_count,2958000000,0.015
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
add_loop_eager_dynamic,compile_time_instruction_count,6349000000,0.025
|
add_loop_eager_dynamic,compile_time_instruction_count,6250000000,0.025
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
add_loop_inductor,compile_time_instruction_count,28630000000,0.015
|
add_loop_inductor,compile_time_instruction_count,28450000000,0.015
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,45240000000,0.025
|
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,44690000000,0.025
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
add_loop_inductor_gpu,compile_time_instruction_count,24960000000,0.015
|
add_loop_inductor_gpu,compile_time_instruction_count,24770000000,0.015
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
basic_modules_ListOfLinears_eager,compile_time_instruction_count,960700000,0.015
|
basic_modules_ListOfLinears_eager,compile_time_instruction_count,959200000,0.015
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
basic_modules_ListOfLinears_inductor,compile_time_instruction_count,18120000000,0.015
|
basic_modules_ListOfLinears_inductor,compile_time_instruction_count,17950000000,0.015
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,16340000000,0.015
|
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,16030000000,0.015
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -34,32 +34,32 @@ basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,98740000
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
update_hint_regression,compile_time_instruction_count,1699000000,0.02
|
update_hint_regression,compile_time_instruction_count,1683000000,0.02
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
sum_floordiv_regression,compile_time_instruction_count,1061000000,0.015
|
sum_floordiv_regression,compile_time_instruction_count,1054000000,0.015
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
symint_sum,compile_time_instruction_count,3194000000,0.015
|
symint_sum,compile_time_instruction_count,3167000000,0.015
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2018000000,0.015
|
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2010000000,0.015
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5792000000,0.015
|
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5776000000,0.015
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
aotdispatcher_partitioner_cpu,compile_time_instruction_count,8703000000,0.015
|
aotdispatcher_partitioner_cpu,compile_time_instruction_count,8521000000,0.015
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3783000000,0.015
|
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3735000000,0.015
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10170000000,0.015
|
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10070000000,0.015
|
||||||
|
|
|
||||||
|
|
|
@ -2505,6 +2505,15 @@ class _NodeBase:
|
||||||
_erased: _bool
|
_erased: _bool
|
||||||
_prev: FxNode
|
_prev: FxNode
|
||||||
_next: FxNode
|
_next: FxNode
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
graph: Any,
|
||||||
|
name: str,
|
||||||
|
op: str,
|
||||||
|
target: Any,
|
||||||
|
return_type: Any,
|
||||||
|
) -> None: ...
|
||||||
|
def _update_args_kwargs(self, args: tuple[Any, ...], kwargs: dict[str, Any]): ...
|
||||||
|
|
||||||
class _NodeIter(Iterator):
|
class _NodeIter(Iterator):
|
||||||
def __init__(self, root: FxNode, reversed: _bool) -> None: ...
|
def __init__(self, root: FxNode, reversed: _bool) -> None: ...
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,8 @@
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
struct NodeBase;
|
||||||
|
|
||||||
// Thrown to exit out of a C++ function and return an error to Python.
|
// Thrown to exit out of a C++ function and return an error to Python.
|
||||||
class PythonError : public std::exception {};
|
class PythonError : public std::exception {};
|
||||||
|
|
||||||
|
|
@ -153,6 +155,18 @@ struct NodeBase {
|
||||||
bool _erased;
|
bool _erased;
|
||||||
NodeBase* _prev;
|
NodeBase* _prev;
|
||||||
NodeBase* _next;
|
NodeBase* _next;
|
||||||
|
PyObject* graph;
|
||||||
|
PyObject* name;
|
||||||
|
PyObject* op;
|
||||||
|
PyObject* target;
|
||||||
|
PyObject* type;
|
||||||
|
PyObject* _input_nodes;
|
||||||
|
PyObject* _args;
|
||||||
|
PyObject* _kwargs;
|
||||||
|
PyObject* users;
|
||||||
|
PyObject* _repr_fn;
|
||||||
|
PyObject* meta;
|
||||||
|
PyObject* _sort_key;
|
||||||
};
|
};
|
||||||
|
|
||||||
static PyObject* NodeBase_new(
|
static PyObject* NodeBase_new(
|
||||||
|
|
@ -166,11 +180,31 @@ static PyObject* NodeBase_new(
|
||||||
}
|
}
|
||||||
|
|
||||||
static int NodeBase_init_fn(NodeBase* self, PyObject* args, PyObject* kwds) {
|
static int NodeBase_init_fn(NodeBase* self, PyObject* args, PyObject* kwds) {
|
||||||
|
PyObject* graph = nullptr;
|
||||||
|
PyObject* name = nullptr;
|
||||||
|
PyObject* op = nullptr;
|
||||||
|
PyObject* target = nullptr;
|
||||||
|
PyObject* type = nullptr;
|
||||||
|
if (!PyArg_ParseTuple(args, "OOOOO", &graph, &name, &op, &target, &type)) {
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
self->_erased = false;
|
self->_erased = false;
|
||||||
Py_INCREF(self);
|
Py_INCREF(self);
|
||||||
self->_prev = self;
|
self->_prev = self;
|
||||||
Py_INCREF(self);
|
Py_INCREF(self);
|
||||||
self->_next = self;
|
self->_next = self;
|
||||||
|
self->graph = Py_NewRef(graph);
|
||||||
|
self->name = Py_NewRef(name);
|
||||||
|
self->op = Py_NewRef(op);
|
||||||
|
self->target = Py_NewRef(target);
|
||||||
|
self->type = Py_NewRef(type);
|
||||||
|
self->_input_nodes = PyDict_New();
|
||||||
|
self->_args = nullptr; // set with _update_args_kwargs
|
||||||
|
self->_kwargs = nullptr; // set with _update_args_kwargs
|
||||||
|
self->users = PyDict_New();
|
||||||
|
self->_repr_fn = Py_NewRef(Py_None);
|
||||||
|
self->meta = PyDict_New();
|
||||||
|
self->_sort_key = PyTuple_New(0);
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -179,18 +213,54 @@ static struct PyMemberDef NodeBase_members[] = {
|
||||||
{"_erased", T_BOOL, offsetof(NodeBase, _erased), 0, nullptr},
|
{"_erased", T_BOOL, offsetof(NodeBase, _erased), 0, nullptr},
|
||||||
{"_prev", T_OBJECT_EX, offsetof(NodeBase, _prev), 0, nullptr},
|
{"_prev", T_OBJECT_EX, offsetof(NodeBase, _prev), 0, nullptr},
|
||||||
{"_next", T_OBJECT_EX, offsetof(NodeBase, _next), 0, nullptr},
|
{"_next", T_OBJECT_EX, offsetof(NodeBase, _next), 0, nullptr},
|
||||||
|
{"graph", T_OBJECT_EX, offsetof(NodeBase, graph), 0, nullptr},
|
||||||
|
{"name", T_OBJECT_EX, offsetof(NodeBase, name), 0, nullptr},
|
||||||
|
{"op", T_OBJECT_EX, offsetof(NodeBase, op), 0, nullptr},
|
||||||
|
{"target", T_OBJECT_EX, offsetof(NodeBase, target), 0, nullptr},
|
||||||
|
{"type", T_OBJECT_EX, offsetof(NodeBase, type), 0, nullptr},
|
||||||
|
{"_input_nodes", T_OBJECT_EX, offsetof(NodeBase, _input_nodes), 0, nullptr},
|
||||||
|
{"_args", T_OBJECT_EX, offsetof(NodeBase, _args), 0, nullptr},
|
||||||
|
{"_kwargs", T_OBJECT_EX, offsetof(NodeBase, _kwargs), 0, nullptr},
|
||||||
|
{"users", T_OBJECT_EX, offsetof(NodeBase, users), 0, nullptr},
|
||||||
|
{"_repr_fn", T_OBJECT_EX, offsetof(NodeBase, _repr_fn), 0, nullptr},
|
||||||
|
{"meta", T_OBJECT_EX, offsetof(NodeBase, meta), 0, nullptr},
|
||||||
|
{"_sort_key", T_OBJECT_EX, offsetof(NodeBase, _sort_key), 0, nullptr},
|
||||||
{nullptr} /* Sentinel */
|
{nullptr} /* Sentinel */
|
||||||
};
|
};
|
||||||
|
|
||||||
static int NodeBase_traverse(NodeBase* self, visitproc visit, void* arg) {
|
static int NodeBase_traverse(NodeBase* self, visitproc visit, void* arg) {
|
||||||
Py_VISIT(self->_prev);
|
Py_VISIT(self->_prev);
|
||||||
Py_VISIT(self->_next);
|
Py_VISIT(self->_next);
|
||||||
|
Py_VISIT(self->graph);
|
||||||
|
Py_VISIT(self->name);
|
||||||
|
Py_VISIT(self->op);
|
||||||
|
Py_VISIT(self->target);
|
||||||
|
Py_VISIT(self->type);
|
||||||
|
Py_VISIT(self->_input_nodes);
|
||||||
|
Py_VISIT(self->_args);
|
||||||
|
Py_VISIT(self->_kwargs);
|
||||||
|
Py_VISIT(self->users);
|
||||||
|
Py_VISIT(self->_repr_fn);
|
||||||
|
Py_VISIT(self->meta);
|
||||||
|
Py_VISIT(self->_sort_key);
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
static int NodeBase_clear(NodeBase* self) {
|
static int NodeBase_clear(NodeBase* self) {
|
||||||
Py_CLEAR(self->_prev);
|
Py_CLEAR(self->_prev);
|
||||||
Py_CLEAR(self->_next);
|
Py_CLEAR(self->_next);
|
||||||
|
Py_CLEAR(self->graph);
|
||||||
|
Py_CLEAR(self->name);
|
||||||
|
Py_CLEAR(self->op);
|
||||||
|
Py_CLEAR(self->target);
|
||||||
|
Py_CLEAR(self->type);
|
||||||
|
Py_CLEAR(self->_input_nodes);
|
||||||
|
Py_CLEAR(self->_args);
|
||||||
|
Py_CLEAR(self->_kwargs);
|
||||||
|
Py_CLEAR(self->users);
|
||||||
|
Py_CLEAR(self->_repr_fn);
|
||||||
|
Py_CLEAR(self->meta);
|
||||||
|
Py_CLEAR(self->_sort_key);
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -200,6 +270,69 @@ static void NodeBase_dealloc(PyObject* self) {
|
||||||
Py_TYPE(self)->tp_free(self);
|
Py_TYPE(self)->tp_free(self);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static PyObject* NodeBase__update_args_kwargs(
|
||||||
|
PyObject* self,
|
||||||
|
PyObject* const* args,
|
||||||
|
Py_ssize_t nargs) {
|
||||||
|
// Verify argument count
|
||||||
|
if (nargs != 2) {
|
||||||
|
PyErr_SetString(
|
||||||
|
PyExc_TypeError,
|
||||||
|
"_update_args_kwargs() requires exactly 2 arguments (new_args, new_kwargs)");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
auto node = reinterpret_cast<NodeBase*>(self);
|
||||||
|
auto input_nodes = node->_input_nodes;
|
||||||
|
if (PyDict_GET_SIZE(input_nodes) > 0) {
|
||||||
|
// Clear other.users containing us and input_nodes
|
||||||
|
PyObject *key = nullptr, *value = nullptr; // borrowed
|
||||||
|
Py_ssize_t pos = 0;
|
||||||
|
while (PyDict_Next(input_nodes, &pos, &key, &value)) {
|
||||||
|
// key.users.pop(self), intentionally ignore KeyError
|
||||||
|
PyDict_DelItem(reinterpret_cast<NodeBase*>(key)->users, self);
|
||||||
|
}
|
||||||
|
PyDict_Clear(input_nodes);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto visit_fn = [self, input_nodes](PyObject* x) {
|
||||||
|
if (is_node(x)) {
|
||||||
|
// self._input_nodes.setdefault(x)
|
||||||
|
if (!PyDict_SetDefault(input_nodes, x, Py_None)) {
|
||||||
|
throw PythonError();
|
||||||
|
}
|
||||||
|
// x.users.setdefault(self)
|
||||||
|
if (!PyDict_SetDefault(
|
||||||
|
reinterpret_cast<NodeBase*>(x)->users, self, Py_None)) {
|
||||||
|
throw PythonError();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Py_NewRef(x);
|
||||||
|
};
|
||||||
|
|
||||||
|
// We do three things in a single pass of the args
|
||||||
|
// - Normalize list->immutable_list, dict->immutable_dict, etc
|
||||||
|
// - Populate self._input_nodes
|
||||||
|
// - Populate arg.users[self] for each arg
|
||||||
|
try {
|
||||||
|
Py_CLEAR(node->_args);
|
||||||
|
node->_args = map_aggregate(args[0], visit_fn);
|
||||||
|
Py_CLEAR(node->_kwargs);
|
||||||
|
node->_kwargs = map_aggregate(args[1], visit_fn);
|
||||||
|
Py_RETURN_NONE;
|
||||||
|
} catch (const PythonError& e) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
|
||||||
|
static PyMethodDef NodeBase_methods[] = {
|
||||||
|
{"_update_args_kwargs",
|
||||||
|
(PyCFunction)(void*)(NodeBase__update_args_kwargs),
|
||||||
|
METH_FASTCALL,
|
||||||
|
"Internal method: do not call directly."},
|
||||||
|
{nullptr, nullptr, 0, nullptr} // Sentinel
|
||||||
|
};
|
||||||
|
|
||||||
PyTypeObject NodeBaseType = {
|
PyTypeObject NodeBaseType = {
|
||||||
PyVarObject_HEAD_INIT(nullptr, 0)
|
PyVarObject_HEAD_INIT(nullptr, 0)
|
||||||
"torch._C._NodeBase", /* tp_name */
|
"torch._C._NodeBase", /* tp_name */
|
||||||
|
|
@ -229,7 +362,7 @@ PyTypeObject NodeBaseType = {
|
||||||
0, /* tp_weaklistoffset */
|
0, /* tp_weaklistoffset */
|
||||||
nullptr, /* tp_iter */
|
nullptr, /* tp_iter */
|
||||||
nullptr, /* tp_iternext */
|
nullptr, /* tp_iternext */
|
||||||
nullptr, /* tp_methods */
|
NodeBase_methods, /* tp_methods */
|
||||||
NodeBase_members, /* tp_members */
|
NodeBase_members, /* tp_members */
|
||||||
nullptr, /* tp_getset */
|
nullptr, /* tp_getset */
|
||||||
nullptr, /* tp_base */
|
nullptr, /* tp_base */
|
||||||
|
|
|
||||||
149
torch/fx/node.py
149
torch/fx/node.py
|
|
@ -229,14 +229,38 @@ class Node(_NodeBase):
|
||||||
_args: tuple["Argument", ...]
|
_args: tuple["Argument", ...]
|
||||||
_kwargs: dict[str, "Argument"]
|
_kwargs: dict[str, "Argument"]
|
||||||
graph: "Graph"
|
graph: "Graph"
|
||||||
|
# unique name of value being created
|
||||||
name: str
|
name: str
|
||||||
|
# the kind of operation = placeholder|call_method|call_module|call_function|get_attr
|
||||||
op: str
|
op: str
|
||||||
|
# for method/module/function, the name of the method/module/function/attr
|
||||||
|
# being invoked, e.g add, layer1, or torch.add
|
||||||
target: "Target"
|
target: "Target"
|
||||||
|
# All `Node`-valued inputs. Key is the Node, value is don't-care.
|
||||||
|
# The public API for this is `all_input_nodes`, this private attribute
|
||||||
|
# should not be accessed directly.
|
||||||
_input_nodes: dict["Node", None]
|
_input_nodes: dict["Node", None]
|
||||||
|
# All of the nodes that use the value produced by this Node
|
||||||
|
# Note one user may correspond to several uses, e.g. the node fo ``x + x``
|
||||||
|
# would appear once here, but represents two uses.
|
||||||
|
# Is a dict to act as an "ordered set". Keys are significant, value dont-care
|
||||||
users: dict["Node", None]
|
users: dict["Node", None]
|
||||||
|
# Type expression representing the output value of this node.
|
||||||
|
# This should contain the same class of Type objects that would appear
|
||||||
|
# as type annotations for function inputs/outputs.
|
||||||
|
#
|
||||||
|
# For placeholder nodes, this value will be used to type-annotate the
|
||||||
|
# generated function parameters.
|
||||||
|
# For the return node, this value will be used to type-annotate the
|
||||||
|
# generated function return type. (Note this is a special case. ``return``
|
||||||
|
# does not produce a value, it's more of a notation. Thus, this value
|
||||||
|
# describes the type of args[0] in the ``return`` node.
|
||||||
type: Optional[Any]
|
type: Optional[Any]
|
||||||
_sort_key: Any
|
_sort_key: Any
|
||||||
|
# If set, use this fn to print this node
|
||||||
_repr_fn: Optional[Callable[["Node"], str]]
|
_repr_fn: Optional[Callable[["Node"], str]]
|
||||||
|
# Dictionary to store metadata passes need to do their
|
||||||
|
# transformations. This metadata is preserved across node copies
|
||||||
meta: dict[str, Any]
|
meta: dict[str, Any]
|
||||||
|
|
||||||
@compatibility(is_backward_compatible=True)
|
@compatibility(is_backward_compatible=True)
|
||||||
|
|
@ -276,7 +300,6 @@ class Node(_NodeBase):
|
||||||
annotation of values in the generated code or for other types
|
annotation of values in the generated code or for other types
|
||||||
of analyses.
|
of analyses.
|
||||||
"""
|
"""
|
||||||
assert op in _legal_ops
|
|
||||||
if op == "call_function":
|
if op == "call_function":
|
||||||
if not callable(target):
|
if not callable(target):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|
@ -284,75 +307,38 @@ class Node(_NodeBase):
|
||||||
"but a Callable is expected"
|
"but a Callable is expected"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
assert op in _legal_ops
|
||||||
if not isinstance(target, str):
|
if not isinstance(target, str):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Node [graph = {graph}, name = '{name}'] target {target} has type {torch.typename(target)} "
|
f"Node [graph = {graph}, name = '{name}'] target {target} has type {torch.typename(target)} "
|
||||||
"but a str is expected"
|
"but a str is expected"
|
||||||
)
|
)
|
||||||
super().__init__()
|
super().__init__(graph, name, op, target, return_type)
|
||||||
|
self._update_args_kwargs(args, kwargs)
|
||||||
# bypass Node.__setattr__ for perf and so that it doesn't need to handle half-built objects
|
|
||||||
assign = object.__setattr__
|
|
||||||
|
|
||||||
assign(self, "graph", graph)
|
|
||||||
assign(self, "name", name) # unique name of value being created
|
|
||||||
assign(
|
|
||||||
self, "op", op
|
|
||||||
) # the kind of operation = placeholder|call_method|call_module|call_function|get_attr
|
|
||||||
|
|
||||||
assign(
|
|
||||||
self, "target", target
|
|
||||||
) # for method/module/function, the name of the method/module/function/attr
|
|
||||||
# being invoked, e.g add, layer1, or torch.add
|
|
||||||
|
|
||||||
# All `Node`-valued inputs. Key is the Node, value is don't-care.
|
|
||||||
# The public API for this is `all_input_nodes`, this private attribute
|
|
||||||
# should not be accessed directly.
|
|
||||||
assign(self, "_input_nodes", {})
|
|
||||||
self.__update_args_kwargs(args, kwargs)
|
|
||||||
|
|
||||||
# All of the nodes that use the value produced by this Node
|
|
||||||
# Note one user may correspond to several uses, e.g. the node fo ``x + x``
|
|
||||||
# would appear once here, but represents two uses.
|
|
||||||
#
|
|
||||||
# Is a dict to act as an "ordered set". Keys are significant, value dont-care
|
|
||||||
assign(self, "users", {})
|
|
||||||
|
|
||||||
# Type expression representing the output value of this node.
|
|
||||||
# This should contain the same class of Type objects that would appear
|
|
||||||
# as type annotations for function inputs/outputs.
|
|
||||||
#
|
|
||||||
# For placeholder nodes, this value will be used to type-annotate the
|
|
||||||
# generated function parameters.
|
|
||||||
# For the return node, this value will be used to type-annotate the
|
|
||||||
# generated function return type. (Note this is a special case. ``return``
|
|
||||||
# does not produce a value, it's more of a notation. Thus, this value
|
|
||||||
# describes the type of args[0] in the ``return`` node.
|
|
||||||
assign(self, "type", return_type)
|
|
||||||
assign(self, "_sort_key", ())
|
|
||||||
|
|
||||||
# If set, use this fn to print this node
|
|
||||||
assign(self, "_repr_fn", None)
|
|
||||||
|
|
||||||
# Dictionary to store metadata passes need to do their
|
|
||||||
# transformations. This metadata is preserved across node copies
|
|
||||||
assign(self, "meta", {})
|
|
||||||
|
|
||||||
def __getstate__(self) -> dict[str, Any]:
|
def __getstate__(self) -> dict[str, Any]:
|
||||||
state = self.__dict__.copy()
|
return {
|
||||||
state["_erased"] = self._erased
|
**self.__dict__,
|
||||||
state["_prev"] = self._prev
|
"graph": self.graph,
|
||||||
state["_next"] = self._next
|
"name": self.name,
|
||||||
return state
|
"op": self.op,
|
||||||
|
"target": self.target,
|
||||||
|
"type": self.target,
|
||||||
|
"_sort_key": self._sort_key,
|
||||||
|
"_args": self._args,
|
||||||
|
"_kwargs": self._kwargs,
|
||||||
|
"_erased": self._erased,
|
||||||
|
"_prev": self._prev,
|
||||||
|
"_next": self._next,
|
||||||
|
"_input_nodes": self._input_nodes,
|
||||||
|
"users": self.users,
|
||||||
|
"_repr_fn": self._repr_fn,
|
||||||
|
"meta": self.meta,
|
||||||
|
}
|
||||||
|
|
||||||
def __setstate__(self, state: dict[str, Any]) -> None:
|
def __setstate__(self, state: dict[str, Any]) -> None:
|
||||||
_erased = state.pop("_erased")
|
for k, v in state.items():
|
||||||
_prev = state.pop("_prev")
|
setattr(self, k, v)
|
||||||
_next = state.pop("_next")
|
|
||||||
self.__dict__.update(state)
|
|
||||||
self._erased = _erased
|
|
||||||
self._prev = _prev
|
|
||||||
self._next = _next
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def next(self) -> "Node":
|
def next(self) -> "Node":
|
||||||
|
|
@ -459,9 +445,9 @@ class Node(_NodeBase):
|
||||||
depends on the node's opcode. See the ``fx.Graph`` docstring for more
|
depends on the node's opcode. See the ``fx.Graph`` docstring for more
|
||||||
information.
|
information.
|
||||||
"""
|
"""
|
||||||
# DO NOT CALL `__update_args_kwargs` directly. The correct way to
|
# DO NOT CALL `_update_args_kwargs` directly. The correct way to
|
||||||
# set `args` is via direct assignment, i.e. `node.args = new_args`
|
# set `args` is via direct assignment, i.e. `node.args = new_args`
|
||||||
self.__update_args_kwargs(a, self._kwargs)
|
self._update_args_kwargs(a, self._kwargs)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def kwargs(self) -> dict[str, Argument]:
|
def kwargs(self) -> dict[str, Argument]:
|
||||||
|
|
@ -482,9 +468,9 @@ class Node(_NodeBase):
|
||||||
depends on the node's opcode. See the ``fx.Graph`` docstring for more
|
depends on the node's opcode. See the ``fx.Graph`` docstring for more
|
||||||
information.
|
information.
|
||||||
"""
|
"""
|
||||||
# DO NOT CALL `__update_args_kwargs` directly. The correct way to
|
# DO NOT CALL `_update_args_kwargs` directly. The correct way to
|
||||||
# set `args` is via direct assignment, i.e. `node.kwargs = new_kwargs`
|
# set `args` is via direct assignment, i.e. `node.kwargs = new_kwargs`
|
||||||
self.__update_args_kwargs(self._args, k)
|
self._update_args_kwargs(self._args, k)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def all_input_nodes(self) -> list["Node"]:
|
def all_input_nodes(self) -> list["Node"]:
|
||||||
|
|
@ -572,35 +558,6 @@ class Node(_NodeBase):
|
||||||
def stack_trace(self, trace: Optional[str]) -> None:
|
def stack_trace(self, trace: Optional[str]) -> None:
|
||||||
self.meta["stack_trace"] = trace
|
self.meta["stack_trace"] = trace
|
||||||
|
|
||||||
def __update_args_kwargs(
|
|
||||||
self, new_args: tuple["Argument", ...], new_kwargs: dict[str, "Argument"]
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
This API is internal. Do *not* call it directly.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def update_users_and_input_nodes(n: Any) -> Any:
|
|
||||||
if isinstance(n, Node):
|
|
||||||
self._input_nodes.setdefault(n)
|
|
||||||
n.users.setdefault(self)
|
|
||||||
return n
|
|
||||||
|
|
||||||
# Clear prior users and input_nodes
|
|
||||||
for old_use in self._input_nodes.keys():
|
|
||||||
old_use.users.pop(self)
|
|
||||||
object.__setattr__(self, "_input_nodes", {}) # bypass Node.__setattr__
|
|
||||||
|
|
||||||
# We do three things in a single pass of the args
|
|
||||||
# - Normalize list->immutable_list, dict->immutable_dict, etc
|
|
||||||
# - Populate self._input_nodes
|
|
||||||
# - Populate arg.users[self] for each arg
|
|
||||||
object.__setattr__(
|
|
||||||
self, "_args", _fx_map_aggregate(new_args, update_users_and_input_nodes)
|
|
||||||
)
|
|
||||||
object.__setattr__(
|
|
||||||
self, "_kwargs", _fx_map_aggregate(new_kwargs, update_users_and_input_nodes)
|
|
||||||
)
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
if self._repr_fn:
|
if self._repr_fn:
|
||||||
return self._repr_fn(self)
|
return self._repr_fn(self)
|
||||||
|
|
@ -751,7 +708,7 @@ class Node(_NodeBase):
|
||||||
new_kwargs = _fx_map_arg(use_node.kwargs, maybe_replace_node)
|
new_kwargs = _fx_map_arg(use_node.kwargs, maybe_replace_node)
|
||||||
assert isinstance(new_args, tuple)
|
assert isinstance(new_args, tuple)
|
||||||
assert isinstance(new_kwargs, dict)
|
assert isinstance(new_kwargs, dict)
|
||||||
use_node.__update_args_kwargs(new_args, new_kwargs)
|
use_node._update_args_kwargs(new_args, new_kwargs)
|
||||||
|
|
||||||
assert len(self.users) - len(skipped) == 0
|
assert len(self.users) - len(skipped) == 0
|
||||||
return [n for n in to_process if n not in skipped]
|
return [n for n in to_process if n not in skipped]
|
||||||
|
|
@ -863,7 +820,7 @@ class Node(_NodeBase):
|
||||||
new_kwargs = _fx_map_arg(self.kwargs, maybe_replace_node)
|
new_kwargs = _fx_map_arg(self.kwargs, maybe_replace_node)
|
||||||
assert isinstance(new_args, tuple)
|
assert isinstance(new_args, tuple)
|
||||||
assert isinstance(new_kwargs, dict)
|
assert isinstance(new_kwargs, dict)
|
||||||
self.__update_args_kwargs(new_args, new_kwargs)
|
self._update_args_kwargs(new_args, new_kwargs)
|
||||||
|
|
||||||
def _rename(self, candidate: str) -> None:
|
def _rename(self, candidate: str) -> None:
|
||||||
if candidate == self.name:
|
if candidate == self.name:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user