mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Renaming `set_fullgraph` to `error_on_graph_break` for now. There are no semantic differences yet. In a followup PR, we will introduce a new `torch.compile` option `error_on_graph_break` that has lower priority than `fullgraph` so that `fullgraph` really returns 1 graph. I could keep `set_fullgraph` as a deprecated alias for `error_on_graph_break` for now, but I'm hoping that won't be necessary since it's still private API (there are no internal callsites yet, and there are no significant OSS callsites yet). cc @albanD @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames @Lucaskabela @mlazos @guilhermeleobas @xmfan as primary users for `set_fullgraph` Pull Request resolved: https://github.com/pytorch/pytorch/pull/161739 Approved by: https://github.com/xmfan, https://github.com/Lucaskabela, https://github.com/anijain2305, https://github.com/mlazos
184 lines
5.9 KiB
Diff
184 lines
5.9 KiB
Diff
diff --git a/test/dynamo/cpython/3_13/seq_tests.py b/test/dynamo/cpython/3_13/seq_tests.py
|
|
index 719c9434a16..290e57c04a0 100644
|
|
--- a/test/dynamo/cpython/3_13/seq_tests.py
|
|
+++ b/test/dynamo/cpython/3_13/seq_tests.py
|
|
@@ -1,3 +1,57 @@
|
|
+# ======= BEGIN Dynamo patch =======
|
|
+# Owner(s): ["module: dynamo"]
|
|
+
|
|
+# ruff: noqa
|
|
+# flake8: noqa
|
|
+
|
|
+# Test copied from
|
|
+# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/seq_tests.py
|
|
+
|
|
+import sys
|
|
+import torch
|
|
+import torch._dynamo.test_case
|
|
+import unittest
|
|
+from torch._dynamo.test_case import CPythonTestCase
|
|
+from torch.testing._internal.common_utils import run_tests
|
|
+
|
|
+__TestCase = CPythonTestCase
|
|
+
|
|
+
|
|
+# redirect import statements
|
|
+import sys
|
|
+import importlib.abc
|
|
+
|
|
+redirect_imports = (
|
|
+ "test.mapping_tests",
|
|
+ "test.typinganndata",
|
|
+ "test.test_grammar",
|
|
+ "test.test_math",
|
|
+ "test.test_iter",
|
|
+ "test.typinganndata.ann_module",
|
|
+)
|
|
+
|
|
+class RedirectImportFinder(importlib.abc.MetaPathFinder):
|
|
+ def find_spec(self, fullname, path, target=None):
|
|
+ # Check if the import is the problematic one
|
|
+ if fullname in redirect_imports:
|
|
+ try:
|
|
+ # Attempt to import the standalone module
|
|
+ name = fullname.removeprefix("test.")
|
|
+ r = importlib.import_module(name)
|
|
+ # Redirect the module in sys.modules
|
|
+ sys.modules[fullname] = r
|
|
+ # Return a module spec from the found module
|
|
+ return importlib.util.find_spec(name)
|
|
+ except ImportError:
|
|
+ return None
|
|
+ return None
|
|
+
|
|
+# Add the custom finder to sys.meta_path
|
|
+sys.meta_path.insert(0, RedirectImportFinder())
|
|
+
|
|
+
|
|
+# ======= END DYNAMO PATCH =======
|
|
+
|
|
"""
|
|
Tests common to tuple, list and UserList.UserList
|
|
"""
|
|
@@ -95,7 +149,7 @@ class LyingList(list):
|
|
def __iter__(self):
|
|
yield 1
|
|
|
|
-class CommonTest(unittest.TestCase):
|
|
+class CommonTest(__TestCase):
|
|
# The type to be tested
|
|
type2test = None
|
|
|
|
@@ -115,13 +169,14 @@ class CommonTest(unittest.TestCase):
|
|
uu2 = self.type2test(u2)
|
|
|
|
v = self.type2test(tuple(u))
|
|
- class OtherSeq:
|
|
- def __init__(self, initseq):
|
|
- self.__data = initseq
|
|
- def __len__(self):
|
|
- return len(self.__data)
|
|
- def __getitem__(self, i):
|
|
- return self.__data[i]
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class OtherSeq:
|
|
+ def __init__(self, initseq):
|
|
+ self.__data = initseq
|
|
+ def __len__(self):
|
|
+ return len(self.__data)
|
|
+ def __getitem__(self, i):
|
|
+ return self.__data[i]
|
|
s = OtherSeq(u0)
|
|
v0 = self.type2test(s)
|
|
self.assertEqual(len(v0), len(s))
|
|
@@ -239,11 +294,12 @@ class CommonTest(unittest.TestCase):
|
|
# Sequences must test in-order. If a rich comparison has side
|
|
# effects, these will be visible to tests against later members.
|
|
# In this test, the "side effect" is a short-circuiting raise.
|
|
- class DoNotTestEq(Exception):
|
|
- pass
|
|
- class StopCompares:
|
|
- def __eq__(self, other):
|
|
- raise DoNotTestEq
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class DoNotTestEq(Exception):
|
|
+ pass
|
|
+ class StopCompares:
|
|
+ def __eq__(self, other):
|
|
+ raise DoNotTestEq
|
|
|
|
checkfirst = self.type2test([1, StopCompares()])
|
|
self.assertIn(1, checkfirst)
|
|
@@ -283,8 +339,9 @@ class CommonTest(unittest.TestCase):
|
|
self.assertEqual(u2+u2+u2, u2*3)
|
|
self.assertEqual(u2+u2+u2, 3*u2)
|
|
|
|
- class subclass(self.type2test):
|
|
- pass
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class subclass(self.type2test):
|
|
+ pass
|
|
u3 = subclass([0, 1])
|
|
self.assertEqual(u3, u3*1)
|
|
self.assertIsNot(u3, u3*1)
|
|
@@ -311,9 +368,10 @@ class CommonTest(unittest.TestCase):
|
|
|
|
def test_getitemoverwriteiter(self):
|
|
# Verify that __getitem__ overrides are not recognized by __iter__
|
|
- class T(self.type2test):
|
|
- def __getitem__(self, key):
|
|
- return str(key) + '!!!'
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class T(self.type2test):
|
|
+ def __getitem__(self, key):
|
|
+ return str(key) + '!!!'
|
|
self.assertEqual(next(iter(T((1,2)))), 1)
|
|
|
|
def test_repeat(self):
|
|
@@ -361,14 +419,15 @@ class CommonTest(unittest.TestCase):
|
|
|
|
self.assertRaises(TypeError, a.count)
|
|
|
|
- class BadExc(Exception):
|
|
- pass
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class BadExc(Exception):
|
|
+ pass
|
|
|
|
- class BadCmp:
|
|
- def __eq__(self, other):
|
|
- if other == 2:
|
|
- raise BadExc()
|
|
- return False
|
|
+ class BadCmp:
|
|
+ def __eq__(self, other):
|
|
+ if other == 2:
|
|
+ raise BadExc()
|
|
+ return False
|
|
|
|
self.assertRaises(BadExc, a.count, BadCmp())
|
|
|
|
@@ -394,14 +453,15 @@ class CommonTest(unittest.TestCase):
|
|
|
|
self.assertRaises(TypeError, u.index)
|
|
|
|
- class BadExc(Exception):
|
|
- pass
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class BadExc(Exception):
|
|
+ pass
|
|
|
|
- class BadCmp:
|
|
- def __eq__(self, other):
|
|
- if other == 2:
|
|
- raise BadExc()
|
|
- return False
|
|
+ class BadCmp:
|
|
+ def __eq__(self, other):
|
|
+ if other == 2:
|
|
+ raise BadExc()
|
|
+ return False
|
|
|
|
a = self.type2test([0, 1, 2, 3])
|
|
self.assertRaises(BadExc, a.index, BadCmp())
|