mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Revert "[pytree] add treespec_{leaf,tuple,dict} functions for args_spec modification (#160843)"
This reverts commit 108bb224f7.
Reverted https://github.com/pytorch/pytorch/pull/160843 on behalf of https://github.com/atalman due to failing internal builds ([comment](https://github.com/pytorch/pytorch/pull/160843#issuecomment-3474354428))
This commit is contained in:
parent
b71966f67b
commit
85b85f6c2c
|
|
@ -91,13 +91,13 @@ from torch.testing._internal.torchbind_impls import load_torchbind_test_lib
|
|||
from torch.testing._internal.triton_utils import requires_cuda_and_triton, requires_gpu
|
||||
from torch.testing._internal.two_tensor import TwoTensor
|
||||
from torch.utils._pytree import (
|
||||
LeafSpec,
|
||||
register_constant,
|
||||
tree_flatten,
|
||||
tree_map,
|
||||
tree_unflatten,
|
||||
TreeSpec,
|
||||
treespec_dumps,
|
||||
treespec_leaf,
|
||||
treespec_loads,
|
||||
)
|
||||
|
||||
|
|
@ -7791,7 +7791,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
|||
|
||||
dt = MyDataClass(x=3, y=4)
|
||||
flat, spec = tree_flatten(dt)
|
||||
self.assertTrue(spec, treespec_leaf())
|
||||
self.assertTrue(spec, LeafSpec())
|
||||
self.assertTrue(len(flat) == 1)
|
||||
|
||||
torch.export.register_dataclass(
|
||||
|
|
@ -7802,9 +7802,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
|||
flat, spec = tree_flatten(dt)
|
||||
self.assertEqual(
|
||||
spec,
|
||||
TreeSpec(
|
||||
MyDataClass, [["x", "y"], ["z"]], [treespec_leaf(), treespec_leaf()]
|
||||
),
|
||||
TreeSpec(MyDataClass, [["x", "y"], ["z"]], [LeafSpec(), LeafSpec()]),
|
||||
)
|
||||
self.assertEqual(flat, [3, 4])
|
||||
|
||||
|
|
@ -7837,7 +7835,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
|||
TreeSpec(
|
||||
MyOtherDataClass,
|
||||
[["x", "y", "z"], []],
|
||||
[treespec_leaf(), treespec_leaf(), treespec_leaf()],
|
||||
[LeafSpec(), LeafSpec(), LeafSpec()],
|
||||
),
|
||||
)
|
||||
self.assertEqual(flat, [3, 4, None])
|
||||
|
|
|
|||
|
|
@ -65,6 +65,9 @@ class TestEnum(enum.Enum):
|
|||
A = auto()
|
||||
|
||||
|
||||
python_leafspec = python_pytree.LeafSpec()
|
||||
|
||||
|
||||
class TestGenericPytree(TestCase):
|
||||
def test_aligned_public_apis(self):
|
||||
public_apis = python_pytree.__all__
|
||||
|
|
@ -194,7 +197,7 @@ class TestGenericPytree(TestCase):
|
|||
def run_test_with_leaf(leaf):
|
||||
values, treespec = pytree.tree_flatten(leaf)
|
||||
self.assertEqual(values, [leaf])
|
||||
self.assertEqual(treespec, pytree.treespec_leaf())
|
||||
self.assertEqual(treespec, pytree.LeafSpec())
|
||||
|
||||
unflattened = pytree.tree_unflatten(values, treespec)
|
||||
self.assertEqual(unflattened, leaf)
|
||||
|
|
@ -212,7 +215,7 @@ class TestGenericPytree(TestCase):
|
|||
(
|
||||
python_pytree,
|
||||
lambda tup: python_pytree.TreeSpec(
|
||||
tuple, None, [python_pytree.treespec_leaf() for _ in tup]
|
||||
tuple, None, [python_leafspec for _ in tup]
|
||||
),
|
||||
),
|
||||
name="python",
|
||||
|
|
@ -247,7 +250,7 @@ class TestGenericPytree(TestCase):
|
|||
(
|
||||
python_pytree,
|
||||
lambda lst: python_pytree.TreeSpec(
|
||||
list, None, [python_pytree.treespec_leaf() for _ in lst]
|
||||
list, None, [python_leafspec for _ in lst]
|
||||
),
|
||||
),
|
||||
name="python",
|
||||
|
|
@ -283,7 +286,7 @@ class TestGenericPytree(TestCase):
|
|||
lambda dct: python_pytree.TreeSpec(
|
||||
dict,
|
||||
list(dct.keys()),
|
||||
[python_pytree.treespec_leaf() for _ in dct.values()],
|
||||
[python_leafspec for _ in dct.values()],
|
||||
),
|
||||
),
|
||||
name="python",
|
||||
|
|
@ -324,7 +327,7 @@ class TestGenericPytree(TestCase):
|
|||
lambda odict: python_pytree.TreeSpec(
|
||||
OrderedDict,
|
||||
list(odict.keys()),
|
||||
[python_pytree.treespec_leaf() for _ in odict.values()],
|
||||
[python_leafspec for _ in odict.values()],
|
||||
),
|
||||
),
|
||||
name="python",
|
||||
|
|
@ -368,7 +371,7 @@ class TestGenericPytree(TestCase):
|
|||
lambda ddct: python_pytree.TreeSpec(
|
||||
defaultdict,
|
||||
[ddct.default_factory, list(ddct.keys())],
|
||||
[python_pytree.treespec_leaf() for _ in ddct.values()],
|
||||
[python_leafspec for _ in ddct.values()],
|
||||
),
|
||||
),
|
||||
name="python",
|
||||
|
|
@ -410,7 +413,7 @@ class TestGenericPytree(TestCase):
|
|||
(
|
||||
python_pytree,
|
||||
lambda deq: python_pytree.TreeSpec(
|
||||
deque, deq.maxlen, [python_pytree.treespec_leaf() for _ in deq]
|
||||
deque, deq.maxlen, [python_leafspec for _ in deq]
|
||||
),
|
||||
),
|
||||
name="python",
|
||||
|
|
@ -450,7 +453,7 @@ class TestGenericPytree(TestCase):
|
|||
def run_test(tup):
|
||||
if pytree is python_pytree:
|
||||
expected_spec = python_pytree.TreeSpec(
|
||||
namedtuple, Point, [python_pytree.treespec_leaf() for _ in tup]
|
||||
namedtuple, Point, [python_leafspec for _ in tup]
|
||||
)
|
||||
else:
|
||||
expected_spec = cxx_pytree.tree_structure(Point(0, 1))
|
||||
|
|
@ -845,16 +848,16 @@ if "optree" in sys.modules:
|
|||
|
||||
def test_treespec_equality(self):
|
||||
self.assertEqual(
|
||||
python_pytree.treespec_leaf(),
|
||||
python_pytree.treespec_leaf(),
|
||||
python_pytree.LeafSpec(),
|
||||
python_pytree.LeafSpec(),
|
||||
)
|
||||
self.assertEqual(
|
||||
python_pytree.TreeSpec(list, None, []),
|
||||
python_pytree.TreeSpec(list, None, []),
|
||||
)
|
||||
self.assertEqual(
|
||||
python_pytree.TreeSpec(list, None, [python_pytree.treespec_leaf()]),
|
||||
python_pytree.TreeSpec(list, None, [python_pytree.treespec_leaf()]),
|
||||
python_pytree.TreeSpec(list, None, [python_pytree.LeafSpec()]),
|
||||
python_pytree.TreeSpec(list, None, [python_pytree.LeafSpec()]),
|
||||
)
|
||||
self.assertFalse(
|
||||
python_pytree.TreeSpec(tuple, None, [])
|
||||
|
|
@ -889,32 +892,24 @@ if "optree" in sys.modules:
|
|||
# python_pytree.tree_structure({})
|
||||
python_pytree.TreeSpec(dict, [], []),
|
||||
# python_pytree.tree_structure([0])
|
||||
python_pytree.TreeSpec(list, None, [python_pytree.treespec_leaf()]),
|
||||
python_pytree.TreeSpec(list, None, [python_leafspec]),
|
||||
# python_pytree.tree_structure([0, 1])
|
||||
python_pytree.TreeSpec(
|
||||
list,
|
||||
None,
|
||||
[python_pytree.treespec_leaf(), python_pytree.treespec_leaf()],
|
||||
[python_leafspec, python_leafspec],
|
||||
),
|
||||
# python_pytree.tree_structure((0, 1, 2))
|
||||
python_pytree.TreeSpec(
|
||||
tuple,
|
||||
None,
|
||||
[
|
||||
python_pytree.treespec_leaf(),
|
||||
python_pytree.treespec_leaf(),
|
||||
python_pytree.treespec_leaf(),
|
||||
],
|
||||
[python_leafspec, python_leafspec, python_leafspec],
|
||||
),
|
||||
# python_pytree.tree_structure({"a": 0, "b": 1, "c": 2})
|
||||
python_pytree.TreeSpec(
|
||||
dict,
|
||||
["a", "b", "c"],
|
||||
[
|
||||
python_pytree.treespec_leaf(),
|
||||
python_pytree.treespec_leaf(),
|
||||
python_pytree.treespec_leaf(),
|
||||
],
|
||||
[python_leafspec, python_leafspec, python_leafspec],
|
||||
),
|
||||
# python_pytree.tree_structure(OrderedDict([("a", (0, 1)), ("b", 2), ("c", {"a": 3, "b": 4, "c": 5})])
|
||||
python_pytree.TreeSpec(
|
||||
|
|
@ -924,17 +919,13 @@ if "optree" in sys.modules:
|
|||
python_pytree.TreeSpec(
|
||||
tuple,
|
||||
None,
|
||||
[python_pytree.treespec_leaf(), python_pytree.treespec_leaf()],
|
||||
[python_leafspec, python_leafspec],
|
||||
),
|
||||
python_pytree.treespec_leaf(),
|
||||
python_leafspec,
|
||||
python_pytree.TreeSpec(
|
||||
dict,
|
||||
["a", "b", "c"],
|
||||
[
|
||||
python_pytree.treespec_leaf(),
|
||||
python_pytree.treespec_leaf(),
|
||||
python_pytree.treespec_leaf(),
|
||||
],
|
||||
[python_leafspec, python_leafspec, python_leafspec],
|
||||
),
|
||||
],
|
||||
),
|
||||
|
|
@ -947,15 +938,12 @@ if "optree" in sys.modules:
|
|||
tuple,
|
||||
None,
|
||||
[
|
||||
python_pytree.treespec_leaf(),
|
||||
python_pytree.treespec_leaf(),
|
||||
python_leafspec,
|
||||
python_leafspec,
|
||||
python_pytree.TreeSpec(
|
||||
list,
|
||||
None,
|
||||
[
|
||||
python_pytree.treespec_leaf(),
|
||||
python_pytree.treespec_leaf(),
|
||||
],
|
||||
[python_leafspec, python_leafspec],
|
||||
),
|
||||
],
|
||||
),
|
||||
|
|
@ -969,12 +957,12 @@ if "optree" in sys.modules:
|
|||
python_pytree.TreeSpec(
|
||||
list,
|
||||
None,
|
||||
[python_pytree.treespec_leaf(), python_pytree.treespec_leaf()],
|
||||
[python_leafspec, python_leafspec],
|
||||
),
|
||||
python_pytree.TreeSpec(
|
||||
list,
|
||||
None,
|
||||
[python_pytree.treespec_leaf(), python_pytree.treespec_leaf()],
|
||||
[python_leafspec, python_leafspec],
|
||||
),
|
||||
python_pytree.TreeSpec(dict, [], []),
|
||||
],
|
||||
|
|
@ -1003,7 +991,7 @@ if "optree" in sys.modules:
|
|||
list,
|
||||
None,
|
||||
[
|
||||
python_pytree.treespec_leaf(),
|
||||
python_leafspec,
|
||||
],
|
||||
),
|
||||
],
|
||||
|
|
@ -1012,7 +1000,7 @@ if "optree" in sys.modules:
|
|||
self.assertIsInstance(serialized_spec, str)
|
||||
|
||||
def test_pytree_serialize_enum(self):
|
||||
spec = python_pytree.TreeSpec(dict, TestEnum.A, [python_pytree.treespec_leaf()])
|
||||
spec = python_pytree.TreeSpec(dict, TestEnum.A, [python_leafspec])
|
||||
|
||||
serialized_spec = python_pytree.treespec_dumps(spec)
|
||||
self.assertIsInstance(serialized_spec, str)
|
||||
|
|
@ -1175,20 +1163,12 @@ if "optree" in sys.modules:
|
|||
OrderedDict,
|
||||
[1, 2, 3],
|
||||
[
|
||||
python_pytree.TreeSpec(
|
||||
tuple,
|
||||
None,
|
||||
[python_pytree.treespec_leaf(), python_pytree.treespec_leaf()],
|
||||
),
|
||||
python_pytree.treespec_leaf(),
|
||||
python_pytree.TreeSpec(tuple, None, [python_leafspec, python_leafspec]),
|
||||
python_leafspec,
|
||||
python_pytree.TreeSpec(
|
||||
dict,
|
||||
[4, 5, 6],
|
||||
[
|
||||
python_pytree.treespec_leaf(),
|
||||
python_pytree.treespec_leaf(),
|
||||
python_pytree.treespec_leaf(),
|
||||
],
|
||||
[python_leafspec, python_leafspec, python_leafspec],
|
||||
),
|
||||
],
|
||||
)
|
||||
|
|
@ -1473,7 +1453,7 @@ class TestCxxPytree(TestCase):
|
|||
raise unittest.SkipTest("C++ pytree tests are not supported in fbcode")
|
||||
|
||||
def test_treespec_equality(self):
|
||||
self.assertEqual(cxx_pytree.treespec_leaf(), cxx_pytree.treespec_leaf())
|
||||
self.assertEqual(cxx_pytree.LeafSpec(), cxx_pytree.LeafSpec())
|
||||
|
||||
def test_treespec_repr(self):
|
||||
# Check that it looks sane
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from ..decorators import substitute_in_graph
|
|||
|
||||
if TYPE_CHECKING:
|
||||
import builtins
|
||||
from collections.abc import Callable, Iterable, Mapping
|
||||
from collections.abc import Callable, Iterable
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
|
|
@ -349,113 +349,6 @@ if python_pytree._cxx_pytree_dynamo_traceable:
|
|||
def _is_pytreespec_instance(obj: Any, /) -> TypeIs[PyTreeSpec]:
|
||||
return isinstance(obj, PyTreeSpec)
|
||||
|
||||
@substitute_in_graph( # type: ignore[arg-type]
|
||||
optree.treespec_leaf,
|
||||
# We need to disable constant folding here because we want the function to reference the
|
||||
# PyTreeSpec class defined above, not the one in the C++ module.
|
||||
can_constant_fold_through=False,
|
||||
)
|
||||
def treespec_leaf(
|
||||
*,
|
||||
none_is_leaf: bool = False,
|
||||
namespace: str = "", # unused
|
||||
) -> PyTreeSpec:
|
||||
return PyTreeSpec(
|
||||
(),
|
||||
None,
|
||||
None,
|
||||
(),
|
||||
None,
|
||||
none_is_leaf=none_is_leaf,
|
||||
namespace="",
|
||||
)
|
||||
|
||||
@substitute_in_graph( # type: ignore[arg-type]
|
||||
optree.treespec_tuple,
|
||||
# We need to disable constant folding here because we want the function to reference the
|
||||
# PyTreeSpec class defined above, not the one in the C++ module.
|
||||
can_constant_fold_through=False,
|
||||
)
|
||||
def treespec_tuple(
|
||||
iterable: Iterable[PyTreeSpec] = (),
|
||||
/,
|
||||
*,
|
||||
none_is_leaf: bool = False,
|
||||
namespace: str = "",
|
||||
) -> PyTreeSpec:
|
||||
children = tuple(iterable)
|
||||
if any(not _is_pytreespec_instance(child) for child in children):
|
||||
raise ValueError(f"Expected a tuple of PyTreeSpecs, got: {children!r}.")
|
||||
if any(child.none_is_leaf != none_is_leaf for child in children):
|
||||
raise ValueError(
|
||||
"All children PyTreeSpecs must have the same `none_is_leaf` value "
|
||||
f"as the parent; expected {none_is_leaf}, got: {children!r}.",
|
||||
)
|
||||
if any(child.namespace not in (namespace, "") for child in children):
|
||||
raise ValueError(
|
||||
"All children PyTreeSpecs must have the same `namespace` value "
|
||||
f"as the parent; expected {namespace!r}, got: {children!r}.",
|
||||
)
|
||||
handler = optree.register_pytree_node.get(tuple, namespace=namespace) # type: ignore[attr-defined]
|
||||
assert handler is not None
|
||||
return PyTreeSpec(
|
||||
tuple(children),
|
||||
tuple,
|
||||
None,
|
||||
tuple(range(len(children))),
|
||||
handler.unflatten_func,
|
||||
none_is_leaf=none_is_leaf,
|
||||
namespace=namespace,
|
||||
)
|
||||
|
||||
@substitute_in_graph( # type: ignore[arg-type]
|
||||
optree.treespec_dict,
|
||||
# We need to disable constant folding here because we want the function to reference the
|
||||
# PyTreeSpec class defined above, not the one in the C++ module.
|
||||
can_constant_fold_through=False,
|
||||
)
|
||||
def treespec_dict(
|
||||
mapping: Mapping[Any, PyTreeSpec] | Iterable[tuple[Any, PyTreeSpec]] = (),
|
||||
/,
|
||||
*,
|
||||
none_is_leaf: bool = False,
|
||||
namespace: str = "",
|
||||
**kwargs: PyTreeSpec,
|
||||
) -> PyTreeSpec:
|
||||
dct = dict(mapping, **kwargs)
|
||||
if any(not _is_pytreespec_instance(child) for child in dct.values()):
|
||||
raise ValueError(f"Expected a dictionary of TreeSpecs, got: {dct!r}.")
|
||||
if any(child.none_is_leaf != none_is_leaf for child in dct.values()):
|
||||
raise ValueError(
|
||||
"All children PyTreeSpecs must have the same `none_is_leaf` value "
|
||||
f"as the parent; expected {none_is_leaf}, got: {dct!r}.",
|
||||
)
|
||||
if any(child.namespace not in (namespace, "") for child in dct.values()):
|
||||
raise ValueError(
|
||||
"All children PyTreeSpecs must have the same `namespace` value "
|
||||
f"as the parent; expected {namespace!r}, got: {dct!r}.",
|
||||
)
|
||||
|
||||
(
|
||||
children,
|
||||
metadata,
|
||||
entries,
|
||||
unflatten_func,
|
||||
) = optree.tree_flatten_one_level( # type: ignore[assignment,var-annotated]
|
||||
dct, # type: ignore[arg-type]
|
||||
none_is_leaf=none_is_leaf,
|
||||
namespace=namespace,
|
||||
)
|
||||
return PyTreeSpec(
|
||||
tuple(children), # type: ignore[arg-type]
|
||||
dict,
|
||||
metadata,
|
||||
entries,
|
||||
unflatten_func,
|
||||
none_is_leaf=none_is_leaf,
|
||||
namespace=namespace,
|
||||
)
|
||||
|
||||
@substitute_in_graph( # type: ignore[arg-type]
|
||||
optree.tree_flatten,
|
||||
# We need to disable constant folding here because we want the function to reference the
|
||||
|
|
|
|||
|
|
@ -3727,7 +3727,9 @@ class SourcelessBuilder:
|
|||
pass # failthrough to unimplemented branch
|
||||
elif isinstance(value, torch.fx.graph_module.GraphModule):
|
||||
return SourcelessGraphModuleVariable(value)
|
||||
elif isinstance(value, torch.utils._pytree.TreeSpec):
|
||||
elif isinstance(
|
||||
value, (torch.utils._pytree.TreeSpec, torch.utils._pytree.LeafSpec)
|
||||
):
|
||||
return UserDefinedObjectVariable(value)
|
||||
elif PlacementVariable.is_placement(value):
|
||||
return PlacementVariable(value)
|
||||
|
|
|
|||
|
|
@ -1661,8 +1661,8 @@ class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
|||
# We need to have this check this way, because in case init is a TreeSpec and carry
|
||||
# but carry is only a LeafSpec, these two cannot be compared correctly.
|
||||
if (
|
||||
xs_treespec.as_python_constant().is_leaf()
|
||||
!= _combine_treespec.as_python_constant().is_leaf()
|
||||
isinstance(xs_treespec.as_python_constant(), pytree.LeafSpec)
|
||||
!= isinstance(_combine_treespec.as_python_constant(), pytree.LeafSpec)
|
||||
) or not _make_inlined(tx, pytree.TreeSpec.__eq__)(
|
||||
xs_treespec, _combine_treespec
|
||||
).as_python_constant():
|
||||
|
|
|
|||
|
|
@ -1530,10 +1530,10 @@ class GraphModuleSerializer(metaclass=Final):
|
|||
else:
|
||||
raise AssertionError("TODO")
|
||||
|
||||
def serialize_treespec(self, treespec: pytree.TreeSpec) -> str:
|
||||
def serialize_treespec(self, treespec):
|
||||
# We want to additionally save all the field names of the namedtuples in
|
||||
# case users want to check that the treespec types are equivalent
|
||||
def store_namedtuple_fields(ts: pytree.TreeSpec) -> None:
|
||||
def store_namedtuple_fields(ts):
|
||||
if ts.type is None:
|
||||
return
|
||||
if ts.type is namedtuple or pytree.is_namedtuple_class(ts.type):
|
||||
|
|
@ -1555,7 +1555,7 @@ class GraphModuleSerializer(metaclass=Final):
|
|||
NamedTupleDef(field_names=ts.context._fields)
|
||||
)
|
||||
|
||||
for child in ts.children():
|
||||
for child in ts.children_specs:
|
||||
store_namedtuple_fields(child)
|
||||
|
||||
serialized_treespec = treespec_dumps(treespec, TREESPEC_VERSION)
|
||||
|
|
|
|||
|
|
@ -158,7 +158,7 @@ class PytreeThunk:
|
|||
assert spec is not None
|
||||
self.spec: pytree.TreeSpec = spec
|
||||
if self.spec.type in {tuple, list} and all(
|
||||
child.is_leaf() for child in spec.children()
|
||||
child.is_leaf() for child in spec.children_specs
|
||||
):
|
||||
self.is_simple = True
|
||||
if self.spec.is_leaf():
|
||||
|
|
|
|||
|
|
@ -1590,7 +1590,7 @@ def aot_export_joint_simple(
|
|||
decompositions=decompositions,
|
||||
trace_joint=trace_joint,
|
||||
)
|
||||
in_spec, _kw_in_spec = in_spec.children()
|
||||
in_spec, _kw_in_spec = in_spec.children_specs
|
||||
# At this point, we can just directly return the (joint or inference graph) that we traced.
|
||||
# First though: a bunch of assertions to make sure that our graph doesn't require
|
||||
# any calling convention changes compared to the original function.
|
||||
|
|
@ -1617,7 +1617,7 @@ def aot_export_joint_simple(
|
|||
raise RuntimeError(
|
||||
f"aot_export_joint_simple requires inputs to be a single list/tuple. in_spec={str(in_spec)}"
|
||||
)
|
||||
if not all(child.is_leaf() for child in in_spec.children()):
|
||||
if not all(child.is_leaf() for child in in_spec.children_specs):
|
||||
raise RuntimeError(
|
||||
f"aot_export_joint_simple requires individual inputs not to be pytrees. in_spec={str(in_spec)}"
|
||||
)
|
||||
|
|
@ -1625,7 +1625,7 @@ def aot_export_joint_simple(
|
|||
raise RuntimeError(
|
||||
f"aot_export_joint_simple requires outputs to be a single list/tuple. out_spec={str(out_spec)}"
|
||||
)
|
||||
if not all(child.is_leaf() for child in out_spec.children()):
|
||||
if not all(child.is_leaf() for child in out_spec.children_specs):
|
||||
raise RuntimeError(
|
||||
f"aot_export_joint_simple requires individual outputs not to be pytrees. out_spec={str(out_spec)}"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -469,7 +469,7 @@ def _unlift_graph(
|
|||
gm,
|
||||
lifted_inputs,
|
||||
mutated_outputs,
|
||||
pytree.treespec_leaf(),
|
||||
pytree.LeafSpec(),
|
||||
None,
|
||||
)
|
||||
return unlifted_gm
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@ from typing import Any, Optional, Union
|
|||
import torch
|
||||
import torch.ao.quantization.pt2e._affine_quantization # noqa: F401
|
||||
import torch.nn.functional as F
|
||||
import torch.utils._pytree as pytree
|
||||
|
||||
# Makes sure that quantized_decomposed ops are registered
|
||||
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
|
||||
|
|
@ -15,6 +14,7 @@ from torch.ao.quantization.quantizer import QuantizationAnnotation
|
|||
from torch.export.unflatten import _assign_attr, _AttrKind
|
||||
from torch.fx import GraphModule, Node
|
||||
from torch.nn.utils.fusion import fuse_conv_bn_weights
|
||||
from torch.utils._pytree import LeafSpec
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
|
@ -477,10 +477,7 @@ def _replace_literals_with_new_placeholders(
|
|||
exclude_literals = []
|
||||
|
||||
in_spec = gm._in_spec
|
||||
assert in_spec.type is tuple
|
||||
args_spec = in_spec.child(0)
|
||||
assert args_spec.type is tuple
|
||||
args_spec_children = args_spec.children()
|
||||
args_spec = in_spec.children_specs[0]
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == "placeholder":
|
||||
last_ph = node
|
||||
|
|
@ -495,7 +492,7 @@ def _replace_literals_with_new_placeholders(
|
|||
else:
|
||||
ph_node = gm.graph.placeholder("arg" + str(cnt))
|
||||
new_args.append(ph_node)
|
||||
args_spec_children.append(pytree.treespec_leaf())
|
||||
args_spec.children_specs.append(LeafSpec())
|
||||
cnt += 1
|
||||
if merge_dup:
|
||||
literal_to_ph[arg] = ph_node
|
||||
|
|
@ -506,8 +503,8 @@ def _replace_literals_with_new_placeholders(
|
|||
node.args = new_args
|
||||
|
||||
# Update `num_nodes`, `num_leaves`, `num_children`.
|
||||
args_spec = pytree.treespec_tuple(args_spec_children)
|
||||
gm._in_spec = in_spec = pytree.treespec_tuple([args_spec, *in_spec.children()[1:]])
|
||||
args_spec.__post_init__()
|
||||
in_spec.__post_init__()
|
||||
return gm
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -195,16 +195,17 @@ def _construct_inputs(
|
|||
unflatten_node = _generate_unflatten(gm, tree_unflatten_args, signature.in_spec)
|
||||
|
||||
assert signature.in_spec.num_children == 2
|
||||
assert signature.in_spec.type is tuple
|
||||
args_spec, kwargs_spec = signature.in_spec.children()
|
||||
assert args_spec.type is tuple
|
||||
assert kwargs_spec.type is dict
|
||||
|
||||
args_spec = signature.in_spec.children_specs[0]
|
||||
assert args_spec.context is None
|
||||
args_node = gm.graph.call_function(operator.getitem, (unflatten_node, 0))
|
||||
args_nodes = [
|
||||
gm.graph.call_function(operator.getitem, (args_node, i))
|
||||
for i in range(args_spec.num_children)
|
||||
]
|
||||
|
||||
kwargs_spec = signature.in_spec.children_specs[1]
|
||||
assert kwargs_spec.context is not None
|
||||
kwargs_node = gm.graph.call_function(operator.getitem, (unflatten_node, 1))
|
||||
kwargs_nodes = {
|
||||
k: gm.graph.call_function(operator.getitem, (kwargs_node, k))
|
||||
|
|
@ -371,8 +372,8 @@ def _fix_input_output_signature(
|
|||
if forward_arg_names is None:
|
||||
forward_arg_names = []
|
||||
assert signature.in_spec.num_children == 2
|
||||
arg_spec = signature.in_spec.child(0)
|
||||
kwarg_spec = signature.in_spec.child(1)
|
||||
arg_spec = signature.in_spec.children_specs[0]
|
||||
kwarg_spec = signature.in_spec.children_specs[1]
|
||||
assert arg_spec.type is tuple
|
||||
assert kwarg_spec.type is dict
|
||||
for i in range(arg_spec.num_children):
|
||||
|
|
|
|||
|
|
@ -1533,7 +1533,7 @@ def _strict_export(
|
|||
|
||||
# aot_export expect the return type to always be a tuple.
|
||||
if out_spec.type not in (list, tuple):
|
||||
out_spec = pytree.treespec_tuple([out_spec])
|
||||
out_spec = pytree.TreeSpec(tuple, None, [out_spec])
|
||||
|
||||
orig_arg_names = gm_torch_level.graph._codegen.pytree_info.orig_args # type: ignore[attr-defined]
|
||||
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ def reorder_kwargs(user_kwargs: dict[str, Any], spec: TreeSpec) -> dict[str, Any
|
|||
# Make sure that the spec is actually shaped like (args, kwargs)
|
||||
assert spec.type is tuple
|
||||
assert spec.num_children == 2
|
||||
kwargs_spec = spec.child(1)
|
||||
kwargs_spec = spec.children_specs[1]
|
||||
assert kwargs_spec.type is dict
|
||||
|
||||
if set(user_kwargs) != set(kwargs_spec.context):
|
||||
|
|
@ -55,10 +55,10 @@ def is_equivalent(
|
|||
return False
|
||||
|
||||
# Recurse on children
|
||||
if spec1.num_children != spec2.num_children:
|
||||
if len(spec1.children_specs) != len(spec2.children_specs):
|
||||
return False
|
||||
|
||||
for child_spec1, child_spec2 in zip(spec1.children(), spec2.children()):
|
||||
for child_spec1, child_spec2 in zip(spec1.children_specs, spec2.children_specs):
|
||||
if not is_equivalent(child_spec1, child_spec2, equivalence_fn):
|
||||
return False
|
||||
|
||||
|
|
|
|||
|
|
@ -57,11 +57,11 @@ def eq_spec(self: pytree.TreeSpec, other: pytree.TreeSpec) -> bool:
|
|||
return False
|
||||
elif a.context != b.context:
|
||||
return False
|
||||
if a.num_children != b.num_children:
|
||||
if len(a.children_specs) != len(b.children_specs):
|
||||
return False
|
||||
return all(
|
||||
_match_normalized_structure(a, b)
|
||||
for a, b in zip(a.children(), b.children())
|
||||
for a, b in zip(a.children_specs, b.children_specs)
|
||||
)
|
||||
|
||||
return _match_normalized_structure(self, other)
|
||||
|
|
@ -357,13 +357,13 @@ def _get_codegen(
|
|||
elif (
|
||||
in_spec.type is tuple
|
||||
and in_spec.num_children == 2
|
||||
and in_spec.child(0).type is tuple
|
||||
and in_spec.child(1).type is dict
|
||||
and in_spec.children_specs[0].type is tuple
|
||||
and in_spec.children_specs[1].type is dict
|
||||
):
|
||||
# if in_spec contains the args (tuple) and kwargs (dict)
|
||||
names = [f"arg_{i}" for i in range(in_spec.child(0).num_children)]
|
||||
names = [f"arg_{i}" for i in range(in_spec.children_specs[0].num_children)]
|
||||
# add kwarg names
|
||||
names.extend(in_spec.child(1).context)
|
||||
names.extend(in_spec.children_specs[1].context)
|
||||
else:
|
||||
names = [f"arg_{i}" for i in range(in_spec.num_children)]
|
||||
|
||||
|
|
|
|||
|
|
@ -12,16 +12,14 @@ import torch
|
|||
from torch.utils._pytree import (
|
||||
_get_node_type,
|
||||
BUILTIN_TYPES,
|
||||
KeyPath,
|
||||
keystr,
|
||||
LeafSpec,
|
||||
MappingKey,
|
||||
SequenceKey,
|
||||
SUPPORTED_NODES,
|
||||
tree_iter,
|
||||
tree_flatten,
|
||||
tree_map,
|
||||
tree_map_with_path,
|
||||
tree_structure,
|
||||
TreeSpec,
|
||||
)
|
||||
|
||||
from .exported_program import ExportedProgram
|
||||
|
|
@ -657,55 +655,53 @@ def _tree_map_with_path(
|
|||
case_name="dynamic_shapes_validation",
|
||||
)
|
||||
|
||||
def _compare(
|
||||
treespec: TreeSpec, other_treespec: TreeSpec, path: KeyPath
|
||||
) -> None:
|
||||
def _compare(tree, dynamic_shapes, path):
|
||||
# raise an error at the point where tree and dynamic_shapes differ,
|
||||
# including the path to that point and the reason for the difference
|
||||
rendered_path = keystr(path)
|
||||
if treespec.is_leaf():
|
||||
if isinstance(tree, LeafSpec):
|
||||
return
|
||||
if other_treespec.is_leaf():
|
||||
if isinstance(dynamic_shapes, LeafSpec):
|
||||
raise_mismatch_error(
|
||||
f"`{tree_name}{rendered_path}` is a {treespec.type}, "
|
||||
f"`{tree_name}{rendered_path}` is a {tree.type}, "
|
||||
f"but `dynamic_shapes{rendered_path}` is not"
|
||||
)
|
||||
if treespec.type != other_treespec.type:
|
||||
if tree.type != dynamic_shapes.type:
|
||||
raise_mismatch_error(
|
||||
f"`{tree_name}{rendered_path}` is a {treespec.type}, "
|
||||
f"but `dynamic_shapes{rendered_path}` is a {other_treespec.type}"
|
||||
f"`{tree_name}{rendered_path}` is a {tree.type}, "
|
||||
f"but `dynamic_shapes{rendered_path}` is a {dynamic_shapes.type}"
|
||||
)
|
||||
if treespec.num_children != other_treespec.num_children:
|
||||
if len(tree.children_specs) != len(dynamic_shapes.children_specs):
|
||||
raise_mismatch_error(
|
||||
f"`{tree_name}{rendered_path}` has {treespec.num_children} elements, "
|
||||
f"but `dynamic_shapes{rendered_path}` has {other_treespec.num_children} elements"
|
||||
f"`{tree_name}{rendered_path}` has {len(tree.children_specs)} elements, "
|
||||
f"but `dynamic_shapes{rendered_path}` has {len(dynamic_shapes.children_specs)} elements"
|
||||
)
|
||||
if treespec.type is dict:
|
||||
if tree.type is dict:
|
||||
# context, children could be out of order
|
||||
if set(treespec.context) != set(other_treespec.context):
|
||||
if sorted(tree.context) != sorted(dynamic_shapes.context):
|
||||
raise_mismatch_error(
|
||||
f"`{tree_name}{rendered_path}` has keys {treespec.context}, "
|
||||
f"but `dynamic_shapes{rendered_path}` has keys {other_treespec.context}"
|
||||
f"`{tree_name}{rendered_path}` has keys {tree.context}, "
|
||||
f"but `dynamic_shapes{rendered_path}` has keys {dynamic_shapes.context}"
|
||||
)
|
||||
_remap = dict(
|
||||
zip(other_treespec.context, other_treespec.children())
|
||||
zip(dynamic_shapes.context, dynamic_shapes.children_specs)
|
||||
)
|
||||
other_children = [_remap[k] for k in treespec.context]
|
||||
dynamic_shapes_children_specs = [_remap[k] for k in tree.context]
|
||||
else:
|
||||
other_children = other_treespec.children()
|
||||
for i, (child, other_child) in enumerate(
|
||||
zip(treespec.children(), other_children)
|
||||
dynamic_shapes_children_specs = dynamic_shapes.children_specs
|
||||
for i, (tree_, dynamic_shapes_) in enumerate(
|
||||
zip(tree.children_specs, dynamic_shapes_children_specs)
|
||||
):
|
||||
_compare(
|
||||
child,
|
||||
other_child,
|
||||
path + (_key(treespec.type, treespec.context, i),),
|
||||
tree_,
|
||||
dynamic_shapes_,
|
||||
path + [_key(tree.type, tree.context, i)],
|
||||
)
|
||||
|
||||
treespec = tree_structure(tree, is_leaf=is_leaf)
|
||||
_, tree_spec = tree_flatten(tree, is_leaf=is_leaf)
|
||||
for other_tree in dynamic_shapes:
|
||||
other_treespec = tree_structure(other_tree, is_leaf)
|
||||
_compare(treespec, other_treespec, ())
|
||||
_, other_tree_spec = tree_flatten(other_tree, is_leaf)
|
||||
_compare(tree_spec, other_tree_spec, [])
|
||||
raise
|
||||
|
||||
|
||||
|
|
@ -1235,7 +1231,10 @@ def _get_dim_name_mapping(
|
|||
dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any], None],
|
||||
):
|
||||
name_to_dim = {}
|
||||
for dim in tree_iter(dynamic_shapes, is_leaf=lambda x: isinstance(x, Dim)):
|
||||
for dim in tree_flatten(
|
||||
dynamic_shapes,
|
||||
is_leaf=lambda x: isinstance(x, Dim),
|
||||
)[0]:
|
||||
if dim is None:
|
||||
# NOTE: this must denote a non-Tensor or automatic at this point.
|
||||
continue
|
||||
|
|
|
|||
|
|
@ -391,7 +391,7 @@ def _decompose_and_get_gm_with_new_signature_constants(
|
|||
# aot_export expect the return type to always be a tuple.
|
||||
assert out_spec is not None
|
||||
if out_spec.type not in (list, tuple):
|
||||
out_spec = pytree.treespec_tuple([out_spec])
|
||||
out_spec = pytree.TreeSpec(tuple, None, [out_spec])
|
||||
|
||||
mod.graph._codegen = _PyTreeCodeGen(
|
||||
_PyTreeInfo(
|
||||
|
|
|
|||
|
|
@ -1124,10 +1124,10 @@ class _ModuleFrame:
|
|||
signature = module_call_graph.get(self.child_fqn)
|
||||
if signature is not None and self.parent is not None:
|
||||
assert signature.in_spec.num_children == 2
|
||||
assert signature.in_spec.type is tuple
|
||||
args_spec, kwargs_spec = signature.in_spec.children()
|
||||
assert args_spec.type is tuple
|
||||
assert kwargs_spec.type is dict
|
||||
args_spec = signature.in_spec.children_specs[0]
|
||||
kwargs_spec = signature.in_spec.children_specs[1]
|
||||
assert args_spec.context is None
|
||||
assert kwargs_spec.context is not None
|
||||
|
||||
with self.graph.inserting_after(None):
|
||||
arg_nodes = [
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ def tree_flatten_spec(
|
|||
flatten_fn_spec = SUPPORTED_NODES[spec.type]
|
||||
child_pytrees = flatten_fn_spec(pytree, spec)
|
||||
result = []
|
||||
for child, child_spec in zip(child_pytrees, spec.children()):
|
||||
for child, child_spec in zip(child_pytrees, spec.children_specs):
|
||||
flat = tree_flatten_spec(child, child_spec)
|
||||
result += flat
|
||||
return result
|
||||
|
|
|
|||
|
|
@ -709,7 +709,7 @@ class Tracer(TracerBase):
|
|||
root_fn = _patch_function(root_fn, len(args))
|
||||
|
||||
flat_args, in_spec = pytree.tree_flatten(tuple(args))
|
||||
if not all(child.is_leaf() for child in in_spec.children()):
|
||||
if not all(child.is_leaf() for child in in_spec.children_specs):
|
||||
# In the case that we have pytree-flattened inputs in
|
||||
# `concrete_args`, generate a flattening wrapper around the
|
||||
# original root function and return that.
|
||||
|
|
|
|||
|
|
@ -933,25 +933,24 @@ class _PyTreeCodeGen(CodeGen):
|
|||
return "\n " + "".join(x + "; " for x in has_annotation) + "\n"
|
||||
|
||||
def gen_var_bindings(self, fn_args, free_vars, expanded_def) -> str:
|
||||
in_spec = self.pytree_info.in_spec
|
||||
# when kwargs is present, in_spec is tuple(args, kwargs)
|
||||
has_args_kwargs_tuple = (
|
||||
in_spec.type is tuple
|
||||
and in_spec.num_children == 2
|
||||
and in_spec.child(0).type is tuple
|
||||
and in_spec.child(1).type is dict
|
||||
self.pytree_info.in_spec.type is tuple
|
||||
and self.pytree_info.in_spec.num_children == 2
|
||||
and self.pytree_info.in_spec.children_specs[0].type is tuple
|
||||
and self.pytree_info.in_spec.children_specs[1].type is dict
|
||||
)
|
||||
fn_kwargs = "{}"
|
||||
fn_signature = f"[{', '.join(fn_args)}], self._in_spec"
|
||||
if has_args_kwargs_tuple:
|
||||
count_args = in_spec.child(0).num_children
|
||||
count_args = self.pytree_info.in_spec.children_specs[0].num_children
|
||||
fn_args = self.pytree_info.orig_args[:count_args]
|
||||
fn_kwargs = (
|
||||
"{"
|
||||
+ ", ".join(
|
||||
f"'{k}':{v}"
|
||||
for k, v in zip(
|
||||
in_spec.child(1).context,
|
||||
self.pytree_info.in_spec.children_specs[1].context,
|
||||
self.pytree_info.orig_args[count_args:],
|
||||
)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -14,9 +14,9 @@ collection support for PyTorch APIs.
|
|||
|
||||
import functools
|
||||
import types
|
||||
from collections.abc import Callable, Iterable, Mapping
|
||||
from collections.abc import Callable, Iterable
|
||||
from typing import Any, Optional, overload, TypeVar, Union
|
||||
from typing_extensions import deprecated, Self, TypeAlias, TypeIs
|
||||
from typing_extensions import deprecated, TypeIs
|
||||
|
||||
import torch.utils._pytree as python_pytree
|
||||
from torch.torch_version import TorchVersion as _TorchVersion
|
||||
|
|
@ -42,7 +42,7 @@ if not python_pytree._cxx_pytree_dynamo_traceable:
|
|||
|
||||
|
||||
import optree
|
||||
from optree import PyTreeSpec # direct import for type annotations
|
||||
from optree import PyTreeSpec as TreeSpec # direct import for type annotations
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
|
@ -53,7 +53,6 @@ __all__ = [
|
|||
"DumpableContext",
|
||||
"ToDumpableContextFn",
|
||||
"FromDumpableContextFn",
|
||||
"PyTreeSpec",
|
||||
"TreeSpec",
|
||||
"LeafSpec",
|
||||
"keystr",
|
||||
|
|
@ -101,8 +100,6 @@ U = TypeVar("U")
|
|||
R = TypeVar("R")
|
||||
|
||||
|
||||
TreeSpec: TypeAlias = PyTreeSpec
|
||||
|
||||
Context = Any
|
||||
PyTree = Any
|
||||
FlattenFunc = Callable[[PyTree], tuple[list[Any], Context]]
|
||||
|
|
@ -270,30 +267,6 @@ def _is_pytreespec_instance(obj: Any, /) -> TypeIs[TreeSpec]:
|
|||
return isinstance(obj, TreeSpec)
|
||||
|
||||
|
||||
def treespec_leaf() -> TreeSpec:
|
||||
"""Make a treespec representing a leaf node."""
|
||||
return optree.treespec_leaf(none_is_leaf=True, namespace="torch")
|
||||
|
||||
|
||||
def treespec_tuple(iterable: Iterable[TreeSpec] = (), /) -> TreeSpec:
|
||||
"""Make a tuple treespec from an iterable of child treespecs."""
|
||||
return optree.treespec_tuple(iterable, none_is_leaf=True, namespace="torch")
|
||||
|
||||
|
||||
def treespec_dict(
|
||||
mapping: Union[Mapping[Any, TreeSpec], Iterable[tuple[Any, TreeSpec]]] = (),
|
||||
/,
|
||||
**kwargs: TreeSpec,
|
||||
) -> TreeSpec:
|
||||
"""Make a dict treespec from a dict of child treespecs."""
|
||||
return optree.treespec_dict(
|
||||
mapping,
|
||||
**kwargs,
|
||||
none_is_leaf=True,
|
||||
namespace="torch",
|
||||
)
|
||||
|
||||
|
||||
def tree_is_leaf(
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
|
|
@ -1012,14 +985,9 @@ class LeafSpecMeta(type(TreeSpec)): # type: ignore[misc]
|
|||
return _is_pytreespec_instance(instance) and instance.is_leaf()
|
||||
|
||||
|
||||
@deprecated(
|
||||
"`isinstance(treespec, LeafSpec)` is deprecated, "
|
||||
"use `isinstance(treespec, TreeSpec)` and `treespec.is_leaf()` instead.",
|
||||
category=FutureWarning,
|
||||
)
|
||||
class LeafSpec(TreeSpec, metaclass=LeafSpecMeta): # type: ignore[misc,final]
|
||||
def __new__(cls) -> Self:
|
||||
return treespec_leaf() # type: ignore[return-value]
|
||||
def __new__(cls) -> "LeafSpec":
|
||||
return optree.treespec_leaf(none_is_leaf=True) # type: ignore[return-value]
|
||||
|
||||
|
||||
def tree_flatten_with_path(
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ from typing import (
|
|||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from typing_extensions import deprecated, NamedTuple, Self, TypeAlias
|
||||
from typing_extensions import deprecated, NamedTuple, Self
|
||||
|
||||
from torch.torch_version import TorchVersion as _TorchVersion
|
||||
|
||||
|
|
@ -52,7 +52,6 @@ __all__ = [
|
|||
"DumpableContext",
|
||||
"ToDumpableContextFn",
|
||||
"FromDumpableContextFn",
|
||||
"PyTreeSpec",
|
||||
"TreeSpec",
|
||||
"LeafSpec",
|
||||
"keystr",
|
||||
|
|
@ -472,7 +471,7 @@ class ConstantNode:
|
|||
|
||||
def _is_constant_holder(spec: "TreeSpec") -> bool:
|
||||
"""Checks if the spec is from a pytree registered with register_constant"""
|
||||
return isinstance(spec._context, ConstantNode)
|
||||
return isinstance(spec.context, ConstantNode)
|
||||
|
||||
|
||||
def _retrieve_constant(spec: "TreeSpec") -> Any:
|
||||
|
|
@ -1072,53 +1071,35 @@ def _is_leaf(tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None) -
|
|||
# context: some context that is useful in unflattening the pytree
|
||||
# children_specs: specs for each child of the root Node
|
||||
# num_leaves: the number of leaves
|
||||
@dataclasses.dataclass(init=False, frozen=True, eq=True, repr=False)
|
||||
@dataclasses.dataclass(init=True, frozen=True, eq=True, repr=False)
|
||||
class TreeSpec:
|
||||
type: Any
|
||||
_context: Context
|
||||
_children: list[Self]
|
||||
context: Context
|
||||
children_specs: list["TreeSpec"]
|
||||
|
||||
num_nodes: int = dataclasses.field(init=False)
|
||||
num_leaves: int = dataclasses.field(init=False)
|
||||
num_children: int = dataclasses.field(init=False)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
type: Any,
|
||||
context: Context, # keep for backward compatibility
|
||||
children: list[Self], # keep for backward compatibility
|
||||
) -> None:
|
||||
object.__setattr__(self, "type", type)
|
||||
object.__setattr__(self, "_context", context)
|
||||
object.__setattr__(self, "_children", children)
|
||||
self.__post_init__()
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.type is None:
|
||||
assert self._context is None
|
||||
assert len(self._children) == 0
|
||||
num_nodes = 1
|
||||
num_leaves = 1
|
||||
num_children = 0
|
||||
else:
|
||||
num_nodes = sum((spec.num_nodes for spec in self._children), start=1)
|
||||
num_leaves = sum(spec.num_leaves for spec in self._children)
|
||||
num_children = len(self._children)
|
||||
num_nodes = sum((spec.num_nodes for spec in self.children_specs), start=1)
|
||||
num_leaves = sum(spec.num_leaves for spec in self.children_specs)
|
||||
num_children = len(self.children_specs)
|
||||
object.__setattr__(self, "num_nodes", num_nodes)
|
||||
object.__setattr__(self, "num_leaves", num_leaves)
|
||||
object.__setattr__(self, "num_children", num_children)
|
||||
|
||||
def __repr__(self, indent: int = 0) -> str:
|
||||
repr_prefix: str = f"TreeSpec({self.type.__name__}, {self._context}, ["
|
||||
repr_prefix: str = f"TreeSpec({self.type.__name__}, {self.context}, ["
|
||||
children_specs_str: str = ""
|
||||
if self.num_children > 0:
|
||||
indent += 2
|
||||
children_specs_str += self._children[0].__repr__(indent)
|
||||
children_specs_str += self.children_specs[0].__repr__(indent)
|
||||
children_specs_str += "," if self.num_children > 1 else ""
|
||||
children_specs_str += ",".join(
|
||||
[
|
||||
"\n" + " " * indent + child.__repr__(indent)
|
||||
for child in self._children[1:]
|
||||
for child in self.children_specs[1:]
|
||||
]
|
||||
)
|
||||
repr_suffix: str = f"{children_specs_str}])"
|
||||
|
|
@ -1130,36 +1111,16 @@ class TreeSpec:
|
|||
elif other.__class__ is self.__class__:
|
||||
if str(self.type) != str(other.type):
|
||||
return False
|
||||
if self._context != other._context:
|
||||
if self.context != other.context:
|
||||
return False
|
||||
elif self._children != other._children:
|
||||
elif self.children_specs != other.children_specs:
|
||||
return False
|
||||
return True
|
||||
return NotImplemented
|
||||
|
||||
@property
|
||||
def context(self) -> Context:
|
||||
return self._context
|
||||
|
||||
@property
|
||||
@deprecated(
|
||||
"`treespec.children_specs` is deprecated. "
|
||||
"Use `treespec.child(index)` to access a single child, "
|
||||
"or `treespec.children()` to get all children.",
|
||||
category=FutureWarning,
|
||||
)
|
||||
def children_specs(self) -> list[Self]:
|
||||
return self._children
|
||||
|
||||
def is_leaf(self) -> bool:
|
||||
return self.num_nodes == 1 and self.num_leaves == 1
|
||||
|
||||
def children(self) -> list[Self]:
|
||||
return self._children.copy()
|
||||
|
||||
def child(self, index: int) -> Self:
|
||||
return self._children[index]
|
||||
|
||||
def flatten_up_to(self, tree: PyTree) -> list[PyTree]:
|
||||
def helper(treespec: TreeSpec, tree: PyTree, subtrees: list[PyTree]) -> None:
|
||||
if treespec.is_leaf():
|
||||
|
|
@ -1181,7 +1142,7 @@ class TreeSpec:
|
|||
f"Node arity mismatch; "
|
||||
f"expected {treespec.num_children}, but got {len(children)}.",
|
||||
)
|
||||
if context != treespec._context:
|
||||
if context != treespec.context:
|
||||
raise ValueError(
|
||||
f"Node context mismatch for custom node type {treespec.type!r}.",
|
||||
)
|
||||
|
|
@ -1206,10 +1167,10 @@ class TreeSpec:
|
|||
if both_standard_dict:
|
||||
# dictionary types are compatible with each other
|
||||
dict_context = (
|
||||
treespec._context
|
||||
treespec.context
|
||||
if treespec.type is not defaultdict
|
||||
# ignore mismatch of `default_factory` for defaultdict
|
||||
else treespec._context[1]
|
||||
else treespec.context[1]
|
||||
)
|
||||
expected_keys = dict_context
|
||||
got_key_set = set(tree)
|
||||
|
|
@ -1230,13 +1191,13 @@ class TreeSpec:
|
|||
children, context = flatten_fn(tree)
|
||||
if (
|
||||
node_type is not deque # ignore mismatch of `maxlen` for deque
|
||||
) and context != treespec._context:
|
||||
) and context != treespec.context:
|
||||
raise ValueError(
|
||||
f"Node context mismatch for node type {treespec.type!r}; "
|
||||
f"expected {treespec._context!r}, but got {context!r}.", # namedtuple type mismatch
|
||||
f"expected {treespec.context!r}, but got {context!r}.", # namedtuple type mismatch
|
||||
)
|
||||
|
||||
for subtree, subspec in zip(children, treespec._children, strict=True):
|
||||
for subtree, subspec in zip(children, treespec.children_specs):
|
||||
helper(subspec, subtree, subtrees)
|
||||
|
||||
subtrees: list[PyTree] = []
|
||||
|
|
@ -1261,24 +1222,24 @@ class TreeSpec:
|
|||
start = 0
|
||||
end = 0
|
||||
child_pytrees = []
|
||||
for child_spec in self._children:
|
||||
for child_spec in self.children_specs:
|
||||
end += child_spec.num_leaves
|
||||
child_pytrees.append(child_spec.unflatten(leaves[start:end]))
|
||||
start = end
|
||||
|
||||
return unflatten_fn(child_pytrees, self._context)
|
||||
return unflatten_fn(child_pytrees, self.context)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
node_type = self.type
|
||||
if node_type is defaultdict:
|
||||
default_factory, dict_context = self._context
|
||||
default_factory, dict_context = self.context
|
||||
hashable_context = (default_factory, tuple(dict_context))
|
||||
elif node_type in (dict, OrderedDict):
|
||||
hashable_context = tuple(self._context)
|
||||
hashable_context = tuple(self.context)
|
||||
elif node_type is None or node_type in BUILTIN_TYPES:
|
||||
hashable_context = self._context
|
||||
elif isinstance(self._context, ConstantNode):
|
||||
hashable_context = self._context.value
|
||||
hashable_context = self.context
|
||||
elif isinstance(self.context, ConstantNode):
|
||||
hashable_context = self.context.value
|
||||
else:
|
||||
# The context for user-defined node types might not be hashable.
|
||||
# Ignore it for hashing.
|
||||
|
|
@ -1286,26 +1247,20 @@ class TreeSpec:
|
|||
# same hash. This might increase the hash collision rate, but we
|
||||
# don't care about that.
|
||||
hashable_context = None
|
||||
return hash((node_type, hashable_context, tuple(self._children)))
|
||||
|
||||
|
||||
PyTreeSpec: TypeAlias = TreeSpec
|
||||
return hash((node_type, hashable_context, tuple(self.children_specs)))
|
||||
|
||||
|
||||
# NOTE: subclassing a dataclass is subtle. In order to enable reasoning about
|
||||
# this class with `dataclasses.fields`, etc., while having a simplified
|
||||
# constructor that takes no argument, we wrap with `dataclass(init=True, ...)`
|
||||
# again, with fields that have `init=False`.
|
||||
@deprecated(
|
||||
"`isinstance(treespec, LeafSpec)` is deprecated, "
|
||||
"use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.",
|
||||
category=FutureWarning,
|
||||
)
|
||||
@dataclasses.dataclass(init=True, frozen=True, eq=False, repr=False)
|
||||
class LeafSpec(TreeSpec):
|
||||
type: Any = dataclasses.field(default=None, init=False)
|
||||
_context: Context = dataclasses.field(default=None, init=False)
|
||||
_children: list[Self] = dataclasses.field(default_factory=list, init=False)
|
||||
context: Context = dataclasses.field(default=None, init=False)
|
||||
children_specs: list["TreeSpec"] = dataclasses.field(
|
||||
default_factory=list, init=False
|
||||
)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Override `__post_init__` for `num_leaves` derivation.
|
||||
|
|
@ -1319,36 +1274,7 @@ class LeafSpec(TreeSpec):
|
|||
|
||||
# All leaves are equivalent, so represent with a single object to save on
|
||||
# object construction time
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings(
|
||||
"ignore", category=FutureWarning, module=__name__, append=False
|
||||
)
|
||||
_LEAF_SPEC = LeafSpec()
|
||||
|
||||
|
||||
def treespec_leaf() -> LeafSpec:
|
||||
"""Make a treespec representing a leaf node."""
|
||||
return _LEAF_SPEC
|
||||
|
||||
|
||||
def treespec_tuple(iterable: Iterable[TreeSpec] = (), /) -> TreeSpec:
|
||||
"""Make a tuple treespec from an iterable of child treespecs."""
|
||||
children = list(iterable)
|
||||
if any(not isinstance(child, TreeSpec) for child in children):
|
||||
raise ValueError(f"Expected a tuple of TreeSpec values, got: {children!r}.")
|
||||
return TreeSpec(tuple, None, children)
|
||||
|
||||
|
||||
def treespec_dict(
|
||||
mapping: Union[Mapping[Any, TreeSpec], Iterable[tuple[Any, TreeSpec]]] = (),
|
||||
/,
|
||||
**kwargs: TreeSpec,
|
||||
) -> TreeSpec:
|
||||
"""Make a dict treespec from a dict of child treespecs."""
|
||||
dct = dict(mapping, **kwargs)
|
||||
if any(not isinstance(child, TreeSpec) for child in dct.values()):
|
||||
raise ValueError(f"Expected a dictionary of TreeSpec values, got: {dct!r}.")
|
||||
return TreeSpec(dict, list(dct.keys()), list(dct.values()))
|
||||
_LEAF_SPEC = LeafSpec()
|
||||
|
||||
|
||||
def tree_flatten(
|
||||
|
|
@ -1827,15 +1753,15 @@ def _broadcast_to_and_flatten(
|
|||
return None
|
||||
|
||||
flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
|
||||
child_pytrees, context = flatten_fn(tree)
|
||||
child_pytrees, ctx = flatten_fn(tree)
|
||||
|
||||
# Check if the Node is different from the spec
|
||||
if len(child_pytrees) != treespec.num_children or context != treespec._context:
|
||||
if len(child_pytrees) != treespec.num_children or ctx != treespec.context:
|
||||
return None
|
||||
|
||||
# Recursively flatten the children
|
||||
result: list[Any] = []
|
||||
for child, child_spec in zip(child_pytrees, treespec._children, strict=True):
|
||||
for child, child_spec in zip(child_pytrees, treespec.children_specs):
|
||||
flat = _broadcast_to_and_flatten(child, child_spec, is_leaf=is_leaf)
|
||||
if flat is not None:
|
||||
result += flat
|
||||
|
|
@ -1889,7 +1815,7 @@ def _treespec_to_json(treespec: TreeSpec) -> _TreeSpecSchema:
|
|||
|
||||
if serialize_node_def.to_dumpable_context is None:
|
||||
try:
|
||||
serialized_context = json.dumps(treespec._context, cls=EnumEncoder)
|
||||
serialized_context = json.dumps(treespec.context, cls=EnumEncoder)
|
||||
except TypeError as e:
|
||||
raise TypeError(
|
||||
"Unable to serialize context. "
|
||||
|
|
@ -1897,9 +1823,9 @@ def _treespec_to_json(treespec: TreeSpec) -> _TreeSpecSchema:
|
|||
"custom serializer using _register_pytree_node."
|
||||
) from e
|
||||
else:
|
||||
serialized_context = serialize_node_def.to_dumpable_context(treespec._context)
|
||||
serialized_context = serialize_node_def.to_dumpable_context(treespec.context)
|
||||
|
||||
child_schemas = [_treespec_to_json(child) for child in treespec._children]
|
||||
child_schemas = [_treespec_to_json(child) for child in treespec.children_specs]
|
||||
|
||||
return _TreeSpecSchema(serialized_type_name, serialized_context, child_schemas)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user