pytorch/test/dynamo/cpython/3_13/test_contextlib.diff
William Wen 8678d831c4 [dynamo] rename set_fullgraph to error_on_graph_break (#161739)
Renaming `set_fullgraph` to `error_on_graph_break` for now. There are no semantic differences yet. In a followup PR, we will introduce a new `torch.compile` option `error_on_graph_break` that has lower priority than `fullgraph` so that `fullgraph` really returns 1 graph.

I could keep `set_fullgraph` as a deprecated alias for `error_on_graph_break` for now, but I'm hoping that won't be necessary since it's still private API (there are no internal callsites yet, and there are no significant OSS callsites yet).

 cc @albanD @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames @Lucaskabela @mlazos @guilhermeleobas @xmfan as primary users for `set_fullgraph`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161739
Approved by: https://github.com/xmfan, https://github.com/Lucaskabela, https://github.com/anijain2305, https://github.com/mlazos
2025-09-04 01:15:06 +00:00

623 lines
20 KiB
Diff

diff --git a/test/dynamo/cpython/3_13/test_contextlib.py b/test/dynamo/cpython/3_13/test_contextlib.py
index cf651959803..256a824932d 100644
--- a/test/dynamo/cpython/3_13/test_contextlib.py
+++ b/test/dynamo/cpython/3_13/test_contextlib.py
@@ -1,3 +1,57 @@
+# ======= BEGIN Dynamo patch =======
+# Owner(s): ["module: dynamo"]
+
+# ruff: noqa
+# flake8: noqa
+
+# Test copied from
+# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_contextlib.py
+
+import sys
+import torch
+import torch._dynamo.test_case
+import unittest
+from torch._dynamo.test_case import CPythonTestCase
+from torch.testing._internal.common_utils import run_tests
+
+__TestCase = CPythonTestCase
+
+
+# redirect import statements
+import sys
+import importlib.abc
+
+redirect_imports = (
+ "test.mapping_tests",
+ "test.typinganndata",
+ "test.test_grammar",
+ "test.test_math",
+ "test.test_iter",
+ "test.typinganndata.ann_module",
+)
+
+class RedirectImportFinder(importlib.abc.MetaPathFinder):
+ def find_spec(self, fullname, path, target=None):
+ # Check if the import is the problematic one
+ if fullname in redirect_imports:
+ try:
+ # Attempt to import the standalone module
+ name = fullname.removeprefix("test.")
+ r = importlib.import_module(name)
+ # Redirect the module in sys.modules
+ sys.modules[fullname] = r
+ # Return a module spec from the found module
+ return importlib.util.find_spec(name)
+ except ImportError:
+ return None
+ return None
+
+# Add the custom finder to sys.meta_path
+sys.meta_path.insert(0, RedirectImportFinder())
+
+
+# ======= END DYNAMO PATCH =======
+
"""Unit tests for contextlib.py, and other context managers."""
import io
@@ -14,60 +68,67 @@ from test.support.testcase import ExceptionIsLikeMixin
import weakref
-class TestAbstractContextManager(unittest.TestCase):
+class TestAbstractContextManager(__TestCase):
def test_enter(self):
- class DefaultEnter(AbstractContextManager):
- def __exit__(self, *args):
- super().__exit__(*args)
+ with torch._dynamo.error_on_graph_break(False):
+ class DefaultEnter(AbstractContextManager):
+ def __exit__(self, *args):
+ super().__exit__(*args)
manager = DefaultEnter()
self.assertIs(manager.__enter__(), manager)
def test_slots(self):
- class DefaultContextManager(AbstractContextManager):
- __slots__ = ()
+ with torch._dynamo.error_on_graph_break(False):
+ class DefaultContextManager(AbstractContextManager):
+ __slots__ = ()
- def __exit__(self, *args):
- super().__exit__(*args)
+ def __exit__(self, *args):
+ super().__exit__(*args)
with self.assertRaises(AttributeError):
DefaultContextManager().var = 42
def test_exit_is_abstract(self):
- class MissingExit(AbstractContextManager):
- pass
+ with torch._dynamo.error_on_graph_break(False):
+ class MissingExit(AbstractContextManager):
+ pass
with self.assertRaises(TypeError):
MissingExit()
def test_structural_subclassing(self):
- class ManagerFromScratch:
- def __enter__(self):
- return self
- def __exit__(self, exc_type, exc_value, traceback):
- return None
+ with torch._dynamo.error_on_graph_break(False):
+ class ManagerFromScratch:
+ def __enter__(self):
+ return self
+ def __exit__(self, exc_type, exc_value, traceback):
+ return None
self.assertTrue(issubclass(ManagerFromScratch, AbstractContextManager))
- class DefaultEnter(AbstractContextManager):
- def __exit__(self, *args):
- super().__exit__(*args)
+ with torch._dynamo.error_on_graph_break(False):
+ class DefaultEnter(AbstractContextManager):
+ def __exit__(self, *args):
+ super().__exit__(*args)
self.assertTrue(issubclass(DefaultEnter, AbstractContextManager))
- class NoEnter(ManagerFromScratch):
- __enter__ = None
+ with torch._dynamo.error_on_graph_break(False):
+ class NoEnter(ManagerFromScratch):
+ __enter__ = None
self.assertFalse(issubclass(NoEnter, AbstractContextManager))
- class NoExit(ManagerFromScratch):
- __exit__ = None
+ with torch._dynamo.error_on_graph_break(False):
+ class NoExit(ManagerFromScratch):
+ __exit__ = None
self.assertFalse(issubclass(NoExit, AbstractContextManager))
-class ContextManagerTestCase(unittest.TestCase):
+class ContextManagerTestCase(__TestCase):
def test_contextmanager_plain(self):
state = []
@@ -115,8 +176,9 @@ class ContextManagerTestCase(unittest.TestCase):
self.assertEqual(frames[0].line, '1/0')
# Repeat with RuntimeError (which goes through a different code path)
- class RuntimeErrorSubclass(RuntimeError):
- pass
+ with torch._dynamo.error_on_graph_break(False):
+ class RuntimeErrorSubclass(RuntimeError):
+ pass
try:
with f():
@@ -128,8 +190,9 @@ class ContextManagerTestCase(unittest.TestCase):
self.assertEqual(frames[0].name, 'test_contextmanager_traceback')
self.assertEqual(frames[0].line, 'raise RuntimeErrorSubclass(42)')
- class StopIterationSubclass(StopIteration):
- pass
+ with torch._dynamo.error_on_graph_break(False):
+ class StopIterationSubclass(StopIteration):
+ pass
for stop_exc in (
StopIteration('spam'),
@@ -169,9 +232,9 @@ class ContextManagerTestCase(unittest.TestCase):
ctx.__enter__()
with self.assertRaises(RuntimeError):
ctx.__exit__(TypeError, TypeError("foo"), None)
- if support.check_impl_detail(cpython=True):
- # The "gen" attribute is an implementation detail.
- self.assertFalse(ctx.gen.gi_suspended)
+ # if support.check_impl_detail(cpython=True):
+ # # The "gen" attribute is an implementation detail.
+ # self.assertFalse(ctx.gen.gi_suspended)
def test_contextmanager_trap_no_yield(self):
@contextmanager
@@ -191,9 +254,9 @@ class ContextManagerTestCase(unittest.TestCase):
ctx.__enter__()
with self.assertRaises(RuntimeError):
ctx.__exit__(None, None, None)
- if support.check_impl_detail(cpython=True):
- # The "gen" attribute is an implementation detail.
- self.assertFalse(ctx.gen.gi_suspended)
+ # if support.check_impl_detail(cpython=True):
+ # # The "gen" attribute is an implementation detail.
+ # self.assertFalse(ctx.gen.gi_suspended)
def test_contextmanager_non_normalised(self):
@contextmanager
@@ -230,8 +293,9 @@ class ContextManagerTestCase(unittest.TestCase):
def woohoo():
yield
- class StopIterationSubclass(StopIteration):
- pass
+ with torch._dynamo.error_on_graph_break(False):
+ class StopIterationSubclass(StopIteration):
+ pass
for stop_exc in (StopIteration('spam'), StopIterationSubclass('spam')):
with self.subTest(type=type(stop_exc)):
@@ -344,8 +408,9 @@ def woohoo():
self.assertEqual(target, (11, 22, 33, 44))
def test_nokeepref(self):
- class A:
- pass
+ with torch._dynamo.error_on_graph_break(False):
+ class A:
+ pass
@contextmanager
def woohoo(a, b):
@@ -396,7 +461,7 @@ def woohoo():
self.assertEqual(depth, 0)
-class ClosingTestCase(unittest.TestCase):
+class ClosingTestCase(__TestCase):
@support.requires_docstrings
def test_instance_docs(self):
@@ -407,9 +472,10 @@ class ClosingTestCase(unittest.TestCase):
def test_closing(self):
state = []
- class C:
- def close(self):
- state.append(1)
+ with torch._dynamo.error_on_graph_break(False):
+ class C:
+ def close(self):
+ state.append(1)
x = C()
self.assertEqual(state, [])
with closing(x) as y:
@@ -418,9 +484,10 @@ class ClosingTestCase(unittest.TestCase):
def test_closing_error(self):
state = []
- class C:
- def close(self):
- state.append(1)
+ with torch._dynamo.error_on_graph_break(False):
+ class C:
+ def close(self):
+ state.append(1)
x = C()
self.assertEqual(state, [])
with self.assertRaises(ZeroDivisionError):
@@ -430,16 +497,17 @@ class ClosingTestCase(unittest.TestCase):
self.assertEqual(state, [1])
-class NullcontextTestCase(unittest.TestCase):
+class NullcontextTestCase(__TestCase):
def test_nullcontext(self):
- class C:
- pass
+ with torch._dynamo.error_on_graph_break(False):
+ class C:
+ pass
c = C()
with nullcontext(c) as c_in:
self.assertIs(c_in, c)
-class FileContextTestCase(unittest.TestCase):
+class FileContextTestCase(__TestCase):
def testWithOpen(self):
tfn = tempfile.mktemp()
@@ -457,7 +525,7 @@ class FileContextTestCase(unittest.TestCase):
finally:
os_helper.unlink(tfn)
-class LockContextTestCase(unittest.TestCase):
+class LockContextTestCase(__TestCase):
def boilerPlate(self, lock, locked):
self.assertFalse(locked())
@@ -520,7 +588,7 @@ class mycontext(ContextDecorator):
return self.catch
-class TestContextDecorator(unittest.TestCase):
+class TestContextDecorator(__TestCase):
@support.requires_docstrings
def test_instance_docs(self):
@@ -584,13 +652,14 @@ class TestContextDecorator(unittest.TestCase):
def test_decorating_method(self):
context = mycontext()
- class Test(object):
+ with torch._dynamo.error_on_graph_break(False):
+ class Test(object):
- @context
- def method(self, a, b, c=None):
- self.a = a
- self.b = b
- self.c = c
+ @context
+ def method(self, a, b, c=None):
+ self.a = a
+ self.b = b
+ self.c = c
# these tests are for argument passing when used as a decorator
test = Test()
@@ -612,11 +681,12 @@ class TestContextDecorator(unittest.TestCase):
def test_typo_enter(self):
- class mycontext(ContextDecorator):
- def __unter__(self):
- pass
- def __exit__(self, *exc):
- pass
+ with torch._dynamo.error_on_graph_break(False):
+ class mycontext(ContextDecorator):
+ def __unter__(self):
+ pass
+ def __exit__(self, *exc):
+ pass
with self.assertRaisesRegex(TypeError, 'the context manager'):
with mycontext():
@@ -624,11 +694,12 @@ class TestContextDecorator(unittest.TestCase):
def test_typo_exit(self):
- class mycontext(ContextDecorator):
- def __enter__(self):
- pass
- def __uxit__(self, *exc):
- pass
+ with torch._dynamo.error_on_graph_break(False):
+ class mycontext(ContextDecorator):
+ def __enter__(self):
+ pass
+ def __uxit__(self, *exc):
+ pass
with self.assertRaisesRegex(TypeError, 'the context manager.*__exit__'):
with mycontext():
@@ -636,19 +707,20 @@ class TestContextDecorator(unittest.TestCase):
def test_contextdecorator_as_mixin(self):
- class somecontext(object):
- started = False
- exc = None
+ with torch._dynamo.error_on_graph_break(False):
+ class somecontext(object):
+ started = False
+ exc = None
- def __enter__(self):
- self.started = True
- return self
+ def __enter__(self):
+ self.started = True
+ return self
- def __exit__(self, *exc):
- self.exc = exc
+ def __exit__(self, *exc):
+ self.exc = exc
- class mycontext(somecontext, ContextDecorator):
- pass
+ class mycontext(somecontext, ContextDecorator):
+ pass
context = mycontext()
@context
@@ -680,7 +752,7 @@ class TestContextDecorator(unittest.TestCase):
self.assertEqual(state, [1, 'something else', 999])
-class TestBaseExitStack:
+class _TestBaseExitStack:
exit_stack = None
@support.requires_docstrings
@@ -745,13 +817,14 @@ class TestBaseExitStack:
self.assertIsNone(exc_type)
self.assertIsNone(exc)
self.assertIsNone(exc_tb)
- class ExitCM(object):
- def __init__(self, check_exc):
- self.check_exc = check_exc
- def __enter__(self):
- self.fail("Should not be called!")
- def __exit__(self, *exc_details):
- self.check_exc(*exc_details)
+ with torch._dynamo.error_on_graph_break(False):
+ class ExitCM(object):
+ def __init__(self, check_exc):
+ self.check_exc = check_exc
+ def __enter__(self):
+ self.fail("Should not be called!")
+ def __exit__(self, *exc_details):
+ self.check_exc(*exc_details)
with self.exit_stack() as stack:
stack.push(_expect_ok)
self.assertIs(stack._exit_callbacks[-1][1], _expect_ok)
@@ -770,11 +843,12 @@ class TestBaseExitStack:
1/0
def test_enter_context(self):
- class TestCM(object):
- def __enter__(self):
- result.append(1)
- def __exit__(self, *exc_details):
- result.append(3)
+ with torch._dynamo.error_on_graph_break(False):
+ class TestCM(object):
+ def __enter__(self):
+ result.append(1)
+ def __exit__(self, *exc_details):
+ result.append(3)
result = []
cm = TestCM()
@@ -789,14 +863,15 @@ class TestBaseExitStack:
self.assertEqual(result, [1, 2, 3, 4])
def test_enter_context_errors(self):
- class LacksEnterAndExit:
- pass
- class LacksEnter:
- def __exit__(self, *exc_info):
- pass
- class LacksExit:
- def __enter__(self):
+ with torch._dynamo.error_on_graph_break(False):
+ class LacksEnterAndExit:
pass
+ class LacksEnter:
+ def __exit__(self, *exc_info):
+ pass
+ class LacksExit:
+ def __enter__(self):
+ pass
with self.exit_stack() as stack:
with self.assertRaisesRegex(TypeError, 'the context manager'):
@@ -877,32 +952,33 @@ class TestBaseExitStack:
def test_exit_exception_chaining_reference(self):
# Sanity check to make sure that ExitStack chaining matches
# actual nested with statements
- class RaiseExc:
- def __init__(self, exc):
- self.exc = exc
- def __enter__(self):
- return self
- def __exit__(self, *exc_details):
- raise self.exc
-
- class RaiseExcWithContext:
- def __init__(self, outer, inner):
- self.outer = outer
- self.inner = inner
- def __enter__(self):
- return self
- def __exit__(self, *exc_details):
- try:
- raise self.inner
- except:
- raise self.outer
-
- class SuppressExc:
- def __enter__(self):
- return self
- def __exit__(self, *exc_details):
- type(self).saved_details = exc_details
- return True
+ with torch._dynamo.error_on_graph_break(False):
+ class RaiseExc:
+ def __init__(self, exc):
+ self.exc = exc
+ def __enter__(self):
+ return self
+ def __exit__(self, *exc_details):
+ raise self.exc
+
+ class RaiseExcWithContext:
+ def __init__(self, outer, inner):
+ self.outer = outer
+ self.inner = inner
+ def __enter__(self):
+ return self
+ def __exit__(self, *exc_details):
+ try:
+ raise self.inner
+ except:
+ raise self.outer
+
+ class SuppressExc:
+ def __enter__(self):
+ return self
+ def __exit__(self, *exc_details):
+ type(self).saved_details = exc_details
+ return True
try:
with RaiseExc(IndexError):
@@ -957,8 +1033,9 @@ class TestBaseExitStack:
# Ensure ExitStack chaining matches actual nested `with` statements
# regarding explicit __context__ = None.
- class MyException(Exception):
- pass
+ with torch._dynamo.error_on_graph_break(False):
+ class MyException(Exception):
+ pass
@contextmanager
def my_cm():
@@ -1096,7 +1173,8 @@ class TestBaseExitStack:
stack.callback(int)
def test_instance_bypass(self):
- class Example(object): pass
+ with torch._dynamo.error_on_graph_break(False):
+ class Example(object): pass
cm = Example()
cm.__enter__ = object()
cm.__exit__ = object()
@@ -1108,8 +1186,9 @@ class TestBaseExitStack:
def test_dont_reraise_RuntimeError(self):
# https://bugs.python.org/issue27122
- class UniqueException(Exception): pass
- class UniqueRuntimeError(RuntimeError): pass
+ with torch._dynamo.error_on_graph_break(False):
+ class UniqueException(Exception): pass
+ class UniqueRuntimeError(RuntimeError): pass
@contextmanager
def second():
@@ -1141,7 +1220,7 @@ class TestBaseExitStack:
self.assertIs(exc.__cause__, exc.__context__)
-class TestExitStack(TestBaseExitStack, unittest.TestCase):
+class TestExitStack(_TestBaseExitStack, __TestCase):
exit_stack = ExitStack
callback_error_internal_frames = [
('__exit__', 'raise exc'),
@@ -1149,7 +1228,7 @@ class TestExitStack(TestBaseExitStack, unittest.TestCase):
]
-class TestRedirectStream:
+class _TestRedirectStream:
redirect_stream = None
orig_stream = None
@@ -1206,19 +1285,19 @@ class TestRedirectStream:
self.assertEqual(s, "Hello World!\n")
-class TestRedirectStdout(TestRedirectStream, unittest.TestCase):
+class TestRedirectStdout(_TestRedirectStream, __TestCase):
redirect_stream = redirect_stdout
orig_stream = "stdout"
-class TestRedirectStderr(TestRedirectStream, unittest.TestCase):
+class TestRedirectStderr(_TestRedirectStream, __TestCase):
redirect_stream = redirect_stderr
orig_stream = "stderr"
-class TestSuppress(ExceptionIsLikeMixin, unittest.TestCase):
+class TestSuppress(ExceptionIsLikeMixin, __TestCase):
@support.requires_docstrings
def test_instance_docs(self):
@@ -1315,7 +1394,7 @@ class TestSuppress(ExceptionIsLikeMixin, unittest.TestCase):
)
-class TestChdir(unittest.TestCase):
+class TestChdir(__TestCase):
def make_relative_path(self, *parts):
return os.path.join(
os.path.dirname(os.path.realpath(__file__)),
@@ -1331,6 +1410,7 @@ class TestChdir(unittest.TestCase):
self.assertEqual(os.getcwd(), target)
self.assertEqual(os.getcwd(), old_cwd)
+ @unittest.skip("Missing archivetestdata")
def test_reentrant(self):
old_cwd = os.getcwd()
target1 = self.make_relative_path('data')
@@ -1363,4 +1443,4 @@ class TestChdir(unittest.TestCase):
if __name__ == "__main__":
- unittest.main()
+ run_tests()