[pytree] add treespec_{leaf,tuple,dict} functions for args_spec modification (#160843)

The goal of this PR is to provide a standard way to create simple treespec instances and hide the implementation details of the `PyTreeSpec` class.

Changes:

1. Add function `treespec_leaf()` to replace `LeafSpec()`.
2. Add function `treespec_tuple(...)` and `treespec_dict(...)` to create treespec for `tuple` / `dict` which is used for `*args` / `**kwargs`. This avoids direct modification to `treespec` instances that rely on the implementation details of the `PyTreeSpec` class.
3. Change `len(spec.children_specs)` to `spec.num_children`.
4. Change `isinstance(spec, LeafSpec)` to `spec.is_leaf()`.

------

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160843
Approved by: https://github.com/mlazos
This commit is contained in:
Xuehai Pan 2025-11-01 08:48:26 +08:00 committed by PyTorch MergeBot
parent 60333de85d
commit e8fadba28c
22 changed files with 404 additions and 164 deletions

View File

@ -91,13 +91,13 @@ from torch.testing._internal.torchbind_impls import load_torchbind_test_lib
from torch.testing._internal.triton_utils import requires_cuda_and_triton, requires_gpu from torch.testing._internal.triton_utils import requires_cuda_and_triton, requires_gpu
from torch.testing._internal.two_tensor import TwoTensor from torch.testing._internal.two_tensor import TwoTensor
from torch.utils._pytree import ( from torch.utils._pytree import (
LeafSpec,
register_constant, register_constant,
tree_flatten, tree_flatten,
tree_map, tree_map,
tree_unflatten, tree_unflatten,
TreeSpec, TreeSpec,
treespec_dumps, treespec_dumps,
treespec_leaf,
treespec_loads, treespec_loads,
) )
@ -7791,7 +7791,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
dt = MyDataClass(x=3, y=4) dt = MyDataClass(x=3, y=4)
flat, spec = tree_flatten(dt) flat, spec = tree_flatten(dt)
self.assertTrue(spec, LeafSpec()) self.assertTrue(spec, treespec_leaf())
self.assertTrue(len(flat) == 1) self.assertTrue(len(flat) == 1)
torch.export.register_dataclass( torch.export.register_dataclass(
@ -7802,7 +7802,9 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
flat, spec = tree_flatten(dt) flat, spec = tree_flatten(dt)
self.assertEqual( self.assertEqual(
spec, spec,
TreeSpec(MyDataClass, [["x", "y"], ["z"]], [LeafSpec(), LeafSpec()]), TreeSpec(
MyDataClass, [["x", "y"], ["z"]], [treespec_leaf(), treespec_leaf()]
),
) )
self.assertEqual(flat, [3, 4]) self.assertEqual(flat, [3, 4])
@ -7835,7 +7837,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
TreeSpec( TreeSpec(
MyOtherDataClass, MyOtherDataClass,
[["x", "y", "z"], []], [["x", "y", "z"], []],
[LeafSpec(), LeafSpec(), LeafSpec()], [treespec_leaf(), treespec_leaf(), treespec_leaf()],
), ),
) )
self.assertEqual(flat, [3, 4, None]) self.assertEqual(flat, [3, 4, None])

View File

@ -65,9 +65,6 @@ class TestEnum(enum.Enum):
A = auto() A = auto()
python_leafspec = python_pytree.LeafSpec()
class TestGenericPytree(TestCase): class TestGenericPytree(TestCase):
def test_aligned_public_apis(self): def test_aligned_public_apis(self):
public_apis = python_pytree.__all__ public_apis = python_pytree.__all__
@ -197,7 +194,7 @@ class TestGenericPytree(TestCase):
def run_test_with_leaf(leaf): def run_test_with_leaf(leaf):
values, treespec = pytree.tree_flatten(leaf) values, treespec = pytree.tree_flatten(leaf)
self.assertEqual(values, [leaf]) self.assertEqual(values, [leaf])
self.assertEqual(treespec, pytree.LeafSpec()) self.assertEqual(treespec, pytree.treespec_leaf())
unflattened = pytree.tree_unflatten(values, treespec) unflattened = pytree.tree_unflatten(values, treespec)
self.assertEqual(unflattened, leaf) self.assertEqual(unflattened, leaf)
@ -215,7 +212,7 @@ class TestGenericPytree(TestCase):
( (
python_pytree, python_pytree,
lambda tup: python_pytree.TreeSpec( lambda tup: python_pytree.TreeSpec(
tuple, None, [python_leafspec for _ in tup] tuple, None, [python_pytree.treespec_leaf() for _ in tup]
), ),
), ),
name="python", name="python",
@ -250,7 +247,7 @@ class TestGenericPytree(TestCase):
( (
python_pytree, python_pytree,
lambda lst: python_pytree.TreeSpec( lambda lst: python_pytree.TreeSpec(
list, None, [python_leafspec for _ in lst] list, None, [python_pytree.treespec_leaf() for _ in lst]
), ),
), ),
name="python", name="python",
@ -286,7 +283,7 @@ class TestGenericPytree(TestCase):
lambda dct: python_pytree.TreeSpec( lambda dct: python_pytree.TreeSpec(
dict, dict,
list(dct.keys()), list(dct.keys()),
[python_leafspec for _ in dct.values()], [python_pytree.treespec_leaf() for _ in dct.values()],
), ),
), ),
name="python", name="python",
@ -327,7 +324,7 @@ class TestGenericPytree(TestCase):
lambda odict: python_pytree.TreeSpec( lambda odict: python_pytree.TreeSpec(
OrderedDict, OrderedDict,
list(odict.keys()), list(odict.keys()),
[python_leafspec for _ in odict.values()], [python_pytree.treespec_leaf() for _ in odict.values()],
), ),
), ),
name="python", name="python",
@ -371,7 +368,7 @@ class TestGenericPytree(TestCase):
lambda ddct: python_pytree.TreeSpec( lambda ddct: python_pytree.TreeSpec(
defaultdict, defaultdict,
[ddct.default_factory, list(ddct.keys())], [ddct.default_factory, list(ddct.keys())],
[python_leafspec for _ in ddct.values()], [python_pytree.treespec_leaf() for _ in ddct.values()],
), ),
), ),
name="python", name="python",
@ -413,7 +410,7 @@ class TestGenericPytree(TestCase):
( (
python_pytree, python_pytree,
lambda deq: python_pytree.TreeSpec( lambda deq: python_pytree.TreeSpec(
deque, deq.maxlen, [python_leafspec for _ in deq] deque, deq.maxlen, [python_pytree.treespec_leaf() for _ in deq]
), ),
), ),
name="python", name="python",
@ -453,7 +450,7 @@ class TestGenericPytree(TestCase):
def run_test(tup): def run_test(tup):
if pytree is python_pytree: if pytree is python_pytree:
expected_spec = python_pytree.TreeSpec( expected_spec = python_pytree.TreeSpec(
namedtuple, Point, [python_leafspec for _ in tup] namedtuple, Point, [python_pytree.treespec_leaf() for _ in tup]
) )
else: else:
expected_spec = cxx_pytree.tree_structure(Point(0, 1)) expected_spec = cxx_pytree.tree_structure(Point(0, 1))
@ -848,16 +845,16 @@ if "optree" in sys.modules:
def test_treespec_equality(self): def test_treespec_equality(self):
self.assertEqual( self.assertEqual(
python_pytree.LeafSpec(), python_pytree.treespec_leaf(),
python_pytree.LeafSpec(), python_pytree.treespec_leaf(),
) )
self.assertEqual( self.assertEqual(
python_pytree.TreeSpec(list, None, []), python_pytree.TreeSpec(list, None, []),
python_pytree.TreeSpec(list, None, []), python_pytree.TreeSpec(list, None, []),
) )
self.assertEqual( self.assertEqual(
python_pytree.TreeSpec(list, None, [python_pytree.LeafSpec()]), python_pytree.TreeSpec(list, None, [python_pytree.treespec_leaf()]),
python_pytree.TreeSpec(list, None, [python_pytree.LeafSpec()]), python_pytree.TreeSpec(list, None, [python_pytree.treespec_leaf()]),
) )
self.assertFalse( self.assertFalse(
python_pytree.TreeSpec(tuple, None, []) python_pytree.TreeSpec(tuple, None, [])
@ -892,24 +889,32 @@ if "optree" in sys.modules:
# python_pytree.tree_structure({}) # python_pytree.tree_structure({})
python_pytree.TreeSpec(dict, [], []), python_pytree.TreeSpec(dict, [], []),
# python_pytree.tree_structure([0]) # python_pytree.tree_structure([0])
python_pytree.TreeSpec(list, None, [python_leafspec]), python_pytree.TreeSpec(list, None, [python_pytree.treespec_leaf()]),
# python_pytree.tree_structure([0, 1]) # python_pytree.tree_structure([0, 1])
python_pytree.TreeSpec( python_pytree.TreeSpec(
list, list,
None, None,
[python_leafspec, python_leafspec], [python_pytree.treespec_leaf(), python_pytree.treespec_leaf()],
), ),
# python_pytree.tree_structure((0, 1, 2)) # python_pytree.tree_structure((0, 1, 2))
python_pytree.TreeSpec( python_pytree.TreeSpec(
tuple, tuple,
None, None,
[python_leafspec, python_leafspec, python_leafspec], [
python_pytree.treespec_leaf(),
python_pytree.treespec_leaf(),
python_pytree.treespec_leaf(),
],
), ),
# python_pytree.tree_structure({"a": 0, "b": 1, "c": 2}) # python_pytree.tree_structure({"a": 0, "b": 1, "c": 2})
python_pytree.TreeSpec( python_pytree.TreeSpec(
dict, dict,
["a", "b", "c"], ["a", "b", "c"],
[python_leafspec, python_leafspec, python_leafspec], [
python_pytree.treespec_leaf(),
python_pytree.treespec_leaf(),
python_pytree.treespec_leaf(),
],
), ),
# python_pytree.tree_structure(OrderedDict([("a", (0, 1)), ("b", 2), ("c", {"a": 3, "b": 4, "c": 5})]) # python_pytree.tree_structure(OrderedDict([("a", (0, 1)), ("b", 2), ("c", {"a": 3, "b": 4, "c": 5})])
python_pytree.TreeSpec( python_pytree.TreeSpec(
@ -919,13 +924,17 @@ if "optree" in sys.modules:
python_pytree.TreeSpec( python_pytree.TreeSpec(
tuple, tuple,
None, None,
[python_leafspec, python_leafspec], [python_pytree.treespec_leaf(), python_pytree.treespec_leaf()],
), ),
python_leafspec, python_pytree.treespec_leaf(),
python_pytree.TreeSpec( python_pytree.TreeSpec(
dict, dict,
["a", "b", "c"], ["a", "b", "c"],
[python_leafspec, python_leafspec, python_leafspec], [
python_pytree.treespec_leaf(),
python_pytree.treespec_leaf(),
python_pytree.treespec_leaf(),
],
), ),
], ],
), ),
@ -938,12 +947,15 @@ if "optree" in sys.modules:
tuple, tuple,
None, None,
[ [
python_leafspec, python_pytree.treespec_leaf(),
python_leafspec, python_pytree.treespec_leaf(),
python_pytree.TreeSpec( python_pytree.TreeSpec(
list, list,
None, None,
[python_leafspec, python_leafspec], [
python_pytree.treespec_leaf(),
python_pytree.treespec_leaf(),
],
), ),
], ],
), ),
@ -957,12 +969,12 @@ if "optree" in sys.modules:
python_pytree.TreeSpec( python_pytree.TreeSpec(
list, list,
None, None,
[python_leafspec, python_leafspec], [python_pytree.treespec_leaf(), python_pytree.treespec_leaf()],
), ),
python_pytree.TreeSpec( python_pytree.TreeSpec(
list, list,
None, None,
[python_leafspec, python_leafspec], [python_pytree.treespec_leaf(), python_pytree.treespec_leaf()],
), ),
python_pytree.TreeSpec(dict, [], []), python_pytree.TreeSpec(dict, [], []),
], ],
@ -991,7 +1003,7 @@ if "optree" in sys.modules:
list, list,
None, None,
[ [
python_leafspec, python_pytree.treespec_leaf(),
], ],
), ),
], ],
@ -1000,7 +1012,7 @@ if "optree" in sys.modules:
self.assertIsInstance(serialized_spec, str) self.assertIsInstance(serialized_spec, str)
def test_pytree_serialize_enum(self): def test_pytree_serialize_enum(self):
spec = python_pytree.TreeSpec(dict, TestEnum.A, [python_leafspec]) spec = python_pytree.TreeSpec(dict, TestEnum.A, [python_pytree.treespec_leaf()])
serialized_spec = python_pytree.treespec_dumps(spec) serialized_spec = python_pytree.treespec_dumps(spec)
self.assertIsInstance(serialized_spec, str) self.assertIsInstance(serialized_spec, str)
@ -1163,12 +1175,20 @@ if "optree" in sys.modules:
OrderedDict, OrderedDict,
[1, 2, 3], [1, 2, 3],
[ [
python_pytree.TreeSpec(tuple, None, [python_leafspec, python_leafspec]), python_pytree.TreeSpec(
python_leafspec, tuple,
None,
[python_pytree.treespec_leaf(), python_pytree.treespec_leaf()],
),
python_pytree.treespec_leaf(),
python_pytree.TreeSpec( python_pytree.TreeSpec(
dict, dict,
[4, 5, 6], [4, 5, 6],
[python_leafspec, python_leafspec, python_leafspec], [
python_pytree.treespec_leaf(),
python_pytree.treespec_leaf(),
python_pytree.treespec_leaf(),
],
), ),
], ],
) )
@ -1453,7 +1473,7 @@ class TestCxxPytree(TestCase):
raise unittest.SkipTest("C++ pytree tests are not supported in fbcode") raise unittest.SkipTest("C++ pytree tests are not supported in fbcode")
def test_treespec_equality(self): def test_treespec_equality(self):
self.assertEqual(cxx_pytree.LeafSpec(), cxx_pytree.LeafSpec()) self.assertEqual(cxx_pytree.treespec_leaf(), cxx_pytree.treespec_leaf())
def test_treespec_repr(self): def test_treespec_repr(self):
# Check that it looks sane # Check that it looks sane

