Revert "[pytree] add treespec_{leaf,tuple,dict} functions for args_spec modification (#160843)"

This reverts commit 108bb224f7.

Reverted https://github.com/pytorch/pytorch/pull/160843 on behalf of https://github.com/atalman due to failing internal builds ([comment](https://github.com/pytorch/pytorch/pull/160843#issuecomment-3474354428))
This commit is contained in:
PyTorch MergeBot 2025-10-31 18:31:32 +00:00
parent b71966f67b
commit 85b85f6c2c
22 changed files with 160 additions and 397 deletions

View File

@ -91,13 +91,13 @@ from torch.testing._internal.torchbind_impls import load_torchbind_test_lib
from torch.testing._internal.triton_utils import requires_cuda_and_triton, requires_gpu
from torch.testing._internal.two_tensor import TwoTensor
from torch.utils._pytree import (
LeafSpec,
register_constant,
tree_flatten,
tree_map,
tree_unflatten,
TreeSpec,
treespec_dumps,
treespec_leaf,
treespec_loads,
)
@ -7791,7 +7791,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
dt = MyDataClass(x=3, y=4)
flat, spec = tree_flatten(dt)
self.assertTrue(spec, treespec_leaf())
self.assertTrue(spec, LeafSpec())
self.assertTrue(len(flat) == 1)
torch.export.register_dataclass(
@ -7802,9 +7802,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
flat, spec = tree_flatten(dt)
self.assertEqual(
spec,
TreeSpec(
MyDataClass, [["x", "y"], ["z"]], [treespec_leaf(), treespec_leaf()]
),
TreeSpec(MyDataClass, [["x", "y"], ["z"]], [LeafSpec(), LeafSpec()]),
)
self.assertEqual(flat, [3, 4])
@ -7837,7 +7835,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
TreeSpec(
MyOtherDataClass,
[["x", "y", "z"], []],
[treespec_leaf(), treespec_leaf(), treespec_leaf()],
[LeafSpec(), LeafSpec(), LeafSpec()],
),
)
self.assertEqual(flat, [3, 4, None])

View File

