mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[pytree][Easy] preserve dict keys in insertion order in CXX pytree (#130140)
`optree` and JAX pytree traversal the `dict` in sorted key ordering (see [Key Ordering for Dictionaries](https://github.com/metaopt/optree#key-ordering-for-dictionaries)). While in PyTorch Python pytree, we traversal the `dict` in insertion order. See also: - #114392 This aligns the behavior of CXX pytree with Python pytree. Pull Request resolved: https://github.com/pytorch/pytorch/pull/130140 Approved by: https://github.com/zou3519
This commit is contained in:
parent
1f8ff94d4f
commit
9abaaad6a8
|
|
@ -23,7 +23,6 @@ from torch.testing._internal.common_utils import (
|
||||||
run_tests,
|
run_tests,
|
||||||
skipIfTorchDynamo,
|
skipIfTorchDynamo,
|
||||||
subtest,
|
subtest,
|
||||||
TEST_WITH_TORCHDYNAMO,
|
|
||||||
TestCase,
|
TestCase,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -805,7 +804,6 @@ if "optree" in sys.modules:
|
||||||
py_pytree.TreeSpec(tuple, None, []) != py_pytree.TreeSpec(list, None, []),
|
py_pytree.TreeSpec(tuple, None, []) != py_pytree.TreeSpec(list, None, []),
|
||||||
)
|
)
|
||||||
|
|
||||||
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "Dynamo test in test_treespec_repr_dynamo.")
|
|
||||||
def test_treespec_repr(self):
|
def test_treespec_repr(self):
|
||||||
# Check that it looks sane
|
# Check that it looks sane
|
||||||
pytree = (0, [0, 0, [0]])
|
pytree = (0, [0, 0, [0]])
|
||||||
|
|
@ -820,20 +818,6 @@ if "optree" in sys.modules:
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@unittest.skipIf(not TEST_WITH_TORCHDYNAMO, "Eager test in test_treespec_repr.")
|
|
||||||
def test_treespec_repr_dynamo(self):
|
|
||||||
# Check that it looks sane
|
|
||||||
pytree = (0, [0, 0, [0]])
|
|
||||||
_, spec = py_pytree.tree_flatten(pytree)
|
|
||||||
self.assertExpectedInline(
|
|
||||||
repr(spec),
|
|
||||||
"""\
|
|
||||||
TreeSpec(tuple, None, [*,
|
|
||||||
TreeSpec(list, None, [*,
|
|
||||||
*,
|
|
||||||
TreeSpec(list, None, [*])])])""",
|
|
||||||
)
|
|
||||||
|
|
||||||
@parametrize(
|
@parametrize(
|
||||||
"spec",
|
"spec",
|
||||||
[
|
[
|
||||||
|
|
@ -1365,21 +1349,12 @@ class TestCxxPytree(TestCase):
|
||||||
def test_treespec_equality(self):
|
def test_treespec_equality(self):
|
||||||
self.assertEqual(cxx_pytree.LeafSpec(), cxx_pytree.LeafSpec())
|
self.assertEqual(cxx_pytree.LeafSpec(), cxx_pytree.LeafSpec())
|
||||||
|
|
||||||
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "Dynamo test in test_treespec_repr_dynamo.")
|
|
||||||
def test_treespec_repr(self):
|
def test_treespec_repr(self):
|
||||||
# Check that it looks sane
|
# Check that it looks sane
|
||||||
pytree = (0, [0, 0, [0]])
|
pytree = (0, [0, 0, [0]])
|
||||||
_, spec = cxx_pytree.tree_flatten(pytree)
|
_, spec = cxx_pytree.tree_flatten(pytree)
|
||||||
self.assertEqual(repr(spec), "PyTreeSpec((*, [*, *, [*]]), NoneIsLeaf)")
|
self.assertEqual(
|
||||||
|
repr(spec), "PyTreeSpec((*, [*, *, [*]]), NoneIsLeaf, namespace='torch')"
|
||||||
@unittest.skipIf(not TEST_WITH_TORCHDYNAMO, "Eager test in test_treespec_repr.")
|
|
||||||
def test_treespec_repr_dynamo(self):
|
|
||||||
# Check that it looks sane
|
|
||||||
pytree = (0, [0, 0, [0]])
|
|
||||||
_, spec = cxx_pytree.tree_flatten(pytree)
|
|
||||||
self.assertExpectedInline(
|
|
||||||
repr(spec),
|
|
||||||
"PyTreeSpec((*, [*, *, [*]]), NoneIsLeaf, namespace='torch')",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@parametrize(
|
@parametrize(
|
||||||
|
|
|
||||||
|
|
@ -61,6 +61,10 @@ __all__ = [
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
__TORCH_DICT_SESSION = optree.dict_insertion_ordered(True, namespace="torch")
|
||||||
|
__TORCH_DICT_SESSION.__enter__() # enable globally and permanently
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
S = TypeVar("S")
|
S = TypeVar("S")
|
||||||
U = TypeVar("U")
|
U = TypeVar("U")
|
||||||
|
|
@ -285,20 +289,15 @@ def tree_flatten(
|
||||||
|
|
||||||
>>> tree = {"b": (2, [3, 4]), "a": 1, "c": None, "d": 5}
|
>>> tree = {"b": (2, [3, 4]), "a": 1, "c": None, "d": 5}
|
||||||
>>> tree_flatten(tree)
|
>>> tree_flatten(tree)
|
||||||
([1, 2, 3, 4, None, 5], PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}, NoneIsLeaf))
|
([2, 3, 4, 1, None, 5], PyTreeSpec({'b': (*, [*, *]), 'a': *, 'c': *, 'd': *}, NoneIsLeaf, namespace='torch'))
|
||||||
>>> tree_flatten(1)
|
>>> tree_flatten(1)
|
||||||
([1], PyTreeSpec(*, NoneIsLeaf))
|
([1], PyTreeSpec(*, NoneIsLeaf, namespace='torch'))
|
||||||
>>> tree_flatten(None)
|
>>> tree_flatten(None)
|
||||||
([None], PyTreeSpec(*, NoneIsLeaf))
|
([None], PyTreeSpec(*, NoneIsLeaf, namespace='torch'))
|
||||||
|
|
||||||
For unordered dictionaries, :class:`dict` and :class:`collections.defaultdict`, the order is
|
|
||||||
dependent on the **sorted** keys in the dictionary. Please use :class:`collections.OrderedDict`
|
|
||||||
if you want to keep the keys in the insertion order.
|
|
||||||
|
|
||||||
>>> from collections import OrderedDict
|
>>> from collections import OrderedDict
|
||||||
>>> tree = OrderedDict([("b", (2, [3, 4])), ("a", 1), ("c", None), ("d", 5)])
|
>>> tree = OrderedDict([("b", (2, [3, 4])), ("a", 1), ("c", None), ("d", 5)])
|
||||||
>>> tree_flatten(tree)
|
>>> tree_flatten(tree)
|
||||||
([2, 3, 4, 1, None, 5], PyTreeSpec(OrderedDict({'b': (*, [*, *]), 'a': *, 'c': *, 'd': *}), NoneIsLeaf))
|
([2, 3, 4, 1, None, 5], PyTreeSpec(OrderedDict({'b': (*, [*, *]), 'a': *, 'c': *, 'd': *}), NoneIsLeaf, namespace='torch'))
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tree (pytree): A pytree to flatten.
|
tree (pytree): A pytree to flatten.
|
||||||
|
|
@ -357,7 +356,7 @@ def tree_iter(
|
||||||
|
|
||||||
>>> tree = {"b": (2, [3, 4]), "a": 1, "c": None, "d": 5}
|
>>> tree = {"b": (2, [3, 4]), "a": 1, "c": None, "d": 5}
|
||||||
>>> list(tree_iter(tree))
|
>>> list(tree_iter(tree))
|
||||||
[1, 2, 3, 4, None, 5]
|
[2, 3, 4, 1, None, 5]
|
||||||
>>> list(tree_iter(1))
|
>>> list(tree_iter(1))
|
||||||
[1]
|
[1]
|
||||||
>>> list(tree_iter(None))
|
>>> list(tree_iter(None))
|
||||||
|
|
@ -392,7 +391,7 @@ def tree_leaves(
|
||||||
|
|
||||||
>>> tree = {"b": (2, [3, 4]), "a": 1, "c": None, "d": 5}
|
>>> tree = {"b": (2, [3, 4]), "a": 1, "c": None, "d": 5}
|
||||||
>>> tree_leaves(tree)
|
>>> tree_leaves(tree)
|
||||||
[1, 2, 3, 4, None, 5]
|
[2, 3, 4, 1, None, 5]
|
||||||
>>> tree_leaves(1)
|
>>> tree_leaves(1)
|
||||||
[1]
|
[1]
|
||||||
>>> tree_leaves(None)
|
>>> tree_leaves(None)
|
||||||
|
|
@ -427,11 +426,11 @@ def tree_structure(
|
||||||
|
|
||||||
>>> tree = {"b": (2, [3, 4]), "a": 1, "c": None, "d": 5}
|
>>> tree = {"b": (2, [3, 4]), "a": 1, "c": None, "d": 5}
|
||||||
>>> tree_structure(tree)
|
>>> tree_structure(tree)
|
||||||
PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}, NoneIsLeaf)
|
PyTreeSpec({'b': (*, [*, *]), 'a': *, 'c': *, 'd': *}, NoneIsLeaf, namespace='torch')
|
||||||
>>> tree_structure(1)
|
>>> tree_structure(1)
|
||||||
PyTreeSpec(*, NoneIsLeaf)
|
PyTreeSpec(*, NoneIsLeaf, namespace='torch')
|
||||||
>>> tree_structure(None)
|
>>> tree_structure(None)
|
||||||
PyTreeSpec(*, NoneIsLeaf)
|
PyTreeSpec(*, NoneIsLeaf, namespace='torch')
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tree (pytree): A pytree to flatten.
|
tree (pytree): A pytree to flatten.
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user