mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "[pytree] add treespec_{leaf,tuple,dict} functions for args_spec modification (#160843)"
This reverts commit 108bb224f7.
Reverted https://github.com/pytorch/pytorch/pull/160843 on behalf of https://github.com/atalman due to failing internal builds ([comment](https://github.com/pytorch/pytorch/pull/160843#issuecomment-3474354428))
This commit is contained in:
parent
b71966f67b
commit
85b85f6c2c
|
|
@ -91,13 +91,13 @@ from torch.testing._internal.torchbind_impls import load_torchbind_test_lib
|
||||||
from torch.testing._internal.triton_utils import requires_cuda_and_triton, requires_gpu
|
from torch.testing._internal.triton_utils import requires_cuda_and_triton, requires_gpu
|
||||||
from torch.testing._internal.two_tensor import TwoTensor
|
from torch.testing._internal.two_tensor import TwoTensor
|
||||||
from torch.utils._pytree import (
|
from torch.utils._pytree import (
|
||||||
|
LeafSpec,
|
||||||
register_constant,
|
register_constant,
|
||||||
tree_flatten,
|
tree_flatten,
|
||||||
tree_map,
|
tree_map,
|
||||||
tree_unflatten,
|
tree_unflatten,
|
||||||
TreeSpec,
|
TreeSpec,
|
||||||
treespec_dumps,
|
treespec_dumps,
|
||||||
treespec_leaf,
|
|
||||||
treespec_loads,
|
treespec_loads,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -7791,7 +7791,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
||||||
|
|
||||||
dt = MyDataClass(x=3, y=4)
|
dt = MyDataClass(x=3, y=4)
|
||||||
flat, spec = tree_flatten(dt)
|
flat, spec = tree_flatten(dt)
|
||||||
self.assertTrue(spec, treespec_leaf())
|
self.assertTrue(spec, LeafSpec())
|
||||||
self.assertTrue(len(flat) == 1)
|
self.assertTrue(len(flat) == 1)
|
||||||
|
|
||||||
torch.export.register_dataclass(
|
torch.export.register_dataclass(
|
||||||
|
|
@ -7802,9 +7802,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
||||||
flat, spec = tree_flatten(dt)
|
flat, spec = tree_flatten(dt)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
spec,
|
spec,
|
||||||
TreeSpec(
|
TreeSpec(MyDataClass, [["x", "y"], ["z"]], [LeafSpec(), LeafSpec()]),
|
||||||
MyDataClass, [["x", "y"], ["z"]], [treespec_leaf(), treespec_leaf()]
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
self.assertEqual(flat, [3, 4])
|
self.assertEqual(flat, [3, 4])
|
||||||
|
|
||||||
|
|
@ -7837,7 +7835,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
||||||
TreeSpec(
|
TreeSpec(
|
||||||
MyOtherDataClass,
|
MyOtherDataClass,
|
||||||
[["x", "y", "z"], []],
|
[["x", "y", "z"], []],
|
||||||
[treespec_leaf(), treespec_leaf(), treespec_leaf()],
|
[LeafSpec(), LeafSpec(), LeafSpec()],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self.assertEqual(flat, [3, 4, None])
|
self.assertEqual(flat, [3, 4, None])
|
||||||
|
|
|
||||||
|
|
@ -65,6 +65,9 @@ class TestEnum(enum.Enum):
|
||||||
A = auto()
|
A = auto()
|
||||||
|
|
||||||
|
|
||||||
|
python_leafspec = python_pytree.LeafSpec()
|
||||||
|
|
||||||
|
|
||||||
class TestGenericPytree(TestCase):
|
class TestGenericPytree(TestCase):
|
||||||
def test_aligned_public_apis(self):
|
def test_aligned_public_apis(self):
|
||||||
public_apis = python_pytree.__all__
|
public_apis = python_pytree.__all__
|
||||||
|
|
@ -194,7 +197,7 @@ class TestGenericPytree(TestCase):
|
||||||
def run_test_with_leaf(leaf):
|
def run_test_with_leaf(leaf):
|
||||||
values, treespec = pytree.tree_flatten(leaf)
|
values, treespec = pytree.tree_flatten(leaf)
|
||||||
self.assertEqual(values, [leaf])
|
self.assertEqual(values, [leaf])
|
||||||
self.assertEqual(treespec, pytree.treespec_leaf())
|
self.assertEqual(treespec, pytree.LeafSpec())
|
||||||
|
|
||||||
unflattened = pytree.tree_unflatten(values, treespec)
|
unflattened = pytree.tree_unflatten(values, treespec)
|
||||||
self.assertEqual(unflattened, leaf)
|
self.assertEqual(unflattened, leaf)
|
||||||
|
|
@ -212,7 +215,7 @@ class TestGenericPytree(TestCase):
|
||||||
(
|
(
|
||||||
python_pytree,
|
python_pytree,
|
||||||
lambda tup: python_pytree.TreeSpec(
|
lambda tup: python_pytree.TreeSpec(
|
||||||
tuple, None, [python_pytree.treespec_leaf() for _ in tup]
|
tuple, None, [python_leafspec for _ in tup]
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
name="python",
|
name="python",
|
||||||
|
|
@ -247,7 +250,7 @@ class TestGenericPytree(TestCase):
|
||||||
(
|
(
|
||||||
python_pytree,
|
python_pytree,
|
||||||
lambda lst: python_pytree.TreeSpec(
|
lambda lst: python_pytree.TreeSpec(
|
||||||
list, None, [python_pytree.treespec_leaf() for _ in lst]
|
list, None, [python_leafspec for _ in lst]
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
name="python",
|
name="python",
|
||||||
|
|
@ -283,7 +286,7 @@ class TestGenericPytree(TestCase):
|
||||||
lambda dct: python_pytree.TreeSpec(
|
lambda dct: python_pytree.TreeSpec(
|
||||||
dict,
|
dict,
|
||||||
list(dct.keys()),
|
list(dct.keys()),
|
||||||
[python_pytree.treespec_leaf() for _ in dct.values()],
|
[python_leafspec for _ in dct.values()],
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
name="python",
|
name="python",
|
||||||
|
|
@ -324,7 +327,7 @@ class TestGenericPytree(TestCase):
|
||||||
lambda odict: python_pytree.TreeSpec(
|
lambda odict: python_pytree.TreeSpec(
|
||||||
OrderedDict,
|
OrderedDict,
|
||||||
list(odict.keys()),
|
list(odict.keys()),
|
||||||
[python_pytree.treespec_leaf() for _ in odict.values()],
|
[python_leafspec for _ in odict.values()],
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
name="python",
|
name="python",
|
||||||
|
|
@ -368,7 +371,7 @@ class TestGenericPytree(TestCase):
|
||||||
lambda ddct: python_pytree.TreeSpec(
|
lambda ddct: python_pytree.TreeSpec(
|
||||||
defaultdict,
|
defaultdict,
|
||||||
[ddct.default_factory, list(ddct.keys())],
|
[ddct.default_factory, list(ddct.keys())],
|
||||||
[python_pytree.treespec_leaf() for _ in ddct.values()],
|
[python_leafspec for _ in ddct.values()],
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
name="python",
|
name="python",
|
||||||
|
|
@ -410,7 +413,7 @@ class TestGenericPytree(TestCase):
|
||||||
(
|
(
|
||||||
python_pytree,
|
python_pytree,
|
||||||
lambda deq: python_pytree.TreeSpec(
|
lambda deq: python_pytree.TreeSpec(
|
||||||
deque, deq.maxlen, [python_pytree.treespec_leaf() for _ in deq]
|
deque, deq.maxlen, [python_leafspec for _ in deq]
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
name="python",
|
name="python",
|
||||||
|
|
@ -450,7 +453,7 @@ class TestGenericPytree(TestCase):
|
||||||
def run_test(tup):
|
def run_test(tup):
|
||||||
if pytree is python_pytree:
|
if pytree is python_pytree:
|
||||||
expected_spec = python_pytree.TreeSpec(
|
expected_spec = python_pytree.TreeSpec(
|
||||||
namedtuple, Point, [python_pytree.treespec_leaf() for _ in tup]
|
namedtuple, Point, [python_leafspec for _ in tup]
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
expected_spec = cxx_pytree.tree_structure(Point(0, 1))
|
expected_spec = cxx_pytree.tree_structure(Point(0, 1))
|
||||||
|
|
@ -845,16 +848,16 @@ if "optree" in sys.modules:
|
||||||
|
|
||||||
def test_treespec_equality(self):
|
def test_treespec_equality(self):
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
python_pytree.treespec_leaf(),
|
python_pytree.LeafSpec(),
|
||||||
python_pytree.treespec_leaf(),
|
python_pytree.LeafSpec(),
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
python_pytree.TreeSpec(list, None, []),
|
python_pytree.TreeSpec(list, None, []),
|
||||||
python_pytree.TreeSpec(list, None, []),
|
python_pytree.TreeSpec(list, None, []),
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
python_pytree.TreeSpec(list, None, [python_pytree.treespec_leaf()]),
|
python_pytree.TreeSpec(list, None, [python_pytree.LeafSpec()]),
|
||||||
python_pytree.TreeSpec(list, None, [python_pytree.treespec_leaf()]),
|
python_pytree.TreeSpec(list, None, [python_pytree.LeafSpec()]),
|
||||||
)
|
)
|
||||||
self.assertFalse(
|
self.assertFalse(
|
||||||
python_pytree.TreeSpec(tuple, None, [])
|
python_pytree.TreeSpec(tuple, None, [])
|
||||||
|
|
@ -889,32 +892,24 @@ if "optree" in sys.modules:
|
||||||
# python_pytree.tree_structure({})
|
# python_pytree.tree_structure({})
|
||||||
python_pytree.TreeSpec(dict, [], []),
|
python_pytree.TreeSpec(dict, [], []),
|
||||||
# python_pytree.tree_structure([0])
|
# python_pytree.tree_structure([0])
|
||||||
python_pytree.TreeSpec(list, None, [python_pytree.treespec_leaf()]),
|
python_pytree.TreeSpec(list, None, [python_leafspec]),
|
||||||
# python_pytree.tree_structure([0, 1])
|
# python_pytree.tree_structure([0, 1])
|
||||||
python_pytree.TreeSpec(
|
python_pytree.TreeSpec(
|
||||||
list,
|
list,
|
||||||
None,
|
None,
|
||||||
[python_pytree.treespec_leaf(), python_pytree.treespec_leaf()],
|
[python_leafspec, python_leafspec],
|
||||||
),
|
),
|
||||||
# python_pytree.tree_structure((0, 1, 2))
|
# python_pytree.tree_structure((0, 1, 2))
|
||||||
python_pytree.TreeSpec(
|
python_pytree.TreeSpec(
|
||||||
tuple,
|
tuple,
|
||||||
None,
|
None,
|
||||||
[
|
[python_leafspec, python_leafspec, python_leafspec],
|
||||||
python_pytree.treespec_leaf(),
|
|
||||||
python_pytree.treespec_leaf(),
|
|
||||||
python_pytree.treespec_leaf(),
|
|
||||||
],
|
|
||||||
),
|
),
|
||||||
# python_pytree.tree_structure({"a": 0, "b": 1, "c": 2})
|
# python_pytree.tree_structure({"a": 0, "b": 1, "c": 2})
|
||||||
python_pytree.TreeSpec(
|
python_pytree.TreeSpec(
|
||||||
dict,
|
dict,
|
||||||
["a", "b", "c"],
|
["a", "b", "c"],
|
||||||
[
|
[python_leafspec, python_leafspec, python_leafspec],
|
||||||
python_pytree.treespec_leaf(),
|
|
||||||
python_pytree.treespec_leaf(),
|
|
||||||
python_pytree.treespec_leaf(),
|
|
||||||
],
|
|
||||||
),
|
),
|
||||||
# python_pytree.tree_structure(OrderedDict([("a", (0, 1)), ("b", 2), ("c", {"a": 3, "b": 4, "c": 5})])
|
# python_pytree.tree_structure(OrderedDict([("a", (0, 1)), ("b", 2), ("c", {"a": 3, "b": 4, "c": 5})])
|
||||||
python_pytree.TreeSpec(
|
python_pytree.TreeSpec(
|
||||||
|
|
@ -924,17 +919,13 @@ if "optree" in sys.modules:
|
||||||
python_pytree.TreeSpec(
|
python_pytree.TreeSpec(
|
||||||
tuple,
|
tuple,
|
||||||
None,
|
None,
|
||||||
[python_pytree.treespec_leaf(), python_pytree.treespec_leaf()],
|
[python_leafspec, python_leafspec],
|
||||||
),
|
),
|
||||||
python_pytree.treespec_leaf(),
|
python_leafspec,
|
||||||
python_pytree.TreeSpec(
|
python_pytree.TreeSpec(
|
||||||
dict,
|
dict,
|
||||||
["a", "b", "c"],
|
["a", "b", "c"],
|
||||||
[
|
[python_leafspec, python_leafspec, python_leafspec],
|
||||||
python_pytree.treespec_leaf(),
|
|
||||||
python_pytree.treespec_leaf(),
|
|
||||||
python_pytree.treespec_leaf(),
|
|
||||||
],
|
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
|
|
@ -947,15 +938,12 @@ if "optree" in sys.modules:
|
||||||
tuple,
|
tuple,
|
||||||
None,
|
None,
|
||||||
[
|
[
|
||||||
python_pytree.treespec_leaf(),
|
python_leafspec,
|
||||||
python_pytree.treespec_leaf(),
|
python_leafspec,
|
||||||
python_pytree.TreeSpec(
|
python_pytree.TreeSpec(
|
||||||
list,
|
list,
|
||||||
None,
|
None,
|
||||||
[
|
[python_leafspec, python_leafspec],
|
||||||
python_pytree.treespec_leaf(),
|
|
||||||
python_pytree.treespec_leaf(),
|
|
||||||
],
|
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
|
|
@ -969,12 +957,12 @@ if "optree" in sys.modules:
|
||||||
python_pytree.TreeSpec(
|
python_pytree.TreeSpec(
|
||||||
list,
|
list,
|
||||||
None,
|
None,
|
||||||
[python_pytree.treespec_leaf(), python_pytree.treespec_leaf()],
|
[python_leafspec, python_leafspec],
|
||||||
),
|
),
|
||||||
python_pytree.TreeSpec(
|
python_pytree.TreeSpec(
|
||||||
list,
|
list,
|
||||||
None,
|
None,
|
||||||
[python_pytree.treespec_leaf(), python_pytree.treespec_leaf()],
|
[python_leafspec, python_leafspec],
|
||||||
),
|
),
|
||||||
python_pytree.TreeSpec(dict, [], []),
|
python_pytree.TreeSpec(dict, [], []),
|
||||||
],
|
],
|
||||||
|
|
@ -1003,7 +991,7 @@ if "optree" in sys.modules:
|
||||||
list,
|
list,
|
||||||
None,
|
None,
|
||||||
[
|
[
|
||||||
python_pytree.treespec_leaf(),
|
python_leafspec,
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
|
|
@ -1012,7 +1000,7 @@ if "optree" in sys.modules:
|
||||||
self.assertIsInstance(serialized_spec, str)
|
self.assertIsInstance(serialized_spec, str)
|
||||||
|
|
||||||
def test_pytree_serialize_enum(self):
|
def test_pytree_serialize_enum(self):
|
||||||
spec = python_pytree.TreeSpec(dict, TestEnum.A, [python_pytree.treespec_leaf()])
|
spec = python_pytree.TreeSpec(dict, TestEnum.A, [python_leafspec])
|
||||||
|
|
||||||
serialized_spec = python_pytree.treespec_dumps(spec)
|
serialized_spec = python_pytree.treespec_dumps(spec)
|
||||||
self.assertIsInstance(serialized_spec, str)
|
self.assertIsInstance(serialized_spec, str)
|
||||||
|
|
@ -1175,20 +1163,12 @@ if "optree" in sys.modules:
|
||||||
OrderedDict,
|
OrderedDict,
|
||||||
[1, 2, 3],
|
[1, 2, 3],
|
||||||
[
|
[
|
||||||
python_pytree.TreeSpec(
|
python_pytree.TreeSpec(tuple, None, [python_leafspec, python_leafspec]),
|
||||||
tuple,
|
python_leafspec,
|
||||||
None,
|
|
||||||
[python_pytree.treespec_leaf(), python_pytree.treespec_leaf()],
|
|
||||||
),
|
|
||||||
python_pytree.treespec_leaf(),
|
|
||||||
python_pytree.TreeSpec(
|
python_pytree.TreeSpec(
|
||||||
dict,
|
dict,
|
||||||
[4, 5, 6],
|
[4, 5, 6],
|
||||||
[
|
[python_leafspec, python_leafspec, python_leafspec],
|
||||||
python_pytree.treespec_leaf(),
|
|
||||||
python_pytree.treespec_leaf(),
|
|
||||||
python_pytree.treespec_leaf(),
|
|
||||||
],
|
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
@ -1473,7 +1453,7 @@ class TestCxxPytree(TestCase):
|
||||||
raise unittest.SkipTest("C++ pytree tests are not supported in fbcode")
|
raise unittest.SkipTest("C++ pytree tests are not supported in fbcode")
|
||||||
|
|
||||||
def test_treespec_equality(self):
|
def test_treespec_equality(self):
|
||||||
self.assertEqual(cxx_pytree.treespec_leaf(), cxx_pytree.treespec_leaf())
|
self.assertEqual(cxx_pytree.LeafSpec(), cxx_pytree.LeafSpec())
|
||||||
|
|
||||||
def test_treespec_repr(self):
|
def test_treespec_repr(self):
|
||||||
# Check that it looks sane
|
# Check that it looks sane
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ from ..decorators import substitute_in_graph
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import builtins
|
import builtins
|
||||||
from collections.abc import Callable, Iterable, Mapping
|
from collections.abc import Callable, Iterable
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -349,113 +349,6 @@ if python_pytree._cxx_pytree_dynamo_traceable:
|
||||||
def _is_pytreespec_instance(obj: Any, /) -> TypeIs[PyTreeSpec]:
|
def _is_pytreespec_instance(obj: Any, /) -> TypeIs[PyTreeSpec]:
|
||||||
return isinstance(obj, PyTreeSpec)
|
return isinstance(obj, PyTreeSpec)
|
||||||
|
|
||||||
@substitute_in_graph( # type: ignore[arg-type]
|
|
||||||
optree.treespec_leaf,
|
|
||||||
# We need to disable constant folding here because we want the function to reference the
|
|
||||||
# PyTreeSpec class defined above, not the one in the C++ module.
|
|
||||||
can_constant_fold_through=False,
|
|
||||||
)
|
|
||||||
def treespec_leaf(
|
|
||||||
*,
|
|
||||||
none_is_leaf: bool = False,
|
|
||||||
namespace: str = "", # unused
|
|
||||||
) -> PyTreeSpec:
|
|
||||||
return PyTreeSpec(
|
|
||||||
(),
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
(),
|
|
||||||
None,
|
|
||||||
none_is_leaf=none_is_leaf,
|
|
||||||
namespace="",
|
|
||||||
)
|
|
||||||
|
|
||||||
@substitute_in_graph( # type: ignore[arg-type]
|
|
||||||
optree.treespec_tuple,
|
|
||||||
# We need to disable constant folding here because we want the function to reference the
|
|
||||||
# PyTreeSpec class defined above, not the one in the C++ module.
|
|
||||||
can_constant_fold_through=False,
|
|
||||||
)
|
|
||||||
def treespec_tuple(
|
|
||||||
iterable: Iterable[PyTreeSpec] = (),
|
|
||||||
/,
|
|
||||||
*,
|
|
||||||
none_is_leaf: bool = False,
|
|
||||||
namespace: str = "",
|
|
||||||
) -> PyTreeSpec:
|
|
||||||
children = tuple(iterable)
|
|
||||||
if any(not _is_pytreespec_instance(child) for child in children):
|
|
||||||
raise ValueError(f"Expected a tuple of PyTreeSpecs, got: {children!r}.")
|
|
||||||
if any(child.none_is_leaf != none_is_leaf for child in children):
|
|
||||||
raise ValueError(
|
|
||||||
"All children PyTreeSpecs must have the same `none_is_leaf` value "
|
|
||||||
f"as the parent; expected {none_is_leaf}, got: {children!r}.",
|
|
||||||
)
|
|
||||||
if any(child.namespace not in (namespace, "") for child in children):
|
|
||||||
raise ValueError(
|
|
||||||
"All children PyTreeSpecs must have the same `namespace` value "
|
|
||||||
f"as the parent; expected {namespace!r}, got: {children!r}.",
|
|
||||||
)
|
|
||||||
handler = optree.register_pytree_node.get(tuple, namespace=namespace) # type: ignore[attr-defined]
|
|
||||||
assert handler is not None
|
|
||||||
return PyTreeSpec(
|
|
||||||
tuple(children),
|
|
||||||
tuple,
|
|
||||||
None,
|
|
||||||
tuple(range(len(children))),
|
|
||||||
handler.unflatten_func,
|
|
||||||
none_is_leaf=none_is_leaf,
|
|
||||||
namespace=namespace,
|
|
||||||
)
|
|
||||||
|
|
||||||
@substitute_in_graph( # type: ignore[arg-type]
|
|
||||||
optree.treespec_dict,
|
|
||||||
# We need to disable constant folding here because we want the function to reference the
|
|
||||||
# PyTreeSpec class defined above, not the one in the C++ module.
|
|
||||||
can_constant_fold_through=False,
|
|
||||||
)
|
|
||||||
def treespec_dict(
|
|
||||||
mapping: Mapping[Any, PyTreeSpec] | Iterable[tuple[Any, PyTreeSpec]] = (),
|
|
||||||
/,
|
|
||||||
*,
|
|
||||||
none_is_leaf: bool = False,
|
|
||||||
namespace: str = "",
|
|
||||||
**kwargs: PyTreeSpec,
|
|
||||||
) -> PyTreeSpec:
|
|
||||||
dct = dict(mapping, **kwargs)
|
|
||||||
if any(not _is_pytreespec_instance(child) for child in dct.values()):
|
|
||||||
raise ValueError(f"Expected a dictionary of TreeSpecs, got: {dct!r}.")
|
|
||||||
if any(child.none_is_leaf != none_is_leaf for child in dct.values()):
|
|
||||||
raise ValueError(
|
|
||||||
"All children PyTreeSpecs must have the same `none_is_leaf` value "
|
|
||||||
f"as the parent; expected {none_is_leaf}, got: {dct!r}.",
|
|
||||||
)
|
|
||||||
if any(child.namespace not in (namespace, "") for child in dct.values()):
|
|
||||||
raise ValueError(
|
|
||||||
"All children PyTreeSpecs must have the same `namespace` value "
|
|
||||||
f"as the parent; expected {namespace!r}, got: {dct!r}.",
|
|
||||||
)
|
|
||||||
|
|
||||||
(
|
|
||||||
children,
|
|
||||||
metadata,
|
|
||||||
entries,
|
|
||||||
unflatten_func,
|
|
||||||
) = optree.tree_flatten_one_level( # type: ignore[assignment,var-annotated]
|
|
||||||
dct, # type: ignore[arg-type]
|
|
||||||
none_is_leaf=none_is_leaf,
|
|
||||||
namespace=namespace,
|
|
||||||
)
|
|
||||||
return PyTreeSpec(
|
|
||||||
tuple(children), # type: ignore[arg-type]
|
|
||||||
dict,
|
|
||||||
metadata,
|
|
||||||
entries,
|
|
||||||
unflatten_func,
|
|
||||||
none_is_leaf=none_is_leaf,
|
|
||||||
namespace=namespace,
|
|
||||||
)
|
|
||||||
|
|
||||||
@substitute_in_graph( # type: ignore[arg-type]
|
@substitute_in_graph( # type: ignore[arg-type]
|
||||||
optree.tree_flatten,
|
optree.tree_flatten,
|
||||||
# We need to disable constant folding here because we want the function to reference the
|
# We need to disable constant folding here because we want the function to reference the
|
||||||
|
|
|
||||||
|
|
@ -3727,7 +3727,9 @@ class SourcelessBuilder:
|
||||||
pass # failthrough to unimplemented branch
|
pass # failthrough to unimplemented branch
|
||||||
elif isinstance(value, torch.fx.graph_module.GraphModule):
|
elif isinstance(value, torch.fx.graph_module.GraphModule):
|
||||||
return SourcelessGraphModuleVariable(value)
|
return SourcelessGraphModuleVariable(value)
|
||||||
elif isinstance(value, torch.utils._pytree.TreeSpec):
|
elif isinstance(
|
||||||
|
value, (torch.utils._pytree.TreeSpec, torch.utils._pytree.LeafSpec)
|
||||||
|
):
|
||||||
return UserDefinedObjectVariable(value)
|
return UserDefinedObjectVariable(value)
|
||||||
elif PlacementVariable.is_placement(value):
|
elif PlacementVariable.is_placement(value):
|
||||||
return PlacementVariable(value)
|
return PlacementVariable(value)
|
||||||
|
|
|
||||||
|
|
@ -1661,8 +1661,8 @@ class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||||
# We need to have this check this way, because in case init is a TreeSpec and carry
|
# We need to have this check this way, because in case init is a TreeSpec and carry
|
||||||
# but carry is only a LeafSpec, these two cannot be compared correctly.
|
# but carry is only a LeafSpec, these two cannot be compared correctly.
|
||||||
if (
|
if (
|
||||||
xs_treespec.as_python_constant().is_leaf()
|
isinstance(xs_treespec.as_python_constant(), pytree.LeafSpec)
|
||||||
!= _combine_treespec.as_python_constant().is_leaf()
|
!= isinstance(_combine_treespec.as_python_constant(), pytree.LeafSpec)
|
||||||
) or not _make_inlined(tx, pytree.TreeSpec.__eq__)(
|
) or not _make_inlined(tx, pytree.TreeSpec.__eq__)(
|
||||||
xs_treespec, _combine_treespec
|
xs_treespec, _combine_treespec
|
||||||
).as_python_constant():
|
).as_python_constant():
|
||||||
|
|
|
||||||
|
|
@ -1530,10 +1530,10 @@ class GraphModuleSerializer(metaclass=Final):
|
||||||
else:
|
else:
|
||||||
raise AssertionError("TODO")
|
raise AssertionError("TODO")
|
||||||
|
|
||||||
def serialize_treespec(self, treespec: pytree.TreeSpec) -> str:
|
def serialize_treespec(self, treespec):
|
||||||
# We want to additionally save all the field names of the namedtuples in
|
# We want to additionally save all the field names of the namedtuples in
|
||||||
# case users want to check that the treespec types are equivalent
|
# case users want to check that the treespec types are equivalent
|
||||||
def store_namedtuple_fields(ts: pytree.TreeSpec) -> None:
|
def store_namedtuple_fields(ts):
|
||||||
if ts.type is None:
|
if ts.type is None:
|
||||||
return
|
return
|
||||||
if ts.type is namedtuple or pytree.is_namedtuple_class(ts.type):
|
if ts.type is namedtuple or pytree.is_namedtuple_class(ts.type):
|
||||||
|
|
@ -1555,7 +1555,7 @@ class GraphModuleSerializer(metaclass=Final):
|
||||||
NamedTupleDef(field_names=ts.context._fields)
|
NamedTupleDef(field_names=ts.context._fields)
|
||||||
)
|
)
|
||||||
|
|
||||||
for child in ts.children():
|
for child in ts.children_specs:
|
||||||
store_namedtuple_fields(child)
|
store_namedtuple_fields(child)
|
||||||
|
|
||||||
serialized_treespec = treespec_dumps(treespec, TREESPEC_VERSION)
|
serialized_treespec = treespec_dumps(treespec, TREESPEC_VERSION)
|
||||||
|
|
|
||||||
|
|
@ -158,7 +158,7 @@ class PytreeThunk:
|
||||||
assert spec is not None
|
assert spec is not None
|
||||||
self.spec: pytree.TreeSpec = spec
|
self.spec: pytree.TreeSpec = spec
|
||||||
if self.spec.type in {tuple, list} and all(
|
if self.spec.type in {tuple, list} and all(
|
||||||
child.is_leaf() for child in spec.children()
|
child.is_leaf() for child in spec.children_specs
|
||||||
):
|
):
|
||||||
self.is_simple = True
|
self.is_simple = True
|
||||||
if self.spec.is_leaf():
|
if self.spec.is_leaf():
|
||||||
|
|
|
||||||
|
|
@ -1590,7 +1590,7 @@ def aot_export_joint_simple(
|
||||||
decompositions=decompositions,
|
decompositions=decompositions,
|
||||||
trace_joint=trace_joint,
|
trace_joint=trace_joint,
|
||||||
)
|
)
|
||||||
in_spec, _kw_in_spec = in_spec.children()
|
in_spec, _kw_in_spec = in_spec.children_specs
|
||||||
# At this point, we can just directly return the (joint or inference graph) that we traced.
|
# At this point, we can just directly return the (joint or inference graph) that we traced.
|
||||||
# First though: a bunch of assertions to make sure that our graph doesn't require
|
# First though: a bunch of assertions to make sure that our graph doesn't require
|
||||||
# any calling convention changes compared to the original function.
|
# any calling convention changes compared to the original function.
|
||||||
|
|
@ -1617,7 +1617,7 @@ def aot_export_joint_simple(
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"aot_export_joint_simple requires inputs to be a single list/tuple. in_spec={str(in_spec)}"
|
f"aot_export_joint_simple requires inputs to be a single list/tuple. in_spec={str(in_spec)}"
|
||||||
)
|
)
|
||||||
if not all(child.is_leaf() for child in in_spec.children()):
|
if not all(child.is_leaf() for child in in_spec.children_specs):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"aot_export_joint_simple requires individual inputs not to be pytrees. in_spec={str(in_spec)}"
|
f"aot_export_joint_simple requires individual inputs not to be pytrees. in_spec={str(in_spec)}"
|
||||||
)
|
)
|
||||||
|
|
@ -1625,7 +1625,7 @@ def aot_export_joint_simple(
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"aot_export_joint_simple requires outputs to be a single list/tuple. out_spec={str(out_spec)}"
|
f"aot_export_joint_simple requires outputs to be a single list/tuple. out_spec={str(out_spec)}"
|
||||||
)
|
)
|
||||||
if not all(child.is_leaf() for child in out_spec.children()):
|
if not all(child.is_leaf() for child in out_spec.children_specs):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"aot_export_joint_simple requires individual outputs not to be pytrees. out_spec={str(out_spec)}"
|
f"aot_export_joint_simple requires individual outputs not to be pytrees. out_spec={str(out_spec)}"
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -469,7 +469,7 @@ def _unlift_graph(
|
||||||
gm,
|
gm,
|
||||||
lifted_inputs,
|
lifted_inputs,
|
||||||
mutated_outputs,
|
mutated_outputs,
|
||||||
pytree.treespec_leaf(),
|
pytree.LeafSpec(),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
return unlifted_gm
|
return unlifted_gm
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,6 @@ from typing import Any, Optional, Union
|
||||||
import torch
|
import torch
|
||||||
import torch.ao.quantization.pt2e._affine_quantization # noqa: F401
|
import torch.ao.quantization.pt2e._affine_quantization # noqa: F401
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.utils._pytree as pytree
|
|
||||||
|
|
||||||
# Makes sure that quantized_decomposed ops are registered
|
# Makes sure that quantized_decomposed ops are registered
|
||||||
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
|
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
|
||||||
|
|
@ -15,6 +14,7 @@ from torch.ao.quantization.quantizer import QuantizationAnnotation
|
||||||
from torch.export.unflatten import _assign_attr, _AttrKind
|
from torch.export.unflatten import _assign_attr, _AttrKind
|
||||||
from torch.fx import GraphModule, Node
|
from torch.fx import GraphModule, Node
|
||||||
from torch.nn.utils.fusion import fuse_conv_bn_weights
|
from torch.nn.utils.fusion import fuse_conv_bn_weights
|
||||||
|
from torch.utils._pytree import LeafSpec
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
|
@ -477,10 +477,7 @@ def _replace_literals_with_new_placeholders(
|
||||||
exclude_literals = []
|
exclude_literals = []
|
||||||
|
|
||||||
in_spec = gm._in_spec
|
in_spec = gm._in_spec
|
||||||
assert in_spec.type is tuple
|
args_spec = in_spec.children_specs[0]
|
||||||
args_spec = in_spec.child(0)
|
|
||||||
assert args_spec.type is tuple
|
|
||||||
args_spec_children = args_spec.children()
|
|
||||||
for node in gm.graph.nodes:
|
for node in gm.graph.nodes:
|
||||||
if node.op == "placeholder":
|
if node.op == "placeholder":
|
||||||
last_ph = node
|
last_ph = node
|
||||||
|
|
@ -495,7 +492,7 @@ def _replace_literals_with_new_placeholders(
|
||||||
else:
|
else:
|
||||||
ph_node = gm.graph.placeholder("arg" + str(cnt))
|
ph_node = gm.graph.placeholder("arg" + str(cnt))
|
||||||
new_args.append(ph_node)
|
new_args.append(ph_node)
|
||||||
args_spec_children.append(pytree.treespec_leaf())
|
args_spec.children_specs.append(LeafSpec())
|
||||||
cnt += 1
|
cnt += 1
|
||||||
if merge_dup:
|
if merge_dup:
|
||||||
literal_to_ph[arg] = ph_node
|
literal_to_ph[arg] = ph_node
|
||||||
|
|
@ -506,8 +503,8 @@ def _replace_literals_with_new_placeholders(
|
||||||
node.args = new_args
|
node.args = new_args
|
||||||
|
|
||||||
# Update `num_nodes`, `num_leaves`, `num_children`.
|
# Update `num_nodes`, `num_leaves`, `num_children`.
|
||||||
args_spec = pytree.treespec_tuple(args_spec_children)
|
args_spec.__post_init__()
|
||||||
gm._in_spec = in_spec = pytree.treespec_tuple([args_spec, *in_spec.children()[1:]])
|
in_spec.__post_init__()
|
||||||
return gm
|
return gm
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -195,16 +195,17 @@ def _construct_inputs(
|
||||||
unflatten_node = _generate_unflatten(gm, tree_unflatten_args, signature.in_spec)
|
unflatten_node = _generate_unflatten(gm, tree_unflatten_args, signature.in_spec)
|
||||||
|
|
||||||
assert signature.in_spec.num_children == 2
|
assert signature.in_spec.num_children == 2
|
||||||
assert signature.in_spec.type is tuple
|
|
||||||
args_spec, kwargs_spec = signature.in_spec.children()
|
|
||||||
assert args_spec.type is tuple
|
|
||||||
assert kwargs_spec.type is dict
|
|
||||||
|
|
||||||
|
args_spec = signature.in_spec.children_specs[0]
|
||||||
|
assert args_spec.context is None
|
||||||
args_node = gm.graph.call_function(operator.getitem, (unflatten_node, 0))
|
args_node = gm.graph.call_function(operator.getitem, (unflatten_node, 0))
|
||||||
args_nodes = [
|
args_nodes = [
|
||||||
gm.graph.call_function(operator.getitem, (args_node, i))
|
gm.graph.call_function(operator.getitem, (args_node, i))
|
||||||
for i in range(args_spec.num_children)
|
for i in range(args_spec.num_children)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
kwargs_spec = signature.in_spec.children_specs[1]
|
||||||
|
assert kwargs_spec.context is not None
|
||||||
kwargs_node = gm.graph.call_function(operator.getitem, (unflatten_node, 1))
|
kwargs_node = gm.graph.call_function(operator.getitem, (unflatten_node, 1))
|
||||||
kwargs_nodes = {
|
kwargs_nodes = {
|
||||||
k: gm.graph.call_function(operator.getitem, (kwargs_node, k))
|
k: gm.graph.call_function(operator.getitem, (kwargs_node, k))
|
||||||
|
|
@ -371,8 +372,8 @@ def _fix_input_output_signature(
|
||||||
if forward_arg_names is None:
|
if forward_arg_names is None:
|
||||||
forward_arg_names = []
|
forward_arg_names = []
|
||||||
assert signature.in_spec.num_children == 2
|
assert signature.in_spec.num_children == 2
|
||||||
arg_spec = signature.in_spec.child(0)
|
arg_spec = signature.in_spec.children_specs[0]
|
||||||
kwarg_spec = signature.in_spec.child(1)
|
kwarg_spec = signature.in_spec.children_specs[1]
|
||||||
assert arg_spec.type is tuple
|
assert arg_spec.type is tuple
|
||||||
assert kwarg_spec.type is dict
|
assert kwarg_spec.type is dict
|
||||||
for i in range(arg_spec.num_children):
|
for i in range(arg_spec.num_children):
|
||||||
|
|
|
||||||
|
|
@ -1533,7 +1533,7 @@ def _strict_export(
|
||||||
|
|
||||||
# aot_export expect the return type to always be a tuple.
|
# aot_export expect the return type to always be a tuple.
|
||||||
if out_spec.type not in (list, tuple):
|
if out_spec.type not in (list, tuple):
|
||||||
out_spec = pytree.treespec_tuple([out_spec])
|
out_spec = pytree.TreeSpec(tuple, None, [out_spec])
|
||||||
|
|
||||||
orig_arg_names = gm_torch_level.graph._codegen.pytree_info.orig_args # type: ignore[attr-defined]
|
orig_arg_names = gm_torch_level.graph._codegen.pytree_info.orig_args # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ def reorder_kwargs(user_kwargs: dict[str, Any], spec: TreeSpec) -> dict[str, Any
|
||||||
# Make sure that the spec is actually shaped like (args, kwargs)
|
# Make sure that the spec is actually shaped like (args, kwargs)
|
||||||
assert spec.type is tuple
|
assert spec.type is tuple
|
||||||
assert spec.num_children == 2
|
assert spec.num_children == 2
|
||||||
kwargs_spec = spec.child(1)
|
kwargs_spec = spec.children_specs[1]
|
||||||
assert kwargs_spec.type is dict
|
assert kwargs_spec.type is dict
|
||||||
|
|
||||||
if set(user_kwargs) != set(kwargs_spec.context):
|
if set(user_kwargs) != set(kwargs_spec.context):
|
||||||
|
|
@ -55,10 +55,10 @@ def is_equivalent(
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Recurse on children
|
# Recurse on children
|
||||||
if spec1.num_children != spec2.num_children:
|
if len(spec1.children_specs) != len(spec2.children_specs):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
for child_spec1, child_spec2 in zip(spec1.children(), spec2.children()):
|
for child_spec1, child_spec2 in zip(spec1.children_specs, spec2.children_specs):
|
||||||
if not is_equivalent(child_spec1, child_spec2, equivalence_fn):
|
if not is_equivalent(child_spec1, child_spec2, equivalence_fn):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -57,11 +57,11 @@ def eq_spec(self: pytree.TreeSpec, other: pytree.TreeSpec) -> bool:
|
||||||
return False
|
return False
|
||||||
elif a.context != b.context:
|
elif a.context != b.context:
|
||||||
return False
|
return False
|
||||||
if a.num_children != b.num_children:
|
if len(a.children_specs) != len(b.children_specs):
|
||||||
return False
|
return False
|
||||||
return all(
|
return all(
|
||||||
_match_normalized_structure(a, b)
|
_match_normalized_structure(a, b)
|
||||||
for a, b in zip(a.children(), b.children())
|
for a, b in zip(a.children_specs, b.children_specs)
|
||||||
)
|
)
|
||||||
|
|
||||||
return _match_normalized_structure(self, other)
|
return _match_normalized_structure(self, other)
|
||||||
|
|
@ -357,13 +357,13 @@ def _get_codegen(
|
||||||
elif (
|
elif (
|
||||||
in_spec.type is tuple
|
in_spec.type is tuple
|
||||||
and in_spec.num_children == 2
|
and in_spec.num_children == 2
|
||||||
and in_spec.child(0).type is tuple
|
and in_spec.children_specs[0].type is tuple
|
||||||
and in_spec.child(1).type is dict
|
and in_spec.children_specs[1].type is dict
|
||||||
):
|
):
|
||||||
# if in_spec contains the args (tuple) and kwargs (dict)
|
# if in_spec contains the args (tuple) and kwargs (dict)
|
||||||
names = [f"arg_{i}" for i in range(in_spec.child(0).num_children)]
|
names = [f"arg_{i}" for i in range(in_spec.children_specs[0].num_children)]
|
||||||
# add kwarg names
|
# add kwarg names
|
||||||
names.extend(in_spec.child(1).context)
|
names.extend(in_spec.children_specs[1].context)
|
||||||
else:
|
else:
|
||||||
names = [f"arg_{i}" for i in range(in_spec.num_children)]
|
names = [f"arg_{i}" for i in range(in_spec.num_children)]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -12,16 +12,14 @@ import torch
|
||||||
from torch.utils._pytree import (
|
from torch.utils._pytree import (
|
||||||
_get_node_type,
|
_get_node_type,
|
||||||
BUILTIN_TYPES,
|
BUILTIN_TYPES,
|
||||||
KeyPath,
|
|
||||||
keystr,
|
keystr,
|
||||||
|
LeafSpec,
|
||||||
MappingKey,
|
MappingKey,
|
||||||
SequenceKey,
|
SequenceKey,
|
||||||
SUPPORTED_NODES,
|
SUPPORTED_NODES,
|
||||||
tree_iter,
|
tree_flatten,
|
||||||
tree_map,
|
tree_map,
|
||||||
tree_map_with_path,
|
tree_map_with_path,
|
||||||
tree_structure,
|
|
||||||
TreeSpec,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from .exported_program import ExportedProgram
|
from .exported_program import ExportedProgram
|
||||||
|
|
@ -657,55 +655,53 @@ def _tree_map_with_path(
|
||||||
case_name="dynamic_shapes_validation",
|
case_name="dynamic_shapes_validation",
|
||||||
)
|
)
|
||||||
|
|
||||||
def _compare(
|
def _compare(tree, dynamic_shapes, path):
|
||||||
treespec: TreeSpec, other_treespec: TreeSpec, path: KeyPath
|
|
||||||
) -> None:
|
|
||||||
# raise an error at the point where tree and dynamic_shapes differ,
|
# raise an error at the point where tree and dynamic_shapes differ,
|
||||||
# including the path to that point and the reason for the difference
|
# including the path to that point and the reason for the difference
|
||||||
rendered_path = keystr(path)
|
rendered_path = keystr(path)
|
||||||
if treespec.is_leaf():
|
if isinstance(tree, LeafSpec):
|
||||||
return
|
return
|
||||||
if other_treespec.is_leaf():
|
if isinstance(dynamic_shapes, LeafSpec):
|
||||||
raise_mismatch_error(
|
raise_mismatch_error(
|
||||||
f"`{tree_name}{rendered_path}` is a {treespec.type}, "
|
f"`{tree_name}{rendered_path}` is a {tree.type}, "
|
||||||
f"but `dynamic_shapes{rendered_path}` is not"
|
f"but `dynamic_shapes{rendered_path}` is not"
|
||||||
)
|
)
|
||||||
if treespec.type != other_treespec.type:
|
if tree.type != dynamic_shapes.type:
|
||||||
raise_mismatch_error(
|
raise_mismatch_error(
|
||||||
f"`{tree_name}{rendered_path}` is a {treespec.type}, "
|
f"`{tree_name}{rendered_path}` is a {tree.type}, "
|
||||||
f"but `dynamic_shapes{rendered_path}` is a {other_treespec.type}"
|
f"but `dynamic_shapes{rendered_path}` is a {dynamic_shapes.type}"
|
||||||
)
|
)
|
||||||
if treespec.num_children != other_treespec.num_children:
|
if len(tree.children_specs) != len(dynamic_shapes.children_specs):
|
||||||
raise_mismatch_error(
|
raise_mismatch_error(
|
||||||
f"`{tree_name}{rendered_path}` has {treespec.num_children} elements, "
|
f"`{tree_name}{rendered_path}` has {len(tree.children_specs)} elements, "
|
||||||
f"but `dynamic_shapes{rendered_path}` has {other_treespec.num_children} elements"
|
f"but `dynamic_shapes{rendered_path}` has {len(dynamic_shapes.children_specs)} elements"
|
||||||
)
|
)
|
||||||
if treespec.type is dict:
|
if tree.type is dict:
|
||||||
# context, children could be out of order
|
# context, children could be out of order
|
||||||
if set(treespec.context) != set(other_treespec.context):
|
if sorted(tree.context) != sorted(dynamic_shapes.context):
|
||||||
raise_mismatch_error(
|
raise_mismatch_error(
|
||||||
f"`{tree_name}{rendered_path}` has keys {treespec.context}, "
|
f"`{tree_name}{rendered_path}` has keys {tree.context}, "
|
||||||
f"but `dynamic_shapes{rendered_path}` has keys {other_treespec.context}"
|
f"but `dynamic_shapes{rendered_path}` has keys {dynamic_shapes.context}"
|
||||||
)
|
)
|
||||||
_remap = dict(
|
_remap = dict(
|
||||||
zip(other_treespec.context, other_treespec.children())
|
zip(dynamic_shapes.context, dynamic_shapes.children_specs)
|
||||||
)
|
)
|
||||||
other_children = [_remap[k] for k in treespec.context]
|
dynamic_shapes_children_specs = [_remap[k] for k in tree.context]
|
||||||
else:
|
else:
|
||||||
other_children = other_treespec.children()
|
dynamic_shapes_children_specs = dynamic_shapes.children_specs
|
||||||
for i, (child, other_child) in enumerate(
|
for i, (tree_, dynamic_shapes_) in enumerate(
|
||||||
zip(treespec.children(), other_children)
|
zip(tree.children_specs, dynamic_shapes_children_specs)
|
||||||
):
|
):
|
||||||
_compare(
|
_compare(
|
||||||
child,
|
tree_,
|
||||||
other_child,
|
dynamic_shapes_,
|
||||||
path + (_key(treespec.type, treespec.context, i),),
|
path + [_key(tree.type, tree.context, i)],
|
||||||
)
|
)
|
||||||
|
|
||||||
treespec = tree_structure(tree, is_leaf=is_leaf)
|
_, tree_spec = tree_flatten(tree, is_leaf=is_leaf)
|
||||||
for other_tree in dynamic_shapes:
|
for other_tree in dynamic_shapes:
|
||||||
other_treespec = tree_structure(other_tree, is_leaf)
|
_, other_tree_spec = tree_flatten(other_tree, is_leaf)
|
||||||
_compare(treespec, other_treespec, ())
|
_compare(tree_spec, other_tree_spec, [])
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1235,7 +1231,10 @@ def _get_dim_name_mapping(
|
||||||
dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any], None],
|
dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any], None],
|
||||||
):
|
):
|
||||||
name_to_dim = {}
|
name_to_dim = {}
|
||||||
for dim in tree_iter(dynamic_shapes, is_leaf=lambda x: isinstance(x, Dim)):
|
for dim in tree_flatten(
|
||||||
|
dynamic_shapes,
|
||||||
|
is_leaf=lambda x: isinstance(x, Dim),
|
||||||
|
)[0]:
|
||||||
if dim is None:
|
if dim is None:
|
||||||
# NOTE: this must denote a non-Tensor or automatic at this point.
|
# NOTE: this must denote a non-Tensor or automatic at this point.
|
||||||
continue
|
continue
|
||||||
|
|
|
||||||
|
|
@ -391,7 +391,7 @@ def _decompose_and_get_gm_with_new_signature_constants(
|
||||||
# aot_export expect the return type to always be a tuple.
|
# aot_export expect the return type to always be a tuple.
|
||||||
assert out_spec is not None
|
assert out_spec is not None
|
||||||
if out_spec.type not in (list, tuple):
|
if out_spec.type not in (list, tuple):
|
||||||
out_spec = pytree.treespec_tuple([out_spec])
|
out_spec = pytree.TreeSpec(tuple, None, [out_spec])
|
||||||
|
|
||||||
mod.graph._codegen = _PyTreeCodeGen(
|
mod.graph._codegen = _PyTreeCodeGen(
|
||||||
_PyTreeInfo(
|
_PyTreeInfo(
|
||||||
|
|
|
||||||
|
|
@ -1124,10 +1124,10 @@ class _ModuleFrame:
|
||||||
signature = module_call_graph.get(self.child_fqn)
|
signature = module_call_graph.get(self.child_fqn)
|
||||||
if signature is not None and self.parent is not None:
|
if signature is not None and self.parent is not None:
|
||||||
assert signature.in_spec.num_children == 2
|
assert signature.in_spec.num_children == 2
|
||||||
assert signature.in_spec.type is tuple
|
args_spec = signature.in_spec.children_specs[0]
|
||||||
args_spec, kwargs_spec = signature.in_spec.children()
|
kwargs_spec = signature.in_spec.children_specs[1]
|
||||||
assert args_spec.type is tuple
|
assert args_spec.context is None
|
||||||
assert kwargs_spec.type is dict
|
assert kwargs_spec.context is not None
|
||||||
|
|
||||||
with self.graph.inserting_after(None):
|
with self.graph.inserting_after(None):
|
||||||
arg_nodes = [
|
arg_nodes = [
|
||||||
|
|
|
||||||
|
|
@ -49,7 +49,7 @@ def tree_flatten_spec(
|
||||||
flatten_fn_spec = SUPPORTED_NODES[spec.type]
|
flatten_fn_spec = SUPPORTED_NODES[spec.type]
|
||||||
child_pytrees = flatten_fn_spec(pytree, spec)
|
child_pytrees = flatten_fn_spec(pytree, spec)
|
||||||
result = []
|
result = []
|
||||||
for child, child_spec in zip(child_pytrees, spec.children()):
|
for child, child_spec in zip(child_pytrees, spec.children_specs):
|
||||||
flat = tree_flatten_spec(child, child_spec)
|
flat = tree_flatten_spec(child, child_spec)
|
||||||
result += flat
|
result += flat
|
||||||
return result
|
return result
|
||||||
|
|
|
||||||
|
|
@ -709,7 +709,7 @@ class Tracer(TracerBase):
|
||||||
root_fn = _patch_function(root_fn, len(args))
|
root_fn = _patch_function(root_fn, len(args))
|
||||||
|
|
||||||
flat_args, in_spec = pytree.tree_flatten(tuple(args))
|
flat_args, in_spec = pytree.tree_flatten(tuple(args))
|
||||||
if not all(child.is_leaf() for child in in_spec.children()):
|
if not all(child.is_leaf() for child in in_spec.children_specs):
|
||||||
# In the case that we have pytree-flattened inputs in
|
# In the case that we have pytree-flattened inputs in
|
||||||
# `concrete_args`, generate a flattening wrapper around the
|
# `concrete_args`, generate a flattening wrapper around the
|
||||||
# original root function and return that.
|
# original root function and return that.
|
||||||
|
|
|
||||||
|
|
@ -933,25 +933,24 @@ class _PyTreeCodeGen(CodeGen):
|
||||||
return "\n " + "".join(x + "; " for x in has_annotation) + "\n"
|
return "\n " + "".join(x + "; " for x in has_annotation) + "\n"
|
||||||
|
|
||||||
def gen_var_bindings(self, fn_args, free_vars, expanded_def) -> str:
|
def gen_var_bindings(self, fn_args, free_vars, expanded_def) -> str:
|
||||||
in_spec = self.pytree_info.in_spec
|
|
||||||
# when kwargs is present, in_spec is tuple(args, kwargs)
|
# when kwargs is present, in_spec is tuple(args, kwargs)
|
||||||
has_args_kwargs_tuple = (
|
has_args_kwargs_tuple = (
|
||||||
in_spec.type is tuple
|
self.pytree_info.in_spec.type is tuple
|
||||||
and in_spec.num_children == 2
|
and self.pytree_info.in_spec.num_children == 2
|
||||||
and in_spec.child(0).type is tuple
|
and self.pytree_info.in_spec.children_specs[0].type is tuple
|
||||||
and in_spec.child(1).type is dict
|
and self.pytree_info.in_spec.children_specs[1].type is dict
|
||||||
)
|
)
|
||||||
fn_kwargs = "{}"
|
fn_kwargs = "{}"
|
||||||
fn_signature = f"[{', '.join(fn_args)}], self._in_spec"
|
fn_signature = f"[{', '.join(fn_args)}], self._in_spec"
|
||||||
if has_args_kwargs_tuple:
|
if has_args_kwargs_tuple:
|
||||||
count_args = in_spec.child(0).num_children
|
count_args = self.pytree_info.in_spec.children_specs[0].num_children
|
||||||
fn_args = self.pytree_info.orig_args[:count_args]
|
fn_args = self.pytree_info.orig_args[:count_args]
|
||||||
fn_kwargs = (
|
fn_kwargs = (
|
||||||
"{"
|
"{"
|
||||||
+ ", ".join(
|
+ ", ".join(
|
||||||
f"'{k}':{v}"
|
f"'{k}':{v}"
|
||||||
for k, v in zip(
|
for k, v in zip(
|
||||||
in_spec.child(1).context,
|
self.pytree_info.in_spec.children_specs[1].context,
|
||||||
self.pytree_info.orig_args[count_args:],
|
self.pytree_info.orig_args[count_args:],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -14,9 +14,9 @@ collection support for PyTorch APIs.
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
import types
|
import types
|
||||||
from collections.abc import Callable, Iterable, Mapping
|
from collections.abc import Callable, Iterable
|
||||||
from typing import Any, Optional, overload, TypeVar, Union
|
from typing import Any, Optional, overload, TypeVar, Union
|
||||||
from typing_extensions import deprecated, Self, TypeAlias, TypeIs
|
from typing_extensions import deprecated, TypeIs
|
||||||
|
|
||||||
import torch.utils._pytree as python_pytree
|
import torch.utils._pytree as python_pytree
|
||||||
from torch.torch_version import TorchVersion as _TorchVersion
|
from torch.torch_version import TorchVersion as _TorchVersion
|
||||||
|
|
@ -42,7 +42,7 @@ if not python_pytree._cxx_pytree_dynamo_traceable:
|
||||||
|
|
||||||
|
|
||||||
import optree
|
import optree
|
||||||
from optree import PyTreeSpec # direct import for type annotations
|
from optree import PyTreeSpec as TreeSpec # direct import for type annotations
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
|
@ -53,7 +53,6 @@ __all__ = [
|
||||||
"DumpableContext",
|
"DumpableContext",
|
||||||
"ToDumpableContextFn",
|
"ToDumpableContextFn",
|
||||||
"FromDumpableContextFn",
|
"FromDumpableContextFn",
|
||||||
"PyTreeSpec",
|
|
||||||
"TreeSpec",
|
"TreeSpec",
|
||||||
"LeafSpec",
|
"LeafSpec",
|
||||||
"keystr",
|
"keystr",
|
||||||
|
|
@ -101,8 +100,6 @@ U = TypeVar("U")
|
||||||
R = TypeVar("R")
|
R = TypeVar("R")
|
||||||
|
|
||||||
|
|
||||||
TreeSpec: TypeAlias = PyTreeSpec
|
|
||||||
|
|
||||||
Context = Any
|
Context = Any
|
||||||
PyTree = Any
|
PyTree = Any
|
||||||
FlattenFunc = Callable[[PyTree], tuple[list[Any], Context]]
|
FlattenFunc = Callable[[PyTree], tuple[list[Any], Context]]
|
||||||
|
|
@ -270,30 +267,6 @@ def _is_pytreespec_instance(obj: Any, /) -> TypeIs[TreeSpec]:
|
||||||
return isinstance(obj, TreeSpec)
|
return isinstance(obj, TreeSpec)
|
||||||
|
|
||||||
|
|
||||||
def treespec_leaf() -> TreeSpec:
|
|
||||||
"""Make a treespec representing a leaf node."""
|
|
||||||
return optree.treespec_leaf(none_is_leaf=True, namespace="torch")
|
|
||||||
|
|
||||||
|
|
||||||
def treespec_tuple(iterable: Iterable[TreeSpec] = (), /) -> TreeSpec:
|
|
||||||
"""Make a tuple treespec from an iterable of child treespecs."""
|
|
||||||
return optree.treespec_tuple(iterable, none_is_leaf=True, namespace="torch")
|
|
||||||
|
|
||||||
|
|
||||||
def treespec_dict(
|
|
||||||
mapping: Union[Mapping[Any, TreeSpec], Iterable[tuple[Any, TreeSpec]]] = (),
|
|
||||||
/,
|
|
||||||
**kwargs: TreeSpec,
|
|
||||||
) -> TreeSpec:
|
|
||||||
"""Make a dict treespec from a dict of child treespecs."""
|
|
||||||
return optree.treespec_dict(
|
|
||||||
mapping,
|
|
||||||
**kwargs,
|
|
||||||
none_is_leaf=True,
|
|
||||||
namespace="torch",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def tree_is_leaf(
|
def tree_is_leaf(
|
||||||
tree: PyTree,
|
tree: PyTree,
|
||||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||||
|
|
@ -1012,14 +985,9 @@ class LeafSpecMeta(type(TreeSpec)): # type: ignore[misc]
|
||||||
return _is_pytreespec_instance(instance) and instance.is_leaf()
|
return _is_pytreespec_instance(instance) and instance.is_leaf()
|
||||||
|
|
||||||
|
|
||||||
@deprecated(
|
|
||||||
"`isinstance(treespec, LeafSpec)` is deprecated, "
|
|
||||||
"use `isinstance(treespec, TreeSpec)` and `treespec.is_leaf()` instead.",
|
|
||||||
category=FutureWarning,
|
|
||||||
)
|
|
||||||
class LeafSpec(TreeSpec, metaclass=LeafSpecMeta): # type: ignore[misc,final]
|
class LeafSpec(TreeSpec, metaclass=LeafSpecMeta): # type: ignore[misc,final]
|
||||||
def __new__(cls) -> Self:
|
def __new__(cls) -> "LeafSpec":
|
||||||
return treespec_leaf() # type: ignore[return-value]
|
return optree.treespec_leaf(none_is_leaf=True) # type: ignore[return-value]
|
||||||
|
|
||||||
|
|
||||||
def tree_flatten_with_path(
|
def tree_flatten_with_path(
|
||||||
|
|
|
||||||
|
|
@ -39,7 +39,7 @@ from typing import (
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
from typing_extensions import deprecated, NamedTuple, Self, TypeAlias
|
from typing_extensions import deprecated, NamedTuple, Self
|
||||||
|
|
||||||
from torch.torch_version import TorchVersion as _TorchVersion
|
from torch.torch_version import TorchVersion as _TorchVersion
|
||||||
|
|
||||||
|
|
@ -52,7 +52,6 @@ __all__ = [
|
||||||
"DumpableContext",
|
"DumpableContext",
|
||||||
"ToDumpableContextFn",
|
"ToDumpableContextFn",
|
||||||
"FromDumpableContextFn",
|
"FromDumpableContextFn",
|
||||||
"PyTreeSpec",
|
|
||||||
"TreeSpec",
|
"TreeSpec",
|
||||||
"LeafSpec",
|
"LeafSpec",
|
||||||
"keystr",
|
"keystr",
|
||||||
|
|
@ -472,7 +471,7 @@ class ConstantNode:
|
||||||
|
|
||||||
def _is_constant_holder(spec: "TreeSpec") -> bool:
|
def _is_constant_holder(spec: "TreeSpec") -> bool:
|
||||||
"""Checks if the spec is from a pytree registered with register_constant"""
|
"""Checks if the spec is from a pytree registered with register_constant"""
|
||||||
return isinstance(spec._context, ConstantNode)
|
return isinstance(spec.context, ConstantNode)
|
||||||
|
|
||||||
|
|
||||||
def _retrieve_constant(spec: "TreeSpec") -> Any:
|
def _retrieve_constant(spec: "TreeSpec") -> Any:
|
||||||
|
|
@ -1072,53 +1071,35 @@ def _is_leaf(tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None) -
|
||||||
# context: some context that is useful in unflattening the pytree
|
# context: some context that is useful in unflattening the pytree
|
||||||
# children_specs: specs for each child of the root Node
|
# children_specs: specs for each child of the root Node
|
||||||
# num_leaves: the number of leaves
|
# num_leaves: the number of leaves
|
||||||
@dataclasses.dataclass(init=False, frozen=True, eq=True, repr=False)
|
@dataclasses.dataclass(init=True, frozen=True, eq=True, repr=False)
|
||||||
class TreeSpec:
|
class TreeSpec:
|
||||||
type: Any
|
type: Any
|
||||||
_context: Context
|
context: Context
|
||||||
_children: list[Self]
|
children_specs: list["TreeSpec"]
|
||||||
|
|
||||||
num_nodes: int = dataclasses.field(init=False)
|
num_nodes: int = dataclasses.field(init=False)
|
||||||
num_leaves: int = dataclasses.field(init=False)
|
num_leaves: int = dataclasses.field(init=False)
|
||||||
num_children: int = dataclasses.field(init=False)
|
num_children: int = dataclasses.field(init=False)
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
type: Any,
|
|
||||||
context: Context, # keep for backward compatibility
|
|
||||||
children: list[Self], # keep for backward compatibility
|
|
||||||
) -> None:
|
|
||||||
object.__setattr__(self, "type", type)
|
|
||||||
object.__setattr__(self, "_context", context)
|
|
||||||
object.__setattr__(self, "_children", children)
|
|
||||||
self.__post_init__()
|
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
if self.type is None:
|
num_nodes = sum((spec.num_nodes for spec in self.children_specs), start=1)
|
||||||
assert self._context is None
|
num_leaves = sum(spec.num_leaves for spec in self.children_specs)
|
||||||
assert len(self._children) == 0
|
num_children = len(self.children_specs)
|
||||||
num_nodes = 1
|
|
||||||
num_leaves = 1
|
|
||||||
num_children = 0
|
|
||||||
else:
|
|
||||||
num_nodes = sum((spec.num_nodes for spec in self._children), start=1)
|
|
||||||
num_leaves = sum(spec.num_leaves for spec in self._children)
|
|
||||||
num_children = len(self._children)
|
|
||||||
object.__setattr__(self, "num_nodes", num_nodes)
|
object.__setattr__(self, "num_nodes", num_nodes)
|
||||||
object.__setattr__(self, "num_leaves", num_leaves)
|
object.__setattr__(self, "num_leaves", num_leaves)
|
||||||
object.__setattr__(self, "num_children", num_children)
|
object.__setattr__(self, "num_children", num_children)
|
||||||
|
|
||||||
def __repr__(self, indent: int = 0) -> str:
|
def __repr__(self, indent: int = 0) -> str:
|
||||||
repr_prefix: str = f"TreeSpec({self.type.__name__}, {self._context}, ["
|
repr_prefix: str = f"TreeSpec({self.type.__name__}, {self.context}, ["
|
||||||
children_specs_str: str = ""
|
children_specs_str: str = ""
|
||||||
if self.num_children > 0:
|
if self.num_children > 0:
|
||||||
indent += 2
|
indent += 2
|
||||||
children_specs_str += self._children[0].__repr__(indent)
|
children_specs_str += self.children_specs[0].__repr__(indent)
|
||||||
children_specs_str += "," if self.num_children > 1 else ""
|
children_specs_str += "," if self.num_children > 1 else ""
|
||||||
children_specs_str += ",".join(
|
children_specs_str += ",".join(
|
||||||
[
|
[
|
||||||
"\n" + " " * indent + child.__repr__(indent)
|
"\n" + " " * indent + child.__repr__(indent)
|
||||||
for child in self._children[1:]
|
for child in self.children_specs[1:]
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
repr_suffix: str = f"{children_specs_str}])"
|
repr_suffix: str = f"{children_specs_str}])"
|
||||||
|
|
@ -1130,36 +1111,16 @@ class TreeSpec:
|
||||||
elif other.__class__ is self.__class__:
|
elif other.__class__ is self.__class__:
|
||||||
if str(self.type) != str(other.type):
|
if str(self.type) != str(other.type):
|
||||||
return False
|
return False
|
||||||
if self._context != other._context:
|
if self.context != other.context:
|
||||||
return False
|
return False
|
||||||
elif self._children != other._children:
|
elif self.children_specs != other.children_specs:
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
return NotImplemented
|
return NotImplemented
|
||||||
|
|
||||||
@property
|
|
||||||
def context(self) -> Context:
|
|
||||||
return self._context
|
|
||||||
|
|
||||||
@property
|
|
||||||
@deprecated(
|
|
||||||
"`treespec.children_specs` is deprecated. "
|
|
||||||
"Use `treespec.child(index)` to access a single child, "
|
|
||||||
"or `treespec.children()` to get all children.",
|
|
||||||
category=FutureWarning,
|
|
||||||
)
|
|
||||||
def children_specs(self) -> list[Self]:
|
|
||||||
return self._children
|
|
||||||
|
|
||||||
def is_leaf(self) -> bool:
|
def is_leaf(self) -> bool:
|
||||||
return self.num_nodes == 1 and self.num_leaves == 1
|
return self.num_nodes == 1 and self.num_leaves == 1
|
||||||
|
|
||||||
def children(self) -> list[Self]:
|
|
||||||
return self._children.copy()
|
|
||||||
|
|
||||||
def child(self, index: int) -> Self:
|
|
||||||
return self._children[index]
|
|
||||||
|
|
||||||
def flatten_up_to(self, tree: PyTree) -> list[PyTree]:
|
def flatten_up_to(self, tree: PyTree) -> list[PyTree]:
|
||||||
def helper(treespec: TreeSpec, tree: PyTree, subtrees: list[PyTree]) -> None:
|
def helper(treespec: TreeSpec, tree: PyTree, subtrees: list[PyTree]) -> None:
|
||||||
if treespec.is_leaf():
|
if treespec.is_leaf():
|
||||||
|
|
@ -1181,7 +1142,7 @@ class TreeSpec:
|
||||||
f"Node arity mismatch; "
|
f"Node arity mismatch; "
|
||||||
f"expected {treespec.num_children}, but got {len(children)}.",
|
f"expected {treespec.num_children}, but got {len(children)}.",
|
||||||
)
|
)
|
||||||
if context != treespec._context:
|
if context != treespec.context:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Node context mismatch for custom node type {treespec.type!r}.",
|
f"Node context mismatch for custom node type {treespec.type!r}.",
|
||||||
)
|
)
|
||||||
|
|
@ -1206,10 +1167,10 @@ class TreeSpec:
|
||||||
if both_standard_dict:
|
if both_standard_dict:
|
||||||
# dictionary types are compatible with each other
|
# dictionary types are compatible with each other
|
||||||
dict_context = (
|
dict_context = (
|
||||||
treespec._context
|
treespec.context
|
||||||
if treespec.type is not defaultdict
|
if treespec.type is not defaultdict
|
||||||
# ignore mismatch of `default_factory` for defaultdict
|
# ignore mismatch of `default_factory` for defaultdict
|
||||||
else treespec._context[1]
|
else treespec.context[1]
|
||||||
)
|
)
|
||||||
expected_keys = dict_context
|
expected_keys = dict_context
|
||||||
got_key_set = set(tree)
|
got_key_set = set(tree)
|
||||||
|
|
@ -1230,13 +1191,13 @@ class TreeSpec:
|
||||||
children, context = flatten_fn(tree)
|
children, context = flatten_fn(tree)
|
||||||
if (
|
if (
|
||||||
node_type is not deque # ignore mismatch of `maxlen` for deque
|
node_type is not deque # ignore mismatch of `maxlen` for deque
|
||||||
) and context != treespec._context:
|
) and context != treespec.context:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Node context mismatch for node type {treespec.type!r}; "
|
f"Node context mismatch for node type {treespec.type!r}; "
|
||||||
f"expected {treespec._context!r}, but got {context!r}.", # namedtuple type mismatch
|
f"expected {treespec.context!r}, but got {context!r}.", # namedtuple type mismatch
|
||||||
)
|
)
|
||||||
|
|
||||||
for subtree, subspec in zip(children, treespec._children, strict=True):
|
for subtree, subspec in zip(children, treespec.children_specs):
|
||||||
helper(subspec, subtree, subtrees)
|
helper(subspec, subtree, subtrees)
|
||||||
|
|
||||||
subtrees: list[PyTree] = []
|
subtrees: list[PyTree] = []
|
||||||
|
|
@ -1261,24 +1222,24 @@ class TreeSpec:
|
||||||
start = 0
|
start = 0
|
||||||
end = 0
|
end = 0
|
||||||
child_pytrees = []
|
child_pytrees = []
|
||||||
for child_spec in self._children:
|
for child_spec in self.children_specs:
|
||||||
end += child_spec.num_leaves
|
end += child_spec.num_leaves
|
||||||
child_pytrees.append(child_spec.unflatten(leaves[start:end]))
|
child_pytrees.append(child_spec.unflatten(leaves[start:end]))
|
||||||
start = end
|
start = end
|
||||||
|
|
||||||
return unflatten_fn(child_pytrees, self._context)
|
return unflatten_fn(child_pytrees, self.context)
|
||||||
|
|
||||||
def __hash__(self) -> int:
|
def __hash__(self) -> int:
|
||||||
node_type = self.type
|
node_type = self.type
|
||||||
if node_type is defaultdict:
|
if node_type is defaultdict:
|
||||||
default_factory, dict_context = self._context
|
default_factory, dict_context = self.context
|
||||||
hashable_context = (default_factory, tuple(dict_context))
|
hashable_context = (default_factory, tuple(dict_context))
|
||||||
elif node_type in (dict, OrderedDict):
|
elif node_type in (dict, OrderedDict):
|
||||||
hashable_context = tuple(self._context)
|
hashable_context = tuple(self.context)
|
||||||
elif node_type is None or node_type in BUILTIN_TYPES:
|
elif node_type is None or node_type in BUILTIN_TYPES:
|
||||||
hashable_context = self._context
|
hashable_context = self.context
|
||||||
elif isinstance(self._context, ConstantNode):
|
elif isinstance(self.context, ConstantNode):
|
||||||
hashable_context = self._context.value
|
hashable_context = self.context.value
|
||||||
else:
|
else:
|
||||||
# The context for user-defined node types might not be hashable.
|
# The context for user-defined node types might not be hashable.
|
||||||
# Ignore it for hashing.
|
# Ignore it for hashing.
|
||||||
|
|
@ -1286,26 +1247,20 @@ class TreeSpec:
|
||||||
# same hash. This might increase the hash collision rate, but we
|
# same hash. This might increase the hash collision rate, but we
|
||||||
# don't care about that.
|
# don't care about that.
|
||||||
hashable_context = None
|
hashable_context = None
|
||||||
return hash((node_type, hashable_context, tuple(self._children)))
|
return hash((node_type, hashable_context, tuple(self.children_specs)))
|
||||||
|
|
||||||
|
|
||||||
PyTreeSpec: TypeAlias = TreeSpec
|
|
||||||
|
|
||||||
|
|
||||||
# NOTE: subclassing a dataclass is subtle. In order to enable reasoning about
|
# NOTE: subclassing a dataclass is subtle. In order to enable reasoning about
|
||||||
# this class with `dataclasses.fields`, etc., while having a simplified
|
# this class with `dataclasses.fields`, etc., while having a simplified
|
||||||
# constructor that takes no argument, we wrap with `dataclass(init=True, ...)`
|
# constructor that takes no argument, we wrap with `dataclass(init=True, ...)`
|
||||||
# again, with fields that have `init=False`.
|
# again, with fields that have `init=False`.
|
||||||
@deprecated(
|
|
||||||
"`isinstance(treespec, LeafSpec)` is deprecated, "
|
|
||||||
"use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.",
|
|
||||||
category=FutureWarning,
|
|
||||||
)
|
|
||||||
@dataclasses.dataclass(init=True, frozen=True, eq=False, repr=False)
|
@dataclasses.dataclass(init=True, frozen=True, eq=False, repr=False)
|
||||||
class LeafSpec(TreeSpec):
|
class LeafSpec(TreeSpec):
|
||||||
type: Any = dataclasses.field(default=None, init=False)
|
type: Any = dataclasses.field(default=None, init=False)
|
||||||
_context: Context = dataclasses.field(default=None, init=False)
|
context: Context = dataclasses.field(default=None, init=False)
|
||||||
_children: list[Self] = dataclasses.field(default_factory=list, init=False)
|
children_specs: list["TreeSpec"] = dataclasses.field(
|
||||||
|
default_factory=list, init=False
|
||||||
|
)
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
# Override `__post_init__` for `num_leaves` derivation.
|
# Override `__post_init__` for `num_leaves` derivation.
|
||||||
|
|
@ -1319,36 +1274,7 @@ class LeafSpec(TreeSpec):
|
||||||
|
|
||||||
# All leaves are equivalent, so represent with a single object to save on
|
# All leaves are equivalent, so represent with a single object to save on
|
||||||
# object construction time
|
# object construction time
|
||||||
with warnings.catch_warnings():
|
_LEAF_SPEC = LeafSpec()
|
||||||
warnings.filterwarnings(
|
|
||||||
"ignore", category=FutureWarning, module=__name__, append=False
|
|
||||||
)
|
|
||||||
_LEAF_SPEC = LeafSpec()
|
|
||||||
|
|
||||||
|
|
||||||
def treespec_leaf() -> LeafSpec:
|
|
||||||
"""Make a treespec representing a leaf node."""
|
|
||||||
return _LEAF_SPEC
|
|
||||||
|
|
||||||
|
|
||||||
def treespec_tuple(iterable: Iterable[TreeSpec] = (), /) -> TreeSpec:
|
|
||||||
"""Make a tuple treespec from an iterable of child treespecs."""
|
|
||||||
children = list(iterable)
|
|
||||||
if any(not isinstance(child, TreeSpec) for child in children):
|
|
||||||
raise ValueError(f"Expected a tuple of TreeSpec values, got: {children!r}.")
|
|
||||||
return TreeSpec(tuple, None, children)
|
|
||||||
|
|
||||||
|
|
||||||
def treespec_dict(
|
|
||||||
mapping: Union[Mapping[Any, TreeSpec], Iterable[tuple[Any, TreeSpec]]] = (),
|
|
||||||
/,
|
|
||||||
**kwargs: TreeSpec,
|
|
||||||
) -> TreeSpec:
|
|
||||||
"""Make a dict treespec from a dict of child treespecs."""
|
|
||||||
dct = dict(mapping, **kwargs)
|
|
||||||
if any(not isinstance(child, TreeSpec) for child in dct.values()):
|
|
||||||
raise ValueError(f"Expected a dictionary of TreeSpec values, got: {dct!r}.")
|
|
||||||
return TreeSpec(dict, list(dct.keys()), list(dct.values()))
|
|
||||||
|
|
||||||
|
|
||||||
def tree_flatten(
|
def tree_flatten(
|
||||||
|
|
@ -1827,15 +1753,15 @@ def _broadcast_to_and_flatten(
|
||||||
return None
|
return None
|
||||||
|
|
||||||
flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
|
flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
|
||||||
child_pytrees, context = flatten_fn(tree)
|
child_pytrees, ctx = flatten_fn(tree)
|
||||||
|
|
||||||
# Check if the Node is different from the spec
|
# Check if the Node is different from the spec
|
||||||
if len(child_pytrees) != treespec.num_children or context != treespec._context:
|
if len(child_pytrees) != treespec.num_children or ctx != treespec.context:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Recursively flatten the children
|
# Recursively flatten the children
|
||||||
result: list[Any] = []
|
result: list[Any] = []
|
||||||
for child, child_spec in zip(child_pytrees, treespec._children, strict=True):
|
for child, child_spec in zip(child_pytrees, treespec.children_specs):
|
||||||
flat = _broadcast_to_and_flatten(child, child_spec, is_leaf=is_leaf)
|
flat = _broadcast_to_and_flatten(child, child_spec, is_leaf=is_leaf)
|
||||||
if flat is not None:
|
if flat is not None:
|
||||||
result += flat
|
result += flat
|
||||||
|
|
@ -1889,7 +1815,7 @@ def _treespec_to_json(treespec: TreeSpec) -> _TreeSpecSchema:
|
||||||
|
|
||||||
if serialize_node_def.to_dumpable_context is None:
|
if serialize_node_def.to_dumpable_context is None:
|
||||||
try:
|
try:
|
||||||
serialized_context = json.dumps(treespec._context, cls=EnumEncoder)
|
serialized_context = json.dumps(treespec.context, cls=EnumEncoder)
|
||||||
except TypeError as e:
|
except TypeError as e:
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
"Unable to serialize context. "
|
"Unable to serialize context. "
|
||||||
|
|
@ -1897,9 +1823,9 @@ def _treespec_to_json(treespec: TreeSpec) -> _TreeSpecSchema:
|
||||||
"custom serializer using _register_pytree_node."
|
"custom serializer using _register_pytree_node."
|
||||||
) from e
|
) from e
|
||||||
else:
|
else:
|
||||||
serialized_context = serialize_node_def.to_dumpable_context(treespec._context)
|
serialized_context = serialize_node_def.to_dumpable_context(treespec.context)
|
||||||
|
|
||||||
child_schemas = [_treespec_to_json(child) for child in treespec._children]
|
child_schemas = [_treespec_to_json(child) for child in treespec.children_specs]
|
||||||
|
|
||||||
return _TreeSpecSchema(serialized_type_name, serialized_context, child_schemas)
|
return _TreeSpecSchema(serialized_type_name, serialized_context, child_schemas)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user