pytorch/test/dynamo/cpython/3_13/test_with.diff
William Wen 8678d831c4 [dynamo] rename set_fullgraph to error_on_graph_break (#161739)
Renaming `set_fullgraph` to `error_on_graph_break` for now. There are no semantic differences yet. In a followup PR, we will introduce a new `torch.compile` option `error_on_graph_break` that has lower priority than `fullgraph` so that `fullgraph` really returns 1 graph.

I could keep `set_fullgraph` as a deprecated alias for `error_on_graph_break` for now, but I'm hoping that won't be necessary since it's still private API (there are no internal callsites yet, and there are no significant OSS callsites yet).

 cc @albanD @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames @Lucaskabela @mlazos @guilhermeleobas @xmfan as primary users for `set_fullgraph`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161739
Approved by: https://github.com/xmfan, https://github.com/Lucaskabela, https://github.com/anijain2305, https://github.com/mlazos
2025-09-04 01:15:06 +00:00

309 lines
11 KiB
Diff

diff --git a/test/dynamo/cpython/3_13/test_with.py b/test/dynamo/cpython/3_13/test_with.py
index 8e9ed8500c7..66c18ad886a 100644
--- a/test/dynamo/cpython/3_13/test_with.py
+++ b/test/dynamo/cpython/3_13/test_with.py
@@ -1,3 +1,23 @@
+# ======= BEGIN Dynamo patch =======
+# Owner(s): ["module: dynamo"]
+
+# ruff: noqa
+# flake8: noqa
+
+# Test copied from
+# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_with.py
+
+import sys
+import torch
+import torch._dynamo.test_case
+import unittest
+from torch._dynamo.test_case import CPythonTestCase
+from torch.testing._internal.common_utils import run_tests
+
+__TestCase = CPythonTestCase
+
+# ======= END DYNAMO PATCH =======
+
"""Unit tests for the with statement specified in PEP 343."""
@@ -104,16 +124,17 @@ class MockNested(Nested):
return Nested.__exit__(self, *exc_info)
-class FailureTestCase(unittest.TestCase):
+class FailureTestCase(__TestCase):
def testNameError(self):
def fooNotDeclared():
with foo: pass
self.assertRaises(NameError, fooNotDeclared)
def testEnterAttributeError1(self):
- class LacksEnter(object):
- def __exit__(self, type, value, traceback):
- pass
+ with torch._dynamo.error_on_graph_break(False):
+ class LacksEnter(object):
+ def __exit__(self, type, value, traceback):
+ pass
def fooLacksEnter():
foo = LacksEnter()
@@ -121,8 +142,9 @@ class FailureTestCase(unittest.TestCase):
self.assertRaisesRegex(TypeError, 'the context manager', fooLacksEnter)
def testEnterAttributeError2(self):
- class LacksEnterAndExit(object):
- pass
+ with torch._dynamo.error_on_graph_break(False):
+ class LacksEnterAndExit(object):
+ pass
def fooLacksEnterAndExit():
foo = LacksEnterAndExit()
@@ -130,9 +152,10 @@ class FailureTestCase(unittest.TestCase):
self.assertRaisesRegex(TypeError, 'the context manager', fooLacksEnterAndExit)
def testExitAttributeError(self):
- class LacksExit(object):
- def __enter__(self):
- pass
+ with torch._dynamo.error_on_graph_break(False):
+ class LacksExit(object):
+ def __enter__(self):
+ pass
def fooLacksExit():
foo = LacksExit()
@@ -162,11 +185,12 @@ class FailureTestCase(unittest.TestCase):
' pass')
def testEnterThrows(self):
- class EnterThrows(object):
- def __enter__(self):
- raise RuntimeError("Enter threw")
- def __exit__(self, *args):
- pass
+ with torch._dynamo.error_on_graph_break(False):
+ class EnterThrows(object):
+ def __enter__(self):
+ raise RuntimeError("Enter threw")
+ def __exit__(self, *args):
+ pass
def shouldThrow():
ct = EnterThrows()
@@ -180,11 +204,12 @@ class FailureTestCase(unittest.TestCase):
self.assertEqual(self.foo, None)
def testExitThrows(self):
- class ExitThrows(object):
- def __enter__(self):
- return
- def __exit__(self, *args):
- raise RuntimeError(42)
+ with torch._dynamo.error_on_graph_break(False):
+ class ExitThrows(object):
+ def __enter__(self):
+ return
+ def __exit__(self, *args):
+ raise RuntimeError(42)
def shouldThrow():
with ExitThrows():
pass
@@ -194,6 +219,7 @@ class ContextmanagerAssertionMixin(object):
def setUp(self):
self.TEST_EXCEPTION = RuntimeError("test exception")
+ super().setUp()
def assertInWithManagerInvariants(self, mock_manager):
self.assertTrue(mock_manager.enter_called)
@@ -237,7 +263,7 @@ class ContextmanagerAssertionMixin(object):
self.assertTrue(mock_generator.stopped)
-class NonexceptionalTestCase(unittest.TestCase, ContextmanagerAssertionMixin):
+class NonexceptionalTestCase(__TestCase, ContextmanagerAssertionMixin):
def testInlineGeneratorSyntax(self):
with mock_contextmanager_generator():
pass
@@ -289,7 +315,7 @@ class NonexceptionalTestCase(unittest.TestCase, ContextmanagerAssertionMixin):
self.assertAfterWithGeneratorInvariantsNoError(foo)
-class NestedNonexceptionalTestCase(unittest.TestCase,
+class NestedNonexceptionalTestCase(__TestCase,
ContextmanagerAssertionMixin):
def testSingleArgInlineGeneratorSyntax(self):
with Nested(mock_contextmanager_generator()):
@@ -355,7 +381,7 @@ class NestedNonexceptionalTestCase(unittest.TestCase,
self.assertAfterWithManagerInvariantsNoError(mock_nested)
-class ExceptionalTestCase(ContextmanagerAssertionMixin, unittest.TestCase):
+class ExceptionalTestCase(ContextmanagerAssertionMixin, __TestCase):
def testSingleResource(self):
cm = mock_contextmanager_generator()
def shouldThrow():
@@ -466,11 +492,12 @@ class ExceptionalTestCase(ContextmanagerAssertionMixin, unittest.TestCase):
def testRaisedStopIteration2(self):
# From bug 1462485
- class cm(object):
- def __enter__(self):
- pass
- def __exit__(self, type, value, traceback):
- pass
+ with torch._dynamo.error_on_graph_break(False):
+ class cm(object):
+ def __enter__(self):
+ pass
+ def __exit__(self, type, value, traceback):
+ pass
def shouldThrow():
with cm():
@@ -507,11 +534,12 @@ class ExceptionalTestCase(ContextmanagerAssertionMixin, unittest.TestCase):
def testRaisedGeneratorExit2(self):
# From bug 1462485
- class cm (object):
- def __enter__(self):
- pass
- def __exit__(self, type, value, traceback):
- pass
+ with torch._dynamo.error_on_graph_break(False):
+ class cm (object):
+ def __enter__(self):
+ pass
+ def __exit__(self, type, value, traceback):
+ pass
def shouldThrow():
with cm():
@@ -523,16 +551,17 @@ class ExceptionalTestCase(ContextmanagerAssertionMixin, unittest.TestCase):
# issue4589: __exit__ return code may raise an exception
# when looking at its truth value.
- class cm(object):
- def __init__(self, bool_conversion):
- class Bool:
- def __bool__(self):
- return bool_conversion()
- self.exit_result = Bool()
- def __enter__(self):
- return 3
- def __exit__(self, a, b, c):
- return self.exit_result
+ with torch._dynamo.error_on_graph_break(False):
+ class cm(object):
+ def __init__(self, bool_conversion):
+ class Bool:
+ def __bool__(self):
+ return bool_conversion()
+ self.exit_result = Bool()
+ def __enter__(self):
+ return 3
+ def __exit__(self, a, b, c):
+ return self.exit_result
def trueAsBool():
with cm(lambda: True):
@@ -550,7 +579,7 @@ class ExceptionalTestCase(ContextmanagerAssertionMixin, unittest.TestCase):
self.assertRaises(ZeroDivisionError, failAsBool)
-class NonLocalFlowControlTestCase(unittest.TestCase):
+class NonLocalFlowControlTestCase(__TestCase):
def testWithBreak(self):
counter = 0
@@ -607,7 +636,7 @@ class NonLocalFlowControlTestCase(unittest.TestCase):
self.fail("Didn't raise RuntimeError")
-class AssignmentTargetTestCase(unittest.TestCase):
+class AssignmentTargetTestCase(__TestCase):
def testSingleComplexTarget(self):
targets = {1: [0, 1, 2]}
@@ -621,15 +650,17 @@ class AssignmentTargetTestCase(unittest.TestCase):
keys = list(targets.keys())
keys.sort()
self.assertEqual(keys, [1, 2])
- class C: pass
+ with torch._dynamo.error_on_graph_break(False):
+ class C: pass
blah = C()
with mock_contextmanager_generator() as blah.foo:
self.assertEqual(hasattr(blah, "foo"), True)
def testMultipleComplexTargets(self):
- class C:
- def __enter__(self): return 1, 2, 3
- def __exit__(self, t, v, tb): pass
+ with torch._dynamo.error_on_graph_break(False):
+ class C:
+ def __enter__(self): return 1, 2, 3
+ def __exit__(self, t, v, tb): pass
targets = {1: [0, 1, 2]}
with C() as (targets[1][0], targets[1][1], targets[1][2]):
self.assertEqual(targets, {1: [1, 2, 3]})
@@ -637,7 +668,8 @@ class AssignmentTargetTestCase(unittest.TestCase):
self.assertEqual(targets, {1: [3, 2, 1]})
with C() as (targets[1], targets[2], targets[3]):
self.assertEqual(targets, {1: 1, 2: 2, 3: 3})
- class B: pass
+ with torch._dynamo.error_on_graph_break(False):
+ class B: pass
blah = B()
with C() as (blah.one, blah.two, blah.three):
self.assertEqual(blah.one, 1)
@@ -651,12 +683,13 @@ class AssignmentTargetTestCase(unittest.TestCase):
self.assertEqual(c, 4)
-class ExitSwallowsExceptionTestCase(unittest.TestCase):
+class ExitSwallowsExceptionTestCase(__TestCase):
def testExitTrueSwallowsException(self):
- class AfricanSwallow:
- def __enter__(self): pass
- def __exit__(self, t, v, tb): return True
+ with torch._dynamo.error_on_graph_break(False):
+ class AfricanSwallow:
+ def __enter__(self): pass
+ def __exit__(self, t, v, tb): return True
try:
with AfricanSwallow():
1/0
@@ -664,9 +697,10 @@ class ExitSwallowsExceptionTestCase(unittest.TestCase):
self.fail("ZeroDivisionError should have been swallowed")
def testExitFalseDoesntSwallowException(self):
- class EuropeanSwallow:
- def __enter__(self): pass
- def __exit__(self, t, v, tb): return False
+ with torch._dynamo.error_on_graph_break(False):
+ class EuropeanSwallow:
+ def __enter__(self): pass
+ def __exit__(self, t, v, tb): return False
try:
with EuropeanSwallow():
1/0
@@ -676,7 +710,7 @@ class ExitSwallowsExceptionTestCase(unittest.TestCase):
self.fail("ZeroDivisionError should have been raised")
-class NestedWith(unittest.TestCase):
+class NestedWith(__TestCase):
class Dummy(object):
def __init__(self, value=None, gobble=False):
@@ -796,4 +830,4 @@ class NestedWith(unittest.TestCase):
if __name__ == '__main__':
- unittest.main()
+ run_tests()