pytorch/test/dynamo/cpython/3_13/seq_tests.diff
William Wen 8678d831c4 [dynamo] rename set_fullgraph to error_on_graph_break (#161739)
Renaming `set_fullgraph` to `error_on_graph_break` for now. There are no semantic differences yet. In a followup PR, we will introduce a new `torch.compile` option `error_on_graph_break` that has lower priority than `fullgraph` so that `fullgraph` really returns 1 graph.

I could keep `set_fullgraph` as a deprecated alias for `error_on_graph_break` for now, but I'm hoping that won't be necessary since it's still private API (there are no internal callsites yet, and there are no significant OSS callsites yet).

 cc @albanD @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames @Lucaskabela @mlazos @guilhermeleobas @xmfan as primary users for `set_fullgraph`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161739
Approved by: https://github.com/xmfan, https://github.com/Lucaskabela, https://github.com/anijain2305, https://github.com/mlazos
2025-09-04 01:15:06 +00:00

184 lines
5.9 KiB
Diff

diff --git a/test/dynamo/cpython/3_13/seq_tests.py b/test/dynamo/cpython/3_13/seq_tests.py
index 719c9434a16..290e57c04a0 100644
--- a/test/dynamo/cpython/3_13/seq_tests.py
+++ b/test/dynamo/cpython/3_13/seq_tests.py
@@ -1,3 +1,57 @@
+# ======= BEGIN Dynamo patch =======
+# Owner(s): ["module: dynamo"]
+
+# ruff: noqa
+# flake8: noqa
+
+# Test copied from
+# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/seq_tests.py
+
+import sys
+import torch
+import torch._dynamo.test_case
+import unittest
+from torch._dynamo.test_case import CPythonTestCase
+from torch.testing._internal.common_utils import run_tests
+
+__TestCase = CPythonTestCase
+
+
+# redirect import statements
+import sys
+import importlib.abc
+
+redirect_imports = (
+ "test.mapping_tests",
+ "test.typinganndata",
+ "test.test_grammar",
+ "test.test_math",
+ "test.test_iter",
+ "test.typinganndata.ann_module",
+)
+
+class RedirectImportFinder(importlib.abc.MetaPathFinder):
+ def find_spec(self, fullname, path, target=None):
+ # Check if the import is the problematic one
+ if fullname in redirect_imports:
+ try:
+ # Attempt to import the standalone module
+ name = fullname.removeprefix("test.")
+ r = importlib.import_module(name)
+ # Redirect the module in sys.modules
+ sys.modules[fullname] = r
+ # Return a module spec from the found module
+ return importlib.util.find_spec(name)
+ except ImportError:
+ return None
+ return None
+
+# Add the custom finder to sys.meta_path
+sys.meta_path.insert(0, RedirectImportFinder())
+
+
+# ======= END DYNAMO PATCH =======
+
"""
Tests common to tuple, list and UserList.UserList
"""
@@ -95,7 +149,7 @@ class LyingList(list):
def __iter__(self):
yield 1
-class CommonTest(unittest.TestCase):
+class CommonTest(__TestCase):
# The type to be tested
type2test = None
@@ -115,13 +169,14 @@ class CommonTest(unittest.TestCase):
uu2 = self.type2test(u2)
v = self.type2test(tuple(u))
- class OtherSeq:
- def __init__(self, initseq):
- self.__data = initseq
- def __len__(self):
- return len(self.__data)
- def __getitem__(self, i):
- return self.__data[i]
+ with torch._dynamo.error_on_graph_break(False):
+ class OtherSeq:
+ def __init__(self, initseq):
+ self.__data = initseq
+ def __len__(self):
+ return len(self.__data)
+ def __getitem__(self, i):
+ return self.__data[i]
s = OtherSeq(u0)
v0 = self.type2test(s)
self.assertEqual(len(v0), len(s))
@@ -239,11 +294,12 @@ class CommonTest(unittest.TestCase):
# Sequences must test in-order. If a rich comparison has side
# effects, these will be visible to tests against later members.
# In this test, the "side effect" is a short-circuiting raise.
- class DoNotTestEq(Exception):
- pass
- class StopCompares:
- def __eq__(self, other):
- raise DoNotTestEq
+ with torch._dynamo.error_on_graph_break(False):
+ class DoNotTestEq(Exception):
+ pass
+ class StopCompares:
+ def __eq__(self, other):
+ raise DoNotTestEq
checkfirst = self.type2test([1, StopCompares()])
self.assertIn(1, checkfirst)
@@ -283,8 +339,9 @@ class CommonTest(unittest.TestCase):
self.assertEqual(u2+u2+u2, u2*3)
self.assertEqual(u2+u2+u2, 3*u2)
- class subclass(self.type2test):
- pass
+ with torch._dynamo.error_on_graph_break(False):
+ class subclass(self.type2test):
+ pass
u3 = subclass([0, 1])
self.assertEqual(u3, u3*1)
self.assertIsNot(u3, u3*1)
@@ -311,9 +368,10 @@ class CommonTest(unittest.TestCase):
def test_getitemoverwriteiter(self):
# Verify that __getitem__ overrides are not recognized by __iter__
- class T(self.type2test):
- def __getitem__(self, key):
- return str(key) + '!!!'
+ with torch._dynamo.error_on_graph_break(False):
+ class T(self.type2test):
+ def __getitem__(self, key):
+ return str(key) + '!!!'
self.assertEqual(next(iter(T((1,2)))), 1)
def test_repeat(self):
@@ -361,14 +419,15 @@ class CommonTest(unittest.TestCase):
self.assertRaises(TypeError, a.count)
- class BadExc(Exception):
- pass
+ with torch._dynamo.error_on_graph_break(False):
+ class BadExc(Exception):
+ pass
- class BadCmp:
- def __eq__(self, other):
- if other == 2:
- raise BadExc()
- return False
+ class BadCmp:
+ def __eq__(self, other):
+ if other == 2:
+ raise BadExc()
+ return False
self.assertRaises(BadExc, a.count, BadCmp())
@@ -394,14 +453,15 @@ class CommonTest(unittest.TestCase):
self.assertRaises(TypeError, u.index)
- class BadExc(Exception):
- pass
+ with torch._dynamo.error_on_graph_break(False):
+ class BadExc(Exception):
+ pass
- class BadCmp:
- def __eq__(self, other):
- if other == 2:
- raise BadExc()
- return False
+ class BadCmp:
+ def __eq__(self, other):
+ if other == 2:
+ raise BadExc()
+ return False
a = self.type2test([0, 1, 2, 3])
self.assertRaises(BadExc, a.index, BadCmp())