View File

@ -17,7 +17,7 @@ from ..decorators import substitute_in_graph
if TYPE_CHECKING: if TYPE_CHECKING:
import builtins import builtins
from collections.abc import Callable, Iterable from collections.abc import Callable, Iterable, Mapping
from typing_extensions import Self from typing_extensions import Self
@ -349,6 +349,113 @@ if python_pytree._cxx_pytree_dynamo_traceable:
def _is_pytreespec_instance(obj: Any, /) -> TypeIs[PyTreeSpec]: def _is_pytreespec_instance(obj: Any, /) -> TypeIs[PyTreeSpec]:
return isinstance(obj, PyTreeSpec) return isinstance(obj, PyTreeSpec)
@substitute_in_graph( # type: ignore[arg-type]
optree.treespec_leaf,
# We need to disable constant folding here because we want the function to reference the
# PyTreeSpec class defined above, not the one in the C++ module.
can_constant_fold_through=False,
)
def treespec_leaf(
*,
none_is_leaf: bool = False,
namespace: str = "", # unused
) -> PyTreeSpec:
return PyTreeSpec(
(),
None,
None,
(),
None,
none_is_leaf=none_is_leaf,
namespace="",
)
@substitute_in_graph( # type: ignore[arg-type]
optree.treespec_tuple,
# We need to disable constant folding here because we want the function to reference the
# PyTreeSpec class defined above, not the one in the C++ module.
can_constant_fold_through=False,
)
def treespec_tuple(
iterable: Iterable[PyTreeSpec] = (),
/,
*,
none_is_leaf: bool = False,
namespace: str = "",
) -> PyTreeSpec:
children = tuple(iterable)
if any(not _is_pytreespec_instance(child) for child in children):
raise ValueError(f"Expected a tuple of PyTreeSpecs, got: {children!r}.")
if any(child.none_is_leaf != none_is_leaf for child in children):
raise ValueError(
"All children PyTreeSpecs must have the same `none_is_leaf` value "
f"as the parent; expected {none_is_leaf}, got: {children!r}.",
)
if any(child.namespace not in (namespace, "") for child in children):
raise ValueError(
"All children PyTreeSpecs must have the same `namespace` value "
f"as the parent; expected {namespace!r}, got: {children!r}.",
)
handler = optree.register_pytree_node.get(tuple, namespace=namespace) # type: ignore[attr-defined]
assert handler is not None
return PyTreeSpec(
tuple(children),
tuple,
None,
tuple(range(len(children))),
handler.unflatten_func,
none_is_leaf=none_is_leaf,
namespace=namespace,
)
@substitute_in_graph( # type: ignore[arg-type]
optree.treespec_dict,
# We need to disable constant folding here because we want the function to reference the
# PyTreeSpec class defined above, not the one in the C++ module.
can_constant_fold_through=False,
)
def treespec_dict(
mapping: Mapping[Any, PyTreeSpec] | Iterable[tuple[Any, PyTreeSpec]] = (),
/,
*,
none_is_leaf: bool = False,
namespace: str = "",
**kwargs: PyTreeSpec,
) -> PyTreeSpec:
dct = dict(mapping, **kwargs)
if any(not _is_pytreespec_instance(child) for child in dct.values()):
raise ValueError(f"Expected a dictionary of TreeSpecs, got: {dct!r}.")
if any(child.none_is_leaf != none_is_leaf for child in dct.values()):
raise ValueError(
"All children PyTreeSpecs must have the same `none_is_leaf` value "
f"as the parent; expected {none_is_leaf}, got: {dct!r}.",
)
if any(child.namespace not in (namespace, "") for child in dct.values()):
raise ValueError(
"All children PyTreeSpecs must have the same `namespace` value "
f"as the parent; expected {namespace!r}, got: {dct!r}.",
)
(
children,
metadata,
entries,
unflatten_func,
) = optree.tree_flatten_one_level( # type: ignore[assignment,var-annotated]
dct, # type: ignore[arg-type]
none_is_leaf=none_is_leaf,
namespace=namespace,
)
return PyTreeSpec(
tuple(children), # type: ignore[arg-type]
dict,
metadata,
entries,
unflatten_func,
none_is_leaf=none_is_leaf,
namespace=namespace,
)
@substitute_in_graph( # type: ignore[arg-type] @substitute_in_graph( # type: ignore[arg-type]
optree.tree_flatten, optree.tree_flatten,
# We need to disable constant folding here because we want the function to reference the # We need to disable constant folding here because we want the function to reference the

View File

@ -3727,9 +3727,7 @@ class SourcelessBuilder:
pass # failthrough to unimplemented branch pass # failthrough to unimplemented branch
elif isinstance(value, torch.fx.graph_module.GraphModule): elif isinstance(value, torch.fx.graph_module.GraphModule):
return SourcelessGraphModuleVariable(value) return SourcelessGraphModuleVariable(value)
elif isinstance( elif isinstance(value, torch.utils._pytree.TreeSpec):
value, (torch.utils._pytree.TreeSpec, torch.utils._pytree.LeafSpec)
):
return UserDefinedObjectVariable(value) return UserDefinedObjectVariable(value)
elif PlacementVariable.is_placement(value): elif PlacementVariable.is_placement(value):
return PlacementVariable(value) return PlacementVariable(value)

View File

@ -1661,8 +1661,8 @@ class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
# We need to have this check this way, because in case init is a TreeSpec and carry # We need to have this check this way, because in case init is a TreeSpec and carry
# but carry is only a LeafSpec, these two cannot be compared correctly. # but carry is only a LeafSpec, these two cannot be compared correctly.
if ( if (
isinstance(xs_treespec.as_python_constant(), pytree.LeafSpec) xs_treespec.as_python_constant().is_leaf()
!= isinstance(_combine_treespec.as_python_constant(), pytree.LeafSpec) != _combine_treespec.as_python_constant().is_leaf()
) or not _make_inlined(tx, pytree.TreeSpec.__eq__)( ) or not _make_inlined(tx, pytree.TreeSpec.__eq__)(
xs_treespec, _combine_treespec xs_treespec, _combine_treespec
).as_python_constant(): ).as_python_constant():

View File

@ -1530,10 +1530,10 @@ class GraphModuleSerializer(metaclass=Final):
else: else:
raise AssertionError("TODO") raise AssertionError("TODO")
def serialize_treespec(self, treespec): def serialize_treespec(self, treespec: pytree.TreeSpec) -> str:
# We want to additionally save all the field names of the namedtuples in # We want to additionally save all the field names of the namedtuples in
# case users want to check that the treespec types are equivalent # case users want to check that the treespec types are equivalent
def store_namedtuple_fields(ts): def store_namedtuple_fields(ts: pytree.TreeSpec) -> None:
if ts.type is None: if ts.type is None:
return return
if ts.type is namedtuple or pytree.is_namedtuple_class(ts.type): if ts.type is namedtuple or pytree.is_namedtuple_class(ts.type):
@ -1555,7 +1555,7 @@ class GraphModuleSerializer(metaclass=Final):
NamedTupleDef(field_names=ts.context._fields) NamedTupleDef(field_names=ts.context._fields)
) )
for child in ts.children_specs: for child in ts.children():
store_namedtuple_fields(child) store_namedtuple_fields(child)
serialized_treespec = treespec_dumps(treespec, TREESPEC_VERSION) serialized_treespec = treespec_dumps(treespec, TREESPEC_VERSION)

