diff --git a/test/export/test_export.py b/test/export/test_export.py index 3250d82c3ea..762ad512ae3 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -91,13 +91,13 @@ from torch.testing._internal.torchbind_impls import load_torchbind_test_lib from torch.testing._internal.triton_utils import requires_cuda_and_triton, requires_gpu from torch.testing._internal.two_tensor import TwoTensor from torch.utils._pytree import ( + LeafSpec, register_constant, tree_flatten, tree_map, tree_unflatten, TreeSpec, treespec_dumps, - treespec_leaf, treespec_loads, ) @@ -7791,7 +7791,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x): dt = MyDataClass(x=3, y=4) flat, spec = tree_flatten(dt) - self.assertTrue(spec, treespec_leaf()) + self.assertTrue(spec, LeafSpec()) self.assertTrue(len(flat) == 1) torch.export.register_dataclass( @@ -7802,9 +7802,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x): flat, spec = tree_flatten(dt) self.assertEqual( spec, - TreeSpec( - MyDataClass, [["x", "y"], ["z"]], [treespec_leaf(), treespec_leaf()] - ), + TreeSpec(MyDataClass, [["x", "y"], ["z"]], [LeafSpec(), LeafSpec()]), ) self.assertEqual(flat, [3, 4]) @@ -7837,7 +7835,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x): TreeSpec( MyOtherDataClass, [["x", "y", "z"], []], - [treespec_leaf(), treespec_leaf(), treespec_leaf()], + [LeafSpec(), LeafSpec(), LeafSpec()], ), ) self.assertEqual(flat, [3, 4, None]) diff --git a/test/test_pytree.py b/test/test_pytree.py index 7cc3b8affc0..e19f1471267 100644 --- a/test/test_pytree.py +++ b/test/test_pytree.py @@ -65,6 +65,9 @@ class TestEnum(enum.Enum): A = auto() +python_leafspec = python_pytree.LeafSpec() + + class TestGenericPytree(TestCase): def test_aligned_public_apis(self): public_apis = python_pytree.__all__ @@ -194,7 +197,7 @@ class TestGenericPytree(TestCase): def run_test_with_leaf(leaf): values, treespec = pytree.tree_flatten(leaf) self.assertEqual(values, [leaf]) - self.assertEqual(treespec, pytree.treespec_leaf()) + self.assertEqual(treespec, pytree.LeafSpec()) unflattened = pytree.tree_unflatten(values, treespec) self.assertEqual(unflattened, leaf) @@ -212,7 +215,7 @@ class TestGenericPytree(TestCase): ( python_pytree, lambda tup: python_pytree.TreeSpec( - tuple, None, [python_pytree.treespec_leaf() for _ in tup] + tuple, None, [python_leafspec for _ in tup] ), ), name="python", @@ -247,7 +250,7 @@ class TestGenericPytree(TestCase): ( python_pytree, lambda lst: python_pytree.TreeSpec( - list, None, [python_pytree.treespec_leaf() for _ in lst] + list, None, [python_leafspec for _ in lst] ), ), name="python", @@ -283,7 +286,7 @@ class TestGenericPytree(TestCase): lambda dct: python_pytree.TreeSpec( dict, list(dct.keys()), - [python_pytree.treespec_leaf() for _ in dct.values()], + [python_leafspec for _ in dct.values()], ), ), name="python", @@ -324,7 +327,7 @@ class TestGenericPytree(TestCase): lambda odict: python_pytree.TreeSpec( OrderedDict, list(odict.keys()), - [python_pytree.treespec_leaf() for _ in odict.values()], + [python_leafspec for _ in odict.values()], ), ), name="python", @@ -368,7 +371,7 @@ class TestGenericPytree(TestCase): lambda ddct: python_pytree.TreeSpec( defaultdict, [ddct.default_factory, list(ddct.keys())], - [python_pytree.treespec_leaf() for _ in ddct.values()], + [python_leafspec for _ in ddct.values()], ), ), name="python", @@ -410,7 +413,7 @@ class TestGenericPytree(TestCase): ( python_pytree, lambda deq: python_pytree.TreeSpec( - deque, deq.maxlen, [python_pytree.treespec_leaf() for _ in deq] + deque, deq.maxlen, [python_leafspec for _ in deq] ), ), name="python", @@ -450,7 +453,7 @@ class TestGenericPytree(TestCase): def run_test(tup): if pytree is python_pytree: expected_spec = python_pytree.TreeSpec( - namedtuple, Point, [python_pytree.treespec_leaf() for _ in tup] + namedtuple, Point, [python_leafspec for _ in tup] ) else: expected_spec = cxx_pytree.tree_structure(Point(0, 1)) @@ -845,16 +848,16 @@ if "optree" in sys.modules: def test_treespec_equality(self): self.assertEqual( - python_pytree.treespec_leaf(), - python_pytree.treespec_leaf(), + python_pytree.LeafSpec(), + python_pytree.LeafSpec(), ) self.assertEqual( python_pytree.TreeSpec(list, None, []), python_pytree.TreeSpec(list, None, []), ) self.assertEqual( - python_pytree.TreeSpec(list, None, [python_pytree.treespec_leaf()]), - python_pytree.TreeSpec(list, None, [python_pytree.treespec_leaf()]), + python_pytree.TreeSpec(list, None, [python_pytree.LeafSpec()]), + python_pytree.TreeSpec(list, None, [python_pytree.LeafSpec()]), ) self.assertFalse( python_pytree.TreeSpec(tuple, None, []) @@ -889,32 +892,24 @@ if "optree" in sys.modules: # python_pytree.tree_structure({}) python_pytree.TreeSpec(dict, [], []), # python_pytree.tree_structure([0]) - python_pytree.TreeSpec(list, None, [python_pytree.treespec_leaf()]), + python_pytree.TreeSpec(list, None, [python_leafspec]), # python_pytree.tree_structure([0, 1]) python_pytree.TreeSpec( list, None, - [python_pytree.treespec_leaf(), python_pytree.treespec_leaf()], + [python_leafspec, python_leafspec], ), # python_pytree.tree_structure((0, 1, 2)) python_pytree.TreeSpec( tuple, None, - [ - python_pytree.treespec_leaf(), - python_pytree.treespec_leaf(), - python_pytree.treespec_leaf(), - ], + [python_leafspec, python_leafspec, python_leafspec], ), # python_pytree.tree_structure({"a": 0, "b": 1, "c": 2}) python_pytree.TreeSpec( dict, ["a", "b", "c"], - [ - python_pytree.treespec_leaf(), - python_pytree.treespec_leaf(), - python_pytree.treespec_leaf(), - ], + [python_leafspec, python_leafspec, python_leafspec], ), # python_pytree.tree_structure(OrderedDict([("a", (0, 1)), ("b", 2), ("c", {"a": 3, "b": 4, "c": 5})]) python_pytree.TreeSpec( @@ -924,17 +919,13 @@ if "optree" in sys.modules: python_pytree.TreeSpec( tuple, None, - [python_pytree.treespec_leaf(), python_pytree.treespec_leaf()], + [python_leafspec, python_leafspec], ), - python_pytree.treespec_leaf(), + python_leafspec, python_pytree.TreeSpec( dict, ["a", "b", "c"], - [ - python_pytree.treespec_leaf(), - python_pytree.treespec_leaf(), - python_pytree.treespec_leaf(), - ], + [python_leafspec, python_leafspec, python_leafspec], ), ], ), @@ -947,15 +938,12 @@ if "optree" in sys.modules: tuple, None, [ - python_pytree.treespec_leaf(), - python_pytree.treespec_leaf(), + python_leafspec, + python_leafspec, python_pytree.TreeSpec( list, None, - [ - python_pytree.treespec_leaf(), - python_pytree.treespec_leaf(), - ], + [python_leafspec, python_leafspec], ), ], ), @@ -969,12 +957,12 @@ if "optree" in sys.modules: python_pytree.TreeSpec( list, None, - [python_pytree.treespec_leaf(), python_pytree.treespec_leaf()], + [python_leafspec, python_leafspec], ), python_pytree.TreeSpec( list, None, - [python_pytree.treespec_leaf(), python_pytree.treespec_leaf()], + [python_leafspec, python_leafspec], ), python_pytree.TreeSpec(dict, [], []), ], @@ -1003,7 +991,7 @@ if "optree" in sys.modules: list, None, [ - python_pytree.treespec_leaf(), + python_leafspec, ], ), ], @@ -1012,7 +1000,7 @@ if "optree" in sys.modules: self.assertIsInstance(serialized_spec, str) def test_pytree_serialize_enum(self): - spec = python_pytree.TreeSpec(dict, TestEnum.A, [python_pytree.treespec_leaf()]) + spec = python_pytree.TreeSpec(dict, TestEnum.A, [python_leafspec]) serialized_spec = python_pytree.treespec_dumps(spec) self.assertIsInstance(serialized_spec, str) @@ -1175,20 +1163,12 @@ if "optree" in sys.modules: OrderedDict, [1, 2, 3], [ - python_pytree.TreeSpec( - tuple, - None, - [python_pytree.treespec_leaf(), python_pytree.treespec_leaf()], - ), - python_pytree.treespec_leaf(), + python_pytree.TreeSpec(tuple, None, [python_leafspec, python_leafspec]), + python_leafspec, python_pytree.TreeSpec( dict, [4, 5, 6], - [ - python_pytree.treespec_leaf(), - python_pytree.treespec_leaf(), - python_pytree.treespec_leaf(), - ], + [python_leafspec, python_leafspec, python_leafspec], ), ], ) @@ -1473,7 +1453,7 @@ class TestCxxPytree(TestCase): raise unittest.SkipTest("C++ pytree tests are not supported in fbcode") def test_treespec_equality(self): - self.assertEqual(cxx_pytree.treespec_leaf(), cxx_pytree.treespec_leaf()) + self.assertEqual(cxx_pytree.LeafSpec(), cxx_pytree.LeafSpec()) def test_treespec_repr(self): # Check that it looks sane diff --git a/torch/_dynamo/polyfills/pytree.py b/torch/_dynamo/polyfills/pytree.py index f9bdc0cce4a..ef308e90789 100644 --- a/torch/_dynamo/polyfills/pytree.py +++ b/torch/_dynamo/polyfills/pytree.py @@ -17,7 +17,7 @@ from ..decorators import substitute_in_graph if TYPE_CHECKING: import builtins - from collections.abc import Callable, Iterable, Mapping + from collections.abc import Callable, Iterable from typing_extensions import Self @@ -349,113 +349,6 @@ if python_pytree._cxx_pytree_dynamo_traceable: def _is_pytreespec_instance(obj: Any, /) -> TypeIs[PyTreeSpec]: return isinstance(obj, PyTreeSpec) - @substitute_in_graph( # type: ignore[arg-type] - optree.treespec_leaf, - # We need to disable constant folding here because we want the function to reference the - # PyTreeSpec class defined above, not the one in the C++ module. - can_constant_fold_through=False, - ) - def treespec_leaf( - *, - none_is_leaf: bool = False, - namespace: str = "", # unused - ) -> PyTreeSpec: - return PyTreeSpec( - (), - None, - None, - (), - None, - none_is_leaf=none_is_leaf, - namespace="", - ) - - @substitute_in_graph( # type: ignore[arg-type] - optree.treespec_tuple, - # We need to disable constant folding here because we want the function to reference the - # PyTreeSpec class defined above, not the one in the C++ module. - can_constant_fold_through=False, - ) - def treespec_tuple( - iterable: Iterable[PyTreeSpec] = (), - /, - *, - none_is_leaf: bool = False, - namespace: str = "", - ) -> PyTreeSpec: - children = tuple(iterable) - if any(not _is_pytreespec_instance(child) for child in children): - raise ValueError(f"Expected a tuple of PyTreeSpecs, got: {children!r}.") - if any(child.none_is_leaf != none_is_leaf for child in children): - raise ValueError( - "All children PyTreeSpecs must have the same `none_is_leaf` value " - f"as the parent; expected {none_is_leaf}, got: {children!r}.", - ) - if any(child.namespace not in (namespace, "") for child in children): - raise ValueError( - "All children PyTreeSpecs must have the same `namespace` value " - f"as the parent; expected {namespace!r}, got: {children!r}.", - ) - handler = optree.register_pytree_node.get(tuple, namespace=namespace) # type: ignore[attr-defined] - assert handler is not None - return PyTreeSpec( - tuple(children), - tuple, - None, - tuple(range(len(children))), - handler.unflatten_func, - none_is_leaf=none_is_leaf, - namespace=namespace, - ) - - @substitute_in_graph( # type: ignore[arg-type] - optree.treespec_dict, - # We need to disable constant folding here because we want the function to reference the - # PyTreeSpec class defined above, not the one in the C++ module. - can_constant_fold_through=False, - ) - def treespec_dict( - mapping: Mapping[Any, PyTreeSpec] | Iterable[tuple[Any, PyTreeSpec]] = (), - /, - *, - none_is_leaf: bool = False, - namespace: str = "", - **kwargs: PyTreeSpec, - ) -> PyTreeSpec: - dct = dict(mapping, **kwargs) - if any(not _is_pytreespec_instance(child) for child in dct.values()): - raise ValueError(f"Expected a dictionary of TreeSpecs, got: {dct!r}.") - if any(child.none_is_leaf != none_is_leaf for child in dct.values()): - raise ValueError( - "All children PyTreeSpecs must have the same `none_is_leaf` value " - f"as the parent; expected {none_is_leaf}, got: {dct!r}.", - ) - if any(child.namespace not in (namespace, "") for child in dct.values()): - raise ValueError( - "All children PyTreeSpecs must have the same `namespace` value " - f"as the parent; expected {namespace!r}, got: {dct!r}.", - ) - - ( - children, - metadata, - entries, - unflatten_func, - ) = optree.tree_flatten_one_level( # type: ignore[assignment,var-annotated] - dct, # type: ignore[arg-type] - none_is_leaf=none_is_leaf, - namespace=namespace, - ) - return PyTreeSpec( - tuple(children), # type: ignore[arg-type] - dict, - metadata, - entries, - unflatten_func, - none_is_leaf=none_is_leaf, - namespace=namespace, - ) - @substitute_in_graph( # type: ignore[arg-type] optree.tree_flatten, # We need to disable constant folding here because we want the function to reference the diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index d19e7998dee..a2bbcb65710 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -3727,7 +3727,9 @@ class SourcelessBuilder: pass # failthrough to unimplemented branch elif isinstance(value, torch.fx.graph_module.GraphModule): return SourcelessGraphModuleVariable(value) - elif isinstance(value, torch.utils._pytree.TreeSpec): + elif isinstance( + value, (torch.utils._pytree.TreeSpec, torch.utils._pytree.LeafSpec) + ): return UserDefinedObjectVariable(value) elif PlacementVariable.is_placement(value): return PlacementVariable(value) diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py index c330a700fd6..bee608f1fb0 100644 --- a/torch/_dynamo/variables/higher_order_ops.py +++ b/torch/_dynamo/variables/higher_order_ops.py @@ -1661,8 +1661,8 @@ class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable): # We need to have this check this way, because in case init is a TreeSpec and carry # but carry is only a LeafSpec, these two cannot be compared correctly. if ( - xs_treespec.as_python_constant().is_leaf() - != _combine_treespec.as_python_constant().is_leaf() + isinstance(xs_treespec.as_python_constant(), pytree.LeafSpec) + != isinstance(_combine_treespec.as_python_constant(), pytree.LeafSpec) ) or not _make_inlined(tx, pytree.TreeSpec.__eq__)( xs_treespec, _combine_treespec ).as_python_constant(): diff --git a/torch/_export/serde/serialize.py b/torch/_export/serde/serialize.py index a0a40666d03..7acacdd0ca1 100644 --- a/torch/_export/serde/serialize.py +++ b/torch/_export/serde/serialize.py @@ -1530,10 +1530,10 @@ class GraphModuleSerializer(metaclass=Final): else: raise AssertionError("TODO") - def serialize_treespec(self, treespec: pytree.TreeSpec) -> str: + def serialize_treespec(self, treespec): # We want to additionally save all the field names of the namedtuples in # case users want to check that the treespec types are equivalent - def store_namedtuple_fields(ts: pytree.TreeSpec) -> None: + def store_namedtuple_fields(ts): if ts.type is None: return if ts.type is namedtuple or pytree.is_namedtuple_class(ts.type): @@ -1555,7 +1555,7 @@ class GraphModuleSerializer(metaclass=Final): NamedTupleDef(field_names=ts.context._fields) ) - for child in ts.children(): + for child in ts.children_specs: store_namedtuple_fields(child) serialized_treespec = treespec_dumps(treespec, TREESPEC_VERSION) diff --git a/torch/_functorch/_aot_autograd/utils.py b/torch/_functorch/_aot_autograd/utils.py index f3e8f6a91b3..858c0e9e539 100644 --- a/torch/_functorch/_aot_autograd/utils.py +++ b/torch/_functorch/_aot_autograd/utils.py @@ -158,7 +158,7 @@ class PytreeThunk: assert spec is not None self.spec: pytree.TreeSpec = spec if self.spec.type in {tuple, list} and all( - child.is_leaf() for child in spec.children() + child.is_leaf() for child in spec.children_specs ): self.is_simple = True if self.spec.is_leaf(): diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py index f48cb04f08f..bf561e62a39 100644 --- a/torch/_functorch/aot_autograd.py +++ b/torch/_functorch/aot_autograd.py @@ -1590,7 +1590,7 @@ def aot_export_joint_simple( decompositions=decompositions, trace_joint=trace_joint, ) - in_spec, _kw_in_spec = in_spec.children() + in_spec, _kw_in_spec = in_spec.children_specs # At this point, we can just directly return the (joint or inference graph) that we traced. # First though: a bunch of assertions to make sure that our graph doesn't require # any calling convention changes compared to the original function. @@ -1617,7 +1617,7 @@ def aot_export_joint_simple( raise RuntimeError( f"aot_export_joint_simple requires inputs to be a single list/tuple. in_spec={str(in_spec)}" ) - if not all(child.is_leaf() for child in in_spec.children()): + if not all(child.is_leaf() for child in in_spec.children_specs): raise RuntimeError( f"aot_export_joint_simple requires individual inputs not to be pytrees. in_spec={str(in_spec)}" ) @@ -1625,7 +1625,7 @@ def aot_export_joint_simple( raise RuntimeError( f"aot_export_joint_simple requires outputs to be a single list/tuple. out_spec={str(out_spec)}" ) - if not all(child.is_leaf() for child in out_spec.children()): + if not all(child.is_leaf() for child in out_spec.children_specs): raise RuntimeError( f"aot_export_joint_simple requires individual outputs not to be pytrees. out_spec={str(out_spec)}" ) diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index 4e40d2347b1..33b25b319da 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -469,7 +469,7 @@ def _unlift_graph( gm, lifted_inputs, mutated_outputs, - pytree.treespec_leaf(), + pytree.LeafSpec(), None, ) return unlifted_gm diff --git a/torch/ao/quantization/pt2e/utils.py b/torch/ao/quantization/pt2e/utils.py index f6e9789e948..1575f936044 100644 --- a/torch/ao/quantization/pt2e/utils.py +++ b/torch/ao/quantization/pt2e/utils.py @@ -7,7 +7,6 @@ from typing import Any, Optional, Union import torch import torch.ao.quantization.pt2e._affine_quantization # noqa: F401 import torch.nn.functional as F -import torch.utils._pytree as pytree # Makes sure that quantized_decomposed ops are registered from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401 @@ -15,6 +14,7 @@ from torch.ao.quantization.quantizer import QuantizationAnnotation from torch.export.unflatten import _assign_attr, _AttrKind from torch.fx import GraphModule, Node from torch.nn.utils.fusion import fuse_conv_bn_weights +from torch.utils._pytree import LeafSpec __all__ = [ @@ -477,10 +477,7 @@ def _replace_literals_with_new_placeholders( exclude_literals = [] in_spec = gm._in_spec - assert in_spec.type is tuple - args_spec = in_spec.child(0) - assert args_spec.type is tuple - args_spec_children = args_spec.children() + args_spec = in_spec.children_specs[0] for node in gm.graph.nodes: if node.op == "placeholder": last_ph = node @@ -495,7 +492,7 @@ def _replace_literals_with_new_placeholders( else: ph_node = gm.graph.placeholder("arg" + str(cnt)) new_args.append(ph_node) - args_spec_children.append(pytree.treespec_leaf()) + args_spec.children_specs.append(LeafSpec()) cnt += 1 if merge_dup: literal_to_ph[arg] = ph_node @@ -506,8 +503,8 @@ def _replace_literals_with_new_placeholders( node.args = new_args # Update `num_nodes`, `num_leaves`, `num_children`. - args_spec = pytree.treespec_tuple(args_spec_children) - gm._in_spec = in_spec = pytree.treespec_tuple([args_spec, *in_spec.children()[1:]]) + args_spec.__post_init__() + in_spec.__post_init__() return gm diff --git a/torch/export/_swap.py b/torch/export/_swap.py index 78c507da02d..0c38d540b97 100644 --- a/torch/export/_swap.py +++ b/torch/export/_swap.py @@ -195,16 +195,17 @@ def _construct_inputs( unflatten_node = _generate_unflatten(gm, tree_unflatten_args, signature.in_spec) assert signature.in_spec.num_children == 2 - assert signature.in_spec.type is tuple - args_spec, kwargs_spec = signature.in_spec.children() - assert args_spec.type is tuple - assert kwargs_spec.type is dict + args_spec = signature.in_spec.children_specs[0] + assert args_spec.context is None args_node = gm.graph.call_function(operator.getitem, (unflatten_node, 0)) args_nodes = [ gm.graph.call_function(operator.getitem, (args_node, i)) for i in range(args_spec.num_children) ] + + kwargs_spec = signature.in_spec.children_specs[1] + assert kwargs_spec.context is not None kwargs_node = gm.graph.call_function(operator.getitem, (unflatten_node, 1)) kwargs_nodes = { k: gm.graph.call_function(operator.getitem, (kwargs_node, k)) @@ -371,8 +372,8 @@ def _fix_input_output_signature( if forward_arg_names is None: forward_arg_names = [] assert signature.in_spec.num_children == 2 - arg_spec = signature.in_spec.child(0) - kwarg_spec = signature.in_spec.child(1) + arg_spec = signature.in_spec.children_specs[0] + kwarg_spec = signature.in_spec.children_specs[1] assert arg_spec.type is tuple assert kwarg_spec.type is dict for i in range(arg_spec.num_children): diff --git a/torch/export/_trace.py b/torch/export/_trace.py index 934ee448820..685fe149714 100644 --- a/torch/export/_trace.py +++ b/torch/export/_trace.py @@ -1533,7 +1533,7 @@ def _strict_export( # aot_export expect the return type to always be a tuple. if out_spec.type not in (list, tuple): - out_spec = pytree.treespec_tuple([out_spec]) + out_spec = pytree.TreeSpec(tuple, None, [out_spec]) orig_arg_names = gm_torch_level.graph._codegen.pytree_info.orig_args # type: ignore[attr-defined] diff --git a/torch/export/_tree_utils.py b/torch/export/_tree_utils.py index de768c38f21..5c2d4426066 100644 --- a/torch/export/_tree_utils.py +++ b/torch/export/_tree_utils.py @@ -15,7 +15,7 @@ def reorder_kwargs(user_kwargs: dict[str, Any], spec: TreeSpec) -> dict[str, Any # Make sure that the spec is actually shaped like (args, kwargs) assert spec.type is tuple assert spec.num_children == 2 - kwargs_spec = spec.child(1) + kwargs_spec = spec.children_specs[1] assert kwargs_spec.type is dict if set(user_kwargs) != set(kwargs_spec.context): @@ -55,10 +55,10 @@ def is_equivalent( return False # Recurse on children - if spec1.num_children != spec2.num_children: + if len(spec1.children_specs) != len(spec2.children_specs): return False - for child_spec1, child_spec2 in zip(spec1.children(), spec2.children()): + for child_spec1, child_spec2 in zip(spec1.children_specs, spec2.children_specs): if not is_equivalent(child_spec1, child_spec2, equivalence_fn): return False diff --git a/torch/export/_unlift.py b/torch/export/_unlift.py index 52d06a294fa..b9e82481322 100644 --- a/torch/export/_unlift.py +++ b/torch/export/_unlift.py @@ -57,11 +57,11 @@ def eq_spec(self: pytree.TreeSpec, other: pytree.TreeSpec) -> bool: return False elif a.context != b.context: return False - if a.num_children != b.num_children: + if len(a.children_specs) != len(b.children_specs): return False return all( _match_normalized_structure(a, b) - for a, b in zip(a.children(), b.children()) + for a, b in zip(a.children_specs, b.children_specs) ) return _match_normalized_structure(self, other) @@ -357,13 +357,13 @@ def _get_codegen( elif ( in_spec.type is tuple and in_spec.num_children == 2 - and in_spec.child(0).type is tuple - and in_spec.child(1).type is dict + and in_spec.children_specs[0].type is tuple + and in_spec.children_specs[1].type is dict ): # if in_spec contains the args (tuple) and kwargs (dict) - names = [f"arg_{i}" for i in range(in_spec.child(0).num_children)] + names = [f"arg_{i}" for i in range(in_spec.children_specs[0].num_children)] # add kwarg names - names.extend(in_spec.child(1).context) + names.extend(in_spec.children_specs[1].context) else: names = [f"arg_{i}" for i in range(in_spec.num_children)] diff --git a/torch/export/dynamic_shapes.py b/torch/export/dynamic_shapes.py index 1e1f1f40985..24964c27472 100644 --- a/torch/export/dynamic_shapes.py +++ b/torch/export/dynamic_shapes.py @@ -12,16 +12,14 @@ import torch from torch.utils._pytree import ( _get_node_type, BUILTIN_TYPES, - KeyPath, keystr, + LeafSpec, MappingKey, SequenceKey, SUPPORTED_NODES, - tree_iter, + tree_flatten, tree_map, tree_map_with_path, - tree_structure, - TreeSpec, ) from .exported_program import ExportedProgram @@ -657,55 +655,53 @@ def _tree_map_with_path( case_name="dynamic_shapes_validation", ) - def _compare( - treespec: TreeSpec, other_treespec: TreeSpec, path: KeyPath - ) -> None: + def _compare(tree, dynamic_shapes, path): # raise an error at the point where tree and dynamic_shapes differ, # including the path to that point and the reason for the difference rendered_path = keystr(path) - if treespec.is_leaf(): + if isinstance(tree, LeafSpec): return - if other_treespec.is_leaf(): + if isinstance(dynamic_shapes, LeafSpec): raise_mismatch_error( - f"`{tree_name}{rendered_path}` is a {treespec.type}, " + f"`{tree_name}{rendered_path}` is a {tree.type}, " f"but `dynamic_shapes{rendered_path}` is not" ) - if treespec.type != other_treespec.type: + if tree.type != dynamic_shapes.type: raise_mismatch_error( - f"`{tree_name}{rendered_path}` is a {treespec.type}, " - f"but `dynamic_shapes{rendered_path}` is a {other_treespec.type}" + f"`{tree_name}{rendered_path}` is a {tree.type}, " + f"but `dynamic_shapes{rendered_path}` is a {dynamic_shapes.type}" ) - if treespec.num_children != other_treespec.num_children: + if len(tree.children_specs) != len(dynamic_shapes.children_specs): raise_mismatch_error( - f"`{tree_name}{rendered_path}` has {treespec.num_children} elements, " - f"but `dynamic_shapes{rendered_path}` has {other_treespec.num_children} elements" + f"`{tree_name}{rendered_path}` has {len(tree.children_specs)} elements, " + f"but `dynamic_shapes{rendered_path}` has {len(dynamic_shapes.children_specs)} elements" ) - if treespec.type is dict: + if tree.type is dict: # context, children could be out of order - if set(treespec.context) != set(other_treespec.context): + if sorted(tree.context) != sorted(dynamic_shapes.context): raise_mismatch_error( - f"`{tree_name}{rendered_path}` has keys {treespec.context}, " - f"but `dynamic_shapes{rendered_path}` has keys {other_treespec.context}" + f"`{tree_name}{rendered_path}` has keys {tree.context}, " + f"but `dynamic_shapes{rendered_path}` has keys {dynamic_shapes.context}" ) _remap = dict( - zip(other_treespec.context, other_treespec.children()) + zip(dynamic_shapes.context, dynamic_shapes.children_specs) ) - other_children = [_remap[k] for k in treespec.context] + dynamic_shapes_children_specs = [_remap[k] for k in tree.context] else: - other_children = other_treespec.children() - for i, (child, other_child) in enumerate( - zip(treespec.children(), other_children) + dynamic_shapes_children_specs = dynamic_shapes.children_specs + for i, (tree_, dynamic_shapes_) in enumerate( + zip(tree.children_specs, dynamic_shapes_children_specs) ): _compare( - child, - other_child, - path + (_key(treespec.type, treespec.context, i),), + tree_, + dynamic_shapes_, + path + [_key(tree.type, tree.context, i)], ) - treespec = tree_structure(tree, is_leaf=is_leaf) + _, tree_spec = tree_flatten(tree, is_leaf=is_leaf) for other_tree in dynamic_shapes: - other_treespec = tree_structure(other_tree, is_leaf) - _compare(treespec, other_treespec, ()) + _, other_tree_spec = tree_flatten(other_tree, is_leaf) + _compare(tree_spec, other_tree_spec, []) raise @@ -1235,7 +1231,10 @@ def _get_dim_name_mapping( dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any], None], ): name_to_dim = {} - for dim in tree_iter(dynamic_shapes, is_leaf=lambda x: isinstance(x, Dim)): + for dim in tree_flatten( + dynamic_shapes, + is_leaf=lambda x: isinstance(x, Dim), + )[0]: if dim is None: # NOTE: this must denote a non-Tensor or automatic at this point. continue diff --git a/torch/export/exported_program.py b/torch/export/exported_program.py index f8faff582f0..3b2d0cb7401 100644 --- a/torch/export/exported_program.py +++ b/torch/export/exported_program.py @@ -391,7 +391,7 @@ def _decompose_and_get_gm_with_new_signature_constants( # aot_export expect the return type to always be a tuple. assert out_spec is not None if out_spec.type not in (list, tuple): - out_spec = pytree.treespec_tuple([out_spec]) + out_spec = pytree.TreeSpec(tuple, None, [out_spec]) mod.graph._codegen = _PyTreeCodeGen( _PyTreeInfo( diff --git a/torch/export/unflatten.py b/torch/export/unflatten.py index 3701ba99047..1e4931f4a19 100644 --- a/torch/export/unflatten.py +++ b/torch/export/unflatten.py @@ -1124,10 +1124,10 @@ class _ModuleFrame: signature = module_call_graph.get(self.child_fqn) if signature is not None and self.parent is not None: assert signature.in_spec.num_children == 2 - assert signature.in_spec.type is tuple - args_spec, kwargs_spec = signature.in_spec.children() - assert args_spec.type is tuple - assert kwargs_spec.type is dict + args_spec = signature.in_spec.children_specs[0] + kwargs_spec = signature.in_spec.children_specs[1] + assert args_spec.context is None + assert kwargs_spec.context is not None with self.graph.inserting_after(None): arg_nodes = [ diff --git a/torch/fx/_pytree.py b/torch/fx/_pytree.py index bfb62f871eb..2f608816c49 100644 --- a/torch/fx/_pytree.py +++ b/torch/fx/_pytree.py @@ -49,7 +49,7 @@ def tree_flatten_spec( flatten_fn_spec = SUPPORTED_NODES[spec.type] child_pytrees = flatten_fn_spec(pytree, spec) result = [] - for child, child_spec in zip(child_pytrees, spec.children()): + for child, child_spec in zip(child_pytrees, spec.children_specs): flat = tree_flatten_spec(child, child_spec) result += flat return result diff --git a/torch/fx/_symbolic_trace.py b/torch/fx/_symbolic_trace.py index cfce00fb05e..cc72d410fd7 100644 --- a/torch/fx/_symbolic_trace.py +++ b/torch/fx/_symbolic_trace.py @@ -709,7 +709,7 @@ class Tracer(TracerBase): root_fn = _patch_function(root_fn, len(args)) flat_args, in_spec = pytree.tree_flatten(tuple(args)) - if not all(child.is_leaf() for child in in_spec.children()): + if not all(child.is_leaf() for child in in_spec.children_specs): # In the case that we have pytree-flattened inputs in # `concrete_args`, generate a flattening wrapper around the # original root function and return that. diff --git a/torch/fx/graph.py b/torch/fx/graph.py index fc6f4c5b270..3a319333529 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -933,25 +933,24 @@ class _PyTreeCodeGen(CodeGen): return "\n " + "".join(x + "; " for x in has_annotation) + "\n" def gen_var_bindings(self, fn_args, free_vars, expanded_def) -> str: - in_spec = self.pytree_info.in_spec # when kwargs is present, in_spec is tuple(args, kwargs) has_args_kwargs_tuple = ( - in_spec.type is tuple - and in_spec.num_children == 2 - and in_spec.child(0).type is tuple - and in_spec.child(1).type is dict + self.pytree_info.in_spec.type is tuple + and self.pytree_info.in_spec.num_children == 2 + and self.pytree_info.in_spec.children_specs[0].type is tuple + and self.pytree_info.in_spec.children_specs[1].type is dict ) fn_kwargs = "{}" fn_signature = f"[{', '.join(fn_args)}], self._in_spec" if has_args_kwargs_tuple: - count_args = in_spec.child(0).num_children + count_args = self.pytree_info.in_spec.children_specs[0].num_children fn_args = self.pytree_info.orig_args[:count_args] fn_kwargs = ( "{" + ", ".join( f"'{k}':{v}" for k, v in zip( - in_spec.child(1).context, + self.pytree_info.in_spec.children_specs[1].context, self.pytree_info.orig_args[count_args:], ) ) diff --git a/torch/utils/_cxx_pytree.py b/torch/utils/_cxx_pytree.py index 603625ed97c..1d67bb41da2 100644 --- a/torch/utils/_cxx_pytree.py +++ b/torch/utils/_cxx_pytree.py @@ -14,9 +14,9 @@ collection support for PyTorch APIs. import functools import types -from collections.abc import Callable, Iterable, Mapping +from collections.abc import Callable, Iterable from typing import Any, Optional, overload, TypeVar, Union -from typing_extensions import deprecated, Self, TypeAlias, TypeIs +from typing_extensions import deprecated, TypeIs import torch.utils._pytree as python_pytree from torch.torch_version import TorchVersion as _TorchVersion @@ -42,7 +42,7 @@ if not python_pytree._cxx_pytree_dynamo_traceable: import optree -from optree import PyTreeSpec # direct import for type annotations +from optree import PyTreeSpec as TreeSpec # direct import for type annotations __all__ = [ @@ -53,7 +53,6 @@ __all__ = [ "DumpableContext", "ToDumpableContextFn", "FromDumpableContextFn", - "PyTreeSpec", "TreeSpec", "LeafSpec", "keystr", @@ -101,8 +100,6 @@ U = TypeVar("U") R = TypeVar("R") -TreeSpec: TypeAlias = PyTreeSpec - Context = Any PyTree = Any FlattenFunc = Callable[[PyTree], tuple[list[Any], Context]] @@ -270,30 +267,6 @@ def _is_pytreespec_instance(obj: Any, /) -> TypeIs[TreeSpec]: return isinstance(obj, TreeSpec) -def treespec_leaf() -> TreeSpec: - """Make a treespec representing a leaf node.""" - return optree.treespec_leaf(none_is_leaf=True, namespace="torch") - - -def treespec_tuple(iterable: Iterable[TreeSpec] = (), /) -> TreeSpec: - """Make a tuple treespec from an iterable of child treespecs.""" - return optree.treespec_tuple(iterable, none_is_leaf=True, namespace="torch") - - -def treespec_dict( - mapping: Union[Mapping[Any, TreeSpec], Iterable[tuple[Any, TreeSpec]]] = (), - /, - **kwargs: TreeSpec, -) -> TreeSpec: - """Make a dict treespec from a dict of child treespecs.""" - return optree.treespec_dict( - mapping, - **kwargs, - none_is_leaf=True, - namespace="torch", - ) - - def tree_is_leaf( tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, @@ -1012,14 +985,9 @@ class LeafSpecMeta(type(TreeSpec)): # type: ignore[misc] return _is_pytreespec_instance(instance) and instance.is_leaf() -@deprecated( - "`isinstance(treespec, LeafSpec)` is deprecated, " - "use `isinstance(treespec, TreeSpec)` and `treespec.is_leaf()` instead.", - category=FutureWarning, -) class LeafSpec(TreeSpec, metaclass=LeafSpecMeta): # type: ignore[misc,final] - def __new__(cls) -> Self: - return treespec_leaf() # type: ignore[return-value] + def __new__(cls) -> "LeafSpec": + return optree.treespec_leaf(none_is_leaf=True) # type: ignore[return-value] def tree_flatten_with_path( diff --git a/torch/utils/_pytree.py b/torch/utils/_pytree.py index 119e96ef897..0d6d1e71445 100644 --- a/torch/utils/_pytree.py +++ b/torch/utils/_pytree.py @@ -39,7 +39,7 @@ from typing import ( TypeVar, Union, ) -from typing_extensions import deprecated, NamedTuple, Self, TypeAlias +from typing_extensions import deprecated, NamedTuple, Self from torch.torch_version import TorchVersion as _TorchVersion @@ -52,7 +52,6 @@ __all__ = [ "DumpableContext", "ToDumpableContextFn", "FromDumpableContextFn", - "PyTreeSpec", "TreeSpec", "LeafSpec", "keystr", @@ -472,7 +471,7 @@ class ConstantNode: def _is_constant_holder(spec: "TreeSpec") -> bool: """Checks if the spec is from a pytree registered with register_constant""" - return isinstance(spec._context, ConstantNode) + return isinstance(spec.context, ConstantNode) def _retrieve_constant(spec: "TreeSpec") -> Any: @@ -1072,53 +1071,35 @@ def _is_leaf(tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None) - # context: some context that is useful in unflattening the pytree # children_specs: specs for each child of the root Node # num_leaves: the number of leaves -@dataclasses.dataclass(init=False, frozen=True, eq=True, repr=False) +@dataclasses.dataclass(init=True, frozen=True, eq=True, repr=False) class TreeSpec: type: Any - _context: Context - _children: list[Self] + context: Context + children_specs: list["TreeSpec"] num_nodes: int = dataclasses.field(init=False) num_leaves: int = dataclasses.field(init=False) num_children: int = dataclasses.field(init=False) - def __init__( - self, - type: Any, - context: Context, # keep for backward compatibility - children: list[Self], # keep for backward compatibility - ) -> None: - object.__setattr__(self, "type", type) - object.__setattr__(self, "_context", context) - object.__setattr__(self, "_children", children) - self.__post_init__() - def __post_init__(self) -> None: - if self.type is None: - assert self._context is None - assert len(self._children) == 0 - num_nodes = 1 - num_leaves = 1 - num_children = 0 - else: - num_nodes = sum((spec.num_nodes for spec in self._children), start=1) - num_leaves = sum(spec.num_leaves for spec in self._children) - num_children = len(self._children) + num_nodes = sum((spec.num_nodes for spec in self.children_specs), start=1) + num_leaves = sum(spec.num_leaves for spec in self.children_specs) + num_children = len(self.children_specs) object.__setattr__(self, "num_nodes", num_nodes) object.__setattr__(self, "num_leaves", num_leaves) object.__setattr__(self, "num_children", num_children) def __repr__(self, indent: int = 0) -> str: - repr_prefix: str = f"TreeSpec({self.type.__name__}, {self._context}, [" + repr_prefix: str = f"TreeSpec({self.type.__name__}, {self.context}, [" children_specs_str: str = "" if self.num_children > 0: indent += 2 - children_specs_str += self._children[0].__repr__(indent) + children_specs_str += self.children_specs[0].__repr__(indent) children_specs_str += "," if self.num_children > 1 else "" children_specs_str += ",".join( [ "\n" + " " * indent + child.__repr__(indent) - for child in self._children[1:] + for child in self.children_specs[1:] ] ) repr_suffix: str = f"{children_specs_str}])" @@ -1130,36 +1111,16 @@ class TreeSpec: elif other.__class__ is self.__class__: if str(self.type) != str(other.type): return False - if self._context != other._context: + if self.context != other.context: return False - elif self._children != other._children: + elif self.children_specs != other.children_specs: return False return True return NotImplemented - @property - def context(self) -> Context: - return self._context - - @property - @deprecated( - "`treespec.children_specs` is deprecated. " - "Use `treespec.child(index)` to access a single child, " - "or `treespec.children()` to get all children.", - category=FutureWarning, - ) - def children_specs(self) -> list[Self]: - return self._children - def is_leaf(self) -> bool: return self.num_nodes == 1 and self.num_leaves == 1 - def children(self) -> list[Self]: - return self._children.copy() - - def child(self, index: int) -> Self: - return self._children[index] - def flatten_up_to(self, tree: PyTree) -> list[PyTree]: def helper(treespec: TreeSpec, tree: PyTree, subtrees: list[PyTree]) -> None: if treespec.is_leaf(): @@ -1181,7 +1142,7 @@ class TreeSpec: f"Node arity mismatch; " f"expected {treespec.num_children}, but got {len(children)}.", ) - if context != treespec._context: + if context != treespec.context: raise ValueError( f"Node context mismatch for custom node type {treespec.type!r}.", ) @@ -1206,10 +1167,10 @@ class TreeSpec: if both_standard_dict: # dictionary types are compatible with each other dict_context = ( - treespec._context + treespec.context if treespec.type is not defaultdict # ignore mismatch of `default_factory` for defaultdict - else treespec._context[1] + else treespec.context[1] ) expected_keys = dict_context got_key_set = set(tree) @@ -1230,13 +1191,13 @@ class TreeSpec: children, context = flatten_fn(tree) if ( node_type is not deque # ignore mismatch of `maxlen` for deque - ) and context != treespec._context: + ) and context != treespec.context: raise ValueError( f"Node context mismatch for node type {treespec.type!r}; " - f"expected {treespec._context!r}, but got {context!r}.", # namedtuple type mismatch + f"expected {treespec.context!r}, but got {context!r}.", # namedtuple type mismatch ) - for subtree, subspec in zip(children, treespec._children, strict=True): + for subtree, subspec in zip(children, treespec.children_specs): helper(subspec, subtree, subtrees) subtrees: list[PyTree] = [] @@ -1261,24 +1222,24 @@ class TreeSpec: start = 0 end = 0 child_pytrees = [] - for child_spec in self._children: + for child_spec in self.children_specs: end += child_spec.num_leaves child_pytrees.append(child_spec.unflatten(leaves[start:end])) start = end - return unflatten_fn(child_pytrees, self._context) + return unflatten_fn(child_pytrees, self.context) def __hash__(self) -> int: node_type = self.type if node_type is defaultdict: - default_factory, dict_context = self._context + default_factory, dict_context = self.context hashable_context = (default_factory, tuple(dict_context)) elif node_type in (dict, OrderedDict): - hashable_context = tuple(self._context) + hashable_context = tuple(self.context) elif node_type is None or node_type in BUILTIN_TYPES: - hashable_context = self._context - elif isinstance(self._context, ConstantNode): - hashable_context = self._context.value + hashable_context = self.context + elif isinstance(self.context, ConstantNode): + hashable_context = self.context.value else: # The context for user-defined node types might not be hashable. # Ignore it for hashing. @@ -1286,26 +1247,20 @@ class TreeSpec: # same hash. This might increase the hash collision rate, but we # don't care about that. hashable_context = None - return hash((node_type, hashable_context, tuple(self._children))) - - -PyTreeSpec: TypeAlias = TreeSpec + return hash((node_type, hashable_context, tuple(self.children_specs))) # NOTE: subclassing a dataclass is subtle. In order to enable reasoning about # this class with `dataclasses.fields`, etc., while having a simplified # constructor that takes no argument, we wrap with `dataclass(init=True, ...)` # again, with fields that have `init=False`. -@deprecated( - "`isinstance(treespec, LeafSpec)` is deprecated, " - "use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.", - category=FutureWarning, -) @dataclasses.dataclass(init=True, frozen=True, eq=False, repr=False) class LeafSpec(TreeSpec): type: Any = dataclasses.field(default=None, init=False) - _context: Context = dataclasses.field(default=None, init=False) - _children: list[Self] = dataclasses.field(default_factory=list, init=False) + context: Context = dataclasses.field(default=None, init=False) + children_specs: list["TreeSpec"] = dataclasses.field( + default_factory=list, init=False + ) def __post_init__(self) -> None: # Override `__post_init__` for `num_leaves` derivation. @@ -1319,36 +1274,7 @@ class LeafSpec(TreeSpec): # All leaves are equivalent, so represent with a single object to save on # object construction time -with warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", category=FutureWarning, module=__name__, append=False - ) - _LEAF_SPEC = LeafSpec() - - -def treespec_leaf() -> LeafSpec: - """Make a treespec representing a leaf node.""" - return _LEAF_SPEC - - -def treespec_tuple(iterable: Iterable[TreeSpec] = (), /) -> TreeSpec: - """Make a tuple treespec from an iterable of child treespecs.""" - children = list(iterable) - if any(not isinstance(child, TreeSpec) for child in children): - raise ValueError(f"Expected a tuple of TreeSpec values, got: {children!r}.") - return TreeSpec(tuple, None, children) - - -def treespec_dict( - mapping: Union[Mapping[Any, TreeSpec], Iterable[tuple[Any, TreeSpec]]] = (), - /, - **kwargs: TreeSpec, -) -> TreeSpec: - """Make a dict treespec from a dict of child treespecs.""" - dct = dict(mapping, **kwargs) - if any(not isinstance(child, TreeSpec) for child in dct.values()): - raise ValueError(f"Expected a dictionary of TreeSpec values, got: {dct!r}.") - return TreeSpec(dict, list(dct.keys()), list(dct.values())) +_LEAF_SPEC = LeafSpec() def tree_flatten( @@ -1827,15 +1753,15 @@ def _broadcast_to_and_flatten( return None flatten_fn = SUPPORTED_NODES[node_type].flatten_fn - child_pytrees, context = flatten_fn(tree) + child_pytrees, ctx = flatten_fn(tree) # Check if the Node is different from the spec - if len(child_pytrees) != treespec.num_children or context != treespec._context: + if len(child_pytrees) != treespec.num_children or ctx != treespec.context: return None # Recursively flatten the children result: list[Any] = [] - for child, child_spec in zip(child_pytrees, treespec._children, strict=True): + for child, child_spec in zip(child_pytrees, treespec.children_specs): flat = _broadcast_to_and_flatten(child, child_spec, is_leaf=is_leaf) if flat is not None: result += flat @@ -1889,7 +1815,7 @@ def _treespec_to_json(treespec: TreeSpec) -> _TreeSpecSchema: if serialize_node_def.to_dumpable_context is None: try: - serialized_context = json.dumps(treespec._context, cls=EnumEncoder) + serialized_context = json.dumps(treespec.context, cls=EnumEncoder) except TypeError as e: raise TypeError( "Unable to serialize context. " @@ -1897,9 +1823,9 @@ def _treespec_to_json(treespec: TreeSpec) -> _TreeSpecSchema: "custom serializer using _register_pytree_node." ) from e else: - serialized_context = serialize_node_def.to_dumpable_context(treespec._context) + serialized_context = serialize_node_def.to_dumpable_context(treespec.context) - child_schemas = [_treespec_to_json(child) for child in treespec._children] + child_schemas = [_treespec_to_json(child) for child in treespec.children_specs] return _TreeSpecSchema(serialized_type_name, serialized_context, child_schemas)