mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[pytree] fix previously failed dynamo tests (#148669)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148669 Approved by: https://github.com/zou3519
This commit is contained in:
parent
28b68b46bc
commit
097b0d372a
|
|
@ -8,7 +8,7 @@ import re
|
|||
import subprocess
|
||||
import sys
|
||||
import unittest
|
||||
from collections import defaultdict, deque, namedtuple, OrderedDict, UserDict
|
||||
from collections import defaultdict, namedtuple, OrderedDict, UserDict
|
||||
from dataclasses import dataclass
|
||||
from enum import auto
|
||||
from typing import Any, NamedTuple
|
||||
|
|
@ -21,7 +21,6 @@ from torch.testing._internal.common_utils import (
|
|||
IS_FBCODE,
|
||||
parametrize,
|
||||
run_tests,
|
||||
skipIfTorchDynamo,
|
||||
subtest,
|
||||
TestCase,
|
||||
)
|
||||
|
|
@ -405,7 +404,9 @@ class TestGenericPytree(TestCase):
|
|||
(
|
||||
py_pytree,
|
||||
lambda deq: py_pytree.TreeSpec(
|
||||
deque, deq.maxlen, [py_pytree.LeafSpec() for _ in deq]
|
||||
collections.deque,
|
||||
deq.maxlen,
|
||||
[py_pytree.LeafSpec() for _ in deq],
|
||||
),
|
||||
),
|
||||
name="py",
|
||||
|
|
@ -414,7 +415,7 @@ class TestGenericPytree(TestCase):
|
|||
(
|
||||
cxx_pytree,
|
||||
lambda deq: cxx_pytree.tree_structure(
|
||||
deque(deq, maxlen=deq.maxlen)
|
||||
collections.deque(deq, maxlen=deq.maxlen)
|
||||
),
|
||||
),
|
||||
name="cxx",
|
||||
|
|
@ -432,11 +433,11 @@ class TestGenericPytree(TestCase):
|
|||
unflattened = pytree_impl.tree_unflatten(values, treespec)
|
||||
self.assertEqual(unflattened, deq)
|
||||
self.assertEqual(unflattened.maxlen, deq.maxlen)
|
||||
self.assertIsInstance(unflattened, deque)
|
||||
self.assertIsInstance(unflattened, collections.deque)
|
||||
|
||||
run_test(deque([]))
|
||||
run_test(deque([1.0, 2]))
|
||||
run_test(deque([torch.tensor([1.0, 2]), 2, 10, 9, 11], maxlen=8))
|
||||
run_test(collections.deque([]))
|
||||
run_test(collections.deque([1.0, 2]))
|
||||
run_test(collections.deque([torch.tensor([1.0, 2]), 2, 10, 9, 11], maxlen=8))
|
||||
|
||||
@parametrize(
|
||||
"pytree_impl",
|
||||
|
|
@ -1242,7 +1243,6 @@ if "optree" in sys.modules:
|
|||
from_one_tree = py_pytree.tree_map(lambda a: a + 2, tree1)
|
||||
self.assertEqual(from_two_trees, from_one_tree)
|
||||
|
||||
@skipIfTorchDynamo("dynamo pytree tracing doesn't work here")
|
||||
def test_tree_flatten_with_path_is_leaf(self):
|
||||
leaf_dict = {"foo": [(3)]}
|
||||
pytree = (["hello", [1, 2], leaf_dict],)
|
||||
|
|
@ -1331,7 +1331,6 @@ if "optree" in sys.modules:
|
|||
],
|
||||
)
|
||||
|
||||
@skipIfTorchDynamo("AssertionError in dynamo")
|
||||
def test_flatten_flatten_with_key_consistency(self):
|
||||
"""Check that flatten and flatten_with_key produces consistent leaves/context."""
|
||||
reg = py_pytree.SUPPORTED_NODES
|
||||
|
|
@ -1340,10 +1339,10 @@ if "optree" in sys.modules:
|
|||
list: [1, 2, 3],
|
||||
tuple: (1, 2, 3),
|
||||
dict: {"foo": 1, "bar": 2},
|
||||
namedtuple: collections.namedtuple("ANamedTuple", ["x", "y"])(1, 2),
|
||||
namedtuple: namedtuple("ANamedTuple", ["x", "y"])(1, 2),
|
||||
OrderedDict: OrderedDict([("foo", 1), ("bar", 2)]),
|
||||
defaultdict: defaultdict(int, {"foo": 1, "bar": 2}),
|
||||
deque: deque([1, 2, 3]),
|
||||
collections.deque: collections.deque([1, 2, 3]),
|
||||
torch.Size: torch.Size([1, 2, 3]),
|
||||
immutable_dict: immutable_dict({"foo": 1, "bar": 2}),
|
||||
immutable_list: immutable_list([1, 2, 3]),
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user