View File

@ -158,7 +158,7 @@ class PytreeThunk:
assert spec is not None assert spec is not None
self.spec: pytree.TreeSpec = spec self.spec: pytree.TreeSpec = spec
if self.spec.type in {tuple, list} and all( if self.spec.type in {tuple, list} and all(
child.is_leaf() for child in spec.children_specs child.is_leaf() for child in spec.children()
): ):
self.is_simple = True self.is_simple = True
if self.spec.is_leaf(): if self.spec.is_leaf():

View File

@ -1590,7 +1590,7 @@ def aot_export_joint_simple(
decompositions=decompositions, decompositions=decompositions,
trace_joint=trace_joint, trace_joint=trace_joint,
) )
in_spec, _kw_in_spec = in_spec.children_specs in_spec, _kw_in_spec = in_spec.children()
# At this point, we can just directly return the (joint or inference graph) that we traced. # At this point, we can just directly return the (joint or inference graph) that we traced.
# First though: a bunch of assertions to make sure that our graph doesn't require # First though: a bunch of assertions to make sure that our graph doesn't require
# any calling convention changes compared to the original function. # any calling convention changes compared to the original function.
@ -1617,7 +1617,7 @@ def aot_export_joint_simple(
raise RuntimeError( raise RuntimeError(
f"aot_export_joint_simple requires inputs to be a single list/tuple. in_spec={str(in_spec)}" f"aot_export_joint_simple requires inputs to be a single list/tuple. in_spec={str(in_spec)}"
) )
if not all(child.is_leaf() for child in in_spec.children_specs): if not all(child.is_leaf() for child in in_spec.children()):
raise RuntimeError( raise RuntimeError(
f"aot_export_joint_simple requires individual inputs not to be pytrees. in_spec={str(in_spec)}" f"aot_export_joint_simple requires individual inputs not to be pytrees. in_spec={str(in_spec)}"
) )
@ -1625,7 +1625,7 @@ def aot_export_joint_simple(
raise RuntimeError( raise RuntimeError(
f"aot_export_joint_simple requires outputs to be a single list/tuple. out_spec={str(out_spec)}" f"aot_export_joint_simple requires outputs to be a single list/tuple. out_spec={str(out_spec)}"
) )
if not all(child.is_leaf() for child in out_spec.children_specs): if not all(child.is_leaf() for child in out_spec.children()):
raise RuntimeError( raise RuntimeError(
f"aot_export_joint_simple requires individual outputs not to be pytrees. out_spec={str(out_spec)}" f"aot_export_joint_simple requires individual outputs not to be pytrees. out_spec={str(out_spec)}"
) )

View File

@ -469,7 +469,7 @@ def _unlift_graph(
gm, gm,
lifted_inputs, lifted_inputs,
mutated_outputs, mutated_outputs,
pytree.LeafSpec(), pytree.treespec_leaf(),
None, None,
) )
return unlifted_gm return unlifted_gm

View File

@ -7,6 +7,7 @@ from typing import Any, Optional, Union
import torch import torch
import torch.ao.quantization.pt2e._affine_quantization # noqa: F401 import torch.ao.quantization.pt2e._affine_quantization # noqa: F401
import torch.nn.functional as F import torch.nn.functional as F
import torch.utils._pytree as pytree
# Makes sure that quantized_decomposed ops are registered # Makes sure that quantized_decomposed ops are registered
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401 from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
@ -14,7 +15,6 @@ from torch.ao.quantization.quantizer import QuantizationAnnotation
from torch.export.unflatten import _assign_attr, _AttrKind from torch.export.unflatten import _assign_attr, _AttrKind
from torch.fx import GraphModule, Node from torch.fx import GraphModule, Node
from torch.nn.utils.fusion import fuse_conv_bn_weights from torch.nn.utils.fusion import fuse_conv_bn_weights
from torch.utils._pytree import LeafSpec
__all__ = [ __all__ = [
@ -477,7 +477,10 @@ def _replace_literals_with_new_placeholders(
exclude_literals = [] exclude_literals = []
in_spec = gm._in_spec in_spec = gm._in_spec
args_spec = in_spec.children_specs[0] assert in_spec.type is tuple
args_spec = in_spec.child(0)
assert args_spec.type is tuple
args_spec_children = args_spec.children()
for node in gm.graph.nodes: for node in gm.graph.nodes:
if node.op == "placeholder": if node.op == "placeholder":
last_ph = node last_ph = node
@ -492,7 +495,7 @@ def _replace_literals_with_new_placeholders(
else: else:
ph_node = gm.graph.placeholder("arg" + str(cnt)) ph_node = gm.graph.placeholder("arg" + str(cnt))
new_args.append(ph_node) new_args.append(ph_node)
args_spec.children_specs.append(LeafSpec()) args_spec_children.append(pytree.treespec_leaf())
cnt += 1 cnt += 1
if merge_dup: if merge_dup:
literal_to_ph[arg] = ph_node literal_to_ph[arg] = ph_node
@ -503,8 +506,8 @@ def _replace_literals_with_new_placeholders(
node.args = new_args node.args = new_args
# Update `num_nodes`, `num_leaves`, `num_children`. # Update `num_nodes`, `num_leaves`, `num_children`.
args_spec.__post_init__() args_spec = pytree.treespec_tuple(args_spec_children)
in_spec.__post_init__() gm._in_spec = in_spec = pytree.treespec_tuple([args_spec, *in_spec.children()[1:]])
return gm return gm

View File

@ -195,17 +195,16 @@ def _construct_inputs(
unflatten_node = _generate_unflatten(gm, tree_unflatten_args, signature.in_spec) unflatten_node = _generate_unflatten(gm, tree_unflatten_args, signature.in_spec)
assert signature.in_spec.num_children == 2 assert signature.in_spec.num_children == 2
assert signature.in_spec.type is tuple
args_spec, kwargs_spec = signature.in_spec.children()
assert args_spec.type is tuple
assert kwargs_spec.type is dict
args_spec = signature.in_spec.children_specs[0]
assert args_spec.context is None
args_node = gm.graph.call_function(operator.getitem, (unflatten_node, 0)) args_node = gm.graph.call_function(operator.getitem, (unflatten_node, 0))
args_nodes = [ args_nodes = [
gm.graph.call_function(operator.getitem, (args_node, i)) gm.graph.call_function(operator.getitem, (args_node, i))
for i in range(args_spec.num_children) for i in range(args_spec.num_children)
] ]
kwargs_spec = signature.in_spec.children_specs[1]
assert kwargs_spec.context is not None
kwargs_node = gm.graph.call_function(operator.getitem, (unflatten_node, 1)) kwargs_node = gm.graph.call_function(operator.getitem, (unflatten_node, 1))
kwargs_nodes = { kwargs_nodes = {
k: gm.graph.call_function(operator.getitem, (kwargs_node, k)) k: gm.graph.call_function(operator.getitem, (kwargs_node, k))
@ -372,8 +371,8 @@ def _fix_input_output_signature(
if forward_arg_names is None: if forward_arg_names is None:
forward_arg_names = [] forward_arg_names = []
assert signature.in_spec.num_children == 2 assert signature.in_spec.num_children == 2
arg_spec = signature.in_spec.children_specs[0] arg_spec = signature.in_spec.child(0)
kwarg_spec = signature.in_spec.children_specs[1] kwarg_spec = signature.in_spec.child(1)
assert arg_spec.type is tuple assert arg_spec.type is tuple
assert kwarg_spec.type is dict assert kwarg_spec.type is dict
for i in range(arg_spec.num_children): for i in range(arg_spec.num_children):

View File

@ -1533,7 +1533,7 @@ def _strict_export(
# aot_export expect the return type to always be a tuple. # aot_export expect the return type to always be a tuple.
if out_spec.type not in (list, tuple): if out_spec.type not in (list, tuple):
out_spec = pytree.TreeSpec(tuple, None, [out_spec]) out_spec = pytree.treespec_tuple([out_spec])
orig_arg_names = gm_torch_level.graph._codegen.pytree_info.orig_args # type: ignore[attr-defined] orig_arg_names = gm_torch_level.graph._codegen.pytree_info.orig_args # type: ignore[attr-defined]

View File

@ -15,7 +15,7 @@ def reorder_kwargs(user_kwargs: dict[str, Any], spec: TreeSpec) -> dict[str, Any
# Make sure that the spec is actually shaped like (args, kwargs) # Make sure that the spec is actually shaped like (args, kwargs)
assert spec.type is tuple assert spec.type is tuple
assert spec.num_children == 2 assert spec.num_children == 2
kwargs_spec = spec.children_specs[1] kwargs_spec = spec.child(1)
assert kwargs_spec.type is dict assert kwargs_spec.type is dict
if set(user_kwargs) != set(kwargs_spec.context): if set(user_kwargs) != set(kwargs_spec.context):
@ -55,10 +55,10 @@ def is_equivalent(
return False return False
# Recurse on children # Recurse on children
if len(spec1.children_specs) != len(spec2.children_specs): if spec1.num_children != spec2.num_children:
return False return False
for child_spec1, child_spec2 in zip(spec1.children_specs, spec2.children_specs): for child_spec1, child_spec2 in zip(spec1.children(), spec2.children()):
if not is_equivalent(child_spec1, child_spec2, equivalence_fn): if not is_equivalent(child_spec1, child_spec2, equivalence_fn):
return False return False

View File

@ -57,11 +57,11 @@ def eq_spec(self: pytree.TreeSpec, other: pytree.TreeSpec) -> bool:
return False return False
elif a.context != b.context: elif a.context != b.context:
return False return False
if len(a.children_specs) != len(b.children_specs): if a.num_children != b.num_children:
return False return False
return all( return all(
_match_normalized_structure(a, b) _match_normalized_structure(a, b)
for a, b in zip(a.children_specs, b.children_specs) for a, b in zip(a.children(), b.children())
) )
return _match_normalized_structure(self, other) return _match_normalized_structure(self, other)
@ -357,13 +357,13 @@ def _get_codegen(
elif ( elif (
in_spec.type is tuple in_spec.type is tuple
and in_spec.num_children == 2 and in_spec.num_children == 2
and in_spec.children_specs[0].type is tuple and in_spec.child(0).type is tuple
and in_spec.children_specs[1].type is dict and in_spec.child(1).type is dict
): ):
# if in_spec contains the args (tuple) and kwargs (dict) # if in_spec contains the args (tuple) and kwargs (dict)
names = [f"arg_{i}" for i in range(in_spec.children_specs[0].num_children)] names = [f"arg_{i}" for i in range(in_spec.child(0).num_children)]
# add kwarg names # add kwarg names
names.extend(in_spec.children_specs[1].context) names.extend(in_spec.child(1).context)
else: else:
names = [f"arg_{i}" for i in range(in_spec.num_children)] names = [f"arg_{i}" for i in range(in_spec.num_children)]

View File

@ -12,14 +12,16 @@ import torch
from torch.utils._pytree import ( from torch.utils._pytree import (
_get_node_type, _get_node_type,
BUILTIN_TYPES, BUILTIN_TYPES,
KeyPath,
keystr, keystr,
LeafSpec,
MappingKey, MappingKey,
SequenceKey, SequenceKey,
SUPPORTED_NODES, SUPPORTED_NODES,
tree_flatten, tree_iter,
tree_map, tree_map,
tree_map_with_path, tree_map_with_path,
tree_structure,
TreeSpec,
) )
from .exported_program import ExportedProgram from .exported_program import ExportedProgram
@ -655,53 +657,55 @@ def _tree_map_with_path(
case_name="dynamic_shapes_validation", case_name="dynamic_shapes_validation",
) )
def _compare(tree, dynamic_shapes, path): def _compare(
treespec: TreeSpec, other_treespec: TreeSpec, path: KeyPath
) -> None:
# raise an error at the point where tree and dynamic_shapes differ, # raise an error at the point where tree and dynamic_shapes differ,
# including the path to that point and the reason for the difference # including the path to that point and the reason for the difference
rendered_path = keystr(path) rendered_path = keystr(path)
if isinstance(tree, LeafSpec): if treespec.is_leaf():
return return
if isinstance(dynamic_shapes, LeafSpec): if other_treespec.is_leaf():
raise_mismatch_error( raise_mismatch_error(
f"`{tree_name}{rendered_path}` is a {tree.type}, " f"`{tree_name}{rendered_path}` is a {treespec.type}, "
f"but `dynamic_shapes{rendered_path}` is not" f"but `dynamic_shapes{rendered_path}` is not"
) )
if tree.type != dynamic_shapes.type: if treespec.type != other_treespec.type:
raise_mismatch_error( raise_mismatch_error(
f"`{tree_name}{rendered_path}` is a {tree.type}, " f"`{tree_name}{rendered_path}` is a {treespec.type}, "
f"but `dynamic_shapes{rendered_path}` is a {dynamic_shapes.type}" f"but `dynamic_shapes{rendered_path}` is a {other_treespec.type}"
) )
if len(tree.children_specs) != len(dynamic_shapes.children_specs): if treespec.num_children != other_treespec.num_children:
raise_mismatch_error( raise_mismatch_error(
f"`{tree_name}{rendered_path}` has {len(tree.children_specs)} elements, " f"`{tree_name}{rendered_path}` has {treespec.num_children} elements, "
f"but `dynamic_shapes{rendered_path}` has {len(dynamic_shapes.children_specs)} elements" f"but `dynamic_shapes{rendered_path}` has {other_treespec.num_children} elements"
) )
if tree.type is dict: if treespec.type is dict:
# context, children could be out of order # context, children could be out of order
if sorted(tree.context) != sorted(dynamic_shapes.context): if set(treespec.context) != set(other_treespec.context):
raise_mismatch_error( raise_mismatch_error(
f"`{tree_name}{rendered_path}` has keys {tree.context}, " f"`{tree_name}{rendered_path}` has keys {treespec.context}, "
f"but `dynamic_shapes{rendered_path}` has keys {dynamic_shapes.context}" f"but `dynamic_shapes{rendered_path}` has keys {other_treespec.context}"
) )
_remap = dict( _remap = dict(
zip(dynamic_shapes.context, dynamic_shapes.children_specs) zip(other_treespec.context, other_treespec.children())
) )
dynamic_shapes_children_specs = [_remap[k] for k in tree.context] other_children = [_remap[k] for k in treespec.context]
else: else:
dynamic_shapes_children_specs = dynamic_shapes.children_specs other_children = other_treespec.children()
for i, (tree_, dynamic_shapes_) in enumerate( for i, (child, other_child) in enumerate(
zip(tree.children_specs, dynamic_shapes_children_specs) zip(treespec.children(), other_children)
): ):
_compare( _compare(
tree_, child,
dynamic_shapes_, other_child,
path + [_key(tree.type, tree.context, i)], path + (_key(treespec.type, treespec.context, i),),
) )
_, tree_spec = tree_flatten(tree, is_leaf=is_leaf) treespec = tree_structure(tree, is_leaf=is_leaf)
for other_tree in dynamic_shapes: for other_tree in dynamic_shapes:
_, other_tree_spec = tree_flatten(other_tree, is_leaf) other_treespec = tree_structure(other_tree, is_leaf)
_compare(tree_spec, other_tree_spec, []) _compare(treespec, other_treespec, ())
raise raise
@ -1231,10 +1235,7 @@ def _get_dim_name_mapping(
dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any], None], dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any], None],
): ):
name_to_dim = {} name_to_dim = {}
for dim in tree_flatten( for dim in tree_iter(dynamic_shapes, is_leaf=lambda x: isinstance(x, Dim)):
dynamic_shapes,
is_leaf=lambda x: isinstance(x, Dim),
)[0]:
if dim is None: if dim is None:
# NOTE: this must denote a non-Tensor or automatic at this point. # NOTE: this must denote a non-Tensor or automatic at this point.
continue continue

View File

@ -391,7 +391,7 @@ def _decompose_and_get_gm_with_new_signature_constants(
# aot_export expect the return type to always be a tuple. # aot_export expect the return type to always be a tuple.
assert out_spec is not None assert out_spec is not None
if out_spec.type not in (list, tuple): if out_spec.type not in (list, tuple):
out_spec = pytree.TreeSpec(tuple, None, [out_spec]) out_spec = pytree.treespec_tuple([out_spec])
mod.graph._codegen = _PyTreeCodeGen( mod.graph._codegen = _PyTreeCodeGen(
_PyTreeInfo( _PyTreeInfo(

View File

@ -1124,10 +1124,10 @@ class _ModuleFrame:
signature = module_call_graph.get(self.child_fqn) signature = module_call_graph.get(self.child_fqn)
if signature is not None and self.parent is not None: if signature is not None and self.parent is not None:
assert signature.in_spec.num_children == 2 assert signature.in_spec.num_children == 2
args_spec = signature.in_spec.children_specs[0] assert signature.in_spec.type is tuple
kwargs_spec = signature.in_spec.children_specs[1] args_spec, kwargs_spec = signature.in_spec.children()
assert args_spec.context is None assert args_spec.type is tuple
assert kwargs_spec.context is not None assert kwargs_spec.type is dict
with self.graph.inserting_after(None): with self.graph.inserting_after(None):
arg_nodes = [ arg_nodes = [

View File

@ -49,7 +49,7 @@ def tree_flatten_spec(
flatten_fn_spec = SUPPORTED_NODES[spec.type] flatten_fn_spec = SUPPORTED_NODES[spec.type]
child_pytrees = flatten_fn_spec(pytree, spec) child_pytrees = flatten_fn_spec(pytree, spec)
result = [] result = []
for child, child_spec in zip(child_pytrees, spec.children_specs): for child, child_spec in zip(child_pytrees, spec.children()):
flat = tree_flatten_spec(child, child_spec) flat = tree_flatten_spec(child, child_spec)
result += flat result += flat
return result return result

View File

@ -709,7 +709,7 @@ class Tracer(TracerBase):
root_fn = _patch_function(root_fn, len(args)) root_fn = _patch_function(root_fn, len(args))
flat_args, in_spec = pytree.tree_flatten(tuple(args)) flat_args, in_spec = pytree.tree_flatten(tuple(args))
if not all(child.is_leaf() for child in in_spec.children_specs): if not all(child.is_leaf() for child in in_spec.children()):
# In the case that we have pytree-flattened inputs in # In the case that we have pytree-flattened inputs in
# `concrete_args`, generate a flattening wrapper around the # `concrete_args`, generate a flattening wrapper around the
# original root function and return that. # original root function and return that.

View File

@ -933,24 +933,25 @@ class _PyTreeCodeGen(CodeGen):
return "\n " + "".join(x + "; " for x in has_annotation) + "\n" return "\n " + "".join(x + "; " for x in has_annotation) + "\n"
def gen_var_bindings(self, fn_args, free_vars, expanded_def) -> str: def gen_var_bindings(self, fn_args, free_vars, expanded_def) -> str:
in_spec = self.pytree_info.in_spec
# when kwargs is present, in_spec is tuple(args, kwargs) # when kwargs is present, in_spec is tuple(args, kwargs)
has_args_kwargs_tuple = ( has_args_kwargs_tuple = (
self.pytree_info.in_spec.type is tuple in_spec.type is tuple
and self.pytree_info.in_spec.num_children == 2 and in_spec.num_children == 2
and self.pytree_info.in_spec.children_specs[0].type is tuple and in_spec.child(0).type is tuple
and self.pytree_info.in_spec.children_specs[1].type is dict and in_spec.child(1).type is dict
) )
fn_kwargs = "{}" fn_kwargs = "{}"
fn_signature = f"[{', '.join(fn_args)}], self._in_spec" fn_signature = f"[{', '.join(fn_args)}], self._in_spec"
if has_args_kwargs_tuple: if has_args_kwargs_tuple:
count_args = self.pytree_info.in_spec.children_specs[0].num_children count_args = in_spec.child(0).num_children
fn_args = self.pytree_info.orig_args[:count_args] fn_args = self.pytree_info.orig_args[:count_args]
fn_kwargs = ( fn_kwargs = (
"{" "{"
+ ", ".join( + ", ".join(
f"'{k}':{v}" f"'{k}':{v}"
for k, v in zip( for k, v in zip(
self.pytree_info.in_spec.children_specs[1].context, in_spec.child(1).context,
self.pytree_info.orig_args[count_args:], self.pytree_info.orig_args[count_args:],
) )
) )

View File

@ -14,9 +14,9 @@ collection support for PyTorch APIs.
import functools import functools
import types import types
from collections.abc import Callable, Iterable from collections.abc import Callable, Iterable, Mapping
from typing import Any, Optional, overload, TypeVar, Union from typing import Any, Optional, overload, TypeVar, Union
from typing_extensions import deprecated, TypeIs from typing_extensions import deprecated, Self, TypeAlias, TypeIs
import torch.utils._pytree as python_pytree import torch.utils._pytree as python_pytree
from torch.torch_version import TorchVersion as _TorchVersion from torch.torch_version import TorchVersion as _TorchVersion
@ -42,7 +42,7 @@ if not python_pytree._cxx_pytree_dynamo_traceable:
import optree import optree
from optree import PyTreeSpec as TreeSpec # direct import for type annotations from optree import PyTreeSpec # direct import for type annotations
__all__ = [ __all__ = [
@ -53,6 +53,7 @@ __all__ = [
"DumpableContext", "DumpableContext",
"ToDumpableContextFn", "ToDumpableContextFn",
"FromDumpableContextFn", "FromDumpableContextFn",
"PyTreeSpec",
"TreeSpec", "TreeSpec",
"LeafSpec", "LeafSpec",
"keystr", "keystr",
@ -100,6 +101,8 @@ U = TypeVar("U")
R = TypeVar("R") R = TypeVar("R")
TreeSpec: TypeAlias = PyTreeSpec
Context = Any Context = Any
PyTree = Any PyTree = Any
FlattenFunc = Callable[[PyTree], tuple[list[Any], Context]] FlattenFunc = Callable[[PyTree], tuple[list[Any], Context]]
@ -267,6 +270,30 @@ def _is_pytreespec_instance(obj: Any, /) -> TypeIs[TreeSpec]:
return isinstance(obj, TreeSpec) return isinstance(obj, TreeSpec)
def treespec_leaf() -> TreeSpec:
"""Make a treespec representing a leaf node."""
return optree.treespec_leaf(none_is_leaf=True, namespace="torch")
def treespec_tuple(iterable: Iterable[TreeSpec] = (), /) -> TreeSpec:
"""Make a tuple treespec from an iterable of child treespecs."""
return optree.treespec_tuple(iterable, none_is_leaf=True, namespace="torch")
def treespec_dict(
mapping: Union[Mapping[Any, TreeSpec], Iterable[tuple[Any, TreeSpec]]] = (),
/,
**kwargs: TreeSpec,
) -> TreeSpec:
"""Make a dict treespec from a dict of child treespecs."""
return optree.treespec_dict(
mapping,
**kwargs,
none_is_leaf=True,
namespace="torch",
)
def tree_is_leaf( def tree_is_leaf(
tree: PyTree, tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None, is_leaf: Optional[Callable[[PyTree], bool]] = None,
@ -985,9 +1012,14 @@ class LeafSpecMeta(type(TreeSpec)): # type: ignore[misc]
return _is_pytreespec_instance(instance) and instance.is_leaf() return _is_pytreespec_instance(instance) and instance.is_leaf()
@deprecated(
"`isinstance(treespec, LeafSpec)` is deprecated, "
"use `isinstance(treespec, TreeSpec)` and `treespec.is_leaf()` instead.",
category=FutureWarning,
)
class LeafSpec(TreeSpec, metaclass=LeafSpecMeta): # type: ignore[misc,final] class LeafSpec(TreeSpec, metaclass=LeafSpecMeta): # type: ignore[misc,final]
def __new__(cls) -> "LeafSpec": def __new__(cls) -> Self:
return optree.treespec_leaf(none_is_leaf=True) # type: ignore[return-value] return treespec_leaf() # type: ignore[return-value]
def tree_flatten_with_path( def tree_flatten_with_path(

View File

@ -39,7 +39,7 @@ from typing import (
TypeVar, TypeVar,
Union, Union,
) )
from typing_extensions import deprecated, NamedTuple, Self from typing_extensions import deprecated, NamedTuple, Self, TypeAlias
from torch.torch_version import TorchVersion as _TorchVersion from torch.torch_version import TorchVersion as _TorchVersion
@ -52,6 +52,7 @@ __all__ = [
"DumpableContext", "DumpableContext",
"ToDumpableContextFn", "ToDumpableContextFn",
"FromDumpableContextFn", "FromDumpableContextFn",
"PyTreeSpec",
"TreeSpec", "TreeSpec",
"LeafSpec", "LeafSpec",
"keystr", "keystr",
@ -475,7 +476,7 @@ class ConstantNode:
def _is_constant_holder(spec: "TreeSpec") -> bool: def _is_constant_holder(spec: "TreeSpec") -> bool:
"""Checks if the spec is from a pytree registered with register_constant""" """Checks if the spec is from a pytree registered with register_constant"""
return isinstance(spec.context, ConstantNode) return isinstance(spec._context, ConstantNode)
def _retrieve_constant(spec: "TreeSpec") -> Any: def _retrieve_constant(spec: "TreeSpec") -> Any:
@ -1076,39 +1077,60 @@ def _is_leaf(tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None) -
# A TreeSpec represents the structure of a pytree. It holds: # A TreeSpec represents the structure of a pytree. It holds:
# "type": the type of root Node of the pytree # "type": the type of root Node of the pytree
# context: some context that is useful in unflattening the pytree # context: some context that is useful in unflattening the pytree
# children_specs: specs for each child of the root Node # children(): specs for each child of the root Node
# num_leaves: the number of leaves # num_nodes: the total number of nodes
@dataclasses.dataclass(init=True, frozen=True, eq=True, repr=False) # num_leaves: the number of leaves
# num_children: the number of children of the root Node (i.e., len(children()))
# is_leaf(): whether the root Node is a leaf
@dataclasses.dataclass(init=False, frozen=True, eq=True, repr=False)
class TreeSpec: class TreeSpec:
type: Any type: Any
context: Context _context: Context
children_specs: list["TreeSpec"] _children: list[Self]
num_nodes: int = dataclasses.field(init=False) num_nodes: int = dataclasses.field(init=False)
num_leaves: int = dataclasses.field(init=False) num_leaves: int = dataclasses.field(init=False)
num_children: int = dataclasses.field(init=False) num_children: int = dataclasses.field(init=False)
def __init__(
self,
type: Any,
context: Context, # keep for backward compatibility
children_specs: list[Self], # keep for backward compatibility
) -> None:
object.__setattr__(self, "type", type)
object.__setattr__(self, "_context", context)
object.__setattr__(self, "_children", children_specs)
self.__post_init__()
def __post_init__(self) -> None: def __post_init__(self) -> None:
num_nodes = sum((spec.num_nodes for spec in self.children_specs), start=1) if self.type is None:
num_leaves = sum(spec.num_leaves for spec in self.children_specs) assert self._context is None
num_children = len(self.children_specs) assert len(self._children) == 0
num_nodes = 1
num_leaves = 1
num_children = 0
else:
num_nodes = sum((spec.num_nodes for spec in self._children), start=1)
num_leaves = sum(spec.num_leaves for spec in self._children)
num_children = len(self._children)
object.__setattr__(self, "num_nodes", num_nodes) object.__setattr__(self, "num_nodes", num_nodes)
object.__setattr__(self, "num_leaves", num_leaves) object.__setattr__(self, "num_leaves", num_leaves)
object.__setattr__(self, "num_children", num_children) object.__setattr__(self, "num_children", num_children)
def __repr__(self, indent: int = 0) -> str: def __repr__(self, indent: int = 0) -> str:
repr_prefix: str = f"TreeSpec({self.type.__name__}, {self.context}, [" repr_prefix: str = f"TreeSpec({self.type.__name__}, {self._context}, ["
children_specs_str: str = "" children_specs_str: str = ""
if self.num_children > 0: if self.num_children > 0:
indent += 2 indent += 2
children_specs_str += self.children_specs[0].__repr__(indent) children_specs_str += self._children[0].__repr__(indent)
children_specs_str += "," if self.num_children > 1 else "" children_specs_str += "," if self.num_children > 1 else ""
children_specs_str += ",".join( children_specs_str += ",".join(
[ [
"\n" + " " * indent + child.__repr__(indent) "\n" + " " * indent + child.__repr__(indent)
for child in self.children_specs[1:] for child in self._children[1:]
] ]
) )
repr_suffix: str = f"{children_specs_str}])" repr_suffix: str = f"{children_specs_str}])"
@ -1120,16 +1142,36 @@ class TreeSpec:
elif other.__class__ is self.__class__: elif other.__class__ is self.__class__:
if str(self.type) != str(other.type): if str(self.type) != str(other.type):
return False return False
if self.context != other.context: if self._context != other._context:
return False return False
elif self.children_specs != other.children_specs: elif self._children != other._children:
return False return False
return True return True
return NotImplemented return NotImplemented
@property
def context(self) -> Context:
return self._context
@property
@deprecated(
"`treespec.children_specs` is deprecated. "
"Use `treespec.child(index)` to access a single child, "
"or `treespec.children()` to get all children.",
category=FutureWarning,
)
def children_specs(self) -> list[Self]:
return self._children
def is_leaf(self) -> bool: def is_leaf(self) -> bool:
return self.num_nodes == 1 and self.num_leaves == 1 return self.num_nodes == 1 and self.num_leaves == 1
def children(self) -> list[Self]:
return self._children.copy()
def child(self, index: int) -> Self:
return self._children[index]
def flatten_up_to(self, tree: PyTree) -> list[PyTree]: def flatten_up_to(self, tree: PyTree) -> list[PyTree]:
def helper(treespec: TreeSpec, tree: PyTree, subtrees: list[PyTree]) -> None: def helper(treespec: TreeSpec, tree: PyTree, subtrees: list[PyTree]) -> None:
if treespec.is_leaf(): if treespec.is_leaf():
@ -1151,7 +1193,7 @@ class TreeSpec:
f"Node arity mismatch; " f"Node arity mismatch; "
f"expected {treespec.num_children}, but got {len(children)}.", f"expected {treespec.num_children}, but got {len(children)}.",
) )
if context != treespec.context: if context != treespec._context:
raise ValueError( raise ValueError(
f"Node context mismatch for custom node type {treespec.type!r}.", f"Node context mismatch for custom node type {treespec.type!r}.",
) )
@ -1176,10 +1218,10 @@ class TreeSpec:
if both_standard_dict: if both_standard_dict:
# dictionary types are compatible with each other # dictionary types are compatible with each other
dict_context = ( dict_context = (
treespec.context treespec._context
if treespec.type is not defaultdict if treespec.type is not defaultdict
# ignore mismatch of `default_factory` for defaultdict # ignore mismatch of `default_factory` for defaultdict
else treespec.context[1] else treespec._context[1]
) )
expected_keys = dict_context expected_keys = dict_context
got_key_set = set(tree) got_key_set = set(tree)
@ -1200,13 +1242,13 @@ class TreeSpec:
children, context = flatten_fn(tree) children, context = flatten_fn(tree)
if ( if (
node_type is not deque # ignore mismatch of `maxlen` for deque node_type is not deque # ignore mismatch of `maxlen` for deque
) and context != treespec.context: ) and context != treespec._context:
raise ValueError( raise ValueError(
f"Node context mismatch for node type {treespec.type!r}; " f"Node context mismatch for node type {treespec.type!r}; "
f"expected {treespec.context!r}, but got {context!r}.", # namedtuple type mismatch f"expected {treespec._context!r}, but got {context!r}.", # namedtuple type mismatch
) )
for subtree, subspec in zip(children, treespec.children_specs): for subtree, subspec in zip(children, treespec._children, strict=True):
helper(subspec, subtree, subtrees) helper(subspec, subtree, subtrees)
subtrees: list[PyTree] = [] subtrees: list[PyTree] = []
@ -1231,24 +1273,24 @@ class TreeSpec:
start = 0 start = 0
end = 0 end = 0
child_pytrees = [] child_pytrees = []
for child_spec in self.children_specs: for child_spec in self._children:
end += child_spec.num_leaves end += child_spec.num_leaves
child_pytrees.append(child_spec.unflatten(leaves[start:end])) child_pytrees.append(child_spec.unflatten(leaves[start:end]))
start = end start = end
return unflatten_fn(child_pytrees, self.context) return unflatten_fn(child_pytrees, self._context)
def __hash__(self) -> int: def __hash__(self) -> int:
node_type = self.type node_type = self.type
if node_type is defaultdict: if node_type is defaultdict:
default_factory, dict_context = self.context default_factory, dict_context = self._context
hashable_context = (default_factory, tuple(dict_context)) hashable_context = (default_factory, tuple(dict_context))
elif node_type in (dict, OrderedDict): elif node_type in (dict, OrderedDict):
hashable_context = tuple(self.context) hashable_context = tuple(self._context)
elif node_type is None or node_type in BUILTIN_TYPES: elif node_type is None or node_type in BUILTIN_TYPES:
hashable_context = self.context hashable_context = self._context
elif isinstance(self.context, ConstantNode): elif isinstance(self._context, ConstantNode):
hashable_context = self.context.value hashable_context = self._context.value
else: else:
# The context for user-defined node types might not be hashable. # The context for user-defined node types might not be hashable.
# Ignore it for hashing. # Ignore it for hashing.
@ -1256,20 +1298,26 @@ class TreeSpec:
# same hash. This might increase the hash collision rate, but we # same hash. This might increase the hash collision rate, but we
# don't care about that. # don't care about that.
hashable_context = None hashable_context = None
return hash((node_type, hashable_context, tuple(self.children_specs))) return hash((node_type, hashable_context, tuple(self._children)))
PyTreeSpec: TypeAlias = TreeSpec
# NOTE: subclassing a dataclass is subtle. In order to enable reasoning about # NOTE: subclassing a dataclass is subtle. In order to enable reasoning about
# this class with `dataclasses.fields`, etc., while having a simplified # this class with `dataclasses.fields`, etc., while having a simplified
# constructor that takes no argument, we wrap with `dataclass(init=True, ...)` # constructor that takes no argument, we wrap with `dataclass(init=True, ...)`
# again, with fields that have `init=False`. # again, with fields that have `init=False`.
@deprecated(
"`isinstance(treespec, LeafSpec)` is deprecated, "
"use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.",
category=FutureWarning,
)
@dataclasses.dataclass(init=True, frozen=True, eq=False, repr=False) @dataclasses.dataclass(init=True, frozen=True, eq=False, repr=False)
class LeafSpec(TreeSpec): class LeafSpec(TreeSpec):
type: Any = dataclasses.field(default=None, init=False) type: Any = dataclasses.field(default=None, init=False)
context: Context = dataclasses.field(default=None, init=False) _context: Context = dataclasses.field(default=None, init=False)
children_specs: list["TreeSpec"] = dataclasses.field( _children: list[Self] = dataclasses.field(default_factory=list, init=False)
default_factory=list, init=False
)
def __post_init__(self) -> None: def __post_init__(self) -> None:
# Override `__post_init__` for `num_leaves` derivation. # Override `__post_init__` for `num_leaves` derivation.
@ -1283,7 +1331,36 @@ class LeafSpec(TreeSpec):
# All leaves are equivalent, so represent with a single object to save on # All leaves are equivalent, so represent with a single object to save on
# object construction time # object construction time
_LEAF_SPEC = LeafSpec() with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", category=FutureWarning, module=__name__, append=False
)
_LEAF_SPEC = LeafSpec()
def treespec_leaf() -> LeafSpec:
"""Make a treespec representing a leaf node."""
return _LEAF_SPEC
def treespec_tuple(iterable: Iterable[TreeSpec] = (), /) -> TreeSpec:
"""Make a tuple treespec from an iterable of child treespecs."""
children = list(iterable)
if any(not isinstance(child, TreeSpec) for child in children):
raise ValueError(f"Expected a tuple of TreeSpec values, got: {children!r}.")
return TreeSpec(tuple, None, children)
def treespec_dict(
mapping: Union[Mapping[Any, TreeSpec], Iterable[tuple[Any, TreeSpec]]] = (),
/,
**kwargs: TreeSpec,
) -> TreeSpec:
"""Make a dict treespec from a dict of child treespecs."""
dct = dict(mapping, **kwargs)
if any(not isinstance(child, TreeSpec) for child in dct.values()):
raise ValueError(f"Expected a dictionary of TreeSpec values, got: {dct!r}.")
return TreeSpec(dict, list(dct.keys()), list(dct.values()))
def tree_flatten( def tree_flatten(
@ -1762,15 +1839,15 @@ def _broadcast_to_and_flatten(
return None return None
flatten_fn = SUPPORTED_NODES[node_type].flatten_fn flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
child_pytrees, ctx = flatten_fn(tree) child_pytrees, context = flatten_fn(tree)
# Check if the Node is different from the spec # Check if the Node is different from the spec
if len(child_pytrees) != treespec.num_children or ctx != treespec.context: if len(child_pytrees) != treespec.num_children or context != treespec._context:
return None return None
# Recursively flatten the children # Recursively flatten the children
result: list[Any] = [] result: list[Any] = []
for child, child_spec in zip(child_pytrees, treespec.children_specs): for child, child_spec in zip(child_pytrees, treespec._children, strict=True):
flat = _broadcast_to_and_flatten(child, child_spec, is_leaf=is_leaf) flat = _broadcast_to_and_flatten(child, child_spec, is_leaf=is_leaf)
if flat is not None: if flat is not None:
result += flat result += flat
@ -1824,7 +1901,7 @@ def _treespec_to_json(treespec: TreeSpec) -> _TreeSpecSchema:
if serialize_node_def.to_dumpable_context is None: if serialize_node_def.to_dumpable_context is None:
try: try:
serialized_context = json.dumps(treespec.context, cls=EnumEncoder) serialized_context = json.dumps(treespec._context, cls=EnumEncoder)
except TypeError as e: except TypeError as e:
raise TypeError( raise TypeError(
"Unable to serialize context. " "Unable to serialize context. "
@ -1832,9 +1909,9 @@ def _treespec_to_json(treespec: TreeSpec) -> _TreeSpecSchema:
"custom serializer using _register_pytree_node." "custom serializer using _register_pytree_node."
) from e ) from e
else: else:
serialized_context = serialize_node_def.to_dumpable_context(treespec.context) serialized_context = serialize_node_def.to_dumpable_context(treespec._context)
child_schemas = [_treespec_to_json(child) for child in treespec.children_specs] child_schemas = [_treespec_to_json(child) for child in treespec._children]
return _TreeSpecSchema(serialized_type_name, serialized_context, child_schemas) return _TreeSpecSchema(serialized_type_name, serialized_context, child_schemas)