[pytree] add treespec_{leaf,tuple,dict} functions for args_spec modification (#160843)

The goal of this PR is to provide a standard way to create simple treespec instances and hide the implementation details of the `PyTreeSpec` class.

Changes:

1. Add function `treespec_leaf()` to replace `LeafSpec()`.
2. Add function `treespec_tuple(...)` and `treespec_dict(...)` to create treespec for `tuple` / `dict` which is used for `*args` / `**kwargs`. This avoids direct modification to `treespec` instances that rely on the implementation details of the `PyTreeSpec` class.
3. Change `len(spec.children_specs)` to `spec.num_children`.
4. Change `isinstance(spec, LeafSpec)` to `spec.is_leaf()`.

------

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160843
Approved by: https://github.com/mlazos
This commit is contained in:
Xuehai Pan 2025-11-01 08:48:26 +08:00 committed by PyTorch MergeBot
parent 60333de85d
commit e8fadba28c
22 changed files with 404 additions and 164 deletions

View File

@ -91,13 +91,13 @@ from torch.testing._internal.torchbind_impls import load_torchbind_test_lib
from torch.testing._internal.triton_utils import requires_cuda_and_triton, requires_gpu
from torch.testing._internal.two_tensor import TwoTensor
from torch.utils._pytree import (
LeafSpec,
register_constant,
tree_flatten,
tree_map,
tree_unflatten,
TreeSpec,
treespec_dumps,
treespec_leaf,
treespec_loads,
)
@ -7791,7 +7791,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
dt = MyDataClass(x=3, y=4)
flat, spec = tree_flatten(dt)
self.assertTrue(spec, LeafSpec())
self.assertTrue(spec, treespec_leaf())
self.assertTrue(len(flat) == 1)
torch.export.register_dataclass(
@ -7802,7 +7802,9 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
flat, spec = tree_flatten(dt)
self.assertEqual(
spec,
TreeSpec(MyDataClass, [["x", "y"], ["z"]], [LeafSpec(), LeafSpec()]),
TreeSpec(
MyDataClass, [["x", "y"], ["z"]], [treespec_leaf(), treespec_leaf()]
),
)
self.assertEqual(flat, [3, 4])
@ -7835,7 +7837,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
TreeSpec(
MyOtherDataClass,
[["x", "y", "z"], []],
[LeafSpec(), LeafSpec(), LeafSpec()],
[treespec_leaf(), treespec_leaf(), treespec_leaf()],
),
)
self.assertEqual(flat, [3, 4, None])

View File

@ -65,9 +65,6 @@ class TestEnum(enum.Enum):
A = auto()
python_leafspec = python_pytree.LeafSpec()
class TestGenericPytree(TestCase):
def test_aligned_public_apis(self):
public_apis = python_pytree.__all__
@ -197,7 +194,7 @@ class TestGenericPytree(TestCase):
def run_test_with_leaf(leaf):
values, treespec = pytree.tree_flatten(leaf)
self.assertEqual(values, [leaf])
self.assertEqual(treespec, pytree.LeafSpec())
self.assertEqual(treespec, pytree.treespec_leaf())
unflattened = pytree.tree_unflatten(values, treespec)
self.assertEqual(unflattened, leaf)
@ -215,7 +212,7 @@ class TestGenericPytree(TestCase):
(
python_pytree,
lambda tup: python_pytree.TreeSpec(
tuple, None, [python_leafspec for _ in tup]
tuple, None, [python_pytree.treespec_leaf() for _ in tup]
),
),
name="python",
@ -250,7 +247,7 @@ class TestGenericPytree(TestCase):
(
python_pytree,
lambda lst: python_pytree.TreeSpec(
list, None, [python_leafspec for _ in lst]
list, None, [python_pytree.treespec_leaf() for _ in lst]
),
),
name="python",
@ -286,7 +283,7 @@ class TestGenericPytree(TestCase):
lambda dct: python_pytree.TreeSpec(
dict,
list(dct.keys()),
[python_leafspec for _ in dct.values()],
[python_pytree.treespec_leaf() for _ in dct.values()],
),
),
name="python",
@ -327,7 +324,7 @@ class TestGenericPytree(TestCase):
lambda odict: python_pytree.TreeSpec(
OrderedDict,
list(odict.keys()),
[python_leafspec for _ in odict.values()],
[python_pytree.treespec_leaf() for _ in odict.values()],
),
),
name="python",
@ -371,7 +368,7 @@ class TestGenericPytree(TestCase):
lambda ddct: python_pytree.TreeSpec(
defaultdict,
[ddct.default_factory, list(ddct.keys())],
[python_leafspec for _ in ddct.values()],
[python_pytree.treespec_leaf() for _ in ddct.values()],
),
),
name="python",
@ -413,7 +410,7 @@ class TestGenericPytree(TestCase):
(
python_pytree,
lambda deq: python_pytree.TreeSpec(
deque, deq.maxlen, [python_leafspec for _ in deq]
deque, deq.maxlen, [python_pytree.treespec_leaf() for _ in deq]
),
),
name="python",
@ -453,7 +450,7 @@ class TestGenericPytree(TestCase):
def run_test(tup):
if pytree is python_pytree:
expected_spec = python_pytree.TreeSpec(
namedtuple, Point, [python_leafspec for _ in tup]
namedtuple, Point, [python_pytree.treespec_leaf() for _ in tup]
)
else:
expected_spec = cxx_pytree.tree_structure(Point(0, 1))
@ -848,16 +845,16 @@ if "optree" in sys.modules:
def test_treespec_equality(self):
self.assertEqual(
python_pytree.LeafSpec(),
python_pytree.LeafSpec(),
python_pytree.treespec_leaf(),
python_pytree.treespec_leaf(),
)
self.assertEqual(
python_pytree.TreeSpec(list, None, []),
python_pytree.TreeSpec(list, None, []),
)
self.assertEqual(
python_pytree.TreeSpec(list, None, [python_pytree.LeafSpec()]),
python_pytree.TreeSpec(list, None, [python_pytree.LeafSpec()]),
python_pytree.TreeSpec(list, None, [python_pytree.treespec_leaf()]),
python_pytree.TreeSpec(list, None, [python_pytree.treespec_leaf()]),
)
self.assertFalse(
python_pytree.TreeSpec(tuple, None, [])
@ -892,24 +889,32 @@ if "optree" in sys.modules:
# python_pytree.tree_structure({})
python_pytree.TreeSpec(dict, [], []),
# python_pytree.tree_structure([0])
python_pytree.TreeSpec(list, None, [python_leafspec]),
python_pytree.TreeSpec(list, None, [python_pytree.treespec_leaf()]),
# python_pytree.tree_structure([0, 1])
python_pytree.TreeSpec(
list,
None,
[python_leafspec, python_leafspec],
[python_pytree.treespec_leaf(), python_pytree.treespec_leaf()],
),
# python_pytree.tree_structure((0, 1, 2))
python_pytree.TreeSpec(
tuple,
None,
[python_leafspec, python_leafspec, python_leafspec],
[
python_pytree.treespec_leaf(),
python_pytree.treespec_leaf(),
python_pytree.treespec_leaf(),
],
),
# python_pytree.tree_structure({"a": 0, "b": 1, "c": 2})
python_pytree.TreeSpec(
dict,
["a", "b", "c"],
[python_leafspec, python_leafspec, python_leafspec],
[
python_pytree.treespec_leaf(),
python_pytree.treespec_leaf(),
python_pytree.treespec_leaf(),
],
),
# python_pytree.tree_structure(OrderedDict([("a", (0, 1)), ("b", 2), ("c", {"a": 3, "b": 4, "c": 5})])
python_pytree.TreeSpec(
@ -919,13 +924,17 @@ if "optree" in sys.modules:
python_pytree.TreeSpec(
tuple,
None,
[python_leafspec, python_leafspec],
[python_pytree.treespec_leaf(), python_pytree.treespec_leaf()],
),
python_leafspec,
python_pytree.treespec_leaf(),
python_pytree.TreeSpec(
dict,
["a", "b", "c"],
[python_leafspec, python_leafspec, python_leafspec],
[
python_pytree.treespec_leaf(),
python_pytree.treespec_leaf(),
python_pytree.treespec_leaf(),
],
),
],
),
@ -938,12 +947,15 @@ if "optree" in sys.modules:
tuple,
None,
[
python_leafspec,
python_leafspec,
python_pytree.treespec_leaf(),
python_pytree.treespec_leaf(),
python_pytree.TreeSpec(
list,
None,
[python_leafspec, python_leafspec],
[
python_pytree.treespec_leaf(),
python_pytree.treespec_leaf(),
],
),
],
),
@ -957,12 +969,12 @@ if "optree" in sys.modules:
python_pytree.TreeSpec(
list,
None,
[python_leafspec, python_leafspec],
[python_pytree.treespec_leaf(), python_pytree.treespec_leaf()],
),
python_pytree.TreeSpec(
list,
None,
[python_leafspec, python_leafspec],
[python_pytree.treespec_leaf(), python_pytree.treespec_leaf()],
),
python_pytree.TreeSpec(dict, [], []),
],
@ -991,7 +1003,7 @@ if "optree" in sys.modules:
list,
None,
[
python_leafspec,
python_pytree.treespec_leaf(),
],
),
],
@ -1000,7 +1012,7 @@ if "optree" in sys.modules:
self.assertIsInstance(serialized_spec, str)
def test_pytree_serialize_enum(self):
spec = python_pytree.TreeSpec(dict, TestEnum.A, [python_leafspec])
spec = python_pytree.TreeSpec(dict, TestEnum.A, [python_pytree.treespec_leaf()])
serialized_spec = python_pytree.treespec_dumps(spec)
self.assertIsInstance(serialized_spec, str)
@ -1163,12 +1175,20 @@ if "optree" in sys.modules:
OrderedDict,
[1, 2, 3],
[
python_pytree.TreeSpec(tuple, None, [python_leafspec, python_leafspec]),
python_leafspec,
python_pytree.TreeSpec(
tuple,
None,
[python_pytree.treespec_leaf(), python_pytree.treespec_leaf()],
),
python_pytree.treespec_leaf(),
python_pytree.TreeSpec(
dict,
[4, 5, 6],
[python_leafspec, python_leafspec, python_leafspec],
[
python_pytree.treespec_leaf(),
python_pytree.treespec_leaf(),
python_pytree.treespec_leaf(),
],
),
],
)
@ -1453,7 +1473,7 @@ class TestCxxPytree(TestCase):
raise unittest.SkipTest("C++ pytree tests are not supported in fbcode")
def test_treespec_equality(self):
self.assertEqual(cxx_pytree.LeafSpec(), cxx_pytree.LeafSpec())
self.assertEqual(cxx_pytree.treespec_leaf(), cxx_pytree.treespec_leaf())
def test_treespec_repr(self):
# Check that it looks sane