@ -65,6 +65,9 @@ class TestEnum(enum.Enum):
A = auto()
python_leafspec = python_pytree.LeafSpec()
class TestGenericPytree(TestCase):
def test_aligned_public_apis(self):
public_apis = python_pytree.__all__
@ -194,7 +197,7 @@ class TestGenericPytree(TestCase):
def run_test_with_leaf(leaf):
values, treespec = pytree.tree_flatten(leaf)
self.assertEqual(values, [leaf])
self.assertEqual(treespec, pytree.treespec_leaf())
self.assertEqual(treespec, pytree.LeafSpec())
unflattened = pytree.tree_unflatten(values, treespec)
self.assertEqual(unflattened, leaf)
@ -212,7 +215,7 @@ class TestGenericPytree(TestCase):
(
python_pytree,
lambda tup: python_pytree.TreeSpec(
tuple, None, [python_pytree.treespec_leaf() for _ in tup]
tuple, None, [python_leafspec for _ in tup]
),
),
name="python",
@ -247,7 +250,7 @@ class TestGenericPytree(TestCase):
(
python_pytree,
lambda lst: python_pytree.TreeSpec(
list, None, [python_pytree.treespec_leaf() for _ in lst]
list, None, [python_leafspec for _ in lst]
),
),
name="python",
@ -283,7 +286,7 @@ class TestGenericPytree(TestCase):
lambda dct: python_pytree.TreeSpec(
dict,
list(dct.keys()),
[python_pytree.treespec_leaf() for _ in dct.values()],
[python_leafspec for _ in dct.values()],
),
),
name="python",
@ -324,7 +327,7 @@ class TestGenericPytree(TestCase):
lambda odict: python_pytree.TreeSpec(
OrderedDict,
list(odict.keys()),
[python_pytree.treespec_leaf() for _ in odict.values()],
[python_leafspec for _ in odict.values()],
),
),
name="python",
@ -368,7 +371,7 @@ class TestGenericPytree(TestCase):
lambda ddct: python_pytree.TreeSpec(
defaultdict,
[ddct.default_factory, list(ddct.keys())],
[python_pytree.treespec_leaf() for _ in ddct.values()],
[python_leafspec for _ in ddct.values()],
),
),
name="python",
@ -410,7 +413,7 @@ class TestGenericPytree(TestCase):
(
python_pytree,
lambda deq: python_pytree.TreeSpec(
deque, deq.maxlen, [python_pytree.treespec_leaf() for _ in deq]
deque, deq.maxlen, [python_leafspec for _ in deq]
),
),
name="python",
@ -450,7 +453,7 @@ class TestGenericPytree(TestCase):
def run_test(tup):
if pytree is python_pytree:
expected_spec = python_pytree.TreeSpec(
namedtuple, Point, [python_pytree.treespec_leaf() for _ in tup]
namedtuple, Point, [python_leafspec for _ in tup]
)
else:
expected_spec = cxx_pytree.tree_structure(Point(0, 1))
@ -845,16 +848,16 @@ if "optree" in sys.modules:
def test_treespec_equality(self):
self.assertEqual(
python_pytree.treespec_leaf(),
python_pytree.treespec_leaf(),
python_pytree.LeafSpec(),
python_pytree.LeafSpec(),
)
self.assertEqual(
python_pytree.TreeSpec(list, None, []),
python_pytree.TreeSpec(list, None, []),
)
self.assertEqual(
python_pytree.TreeSpec(list, None, [python_pytree.treespec_leaf()]),
python_pytree.TreeSpec(list, None, [python_pytree.treespec_leaf()]),
python_pytree.TreeSpec(list, None, [python_pytree.LeafSpec()]),
python_pytree.TreeSpec(list, None, [python_pytree.LeafSpec()]),
)
self.assertFalse(
python_pytree.TreeSpec(tuple, None, [])
@ -889,32 +892,24 @@ if "optree" in sys.modules:
# python_pytree.tree_structure({})
python_pytree.TreeSpec(dict, [], []),
# python_pytree.tree_structure([0])
python_pytree.TreeSpec(list, None, [python_pytree.treespec_leaf()]),
python_pytree.TreeSpec(list, None, [python_leafspec]),
# python_pytree.tree_structure([0, 1])
python_pytree.TreeSpec(
list,
None,
[python_pytree.treespec_leaf(), python_pytree.treespec_leaf()],
[python_leafspec, python_leafspec],
),
# python_pytree.tree_structure((0, 1, 2))
python_pytree.TreeSpec(
tuple,
None,
[
python_pytree.treespec_leaf(),
python_pytree.treespec_leaf(),
python_pytree.treespec_leaf(),
],
[python_leafspec, python_leafspec, python_leafspec],
),
# python_pytree.tree_structure({"a": 0, "b": 1, "c": 2})
python_pytree.TreeSpec(
dict,
["a", "b", "c"],
[
python_pytree.treespec_leaf(),
python_pytree.treespec_leaf(),
python_pytree.treespec_leaf(),
],
[python_leafspec, python_leafspec, python_leafspec],
),
# python_pytree.tree_structure(OrderedDict([("a", (0, 1)), ("b", 2), ("c", {"a": 3, "b": 4, "c": 5})])
python_pytree.TreeSpec(
@ -924,17 +919,13 @@ if "optree" in sys.modules:
python_pytree.TreeSpec(
tuple,
None,
[python_pytree.treespec_leaf(), python_pytree.treespec_leaf()],
[python_leafspec, python_leafspec],
),
python_pytree.treespec_leaf(),
python_leafspec,
python_pytree.TreeSpec(
dict,
["a", "b", "c"],
[
python_pytree.treespec_leaf(),
python_pytree.treespec_leaf(),
python_pytree.treespec_leaf(),
],
[python_leafspec, python_leafspec, python_leafspec],
),
],
),
@ -947,15 +938,12 @@ if "optree" in sys.modules:
tuple,
None,
[
python_pytree.treespec_leaf(),
python_pytree.treespec_leaf(),
python_leafspec,
python_leafspec,
python_pytree.TreeSpec(
list,
None,
[
python_pytree.treespec_leaf(),
python_pytree.treespec_leaf(),
],
[python_leafspec, python_leafspec],
),
],
),
@ -969,12 +957,12 @@ if "optree" in sys.modules:
python_pytree.TreeSpec(
list,
None,
[python_pytree.treespec_leaf(), python_pytree.treespec_leaf()],
[python_leafspec, python_leafspec],
),
python_pytree.TreeSpec(
list,
None,
[python_pytree.treespec_leaf(), python_pytree.treespec_leaf()],
[python_leafspec, python_leafspec],
),
python_pytree.TreeSpec(dict, [], []),
],
@ -1003,7 +991,7 @@ if "optree" in sys.modules:
list,
None,
[
python_pytree.treespec_leaf(),
python_leafspec,
],
),
],
@ -1012,7 +1000,7 @@ if "optree" in sys.modules:
self.assertIsInstance(serialized_spec, str)
def test_pytree_serialize_enum(self):
spec = python_pytree.TreeSpec(dict, TestEnum.A, [python_pytree.treespec_leaf()])
spec = python_pytree.TreeSpec(dict, TestEnum.A, [python_leafspec])
serialized_spec = python_pytree.treespec_dumps(spec)
self.assertIsInstance(serialized_spec, str)
@ -1175,20 +1163,12 @@ if "optree" in sys.modules:
OrderedDict,
[1, 2, 3],
[
python_pytree.TreeSpec(
tuple,
None,
[python_pytree.treespec_leaf(), python_pytree.treespec_leaf()],
),
python_pytree.treespec_leaf(),
python_pytree.TreeSpec(tuple, None, [python_leafspec, python_leafspec]),
python_leafspec,
python_pytree.TreeSpec(
dict,
[4, 5, 6],
[
python_pytree.treespec_leaf(),
python_pytree.treespec_leaf(),
python_pytree.treespec_leaf(),
],
[python_leafspec, python_leafspec, python_leafspec],
),
],
)
@ -1473,7 +1453,7 @@ class TestCxxPytree(TestCase):
raise unittest.SkipTest("C++ pytree tests are not supported in fbcode")
def test_treespec_equality(self):
self.assertEqual(cxx_pytree.treespec_leaf(), cxx_pytree.treespec_leaf())
self.assertEqual(cxx_pytree.LeafSpec(), cxx_pytree.LeafSpec())
def test_treespec_repr(self):
# Check that it looks sane

View File

