[pytree] fix previously failed dynamo tests (#148669)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148669
Approved by: https://github.com/zou3519
This commit is contained in:
Xuehai Pan 2025-03-06 23:18:50 +08:00 committed by PyTorch MergeBot
parent 28b68b46bc
commit 097b0d372a
3 changed files with 11 additions and 12 deletions

View File

@ -8,7 +8,7 @@ import re
import subprocess
import sys
import unittest
from collections import defaultdict, deque, namedtuple, OrderedDict, UserDict
from collections import defaultdict, namedtuple, OrderedDict, UserDict
from dataclasses import dataclass
from enum import auto
from typing import Any, NamedTuple
@ -21,7 +21,6 @@ from torch.testing._internal.common_utils import (
IS_FBCODE,
parametrize,
run_tests,
skipIfTorchDynamo,
subtest,
TestCase,
)
@ -405,7 +404,9 @@ class TestGenericPytree(TestCase):
(
py_pytree,
lambda deq: py_pytree.TreeSpec(
deque, deq.maxlen, [py_pytree.LeafSpec() for _ in deq]
collections.deque,
deq.maxlen,
[py_pytree.LeafSpec() for _ in deq],
),
),
name="py",
@ -414,7 +415,7 @@ class TestGenericPytree(TestCase):
(
cxx_pytree,
lambda deq: cxx_pytree.tree_structure(
deque(deq, maxlen=deq.maxlen)
collections.deque(deq, maxlen=deq.maxlen)
),
),
name="cxx",
@ -432,11 +433,11 @@ class TestGenericPytree(TestCase):
unflattened = pytree_impl.tree_unflatten(values, treespec)
self.assertEqual(unflattened, deq)
self.assertEqual(unflattened.maxlen, deq.maxlen)
self.assertIsInstance(unflattened, deque)
self.assertIsInstance(unflattened, collections.deque)
run_test(deque([]))
run_test(deque([1.0, 2]))
run_test(deque([torch.tensor([1.0, 2]), 2, 10, 9, 11], maxlen=8))
run_test(collections.deque([]))
run_test(collections.deque([1.0, 2]))
run_test(collections.deque([torch.tensor([1.0, 2]), 2, 10, 9, 11], maxlen=8))
@parametrize(
"pytree_impl",
@ -1242,7 +1243,6 @@ if "optree" in sys.modules:
from_one_tree = py_pytree.tree_map(lambda a: a + 2, tree1)
self.assertEqual(from_two_trees, from_one_tree)
@skipIfTorchDynamo("dynamo pytree tracing doesn't work here")
def test_tree_flatten_with_path_is_leaf(self):
leaf_dict = {"foo": [(3)]}
pytree = (["hello", [1, 2], leaf_dict],)
@ -1331,7 +1331,6 @@ if "optree" in sys.modules:
],
)
@skipIfTorchDynamo("AssertionError in dynamo")
def test_flatten_flatten_with_key_consistency(self):
"""Check that flatten and flatten_with_key produces consistent leaves/context."""
reg = py_pytree.SUPPORTED_NODES
@ -1340,10 +1339,10 @@ if "optree" in sys.modules:
list: [1, 2, 3],
tuple: (1, 2, 3),
dict: {"foo": 1, "bar": 2},
namedtuple: collections.namedtuple("ANamedTuple", ["x", "y"])(1, 2),
namedtuple: namedtuple("ANamedTuple", ["x", "y"])(1, 2),
OrderedDict: OrderedDict([("foo", 1), ("bar", 2)]),
defaultdict: defaultdict(int, {"foo": 1, "bar": 2}),
deque: deque([1, 2, 3]),
collections.deque: collections.deque([1, 2, 3]),
torch.Size: torch.Size([1, 2, 3]),
immutable_dict: immutable_dict({"foo": 1, "bar": 2}),
immutable_list: immutable_list([1, 2, 3]),