View File

@ -17,7 +17,7 @@ from ..decorators import substitute_in_graph
if TYPE_CHECKING:
import builtins
from collections.abc import Callable, Iterable
from collections.abc import Callable, Iterable, Mapping
from typing_extensions import Self
@ -349,6 +349,113 @@ if python_pytree._cxx_pytree_dynamo_traceable:
def _is_pytreespec_instance(obj: Any, /) -> TypeIs[PyTreeSpec]:
return isinstance(obj, PyTreeSpec)
@substitute_in_graph( # type: ignore[arg-type]
optree.treespec_leaf,
# We need to disable constant folding here because we want the function to reference the
# PyTreeSpec class defined above, not the one in the C++ module.
can_constant_fold_through=False,
)
def treespec_leaf(
*,
none_is_leaf: bool = False,
namespace: str = "", # unused
) -> PyTreeSpec:
return PyTreeSpec(
(),
None,
None,
(),
None,
none_is_leaf=none_is_leaf,
namespace="",
)
@substitute_in_graph( # type: ignore[arg-type]
optree.treespec_tuple,
# We need to disable constant folding here because we want the function to reference the
# PyTreeSpec class defined above, not the one in the C++ module.
can_constant_fold_through=False,
)
def treespec_tuple(
iterable: Iterable[PyTreeSpec] = (),
/,
*,
none_is_leaf: bool = False,
namespace: str = "",
) -> PyTreeSpec:
children = tuple(iterable)
if any(not _is_pytreespec_instance(child) for child in children):
raise ValueError(f"Expected a tuple of PyTreeSpecs, got: {children!r}.")
if any(child.none_is_leaf != none_is_leaf for child in children):
raise ValueError(
"All children PyTreeSpecs must have the same `none_is_leaf` value "
f"as the parent; expected {none_is_leaf}, got: {children!r}.",
)
if any(child.namespace not in (namespace, "") for child in children):
raise ValueError(
"All children PyTreeSpecs must have the same `namespace` value "
f"as the parent; expected {namespace!r}, got: {children!r}.",
)
handler = optree.register_pytree_node.get(tuple, namespace=namespace) # type: ignore[attr-defined]
assert handler is not None
return PyTreeSpec(
tuple(children),
tuple,
None,
tuple(range(len(children))),
handler.unflatten_func,
none_is_leaf=none_is_leaf,
namespace=namespace,
)
@substitute_in_graph( # type: ignore[arg-type]
optree.treespec_dict,
# We need to disable constant folding here because we want the function to reference the
# PyTreeSpec class defined above, not the one in the C++ module.
can_constant_fold_through=False,
)
def treespec_dict(
mapping: Mapping[Any, PyTreeSpec] | Iterable[tuple[Any, PyTreeSpec]] = (),
/,
*,
none_is_leaf: bool = False,
namespace: str = "",
**kwargs: PyTreeSpec,
) -> PyTreeSpec:
dct = dict(mapping, **kwargs)
if any(not _is_pytreespec_instance(child) for child in dct.values()):
raise ValueError(f"Expected a dictionary of TreeSpecs, got: {dct!r}.")
if any(child.none_is_leaf != none_is_leaf for child in dct.values()):
raise ValueError(
"All children PyTreeSpecs must have the same `none_is_leaf` value "
f"as the parent; expected {none_is_leaf}, got: {dct!r}.",
)
if any(child.namespace not in (namespace, "") for child in dct.values()):
raise ValueError(
"All children PyTreeSpecs must have the same `namespace` value "
f"as the parent; expected {namespace!r}, got: {dct!r}.",
)
(
children,
metadata,
entries,
unflatten_func,
) = optree.tree_flatten_one_level( # type: ignore[assignment,var-annotated]
dct, # type: ignore[arg-type]
none_is_leaf=none_is_leaf,
namespace=namespace,
)
return PyTreeSpec(
tuple(children), # type: ignore[arg-type]
dict,
metadata,
entries,
unflatten_func,
none_is_leaf=none_is_leaf,
namespace=namespace,
)
@substitute_in_graph( # type: ignore[arg-type]
optree.tree_flatten,
# We need to disable constant folding here because we want the function to reference the

View File

@ -3727,9 +3727,7 @@ class SourcelessBuilder:
pass # failthrough to unimplemented branch
elif isinstance(value, torch.fx.graph_module.GraphModule):
return SourcelessGraphModuleVariable(value)
elif isinstance(
value, (torch.utils._pytree.TreeSpec, torch.utils._pytree.LeafSpec)
):
elif isinstance(value, torch.utils._pytree.TreeSpec):
return UserDefinedObjectVariable(value)
elif PlacementVariable.is_placement(value):
return PlacementVariable(value)

View File

@ -1661,8 +1661,8 @@ class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
# We need to have this check this way, because in case init is a TreeSpec and carry
# but carry is only a LeafSpec, these two cannot be compared correctly.
if (
isinstance(xs_treespec.as_python_constant(), pytree.LeafSpec)
!= isinstance(_combine_treespec.as_python_constant(), pytree.LeafSpec)
xs_treespec.as_python_constant().is_leaf()
!= _combine_treespec.as_python_constant().is_leaf()
) or not _make_inlined(tx, pytree.TreeSpec.__eq__)(
xs_treespec, _combine_treespec
).as_python_constant():

View File

@ -1530,10 +1530,10 @@ class GraphModuleSerializer(metaclass=Final):
else:
raise AssertionError("TODO")
def serialize_treespec(self, treespec):
def serialize_treespec(self, treespec: pytree.TreeSpec) -> str:
# We want to additionally save all the field names of the namedtuples in
# case users want to check that the treespec types are equivalent
def store_namedtuple_fields(ts):
def store_namedtuple_fields(ts: pytree.TreeSpec) -> None:
if ts.type is None:
return
if ts.type is namedtuple or pytree.is_namedtuple_class(ts.type):
@ -1555,7 +1555,7 @@ class GraphModuleSerializer(metaclass=Final):
NamedTupleDef(field_names=ts.context._fields)
)
for child in ts.children_specs:
for child in ts.children():
store_namedtuple_fields(child)
serialized_treespec = treespec_dumps(treespec, TREESPEC_VERSION)

