mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Renaming `set_fullgraph` to `error_on_graph_break` for now. There are no semantic differences yet. In a followup PR, we will introduce a new `torch.compile` option `error_on_graph_break` that has lower priority than `fullgraph` so that `fullgraph` really returns 1 graph. I could keep `set_fullgraph` as a deprecated alias for `error_on_graph_break` for now, but I'm hoping that won't be necessary since it's still private API (there are no internal callsites yet, and there are no significant OSS callsites yet). cc @albanD @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames @Lucaskabela @mlazos @guilhermeleobas @xmfan as primary users for `set_fullgraph` Pull Request resolved: https://github.com/pytorch/pytorch/pull/161739 Approved by: https://github.com/xmfan, https://github.com/Lucaskabela, https://github.com/anijain2305, https://github.com/mlazos
309 lines
11 KiB
Diff
309 lines
11 KiB
Diff
diff --git a/test/dynamo/cpython/3_13/test_with.py b/test/dynamo/cpython/3_13/test_with.py
|
|
index 8e9ed8500c7..66c18ad886a 100644
|
|
--- a/test/dynamo/cpython/3_13/test_with.py
|
|
+++ b/test/dynamo/cpython/3_13/test_with.py
|
|
@@ -1,3 +1,23 @@
|
|
+# ======= BEGIN Dynamo patch =======
|
|
+# Owner(s): ["module: dynamo"]
|
|
+
|
|
+# ruff: noqa
|
|
+# flake8: noqa
|
|
+
|
|
+# Test copied from
|
|
+# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_with.py
|
|
+
|
|
+import sys
|
|
+import torch
|
|
+import torch._dynamo.test_case
|
|
+import unittest
|
|
+from torch._dynamo.test_case import CPythonTestCase
|
|
+from torch.testing._internal.common_utils import run_tests
|
|
+
|
|
+__TestCase = CPythonTestCase
|
|
+
|
|
+# ======= END DYNAMO PATCH =======
|
|
+
|
|
"""Unit tests for the with statement specified in PEP 343."""
|
|
|
|
|
|
@@ -104,16 +124,17 @@ class MockNested(Nested):
|
|
return Nested.__exit__(self, *exc_info)
|
|
|
|
|
|
-class FailureTestCase(unittest.TestCase):
|
|
+class FailureTestCase(__TestCase):
|
|
def testNameError(self):
|
|
def fooNotDeclared():
|
|
with foo: pass
|
|
self.assertRaises(NameError, fooNotDeclared)
|
|
|
|
def testEnterAttributeError1(self):
|
|
- class LacksEnter(object):
|
|
- def __exit__(self, type, value, traceback):
|
|
- pass
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class LacksEnter(object):
|
|
+ def __exit__(self, type, value, traceback):
|
|
+ pass
|
|
|
|
def fooLacksEnter():
|
|
foo = LacksEnter()
|
|
@@ -121,8 +142,9 @@ class FailureTestCase(unittest.TestCase):
|
|
self.assertRaisesRegex(TypeError, 'the context manager', fooLacksEnter)
|
|
|
|
def testEnterAttributeError2(self):
|
|
- class LacksEnterAndExit(object):
|
|
- pass
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class LacksEnterAndExit(object):
|
|
+ pass
|
|
|
|
def fooLacksEnterAndExit():
|
|
foo = LacksEnterAndExit()
|
|
@@ -130,9 +152,10 @@ class FailureTestCase(unittest.TestCase):
|
|
self.assertRaisesRegex(TypeError, 'the context manager', fooLacksEnterAndExit)
|
|
|
|
def testExitAttributeError(self):
|
|
- class LacksExit(object):
|
|
- def __enter__(self):
|
|
- pass
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class LacksExit(object):
|
|
+ def __enter__(self):
|
|
+ pass
|
|
|
|
def fooLacksExit():
|
|
foo = LacksExit()
|
|
@@ -162,11 +185,12 @@ class FailureTestCase(unittest.TestCase):
|
|
' pass')
|
|
|
|
def testEnterThrows(self):
|
|
- class EnterThrows(object):
|
|
- def __enter__(self):
|
|
- raise RuntimeError("Enter threw")
|
|
- def __exit__(self, *args):
|
|
- pass
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class EnterThrows(object):
|
|
+ def __enter__(self):
|
|
+ raise RuntimeError("Enter threw")
|
|
+ def __exit__(self, *args):
|
|
+ pass
|
|
|
|
def shouldThrow():
|
|
ct = EnterThrows()
|
|
@@ -180,11 +204,12 @@ class FailureTestCase(unittest.TestCase):
|
|
self.assertEqual(self.foo, None)
|
|
|
|
def testExitThrows(self):
|
|
- class ExitThrows(object):
|
|
- def __enter__(self):
|
|
- return
|
|
- def __exit__(self, *args):
|
|
- raise RuntimeError(42)
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class ExitThrows(object):
|
|
+ def __enter__(self):
|
|
+ return
|
|
+ def __exit__(self, *args):
|
|
+ raise RuntimeError(42)
|
|
def shouldThrow():
|
|
with ExitThrows():
|
|
pass
|
|
@@ -194,6 +219,7 @@ class ContextmanagerAssertionMixin(object):
|
|
|
|
def setUp(self):
|
|
self.TEST_EXCEPTION = RuntimeError("test exception")
|
|
+ super().setUp()
|
|
|
|
def assertInWithManagerInvariants(self, mock_manager):
|
|
self.assertTrue(mock_manager.enter_called)
|
|
@@ -237,7 +263,7 @@ class ContextmanagerAssertionMixin(object):
|
|
self.assertTrue(mock_generator.stopped)
|
|
|
|
|
|
-class NonexceptionalTestCase(unittest.TestCase, ContextmanagerAssertionMixin):
|
|
+class NonexceptionalTestCase(__TestCase, ContextmanagerAssertionMixin):
|
|
def testInlineGeneratorSyntax(self):
|
|
with mock_contextmanager_generator():
|
|
pass
|
|
@@ -289,7 +315,7 @@ class NonexceptionalTestCase(unittest.TestCase, ContextmanagerAssertionMixin):
|
|
self.assertAfterWithGeneratorInvariantsNoError(foo)
|
|
|
|
|
|
-class NestedNonexceptionalTestCase(unittest.TestCase,
|
|
+class NestedNonexceptionalTestCase(__TestCase,
|
|
ContextmanagerAssertionMixin):
|
|
def testSingleArgInlineGeneratorSyntax(self):
|
|
with Nested(mock_contextmanager_generator()):
|
|
@@ -355,7 +381,7 @@ class NestedNonexceptionalTestCase(unittest.TestCase,
|
|
self.assertAfterWithManagerInvariantsNoError(mock_nested)
|
|
|
|
|
|
-class ExceptionalTestCase(ContextmanagerAssertionMixin, unittest.TestCase):
|
|
+class ExceptionalTestCase(ContextmanagerAssertionMixin, __TestCase):
|
|
def testSingleResource(self):
|
|
cm = mock_contextmanager_generator()
|
|
def shouldThrow():
|
|
@@ -466,11 +492,12 @@ class ExceptionalTestCase(ContextmanagerAssertionMixin, unittest.TestCase):
|
|
|
|
def testRaisedStopIteration2(self):
|
|
# From bug 1462485
|
|
- class cm(object):
|
|
- def __enter__(self):
|
|
- pass
|
|
- def __exit__(self, type, value, traceback):
|
|
- pass
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class cm(object):
|
|
+ def __enter__(self):
|
|
+ pass
|
|
+ def __exit__(self, type, value, traceback):
|
|
+ pass
|
|
|
|
def shouldThrow():
|
|
with cm():
|
|
@@ -507,11 +534,12 @@ class ExceptionalTestCase(ContextmanagerAssertionMixin, unittest.TestCase):
|
|
|
|
def testRaisedGeneratorExit2(self):
|
|
# From bug 1462485
|
|
- class cm (object):
|
|
- def __enter__(self):
|
|
- pass
|
|
- def __exit__(self, type, value, traceback):
|
|
- pass
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class cm (object):
|
|
+ def __enter__(self):
|
|
+ pass
|
|
+ def __exit__(self, type, value, traceback):
|
|
+ pass
|
|
|
|
def shouldThrow():
|
|
with cm():
|
|
@@ -523,16 +551,17 @@ class ExceptionalTestCase(ContextmanagerAssertionMixin, unittest.TestCase):
|
|
# issue4589: __exit__ return code may raise an exception
|
|
# when looking at its truth value.
|
|
|
|
- class cm(object):
|
|
- def __init__(self, bool_conversion):
|
|
- class Bool:
|
|
- def __bool__(self):
|
|
- return bool_conversion()
|
|
- self.exit_result = Bool()
|
|
- def __enter__(self):
|
|
- return 3
|
|
- def __exit__(self, a, b, c):
|
|
- return self.exit_result
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class cm(object):
|
|
+ def __init__(self, bool_conversion):
|
|
+ class Bool:
|
|
+ def __bool__(self):
|
|
+ return bool_conversion()
|
|
+ self.exit_result = Bool()
|
|
+ def __enter__(self):
|
|
+ return 3
|
|
+ def __exit__(self, a, b, c):
|
|
+ return self.exit_result
|
|
|
|
def trueAsBool():
|
|
with cm(lambda: True):
|
|
@@ -550,7 +579,7 @@ class ExceptionalTestCase(ContextmanagerAssertionMixin, unittest.TestCase):
|
|
self.assertRaises(ZeroDivisionError, failAsBool)
|
|
|
|
|
|
-class NonLocalFlowControlTestCase(unittest.TestCase):
|
|
+class NonLocalFlowControlTestCase(__TestCase):
|
|
|
|
def testWithBreak(self):
|
|
counter = 0
|
|
@@ -607,7 +636,7 @@ class NonLocalFlowControlTestCase(unittest.TestCase):
|
|
self.fail("Didn't raise RuntimeError")
|
|
|
|
|
|
-class AssignmentTargetTestCase(unittest.TestCase):
|
|
+class AssignmentTargetTestCase(__TestCase):
|
|
|
|
def testSingleComplexTarget(self):
|
|
targets = {1: [0, 1, 2]}
|
|
@@ -621,15 +650,17 @@ class AssignmentTargetTestCase(unittest.TestCase):
|
|
keys = list(targets.keys())
|
|
keys.sort()
|
|
self.assertEqual(keys, [1, 2])
|
|
- class C: pass
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class C: pass
|
|
blah = C()
|
|
with mock_contextmanager_generator() as blah.foo:
|
|
self.assertEqual(hasattr(blah, "foo"), True)
|
|
|
|
def testMultipleComplexTargets(self):
|
|
- class C:
|
|
- def __enter__(self): return 1, 2, 3
|
|
- def __exit__(self, t, v, tb): pass
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class C:
|
|
+ def __enter__(self): return 1, 2, 3
|
|
+ def __exit__(self, t, v, tb): pass
|
|
targets = {1: [0, 1, 2]}
|
|
with C() as (targets[1][0], targets[1][1], targets[1][2]):
|
|
self.assertEqual(targets, {1: [1, 2, 3]})
|
|
@@ -637,7 +668,8 @@ class AssignmentTargetTestCase(unittest.TestCase):
|
|
self.assertEqual(targets, {1: [3, 2, 1]})
|
|
with C() as (targets[1], targets[2], targets[3]):
|
|
self.assertEqual(targets, {1: 1, 2: 2, 3: 3})
|
|
- class B: pass
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class B: pass
|
|
blah = B()
|
|
with C() as (blah.one, blah.two, blah.three):
|
|
self.assertEqual(blah.one, 1)
|
|
@@ -651,12 +683,13 @@ class AssignmentTargetTestCase(unittest.TestCase):
|
|
self.assertEqual(c, 4)
|
|
|
|
|
|
-class ExitSwallowsExceptionTestCase(unittest.TestCase):
|
|
+class ExitSwallowsExceptionTestCase(__TestCase):
|
|
|
|
def testExitTrueSwallowsException(self):
|
|
- class AfricanSwallow:
|
|
- def __enter__(self): pass
|
|
- def __exit__(self, t, v, tb): return True
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class AfricanSwallow:
|
|
+ def __enter__(self): pass
|
|
+ def __exit__(self, t, v, tb): return True
|
|
try:
|
|
with AfricanSwallow():
|
|
1/0
|
|
@@ -664,9 +697,10 @@ class ExitSwallowsExceptionTestCase(unittest.TestCase):
|
|
self.fail("ZeroDivisionError should have been swallowed")
|
|
|
|
def testExitFalseDoesntSwallowException(self):
|
|
- class EuropeanSwallow:
|
|
- def __enter__(self): pass
|
|
- def __exit__(self, t, v, tb): return False
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class EuropeanSwallow:
|
|
+ def __enter__(self): pass
|
|
+ def __exit__(self, t, v, tb): return False
|
|
try:
|
|
with EuropeanSwallow():
|
|
1/0
|
|
@@ -676,7 +710,7 @@ class ExitSwallowsExceptionTestCase(unittest.TestCase):
|
|
self.fail("ZeroDivisionError should have been raised")
|
|
|
|
|
|
-class NestedWith(unittest.TestCase):
|
|
+class NestedWith(__TestCase):
|
|
|
|
class Dummy(object):
|
|
def __init__(self, value=None, gobble=False):
|
|
@@ -796,4 +830,4 @@ class NestedWith(unittest.TestCase):
|
|
|
|
|
|
if __name__ == '__main__':
|
|
- unittest.main()
|
|
+ run_tests()
|