@ -17,7 +17,7 @@ from ..decorators import substitute_in_graph
if TYPE_CHECKING:
import builtins
from collections.abc import Callable, Iterable, Mapping
from collections.abc import Callable, Iterable
from typing_extensions import Self
@ -349,113 +349,6 @@ if python_pytree._cxx_pytree_dynamo_traceable:
def _is_pytreespec_instance(obj: Any, /) -> TypeIs[PyTreeSpec]:
return isinstance(obj, PyTreeSpec)
@substitute_in_graph( # type: ignore[arg-type]
optree.treespec_leaf,
# We need to disable constant folding here because we want the function to reference the
# PyTreeSpec class defined above, not the one in the C++ module.
can_constant_fold_through=False,
)
def treespec_leaf(
*,
none_is_leaf: bool = False,
namespace: str = "", # unused
) -> PyTreeSpec:
return PyTreeSpec(
(),
None,
None,
(),
None,
none_is_leaf=none_is_leaf,
namespace="",
)
@substitute_in_graph( # type: ignore[arg-type]
optree.treespec_tuple,
# We need to disable constant folding here because we want the function to reference the
# PyTreeSpec class defined above, not the one in the C++ module.
can_constant_fold_through=False,
)
def treespec_tuple(
iterable: Iterable[PyTreeSpec] = (),
/,
*,
none_is_leaf: bool = False,
namespace: str = "",
) -> PyTreeSpec:
children = tuple(iterable)
if any(not _is_pytreespec_instance(child) for child in children):
raise ValueError(f"Expected a tuple of PyTreeSpecs, got: {children!r}.")
if any(child.none_is_leaf != none_is_leaf for child in children):
raise ValueError(
"All children PyTreeSpecs must have the same `none_is_leaf` value "
f"as the parent; expected {none_is_leaf}, got: {children!r}.",
)
if any(child.namespace not in (namespace, "") for child in children):
raise ValueError(
"All children PyTreeSpecs must have the same `namespace` value "
f"as the parent; expected {namespace!r}, got: {children!r}.",
)
handler = optree.register_pytree_node.get(tuple, namespace=namespace) # type: ignore[attr-defined]
assert handler is not None
return PyTreeSpec(
tuple(children),
tuple,
None,
tuple(range(len(children))),
handler.unflatten_func,
none_is_leaf=none_is_leaf,
namespace=namespace,
)
@substitute_in_graph( # type: ignore[arg-type]
optree.treespec_dict,
# We need to disable constant folding here because we want the function to reference the
# PyTreeSpec class defined above, not the one in the C++ module.
can_constant_fold_through=False,
)
def treespec_dict(
mapping: Mapping[Any, PyTreeSpec] | Iterable[tuple[Any, PyTreeSpec]] = (),
/,
*,
none_is_leaf: bool = False,
namespace: str = "",
**kwargs: PyTreeSpec,
) -> PyTreeSpec:
dct = dict(mapping, **kwargs)
if any(not _is_pytreespec_instance(child) for child in dct.values()):
raise ValueError(f"Expected a dictionary of TreeSpecs, got: {dct!r}.")
if any(child.none_is_leaf != none_is_leaf for child in dct.values()):
raise ValueError(
"All children PyTreeSpecs must have the same `none_is_leaf` value "
f"as the parent; expected {none_is_leaf}, got: {dct!r}.",
)
if any(child.namespace not in (namespace, "") for child in dct.values()):
raise ValueError(
"All children PyTreeSpecs must have the same `namespace` value "
f"as the parent; expected {namespace!r}, got: {dct!r}.",
)
(
children,
metadata,
entries,
unflatten_func,
) = optree.tree_flatten_one_level( # type: ignore[assignment,var-annotated]
dct, # type: ignore[arg-type]
none_is_leaf=none_is_leaf,
namespace=namespace,
)
return PyTreeSpec(
tuple(children), # type: ignore[arg-type]
dict,
metadata,
entries,
unflatten_func,
none_is_leaf=none_is_leaf,
namespace=namespace,
)
@substitute_in_graph( # type: ignore[arg-type]
optree.tree_flatten,
# We need to disable constant folding here because we want the function to reference the

View File

@ -3727,7 +3727,9 @@ class SourcelessBuilder:
pass # failthrough to unimplemented branch
elif isinstance(value, torch.fx.graph_module.GraphModule):
return SourcelessGraphModuleVariable(value)
elif isinstance(value, torch.utils._pytree.TreeSpec):
elif isinstance(
value, (torch.utils._pytree.TreeSpec, torch.utils._pytree.LeafSpec)
):
return UserDefinedObjectVariable(value)
elif PlacementVariable.is_placement(value):
return PlacementVariable(value)

View File

@ -1661,8 +1661,8 @@ class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
# We need to have this check this way, because in case init is a TreeSpec and carry
# but carry is only a LeafSpec, these two cannot be compared correctly.
if (
xs_treespec.as_python_constant().is_leaf()
!= _combine_treespec.as_python_constant().is_leaf()
isinstance(xs_treespec.as_python_constant(), pytree.LeafSpec)
!= isinstance(_combine_treespec.as_python_constant(), pytree.LeafSpec)
) or not _make_inlined(tx, pytree.TreeSpec.__eq__)(
xs_treespec, _combine_treespec
).as_python_constant():

View File

@ -1530,10 +1530,10 @@ class GraphModuleSerializer(metaclass=Final):
else:
raise AssertionError("TODO")
def serialize_treespec(self, treespec: pytree.TreeSpec) -> str:
def serialize_treespec(self, treespec):
# We want to additionally save all the field names of the namedtuples in
# case users want to check that the treespec types are equivalent
def store_namedtuple_fields(ts: pytree.TreeSpec) -> None:
def store_namedtuple_fields(ts):
if ts.type is None:
return
if ts.type is namedtuple or pytree.is_namedtuple_class(ts.type):
@ -1555,7 +1555,7 @@ class GraphModuleSerializer(metaclass=Final):
NamedTupleDef(field_names=ts.context._fields)
)
for child in ts.children():
for child in ts.children_specs:
store_namedtuple_fields(child)
serialized_treespec = treespec_dumps(treespec, TREESPEC_VERSION)

View File

@ -158,7 +158,7 @@ class PytreeThunk:
assert spec is not None
self.spec: pytree.TreeSpec = spec
if self.spec.type in {tuple, list} and all(
child.is_leaf() for child in spec.children()
child.is_leaf() for child in spec.children_specs
):
self.is_simple = True
if self.spec.is_leaf():

View File

