mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[pytree] fix previously failed dynamo tests (#148669)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148669 Approved by: https://github.com/zou3519
This commit is contained in:
parent
28b68b46bc
commit
097b0d372a
|
|
@ -8,7 +8,7 @@ import re
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
import unittest
|
import unittest
|
||||||
from collections import defaultdict, deque, namedtuple, OrderedDict, UserDict
|
from collections import defaultdict, namedtuple, OrderedDict, UserDict
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import auto
|
from enum import auto
|
||||||
from typing import Any, NamedTuple
|
from typing import Any, NamedTuple
|
||||||
|
|
@ -21,7 +21,6 @@ from torch.testing._internal.common_utils import (
|
||||||
IS_FBCODE,
|
IS_FBCODE,
|
||||||
parametrize,
|
parametrize,
|
||||||
run_tests,
|
run_tests,
|
||||||
skipIfTorchDynamo,
|
|
||||||
subtest,
|
subtest,
|
||||||
TestCase,
|
TestCase,
|
||||||
)
|
)
|
||||||
|
|
@ -405,7 +404,9 @@ class TestGenericPytree(TestCase):
|
||||||
(
|
(
|
||||||
py_pytree,
|
py_pytree,
|
||||||
lambda deq: py_pytree.TreeSpec(
|
lambda deq: py_pytree.TreeSpec(
|
||||||
deque, deq.maxlen, [py_pytree.LeafSpec() for _ in deq]
|
collections.deque,
|
||||||
|
deq.maxlen,
|
||||||
|
[py_pytree.LeafSpec() for _ in deq],
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
name="py",
|
name="py",
|
||||||
|
|
@ -414,7 +415,7 @@ class TestGenericPytree(TestCase):
|
||||||
(
|
(
|
||||||
cxx_pytree,
|
cxx_pytree,
|
||||||
lambda deq: cxx_pytree.tree_structure(
|
lambda deq: cxx_pytree.tree_structure(
|
||||||
deque(deq, maxlen=deq.maxlen)
|
collections.deque(deq, maxlen=deq.maxlen)
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
name="cxx",
|
name="cxx",
|
||||||
|
|
@ -432,11 +433,11 @@ class TestGenericPytree(TestCase):
|
||||||
unflattened = pytree_impl.tree_unflatten(values, treespec)
|
unflattened = pytree_impl.tree_unflatten(values, treespec)
|
||||||
self.assertEqual(unflattened, deq)
|
self.assertEqual(unflattened, deq)
|
||||||
self.assertEqual(unflattened.maxlen, deq.maxlen)
|
self.assertEqual(unflattened.maxlen, deq.maxlen)
|
||||||
self.assertIsInstance(unflattened, deque)
|
self.assertIsInstance(unflattened, collections.deque)
|
||||||
|
|
||||||
run_test(deque([]))
|
run_test(collections.deque([]))
|
||||||
run_test(deque([1.0, 2]))
|
run_test(collections.deque([1.0, 2]))
|
||||||
run_test(deque([torch.tensor([1.0, 2]), 2, 10, 9, 11], maxlen=8))
|
run_test(collections.deque([torch.tensor([1.0, 2]), 2, 10, 9, 11], maxlen=8))
|
||||||
|
|
||||||
@parametrize(
|
@parametrize(
|
||||||
"pytree_impl",
|
"pytree_impl",
|
||||||
|
|
@ -1242,7 +1243,6 @@ if "optree" in sys.modules:
|
||||||
from_one_tree = py_pytree.tree_map(lambda a: a + 2, tree1)
|
from_one_tree = py_pytree.tree_map(lambda a: a + 2, tree1)
|
||||||
self.assertEqual(from_two_trees, from_one_tree)
|
self.assertEqual(from_two_trees, from_one_tree)
|
||||||
|
|
||||||
@skipIfTorchDynamo("dynamo pytree tracing doesn't work here")
|
|
||||||
def test_tree_flatten_with_path_is_leaf(self):
|
def test_tree_flatten_with_path_is_leaf(self):
|
||||||
leaf_dict = {"foo": [(3)]}
|
leaf_dict = {"foo": [(3)]}
|
||||||
pytree = (["hello", [1, 2], leaf_dict],)
|
pytree = (["hello", [1, 2], leaf_dict],)
|
||||||
|
|
@ -1331,7 +1331,6 @@ if "optree" in sys.modules:
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@skipIfTorchDynamo("AssertionError in dynamo")
|
|
||||||
def test_flatten_flatten_with_key_consistency(self):
|
def test_flatten_flatten_with_key_consistency(self):
|
||||||
"""Check that flatten and flatten_with_key produces consistent leaves/context."""
|
"""Check that flatten and flatten_with_key produces consistent leaves/context."""
|
||||||
reg = py_pytree.SUPPORTED_NODES
|
reg = py_pytree.SUPPORTED_NODES
|
||||||
|
|
@ -1340,10 +1339,10 @@ if "optree" in sys.modules:
|
||||||
list: [1, 2, 3],
|
list: [1, 2, 3],
|
||||||
tuple: (1, 2, 3),
|
tuple: (1, 2, 3),
|
||||||
dict: {"foo": 1, "bar": 2},
|
dict: {"foo": 1, "bar": 2},
|
||||||
namedtuple: collections.namedtuple("ANamedTuple", ["x", "y"])(1, 2),
|
namedtuple: namedtuple("ANamedTuple", ["x", "y"])(1, 2),
|
||||||
OrderedDict: OrderedDict([("foo", 1), ("bar", 2)]),
|
OrderedDict: OrderedDict([("foo", 1), ("bar", 2)]),
|
||||||
defaultdict: defaultdict(int, {"foo": 1, "bar": 2}),
|
defaultdict: defaultdict(int, {"foo": 1, "bar": 2}),
|
||||||
deque: deque([1, 2, 3]),
|
collections.deque: collections.deque([1, 2, 3]),
|
||||||
torch.Size: torch.Size([1, 2, 3]),
|
torch.Size: torch.Size([1, 2, 3]),
|
||||||
immutable_dict: immutable_dict({"foo": 1, "bar": 2}),
|
immutable_dict: immutable_dict({"foo": 1, "bar": 2}),
|
||||||
immutable_list: immutable_list([1, 2, 3]),
|
immutable_list: immutable_list([1, 2, 3]),
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user