mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Renaming `set_fullgraph` to `error_on_graph_break` for now. There are no semantic differences yet. In a followup PR, we will introduce a new `torch.compile` option `error_on_graph_break` that has lower priority than `fullgraph` so that `fullgraph` really returns 1 graph. I could keep `set_fullgraph` as a deprecated alias for `error_on_graph_break` for now, but I'm hoping that won't be necessary since it's still private API (there are no internal callsites yet, and there are no significant OSS callsites yet). cc @albanD @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames @Lucaskabela @mlazos @guilhermeleobas @xmfan as primary users for `set_fullgraph` Pull Request resolved: https://github.com/pytorch/pytorch/pull/161739 Approved by: https://github.com/xmfan, https://github.com/Lucaskabela, https://github.com/anijain2305, https://github.com/mlazos
134 lines
4.4 KiB
Diff
134 lines
4.4 KiB
Diff
diff --git a/test/dynamo/cpython/3_13/test_tuple.py b/test/dynamo/cpython/3_13/test_tuple.py
|
|
index 9ce80c5e8ea..1080e85e31a 100644
|
|
--- a/test/dynamo/cpython/3_13/test_tuple.py
|
|
+++ b/test/dynamo/cpython/3_13/test_tuple.py
|
|
@@ -1,4 +1,58 @@
|
|
-from test import support, seq_tests
|
|
+# ======= BEGIN Dynamo patch =======
|
|
+# Owner(s): ["module: dynamo"]
|
|
+
|
|
+# ruff: noqa
|
|
+# flake8: noqa
|
|
+
|
|
+# Test copied from
|
|
+# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_tuple.py
|
|
+
|
|
+import sys
|
|
+import torch
|
|
+import torch._dynamo.test_case
|
|
+import unittest
|
|
+from torch._dynamo.test_case import CPythonTestCase
|
|
+from torch.testing._internal.common_utils import run_tests
|
|
+
|
|
+__TestCase = CPythonTestCase
|
|
+
|
|
+# redirect import statements
|
|
+import sys
|
|
+import importlib.abc
|
|
+
|
|
+redirect_imports = (
|
|
+ "test.mapping_tests",
|
|
+ "test.typinganndata",
|
|
+ "test.test_grammar",
|
|
+ "test.test_math",
|
|
+ "test.test_iter",
|
|
+ "test.typinganndata.ann_module",
|
|
+)
|
|
+
|
|
+class RedirectImportFinder(importlib.abc.MetaPathFinder):
|
|
+ def find_spec(self, fullname, path, target=None):
|
|
+ # Check if the import is the problematic one
|
|
+ if fullname in redirect_imports:
|
|
+ try:
|
|
+ # Attempt to import the standalone module
|
|
+ name = fullname.removeprefix("test.")
|
|
+ r = importlib.import_module(name)
|
|
+ # Redirect the module in sys.modules
|
|
+ sys.modules[fullname] = r
|
|
+ # Return a module spec from the found module
|
|
+ return importlib.util.find_spec(name)
|
|
+ except ImportError:
|
|
+ return None
|
|
+ return None
|
|
+
|
|
+# Add the custom finder to sys.meta_path
|
|
+sys.meta_path.insert(0, RedirectImportFinder())
|
|
+
|
|
+
|
|
+# ======= END DYNAMO PATCH =======
|
|
+
|
|
+from test import support
|
|
+import seq_tests
|
|
import unittest
|
|
|
|
import gc
|
|
@@ -43,27 +97,30 @@ class TupleTest(seq_tests.CommonTest):
|
|
tuple(sequence=())
|
|
|
|
def test_keywords_in_subclass(self):
|
|
- class subclass(tuple):
|
|
- pass
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class subclass(tuple):
|
|
+ pass
|
|
u = subclass([1, 2])
|
|
self.assertIs(type(u), subclass)
|
|
self.assertEqual(list(u), [1, 2])
|
|
with self.assertRaises(TypeError):
|
|
subclass(sequence=())
|
|
|
|
- class subclass_with_init(tuple):
|
|
- def __init__(self, arg, newarg=None):
|
|
- self.newarg = newarg
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class subclass_with_init(tuple):
|
|
+ def __init__(self, arg, newarg=None):
|
|
+ self.newarg = newarg
|
|
u = subclass_with_init([1, 2], newarg=3)
|
|
self.assertIs(type(u), subclass_with_init)
|
|
self.assertEqual(list(u), [1, 2])
|
|
self.assertEqual(u.newarg, 3)
|
|
|
|
- class subclass_with_new(tuple):
|
|
- def __new__(cls, arg, newarg=None):
|
|
- self = super().__new__(cls, arg)
|
|
- self.newarg = newarg
|
|
- return self
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class subclass_with_new(tuple):
|
|
+ def __new__(cls, arg, newarg=None):
|
|
+ self = super().__new__(cls, arg)
|
|
+ self.newarg = newarg
|
|
+ return self
|
|
u = subclass_with_new([1, 2], newarg=3)
|
|
self.assertIs(type(u), subclass_with_new)
|
|
self.assertEqual(list(u), [1, 2])
|
|
@@ -351,8 +408,9 @@ class TupleTest(seq_tests.CommonTest):
|
|
@support.cpython_only
|
|
def test_track_subtypes(self):
|
|
# Tuple subtypes must always be tracked
|
|
- class MyTuple(tuple):
|
|
- pass
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class MyTuple(tuple):
|
|
+ pass
|
|
self.check_track_dynamic(MyTuple, True)
|
|
|
|
@support.cpython_only
|
|
@@ -404,7 +462,8 @@ class TupleTest(seq_tests.CommonTest):
|
|
# Issue 8847: In the PGO build, the MSVC linker's COMDAT folding
|
|
# optimization causes failures in code that relies on distinct
|
|
# function addresses.
|
|
- class T(tuple): pass
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class T(tuple): pass
|
|
with self.assertRaises(TypeError):
|
|
[3,] + T((1,2))
|
|
|
|
@@ -510,4 +569,4 @@ class TupleTest(seq_tests.CommonTest):
|
|
# pileup 262,143 mean 8.0 coll 262,143 z +92683.6
|
|
|
|
if __name__ == "__main__":
|
|
- unittest.main()
|
|
+ run_tests()
|