View File

@ -158,7 +158,7 @@ class PytreeThunk:
assert spec is not None
self.spec: pytree.TreeSpec = spec
if self.spec.type in {tuple, list} and all(
child.is_leaf() for child in spec.children_specs
child.is_leaf() for child in spec.children()
):
self.is_simple = True
if self.spec.is_leaf():

View File

@ -1590,7 +1590,7 @@ def aot_export_joint_simple(
decompositions=decompositions,
trace_joint=trace_joint,
)
in_spec, _kw_in_spec = in_spec.children_specs
in_spec, _kw_in_spec = in_spec.children()
# At this point, we can just directly return the (joint or inference graph) that we traced.
# First though: a bunch of assertions to make sure that our graph doesn't require
# any calling convention changes compared to the original function.
@ -1617,7 +1617,7 @@ def aot_export_joint_simple(
raise RuntimeError(
f"aot_export_joint_simple requires inputs to be a single list/tuple. in_spec={str(in_spec)}"
)
if not all(child.is_leaf() for child in in_spec.children_specs):
if not all(child.is_leaf() for child in in_spec.children()):
raise RuntimeError(
f"aot_export_joint_simple requires individual inputs not to be pytrees. in_spec={str(in_spec)}"
)
@ -1625,7 +1625,7 @@ def aot_export_joint_simple(
raise RuntimeError(
f"aot_export_joint_simple requires outputs to be a single list/tuple. out_spec={str(out_spec)}"
)
if not all(child.is_leaf() for child in out_spec.children_specs):
if not all(child.is_leaf() for child in out_spec.children()):
raise RuntimeError(
f"aot_export_joint_simple requires individual outputs not to be pytrees. out_spec={str(out_spec)}"
)

View File

@ -469,7 +469,7 @@ def _unlift_graph(
gm,
lifted_inputs,
mutated_outputs,
pytree.LeafSpec(),
pytree.treespec_leaf(),
None,
)
return unlifted_gm

View File

@ -7,6 +7,7 @@ from typing import Any, Optional, Union
import torch
import torch.ao.quantization.pt2e._affine_quantization # noqa: F401
import torch.nn.functional as F
import torch.utils._pytree as pytree
# Makes sure that quantized_decomposed ops are registered
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
@ -14,7 +15,6 @@ from torch.ao.quantization.quantizer import QuantizationAnnotation
from torch.export.unflatten import _assign_attr, _AttrKind
from torch.fx import GraphModule, Node
from torch.nn.utils.fusion import fuse_conv_bn_weights
from torch.utils._pytree import LeafSpec
__all__ = [
@ -477,7 +477,10 @@ def _replace_literals_with_new_placeholders(
exclude_literals = []
in_spec = gm._in_spec
args_spec = in_spec.children_specs[0]
assert in_spec.type is tuple
args_spec = in_spec.child(0)
assert args_spec.type is tuple
args_spec_children = args_spec.children()
for node in gm.graph.nodes:
if node.op == "placeholder":
last_ph = node
@ -492,7 +495,7 @@ def _replace_literals_with_new_placeholders(
else:
ph_node = gm.graph.placeholder("arg" + str(cnt))
new_args.append(ph_node)
args_spec.children_specs.append(LeafSpec())
args_spec_children.append(pytree.treespec_leaf())
cnt += 1
if merge_dup:
literal_to_ph[arg] = ph_node
@ -503,8 +506,8 @@ def _replace_literals_with_new_placeholders(
node.args = new_args
# Update `num_nodes`, `num_leaves`, `num_children`.
args_spec.__post_init__()
in_spec.__post_init__()
args_spec = pytree.treespec_tuple(args_spec_children)
gm._in_spec = in_spec = pytree.treespec_tuple([args_spec, *in_spec.children()[1:]])
return gm

View File

@ -195,17 +195,16 @@ def _construct_inputs(
unflatten_node = _generate_unflatten(gm, tree_unflatten_args, signature.in_spec)
assert signature.in_spec.num_children == 2
assert signature.in_spec.type is tuple
args_spec, kwargs_spec = signature.in_spec.children()
assert args_spec.type is tuple
assert kwargs_spec.type is dict
args_spec = signature.in_spec.children_specs[0]
assert args_spec.context is None
args_node = gm.graph.call_function(operator.getitem, (unflatten_node, 0))
args_nodes = [
gm.graph.call_function(operator.getitem, (args_node, i))
for i in range(args_spec.num_children)
]
kwargs_spec = signature.in_spec.children_specs[1]
assert kwargs_spec.context is not None
kwargs_node = gm.graph.call_function(operator.getitem, (unflatten_node, 1))
kwargs_nodes = {
k: gm.graph.call_function(operator.getitem, (kwargs_node, k))
@ -372,8 +371,8 @@ def _fix_input_output_signature(
if forward_arg_names is None:
forward_arg_names = []
assert signature.in_spec.num_children == 2
arg_spec = signature.in_spec.children_specs[0]
kwarg_spec = signature.in_spec.children_specs[1]
arg_spec = signature.in_spec.child(0)
kwarg_spec = signature.in_spec.child(1)
assert arg_spec.type is tuple
assert kwarg_spec.type is dict
for i in range(arg_spec.num_children):

View File

@ -1533,7 +1533,7 @@ def _strict_export(
# aot_export expect the return type to always be a tuple.
if out_spec.type not in (list, tuple):
out_spec = pytree.TreeSpec(tuple, None, [out_spec])
out_spec = pytree.treespec_tuple([out_spec])
orig_arg_names = gm_torch_level.graph._codegen.pytree_info.orig_args # type: ignore[attr-defined]

View File

@ -15,7 +15,7 @@ def reorder_kwargs(user_kwargs: dict[str, Any], spec: TreeSpec) -> dict[str, Any
# Make sure that the spec is actually shaped like (args, kwargs)
assert spec.type is tuple
assert spec.num_children == 2
kwargs_spec = spec.children_specs[1]
kwargs_spec = spec.child(1)
assert kwargs_spec.type is dict
if set(user_kwargs) != set(kwargs_spec.context):
@ -55,10 +55,10 @@ def is_equivalent(
return False
# Recurse on children
if len(spec1.children_specs) != len(spec2.children_specs):
if spec1.num_children != spec2.num_children:
return False
for child_spec1, child_spec2 in zip(spec1.children_specs, spec2.children_specs):
for child_spec1, child_spec2 in zip(spec1.children(), spec2.children()):
if not is_equivalent(child_spec1, child_spec2, equivalence_fn):
return False

View File

@ -57,11 +57,11 @@ def eq_spec(self: pytree.TreeSpec, other: pytree.TreeSpec) -> bool:
return False
elif a.context != b.context:
return False
if len(a.children_specs) != len(b.children_specs):
if a.num_children != b.num_children:
return False
return all(
_match_normalized_structure(a, b)
for a, b in zip(a.children_specs, b.children_specs)
for a, b in zip(a.children(), b.children())
)
return _match_normalized_structure(self, other)
@ -357,13 +357,13 @@ def _get_codegen(
elif (
in_spec.type is tuple
and in_spec.num_children == 2
and in_spec.children_specs[0].type is tuple
and in_spec.children_specs[1].type is dict
and in_spec.child(0).type is tuple
and in_spec.child(1).type is dict
):
# if in_spec contains the args (tuple) and kwargs (dict)
names = [f"arg_{i}" for i in range(in_spec.children_specs[0].num_children)]
names = [f"arg_{i}" for i in range(in_spec.child(0).num_children)]
# add kwarg names
names.extend(in_spec.children_specs[1].context)
names.extend(in_spec.child(1).context)
else:
names = [f"arg_{i}" for i in range(in_spec.num_children)]

View File

@ -12,14 +12,16 @@ import torch
from torch.utils._pytree import (
_get_node_type,
BUILTIN_TYPES,
KeyPath,
keystr,
LeafSpec,
MappingKey,
SequenceKey,
SUPPORTED_NODES,
tree_flatten,
tree_iter,
tree_map,
tree_map_with_path,
tree_structure,
TreeSpec,
)
from .exported_program import ExportedProgram
@ -655,53 +657,55 @@ def _tree_map_with_path(
case_name="dynamic_shapes_validation",
)
def _compare(tree, dynamic_shapes, path):
def _compare(
treespec: TreeSpec, other_treespec: TreeSpec, path: KeyPath
) -> None:
# raise an error at the point where tree and dynamic_shapes differ,
# including the path to that point and the reason for the difference
rendered_path = keystr(path)
if isinstance(tree, LeafSpec):
if treespec.is_leaf():
return
if isinstance(dynamic_shapes, LeafSpec):
if other_treespec.is_leaf():
raise_mismatch_error(
f"`{tree_name}{rendered_path}` is a {tree.type}, "
f"`{tree_name}{rendered_path}` is a {treespec.type}, "
f"but `dynamic_shapes{rendered_path}` is not"
)
if tree.type != dynamic_shapes.type:
if treespec.type != other_treespec.type:
raise_mismatch_error(
f"`{tree_name}{rendered_path}` is a {tree.type}, "
f"but `dynamic_shapes{rendered_path}` is a {dynamic_shapes.type}"
f"`{tree_name}{rendered_path}` is a {treespec.type}, "
f"but `dynamic_shapes{rendered_path}` is a {other_treespec.type}"
)
if len(tree.children_specs) != len(dynamic_shapes.children_specs):
if treespec.num_children != other_treespec.num_children:
raise_mismatch_error(
f"`{tree_name}{rendered_path}` has {len(tree.children_specs)} elements, "
f"but `dynamic_shapes{rendered_path}` has {len(dynamic_shapes.children_specs)} elements"
f"`{tree_name}{rendered_path}` has {treespec.num_children} elements, "
f"but `dynamic_shapes{rendered_path}` has {other_treespec.num_children} elements"
)
if tree.type is dict:
if treespec.type is dict:
# context, children could be out of order
if sorted(tree.context) != sorted(dynamic_shapes.context):
if set(treespec.context) != set(other_treespec.context):
raise_mismatch_error(
f"`{tree_name}{rendered_path}` has keys {tree.context}, "
f"but `dynamic_shapes{rendered_path}` has keys {dynamic_shapes.context}"
f"`{tree_name}{rendered_path}` has keys {treespec.context}, "
f"but `dynamic_shapes{rendered_path}` has keys {other_treespec.context}"
)
_remap = dict(
zip(dynamic_shapes.context, dynamic_shapes.children_specs)
zip(other_treespec.context, other_treespec.children())
)
dynamic_shapes_children_specs = [_remap[k] for k in tree.context]
other_children = [_remap[k] for k in treespec.context]
else:
dynamic_shapes_children_specs = dynamic_shapes.children_specs
for i, (tree_, dynamic_shapes_) in enumerate(
zip(tree.children_specs, dynamic_shapes_children_specs)
other_children = other_treespec.children()
for i, (child, other_child) in enumerate(
zip(treespec.children(), other_children)
):
_compare(
tree_,
dynamic_shapes_,
path + [_key(tree.type, tree.context, i)],
child,
other_child,
path + (_key(treespec.type, treespec.context, i),),
)
_, tree_spec = tree_flatten(tree, is_leaf=is_leaf)
treespec = tree_structure(tree, is_leaf=is_leaf)
for other_tree in dynamic_shapes:
_, other_tree_spec = tree_flatten(other_tree, is_leaf)
_compare(tree_spec, other_tree_spec, [])
other_treespec = tree_structure(other_tree, is_leaf)
_compare(treespec, other_treespec, ())
raise
@ -1231,10 +1235,7 @@ def _get_dim_name_mapping(
dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any], None],
):
name_to_dim = {}
for dim in tree_flatten(
dynamic_shapes,
is_leaf=lambda x: isinstance(x, Dim),
)[0]:
for dim in tree_iter(dynamic_shapes, is_leaf=lambda x: isinstance(x, Dim)):
if dim is None:
# NOTE: this must denote a non-Tensor or automatic at this point.
continue

View File

@ -391,7 +391,7 @@ def _decompose_and_get_gm_with_new_signature_constants(
# aot_export expect the return type to always be a tuple.
assert out_spec is not None
if out_spec.type not in (list, tuple):
out_spec = pytree.TreeSpec(tuple, None, [out_spec])
out_spec = pytree.treespec_tuple([out_spec])
mod.graph._codegen = _PyTreeCodeGen(
_PyTreeInfo(

View File

@ -1124,10 +1124,10 @@ class _ModuleFrame:
signature = module_call_graph.get(self.child_fqn)
if signature is not None and self.parent is not None:
assert signature.in_spec.num_children == 2
args_spec = signature.in_spec.children_specs[0]
kwargs_spec = signature.in_spec.children_specs[1]
assert args_spec.context is None
assert kwargs_spec.context is not None
assert signature.in_spec.type is tuple
args_spec, kwargs_spec = signature.in_spec.children()
assert args_spec.type is tuple
assert kwargs_spec.type is dict
with self.graph.inserting_after(None):
arg_nodes = [

View File

@ -49,7 +49,7 @@ def tree_flatten_spec(
flatten_fn_spec = SUPPORTED_NODES[spec.type]
child_pytrees = flatten_fn_spec(pytree, spec)
result = []
for child, child_spec in zip(child_pytrees, spec.children_specs):
for child, child_spec in zip(child_pytrees, spec.children()):
flat = tree_flatten_spec(child, child_spec)
result += flat
return result

View File

@ -709,7 +709,7 @@ class Tracer(TracerBase):
root_fn = _patch_function(root_fn, len(args))
flat_args, in_spec = pytree.tree_flatten(tuple(args))
if not all(child.is_leaf() for child in in_spec.children_specs):
if not all(child.is_leaf() for child in in_spec.children()):
# In the case that we have pytree-flattened inputs in
# `concrete_args`, generate a flattening wrapper around the
# original root function and return that.

View File

@ -933,24 +933,25 @@ class _PyTreeCodeGen(CodeGen):
return "\n " + "".join(x + "; " for x in has_annotation) + "\n"
def gen_var_bindings(self, fn_args, free_vars, expanded_def) -> str:
in_spec = self.pytree_info.in_spec
# when kwargs is present, in_spec is tuple(args, kwargs)
has_args_kwargs_tuple = (
self.pytree_info.in_spec.type is tuple
and self.pytree_info.in_spec.num_children == 2
and self.pytree_info.in_spec.children_specs[0].type is tuple
and self.pytree_info.in_spec.children_specs[1].type is dict
in_spec.type is tuple
and in_spec.num_children == 2
and in_spec.child(0).type is tuple
and in_spec.child(1).type is dict
)
fn_kwargs = "{}"
fn_signature = f"[{', '.join(fn_args)}], self._in_spec"
if has_args_kwargs_tuple:
count_args = self.pytree_info.in_spec.children_specs[0].num_children
count_args = in_spec.child(0).num_children
fn_args = self.pytree_info.orig_args[:count_args]
fn_kwargs = (
"{"
+ ", ".join(
f"'{k}':{v}"
for k, v in zip(
self.pytree_info.in_spec.children_specs[1].context,
in_spec.child(1).context,
self.pytree_info.orig_args[count_args:],
)
)

View File

@ -14,9 +14,9 @@ collection support for PyTorch APIs.
import functools
import types
from collections.abc import Callable, Iterable
from collections.abc import Callable, Iterable, Mapping
from typing import Any, Optional, overload, TypeVar, Union
from typing_extensions import deprecated, TypeIs
from typing_extensions import deprecated, Self, TypeAlias, TypeIs
import torch.utils._pytree as python_pytree
from torch.torch_version import TorchVersion as _TorchVersion
@ -42,7 +42,7 @@ if not python_pytree._cxx_pytree_dynamo_traceable:
import optree
from optree import PyTreeSpec as TreeSpec # direct import for type annotations
from optree import PyTreeSpec # direct import for type annotations
__all__ = [
@ -53,6 +53,7 @@ __all__ = [
"DumpableContext",
"ToDumpableContextFn",
"FromDumpableContextFn",
"PyTreeSpec",
"TreeSpec",
"LeafSpec",
"keystr",
@ -100,6 +101,8 @@ U = TypeVar("U")
R = TypeVar("R")
TreeSpec: TypeAlias = PyTreeSpec
Context = Any
PyTree = Any
FlattenFunc = Callable[[PyTree], tuple[list[Any], Context]]
@ -267,6 +270,30 @@ def _is_pytreespec_instance(obj: Any, /) -> TypeIs[TreeSpec]:
return isinstance(obj, TreeSpec)
def treespec_leaf() -> TreeSpec:
"""Make a treespec representing a leaf node."""
return optree.treespec_leaf(none_is_leaf=True, namespace="torch")
def treespec_tuple(iterable: Iterable[TreeSpec] = (), /) -> TreeSpec:
"""Make a tuple treespec from an iterable of child treespecs."""
return optree.treespec_tuple(iterable, none_is_leaf=True, namespace="torch")
def treespec_dict(
mapping: Union[Mapping[Any, TreeSpec], Iterable[tuple[Any, TreeSpec]]] = (),
/,
**kwargs: TreeSpec,
) -> TreeSpec:
"""Make a dict treespec from a dict of child treespecs."""
return optree.treespec_dict(
mapping,
**kwargs,
none_is_leaf=True,
namespace="torch",
)
def tree_is_leaf(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
@ -985,9 +1012,14 @@ class LeafSpecMeta(type(TreeSpec)): # type: ignore[misc]
return _is_pytreespec_instance(instance) and instance.is_leaf()
@deprecated(
"`isinstance(treespec, LeafSpec)` is deprecated, "
"use `isinstance(treespec, TreeSpec)` and `treespec.is_leaf()` instead.",
category=FutureWarning,
)
class LeafSpec(TreeSpec, metaclass=LeafSpecMeta): # type: ignore[misc,final]
def __new__(cls) -> "LeafSpec":
return optree.treespec_leaf(none_is_leaf=True) # type: ignore[return-value]
def __new__(cls) -> Self:
return treespec_leaf() # type: ignore[return-value]
def tree_flatten_with_path(

View File

@ -39,7 +39,7 @@ from typing import (
TypeVar,
Union,
)
from typing_extensions import deprecated, NamedTuple, Self
from typing_extensions import deprecated, NamedTuple, Self, TypeAlias
from torch.torch_version import TorchVersion as _TorchVersion
@ -52,6 +52,7 @@ __all__ = [
"DumpableContext",
"ToDumpableContextFn",
"FromDumpableContextFn",
"PyTreeSpec",
"TreeSpec",
"LeafSpec",
"keystr",
@ -475,7 +476,7 @@ class ConstantNode:
def _is_constant_holder(spec: "TreeSpec") -> bool:
"""Checks if the spec is from a pytree registered with register_constant"""
return isinstance(spec.context, ConstantNode)
return isinstance(spec._context, ConstantNode)
def _retrieve_constant(spec: "TreeSpec") -> Any:
@ -1076,39 +1077,60 @@ def _is_leaf(tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None) -
# A TreeSpec represents the structure of a pytree. It holds:
# "type": the type of root Node of the pytree
# context: some context that is useful in unflattening the pytree
# children_specs: specs for each child of the root Node
# num_leaves: the number of leaves
@dataclasses.dataclass(init=True, frozen=True, eq=True, repr=False)
# "type": the type of root Node of the pytree
# context: some context that is useful in unflattening the pytree
# children(): specs for each child of the root Node
# num_nodes: the total number of nodes
# num_leaves: the number of leaves
# num_children: the number of children of the root Node (i.e., len(children()))
# is_leaf(): whether the root Node is a leaf
@dataclasses.dataclass(init=False, frozen=True, eq=True, repr=False)
class TreeSpec:
type: Any
context: Context
children_specs: list["TreeSpec"]
_context: Context
_children: list[Self]
num_nodes: int = dataclasses.field(init=False)
num_leaves: int = dataclasses.field(init=False)
num_children: int = dataclasses.field(init=False)
def __init__(
self,
type: Any,
context: Context, # keep for backward compatibility
children_specs: list[Self], # keep for backward compatibility
) -> None:
object.__setattr__(self, "type", type)
object.__setattr__(self, "_context", context)
object.__setattr__(self, "_children", children_specs)
self.__post_init__()
def __post_init__(self) -> None:
num_nodes = sum((spec.num_nodes for spec in self.children_specs), start=1)
num_leaves = sum(spec.num_leaves for spec in self.children_specs)
num_children = len(self.children_specs)
if self.type is None:
assert self._context is None
assert len(self._children) == 0
num_nodes = 1
num_leaves = 1
num_children = 0
else:
num_nodes = sum((spec.num_nodes for spec in self._children), start=1)
num_leaves = sum(spec.num_leaves for spec in self._children)
num_children = len(self._children)
object.__setattr__(self, "num_nodes", num_nodes)
object.__setattr__(self, "num_leaves", num_leaves)
object.__setattr__(self, "num_children", num_children)
def __repr__(self, indent: int = 0) -> str:
repr_prefix: str = f"TreeSpec({self.type.__name__}, {self.context}, ["
repr_prefix: str = f"TreeSpec({self.type.__name__}, {self._context}, ["
children_specs_str: str = ""
if self.num_children > 0:
indent += 2
children_specs_str += self.children_specs[0].__repr__(indent)
children_specs_str += self._children[0].__repr__(indent)
children_specs_str += "," if self.num_children > 1 else ""
children_specs_str += ",".join(
[
"\n" + " " * indent + child.__repr__(indent)
for child in self.children_specs[1:]
for child in self._children[1:]
]
)
repr_suffix: str = f"{children_specs_str}])"
@ -1120,16 +1142,36 @@ class TreeSpec:
elif other.__class__ is self.__class__:
if str(self.type) != str(other.type):
return False
if self.context != other.context:
if self._context != other._context:
return False
elif self.children_specs != other.children_specs:
elif self._children != other._children:
return False
return True
return NotImplemented
@property
def context(self) -> Context:
return self._context
@property
@deprecated(
"`treespec.children_specs` is deprecated. "
"Use `treespec.child(index)` to access a single child, "
"or `treespec.children()` to get all children.",
category=FutureWarning,
)
def children_specs(self) -> list[Self]:
return self._children
def is_leaf(self) -> bool:
return self.num_nodes == 1 and self.num_leaves == 1
def children(self) -> list[Self]:
return self._children.copy()
def child(self, index: int) -> Self:
return self._children[index]
def flatten_up_to(self, tree: PyTree) -> list[PyTree]:
def helper(treespec: TreeSpec, tree: PyTree, subtrees: list[PyTree]) -> None:
if treespec.is_leaf():
@ -1151,7 +1193,7 @@ class TreeSpec:
f"Node arity mismatch; "
f"expected {treespec.num_children}, but got {len(children)}.",
)
if context != treespec.context:
if context != treespec._context:
raise ValueError(
f"Node context mismatch for custom node type {treespec.type!r}.",
)
@ -1176,10 +1218,10 @@ class TreeSpec:
if both_standard_dict:
# dictionary types are compatible with each other
dict_context = (
treespec.context
treespec._context
if treespec.type is not defaultdict
# ignore mismatch of `default_factory` for defaultdict
else treespec.context[1]
else treespec._context[1]
)
expected_keys = dict_context
got_key_set = set(tree)
@ -1200,13 +1242,13 @@ class TreeSpec:
children, context = flatten_fn(tree)
if (
node_type is not deque # ignore mismatch of `maxlen` for deque
) and context != treespec.context:
) and context != treespec._context:
raise ValueError(
f"Node context mismatch for node type {treespec.type!r}; "
f"expected {treespec.context!r}, but got {context!r}.", # namedtuple type mismatch
f"expected {treespec._context!r}, but got {context!r}.", # namedtuple type mismatch
)
for subtree, subspec in zip(children, treespec.children_specs):
for subtree, subspec in zip(children, treespec._children, strict=True):
helper(subspec, subtree, subtrees)
subtrees: list[PyTree] = []
@ -1231,24 +1273,24 @@ class TreeSpec:
start = 0
end = 0
child_pytrees = []
for child_spec in self.children_specs:
for child_spec in self._children:
end += child_spec.num_leaves
child_pytrees.append(child_spec.unflatten(leaves[start:end]))
start = end
return unflatten_fn(child_pytrees, self.context)
return unflatten_fn(child_pytrees, self._context)
def __hash__(self) -> int:
node_type = self.type
if node_type is defaultdict:
default_factory, dict_context = self.context
default_factory, dict_context = self._context
hashable_context = (default_factory, tuple(dict_context))
elif node_type in (dict, OrderedDict):
hashable_context = tuple(self.context)
hashable_context = tuple(self._context)
elif node_type is None or node_type in BUILTIN_TYPES:
hashable_context = self.context
elif isinstance(self.context, ConstantNode):
hashable_context = self.context.value
hashable_context = self._context
elif isinstance(self._context, ConstantNode):
hashable_context = self._context.value
else:
# The context for user-defined node types might not be hashable.
# Ignore it for hashing.
@ -1256,20 +1298,26 @@ class TreeSpec:
# same hash. This might increase the hash collision rate, but we
# don't care about that.
hashable_context = None
return hash((node_type, hashable_context, tuple(self.children_specs)))
return hash((node_type, hashable_context, tuple(self._children)))
PyTreeSpec: TypeAlias = TreeSpec
# NOTE: subclassing a dataclass is subtle. In order to enable reasoning about
# this class with `dataclasses.fields`, etc., while having a simplified
# constructor that takes no argument, we wrap with `dataclass(init=True, ...)`
# again, with fields that have `init=False`.
@deprecated(
"`isinstance(treespec, LeafSpec)` is deprecated, "
"use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.",
category=FutureWarning,
)
@dataclasses.dataclass(init=True, frozen=True, eq=False, repr=False)
class LeafSpec(TreeSpec):
type: Any = dataclasses.field(default=None, init=False)
context: Context = dataclasses.field(default=None, init=False)
children_specs: list["TreeSpec"] = dataclasses.field(
default_factory=list, init=False
)
_context: Context = dataclasses.field(default=None, init=False)
_children: list[Self] = dataclasses.field(default_factory=list, init=False)
def __post_init__(self) -> None:
# Override `__post_init__` for `num_leaves` derivation.
@ -1283,7 +1331,36 @@ class LeafSpec(TreeSpec):
# All leaves are equivalent, so represent with a single object to save on
# object construction time
_LEAF_SPEC = LeafSpec()
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", category=FutureWarning, module=__name__, append=False
)
_LEAF_SPEC = LeafSpec()
def treespec_leaf() -> LeafSpec:
"""Make a treespec representing a leaf node."""
return _LEAF_SPEC
def treespec_tuple(iterable: Iterable[TreeSpec] = (), /) -> TreeSpec:
"""Make a tuple treespec from an iterable of child treespecs."""
children = list(iterable)
if any(not isinstance(child, TreeSpec) for child in children):
raise ValueError(f"Expected a tuple of TreeSpec values, got: {children!r}.")
return TreeSpec(tuple, None, children)
def treespec_dict(
mapping: Union[Mapping[Any, TreeSpec], Iterable[tuple[Any, TreeSpec]]] = (),
/,
**kwargs: TreeSpec,
) -> TreeSpec:
"""Make a dict treespec from a dict of child treespecs."""
dct = dict(mapping, **kwargs)
if any(not isinstance(child, TreeSpec) for child in dct.values()):
raise ValueError(f"Expected a dictionary of TreeSpec values, got: {dct!r}.")
return TreeSpec(dict, list(dct.keys()), list(dct.values()))
def tree_flatten(
@ -1762,15 +1839,15 @@ def _broadcast_to_and_flatten(
return None
flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
child_pytrees, ctx = flatten_fn(tree)
child_pytrees, context = flatten_fn(tree)
# Check if the Node is different from the spec
if len(child_pytrees) != treespec.num_children or ctx != treespec.context:
if len(child_pytrees) != treespec.num_children or context != treespec._context:
return None
# Recursively flatten the children
result: list[Any] = []
for child, child_spec in zip(child_pytrees, treespec.children_specs):
for child, child_spec in zip(child_pytrees, treespec._children, strict=True):
flat = _broadcast_to_and_flatten(child, child_spec, is_leaf=is_leaf)
if flat is not None:
result += flat
@ -1824,7 +1901,7 @@ def _treespec_to_json(treespec: TreeSpec) -> _TreeSpecSchema:
if serialize_node_def.to_dumpable_context is None:
try:
serialized_context = json.dumps(treespec.context, cls=EnumEncoder)
serialized_context = json.dumps(treespec._context, cls=EnumEncoder)
except TypeError as e:
raise TypeError(
"Unable to serialize context. "
@ -1832,9 +1909,9 @@ def _treespec_to_json(treespec: TreeSpec) -> _TreeSpecSchema:
"custom serializer using _register_pytree_node."
) from e
else:
serialized_context = serialize_node_def.to_dumpable_context(treespec.context)
serialized_context = serialize_node_def.to_dumpable_context(treespec._context)
child_schemas = [_treespec_to_json(child) for child in treespec.children_specs]
child_schemas = [_treespec_to_json(child) for child in treespec._children]
return _TreeSpecSchema(serialized_type_name, serialized_context, child_schemas)