[pytree] fix previously failed dynamo tests (#148669)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148669
Approved by: https://github.com/zou3519
This commit is contained in:
Xuehai Pan 2025-03-06 23:18:50 +08:00 committed by PyTorch MergeBot
parent 28b68b46bc
commit 097b0d372a
3 changed files with 11 additions and 12 deletions

View File

@ -8,7 +8,7 @@ import re
import subprocess import subprocess
import sys import sys
import unittest import unittest
from collections import defaultdict, deque, namedtuple, OrderedDict, UserDict from collections import defaultdict, namedtuple, OrderedDict, UserDict
from dataclasses import dataclass from dataclasses import dataclass
from enum import auto from enum import auto
from typing import Any, NamedTuple from typing import Any, NamedTuple
@ -21,7 +21,6 @@ from torch.testing._internal.common_utils import (
IS_FBCODE, IS_FBCODE,
parametrize, parametrize,
run_tests, run_tests,
skipIfTorchDynamo,
subtest, subtest,
TestCase, TestCase,
) )
@ -405,7 +404,9 @@ class TestGenericPytree(TestCase):
( (
py_pytree, py_pytree,
lambda deq: py_pytree.TreeSpec( lambda deq: py_pytree.TreeSpec(
deque, deq.maxlen, [py_pytree.LeafSpec() for _ in deq] collections.deque,
deq.maxlen,
[py_pytree.LeafSpec() for _ in deq],
), ),
), ),
name="py", name="py",
@ -414,7 +415,7 @@ class TestGenericPytree(TestCase):
( (
cxx_pytree, cxx_pytree,
lambda deq: cxx_pytree.tree_structure( lambda deq: cxx_pytree.tree_structure(
deque(deq, maxlen=deq.maxlen) collections.deque(deq, maxlen=deq.maxlen)
), ),
), ),
name="cxx", name="cxx",
@ -432,11 +433,11 @@ class TestGenericPytree(TestCase):
unflattened = pytree_impl.tree_unflatten(values, treespec) unflattened = pytree_impl.tree_unflatten(values, treespec)
self.assertEqual(unflattened, deq) self.assertEqual(unflattened, deq)
self.assertEqual(unflattened.maxlen, deq.maxlen) self.assertEqual(unflattened.maxlen, deq.maxlen)
self.assertIsInstance(unflattened, deque) self.assertIsInstance(unflattened, collections.deque)
run_test(deque([])) run_test(collections.deque([]))
run_test(deque([1.0, 2])) run_test(collections.deque([1.0, 2]))
run_test(deque([torch.tensor([1.0, 2]), 2, 10, 9, 11], maxlen=8)) run_test(collections.deque([torch.tensor([1.0, 2]), 2, 10, 9, 11], maxlen=8))
@parametrize( @parametrize(
"pytree_impl", "pytree_impl",
@ -1242,7 +1243,6 @@ if "optree" in sys.modules:
from_one_tree = py_pytree.tree_map(lambda a: a + 2, tree1) from_one_tree = py_pytree.tree_map(lambda a: a + 2, tree1)
self.assertEqual(from_two_trees, from_one_tree) self.assertEqual(from_two_trees, from_one_tree)
@skipIfTorchDynamo("dynamo pytree tracing doesn't work here")
def test_tree_flatten_with_path_is_leaf(self): def test_tree_flatten_with_path_is_leaf(self):
leaf_dict = {"foo": [(3)]} leaf_dict = {"foo": [(3)]}
pytree = (["hello", [1, 2], leaf_dict],) pytree = (["hello", [1, 2], leaf_dict],)
@ -1331,7 +1331,6 @@ if "optree" in sys.modules:
], ],
) )
@skipIfTorchDynamo("AssertionError in dynamo")
def test_flatten_flatten_with_key_consistency(self): def test_flatten_flatten_with_key_consistency(self):
"""Check that flatten and flatten_with_key produces consistent leaves/context.""" """Check that flatten and flatten_with_key produces consistent leaves/context."""
reg = py_pytree.SUPPORTED_NODES reg = py_pytree.SUPPORTED_NODES
@ -1340,10 +1339,10 @@ if "optree" in sys.modules:
list: [1, 2, 3], list: [1, 2, 3],
tuple: (1, 2, 3), tuple: (1, 2, 3),
dict: {"foo": 1, "bar": 2}, dict: {"foo": 1, "bar": 2},
namedtuple: collections.namedtuple("ANamedTuple", ["x", "y"])(1, 2), namedtuple: namedtuple("ANamedTuple", ["x", "y"])(1, 2),
OrderedDict: OrderedDict([("foo", 1), ("bar", 2)]), OrderedDict: OrderedDict([("foo", 1), ("bar", 2)]),
defaultdict: defaultdict(int, {"foo": 1, "bar": 2}), defaultdict: defaultdict(int, {"foo": 1, "bar": 2}),
deque: deque([1, 2, 3]), collections.deque: collections.deque([1, 2, 3]),
torch.Size: torch.Size([1, 2, 3]), torch.Size: torch.Size([1, 2, 3]),
immutable_dict: immutable_dict({"foo": 1, "bar": 2}), immutable_dict: immutable_dict({"foo": 1, "bar": 2}),
immutable_list: immutable_list([1, 2, 3]), immutable_list: immutable_list([1, 2, 3]),