From e8fadba28cc225782ddce1bf0ef8315ee1743761 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Sat, 1 Nov 2025 08:48:26 +0800 Subject: [PATCH] [pytree] add `treespec_{leaf,tuple,dict}` functions for args_spec modification (#160843) The goal of this PR is to provide a standard way to create simple treespec instances and hide the implementation details of the `PyTreeSpec` class. Changes: 1. Add function `treespec_leaf()` to replace `LeafSpec()`. 2. Add function `treespec_tuple(...)` and `treespec_dict(...)` to create treespec for `tuple` / `dict` which is used for `*args` / `**kwargs`. This avoids direct modification to `treespec` instances that rely on the implementation details of the `PyTreeSpec` class. 3. Change `len(spec.children_specs)` to `spec.num_children`. 4. Change `isinstance(spec, LeafSpec)` to `spec.is_leaf()`. ------ Pull Request resolved: https://github.com/pytorch/pytorch/pull/160843 Approved by: https://github.com/mlazos --- test/export/test_export.py | 10 +- test/test_pytree.py | 86 +++++++---- torch/_dynamo/polyfills/pytree.py | 109 ++++++++++++- torch/_dynamo/variables/builder.py | 4 +- torch/_dynamo/variables/higher_order_ops.py | 4 +- torch/_export/serde/serialize.py | 6 +- torch/_functorch/_aot_autograd/utils.py | 2 +- torch/_functorch/aot_autograd.py | 6 +- torch/_inductor/compile_fx.py | 2 +- torch/ao/quantization/pt2e/utils.py | 13 +- torch/export/_swap.py | 13 +- torch/export/_trace.py | 2 +- torch/export/_tree_utils.py | 6 +- torch/export/_unlift.py | 12 +- torch/export/dynamic_shapes.py | 63 ++++---- torch/export/exported_program.py | 2 +- torch/export/unflatten.py | 8 +- torch/fx/_pytree.py | 2 +- torch/fx/_symbolic_trace.py | 2 +- torch/fx/graph.py | 13 +- torch/utils/_cxx_pytree.py | 42 ++++- torch/utils/_pytree.py | 161 +++++++++++++++----- 22 files changed, 404 insertions(+), 164 deletions(-) diff --git a/test/export/test_export.py b/test/export/test_export.py index 61b7b886a71..3908f03b11e 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -91,13 +91,13 @@ from torch.testing._internal.torchbind_impls import load_torchbind_test_lib from torch.testing._internal.triton_utils import requires_cuda_and_triton, requires_gpu from torch.testing._internal.two_tensor import TwoTensor from torch.utils._pytree import ( - LeafSpec, register_constant, tree_flatten, tree_map, tree_unflatten, TreeSpec, treespec_dumps, + treespec_leaf, treespec_loads, ) @@ -7791,7 +7791,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x): dt = MyDataClass(x=3, y=4) flat, spec = tree_flatten(dt) - self.assertTrue(spec, LeafSpec()) + self.assertTrue(spec, treespec_leaf()) self.assertTrue(len(flat) == 1) torch.export.register_dataclass( @@ -7802,7 +7802,9 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x): flat, spec = tree_flatten(dt) self.assertEqual( spec, - TreeSpec(MyDataClass, [["x", "y"], ["z"]], [LeafSpec(), LeafSpec()]), + TreeSpec( + MyDataClass, [["x", "y"], ["z"]], [treespec_leaf(), treespec_leaf()] + ), ) self.assertEqual(flat, [3, 4]) @@ -7835,7 +7837,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x): TreeSpec( MyOtherDataClass, [["x", "y", "z"], []], - [LeafSpec(), LeafSpec(), LeafSpec()], + [treespec_leaf(), treespec_leaf(), treespec_leaf()], ), ) self.assertEqual(flat, [3, 4, None]) diff --git a/test/test_pytree.py b/test/test_pytree.py index e19f1471267..7cc3b8affc0 100644 --- a/test/test_pytree.py +++ b/test/test_pytree.py @@ -65,9 +65,6 @@ class TestEnum(enum.Enum): A = auto() -python_leafspec = python_pytree.LeafSpec() - - class TestGenericPytree(TestCase): def test_aligned_public_apis(self): public_apis = python_pytree.__all__ @@ -197,7 +194,7 @@ class TestGenericPytree(TestCase): def run_test_with_leaf(leaf): values, treespec = pytree.tree_flatten(leaf) self.assertEqual(values, [leaf]) - self.assertEqual(treespec, pytree.LeafSpec()) + self.assertEqual(treespec, pytree.treespec_leaf()) unflattened = pytree.tree_unflatten(values, treespec) self.assertEqual(unflattened, leaf) @@ -215,7 +212,7 @@ class TestGenericPytree(TestCase): ( python_pytree, lambda tup: python_pytree.TreeSpec( - tuple, None, [python_leafspec for _ in tup] + tuple, None, [python_pytree.treespec_leaf() for _ in tup] ), ), name="python", @@ -250,7 +247,7 @@ class TestGenericPytree(TestCase): ( python_pytree, lambda lst: python_pytree.TreeSpec( - list, None, [python_leafspec for _ in lst] + list, None, [python_pytree.treespec_leaf() for _ in lst] ), ), name="python", @@ -286,7 +283,7 @@ class TestGenericPytree(TestCase): lambda dct: python_pytree.TreeSpec( dict, list(dct.keys()), - [python_leafspec for _ in dct.values()], + [python_pytree.treespec_leaf() for _ in dct.values()], ), ), name="python", @@ -327,7 +324,7 @@ class TestGenericPytree(TestCase): lambda odict: python_pytree.TreeSpec( OrderedDict, list(odict.keys()), - [python_leafspec for _ in odict.values()], + [python_pytree.treespec_leaf() for _ in odict.values()], ), ), name="python", @@ -371,7 +368,7 @@ class TestGenericPytree(TestCase): lambda ddct: python_pytree.TreeSpec( defaultdict, [ddct.default_factory, list(ddct.keys())], - [python_leafspec for _ in ddct.values()], + [python_pytree.treespec_leaf() for _ in ddct.values()], ), ), name="python", @@ -413,7 +410,7 @@ class TestGenericPytree(TestCase): ( python_pytree, lambda deq: python_pytree.TreeSpec( - deque, deq.maxlen, [python_leafspec for _ in deq] + deque, deq.maxlen, [python_pytree.treespec_leaf() for _ in deq] ), ), name="python", @@ -453,7 +450,7 @@ class TestGenericPytree(TestCase): def run_test(tup): if pytree is python_pytree: expected_spec = python_pytree.TreeSpec( - namedtuple, Point, [python_leafspec for _ in tup] + namedtuple, Point, [python_pytree.treespec_leaf() for _ in tup] ) else: expected_spec = cxx_pytree.tree_structure(Point(0, 1)) @@ -848,16 +845,16 @@ if "optree" in sys.modules: def test_treespec_equality(self): self.assertEqual( - python_pytree.LeafSpec(), - python_pytree.LeafSpec(), + python_pytree.treespec_leaf(), + python_pytree.treespec_leaf(), ) self.assertEqual( python_pytree.TreeSpec(list, None, []), python_pytree.TreeSpec(list, None, []), ) self.assertEqual( - python_pytree.TreeSpec(list, None, [python_pytree.LeafSpec()]), - python_pytree.TreeSpec(list, None, [python_pytree.LeafSpec()]), + python_pytree.TreeSpec(list, None, [python_pytree.treespec_leaf()]), + python_pytree.TreeSpec(list, None, [python_pytree.treespec_leaf()]), ) self.assertFalse( python_pytree.TreeSpec(tuple, None, []) @@ -892,24 +889,32 @@ if "optree" in sys.modules: # python_pytree.tree_structure({}) python_pytree.TreeSpec(dict, [], []), # python_pytree.tree_structure([0]) - python_pytree.TreeSpec(list, None, [python_leafspec]), + python_pytree.TreeSpec(list, None, [python_pytree.treespec_leaf()]), # python_pytree.tree_structure([0, 1]) python_pytree.TreeSpec( list, None, - [python_leafspec, python_leafspec], + [python_pytree.treespec_leaf(), python_pytree.treespec_leaf()], ), # python_pytree.tree_structure((0, 1, 2)) python_pytree.TreeSpec( tuple, None, - [python_leafspec, python_leafspec, python_leafspec], + [ + python_pytree.treespec_leaf(), + python_pytree.treespec_leaf(), + python_pytree.treespec_leaf(), + ], ), # python_pytree.tree_structure({"a": 0, "b": 1, "c": 2}) python_pytree.TreeSpec( dict, ["a", "b", "c"], - [python_leafspec, python_leafspec, python_leafspec], + [ + python_pytree.treespec_leaf(), + python_pytree.treespec_leaf(), + python_pytree.treespec_leaf(), + ], ), # python_pytree.tree_structure(OrderedDict([("a", (0, 1)), ("b", 2), ("c", {"a": 3, "b": 4, "c": 5})]) python_pytree.TreeSpec( @@ -919,13 +924,17 @@ if "optree" in sys.modules: python_pytree.TreeSpec( tuple, None, - [python_leafspec, python_leafspec], + [python_pytree.treespec_leaf(), python_pytree.treespec_leaf()], ), - python_leafspec, + python_pytree.treespec_leaf(), python_pytree.TreeSpec( dict, ["a", "b", "c"], - [python_leafspec, python_leafspec, python_leafspec], + [ + python_pytree.treespec_leaf(), + python_pytree.treespec_leaf(), + python_pytree.treespec_leaf(), + ], ), ], ), @@ -938,12 +947,15 @@ if "optree" in sys.modules: tuple, None, [ - python_leafspec, - python_leafspec, + python_pytree.treespec_leaf(), + python_pytree.treespec_leaf(), python_pytree.TreeSpec( list, None, - [python_leafspec, python_leafspec], + [ + python_pytree.treespec_leaf(), + python_pytree.treespec_leaf(), + ], ), ], ), @@ -957,12 +969,12 @@ if "optree" in sys.modules: python_pytree.TreeSpec( list, None, - [python_leafspec, python_leafspec], + [python_pytree.treespec_leaf(), python_pytree.treespec_leaf()], ), python_pytree.TreeSpec( list, None, - [python_leafspec, python_leafspec], + [python_pytree.treespec_leaf(), python_pytree.treespec_leaf()], ), python_pytree.TreeSpec(dict, [], []), ], @@ -991,7 +1003,7 @@ if "optree" in sys.modules: list, None, [ - python_leafspec, + python_pytree.treespec_leaf(), ], ), ], @@ -1000,7 +1012,7 @@ if "optree" in sys.modules: self.assertIsInstance(serialized_spec, str) def test_pytree_serialize_enum(self): - spec = python_pytree.TreeSpec(dict, TestEnum.A, [python_leafspec]) + spec = python_pytree.TreeSpec(dict, TestEnum.A, [python_pytree.treespec_leaf()]) serialized_spec = python_pytree.treespec_dumps(spec) self.assertIsInstance(serialized_spec, str) @@ -1163,12 +1175,20 @@ if "optree" in sys.modules: OrderedDict, [1, 2, 3], [ - python_pytree.TreeSpec(tuple, None, [python_leafspec, python_leafspec]), - python_leafspec, + python_pytree.TreeSpec( + tuple, + None, + [python_pytree.treespec_leaf(), python_pytree.treespec_leaf()], + ), + python_pytree.treespec_leaf(), python_pytree.TreeSpec( dict, [4, 5, 6], - [python_leafspec, python_leafspec, python_leafspec], + [ + python_pytree.treespec_leaf(), + python_pytree.treespec_leaf(), + python_pytree.treespec_leaf(), + ], ), ], ) @@ -1453,7 +1473,7 @@ class TestCxxPytree(TestCase): raise unittest.SkipTest("C++ pytree tests are not supported in fbcode") def test_treespec_equality(self): - self.assertEqual(cxx_pytree.LeafSpec(), cxx_pytree.LeafSpec()) + self.assertEqual(cxx_pytree.treespec_leaf(), cxx_pytree.treespec_leaf()) def test_treespec_repr(self): # Check that it looks sane diff --git a/torch/_dynamo/polyfills/pytree.py b/torch/_dynamo/polyfills/pytree.py index ef308e90789..f9bdc0cce4a 100644 --- a/torch/_dynamo/polyfills/pytree.py +++ b/torch/_dynamo/polyfills/pytree.py @@ -17,7 +17,7 @@ from ..decorators import substitute_in_graph if TYPE_CHECKING: import builtins - from collections.abc import Callable, Iterable + from collections.abc import Callable, Iterable, Mapping from typing_extensions import Self @@ -349,6 +349,113 @@ if python_pytree._cxx_pytree_dynamo_traceable: def _is_pytreespec_instance(obj: Any, /) -> TypeIs[PyTreeSpec]: return isinstance(obj, PyTreeSpec) + @substitute_in_graph( # type: ignore[arg-type] + optree.treespec_leaf, + # We need to disable constant folding here because we want the function to reference the + # PyTreeSpec class defined above, not the one in the C++ module. + can_constant_fold_through=False, + ) + def treespec_leaf( + *, + none_is_leaf: bool = False, + namespace: str = "", # unused + ) -> PyTreeSpec: + return PyTreeSpec( + (), + None, + None, + (), + None, + none_is_leaf=none_is_leaf, + namespace="", + ) + + @substitute_in_graph( # type: ignore[arg-type] + optree.treespec_tuple, + # We need to disable constant folding here because we want the function to reference the + # PyTreeSpec class defined above, not the one in the C++ module. + can_constant_fold_through=False, + ) + def treespec_tuple( + iterable: Iterable[PyTreeSpec] = (), + /, + *, + none_is_leaf: bool = False, + namespace: str = "", + ) -> PyTreeSpec: + children = tuple(iterable) + if any(not _is_pytreespec_instance(child) for child in children): + raise ValueError(f"Expected a tuple of PyTreeSpecs, got: {children!r}.") + if any(child.none_is_leaf != none_is_leaf for child in children): + raise ValueError( + "All children PyTreeSpecs must have the same `none_is_leaf` value " + f"as the parent; expected {none_is_leaf}, got: {children!r}.", + ) + if any(child.namespace not in (namespace, "") for child in children): + raise ValueError( + "All children PyTreeSpecs must have the same `namespace` value " + f"as the parent; expected {namespace!r}, got: {children!r}.", + ) + handler = optree.register_pytree_node.get(tuple, namespace=namespace) # type: ignore[attr-defined] + assert handler is not None + return PyTreeSpec( + tuple(children), + tuple, + None, + tuple(range(len(children))), + handler.unflatten_func, + none_is_leaf=none_is_leaf, + namespace=namespace, + ) + + @substitute_in_graph( # type: ignore[arg-type] + optree.treespec_dict, + # We need to disable constant folding here because we want the function to reference the + # PyTreeSpec class defined above, not the one in the C++ module. + can_constant_fold_through=False, + ) + def treespec_dict( + mapping: Mapping[Any, PyTreeSpec] | Iterable[tuple[Any, PyTreeSpec]] = (), + /, + *, + none_is_leaf: bool = False, + namespace: str = "", + **kwargs: PyTreeSpec, + ) -> PyTreeSpec: + dct = dict(mapping, **kwargs) + if any(not _is_pytreespec_instance(child) for child in dct.values()): + raise ValueError(f"Expected a dictionary of TreeSpecs, got: {dct!r}.") + if any(child.none_is_leaf != none_is_leaf for child in dct.values()): + raise ValueError( + "All children PyTreeSpecs must have the same `none_is_leaf` value " + f"as the parent; expected {none_is_leaf}, got: {dct!r}.", + ) + if any(child.namespace not in (namespace, "") for child in dct.values()): + raise ValueError( + "All children PyTreeSpecs must have the same `namespace` value " + f"as the parent; expected {namespace!r}, got: {dct!r}.", + ) + + ( + children, + metadata, + entries, + unflatten_func, + ) = optree.tree_flatten_one_level( # type: ignore[assignment,var-annotated] + dct, # type: ignore[arg-type] + none_is_leaf=none_is_leaf, + namespace=namespace, + ) + return PyTreeSpec( + tuple(children), # type: ignore[arg-type] + dict, + metadata, + entries, + unflatten_func, + none_is_leaf=none_is_leaf, + namespace=namespace, + ) + @substitute_in_graph( # type: ignore[arg-type] optree.tree_flatten, # We need to disable constant folding here because we want the function to reference the diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index a2bbcb65710..d19e7998dee 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -3727,9 +3727,7 @@ class SourcelessBuilder: pass # failthrough to unimplemented branch elif isinstance(value, torch.fx.graph_module.GraphModule): return SourcelessGraphModuleVariable(value) - elif isinstance( - value, (torch.utils._pytree.TreeSpec, torch.utils._pytree.LeafSpec) - ): + elif isinstance(value, torch.utils._pytree.TreeSpec): return UserDefinedObjectVariable(value) elif PlacementVariable.is_placement(value): return PlacementVariable(value) diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py index bee608f1fb0..c330a700fd6 100644 --- a/torch/_dynamo/variables/higher_order_ops.py +++ b/torch/_dynamo/variables/higher_order_ops.py @@ -1661,8 +1661,8 @@ class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable): # We need to have this check this way, because in case init is a TreeSpec and carry # but carry is only a LeafSpec, these two cannot be compared correctly. if ( - isinstance(xs_treespec.as_python_constant(), pytree.LeafSpec) - != isinstance(_combine_treespec.as_python_constant(), pytree.LeafSpec) + xs_treespec.as_python_constant().is_leaf() + != _combine_treespec.as_python_constant().is_leaf() ) or not _make_inlined(tx, pytree.TreeSpec.__eq__)( xs_treespec, _combine_treespec ).as_python_constant(): diff --git a/torch/_export/serde/serialize.py b/torch/_export/serde/serialize.py index 7acacdd0ca1..a0a40666d03 100644 --- a/torch/_export/serde/serialize.py +++ b/torch/_export/serde/serialize.py @@ -1530,10 +1530,10 @@ class GraphModuleSerializer(metaclass=Final): else: raise AssertionError("TODO") - def serialize_treespec(self, treespec): + def serialize_treespec(self, treespec: pytree.TreeSpec) -> str: # We want to additionally save all the field names of the namedtuples in # case users want to check that the treespec types are equivalent - def store_namedtuple_fields(ts): + def store_namedtuple_fields(ts: pytree.TreeSpec) -> None: if ts.type is None: return if ts.type is namedtuple or pytree.is_namedtuple_class(ts.type): @@ -1555,7 +1555,7 @@ class GraphModuleSerializer(metaclass=Final): NamedTupleDef(field_names=ts.context._fields) ) - for child in ts.children_specs: + for child in ts.children(): store_namedtuple_fields(child) serialized_treespec = treespec_dumps(treespec, TREESPEC_VERSION) diff --git a/torch/_functorch/_aot_autograd/utils.py b/torch/_functorch/_aot_autograd/utils.py index 858c0e9e539..f3e8f6a91b3 100644 --- a/torch/_functorch/_aot_autograd/utils.py +++ b/torch/_functorch/_aot_autograd/utils.py @@ -158,7 +158,7 @@ class PytreeThunk: assert spec is not None self.spec: pytree.TreeSpec = spec if self.spec.type in {tuple, list} and all( - child.is_leaf() for child in spec.children_specs + child.is_leaf() for child in spec.children() ): self.is_simple = True if self.spec.is_leaf(): diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py index bf561e62a39..f48cb04f08f 100644 --- a/torch/_functorch/aot_autograd.py +++ b/torch/_functorch/aot_autograd.py @@ -1590,7 +1590,7 @@ def aot_export_joint_simple( decompositions=decompositions, trace_joint=trace_joint, ) - in_spec, _kw_in_spec = in_spec.children_specs + in_spec, _kw_in_spec = in_spec.children() # At this point, we can just directly return the (joint or inference graph) that we traced. # First though: a bunch of assertions to make sure that our graph doesn't require # any calling convention changes compared to the original function. @@ -1617,7 +1617,7 @@ def aot_export_joint_simple( raise RuntimeError( f"aot_export_joint_simple requires inputs to be a single list/tuple. in_spec={str(in_spec)}" ) - if not all(child.is_leaf() for child in in_spec.children_specs): + if not all(child.is_leaf() for child in in_spec.children()): raise RuntimeError( f"aot_export_joint_simple requires individual inputs not to be pytrees. in_spec={str(in_spec)}" ) @@ -1625,7 +1625,7 @@ def aot_export_joint_simple( raise RuntimeError( f"aot_export_joint_simple requires outputs to be a single list/tuple. out_spec={str(out_spec)}" ) - if not all(child.is_leaf() for child in out_spec.children_specs): + if not all(child.is_leaf() for child in out_spec.children()): raise RuntimeError( f"aot_export_joint_simple requires individual outputs not to be pytrees. out_spec={str(out_spec)}" ) diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index 33b25b319da..4e40d2347b1 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -469,7 +469,7 @@ def _unlift_graph( gm, lifted_inputs, mutated_outputs, - pytree.LeafSpec(), + pytree.treespec_leaf(), None, ) return unlifted_gm diff --git a/torch/ao/quantization/pt2e/utils.py b/torch/ao/quantization/pt2e/utils.py index 1575f936044..f6e9789e948 100644 --- a/torch/ao/quantization/pt2e/utils.py +++ b/torch/ao/quantization/pt2e/utils.py @@ -7,6 +7,7 @@ from typing import Any, Optional, Union import torch import torch.ao.quantization.pt2e._affine_quantization # noqa: F401 import torch.nn.functional as F +import torch.utils._pytree as pytree # Makes sure that quantized_decomposed ops are registered from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401 @@ -14,7 +15,6 @@ from torch.ao.quantization.quantizer import QuantizationAnnotation from torch.export.unflatten import _assign_attr, _AttrKind from torch.fx import GraphModule, Node from torch.nn.utils.fusion import fuse_conv_bn_weights -from torch.utils._pytree import LeafSpec __all__ = [ @@ -477,7 +477,10 @@ def _replace_literals_with_new_placeholders( exclude_literals = [] in_spec = gm._in_spec - args_spec = in_spec.children_specs[0] + assert in_spec.type is tuple + args_spec = in_spec.child(0) + assert args_spec.type is tuple + args_spec_children = args_spec.children() for node in gm.graph.nodes: if node.op == "placeholder": last_ph = node @@ -492,7 +495,7 @@ def _replace_literals_with_new_placeholders( else: ph_node = gm.graph.placeholder("arg" + str(cnt)) new_args.append(ph_node) - args_spec.children_specs.append(LeafSpec()) + args_spec_children.append(pytree.treespec_leaf()) cnt += 1 if merge_dup: literal_to_ph[arg] = ph_node @@ -503,8 +506,8 @@ def _replace_literals_with_new_placeholders( node.args = new_args # Update `num_nodes`, `num_leaves`, `num_children`. - args_spec.__post_init__() - in_spec.__post_init__() + args_spec = pytree.treespec_tuple(args_spec_children) + gm._in_spec = in_spec = pytree.treespec_tuple([args_spec, *in_spec.children()[1:]]) return gm diff --git a/torch/export/_swap.py b/torch/export/_swap.py index 0c38d540b97..78c507da02d 100644 --- a/torch/export/_swap.py +++ b/torch/export/_swap.py @@ -195,17 +195,16 @@ def _construct_inputs( unflatten_node = _generate_unflatten(gm, tree_unflatten_args, signature.in_spec) assert signature.in_spec.num_children == 2 + assert signature.in_spec.type is tuple + args_spec, kwargs_spec = signature.in_spec.children() + assert args_spec.type is tuple + assert kwargs_spec.type is dict - args_spec = signature.in_spec.children_specs[0] - assert args_spec.context is None args_node = gm.graph.call_function(operator.getitem, (unflatten_node, 0)) args_nodes = [ gm.graph.call_function(operator.getitem, (args_node, i)) for i in range(args_spec.num_children) ] - - kwargs_spec = signature.in_spec.children_specs[1] - assert kwargs_spec.context is not None kwargs_node = gm.graph.call_function(operator.getitem, (unflatten_node, 1)) kwargs_nodes = { k: gm.graph.call_function(operator.getitem, (kwargs_node, k)) @@ -372,8 +371,8 @@ def _fix_input_output_signature( if forward_arg_names is None: forward_arg_names = [] assert signature.in_spec.num_children == 2 - arg_spec = signature.in_spec.children_specs[0] - kwarg_spec = signature.in_spec.children_specs[1] + arg_spec = signature.in_spec.child(0) + kwarg_spec = signature.in_spec.child(1) assert arg_spec.type is tuple assert kwarg_spec.type is dict for i in range(arg_spec.num_children): diff --git a/torch/export/_trace.py b/torch/export/_trace.py index 685fe149714..934ee448820 100644 --- a/torch/export/_trace.py +++ b/torch/export/_trace.py @@ -1533,7 +1533,7 @@ def _strict_export( # aot_export expect the return type to always be a tuple. if out_spec.type not in (list, tuple): - out_spec = pytree.TreeSpec(tuple, None, [out_spec]) + out_spec = pytree.treespec_tuple([out_spec]) orig_arg_names = gm_torch_level.graph._codegen.pytree_info.orig_args # type: ignore[attr-defined] diff --git a/torch/export/_tree_utils.py b/torch/export/_tree_utils.py index 5c2d4426066..de768c38f21 100644 --- a/torch/export/_tree_utils.py +++ b/torch/export/_tree_utils.py @@ -15,7 +15,7 @@ def reorder_kwargs(user_kwargs: dict[str, Any], spec: TreeSpec) -> dict[str, Any # Make sure that the spec is actually shaped like (args, kwargs) assert spec.type is tuple assert spec.num_children == 2 - kwargs_spec = spec.children_specs[1] + kwargs_spec = spec.child(1) assert kwargs_spec.type is dict if set(user_kwargs) != set(kwargs_spec.context): @@ -55,10 +55,10 @@ def is_equivalent( return False # Recurse on children - if len(spec1.children_specs) != len(spec2.children_specs): + if spec1.num_children != spec2.num_children: return False - for child_spec1, child_spec2 in zip(spec1.children_specs, spec2.children_specs): + for child_spec1, child_spec2 in zip(spec1.children(), spec2.children()): if not is_equivalent(child_spec1, child_spec2, equivalence_fn): return False diff --git a/torch/export/_unlift.py b/torch/export/_unlift.py index b9e82481322..52d06a294fa 100644 --- a/torch/export/_unlift.py +++ b/torch/export/_unlift.py @@ -57,11 +57,11 @@ def eq_spec(self: pytree.TreeSpec, other: pytree.TreeSpec) -> bool: return False elif a.context != b.context: return False - if len(a.children_specs) != len(b.children_specs): + if a.num_children != b.num_children: return False return all( _match_normalized_structure(a, b) - for a, b in zip(a.children_specs, b.children_specs) + for a, b in zip(a.children(), b.children()) ) return _match_normalized_structure(self, other) @@ -357,13 +357,13 @@ def _get_codegen( elif ( in_spec.type is tuple and in_spec.num_children == 2 - and in_spec.children_specs[0].type is tuple - and in_spec.children_specs[1].type is dict + and in_spec.child(0).type is tuple + and in_spec.child(1).type is dict ): # if in_spec contains the args (tuple) and kwargs (dict) - names = [f"arg_{i}" for i in range(in_spec.children_specs[0].num_children)] + names = [f"arg_{i}" for i in range(in_spec.child(0).num_children)] # add kwarg names - names.extend(in_spec.children_specs[1].context) + names.extend(in_spec.child(1).context) else: names = [f"arg_{i}" for i in range(in_spec.num_children)] diff --git a/torch/export/dynamic_shapes.py b/torch/export/dynamic_shapes.py index 24964c27472..1e1f1f40985 100644 --- a/torch/export/dynamic_shapes.py +++ b/torch/export/dynamic_shapes.py @@ -12,14 +12,16 @@ import torch from torch.utils._pytree import ( _get_node_type, BUILTIN_TYPES, + KeyPath, keystr, - LeafSpec, MappingKey, SequenceKey, SUPPORTED_NODES, - tree_flatten, + tree_iter, tree_map, tree_map_with_path, + tree_structure, + TreeSpec, ) from .exported_program import ExportedProgram @@ -655,53 +657,55 @@ def _tree_map_with_path( case_name="dynamic_shapes_validation", ) - def _compare(tree, dynamic_shapes, path): + def _compare( + treespec: TreeSpec, other_treespec: TreeSpec, path: KeyPath + ) -> None: # raise an error at the point where tree and dynamic_shapes differ, # including the path to that point and the reason for the difference rendered_path = keystr(path) - if isinstance(tree, LeafSpec): + if treespec.is_leaf(): return - if isinstance(dynamic_shapes, LeafSpec): + if other_treespec.is_leaf(): raise_mismatch_error( - f"`{tree_name}{rendered_path}` is a {tree.type}, " + f"`{tree_name}{rendered_path}` is a {treespec.type}, " f"but `dynamic_shapes{rendered_path}` is not" ) - if tree.type != dynamic_shapes.type: + if treespec.type != other_treespec.type: raise_mismatch_error( - f"`{tree_name}{rendered_path}` is a {tree.type}, " - f"but `dynamic_shapes{rendered_path}` is a {dynamic_shapes.type}" + f"`{tree_name}{rendered_path}` is a {treespec.type}, " + f"but `dynamic_shapes{rendered_path}` is a {other_treespec.type}" ) - if len(tree.children_specs) != len(dynamic_shapes.children_specs): + if treespec.num_children != other_treespec.num_children: raise_mismatch_error( - f"`{tree_name}{rendered_path}` has {len(tree.children_specs)} elements, " - f"but `dynamic_shapes{rendered_path}` has {len(dynamic_shapes.children_specs)} elements" + f"`{tree_name}{rendered_path}` has {treespec.num_children} elements, " + f"but `dynamic_shapes{rendered_path}` has {other_treespec.num_children} elements" ) - if tree.type is dict: + if treespec.type is dict: # context, children could be out of order - if sorted(tree.context) != sorted(dynamic_shapes.context): + if set(treespec.context) != set(other_treespec.context): raise_mismatch_error( - f"`{tree_name}{rendered_path}` has keys {tree.context}, " - f"but `dynamic_shapes{rendered_path}` has keys {dynamic_shapes.context}" + f"`{tree_name}{rendered_path}` has keys {treespec.context}, " + f"but `dynamic_shapes{rendered_path}` has keys {other_treespec.context}" ) _remap = dict( - zip(dynamic_shapes.context, dynamic_shapes.children_specs) + zip(other_treespec.context, other_treespec.children()) ) - dynamic_shapes_children_specs = [_remap[k] for k in tree.context] + other_children = [_remap[k] for k in treespec.context] else: - dynamic_shapes_children_specs = dynamic_shapes.children_specs - for i, (tree_, dynamic_shapes_) in enumerate( - zip(tree.children_specs, dynamic_shapes_children_specs) + other_children = other_treespec.children() + for i, (child, other_child) in enumerate( + zip(treespec.children(), other_children) ): _compare( - tree_, - dynamic_shapes_, - path + [_key(tree.type, tree.context, i)], + child, + other_child, + path + (_key(treespec.type, treespec.context, i),), ) - _, tree_spec = tree_flatten(tree, is_leaf=is_leaf) + treespec = tree_structure(tree, is_leaf=is_leaf) for other_tree in dynamic_shapes: - _, other_tree_spec = tree_flatten(other_tree, is_leaf) - _compare(tree_spec, other_tree_spec, []) + other_treespec = tree_structure(other_tree, is_leaf) + _compare(treespec, other_treespec, ()) raise @@ -1231,10 +1235,7 @@ def _get_dim_name_mapping( dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any], None], ): name_to_dim = {} - for dim in tree_flatten( - dynamic_shapes, - is_leaf=lambda x: isinstance(x, Dim), - )[0]: + for dim in tree_iter(dynamic_shapes, is_leaf=lambda x: isinstance(x, Dim)): if dim is None: # NOTE: this must denote a non-Tensor or automatic at this point. continue diff --git a/torch/export/exported_program.py b/torch/export/exported_program.py index 3b2d0cb7401..f8faff582f0 100644 --- a/torch/export/exported_program.py +++ b/torch/export/exported_program.py @@ -391,7 +391,7 @@ def _decompose_and_get_gm_with_new_signature_constants( # aot_export expect the return type to always be a tuple. assert out_spec is not None if out_spec.type not in (list, tuple): - out_spec = pytree.TreeSpec(tuple, None, [out_spec]) + out_spec = pytree.treespec_tuple([out_spec]) mod.graph._codegen = _PyTreeCodeGen( _PyTreeInfo( diff --git a/torch/export/unflatten.py b/torch/export/unflatten.py index 1e4931f4a19..3701ba99047 100644 --- a/torch/export/unflatten.py +++ b/torch/export/unflatten.py @@ -1124,10 +1124,10 @@ class _ModuleFrame: signature = module_call_graph.get(self.child_fqn) if signature is not None and self.parent is not None: assert signature.in_spec.num_children == 2 - args_spec = signature.in_spec.children_specs[0] - kwargs_spec = signature.in_spec.children_specs[1] - assert args_spec.context is None - assert kwargs_spec.context is not None + assert signature.in_spec.type is tuple + args_spec, kwargs_spec = signature.in_spec.children() + assert args_spec.type is tuple + assert kwargs_spec.type is dict with self.graph.inserting_after(None): arg_nodes = [ diff --git a/torch/fx/_pytree.py b/torch/fx/_pytree.py index 2f608816c49..bfb62f871eb 100644 --- a/torch/fx/_pytree.py +++ b/torch/fx/_pytree.py @@ -49,7 +49,7 @@ def tree_flatten_spec( flatten_fn_spec = SUPPORTED_NODES[spec.type] child_pytrees = flatten_fn_spec(pytree, spec) result = [] - for child, child_spec in zip(child_pytrees, spec.children_specs): + for child, child_spec in zip(child_pytrees, spec.children()): flat = tree_flatten_spec(child, child_spec) result += flat return result diff --git a/torch/fx/_symbolic_trace.py b/torch/fx/_symbolic_trace.py index cc72d410fd7..cfce00fb05e 100644 --- a/torch/fx/_symbolic_trace.py +++ b/torch/fx/_symbolic_trace.py @@ -709,7 +709,7 @@ class Tracer(TracerBase): root_fn = _patch_function(root_fn, len(args)) flat_args, in_spec = pytree.tree_flatten(tuple(args)) - if not all(child.is_leaf() for child in in_spec.children_specs): + if not all(child.is_leaf() for child in in_spec.children()): # In the case that we have pytree-flattened inputs in # `concrete_args`, generate a flattening wrapper around the # original root function and return that. diff --git a/torch/fx/graph.py b/torch/fx/graph.py index 3a319333529..fc6f4c5b270 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -933,24 +933,25 @@ class _PyTreeCodeGen(CodeGen): return "\n " + "".join(x + "; " for x in has_annotation) + "\n" def gen_var_bindings(self, fn_args, free_vars, expanded_def) -> str: + in_spec = self.pytree_info.in_spec # when kwargs is present, in_spec is tuple(args, kwargs) has_args_kwargs_tuple = ( - self.pytree_info.in_spec.type is tuple - and self.pytree_info.in_spec.num_children == 2 - and self.pytree_info.in_spec.children_specs[0].type is tuple - and self.pytree_info.in_spec.children_specs[1].type is dict + in_spec.type is tuple + and in_spec.num_children == 2 + and in_spec.child(0).type is tuple + and in_spec.child(1).type is dict ) fn_kwargs = "{}" fn_signature = f"[{', '.join(fn_args)}], self._in_spec" if has_args_kwargs_tuple: - count_args = self.pytree_info.in_spec.children_specs[0].num_children + count_args = in_spec.child(0).num_children fn_args = self.pytree_info.orig_args[:count_args] fn_kwargs = ( "{" + ", ".join( f"'{k}':{v}" for k, v in zip( - self.pytree_info.in_spec.children_specs[1].context, + in_spec.child(1).context, self.pytree_info.orig_args[count_args:], ) ) diff --git a/torch/utils/_cxx_pytree.py b/torch/utils/_cxx_pytree.py index 1d67bb41da2..603625ed97c 100644 --- a/torch/utils/_cxx_pytree.py +++ b/torch/utils/_cxx_pytree.py @@ -14,9 +14,9 @@ collection support for PyTorch APIs. import functools import types -from collections.abc import Callable, Iterable +from collections.abc import Callable, Iterable, Mapping from typing import Any, Optional, overload, TypeVar, Union -from typing_extensions import deprecated, TypeIs +from typing_extensions import deprecated, Self, TypeAlias, TypeIs import torch.utils._pytree as python_pytree from torch.torch_version import TorchVersion as _TorchVersion @@ -42,7 +42,7 @@ if not python_pytree._cxx_pytree_dynamo_traceable: import optree -from optree import PyTreeSpec as TreeSpec # direct import for type annotations +from optree import PyTreeSpec # direct import for type annotations __all__ = [ @@ -53,6 +53,7 @@ __all__ = [ "DumpableContext", "ToDumpableContextFn", "FromDumpableContextFn", + "PyTreeSpec", "TreeSpec", "LeafSpec", "keystr", @@ -100,6 +101,8 @@ U = TypeVar("U") R = TypeVar("R") +TreeSpec: TypeAlias = PyTreeSpec + Context = Any PyTree = Any FlattenFunc = Callable[[PyTree], tuple[list[Any], Context]] @@ -267,6 +270,30 @@ def _is_pytreespec_instance(obj: Any, /) -> TypeIs[TreeSpec]: return isinstance(obj, TreeSpec) +def treespec_leaf() -> TreeSpec: + """Make a treespec representing a leaf node.""" + return optree.treespec_leaf(none_is_leaf=True, namespace="torch") + + +def treespec_tuple(iterable: Iterable[TreeSpec] = (), /) -> TreeSpec: + """Make a tuple treespec from an iterable of child treespecs.""" + return optree.treespec_tuple(iterable, none_is_leaf=True, namespace="torch") + + +def treespec_dict( + mapping: Union[Mapping[Any, TreeSpec], Iterable[tuple[Any, TreeSpec]]] = (), + /, + **kwargs: TreeSpec, +) -> TreeSpec: + """Make a dict treespec from a dict of child treespecs.""" + return optree.treespec_dict( + mapping, + **kwargs, + none_is_leaf=True, + namespace="torch", + ) + + def tree_is_leaf( tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, @@ -985,9 +1012,14 @@ class LeafSpecMeta(type(TreeSpec)): # type: ignore[misc] return _is_pytreespec_instance(instance) and instance.is_leaf() +@deprecated( + "`isinstance(treespec, LeafSpec)` is deprecated, " + "use `isinstance(treespec, TreeSpec)` and `treespec.is_leaf()` instead.", + category=FutureWarning, +) class LeafSpec(TreeSpec, metaclass=LeafSpecMeta): # type: ignore[misc,final] - def __new__(cls) -> "LeafSpec": - return optree.treespec_leaf(none_is_leaf=True) # type: ignore[return-value] + def __new__(cls) -> Self: + return treespec_leaf() # type: ignore[return-value] def tree_flatten_with_path( diff --git a/torch/utils/_pytree.py b/torch/utils/_pytree.py index dac064c8966..56704bb3f80 100644 --- a/torch/utils/_pytree.py +++ b/torch/utils/_pytree.py @@ -39,7 +39,7 @@ from typing import ( TypeVar, Union, ) -from typing_extensions import deprecated, NamedTuple, Self +from typing_extensions import deprecated, NamedTuple, Self, TypeAlias from torch.torch_version import TorchVersion as _TorchVersion @@ -52,6 +52,7 @@ __all__ = [ "DumpableContext", "ToDumpableContextFn", "FromDumpableContextFn", + "PyTreeSpec", "TreeSpec", "LeafSpec", "keystr", @@ -475,7 +476,7 @@ class ConstantNode: def _is_constant_holder(spec: "TreeSpec") -> bool: """Checks if the spec is from a pytree registered with register_constant""" - return isinstance(spec.context, ConstantNode) + return isinstance(spec._context, ConstantNode) def _retrieve_constant(spec: "TreeSpec") -> Any: @@ -1076,39 +1077,60 @@ def _is_leaf(tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None) - # A TreeSpec represents the structure of a pytree. It holds: -# "type": the type of root Node of the pytree -# context: some context that is useful in unflattening the pytree -# children_specs: specs for each child of the root Node -# num_leaves: the number of leaves -@dataclasses.dataclass(init=True, frozen=True, eq=True, repr=False) +# "type": the type of root Node of the pytree +# context: some context that is useful in unflattening the pytree +# children(): specs for each child of the root Node +# num_nodes: the total number of nodes +# num_leaves: the number of leaves +# num_children: the number of children of the root Node (i.e., len(children())) +# is_leaf(): whether the root Node is a leaf +@dataclasses.dataclass(init=False, frozen=True, eq=True, repr=False) class TreeSpec: type: Any - context: Context - children_specs: list["TreeSpec"] + _context: Context + _children: list[Self] num_nodes: int = dataclasses.field(init=False) num_leaves: int = dataclasses.field(init=False) num_children: int = dataclasses.field(init=False) + def __init__( + self, + type: Any, + context: Context, # keep for backward compatibility + children_specs: list[Self], # keep for backward compatibility + ) -> None: + object.__setattr__(self, "type", type) + object.__setattr__(self, "_context", context) + object.__setattr__(self, "_children", children_specs) + self.__post_init__() + def __post_init__(self) -> None: - num_nodes = sum((spec.num_nodes for spec in self.children_specs), start=1) - num_leaves = sum(spec.num_leaves for spec in self.children_specs) - num_children = len(self.children_specs) + if self.type is None: + assert self._context is None + assert len(self._children) == 0 + num_nodes = 1 + num_leaves = 1 + num_children = 0 + else: + num_nodes = sum((spec.num_nodes for spec in self._children), start=1) + num_leaves = sum(spec.num_leaves for spec in self._children) + num_children = len(self._children) object.__setattr__(self, "num_nodes", num_nodes) object.__setattr__(self, "num_leaves", num_leaves) object.__setattr__(self, "num_children", num_children) def __repr__(self, indent: int = 0) -> str: - repr_prefix: str = f"TreeSpec({self.type.__name__}, {self.context}, [" + repr_prefix: str = f"TreeSpec({self.type.__name__}, {self._context}, [" children_specs_str: str = "" if self.num_children > 0: indent += 2 - children_specs_str += self.children_specs[0].__repr__(indent) + children_specs_str += self._children[0].__repr__(indent) children_specs_str += "," if self.num_children > 1 else "" children_specs_str += ",".join( [ "\n" + " " * indent + child.__repr__(indent) - for child in self.children_specs[1:] + for child in self._children[1:] ] ) repr_suffix: str = f"{children_specs_str}])" @@ -1120,16 +1142,36 @@ class TreeSpec: elif other.__class__ is self.__class__: if str(self.type) != str(other.type): return False - if self.context != other.context: + if self._context != other._context: return False - elif self.children_specs != other.children_specs: + elif self._children != other._children: return False return True return NotImplemented + @property + def context(self) -> Context: + return self._context + + @property + @deprecated( + "`treespec.children_specs` is deprecated. " + "Use `treespec.child(index)` to access a single child, " + "or `treespec.children()` to get all children.", + category=FutureWarning, + ) + def children_specs(self) -> list[Self]: + return self._children + def is_leaf(self) -> bool: return self.num_nodes == 1 and self.num_leaves == 1 + def children(self) -> list[Self]: + return self._children.copy() + + def child(self, index: int) -> Self: + return self._children[index] + def flatten_up_to(self, tree: PyTree) -> list[PyTree]: def helper(treespec: TreeSpec, tree: PyTree, subtrees: list[PyTree]) -> None: if treespec.is_leaf(): @@ -1151,7 +1193,7 @@ class TreeSpec: f"Node arity mismatch; " f"expected {treespec.num_children}, but got {len(children)}.", ) - if context != treespec.context: + if context != treespec._context: raise ValueError( f"Node context mismatch for custom node type {treespec.type!r}.", ) @@ -1176,10 +1218,10 @@ class TreeSpec: if both_standard_dict: # dictionary types are compatible with each other dict_context = ( - treespec.context + treespec._context if treespec.type is not defaultdict # ignore mismatch of `default_factory` for defaultdict - else treespec.context[1] + else treespec._context[1] ) expected_keys = dict_context got_key_set = set(tree) @@ -1200,13 +1242,13 @@ class TreeSpec: children, context = flatten_fn(tree) if ( node_type is not deque # ignore mismatch of `maxlen` for deque - ) and context != treespec.context: + ) and context != treespec._context: raise ValueError( f"Node context mismatch for node type {treespec.type!r}; " - f"expected {treespec.context!r}, but got {context!r}.", # namedtuple type mismatch + f"expected {treespec._context!r}, but got {context!r}.", # namedtuple type mismatch ) - for subtree, subspec in zip(children, treespec.children_specs): + for subtree, subspec in zip(children, treespec._children, strict=True): helper(subspec, subtree, subtrees) subtrees: list[PyTree] = [] @@ -1231,24 +1273,24 @@ class TreeSpec: start = 0 end = 0 child_pytrees = [] - for child_spec in self.children_specs: + for child_spec in self._children: end += child_spec.num_leaves child_pytrees.append(child_spec.unflatten(leaves[start:end])) start = end - return unflatten_fn(child_pytrees, self.context) + return unflatten_fn(child_pytrees, self._context) def __hash__(self) -> int: node_type = self.type if node_type is defaultdict: - default_factory, dict_context = self.context + default_factory, dict_context = self._context hashable_context = (default_factory, tuple(dict_context)) elif node_type in (dict, OrderedDict): - hashable_context = tuple(self.context) + hashable_context = tuple(self._context) elif node_type is None or node_type in BUILTIN_TYPES: - hashable_context = self.context - elif isinstance(self.context, ConstantNode): - hashable_context = self.context.value + hashable_context = self._context + elif isinstance(self._context, ConstantNode): + hashable_context = self._context.value else: # The context for user-defined node types might not be hashable. # Ignore it for hashing. @@ -1256,20 +1298,26 @@ class TreeSpec: # same hash. This might increase the hash collision rate, but we # don't care about that. hashable_context = None - return hash((node_type, hashable_context, tuple(self.children_specs))) + return hash((node_type, hashable_context, tuple(self._children))) + + +PyTreeSpec: TypeAlias = TreeSpec # NOTE: subclassing a dataclass is subtle. In order to enable reasoning about # this class with `dataclasses.fields`, etc., while having a simplified # constructor that takes no argument, we wrap with `dataclass(init=True, ...)` # again, with fields that have `init=False`. +@deprecated( + "`isinstance(treespec, LeafSpec)` is deprecated, " + "use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.", + category=FutureWarning, +) @dataclasses.dataclass(init=True, frozen=True, eq=False, repr=False) class LeafSpec(TreeSpec): type: Any = dataclasses.field(default=None, init=False) - context: Context = dataclasses.field(default=None, init=False) - children_specs: list["TreeSpec"] = dataclasses.field( - default_factory=list, init=False - ) + _context: Context = dataclasses.field(default=None, init=False) + _children: list[Self] = dataclasses.field(default_factory=list, init=False) def __post_init__(self) -> None: # Override `__post_init__` for `num_leaves` derivation. @@ -1283,7 +1331,36 @@ class LeafSpec(TreeSpec): # All leaves are equivalent, so represent with a single object to save on # object construction time -_LEAF_SPEC = LeafSpec() +with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", category=FutureWarning, module=__name__, append=False + ) + _LEAF_SPEC = LeafSpec() + + +def treespec_leaf() -> LeafSpec: + """Make a treespec representing a leaf node.""" + return _LEAF_SPEC + + +def treespec_tuple(iterable: Iterable[TreeSpec] = (), /) -> TreeSpec: + """Make a tuple treespec from an iterable of child treespecs.""" + children = list(iterable) + if any(not isinstance(child, TreeSpec) for child in children): + raise ValueError(f"Expected a tuple of TreeSpec values, got: {children!r}.") + return TreeSpec(tuple, None, children) + + +def treespec_dict( + mapping: Union[Mapping[Any, TreeSpec], Iterable[tuple[Any, TreeSpec]]] = (), + /, + **kwargs: TreeSpec, +) -> TreeSpec: + """Make a dict treespec from a dict of child treespecs.""" + dct = dict(mapping, **kwargs) + if any(not isinstance(child, TreeSpec) for child in dct.values()): + raise ValueError(f"Expected a dictionary of TreeSpec values, got: {dct!r}.") + return TreeSpec(dict, list(dct.keys()), list(dct.values())) def tree_flatten( @@ -1762,15 +1839,15 @@ def _broadcast_to_and_flatten( return None flatten_fn = SUPPORTED_NODES[node_type].flatten_fn - child_pytrees, ctx = flatten_fn(tree) + child_pytrees, context = flatten_fn(tree) # Check if the Node is different from the spec - if len(child_pytrees) != treespec.num_children or ctx != treespec.context: + if len(child_pytrees) != treespec.num_children or context != treespec._context: return None # Recursively flatten the children result: list[Any] = [] - for child, child_spec in zip(child_pytrees, treespec.children_specs): + for child, child_spec in zip(child_pytrees, treespec._children, strict=True): flat = _broadcast_to_and_flatten(child, child_spec, is_leaf=is_leaf) if flat is not None: result += flat @@ -1824,7 +1901,7 @@ def _treespec_to_json(treespec: TreeSpec) -> _TreeSpecSchema: if serialize_node_def.to_dumpable_context is None: try: - serialized_context = json.dumps(treespec.context, cls=EnumEncoder) + serialized_context = json.dumps(treespec._context, cls=EnumEncoder) except TypeError as e: raise TypeError( "Unable to serialize context. " @@ -1832,9 +1909,9 @@ def _treespec_to_json(treespec: TreeSpec) -> _TreeSpecSchema: "custom serializer using _register_pytree_node." ) from e else: - serialized_context = serialize_node_def.to_dumpable_context(treespec.context) + serialized_context = serialize_node_def.to_dumpable_context(treespec._context) - child_schemas = [_treespec_to_json(child) for child in treespec.children_specs] + child_schemas = [_treespec_to_json(child) for child in treespec._children] return _TreeSpecSchema(serialized_type_name, serialized_context, child_schemas)