@ -1590,7 +1590,7 @@ def aot_export_joint_simple(
decompositions=decompositions,
trace_joint=trace_joint,
)
in_spec, _kw_in_spec = in_spec.children()
in_spec, _kw_in_spec = in_spec.children_specs
# At this point, we can just directly return the (joint or inference graph) that we traced.
# First though: a bunch of assertions to make sure that our graph doesn't require
# any calling convention changes compared to the original function.
@ -1617,7 +1617,7 @@ def aot_export_joint_simple(
raise RuntimeError(
f"aot_export_joint_simple requires inputs to be a single list/tuple. in_spec={str(in_spec)}"
)
if not all(child.is_leaf() for child in in_spec.children()):
if not all(child.is_leaf() for child in in_spec.children_specs):
raise RuntimeError(
f"aot_export_joint_simple requires individual inputs not to be pytrees. in_spec={str(in_spec)}"
)
@ -1625,7 +1625,7 @@ def aot_export_joint_simple(
raise RuntimeError(
f"aot_export_joint_simple requires outputs to be a single list/tuple. out_spec={str(out_spec)}"
)
if not all(child.is_leaf() for child in out_spec.children()):
if not all(child.is_leaf() for child in out_spec.children_specs):
raise RuntimeError(
f"aot_export_joint_simple requires individual outputs not to be pytrees. out_spec={str(out_spec)}"
)

View File

@ -469,7 +469,7 @@ def _unlift_graph(
gm,
lifted_inputs,
mutated_outputs,
pytree.treespec_leaf(),
pytree.LeafSpec(),
None,
)
return unlifted_gm

View File

@ -7,7 +7,6 @@ from typing import Any, Optional, Union
import torch
import torch.ao.quantization.pt2e._affine_quantization # noqa: F401
import torch.nn.functional as F
import torch.utils._pytree as pytree
# Makes sure that quantized_decomposed ops are registered
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
@ -15,6 +14,7 @@ from torch.ao.quantization.quantizer import QuantizationAnnotation
from torch.export.unflatten import _assign_attr, _AttrKind
from torch.fx import GraphModule, Node
from torch.nn.utils.fusion import fuse_conv_bn_weights
from torch.utils._pytree import LeafSpec
__all__ = [
@ -477,10 +477,7 @@ def _replace_literals_with_new_placeholders(
exclude_literals = []
in_spec = gm._in_spec
assert in_spec.type is tuple
args_spec = in_spec.child(0)
assert args_spec.type is tuple
args_spec_children = args_spec.children()
args_spec = in_spec.children_specs[0]
for node in gm.graph.nodes:
if node.op == "placeholder":
last_ph = node
@ -495,7 +492,7 @@ def _replace_literals_with_new_placeholders(
else:
ph_node = gm.graph.placeholder("arg" + str(cnt))
new_args.append(ph_node)
args_spec_children.append(pytree.treespec_leaf())
args_spec.children_specs.append(LeafSpec())
cnt += 1
if merge_dup:
literal_to_ph[arg] = ph_node
@ -506,8 +503,8 @@ def _replace_literals_with_new_placeholders(
node.args = new_args
# Update `num_nodes`, `num_leaves`, `num_children`.
args_spec = pytree.treespec_tuple(args_spec_children)
gm._in_spec = in_spec = pytree.treespec_tuple([args_spec, *in_spec.children()[1:]])
args_spec.__post_init__()
in_spec.__post_init__()
return gm

View File

@ -195,16 +195,17 @@ def _construct_inputs(
unflatten_node = _generate_unflatten(gm, tree_unflatten_args, signature.in_spec)
assert signature.in_spec.num_children == 2
assert signature.in_spec.type is tuple
args_spec, kwargs_spec = signature.in_spec.children()
assert args_spec.type is tuple
assert kwargs_spec.type is dict
args_spec = signature.in_spec.children_specs[0]
assert args_spec.context is None
args_node = gm.graph.call_function(operator.getitem, (unflatten_node, 0))
args_nodes = [
gm.graph.call_function(operator.getitem, (args_node, i))
for i in range(args_spec.num_children)
]
kwargs_spec = signature.in_spec.children_specs[1]
assert kwargs_spec.context is not None
kwargs_node = gm.graph.call_function(operator.getitem, (unflatten_node, 1))
kwargs_nodes = {
k: gm.graph.call_function(operator.getitem, (kwargs_node, k))
@ -371,8 +372,8 @@ def _fix_input_output_signature(
if forward_arg_names is None:
forward_arg_names = []
assert signature.in_spec.num_children == 2
arg_spec = signature.in_spec.child(0)
kwarg_spec = signature.in_spec.child(1)
arg_spec = signature.in_spec.children_specs[0]
kwarg_spec = signature.in_spec.children_specs[1]
assert arg_spec.type is tuple
assert kwarg_spec.type is dict
for i in range(arg_spec.num_children):

View File

@ -1533,7 +1533,7 @@ def _strict_export(
# aot_export expect the return type to always be a tuple.
if out_spec.type not in (list, tuple):
out_spec = pytree.treespec_tuple([out_spec])
out_spec = pytree.TreeSpec(tuple, None, [out_spec])
orig_arg_names = gm_torch_level.graph._codegen.pytree_info.orig_args # type: ignore[attr-defined]

View File

@ -15,7 +15,7 @@ def reorder_kwargs(user_kwargs: dict[str, Any], spec: TreeSpec) -> dict[str, Any
# Make sure that the spec is actually shaped like (args, kwargs)
assert spec.type is tuple
assert spec.num_children == 2
kwargs_spec = spec.child(1)
kwargs_spec = spec.children_specs[1]
assert kwargs_spec.type is dict
if set(user_kwargs) != set(kwargs_spec.context):
@ -55,10 +55,10 @@ def is_equivalent(
return False
# Recurse on children
if spec1.num_children != spec2.num_children:
if len(spec1.children_specs) != len(spec2.children_specs):
return False
for child_spec1, child_spec2 in zip(spec1.children(), spec2.children()):
for child_spec1, child_spec2 in zip(spec1.children_specs, spec2.children_specs):
if not is_equivalent(child_spec1, child_spec2, equivalence_fn):
return False

View File

@ -57,11 +57,11 @@ def eq_spec(self: pytree.TreeSpec, other: pytree.TreeSpec) -> bool:
return False
elif a.context != b.context:
return False
if a.num_children != b.num_children:
if len(a.children_specs) != len(b.children_specs):
return False
return all(
_match_normalized_structure(a, b)
for a, b in zip(a.children(), b.children())
for a, b in zip(a.children_specs, b.children_specs)
)
return _match_normalized_structure(self, other)
@ -357,13 +357,13 @@ def _get_codegen(
elif (
in_spec.type is tuple
and in_spec.num_children == 2
and in_spec.child(0).type is tuple
and in_spec.child(1).type is dict
and in_spec.children_specs[0].type is tuple
and in_spec.children_specs[1].type is dict
):
# if in_spec contains the args (tuple) and kwargs (dict)
names = [f"arg_{i}" for i in range(in_spec.child(0).num_children)]
names = [f"arg_{i}" for i in range(in_spec.children_specs[0].num_children)]
# add kwarg names
names.extend(in_spec.child(1).context)
names.extend(in_spec.children_specs[1].context)
else:
names = [f"arg_{i}" for i in range(in_spec.num_children)]

View File

@ -12,16 +12,14 @@ import torch
from torch.utils._pytree import (
_get_node_type,
BUILTIN_TYPES,
KeyPath,
keystr,
LeafSpec,
MappingKey,
SequenceKey,
SUPPORTED_NODES,
tree_iter,
tree_flatten,
tree_map,
tree_map_with_path,
tree_structure,
TreeSpec,
)
from .exported_program import ExportedProgram
@ -657,55 +655,53 @@ def _tree_map_with_path(
case_name="dynamic_shapes_validation",
)
def _compare(
treespec: TreeSpec, other_treespec: TreeSpec, path: KeyPath
) -> None:
def _compare(tree, dynamic_shapes, path):
# raise an error at the point where tree and dynamic_shapes differ,
# including the path to that point and the reason for the difference
rendered_path = keystr(path)
if treespec.is_leaf():
if isinstance(tree, LeafSpec):
return
if other_treespec.is_leaf():
if isinstance(dynamic_shapes, LeafSpec):
raise_mismatch_error(
f"`{tree_name}{rendered_path}` is a {treespec.type}, "
f"`{tree_name}{rendered_path}` is a {tree.type}, "
f"but `dynamic_shapes{rendered_path}` is not"
)
if treespec.type != other_treespec.type:
if tree.type != dynamic_shapes.type:
raise_mismatch_error(
f"`{tree_name}{rendered_path}` is a {treespec.type}, "
f"but `dynamic_shapes{rendered_path}` is a {other_treespec.type}"
f"`{tree_name}{rendered_path}` is a {tree.type}, "
f"but `dynamic_shapes{rendered_path}` is a {dynamic_shapes.type}"
)
if treespec.num_children != other_treespec.num_children:
if len(tree.children_specs) != len(dynamic_shapes.children_specs):
raise_mismatch_error(
f"`{tree_name}{rendered_path}` has {treespec.num_children} elements, "
f"but `dynamic_shapes{rendered_path}` has {other_treespec.num_children} elements"
f"`{tree_name}{rendered_path}` has {len(tree.children_specs)} elements, "
f"but `dynamic_shapes{rendered_path}` has {len(dynamic_shapes.children_specs)} elements"
)
if treespec.type is dict:
if tree.type is dict:
# context, children could be out of order
if set(treespec.context) != set(other_treespec.context):
if sorted(tree.context) != sorted(dynamic_shapes.context):
raise_mismatch_error(
f"`{tree_name}{rendered_path}` has keys {treespec.context}, "
f"but `dynamic_shapes{rendered_path}` has keys {other_treespec.context}"
f"`{tree_name}{rendered_path}` has keys {tree.context}, "
f"but `dynamic_shapes{rendered_path}` has keys {dynamic_shapes.context}"
)
_remap = dict(
zip(other_treespec.context, other_treespec.children())
zip(dynamic_shapes.context, dynamic_shapes.children_specs)
)
other_children = [_remap[k] for k in treespec.context]
dynamic_shapes_children_specs = [_remap[k] for k in tree.context]
else:
other_children = other_treespec.children()
for i, (child, other_child) in enumerate(
zip(treespec.children(), other_children)
dynamic_shapes_children_specs = dynamic_shapes.children_specs
for i, (tree_, dynamic_shapes_) in enumerate(
zip(tree.children_specs, dynamic_shapes_children_specs)
):
_compare(
child,
other_child,
path + (_key(treespec.type, treespec.context, i),),
tree_,
dynamic_shapes_,
path + [_key(tree.type, tree.context, i)],
)
treespec = tree_structure(tree, is_leaf=is_leaf)
_, tree_spec = tree_flatten(tree, is_leaf=is_leaf)
for other_tree in dynamic_shapes:
other_treespec = tree_structure(other_tree, is_leaf)
_compare(treespec, other_treespec, ())
_, other_tree_spec = tree_flatten(other_tree, is_leaf)
_compare(tree_spec, other_tree_spec, [])
raise
@ -1235,7 +1231,10 @@ def _get_dim_name_mapping(
dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any], None],
):
name_to_dim = {}
for dim in tree_iter(dynamic_shapes, is_leaf=lambda x: isinstance(x, Dim)):
for dim in tree_flatten(
dynamic_shapes,
is_leaf=lambda x: isinstance(x, Dim),
)[0]:
if dim is None:
# NOTE: this must denote a non-Tensor or automatic at this point.
continue

View File

@ -391,7 +391,7 @@ def _decompose_and_get_gm_with_new_signature_constants(
# aot_export expect the return type to always be a tuple.
assert out_spec is not None
if out_spec.type not in (list, tuple):
out_spec = pytree.treespec_tuple([out_spec])
out_spec = pytree.TreeSpec(tuple, None, [out_spec])
mod.graph._codegen = _PyTreeCodeGen(
_PyTreeInfo(

View File

@ -1124,10 +1124,10 @@ class _ModuleFrame:
signature = module_call_graph.get(self.child_fqn)
if signature is not None and self.parent is not None:
assert signature.in_spec.num_children == 2
assert signature.in_spec.type is tuple
args_spec, kwargs_spec = signature.in_spec.children()
assert args_spec.type is tuple
assert kwargs_spec.type is dict
args_spec = signature.in_spec.children_specs[0]
kwargs_spec = signature.in_spec.children_specs[1]
assert args_spec.context is None
assert kwargs_spec.context is not None
with self.graph.inserting_after(None):
arg_nodes = [

View File

@ -49,7 +49,7 @@ def tree_flatten_spec(
flatten_fn_spec = SUPPORTED_NODES[spec.type]
child_pytrees = flatten_fn_spec(pytree, spec)
result = []
for child, child_spec in zip(child_pytrees, spec.children()):
for child, child_spec in zip(child_pytrees, spec.children_specs):
flat = tree_flatten_spec(child, child_spec)
result += flat
return result

View File

@ -709,7 +709,7 @@ class Tracer(TracerBase):
root_fn = _patch_function(root_fn, len(args))
flat_args, in_spec = pytree.tree_flatten(tuple(args))
if not all(child.is_leaf() for child in in_spec.children()):
if not all(child.is_leaf() for child in in_spec.children_specs):
# In the case that we have pytree-flattened inputs in
# `concrete_args`, generate a flattening wrapper around the
# original root function and return that.

View File

@ -933,25 +933,24 @@ class _PyTreeCodeGen(CodeGen):
return "\n " + "".join(x + "; " for x in has_annotation) + "\n"
def gen_var_bindings(self, fn_args, free_vars, expanded_def) -> str:
in_spec = self.pytree_info.in_spec
# when kwargs is present, in_spec is tuple(args, kwargs)
has_args_kwargs_tuple = (
in_spec.type is tuple
and in_spec.num_children == 2
and in_spec.child(0).type is tuple
and in_spec.child(1).type is dict
self.pytree_info.in_spec.type is tuple
and self.pytree_info.in_spec.num_children == 2
and self.pytree_info.in_spec.children_specs[0].type is tuple
and self.pytree_info.in_spec.children_specs[1].type is dict
)
fn_kwargs = "{}"
fn_signature = f"[{', '.join(fn_args)}], self._in_spec"
if has_args_kwargs_tuple:
count_args = in_spec.child(0).num_children
count_args = self.pytree_info.in_spec.children_specs[0].num_children
fn_args = self.pytree_info.orig_args[:count_args]
fn_kwargs = (
"{"
+ ", ".join(
f"'{k}':{v}"
for k, v in zip(
in_spec.child(1).context,
self.pytree_info.in_spec.children_specs[1].context,
self.pytree_info.orig_args[count_args:],
)
)

View File

@ -14,9 +14,9 @@ collection support for PyTorch APIs.
import functools
import types
from collections.abc import Callable, Iterable, Mapping
from collections.abc import Callable, Iterable
from typing import Any, Optional, overload, TypeVar, Union
from typing_extensions import deprecated, Self, TypeAlias, TypeIs
from typing_extensions import deprecated, TypeIs
import torch.utils._pytree as python_pytree
from torch.torch_version import TorchVersion as _TorchVersion
@ -42,7 +42,7 @@ if not python_pytree._cxx_pytree_dynamo_traceable:
import optree
from optree import PyTreeSpec # direct import for type annotations
from optree import PyTreeSpec as TreeSpec # direct import for type annotations
__all__ = [
@ -53,7 +53,6 @@ __all__ = [
"DumpableContext",
"ToDumpableContextFn",
"FromDumpableContextFn",
"PyTreeSpec",
"TreeSpec",
"LeafSpec",
"keystr",
@ -101,8 +100,6 @@ U = TypeVar("U")
R = TypeVar("R")
TreeSpec: TypeAlias = PyTreeSpec
Context = Any
PyTree = Any
FlattenFunc = Callable[[PyTree], tuple[list[Any], Context]]
@ -270,30 +267,6 @@ def _is_pytreespec_instance(obj: Any, /) -> TypeIs[TreeSpec]:
return isinstance(obj, TreeSpec)
def treespec_leaf() -> TreeSpec:
"""Make a treespec representing a leaf node."""
return optree.treespec_leaf(none_is_leaf=True, namespace="torch")
def treespec_tuple(iterable: Iterable[TreeSpec] = (), /) -> TreeSpec:
"""Make a tuple treespec from an iterable of child treespecs."""
return optree.treespec_tuple(iterable, none_is_leaf=True, namespace="torch")
def treespec_dict(
mapping: Union[Mapping[Any, TreeSpec], Iterable[tuple[Any, TreeSpec]]] = (),
/,
**kwargs: TreeSpec,
) -> TreeSpec:
"""Make a dict treespec from a dict of child treespecs."""
return optree.treespec_dict(
mapping,
**kwargs,
none_is_leaf=True,
namespace="torch",
)
def tree_is_leaf(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
@ -1012,14 +985,9 @@ class LeafSpecMeta(type(TreeSpec)): # type: ignore[misc]
return _is_pytreespec_instance(instance) and instance.is_leaf()
@deprecated(
"`isinstance(treespec, LeafSpec)` is deprecated, "
"use `isinstance(treespec, TreeSpec)` and `treespec.is_leaf()` instead.",
category=FutureWarning,
)
class LeafSpec(TreeSpec, metaclass=LeafSpecMeta): # type: ignore[misc,final]
def __new__(cls) -> Self:
return treespec_leaf() # type: ignore[return-value]
def __new__(cls) -> "LeafSpec":
return optree.treespec_leaf(none_is_leaf=True) # type: ignore[return-value]
def tree_flatten_with_path(

View File

@ -39,7 +39,7 @@ from typing import (
TypeVar,
Union,
)
from typing_extensions import deprecated, NamedTuple, Self, TypeAlias
from typing_extensions import deprecated, NamedTuple, Self
from torch.torch_version import TorchVersion as _TorchVersion
@ -52,7 +52,6 @@ __all__ = [
"DumpableContext",
"ToDumpableContextFn",
"FromDumpableContextFn",
"PyTreeSpec",
"TreeSpec",
"LeafSpec",
"keystr",
@ -472,7 +471,7 @@ class ConstantNode:
def _is_constant_holder(spec: "TreeSpec") -> bool:
"""Checks if the spec is from a pytree registered with register_constant"""
return isinstance(spec._context, ConstantNode)
return isinstance(spec.context, ConstantNode)
def _retrieve_constant(spec: "TreeSpec") -> Any:
@ -1072,53 +1071,35 @@ def _is_leaf(tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None) -
# context: some context that is useful in unflattening the pytree
# children_specs: specs for each child of the root Node
# num_leaves: the number of leaves
@dataclasses.dataclass(init=False, frozen=True, eq=True, repr=False)
@dataclasses.dataclass(init=True, frozen=True, eq=True, repr=False)
class TreeSpec:
type: Any
_context: Context
_children: list[Self]
context: Context
children_specs: list["TreeSpec"]
num_nodes: int = dataclasses.field(init=False)
num_leaves: int = dataclasses.field(init=False)
num_children: int = dataclasses.field(init=False)
def __init__(
self,
type: Any,
context: Context, # keep for backward compatibility
children: list[Self], # keep for backward compatibility
) -> None:
object.__setattr__(self, "type", type)
object.__setattr__(self, "_context", context)
object.__setattr__(self, "_children", children)
self.__post_init__()
def __post_init__(self) -> None:
if self.type is None:
assert self._context is None
assert len(self._children) == 0
num_nodes = 1
num_leaves = 1
num_children = 0
else:
num_nodes = sum((spec.num_nodes for spec in self._children), start=1)
num_leaves = sum(spec.num_leaves for spec in self._children)
num_children = len(self._children)
num_nodes = sum((spec.num_nodes for spec in self.children_specs), start=1)
num_leaves = sum(spec.num_leaves for spec in self.children_specs)
num_children = len(self.children_specs)
object.__setattr__(self, "num_nodes", num_nodes)
object.__setattr__(self, "num_leaves", num_leaves)
object.__setattr__(self, "num_children", num_children)
def __repr__(self, indent: int = 0) -> str:
repr_prefix: str = f"TreeSpec({self.type.__name__}, {self._context}, ["
repr_prefix: str = f"TreeSpec({self.type.__name__}, {self.context}, ["
children_specs_str: str = ""
if self.num_children > 0:
indent += 2
children_specs_str += self._children[0].__repr__(indent)
children_specs_str += self.children_specs[0].__repr__(indent)
children_specs_str += "," if self.num_children > 1 else ""
children_specs_str += ",".join(
[
"\n" + " " * indent + child.__repr__(indent)
for child in self._children[1:]
for child in self.children_specs[1:]
]
)
repr_suffix: str = f"{children_specs_str}])"
@ -1130,36 +1111,16 @@ class TreeSpec:
elif other.__class__ is self.__class__:
if str(self.type) != str(other.type):
return False
if self._context != other._context:
if self.context != other.context:
return False
elif self._children != other._children:
elif self.children_specs != other.children_specs:
return False
return True
return NotImplemented
@property
def context(self) -> Context:
return self._context
@property
@deprecated(
"`treespec.children_specs` is deprecated. "
"Use `treespec.child(index)` to access a single child, "
"or `treespec.children()` to get all children.",
category=FutureWarning,
)
def children_specs(self) -> list[Self]:
return self._children
def is_leaf(self) -> bool:
return self.num_nodes == 1 and self.num_leaves == 1
def children(self) -> list[Self]:
return self._children.copy()
def child(self, index: int) -> Self:
return self._children[index]
def flatten_up_to(self, tree: PyTree) -> list[PyTree]:
def helper(treespec: TreeSpec, tree: PyTree, subtrees: list[PyTree]) -> None:
if treespec.is_leaf():
@ -1181,7 +1142,7 @@ class TreeSpec:
f"Node arity mismatch; "
f"expected {treespec.num_children}, but got {len(children)}.",
)
if context != treespec._context:
if context != treespec.context:
raise ValueError(
f"Node context mismatch for custom node type {treespec.type!r}.",
)
@ -1206,10 +1167,10 @@ class TreeSpec:
if both_standard_dict:
# dictionary types are compatible with each other
dict_context = (
treespec._context
treespec.context
if treespec.type is not defaultdict
# ignore mismatch of `default_factory` for defaultdict
else treespec._context[1]
else treespec.context[1]
)
expected_keys = dict_context
got_key_set = set(tree)
@ -1230,13 +1191,13 @@ class TreeSpec:
children, context = flatten_fn(tree)
if (
node_type is not deque # ignore mismatch of `maxlen` for deque
) and context != treespec._context:
) and context != treespec.context:
raise ValueError(
f"Node context mismatch for node type {treespec.type!r}; "
f"expected {treespec._context!r}, but got {context!r}.", # namedtuple type mismatch
f"expected {treespec.context!r}, but got {context!r}.", # namedtuple type mismatch
)
for subtree, subspec in zip(children, treespec._children, strict=True):
for subtree, subspec in zip(children, treespec.children_specs):
helper(subspec, subtree, subtrees)
subtrees: list[PyTree] = []
@ -1261,24 +1222,24 @@ class TreeSpec:
start = 0
end = 0
child_pytrees = []
for child_spec in self._children:
for child_spec in self.children_specs:
end += child_spec.num_leaves
child_pytrees.append(child_spec.unflatten(leaves[start:end]))
start = end
return unflatten_fn(child_pytrees, self._context)
return unflatten_fn(child_pytrees, self.context)
def __hash__(self) -> int:
node_type = self.type
if node_type is defaultdict:
default_factory, dict_context = self._context
default_factory, dict_context = self.context
hashable_context = (default_factory, tuple(dict_context))
elif node_type in (dict, OrderedDict):
hashable_context = tuple(self._context)
hashable_context = tuple(self.context)
elif node_type is None or node_type in BUILTIN_TYPES:
hashable_context = self._context
elif isinstance(self._context, ConstantNode):
hashable_context = self._context.value
hashable_context = self.context
elif isinstance(self.context, ConstantNode):
hashable_context = self.context.value
else:
# The context for user-defined node types might not be hashable.
# Ignore it for hashing.
@ -1286,26 +1247,20 @@ class TreeSpec:
# same hash. This might increase the hash collision rate, but we
# don't care about that.
hashable_context = None
return hash((node_type, hashable_context, tuple(self._children)))
PyTreeSpec: TypeAlias = TreeSpec
return hash((node_type, hashable_context, tuple(self.children_specs)))
# NOTE: subclassing a dataclass is subtle. In order to enable reasoning about
# this class with `dataclasses.fields`, etc., while having a simplified
# constructor that takes no argument, we wrap with `dataclass(init=True, ...)`
# again, with fields that have `init=False`.
@deprecated(
"`isinstance(treespec, LeafSpec)` is deprecated, "
"use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.",
category=FutureWarning,
)
@dataclasses.dataclass(init=True, frozen=True, eq=False, repr=False)
class LeafSpec(TreeSpec):
type: Any = dataclasses.field(default=None, init=False)
_context: Context = dataclasses.field(default=None, init=False)
_children: list[Self] = dataclasses.field(default_factory=list, init=False)
context: Context = dataclasses.field(default=None, init=False)
children_specs: list["TreeSpec"] = dataclasses.field(
default_factory=list, init=False
)
def __post_init__(self) -> None:
# Override `__post_init__` for `num_leaves` derivation.
@ -1319,38 +1274,9 @@ class LeafSpec(TreeSpec):
# All leaves are equivalent, so represent with a single object to save on
# object construction time
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", category=FutureWarning, module=__name__, append=False
)
_LEAF_SPEC = LeafSpec()
def treespec_leaf() -> LeafSpec:
"""Make a treespec representing a leaf node."""
return _LEAF_SPEC
def treespec_tuple(iterable: Iterable[TreeSpec] = (), /) -> TreeSpec:
"""Make a tuple treespec from an iterable of child treespecs."""
children = list(iterable)
if any(not isinstance(child, TreeSpec) for child in children):
raise ValueError(f"Expected a tuple of TreeSpec values, got: {children!r}.")
return TreeSpec(tuple, None, children)
def treespec_dict(
mapping: Union[Mapping[Any, TreeSpec], Iterable[tuple[Any, TreeSpec]]] = (),
/,
**kwargs: TreeSpec,
) -> TreeSpec:
"""Make a dict treespec from a dict of child treespecs."""
dct = dict(mapping, **kwargs)
if any(not isinstance(child, TreeSpec) for child in dct.values()):
raise ValueError(f"Expected a dictionary of TreeSpec values, got: {dct!r}.")
return TreeSpec(dict, list(dct.keys()), list(dct.values()))
def tree_flatten(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
@ -1827,15 +1753,15 @@ def _broadcast_to_and_flatten(
return None
flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
child_pytrees, context = flatten_fn(tree)
child_pytrees, ctx = flatten_fn(tree)
# Check if the Node is different from the spec
if len(child_pytrees) != treespec.num_children or context != treespec._context:
if len(child_pytrees) != treespec.num_children or ctx != treespec.context:
return None
# Recursively flatten the children
result: list[Any] = []
for child, child_spec in zip(child_pytrees, treespec._children, strict=True):
for child, child_spec in zip(child_pytrees, treespec.children_specs):
flat = _broadcast_to_and_flatten(child, child_spec, is_leaf=is_leaf)
if flat is not None:
result += flat
@ -1889,7 +1815,7 @@ def _treespec_to_json(treespec: TreeSpec) -> _TreeSpecSchema:
if serialize_node_def.to_dumpable_context is None:
try:
serialized_context = json.dumps(treespec._context, cls=EnumEncoder)
serialized_context = json.dumps(treespec.context, cls=EnumEncoder)
except TypeError as e:
raise TypeError(
"Unable to serialize context. "
@ -1897,9 +1823,9 @@ def _treespec_to_json(treespec: TreeSpec) -> _TreeSpecSchema:
"custom serializer using _register_pytree_node."
) from e
else:
serialized_context = serialize_node_def.to_dumpable_context(treespec._context)
serialized_context = serialize_node_def.to_dumpable_context(treespec.context)
child_schemas = [_treespec_to_json(child) for child in treespec._children]
child_schemas = [_treespec_to_json(child) for child in treespec.children_specs]
return _TreeSpecSchema(serialized_type_name, serialized_context, child_schemas)