mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "Add infra to run CPython tests under Dynamo (#150787)"
This reverts commit 7c96dd8f0c.
Reverted https://github.com/pytorch/pytorch/pull/150787 on behalf of https://github.com/huydhn due to Sorry for reverting your change but a failed test is showing up in trunk ([comment](https://github.com/pytorch/pytorch/pull/150787#issuecomment-2852818113))
This commit is contained in:
parent
0e9874849f
commit
103fe856e1
|
|
@ -18,8 +18,6 @@ exclude_patterns = [
|
|||
'torch/_inductor/autoheuristic/artifacts/**',
|
||||
'scripts/**',
|
||||
'test/generated_type_hints_smoketest.py',
|
||||
# CPython tests
|
||||
'test/dynamo/cpython/**',
|
||||
# Tests from the NumPy test suite
|
||||
'test/torch_np/numpy_test/**/*.py',
|
||||
'third_party/**',
|
||||
|
|
@ -400,7 +398,6 @@ exclude_patterns=[
|
|||
'tools/clang_format_hash/**',
|
||||
'test/cpp/jit/upgrader_models/*.ptl',
|
||||
'test/cpp/jit/upgrader_models/*.ptl.ff',
|
||||
'test/dynamo/cpython/**',
|
||||
'**/*.png',
|
||||
'**/*.gz',
|
||||
'**/*.patch',
|
||||
|
|
@ -939,7 +936,6 @@ include_patterns = [
|
|||
exclude_patterns = [
|
||||
'test/run_test.py',
|
||||
'**/fb/**',
|
||||
'test/dynamo/cpython/3.13/**',
|
||||
'test/quantization/**', # should be run through test/test_quantization.py
|
||||
'test/jit/**', # should be run through test/test_jit.py
|
||||
'test/ao/sparsity/**', # should be run through test/test_ao_sparsity.py
|
||||
|
|
@ -1135,7 +1131,6 @@ exclude_patterns = [
|
|||
'caffe2/**/*.pyi',
|
||||
'fb/**',
|
||||
'**/fb/**',
|
||||
'test/dynamo/cpython/**',
|
||||
'third_party/**/*.py',
|
||||
'third_party/**/*.pyi',
|
||||
'torch/_vendor/**',
|
||||
|
|
@ -1541,7 +1536,6 @@ exclude_patterns = [
|
|||
'functorch/notebooks/**',
|
||||
'torch/_inductor/fx_passes/serialized_patterns/**',
|
||||
'torch/_inductor/autoheuristic/artifacts/**',
|
||||
'test/dynamo/cpython/**',
|
||||
'scripts/**',
|
||||
'third_party/**',
|
||||
'fb/**',
|
||||
|
|
|
|||
|
|
@ -1,9 +0,0 @@
|
|||
This subdirectory contains a selection of tests from the CPython repository (branch: v3.13.0):\
|
||||
https://github.com/python/cpython/releases/tag/v3.13.0
|
||||
|
||||
Modifications were made to ensure compatibility with the Dynamo infrastructure:
|
||||
+ Monkey-patched `unittest.TestCase` to `torch._dynamo.test_case.CPythonTestCase`.
|
||||
+ Replaced `unittest.main()` with `torch._dynamo.test_case.run_tests()`.
|
||||
+ Assigned test "owners."
|
||||
+ Annotated CPU-intensive tests with the `@slowTest` decorator.
|
||||
+ Adjusted imports to use `import module` instead of `from test import module`.
|
||||
|
|
@ -1,46 +0,0 @@
|
|||
PYTHON SOFTWARE FOUNDATION LICENSE VERSION 2
|
||||
--------------------------------------------
|
||||
|
||||
1. This LICENSE AGREEMENT is between the Python Software Foundation
|
||||
("PSF"), and the Individual or Organization ("Licensee") accessing and
|
||||
otherwise using this software ("Python") in source or binary form and
|
||||
its associated documentation.
|
||||
|
||||
2. Subject to the terms and conditions of this License Agreement, PSF hereby
|
||||
grants Licensee a nonexclusive, royalty-free, world-wide license to reproduce,
|
||||
analyze, test, perform and/or display publicly, prepare derivative works,
|
||||
distribute, and otherwise use Python alone or in any derivative version,
|
||||
provided, however, that PSF's License Agreement and PSF's notice of copyright,
|
||||
i.e., "Copyright (c) 2001 Python Software Foundation; All Rights Reserved"
|
||||
are retained in Python alone or in any derivative version prepared by Licensee.
|
||||
|
||||
3. In the event Licensee prepares a derivative work that is based on
|
||||
or incorporates Python or any part thereof, and wants to make
|
||||
the derivative work available to others as provided herein, then
|
||||
Licensee hereby agrees to include in any such work a brief summary of
|
||||
the changes made to Python.
|
||||
|
||||
4. PSF is making Python available to Licensee on an "AS IS"
|
||||
basis. PSF MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR
|
||||
IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, PSF MAKES NO AND
|
||||
DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS
|
||||
FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON WILL NOT
|
||||
INFRINGE ANY THIRD PARTY RIGHTS.
|
||||
|
||||
5. PSF SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON
|
||||
FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS
|
||||
A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON,
|
||||
OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF.
|
||||
|
||||
6. This License Agreement will automatically terminate upon a material
|
||||
breach of its terms and conditions.
|
||||
|
||||
7. Nothing in this License Agreement shall be deemed to create any
|
||||
relationship of agency, partnership, or joint venture between PSF and
|
||||
Licensee. This License Agreement does not grant permission to use PSF
|
||||
trademarks or trade name in a trademark sense to endorse or promote
|
||||
products or services of Licensee, or any third party.
|
||||
|
||||
8. By copying, installing or otherwise using Python, Licensee
|
||||
agrees to be bound by the terms and conditions of this License
|
||||
Agreement.
|
||||
|
|
@ -1,6 +1,7 @@
|
|||
# Owner(s): ["module: dynamo"]
|
||||
import contextlib
|
||||
import sys
|
||||
import traceback
|
||||
import unittest
|
||||
from contextlib import contextmanager
|
||||
|
||||
|
|
@ -8,12 +9,18 @@ import torch
|
|||
import torch._dynamo.test_case
|
||||
import torch._dynamo.testing
|
||||
from torch._dynamo.exc import InternalTorchDynamoError
|
||||
from torch._dynamo.testing import EagerAndRecordGraphs, normalize_gm, same
|
||||
from torch._dynamo.testing import (
|
||||
EagerAndRecordGraphs,
|
||||
normalize_gm,
|
||||
same,
|
||||
skipIfNotPy311,
|
||||
)
|
||||
from torch._dynamo.utils import counters
|
||||
from torch.nn import functional as F
|
||||
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
make_dynamo_test,
|
||||
parametrize,
|
||||
)
|
||||
|
||||
|
|
@ -30,16 +37,6 @@ z_glb = 0
|
|||
k_glb = 0
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def set_default_dtype(dtype):
|
||||
old_dtype = torch.get_default_dtype()
|
||||
try:
|
||||
torch.set_default_dtype(dtype)
|
||||
yield
|
||||
finally:
|
||||
torch.set_default_dtype(old_dtype)
|
||||
|
||||
|
||||
class CustomizedCtxManager:
|
||||
def __init__(self, mode):
|
||||
self.prev = torch.is_grad_enabled()
|
||||
|
|
@ -2703,6 +2700,319 @@ class GraphModule(torch.nn.Module):
|
|||
self.assertEqual(y, t.sin())
|
||||
|
||||
|
||||
class CPythonContextManagerTestCase(torch._dynamo.test_case.CPythonTestCase):
|
||||
# Tests taken from CPython source code in cpython/Lib/test/test_contextlib.py
|
||||
# https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_contextlib.py
|
||||
|
||||
@make_dynamo_test
|
||||
def test_contextmanager_plain(self):
|
||||
state = []
|
||||
|
||||
@contextmanager
|
||||
def woohoo():
|
||||
state.append(1)
|
||||
yield 42
|
||||
state.append(999)
|
||||
|
||||
with woohoo() as x:
|
||||
self.assertEqual(state, [1])
|
||||
self.assertEqual(x, 42)
|
||||
state.append(x)
|
||||
self.assertEqual(state, [1, 42, 999])
|
||||
|
||||
@skipIfNotPy311
|
||||
@make_dynamo_test
|
||||
def test_contextmanager_finally(self):
|
||||
state = []
|
||||
|
||||
@contextmanager
|
||||
def woohoo():
|
||||
state.append(1)
|
||||
try:
|
||||
yield 42
|
||||
finally:
|
||||
state.append(999)
|
||||
|
||||
with self.assertRaises(ZeroDivisionError):
|
||||
with woohoo() as x:
|
||||
self.assertEqual(state, [1])
|
||||
self.assertEqual(x, 42)
|
||||
state.append(x)
|
||||
raise ZeroDivisionError
|
||||
self.assertEqual(state, [1, 42, 999])
|
||||
|
||||
@unittest.expectedFailure
|
||||
@make_dynamo_test
|
||||
def test_contextmanager_traceback(self):
|
||||
@contextmanager
|
||||
def f():
|
||||
yield
|
||||
|
||||
try:
|
||||
with f():
|
||||
1 / 0
|
||||
except ZeroDivisionError as e:
|
||||
frames = traceback.extract_tb(e.__traceback__)
|
||||
|
||||
self.assertEqual(len(frames), 1)
|
||||
self.assertEqual(frames[0].name, "test_contextmanager_traceback")
|
||||
self.assertEqual(frames[0].line, "1/0")
|
||||
|
||||
# Repeat with RuntimeError (which goes through a different code path)
|
||||
try:
|
||||
with f():
|
||||
raise NotImplementedError(42)
|
||||
except NotImplementedError as e:
|
||||
frames = traceback.extract_tb(e.__traceback__)
|
||||
|
||||
self.assertEqual(len(frames), 1)
|
||||
self.assertEqual(frames[0].name, "test_contextmanager_traceback")
|
||||
self.assertEqual(frames[0].line, "raise NotImplementedError(42)")
|
||||
|
||||
@make_dynamo_test
|
||||
def test_contextmanager_no_reraise(self):
|
||||
@contextmanager
|
||||
def whee():
|
||||
yield
|
||||
|
||||
ctx = whee()
|
||||
ctx.__enter__()
|
||||
# Calling __exit__ should not result in an exception
|
||||
self.assertFalse(ctx.__exit__(TypeError, TypeError("foo"), None))
|
||||
|
||||
@make_dynamo_test
|
||||
def test_contextmanager_trap_yield_after_throw(self):
|
||||
@contextmanager
|
||||
def whoo():
|
||||
try:
|
||||
yield
|
||||
except Exception: # noqa: E722
|
||||
yield
|
||||
|
||||
ctx = whoo()
|
||||
ctx.__enter__()
|
||||
with self.assertRaises(RuntimeError):
|
||||
ctx.__exit__(TypeError, TypeError("foo"), None)
|
||||
|
||||
@unittest.skipIf(sys.version_info < (3, 11), "Python 3.11+")
|
||||
def test_contextmanager_except(self):
|
||||
state = []
|
||||
|
||||
@contextmanager
|
||||
def woohoo():
|
||||
state.append(1)
|
||||
try:
|
||||
yield 42
|
||||
except ZeroDivisionError as e:
|
||||
state.append(e.args[0])
|
||||
self.assertEqual(state, [1, 42, 999])
|
||||
|
||||
with woohoo() as x:
|
||||
self.assertEqual(state, [1])
|
||||
self.assertEqual(x, 42)
|
||||
state.append(x)
|
||||
raise ZeroDivisionError(999)
|
||||
self.assertEqual(state, [1, 42, 999])
|
||||
|
||||
@unittest.expectedFailure
|
||||
@make_dynamo_test
|
||||
def test_contextmanager_except_stopiter(self):
|
||||
@contextmanager
|
||||
def woohoo():
|
||||
yield
|
||||
|
||||
class StopIterationSubclass(StopIteration):
|
||||
pass
|
||||
|
||||
for stop_exc in (StopIteration("spam"), StopIterationSubclass("spam")):
|
||||
with self.subTest(type=type(stop_exc)):
|
||||
try:
|
||||
with woohoo():
|
||||
raise stop_exc
|
||||
except Exception as ex:
|
||||
self.assertIs(ex, stop_exc)
|
||||
else:
|
||||
self.fail(f"{stop_exc} was suppressed")
|
||||
|
||||
@unittest.expectedFailure
|
||||
@make_dynamo_test
|
||||
def test_contextmanager_except_pep479(self):
|
||||
code = """\
|
||||
from __future__ import generator_stop
|
||||
from contextlib import contextmanager
|
||||
@contextmanager
|
||||
def woohoo():
|
||||
yield
|
||||
"""
|
||||
locals = {}
|
||||
exec(code, locals, locals)
|
||||
woohoo = locals["woohoo"]
|
||||
|
||||
stop_exc = StopIteration("spam")
|
||||
try:
|
||||
with woohoo():
|
||||
raise stop_exc
|
||||
except Exception as ex:
|
||||
self.assertIs(ex, stop_exc)
|
||||
else:
|
||||
self.fail("StopIteration was suppressed")
|
||||
|
||||
@unittest.expectedFailure
|
||||
@make_dynamo_test
|
||||
def test_contextmanager_do_not_unchain_non_stopiteration_exceptions(self):
|
||||
@contextmanager
|
||||
def test_issue29692():
|
||||
try:
|
||||
yield
|
||||
except Exception as exc:
|
||||
raise RuntimeError("issue29692:Chained") from exc
|
||||
|
||||
try:
|
||||
with test_issue29692():
|
||||
raise ZeroDivisionError
|
||||
except Exception as ex:
|
||||
self.assertIs(type(ex), RuntimeError)
|
||||
self.assertEqual(ex.args[0], "issue29692:Chained")
|
||||
self.assertIsInstance(ex.__cause__, ZeroDivisionError)
|
||||
|
||||
try:
|
||||
with test_issue29692():
|
||||
raise StopIteration("issue29692:Unchained")
|
||||
except Exception as ex:
|
||||
self.assertIs(type(ex), StopIteration)
|
||||
self.assertEqual(ex.args[0], "issue29692:Unchained")
|
||||
self.assertIsNone(ex.__cause__)
|
||||
|
||||
@unittest.expectedFailure
|
||||
@make_dynamo_test
|
||||
def _create_contextmanager_attribs(self):
|
||||
def attribs(**kw):
|
||||
def decorate(func):
|
||||
for k, v in kw.items():
|
||||
setattr(func, k, v)
|
||||
return func
|
||||
|
||||
return decorate
|
||||
|
||||
@contextmanager
|
||||
@attribs(foo="bar")
|
||||
def baz(spam):
|
||||
"""Whee!"""
|
||||
|
||||
return baz
|
||||
|
||||
@unittest.expectedFailure
|
||||
@make_dynamo_test
|
||||
def test_contextmanager_attribs(self):
|
||||
baz = self._create_contextmanager_attribs()
|
||||
self.assertEqual(baz.__name__, "baz")
|
||||
self.assertEqual(baz.foo, "bar")
|
||||
|
||||
@make_dynamo_test
|
||||
def test_keywords(self):
|
||||
# Ensure no keyword arguments are inhibited
|
||||
@contextmanager
|
||||
def woohoo(self, func, args, kwds):
|
||||
yield (self, func, args, kwds)
|
||||
|
||||
with woohoo(self=11, func=22, args=33, kwds=44) as target:
|
||||
self.assertEqual(target, (11, 22, 33, 44))
|
||||
|
||||
@unittest.expectedFailure
|
||||
@make_dynamo_test
|
||||
def test_param_errors(self):
|
||||
@contextmanager
|
||||
def woohoo(a, *, b):
|
||||
yield
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
woohoo()
|
||||
with self.assertRaises(TypeError):
|
||||
woohoo(3, 5)
|
||||
with self.assertRaises(TypeError):
|
||||
woohoo(b=3)
|
||||
|
||||
@make_dynamo_test
|
||||
def test_recursive(self):
|
||||
depth = 0
|
||||
|
||||
@contextmanager
|
||||
def woohoo():
|
||||
nonlocal depth
|
||||
before = depth
|
||||
depth += 1
|
||||
yield
|
||||
depth -= 1
|
||||
self.assertEqual(depth, before)
|
||||
|
||||
@woohoo()
|
||||
def recursive():
|
||||
if depth < 10:
|
||||
recursive()
|
||||
|
||||
recursive()
|
||||
self.assertEqual(depth, 0)
|
||||
|
||||
@skipIfNotPy311
|
||||
@make_dynamo_test
|
||||
def test_contextmanager_trap_no_yield(self):
|
||||
@contextmanager
|
||||
def whoo():
|
||||
if False:
|
||||
yield
|
||||
|
||||
ctx = whoo()
|
||||
with self.assertRaises(RuntimeError):
|
||||
ctx.__enter__()
|
||||
|
||||
@make_dynamo_test
|
||||
def test_contextmanager_trap_second_yield(self):
|
||||
@contextmanager
|
||||
def whoo():
|
||||
yield
|
||||
yield
|
||||
|
||||
ctx = whoo()
|
||||
ctx.__enter__()
|
||||
with self.assertRaises(RuntimeError):
|
||||
ctx.__exit__(None, None, None)
|
||||
|
||||
@unittest.expectedFailure
|
||||
@make_dynamo_test
|
||||
def test_contextmanager_wrap_runtimeerror(self):
|
||||
@contextmanager
|
||||
def woohoo():
|
||||
try:
|
||||
yield
|
||||
except Exception as exc:
|
||||
raise RuntimeError(f"caught {exc}") from exc
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
with woohoo():
|
||||
1 / 0
|
||||
|
||||
# If the context manager wrapped StopIteration in a RuntimeError,
|
||||
# we also unwrap it, because we can't tell whether the wrapping was
|
||||
# done by the generator machinery or by the generator itself.
|
||||
with self.assertRaises(StopIteration):
|
||||
with woohoo():
|
||||
raise StopIteration
|
||||
|
||||
@make_dynamo_test
|
||||
def test_contextmanager_non_normalised(self):
|
||||
@contextmanager
|
||||
def whoo():
|
||||
try:
|
||||
yield
|
||||
except RuntimeError:
|
||||
raise SyntaxError # noqa: B904
|
||||
|
||||
ctx = whoo()
|
||||
ctx.__enter__()
|
||||
with self.assertRaises(SyntaxError):
|
||||
ctx.__exit__(RuntimeError, None, None)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(CtxManagerTests)
|
||||
instantiate_parametrized_tests(ContextlibContextManagerTests)
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
import contextlib
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
import torch._dynamo.config
|
||||
|
|
@ -904,6 +905,238 @@ class ExceptionTests(torch._dynamo.test_case.TestCase):
|
|||
assert exc2.__context__ is None
|
||||
|
||||
|
||||
class CPythonExceptionTests(torch._dynamo.test_case.TestCase):
|
||||
# Tests taken from CPython source code in cpython/Lib/test/test_exceptions.py
|
||||
# https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_exceptions.py
|
||||
def setUp(self):
|
||||
self._u_prev = torch._dynamo.config.enable_trace_unittest
|
||||
torch._dynamo.config.enable_trace_unittest = True
|
||||
|
||||
def tearDown(self):
|
||||
torch._dynamo.config.enable_trace_unittest = self._u_prev
|
||||
|
||||
@make_dynamo_test
|
||||
def testChainingAttrs(self):
|
||||
e = Exception()
|
||||
assert e.__context__ is None
|
||||
assert e.__cause__ is None
|
||||
|
||||
e = TypeError()
|
||||
assert e.__context__ is None
|
||||
assert e.__cause__ is None
|
||||
|
||||
e = MyException()
|
||||
assert e.__context__ is None
|
||||
assert e.__cause__ is None
|
||||
|
||||
@make_dynamo_test
|
||||
def testChainingDescriptors(self):
|
||||
try:
|
||||
raise Exception # noqa: TRY002
|
||||
except Exception as exc:
|
||||
e = exc
|
||||
|
||||
assert e.__context__ is None
|
||||
assert e.__cause__ is None
|
||||
assert e.__suppress_context__ is False
|
||||
|
||||
e.__context__ = NameError()
|
||||
e.__cause__ = None
|
||||
assert isinstance(e.__context__, NameError)
|
||||
assert e.__cause__ is None
|
||||
assert e.__suppress_context__ is True
|
||||
e.__suppress_context__ = False
|
||||
assert e.__suppress_context__ is False
|
||||
|
||||
@make_dynamo_test
|
||||
def test_context_of_exception_in_try_and_finally(self):
|
||||
try:
|
||||
try:
|
||||
te = TypeError(1)
|
||||
raise te
|
||||
finally:
|
||||
ve = ValueError(2)
|
||||
raise ve
|
||||
except Exception as e:
|
||||
exc = e
|
||||
|
||||
assert exc is ve
|
||||
assert exc.__context__ is te
|
||||
|
||||
@make_dynamo_test
|
||||
def test_context_of_exception_in_except_and_finally(self):
|
||||
try:
|
||||
try:
|
||||
te = TypeError(1)
|
||||
raise te
|
||||
except Exception: # noqa: E722
|
||||
ve = ValueError(2)
|
||||
raise ve # noqa: B904
|
||||
finally:
|
||||
oe = OSError(3)
|
||||
raise oe
|
||||
except Exception as e:
|
||||
exc = e
|
||||
|
||||
assert exc is oe
|
||||
assert exc.__context__ is ve
|
||||
assert exc.__context__.__context__ is te
|
||||
|
||||
@make_dynamo_test
|
||||
def test_context_of_exception_in_else_and_finally(self):
|
||||
try:
|
||||
try:
|
||||
pass
|
||||
except Exception: # noqa: E722
|
||||
pass
|
||||
else:
|
||||
ve = ValueError(1)
|
||||
raise ve
|
||||
finally:
|
||||
oe = OSError(2)
|
||||
raise oe
|
||||
except Exception as e:
|
||||
exc = e
|
||||
|
||||
assert exc is oe
|
||||
assert exc.__context__ is ve
|
||||
|
||||
@make_dynamo_test
|
||||
def test_raise_does_not_create_context_chain_cycle(self):
|
||||
A = AssertionError
|
||||
B = BytesWarning
|
||||
C = ConnectionError
|
||||
|
||||
# Create a context chain:
|
||||
# C -> B -> A
|
||||
# Then raise A in context of C.
|
||||
try:
|
||||
try:
|
||||
raise A
|
||||
except A as a_:
|
||||
a = a_
|
||||
try:
|
||||
raise B
|
||||
except B as b_:
|
||||
b = b_
|
||||
try:
|
||||
raise C
|
||||
except C as c_:
|
||||
c = c_
|
||||
self.assertIsInstance(a, A)
|
||||
self.assertIsInstance(b, B)
|
||||
self.assertIsInstance(c, C)
|
||||
self.assertIsNone(a.__context__)
|
||||
self.assertIs(b.__context__, a)
|
||||
self.assertIs(c.__context__, b)
|
||||
raise a # noqa: B904
|
||||
except A as e:
|
||||
exc = e
|
||||
|
||||
# Expect A -> C -> B, without cycle
|
||||
self.assertIs(exc, a)
|
||||
self.assertIs(a.__context__, c)
|
||||
self.assertIs(c.__context__, b)
|
||||
self.assertIsNone(b.__context__)
|
||||
|
||||
@make_dynamo_test
|
||||
def test_no_hang_on_context_chain_cycle1(self):
|
||||
# See issue 25782. Cycle in context chain.
|
||||
|
||||
def cycle():
|
||||
try:
|
||||
raise ValueError(1)
|
||||
except ValueError as ex:
|
||||
ex.__context__ = ex
|
||||
raise TypeError(2) # noqa: B904
|
||||
|
||||
try:
|
||||
cycle()
|
||||
except Exception as e:
|
||||
exc = e
|
||||
|
||||
self.assertIsInstance(exc, TypeError)
|
||||
self.assertIsInstance(exc.__context__, ValueError)
|
||||
self.assertIs(exc.__context__.__context__, exc.__context__)
|
||||
|
||||
@unittest.expectedFailure
|
||||
@make_dynamo_test
|
||||
def test_no_hang_on_context_chain_cycle2(self):
|
||||
# See issue 25782. Cycle at head of context chain.
|
||||
|
||||
A = AssertionError
|
||||
B = BytesWarning
|
||||
C = ConnectionError
|
||||
|
||||
# Context cycle:
|
||||
# +-----------+
|
||||
# V |
|
||||
# C --> B --> A
|
||||
with self.assertRaises(C) as cm:
|
||||
try:
|
||||
raise A() # noqa: RSE102
|
||||
except A as _a:
|
||||
a = _a
|
||||
try:
|
||||
raise B() # noqa: RSE102
|
||||
except B as _b:
|
||||
b = _b
|
||||
try:
|
||||
raise C() # noqa: RSE102
|
||||
except C as _c:
|
||||
c = _c
|
||||
a.__context__ = c
|
||||
raise c # noqa: B904
|
||||
|
||||
self.assertIs(cm.exception, c)
|
||||
# Verify the expected context chain cycle
|
||||
self.assertIs(c.__context__, b)
|
||||
self.assertIs(b.__context__, a)
|
||||
self.assertIs(a.__context__, c)
|
||||
|
||||
@make_dynamo_test
|
||||
def test_no_hang_on_context_chain_cycle3(self):
|
||||
# See issue 25782. Longer context chain with cycle.
|
||||
A = AssertionError
|
||||
B = BytesWarning
|
||||
C = ConnectionError
|
||||
D = DeprecationWarning
|
||||
E = Exception
|
||||
|
||||
# Context cycle:
|
||||
# +-----------+
|
||||
# V |
|
||||
# E --> D --> C --> B --> A
|
||||
with self.assertRaises(E) as cm:
|
||||
try:
|
||||
raise A
|
||||
except A as _a:
|
||||
a = _a
|
||||
try:
|
||||
raise B
|
||||
except B as _b:
|
||||
b = _b
|
||||
try:
|
||||
raise C
|
||||
except C as _c:
|
||||
c = _c
|
||||
a.__context__ = c
|
||||
try:
|
||||
raise D
|
||||
except D as _d:
|
||||
d = _d
|
||||
e = E()
|
||||
raise e # noqa: B904
|
||||
|
||||
self.assertIs(cm.exception, e)
|
||||
# Verify the expected context chain cycle
|
||||
self.assertIs(e.__context__, d)
|
||||
self.assertIs(d.__context__, c)
|
||||
self.assertIs(c.__context__, b)
|
||||
self.assertIs(b.__context__, a)
|
||||
self.assertIs(a.__context__, c)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(ExceptionTests)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1481,6 +1481,331 @@ class TestGeneratorThrow(GeneratorTestsBase):
|
|||
self._compile_check(fn)
|
||||
|
||||
|
||||
class GeneratorCloseCPythonTests(GeneratorTestsBase):
|
||||
# Taken from commit
|
||||
# https://github.com/python/cpython/blob/d51a4ca1123e3e49e5cae4273355bdfd9e419a10
|
||||
# changed the tests a little bit to run them inside dynamo
|
||||
# + replaced all self.assert* calls to plain assert statements
|
||||
|
||||
def test_close_no_return_value(self):
|
||||
def f():
|
||||
yield
|
||||
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def fn(t):
|
||||
gen = f()
|
||||
gen.send(None)
|
||||
assert gen.close() is None
|
||||
return t.sin()
|
||||
|
||||
t = torch.randn(2)
|
||||
fn(t)
|
||||
|
||||
def test_close_return_value(self):
|
||||
def f():
|
||||
try:
|
||||
yield
|
||||
# close() raises GeneratorExit here, which is caught
|
||||
except GeneratorExit:
|
||||
return 0 # noqa: B901
|
||||
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def fn(t):
|
||||
gen = f()
|
||||
gen.send(None)
|
||||
assert gen.close() == 0
|
||||
return t.sin()
|
||||
|
||||
t = torch.randn(2)
|
||||
fn(t)
|
||||
|
||||
def test_close_not_catching_exit(self):
|
||||
def f():
|
||||
yield
|
||||
# close() raises GeneratorExit here, which isn't caught and
|
||||
# therefore propagates -- no return value
|
||||
return 0 # noqa: B901
|
||||
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def fn(t):
|
||||
gen = f()
|
||||
gen.send(None)
|
||||
assert gen.close() is None
|
||||
return t.sin()
|
||||
|
||||
t = torch.randn(2)
|
||||
fn(t)
|
||||
|
||||
def test_close_not_started(self):
|
||||
def f():
|
||||
try:
|
||||
yield
|
||||
except GeneratorExit:
|
||||
return 0 # noqa: B901
|
||||
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def fn(t):
|
||||
gen = f()
|
||||
assert gen.close() is None
|
||||
return t.sin()
|
||||
|
||||
t = torch.randn(2)
|
||||
fn(t)
|
||||
|
||||
def test_close_exhausted(self):
|
||||
def f():
|
||||
try:
|
||||
yield
|
||||
except GeneratorExit:
|
||||
return 0 # noqa: B901
|
||||
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def fn(t):
|
||||
gen = f()
|
||||
next(gen)
|
||||
z = 0
|
||||
try:
|
||||
next(gen) # -> StopIteration
|
||||
except StopIteration:
|
||||
z = 1
|
||||
except Exception as e:
|
||||
# anything other than StopIteration should fail
|
||||
raise AssertionError from e
|
||||
assert z == 1
|
||||
assert gen.close() is None
|
||||
return t.sin()
|
||||
|
||||
t = torch.randn(2)
|
||||
fn(t)
|
||||
|
||||
def test_close_closed(self):
|
||||
def f():
|
||||
try:
|
||||
yield
|
||||
except GeneratorExit:
|
||||
return 0 # noqa: B901
|
||||
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def fn(t):
|
||||
gen = f()
|
||||
gen.send(None)
|
||||
assert gen.close() == 0
|
||||
assert gen.close() is None
|
||||
return t.sin()
|
||||
|
||||
t = torch.randn(2)
|
||||
fn(t)
|
||||
|
||||
def test_close_raises(self):
|
||||
def f():
|
||||
try:
|
||||
yield
|
||||
except GeneratorExit:
|
||||
pass
|
||||
raise RuntimeError
|
||||
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def fn(t):
|
||||
gen = f()
|
||||
gen.send(None)
|
||||
z = 0
|
||||
try:
|
||||
gen.close() # -> RuntimeError
|
||||
except RuntimeError:
|
||||
z = 1
|
||||
except Exception as e:
|
||||
raise AssertionError from e
|
||||
assert z == 1
|
||||
return t.sin()
|
||||
|
||||
t = torch.randn(2)
|
||||
fn(t)
|
||||
|
||||
|
||||
class GeneratorThrowCpythonTests(GeneratorTestsBase):
|
||||
# Taken from commit
|
||||
# https://github.com/python/cpython/blob/d51a4ca1123e3e49e5cae4273355bdfd9e419a10
|
||||
# changed the tests a little bit to run them inside dynamo
|
||||
# + replaced all self.assert* calls to plain assert statements
|
||||
|
||||
def test_exception_context_with_yield(self):
|
||||
def f():
|
||||
try:
|
||||
raise KeyError("a")
|
||||
except Exception:
|
||||
yield
|
||||
|
||||
def fn(t):
|
||||
gen = f()
|
||||
gen.send(None)
|
||||
try:
|
||||
gen.throw(ValueError)
|
||||
except ValueError as e:
|
||||
context = e.__context__
|
||||
assert (type(context), context.args) == (KeyError, ("a",))
|
||||
except Exception as e:
|
||||
raise AssertionError from e
|
||||
return t.sin()
|
||||
|
||||
self._compile_check(fn)
|
||||
|
||||
def test_exception_context_with_yield_inside_generator(self):
|
||||
# Check that the context is also available from inside the generator
|
||||
# with yield, as opposed to outside.
|
||||
def f():
|
||||
z = 0
|
||||
try:
|
||||
raise KeyError("a")
|
||||
except Exception:
|
||||
try:
|
||||
yield
|
||||
except Exception as exc:
|
||||
z = 1
|
||||
assert type(exc) == ValueError
|
||||
context = exc.__context__
|
||||
assert (type(context), context.args) == (KeyError, ("a",))
|
||||
yield "b"
|
||||
finally:
|
||||
assert z == 1
|
||||
|
||||
def fn(t):
|
||||
gen = f()
|
||||
gen.send(None)
|
||||
actual = gen.throw(ValueError)
|
||||
# This ensures that the assertions inside were executed.
|
||||
assert actual == "b"
|
||||
return t.sin()
|
||||
|
||||
self._compile_check(fn)
|
||||
|
||||
def test_exception_context_with_yield_from(self):
|
||||
def f():
|
||||
yield
|
||||
|
||||
def g():
|
||||
try:
|
||||
raise KeyError("a")
|
||||
except Exception:
|
||||
yield from f()
|
||||
|
||||
def fn(t):
|
||||
gen = g()
|
||||
gen.send(None)
|
||||
try:
|
||||
gen.throw(ValueError)
|
||||
except ValueError as e:
|
||||
context = e.__context__
|
||||
assert (type(context), context.args) == (KeyError, ("a",))
|
||||
except Exception as e:
|
||||
raise AssertionError from e
|
||||
return t.sin()
|
||||
|
||||
self._compile_check(fn)
|
||||
|
||||
def test_exception_context_with_yield_from_with_context_cycle(self):
|
||||
# Check trying to create an exception context cycle:
|
||||
# https://bugs.python.org/issue40696
|
||||
has_cycle = None
|
||||
|
||||
def f():
|
||||
yield
|
||||
|
||||
def g(exc):
|
||||
nonlocal has_cycle
|
||||
try:
|
||||
raise exc
|
||||
except Exception:
|
||||
try:
|
||||
yield from f()
|
||||
except Exception as exc:
|
||||
has_cycle = exc is exc.__context__
|
||||
yield
|
||||
|
||||
def fn(t):
|
||||
exc = KeyError("a")
|
||||
gen = g(exc)
|
||||
gen.send(None)
|
||||
gen.throw(exc)
|
||||
# This also distinguishes from the initial has_cycle=None.
|
||||
assert has_cycle is False
|
||||
return t.sin()
|
||||
|
||||
self._compile_check(fn)
|
||||
|
||||
def test_throw_after_none_exc_type(self):
|
||||
def g():
|
||||
try:
|
||||
raise KeyError
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
try:
|
||||
yield
|
||||
except Exception:
|
||||
raise RuntimeError # noqa: B904
|
||||
|
||||
def fn(t):
|
||||
gen = g()
|
||||
gen.send(None)
|
||||
z = 0
|
||||
try:
|
||||
gen.throw(ValueError)
|
||||
except RuntimeError:
|
||||
z += 1
|
||||
except Exception:
|
||||
raise AssertionError # noqa: B904
|
||||
assert z == 1
|
||||
return t.sin()
|
||||
|
||||
self._compile_check(fn)
|
||||
|
||||
|
||||
class GeneratorCPythonTests(GeneratorTestsBase):
|
||||
# Taken from commit
|
||||
# https://github.com/python/cpython/blob/d51a4ca1123e3e49e5cae4273355bdfd9e419a10
|
||||
# changed the tests a little bit to run them inside dynamo
|
||||
# + replaced all self.assert* calls to plain assert statements
|
||||
|
||||
def test_send_non_none_to_new_gen(self):
|
||||
def f():
|
||||
yield 1
|
||||
|
||||
def fn(t):
|
||||
g = f()
|
||||
z = 0
|
||||
try:
|
||||
g.send(0)
|
||||
except TypeError:
|
||||
z += 1
|
||||
except Exception as e:
|
||||
raise AssertionError from e
|
||||
assert z == 1
|
||||
assert next(g) == 1
|
||||
return t.sin()
|
||||
|
||||
self._compile_check(fn)
|
||||
|
||||
def test_issue103488(self):
|
||||
def gen_raises():
|
||||
yield 1
|
||||
raise ValueError
|
||||
|
||||
def loop():
|
||||
try:
|
||||
for _ in gen_raises():
|
||||
if True is False: # noqa: PLR0133
|
||||
return
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
def fn(t):
|
||||
# This should not raise
|
||||
loop()
|
||||
return t.sin()
|
||||
|
||||
self._compile_check(fn)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(GeneratorTests)
|
||||
instantiate_parametrized_tests(TestGeneratorSend)
|
||||
instantiate_parametrized_tests(TestGeneratorClose)
|
||||
|
|
|
|||
52
test/dynamo/test_generator_stop.py
Normal file
52
test/dynamo/test_generator_stop.py
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
# Owner(s): ["module: dynamo"]
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
import torch._dynamo.test_case
|
||||
from torch.testing._internal.common_utils import make_dynamo_test
|
||||
|
||||
|
||||
class TestPEP479(torch._dynamo.test_case.CPythonTestCase):
|
||||
# Tests taken from CPython source code in cpython/Lib/test/test_generator_stop.py
|
||||
# https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_generator_stop.py
|
||||
@unittest.skipIf(sys.version_info < (3, 12), "Test does not work in Python < 3.12")
|
||||
@make_dynamo_test
|
||||
def test_stopiteration_wrapping(self):
|
||||
def f():
|
||||
raise StopIteration
|
||||
|
||||
def g():
|
||||
yield f()
|
||||
|
||||
with self.assertRaises(RuntimeError) as cm:
|
||||
next(g())
|
||||
self.assertEqual("generator raised StopIteration", str(cm.exception))
|
||||
|
||||
@unittest.skipIf(sys.version_info < (3, 12), "Test does not work in Python < 3.12")
|
||||
@make_dynamo_test
|
||||
def test_stopiteration_wrapping_context(self):
|
||||
def f():
|
||||
raise StopIteration
|
||||
|
||||
def g():
|
||||
yield f()
|
||||
|
||||
try:
|
||||
next(g())
|
||||
except RuntimeError as exc:
|
||||
self.assertIs(type(exc.__cause__), StopIteration)
|
||||
self.assertIs(type(exc.__context__), StopIteration)
|
||||
self.assertTrue(exc.__suppress_context__)
|
||||
else:
|
||||
self.fail(
|
||||
"__cause__, __context__, or __suppress_context__ "
|
||||
"were not properly set"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
run_tests()
|
||||
563
test/dynamo/test_raise.py
Normal file
563
test/dynamo/test_raise.py
Normal file
|
|
@ -0,0 +1,563 @@
|
|||
# Owner(s): ["module: dynamo"]
|
||||
|
||||
# ruff: noqa
|
||||
# flake8: noqa
|
||||
|
||||
import sys
|
||||
import types
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
import torch._dynamo.config
|
||||
import torch._dynamo.test_case
|
||||
import torch._functorch.config
|
||||
import torch.nn
|
||||
import torch.utils.checkpoint
|
||||
from torch.testing._internal.common_utils import make_dynamo_test
|
||||
|
||||
|
||||
def get_tb():
|
||||
try:
|
||||
raise OSError()
|
||||
except:
|
||||
return sys.exc_info()[2]
|
||||
|
||||
|
||||
class Context:
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, exc_tb):
|
||||
return True
|
||||
|
||||
|
||||
class MyException(Exception):
|
||||
def __init__(self):
|
||||
raise RuntimeError()
|
||||
|
||||
|
||||
class ContextManager:
|
||||
def __enter__(self):
|
||||
pass
|
||||
|
||||
def __exit__(self, t, v, tb):
|
||||
raise NameError
|
||||
|
||||
|
||||
class TestRaise(torch._dynamo.test_case.CPythonTestCase):
|
||||
# Tests taken from CPython source code in cpython/Lib/test/test_raise.py
|
||||
# https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_raise.py
|
||||
@make_dynamo_test
|
||||
def test_invalid_reraise(self):
|
||||
try:
|
||||
raise
|
||||
except RuntimeError as e:
|
||||
self.assertIn("No active exception", str(e))
|
||||
else:
|
||||
self.fail("No exception raised")
|
||||
|
||||
@make_dynamo_test
|
||||
def test_reraise(self):
|
||||
try:
|
||||
try:
|
||||
raise IndexError
|
||||
except IndexError as e:
|
||||
exc1 = e
|
||||
raise
|
||||
except IndexError as exc2:
|
||||
self.assertIs(exc1, exc2)
|
||||
else:
|
||||
self.fail("No exception raised")
|
||||
|
||||
@make_dynamo_test
|
||||
def test_except_reraise(self):
|
||||
def reraise():
|
||||
try:
|
||||
raise TypeError("foo")
|
||||
except:
|
||||
try:
|
||||
raise KeyError("caught")
|
||||
except KeyError:
|
||||
pass
|
||||
raise
|
||||
|
||||
self.assertRaises(TypeError, reraise)
|
||||
|
||||
@make_dynamo_test
|
||||
def test_finally_reraise(self):
|
||||
def reraise():
|
||||
try:
|
||||
raise TypeError("foo")
|
||||
except:
|
||||
try:
|
||||
raise KeyError("caught")
|
||||
finally:
|
||||
raise
|
||||
|
||||
self.assertRaises(KeyError, reraise)
|
||||
|
||||
@make_dynamo_test
|
||||
def test_nested_reraise(self):
|
||||
def nested_reraise():
|
||||
raise
|
||||
|
||||
def reraise():
|
||||
try:
|
||||
raise TypeError("foo")
|
||||
except:
|
||||
nested_reraise()
|
||||
|
||||
self.assertRaises(TypeError, reraise)
|
||||
|
||||
@make_dynamo_test
|
||||
def test_raise_from_None(self):
|
||||
try:
|
||||
try:
|
||||
raise TypeError("foo")
|
||||
except:
|
||||
raise ValueError() from None
|
||||
except ValueError as e:
|
||||
self.assertIsInstance(e.__context__, TypeError)
|
||||
self.assertIsNone(e.__cause__)
|
||||
|
||||
@make_dynamo_test
|
||||
def test_with_reraise1(self):
|
||||
def reraise():
|
||||
try:
|
||||
raise TypeError("foo")
|
||||
except:
|
||||
with Context():
|
||||
pass
|
||||
raise
|
||||
|
||||
self.assertRaises(TypeError, reraise)
|
||||
|
||||
@make_dynamo_test
|
||||
def test_with_reraise2(self):
|
||||
def reraise():
|
||||
try:
|
||||
raise TypeError("foo")
|
||||
except:
|
||||
with Context():
|
||||
raise KeyError("caught")
|
||||
raise
|
||||
|
||||
self.assertRaises(TypeError, reraise)
|
||||
|
||||
@make_dynamo_test
|
||||
def test_yield_reraise(self):
|
||||
def reraise():
|
||||
try:
|
||||
raise TypeError("foo")
|
||||
except:
|
||||
yield 1
|
||||
raise
|
||||
|
||||
g = reraise()
|
||||
next(g)
|
||||
self.assertRaises(TypeError, lambda: next(g))
|
||||
self.assertRaises(StopIteration, lambda: next(g))
|
||||
|
||||
@make_dynamo_test
|
||||
def test_erroneous_exception(self):
|
||||
try:
|
||||
raise MyException
|
||||
except RuntimeError:
|
||||
pass
|
||||
else:
|
||||
self.fail("No exception raised")
|
||||
|
||||
@unittest.expectedFailure # object
|
||||
@make_dynamo_test
|
||||
def test_new_returns_invalid_instance(self):
|
||||
# See issue #11627.
|
||||
class MyException2(Exception):
|
||||
def __new__(cls, *args):
|
||||
return object()
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
raise MyException2
|
||||
|
||||
@unittest.expectedFailure # Assertion with non-string message
|
||||
@make_dynamo_test
|
||||
def test_assert_with_tuple_arg(self):
|
||||
try:
|
||||
assert False, (3,)
|
||||
except AssertionError as e:
|
||||
self.assertEqual(str(e), "(3,)")
|
||||
|
||||
|
||||
class TestCause(torch._dynamo.test_case.TestCase):
|
||||
# Tests taken from CPython source code in cpython/Lib/test/test_raise.py
|
||||
# https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_raise.py
|
||||
def setUp(self):
|
||||
self._prev = torch._dynamo.config.enable_trace_unittest
|
||||
torch._dynamo.config.enable_trace_unittest = True
|
||||
|
||||
def tearDown(self):
|
||||
torch._dynamo.config.enable_trace_unittest = self._prev
|
||||
|
||||
@make_dynamo_test
|
||||
def testCauseSyntax(self):
|
||||
try:
|
||||
try:
|
||||
try:
|
||||
raise TypeError
|
||||
except Exception:
|
||||
raise ValueError from None
|
||||
except ValueError as exc:
|
||||
self.assertIsNone(exc.__cause__)
|
||||
self.assertTrue(exc.__suppress_context__)
|
||||
exc.__suppress_context__ = False
|
||||
raise exc
|
||||
except ValueError as exc:
|
||||
e = exc
|
||||
|
||||
self.assertIsNone(e.__cause__)
|
||||
self.assertFalse(e.__suppress_context__)
|
||||
self.assertIsInstance(e.__context__, TypeError)
|
||||
|
||||
@make_dynamo_test
|
||||
def test_invalid_cause(self):
|
||||
try:
|
||||
raise IndexError from 5
|
||||
except TypeError as e:
|
||||
self.assertIn("exception cause", str(e))
|
||||
else:
|
||||
self.fail("No exception raised")
|
||||
|
||||
@make_dynamo_test
|
||||
def test_class_cause(self):
|
||||
try:
|
||||
raise IndexError from KeyError
|
||||
except IndexError as e:
|
||||
self.assertIsInstance(e.__cause__, KeyError)
|
||||
else:
|
||||
self.fail("No exception raised")
|
||||
|
||||
@make_dynamo_test
|
||||
def test_instance_cause(self):
|
||||
cause = KeyError()
|
||||
try:
|
||||
raise IndexError from cause
|
||||
except IndexError as e:
|
||||
self.assertIs(e.__cause__, cause)
|
||||
else:
|
||||
self.fail("No exception raised")
|
||||
|
||||
@make_dynamo_test
|
||||
def test_erroneous_cause(self):
|
||||
try:
|
||||
raise IndexError from MyException
|
||||
except RuntimeError:
|
||||
pass
|
||||
else:
|
||||
self.fail("No exception raised")
|
||||
|
||||
|
||||
class TestTraceback(torch._dynamo.test_case.TestCase):
|
||||
# Tests taken from CPython source code in cpython/Lib/test/test_raise.py
|
||||
# https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_raise.py
|
||||
def setUp(self):
|
||||
self._prev = torch._dynamo.config.enable_trace_unittest
|
||||
torch._dynamo.config.enable_trace_unittest = True
|
||||
|
||||
def tearDown(self):
|
||||
torch._dynamo.config.enable_trace_unittest = self._prev
|
||||
|
||||
@unittest.expectedFailure # Dynamo doesn't track traceback
|
||||
@make_dynamo_test
|
||||
def test_sets_traceback(self):
|
||||
try:
|
||||
raise IndexError()
|
||||
except IndexError as e:
|
||||
self.assertIsInstance(e.__traceback__, types.TracebackType)
|
||||
else:
|
||||
self.fail("No exception raised")
|
||||
|
||||
@unittest.expectedFailure # Dynamo doesn't track traceback
|
||||
@make_dynamo_test
|
||||
def test_accepts_traceback(self):
|
||||
tb = get_tb()
|
||||
try:
|
||||
raise IndexError().with_traceback(tb)
|
||||
except IndexError as e:
|
||||
self.assertNotEqual(e.__traceback__, tb)
|
||||
self.assertEqual(e.__traceback__.tb_next, tb)
|
||||
else:
|
||||
self.fail("No exception raised")
|
||||
|
||||
|
||||
class TestTracebackType(torch._dynamo.test_case.TestCase):
|
||||
# Tests taken from CPython source code in cpython/Lib/test/test_raise.py
|
||||
# https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_raise.py
|
||||
def setUp(self):
|
||||
self._prev = torch._dynamo.config.enable_trace_unittest
|
||||
torch._dynamo.config.enable_trace_unittest = True
|
||||
|
||||
def tearDown(self):
|
||||
torch._dynamo.config.enable_trace_unittest = self._prev
|
||||
|
||||
def raiser(self):
|
||||
raise ValueError
|
||||
|
||||
@unittest.expectedFailure # Dynamo doesn't track traceback
|
||||
@make_dynamo_test
|
||||
def test_attrs(self):
|
||||
try:
|
||||
self.raiser()
|
||||
except Exception as exc:
|
||||
tb = exc.__traceback__
|
||||
|
||||
self.assertIsInstance(tb.tb_next, types.TracebackType)
|
||||
self.assertIs(tb.tb_frame, sys._getframe())
|
||||
self.assertIsInstance(tb.tb_lasti, int)
|
||||
self.assertIsInstance(tb.tb_lineno, int)
|
||||
|
||||
self.assertIs(tb.tb_next.tb_next, None)
|
||||
|
||||
# Invalid assignments
|
||||
with self.assertRaises(TypeError):
|
||||
del tb.tb_next
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
tb.tb_next = "asdf"
|
||||
|
||||
# Loops
|
||||
with self.assertRaises(ValueError):
|
||||
tb.tb_next = tb
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
tb.tb_next.tb_next = tb
|
||||
|
||||
# Valid assignments
|
||||
tb.tb_next = None
|
||||
self.assertIs(tb.tb_next, None)
|
||||
|
||||
new_tb = get_tb()
|
||||
tb.tb_next = new_tb
|
||||
self.assertIs(tb.tb_next, new_tb)
|
||||
|
||||
@unittest.expectedFailure # Dynamo doesn't track traceback
|
||||
@make_dynamo_test
|
||||
def test_constructor(self):
|
||||
other_tb = get_tb()
|
||||
frame = sys._getframe()
|
||||
|
||||
tb = types.TracebackType(other_tb, frame, 1, 2)
|
||||
self.assertEqual(tb.tb_next, other_tb)
|
||||
self.assertEqual(tb.tb_frame, frame)
|
||||
self.assertEqual(tb.tb_lasti, 1)
|
||||
self.assertEqual(tb.tb_lineno, 2)
|
||||
|
||||
tb = types.TracebackType(None, frame, 1, 2)
|
||||
self.assertEqual(tb.tb_next, None)
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
types.TracebackType("no", frame, 1, 2)
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
types.TracebackType(other_tb, "no", 1, 2)
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
types.TracebackType(other_tb, frame, "no", 2)
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
types.TracebackType(other_tb, frame, 1, "nuh-uh")
|
||||
|
||||
|
||||
class TestContext(torch._dynamo.test_case.TestCase):
|
||||
# Tests taken from CPython source code in cpython/Lib/test/test_raise.py
|
||||
# https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_raise.py
|
||||
def setUp(self):
|
||||
self._prev = torch._dynamo.config.enable_trace_unittest
|
||||
torch._dynamo.config.enable_trace_unittest = True
|
||||
|
||||
def tearDown(self):
|
||||
torch._dynamo.config.enable_trace_unittest = self._prev
|
||||
|
||||
@unittest.expectedFailure # missing Exception.__eq__
|
||||
@make_dynamo_test
|
||||
def test_instance_context_instance_raise(self):
|
||||
context = IndexError()
|
||||
try:
|
||||
try:
|
||||
raise context
|
||||
except:
|
||||
raise OSError()
|
||||
except OSError as e:
|
||||
self.assertEqual(e.__context__, context)
|
||||
else:
|
||||
self.fail("No exception raised")
|
||||
|
||||
@unittest.expectedFailure # missing Exception.__eq__ and Exception.__repr__
|
||||
@make_dynamo_test
|
||||
def test_class_context_instance_raise(self):
|
||||
context = IndexError
|
||||
try:
|
||||
try:
|
||||
raise context
|
||||
except:
|
||||
raise OSError()
|
||||
except OSError as e:
|
||||
self.assertNotEqual(e.__context__, context)
|
||||
self.assertIsInstance(e.__context__, context)
|
||||
else:
|
||||
self.fail("No exception raised")
|
||||
|
||||
@unittest.expectedFailure # missing Exception.__eq__ and Exception.__repr__
|
||||
@make_dynamo_test
|
||||
def test_class_context_class_raise(self):
|
||||
context = IndexError
|
||||
try:
|
||||
try:
|
||||
raise context
|
||||
except:
|
||||
raise OSError
|
||||
except OSError as e:
|
||||
self.assertNotEqual(e.__context__, context)
|
||||
self.assertIsInstance(e.__context__, context)
|
||||
else:
|
||||
self.fail("No exception raised")
|
||||
|
||||
@make_dynamo_test
|
||||
def test_c_exception_context(self):
|
||||
try:
|
||||
try:
|
||||
raise ZeroDivisionError
|
||||
except:
|
||||
raise OSError
|
||||
except OSError as e:
|
||||
self.assertIsInstance(e.__context__, ZeroDivisionError)
|
||||
else:
|
||||
self.fail("No exception raised")
|
||||
|
||||
@make_dynamo_test
|
||||
def test_c_exception_raise(self):
|
||||
try:
|
||||
try:
|
||||
raise ZeroDivisionError
|
||||
except:
|
||||
raise NameError
|
||||
except NameError as e:
|
||||
self.assertIsInstance(e.__context__, ZeroDivisionError)
|
||||
else:
|
||||
self.fail("No exception raised")
|
||||
|
||||
@make_dynamo_test
|
||||
def test_noraise_finally(self):
|
||||
try:
|
||||
try:
|
||||
pass
|
||||
finally:
|
||||
raise OSError
|
||||
except OSError as e:
|
||||
self.assertIsNone(e.__context__)
|
||||
else:
|
||||
self.fail("No exception raised")
|
||||
|
||||
@make_dynamo_test
|
||||
def test_raise_finally(self):
|
||||
try:
|
||||
try:
|
||||
raise ZeroDivisionError
|
||||
finally:
|
||||
raise OSError
|
||||
except OSError as e:
|
||||
self.assertIsInstance(e.__context__, ZeroDivisionError)
|
||||
else:
|
||||
self.fail("No exception raised")
|
||||
|
||||
@make_dynamo_test
|
||||
def test_context_manager(self):
|
||||
try:
|
||||
with ContextManager():
|
||||
raise ZeroDivisionError
|
||||
except NameError as e:
|
||||
self.assertIsInstance(e.__context__, ZeroDivisionError)
|
||||
else:
|
||||
self.fail("No exception raised")
|
||||
|
||||
@make_dynamo_test
|
||||
def test_cycle_broken(self):
|
||||
# Self-cycles (when re-raising a caught exception) are broken
|
||||
try:
|
||||
try:
|
||||
raise ZeroDivisionError
|
||||
except ZeroDivisionError as e:
|
||||
raise e
|
||||
except ZeroDivisionError as e:
|
||||
self.assertIsNone(e.__context__)
|
||||
|
||||
@make_dynamo_test
|
||||
def test_reraise_cycle_broken(self):
|
||||
# Non-trivial context cycles (through re-raising a previous exception)
|
||||
# are broken too.
|
||||
try:
|
||||
try:
|
||||
raise NameError
|
||||
except NameError as a:
|
||||
try:
|
||||
raise ZeroDivisionError
|
||||
except ZeroDivisionError:
|
||||
raise a
|
||||
except NameError as e:
|
||||
self.assertIsNone(e.__context__.__context__)
|
||||
|
||||
@make_dynamo_test
|
||||
def test_3118(self):
|
||||
# deleting the generator caused the __context__ to be cleared
|
||||
def gen():
|
||||
try:
|
||||
yield 1
|
||||
finally:
|
||||
pass
|
||||
|
||||
def f():
|
||||
g = gen()
|
||||
next(g)
|
||||
try:
|
||||
try:
|
||||
raise ValueError
|
||||
except:
|
||||
del g
|
||||
raise KeyError
|
||||
except Exception as e:
|
||||
self.assertIsInstance(e.__context__, ValueError)
|
||||
|
||||
f()
|
||||
|
||||
@unittest.expectedFailure # too CPython specific(?)
|
||||
@make_dynamo_test
|
||||
def test_3611(self):
|
||||
# A re-raised exception in a __del__ caused the __context__
|
||||
# to be cleared
|
||||
class C:
|
||||
def __del__(self):
|
||||
try:
|
||||
raise ZeroDivisionError
|
||||
except:
|
||||
raise
|
||||
|
||||
def f():
|
||||
x = C()
|
||||
try:
|
||||
try:
|
||||
x.x
|
||||
except AttributeError:
|
||||
del x
|
||||
raise TypeError
|
||||
except Exception as e:
|
||||
self.assertNotEqual(e.__context__, None)
|
||||
self.assertIsInstance(e.__context__, AttributeError)
|
||||
|
||||
with support.catch_unraisable_exception() as cm:
|
||||
f()
|
||||
|
||||
self.assertEqual(ZeroDivisionError, cm.unraisable.exc_type)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
run_tests()
|
||||
107
test/dynamo/test_sys.py
Normal file
107
test/dynamo/test_sys.py
Normal file
|
|
@ -0,0 +1,107 @@
|
|||
# Owner(s): ["module: dynamo"]
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
import torch._dynamo.test_case
|
||||
from torch.testing._internal.common_utils import make_dynamo_test
|
||||
|
||||
|
||||
class SysTests(torch._dynamo.test_case.TestCase):
|
||||
def test_exc_info(self):
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def fn(t):
|
||||
try:
|
||||
raise ValueError
|
||||
except Exception:
|
||||
typ, _, _ = sys.exc_info()
|
||||
if typ is ValueError:
|
||||
return t.sin()
|
||||
else:
|
||||
return t.cos()
|
||||
|
||||
t = torch.randn(2)
|
||||
y = fn(t)
|
||||
self.assertEqual(y, t.sin())
|
||||
|
||||
|
||||
class CPythonActiveExceptionTests(torch._dynamo.test_case.CPythonTestCase):
|
||||
# Tests taken from CPython source code in cpython/Lib/test/test_sys.py
|
||||
# https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_sys.py
|
||||
|
||||
@make_dynamo_test
|
||||
def test_exc_info_no_exception(self):
|
||||
self.assertEqual(sys.exc_info(), (None, None, None))
|
||||
|
||||
@unittest.skipIf(sys.version_info < (3, 11), "Python 3.11+")
|
||||
@make_dynamo_test
|
||||
def test_sys_exception_no_exception(self):
|
||||
self.assertEqual(sys.exception(), None)
|
||||
|
||||
@make_dynamo_test
|
||||
def test_exc_info_with_exception_instance(self):
|
||||
def f():
|
||||
raise ValueError(42)
|
||||
|
||||
try:
|
||||
f()
|
||||
except Exception as e_:
|
||||
e = e_
|
||||
exc_info = sys.exc_info()
|
||||
|
||||
self.assertIsInstance(e, ValueError)
|
||||
self.assertIs(exc_info[0], ValueError)
|
||||
self.assertIs(exc_info[1], e)
|
||||
self.assertIs(exc_info[2], e.__traceback__)
|
||||
|
||||
@make_dynamo_test
|
||||
def test_exc_info_with_exception_type(self):
|
||||
def f():
|
||||
raise ValueError
|
||||
|
||||
try:
|
||||
f()
|
||||
except Exception as e_:
|
||||
e = e_
|
||||
exc_info = sys.exc_info()
|
||||
|
||||
self.assertIsInstance(e, ValueError)
|
||||
self.assertIs(exc_info[0], ValueError)
|
||||
self.assertIs(exc_info[1], e)
|
||||
self.assertIs(exc_info[2], e.__traceback__)
|
||||
|
||||
@unittest.skipIf(sys.version_info < (3, 11), "Python 3.11+")
|
||||
@make_dynamo_test
|
||||
def test_sys_exception_with_exception_instance(self):
|
||||
def f():
|
||||
raise ValueError(42)
|
||||
|
||||
try:
|
||||
f()
|
||||
except Exception as e_:
|
||||
e = e_
|
||||
exc = sys.exception()
|
||||
|
||||
self.assertIsInstance(e, ValueError)
|
||||
self.assertIs(exc, e)
|
||||
|
||||
@unittest.skipIf(sys.version_info < (3, 11), "Python 3.11+")
|
||||
@make_dynamo_test
|
||||
def test_sys_exception_with_exception_type(self):
|
||||
def f():
|
||||
raise ValueError
|
||||
|
||||
try:
|
||||
f()
|
||||
except Exception as e_:
|
||||
e = e_
|
||||
exc = sys.exception()
|
||||
|
||||
self.assertIsInstance(e, ValueError)
|
||||
self.assertIs(exc, e)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
run_tests()
|
||||
|
|
@ -1593,13 +1593,6 @@ def get_selected_tests(options) -> list[str]:
|
|||
]
|
||||
)
|
||||
|
||||
if sys.version_info[:2] < (3, 13):
|
||||
# Skip tests for older Python versions as they may use syntax or features
|
||||
# not supported in those versions
|
||||
options.exclude.extend(
|
||||
[test for test in selected_tests if test.startswith("dynamo/cpython/3_13/")]
|
||||
)
|
||||
|
||||
selected_tests = exclude_tests(options.exclude, selected_tests)
|
||||
|
||||
if sys.platform == "win32" and not options.ignore_win_blocklist:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,3 @@
|
|||
# mypy: allow-untyped-defs
|
||||
|
||||
"""Testing utilities for Dynamo, providing a specialized TestCase class and test running functionality.
|
||||
|
||||
This module extends PyTorch's testing framework with Dynamo-specific testing capabilities.
|
||||
|
|
@ -12,12 +10,8 @@ It includes:
|
|||
|
||||
import contextlib
|
||||
import importlib
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import unittest
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
|
|
@ -104,67 +98,7 @@ class TestCase(TorchTestCase):
|
|||
|
||||
|
||||
class CPythonTestCase(TestCase):
|
||||
"""
|
||||
Enable certain features that are off by default (i.e. tracing through unittest)
|
||||
"""
|
||||
|
||||
_stack: contextlib.ExitStack
|
||||
dynamo_strict_nopython = True
|
||||
|
||||
# Restore original unittest methods to simplify tracing CPython test cases.
|
||||
assertEqual = unittest.TestCase.assertEqual # type: ignore[assignment]
|
||||
assertNotEqual = unittest.TestCase.assertNotEqual # type: ignore[assignment]
|
||||
assertTrue = unittest.TestCase.assertTrue
|
||||
assertFalse = unittest.TestCase.assertFalse
|
||||
assertIs = unittest.TestCase.assertIs
|
||||
assertIsNot = unittest.TestCase.assertIsNot
|
||||
assertIsNone = unittest.TestCase.assertIsNone
|
||||
assertIsNotNone = unittest.TestCase.assertIsNotNone
|
||||
assertIn = unittest.TestCase.assertIn
|
||||
assertNotIn = unittest.TestCase.assertNotIn
|
||||
assertIsInstance = unittest.TestCase.assertIsInstance
|
||||
assertNotIsInstance = unittest.TestCase.assertNotIsInstance
|
||||
assertAlmostEqual = unittest.TestCase.assertAlmostEqual
|
||||
assertNotAlmostEqual = unittest.TestCase.assertNotAlmostEqual
|
||||
assertGreater = unittest.TestCase.assertGreater
|
||||
assertGreaterEqual = unittest.TestCase.assertGreaterEqual
|
||||
assertLess = unittest.TestCase.assertLess
|
||||
assertLessEqual = unittest.TestCase.assertLessEqual
|
||||
assertRegex = unittest.TestCase.assertRegex
|
||||
assertNotRegex = unittest.TestCase.assertNotRegex
|
||||
assertCountEqual = unittest.TestCase.assertCountEqual
|
||||
assertMultiLineEqual = unittest.TestCase.assertMultiLineEqual
|
||||
assertSequenceEqual = unittest.TestCase.assertSequenceEqual
|
||||
assertListEqual = unittest.TestCase.assertListEqual
|
||||
assertTupleEqual = unittest.TestCase.assertTupleEqual
|
||||
assertSetEqual = unittest.TestCase.assertSetEqual
|
||||
assertDictEqual = unittest.TestCase.assertDictEqual
|
||||
assertRaises = unittest.TestCase.assertRaises
|
||||
assertRaisesRegex = unittest.TestCase.assertRaisesRegex
|
||||
assertWarns = unittest.TestCase.assertWarns
|
||||
assertWarnsRegex = unittest.TestCase.assertWarnsRegex
|
||||
assertLogs = unittest.TestCase.assertLogs
|
||||
fail = unittest.TestCase.fail
|
||||
failureException = unittest.TestCase.failureException
|
||||
|
||||
def compile_fn(self, fn, backend, nopython):
|
||||
# We want to compile only the test function, excluding any setup code
|
||||
# from unittest
|
||||
method = getattr(self, self._testMethodName)
|
||||
method = torch._dynamo.optimize(backend, nopython=nopython)(method)
|
||||
setattr(self, self._testMethodName, method)
|
||||
return fn
|
||||
|
||||
def _dynamo_test_key(self):
|
||||
suffix = super()._dynamo_test_key()
|
||||
test_cls = self.__class__
|
||||
test_file = inspect.getfile(test_cls).split(os.sep)[-1].split(".")[0]
|
||||
py_ver = re.search(r"/([\d_]+)/", inspect.getfile(test_cls))
|
||||
if py_ver:
|
||||
py_ver = py_ver.group().strip(os.sep).replace("_", "") # type: ignore[assignment]
|
||||
else:
|
||||
return suffix
|
||||
return f"CPython{py_ver}-{test_file}-{suffix}"
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls) -> None:
|
||||
|
|
@ -173,22 +107,6 @@ class CPythonTestCase(TestCase):
|
|||
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
# Skip test if python versions doesn't match
|
||||
m = re.search(r"\b\d+_\d+\b", inspect.getfile(cls))
|
||||
if m:
|
||||
test_py_ver = tuple(map(int, m.group().split("_")))
|
||||
else:
|
||||
raise unittest.TestCase.failureException(
|
||||
f"Test file {inspect.getfile(cls)} does not contain a valid Python version"
|
||||
)
|
||||
py_ver = sys.version_info[:2]
|
||||
if py_ver != test_py_ver:
|
||||
expected = ".".join(map(str, test_py_ver))
|
||||
got = ".".join(map(str, py_ver))
|
||||
raise unittest.SkipTest(
|
||||
f"Test requires Python {expected} but got Python {got}"
|
||||
)
|
||||
|
||||
super().setUpClass()
|
||||
cls._stack = contextlib.ExitStack() # type: ignore[attr-defined]
|
||||
cls._stack.enter_context( # type: ignore[attr-defined]
|
||||
|
|
|
|||
|
|
@ -1989,11 +1989,11 @@ class BuiltinVariable(VariableTracker):
|
|||
)
|
||||
):
|
||||
unimplemented_v2(
|
||||
gb_type="Failed to trace unittest method",
|
||||
gb_type="Failed to trace builtin operator",
|
||||
context=f"function: unittest.TestCase.{name}",
|
||||
explanation=f"Dynamo does not know how to trace unittest method `{name}` ",
|
||||
explanation=f"Dynamo does not know how to trace builtin operator `{name}` ",
|
||||
hints=[
|
||||
f"Avoid calling `TestCase.{name}`. "
|
||||
f"Avoid calling builtin `{name}`. "
|
||||
"Please report an issue to PyTorch.",
|
||||
],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -3157,13 +3157,6 @@ class TestCase(expecttest.TestCase):
|
|||
def wrap_with_cuda_memory_check(self, method):
|
||||
return self.wrap_method_with_policy(method, self.assertLeaksNoCudaTensors)
|
||||
|
||||
def _dynamo_test_key(self):
|
||||
return f"{self.__class__.__name__}.{self._testMethodName}"
|
||||
|
||||
def compile_fn(self, fn, backend, nopython):
|
||||
# Allows subclasses to control compilation
|
||||
return torch._dynamo.optimize(backend, nopython=nopython)(fn)
|
||||
|
||||
def _run_custom(self, result=None):
|
||||
using_unittest = isinstance(result, unittest.TestResult)
|
||||
|
||||
|
|
@ -3239,16 +3232,16 @@ class TestCase(expecttest.TestCase):
|
|||
|
||||
with unittest.mock.patch("torch._dynamo.config.suppress_errors", suppress_errors), maybe_disable_size_asserts:
|
||||
if TEST_WITH_AOT_EAGER:
|
||||
super_run = self.compile_fn(super_run, "aot_eager_decomp_partition", nopython)
|
||||
super_run = torch._dynamo.optimize("aot_eager_decomp_partition")(super_run)
|
||||
elif TEST_WITH_TORCHDYNAMO or TEST_WITH_TORCHINDUCTOR:
|
||||
if TEST_WITH_TORCHINDUCTOR:
|
||||
super_run = self.compile_fn(super_run, "inductor", nopython)
|
||||
super_run = torch._dynamo.optimize("inductor")(super_run)
|
||||
else:
|
||||
# Assume eager-generated GraphModules will not error out.
|
||||
# If we do, this is probably a Dynamo bug!
|
||||
super_run = self.compile_fn(super_run, "eager_noexcept", nopython)
|
||||
super_run = torch._dynamo.optimize("eager_noexcept", nopython=nopython)(super_run)
|
||||
|
||||
key = self._dynamo_test_key()
|
||||
key = f"{self.__class__.__name__}.{self._testMethodName}"
|
||||
|
||||
def expect_failure(f, file_name):
|
||||
@wraps(f)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user