mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add infra to run CPython tests under Dynamo (#150787)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150787 Approved by: https://github.com/zou3519
This commit is contained in:
parent
13fbf21a76
commit
ae1e51b6ad
|
|
@ -18,6 +18,8 @@ exclude_patterns = [
|
||||||
'torch/_inductor/autoheuristic/artifacts/**',
|
'torch/_inductor/autoheuristic/artifacts/**',
|
||||||
'scripts/**',
|
'scripts/**',
|
||||||
'test/generated_type_hints_smoketest.py',
|
'test/generated_type_hints_smoketest.py',
|
||||||
|
# CPython tests
|
||||||
|
'test/dynamo/cpython/**',
|
||||||
# Tests from the NumPy test suite
|
# Tests from the NumPy test suite
|
||||||
'test/torch_np/numpy_test/**/*.py',
|
'test/torch_np/numpy_test/**/*.py',
|
||||||
'third_party/**',
|
'third_party/**',
|
||||||
|
|
@ -398,6 +400,7 @@ exclude_patterns=[
|
||||||
'tools/clang_format_hash/**',
|
'tools/clang_format_hash/**',
|
||||||
'test/cpp/jit/upgrader_models/*.ptl',
|
'test/cpp/jit/upgrader_models/*.ptl',
|
||||||
'test/cpp/jit/upgrader_models/*.ptl.ff',
|
'test/cpp/jit/upgrader_models/*.ptl.ff',
|
||||||
|
'test/dynamo/cpython/**',
|
||||||
'**/*.png',
|
'**/*.png',
|
||||||
'**/*.gz',
|
'**/*.gz',
|
||||||
'**/*.patch',
|
'**/*.patch',
|
||||||
|
|
@ -936,6 +939,7 @@ include_patterns = [
|
||||||
exclude_patterns = [
|
exclude_patterns = [
|
||||||
'test/run_test.py',
|
'test/run_test.py',
|
||||||
'**/fb/**',
|
'**/fb/**',
|
||||||
|
'test/dynamo/cpython/3.13/**',
|
||||||
'test/quantization/**', # should be run through test/test_quantization.py
|
'test/quantization/**', # should be run through test/test_quantization.py
|
||||||
'test/jit/**', # should be run through test/test_jit.py
|
'test/jit/**', # should be run through test/test_jit.py
|
||||||
'test/ao/sparsity/**', # should be run through test/test_ao_sparsity.py
|
'test/ao/sparsity/**', # should be run through test/test_ao_sparsity.py
|
||||||
|
|
@ -1131,6 +1135,7 @@ exclude_patterns = [
|
||||||
'caffe2/**/*.pyi',
|
'caffe2/**/*.pyi',
|
||||||
'fb/**',
|
'fb/**',
|
||||||
'**/fb/**',
|
'**/fb/**',
|
||||||
|
'test/dynamo/cpython/**',
|
||||||
'third_party/**/*.py',
|
'third_party/**/*.py',
|
||||||
'third_party/**/*.pyi',
|
'third_party/**/*.pyi',
|
||||||
'torch/_vendor/**',
|
'torch/_vendor/**',
|
||||||
|
|
@ -1536,6 +1541,7 @@ exclude_patterns = [
|
||||||
'functorch/notebooks/**',
|
'functorch/notebooks/**',
|
||||||
'torch/_inductor/fx_passes/serialized_patterns/**',
|
'torch/_inductor/fx_passes/serialized_patterns/**',
|
||||||
'torch/_inductor/autoheuristic/artifacts/**',
|
'torch/_inductor/autoheuristic/artifacts/**',
|
||||||
|
'test/dynamo/cpython/**',
|
||||||
'scripts/**',
|
'scripts/**',
|
||||||
'third_party/**',
|
'third_party/**',
|
||||||
'fb/**',
|
'fb/**',
|
||||||
|
|
|
||||||
9
test/dynamo/cpython/3_13/CHANGES.txt
Normal file
9
test/dynamo/cpython/3_13/CHANGES.txt
Normal file
|
|
@ -0,0 +1,9 @@
|
||||||
|
This subdirectory contains a selection of tests from the CPython repository (branch: v3.13.0):\
|
||||||
|
https://github.com/python/cpython/releases/tag/v3.13.0
|
||||||
|
|
||||||
|
Modifications were made to ensure compatibility with the Dynamo infrastructure:
|
||||||
|
+ Monkey-patched `unittest.TestCase` to `torch._dynamo.test_case.CPythonTestCase`.
|
||||||
|
+ Replaced `unittest.main()` with `torch._dynamo.test_case.run_tests()`.
|
||||||
|
+ Assigned test "owners."
|
||||||
|
+ Annotated CPU-intensive tests with the `@slowTest` decorator.
|
||||||
|
+ Adjusted imports to use `import module` instead of `from test import module`.
|
||||||
46
test/dynamo/cpython/3_13/LICENSE
Normal file
46
test/dynamo/cpython/3_13/LICENSE
Normal file
|
|
@ -0,0 +1,46 @@
|
||||||
|
PYTHON SOFTWARE FOUNDATION LICENSE VERSION 2
|
||||||
|
--------------------------------------------
|
||||||
|
|
||||||
|
1. This LICENSE AGREEMENT is between the Python Software Foundation
|
||||||
|
("PSF"), and the Individual or Organization ("Licensee") accessing and
|
||||||
|
otherwise using this software ("Python") in source or binary form and
|
||||||
|
its associated documentation.
|
||||||
|
|
||||||
|
2. Subject to the terms and conditions of this License Agreement, PSF hereby
|
||||||
|
grants Licensee a nonexclusive, royalty-free, world-wide license to reproduce,
|
||||||
|
analyze, test, perform and/or display publicly, prepare derivative works,
|
||||||
|
distribute, and otherwise use Python alone or in any derivative version,
|
||||||
|
provided, however, that PSF's License Agreement and PSF's notice of copyright,
|
||||||
|
i.e., "Copyright (c) 2001 Python Software Foundation; All Rights Reserved"
|
||||||
|
are retained in Python alone or in any derivative version prepared by Licensee.
|
||||||
|
|
||||||
|
3. In the event Licensee prepares a derivative work that is based on
|
||||||
|
or incorporates Python or any part thereof, and wants to make
|
||||||
|
the derivative work available to others as provided herein, then
|
||||||
|
Licensee hereby agrees to include in any such work a brief summary of
|
||||||
|
the changes made to Python.
|
||||||
|
|
||||||
|
4. PSF is making Python available to Licensee on an "AS IS"
|
||||||
|
basis. PSF MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR
|
||||||
|
IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, PSF MAKES NO AND
|
||||||
|
DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS
|
||||||
|
FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON WILL NOT
|
||||||
|
INFRINGE ANY THIRD PARTY RIGHTS.
|
||||||
|
|
||||||
|
5. PSF SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON
|
||||||
|
FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS
|
||||||
|
A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON,
|
||||||
|
OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF.
|
||||||
|
|
||||||
|
6. This License Agreement will automatically terminate upon a material
|
||||||
|
breach of its terms and conditions.
|
||||||
|
|
||||||
|
7. Nothing in this License Agreement shall be deemed to create any
|
||||||
|
relationship of agency, partnership, or joint venture between PSF and
|
||||||
|
Licensee. This License Agreement does not grant permission to use PSF
|
||||||
|
trademarks or trade name in a trademark sense to endorse or promote
|
||||||
|
products or services of Licensee, or any third party.
|
||||||
|
|
||||||
|
8. By copying, installing or otherwise using Python, Licensee
|
||||||
|
agrees to be bound by the terms and conditions of this License
|
||||||
|
Agreement.
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
# Owner(s): ["module: dynamo"]
|
# Owner(s): ["module: dynamo"]
|
||||||
import contextlib
|
import contextlib
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
|
||||||
import unittest
|
import unittest
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
|
@ -9,18 +8,12 @@ import torch
|
||||||
import torch._dynamo.test_case
|
import torch._dynamo.test_case
|
||||||
import torch._dynamo.testing
|
import torch._dynamo.testing
|
||||||
from torch._dynamo.exc import InternalTorchDynamoError
|
from torch._dynamo.exc import InternalTorchDynamoError
|
||||||
from torch._dynamo.testing import (
|
from torch._dynamo.testing import EagerAndRecordGraphs, normalize_gm, same
|
||||||
EagerAndRecordGraphs,
|
|
||||||
normalize_gm,
|
|
||||||
same,
|
|
||||||
skipIfNotPy311,
|
|
||||||
)
|
|
||||||
from torch._dynamo.utils import counters
|
from torch._dynamo.utils import counters
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION
|
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION
|
||||||
from torch.testing._internal.common_utils import (
|
from torch.testing._internal.common_utils import (
|
||||||
instantiate_parametrized_tests,
|
instantiate_parametrized_tests,
|
||||||
make_dynamo_test,
|
|
||||||
parametrize,
|
parametrize,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -37,6 +30,16 @@ z_glb = 0
|
||||||
k_glb = 0
|
k_glb = 0
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def set_default_dtype(dtype):
|
||||||
|
old_dtype = torch.get_default_dtype()
|
||||||
|
try:
|
||||||
|
torch.set_default_dtype(dtype)
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
torch.set_default_dtype(old_dtype)
|
||||||
|
|
||||||
|
|
||||||
class CustomizedCtxManager:
|
class CustomizedCtxManager:
|
||||||
def __init__(self, mode):
|
def __init__(self, mode):
|
||||||
self.prev = torch.is_grad_enabled()
|
self.prev = torch.is_grad_enabled()
|
||||||
|
|
@ -2700,319 +2703,6 @@ class GraphModule(torch.nn.Module):
|
||||||
self.assertEqual(y, t.sin())
|
self.assertEqual(y, t.sin())
|
||||||
|
|
||||||
|
|
||||||
class CPythonContextManagerTestCase(torch._dynamo.test_case.CPythonTestCase):
|
|
||||||
# Tests taken from CPython source code in cpython/Lib/test/test_contextlib.py
|
|
||||||
# https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_contextlib.py
|
|
||||||
|
|
||||||
@make_dynamo_test
|
|
||||||
def test_contextmanager_plain(self):
|
|
||||||
state = []
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def woohoo():
|
|
||||||
state.append(1)
|
|
||||||
yield 42
|
|
||||||
state.append(999)
|
|
||||||
|
|
||||||
with woohoo() as x:
|
|
||||||
self.assertEqual(state, [1])
|
|
||||||
self.assertEqual(x, 42)
|
|
||||||
state.append(x)
|
|
||||||
self.assertEqual(state, [1, 42, 999])
|
|
||||||
|
|
||||||
@skipIfNotPy311
|
|
||||||
@make_dynamo_test
|
|
||||||
def test_contextmanager_finally(self):
|
|
||||||
state = []
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def woohoo():
|
|
||||||
state.append(1)
|
|
||||||
try:
|
|
||||||
yield 42
|
|
||||||
finally:
|
|
||||||
state.append(999)
|
|
||||||
|
|
||||||
with self.assertRaises(ZeroDivisionError):
|
|
||||||
with woohoo() as x:
|
|
||||||
self.assertEqual(state, [1])
|
|
||||||
self.assertEqual(x, 42)
|
|
||||||
state.append(x)
|
|
||||||
raise ZeroDivisionError
|
|
||||||
self.assertEqual(state, [1, 42, 999])
|
|
||||||
|
|
||||||
@unittest.expectedFailure
|
|
||||||
@make_dynamo_test
|
|
||||||
def test_contextmanager_traceback(self):
|
|
||||||
@contextmanager
|
|
||||||
def f():
|
|
||||||
yield
|
|
||||||
|
|
||||||
try:
|
|
||||||
with f():
|
|
||||||
1 / 0
|
|
||||||
except ZeroDivisionError as e:
|
|
||||||
frames = traceback.extract_tb(e.__traceback__)
|
|
||||||
|
|
||||||
self.assertEqual(len(frames), 1)
|
|
||||||
self.assertEqual(frames[0].name, "test_contextmanager_traceback")
|
|
||||||
self.assertEqual(frames[0].line, "1/0")
|
|
||||||
|
|
||||||
# Repeat with RuntimeError (which goes through a different code path)
|
|
||||||
try:
|
|
||||||
with f():
|
|
||||||
raise NotImplementedError(42)
|
|
||||||
except NotImplementedError as e:
|
|
||||||
frames = traceback.extract_tb(e.__traceback__)
|
|
||||||
|
|
||||||
self.assertEqual(len(frames), 1)
|
|
||||||
self.assertEqual(frames[0].name, "test_contextmanager_traceback")
|
|
||||||
self.assertEqual(frames[0].line, "raise NotImplementedError(42)")
|
|
||||||
|
|
||||||
@make_dynamo_test
|
|
||||||
def test_contextmanager_no_reraise(self):
|
|
||||||
@contextmanager
|
|
||||||
def whee():
|
|
||||||
yield
|
|
||||||
|
|
||||||
ctx = whee()
|
|
||||||
ctx.__enter__()
|
|
||||||
# Calling __exit__ should not result in an exception
|
|
||||||
self.assertFalse(ctx.__exit__(TypeError, TypeError("foo"), None))
|
|
||||||
|
|
||||||
@make_dynamo_test
|
|
||||||
def test_contextmanager_trap_yield_after_throw(self):
|
|
||||||
@contextmanager
|
|
||||||
def whoo():
|
|
||||||
try:
|
|
||||||
yield
|
|
||||||
except Exception: # noqa: E722
|
|
||||||
yield
|
|
||||||
|
|
||||||
ctx = whoo()
|
|
||||||
ctx.__enter__()
|
|
||||||
with self.assertRaises(RuntimeError):
|
|
||||||
ctx.__exit__(TypeError, TypeError("foo"), None)
|
|
||||||
|
|
||||||
@unittest.skipIf(sys.version_info < (3, 11), "Python 3.11+")
|
|
||||||
def test_contextmanager_except(self):
|
|
||||||
state = []
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def woohoo():
|
|
||||||
state.append(1)
|
|
||||||
try:
|
|
||||||
yield 42
|
|
||||||
except ZeroDivisionError as e:
|
|
||||||
state.append(e.args[0])
|
|
||||||
self.assertEqual(state, [1, 42, 999])
|
|
||||||
|
|
||||||
with woohoo() as x:
|
|
||||||
self.assertEqual(state, [1])
|
|
||||||
self.assertEqual(x, 42)
|
|
||||||
state.append(x)
|
|
||||||
raise ZeroDivisionError(999)
|
|
||||||
self.assertEqual(state, [1, 42, 999])
|
|
||||||
|
|
||||||
@unittest.expectedFailure
|
|
||||||
@make_dynamo_test
|
|
||||||
def test_contextmanager_except_stopiter(self):
|
|
||||||
@contextmanager
|
|
||||||
def woohoo():
|
|
||||||
yield
|
|
||||||
|
|
||||||
class StopIterationSubclass(StopIteration):
|
|
||||||
pass
|
|
||||||
|
|
||||||
for stop_exc in (StopIteration("spam"), StopIterationSubclass("spam")):
|
|
||||||
with self.subTest(type=type(stop_exc)):
|
|
||||||
try:
|
|
||||||
with woohoo():
|
|
||||||
raise stop_exc
|
|
||||||
except Exception as ex:
|
|
||||||
self.assertIs(ex, stop_exc)
|
|
||||||
else:
|
|
||||||
self.fail(f"{stop_exc} was suppressed")
|
|
||||||
|
|
||||||
@unittest.expectedFailure
|
|
||||||
@make_dynamo_test
|
|
||||||
def test_contextmanager_except_pep479(self):
|
|
||||||
code = """\
|
|
||||||
from __future__ import generator_stop
|
|
||||||
from contextlib import contextmanager
|
|
||||||
@contextmanager
|
|
||||||
def woohoo():
|
|
||||||
yield
|
|
||||||
"""
|
|
||||||
locals = {}
|
|
||||||
exec(code, locals, locals)
|
|
||||||
woohoo = locals["woohoo"]
|
|
||||||
|
|
||||||
stop_exc = StopIteration("spam")
|
|
||||||
try:
|
|
||||||
with woohoo():
|
|
||||||
raise stop_exc
|
|
||||||
except Exception as ex:
|
|
||||||
self.assertIs(ex, stop_exc)
|
|
||||||
else:
|
|
||||||
self.fail("StopIteration was suppressed")
|
|
||||||
|
|
||||||
@unittest.expectedFailure
|
|
||||||
@make_dynamo_test
|
|
||||||
def test_contextmanager_do_not_unchain_non_stopiteration_exceptions(self):
|
|
||||||
@contextmanager
|
|
||||||
def test_issue29692():
|
|
||||||
try:
|
|
||||||
yield
|
|
||||||
except Exception as exc:
|
|
||||||
raise RuntimeError("issue29692:Chained") from exc
|
|
||||||
|
|
||||||
try:
|
|
||||||
with test_issue29692():
|
|
||||||
raise ZeroDivisionError
|
|
||||||
except Exception as ex:
|
|
||||||
self.assertIs(type(ex), RuntimeError)
|
|
||||||
self.assertEqual(ex.args[0], "issue29692:Chained")
|
|
||||||
self.assertIsInstance(ex.__cause__, ZeroDivisionError)
|
|
||||||
|
|
||||||
try:
|
|
||||||
with test_issue29692():
|
|
||||||
raise StopIteration("issue29692:Unchained")
|
|
||||||
except Exception as ex:
|
|
||||||
self.assertIs(type(ex), StopIteration)
|
|
||||||
self.assertEqual(ex.args[0], "issue29692:Unchained")
|
|
||||||
self.assertIsNone(ex.__cause__)
|
|
||||||
|
|
||||||
@unittest.expectedFailure
|
|
||||||
@make_dynamo_test
|
|
||||||
def _create_contextmanager_attribs(self):
|
|
||||||
def attribs(**kw):
|
|
||||||
def decorate(func):
|
|
||||||
for k, v in kw.items():
|
|
||||||
setattr(func, k, v)
|
|
||||||
return func
|
|
||||||
|
|
||||||
return decorate
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
@attribs(foo="bar")
|
|
||||||
def baz(spam):
|
|
||||||
"""Whee!"""
|
|
||||||
|
|
||||||
return baz
|
|
||||||
|
|
||||||
@unittest.expectedFailure
|
|
||||||
@make_dynamo_test
|
|
||||||
def test_contextmanager_attribs(self):
|
|
||||||
baz = self._create_contextmanager_attribs()
|
|
||||||
self.assertEqual(baz.__name__, "baz")
|
|
||||||
self.assertEqual(baz.foo, "bar")
|
|
||||||
|
|
||||||
@make_dynamo_test
|
|
||||||
def test_keywords(self):
|
|
||||||
# Ensure no keyword arguments are inhibited
|
|
||||||
@contextmanager
|
|
||||||
def woohoo(self, func, args, kwds):
|
|
||||||
yield (self, func, args, kwds)
|
|
||||||
|
|
||||||
with woohoo(self=11, func=22, args=33, kwds=44) as target:
|
|
||||||
self.assertEqual(target, (11, 22, 33, 44))
|
|
||||||
|
|
||||||
@unittest.expectedFailure
|
|
||||||
@make_dynamo_test
|
|
||||||
def test_param_errors(self):
|
|
||||||
@contextmanager
|
|
||||||
def woohoo(a, *, b):
|
|
||||||
yield
|
|
||||||
|
|
||||||
with self.assertRaises(TypeError):
|
|
||||||
woohoo()
|
|
||||||
with self.assertRaises(TypeError):
|
|
||||||
woohoo(3, 5)
|
|
||||||
with self.assertRaises(TypeError):
|
|
||||||
woohoo(b=3)
|
|
||||||
|
|
||||||
@make_dynamo_test
|
|
||||||
def test_recursive(self):
|
|
||||||
depth = 0
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def woohoo():
|
|
||||||
nonlocal depth
|
|
||||||
before = depth
|
|
||||||
depth += 1
|
|
||||||
yield
|
|
||||||
depth -= 1
|
|
||||||
self.assertEqual(depth, before)
|
|
||||||
|
|
||||||
@woohoo()
|
|
||||||
def recursive():
|
|
||||||
if depth < 10:
|
|
||||||
recursive()
|
|
||||||
|
|
||||||
recursive()
|
|
||||||
self.assertEqual(depth, 0)
|
|
||||||
|
|
||||||
@skipIfNotPy311
|
|
||||||
@make_dynamo_test
|
|
||||||
def test_contextmanager_trap_no_yield(self):
|
|
||||||
@contextmanager
|
|
||||||
def whoo():
|
|
||||||
if False:
|
|
||||||
yield
|
|
||||||
|
|
||||||
ctx = whoo()
|
|
||||||
with self.assertRaises(RuntimeError):
|
|
||||||
ctx.__enter__()
|
|
||||||
|
|
||||||
@make_dynamo_test
|
|
||||||
def test_contextmanager_trap_second_yield(self):
|
|
||||||
@contextmanager
|
|
||||||
def whoo():
|
|
||||||
yield
|
|
||||||
yield
|
|
||||||
|
|
||||||
ctx = whoo()
|
|
||||||
ctx.__enter__()
|
|
||||||
with self.assertRaises(RuntimeError):
|
|
||||||
ctx.__exit__(None, None, None)
|
|
||||||
|
|
||||||
@unittest.expectedFailure
|
|
||||||
@make_dynamo_test
|
|
||||||
def test_contextmanager_wrap_runtimeerror(self):
|
|
||||||
@contextmanager
|
|
||||||
def woohoo():
|
|
||||||
try:
|
|
||||||
yield
|
|
||||||
except Exception as exc:
|
|
||||||
raise RuntimeError(f"caught {exc}") from exc
|
|
||||||
|
|
||||||
with self.assertRaises(RuntimeError):
|
|
||||||
with woohoo():
|
|
||||||
1 / 0
|
|
||||||
|
|
||||||
# If the context manager wrapped StopIteration in a RuntimeError,
|
|
||||||
# we also unwrap it, because we can't tell whether the wrapping was
|
|
||||||
# done by the generator machinery or by the generator itself.
|
|
||||||
with self.assertRaises(StopIteration):
|
|
||||||
with woohoo():
|
|
||||||
raise StopIteration
|
|
||||||
|
|
||||||
@make_dynamo_test
|
|
||||||
def test_contextmanager_non_normalised(self):
|
|
||||||
@contextmanager
|
|
||||||
def whoo():
|
|
||||||
try:
|
|
||||||
yield
|
|
||||||
except RuntimeError:
|
|
||||||
raise SyntaxError # noqa: B904
|
|
||||||
|
|
||||||
ctx = whoo()
|
|
||||||
ctx.__enter__()
|
|
||||||
with self.assertRaises(SyntaxError):
|
|
||||||
ctx.__exit__(RuntimeError, None, None)
|
|
||||||
|
|
||||||
|
|
||||||
instantiate_parametrized_tests(CtxManagerTests)
|
instantiate_parametrized_tests(CtxManagerTests)
|
||||||
instantiate_parametrized_tests(ContextlibContextManagerTests)
|
instantiate_parametrized_tests(ContextlibContextManagerTests)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,6 @@
|
||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
import sys
|
import sys
|
||||||
import unittest
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch._dynamo.config
|
import torch._dynamo.config
|
||||||
|
|
@ -905,238 +904,6 @@ class ExceptionTests(torch._dynamo.test_case.TestCase):
|
||||||
assert exc2.__context__ is None
|
assert exc2.__context__ is None
|
||||||
|
|
||||||
|
|
||||||
class CPythonExceptionTests(torch._dynamo.test_case.TestCase):
|
|
||||||
# Tests taken from CPython source code in cpython/Lib/test/test_exceptions.py
|
|
||||||
# https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_exceptions.py
|
|
||||||
def setUp(self):
|
|
||||||
self._u_prev = torch._dynamo.config.enable_trace_unittest
|
|
||||||
torch._dynamo.config.enable_trace_unittest = True
|
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
torch._dynamo.config.enable_trace_unittest = self._u_prev
|
|
||||||
|
|
||||||
@make_dynamo_test
|
|
||||||
def testChainingAttrs(self):
|
|
||||||
e = Exception()
|
|
||||||
assert e.__context__ is None
|
|
||||||
assert e.__cause__ is None
|
|
||||||
|
|
||||||
e = TypeError()
|
|
||||||
assert e.__context__ is None
|
|
||||||
assert e.__cause__ is None
|
|
||||||
|
|
||||||
e = MyException()
|
|
||||||
assert e.__context__ is None
|
|
||||||
assert e.__cause__ is None
|
|
||||||
|
|
||||||
@make_dynamo_test
|
|
||||||
def testChainingDescriptors(self):
|
|
||||||
try:
|
|
||||||
raise Exception # noqa: TRY002
|
|
||||||
except Exception as exc:
|
|
||||||
e = exc
|
|
||||||
|
|
||||||
assert e.__context__ is None
|
|
||||||
assert e.__cause__ is None
|
|
||||||
assert e.__suppress_context__ is False
|
|
||||||
|
|
||||||
e.__context__ = NameError()
|
|
||||||
e.__cause__ = None
|
|
||||||
assert isinstance(e.__context__, NameError)
|
|
||||||
assert e.__cause__ is None
|
|
||||||
assert e.__suppress_context__ is True
|
|
||||||
e.__suppress_context__ = False
|
|
||||||
assert e.__suppress_context__ is False
|
|
||||||
|
|
||||||
@make_dynamo_test
|
|
||||||
def test_context_of_exception_in_try_and_finally(self):
|
|
||||||
try:
|
|
||||||
try:
|
|
||||||
te = TypeError(1)
|
|
||||||
raise te
|
|
||||||
finally:
|
|
||||||
ve = ValueError(2)
|
|
||||||
raise ve
|
|
||||||
except Exception as e:
|
|
||||||
exc = e
|
|
||||||
|
|
||||||
assert exc is ve
|
|
||||||
assert exc.__context__ is te
|
|
||||||
|
|
||||||
@make_dynamo_test
|
|
||||||
def test_context_of_exception_in_except_and_finally(self):
|
|
||||||
try:
|
|
||||||
try:
|
|
||||||
te = TypeError(1)
|
|
||||||
raise te
|
|
||||||
except Exception: # noqa: E722
|
|
||||||
ve = ValueError(2)
|
|
||||||
raise ve # noqa: B904
|
|
||||||
finally:
|
|
||||||
oe = OSError(3)
|
|
||||||
raise oe
|
|
||||||
except Exception as e:
|
|
||||||
exc = e
|
|
||||||
|
|
||||||
assert exc is oe
|
|
||||||
assert exc.__context__ is ve
|
|
||||||
assert exc.__context__.__context__ is te
|
|
||||||
|
|
||||||
@make_dynamo_test
|
|
||||||
def test_context_of_exception_in_else_and_finally(self):
|
|
||||||
try:
|
|
||||||
try:
|
|
||||||
pass
|
|
||||||
except Exception: # noqa: E722
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
ve = ValueError(1)
|
|
||||||
raise ve
|
|
||||||
finally:
|
|
||||||
oe = OSError(2)
|
|
||||||
raise oe
|
|
||||||
except Exception as e:
|
|
||||||
exc = e
|
|
||||||
|
|
||||||
assert exc is oe
|
|
||||||
assert exc.__context__ is ve
|
|
||||||
|
|
||||||
@make_dynamo_test
|
|
||||||
def test_raise_does_not_create_context_chain_cycle(self):
|
|
||||||
A = AssertionError
|
|
||||||
B = BytesWarning
|
|
||||||
C = ConnectionError
|
|
||||||
|
|
||||||
# Create a context chain:
|
|
||||||
# C -> B -> A
|
|
||||||
# Then raise A in context of C.
|
|
||||||
try:
|
|
||||||
try:
|
|
||||||
raise A
|
|
||||||
except A as a_:
|
|
||||||
a = a_
|
|
||||||
try:
|
|
||||||
raise B
|
|
||||||
except B as b_:
|
|
||||||
b = b_
|
|
||||||
try:
|
|
||||||
raise C
|
|
||||||
except C as c_:
|
|
||||||
c = c_
|
|
||||||
self.assertIsInstance(a, A)
|
|
||||||
self.assertIsInstance(b, B)
|
|
||||||
self.assertIsInstance(c, C)
|
|
||||||
self.assertIsNone(a.__context__)
|
|
||||||
self.assertIs(b.__context__, a)
|
|
||||||
self.assertIs(c.__context__, b)
|
|
||||||
raise a # noqa: B904
|
|
||||||
except A as e:
|
|
||||||
exc = e
|
|
||||||
|
|
||||||
# Expect A -> C -> B, without cycle
|
|
||||||
self.assertIs(exc, a)
|
|
||||||
self.assertIs(a.__context__, c)
|
|
||||||
self.assertIs(c.__context__, b)
|
|
||||||
self.assertIsNone(b.__context__)
|
|
||||||
|
|
||||||
@make_dynamo_test
|
|
||||||
def test_no_hang_on_context_chain_cycle1(self):
|
|
||||||
# See issue 25782. Cycle in context chain.
|
|
||||||
|
|
||||||
def cycle():
|
|
||||||
try:
|
|
||||||
raise ValueError(1)
|
|
||||||
except ValueError as ex:
|
|
||||||
ex.__context__ = ex
|
|
||||||
raise TypeError(2) # noqa: B904
|
|
||||||
|
|
||||||
try:
|
|
||||||
cycle()
|
|
||||||
except Exception as e:
|
|
||||||
exc = e
|
|
||||||
|
|
||||||
self.assertIsInstance(exc, TypeError)
|
|
||||||
self.assertIsInstance(exc.__context__, ValueError)
|
|
||||||
self.assertIs(exc.__context__.__context__, exc.__context__)
|
|
||||||
|
|
||||||
@unittest.expectedFailure
|
|
||||||
@make_dynamo_test
|
|
||||||
def test_no_hang_on_context_chain_cycle2(self):
|
|
||||||
# See issue 25782. Cycle at head of context chain.
|
|
||||||
|
|
||||||
A = AssertionError
|
|
||||||
B = BytesWarning
|
|
||||||
C = ConnectionError
|
|
||||||
|
|
||||||
# Context cycle:
|
|
||||||
# +-----------+
|
|
||||||
# V |
|
|
||||||
# C --> B --> A
|
|
||||||
with self.assertRaises(C) as cm:
|
|
||||||
try:
|
|
||||||
raise A() # noqa: RSE102
|
|
||||||
except A as _a:
|
|
||||||
a = _a
|
|
||||||
try:
|
|
||||||
raise B() # noqa: RSE102
|
|
||||||
except B as _b:
|
|
||||||
b = _b
|
|
||||||
try:
|
|
||||||
raise C() # noqa: RSE102
|
|
||||||
except C as _c:
|
|
||||||
c = _c
|
|
||||||
a.__context__ = c
|
|
||||||
raise c # noqa: B904
|
|
||||||
|
|
||||||
self.assertIs(cm.exception, c)
|
|
||||||
# Verify the expected context chain cycle
|
|
||||||
self.assertIs(c.__context__, b)
|
|
||||||
self.assertIs(b.__context__, a)
|
|
||||||
self.assertIs(a.__context__, c)
|
|
||||||
|
|
||||||
@make_dynamo_test
|
|
||||||
def test_no_hang_on_context_chain_cycle3(self):
|
|
||||||
# See issue 25782. Longer context chain with cycle.
|
|
||||||
A = AssertionError
|
|
||||||
B = BytesWarning
|
|
||||||
C = ConnectionError
|
|
||||||
D = DeprecationWarning
|
|
||||||
E = Exception
|
|
||||||
|
|
||||||
# Context cycle:
|
|
||||||
# +-----------+
|
|
||||||
# V |
|
|
||||||
# E --> D --> C --> B --> A
|
|
||||||
with self.assertRaises(E) as cm:
|
|
||||||
try:
|
|
||||||
raise A
|
|
||||||
except A as _a:
|
|
||||||
a = _a
|
|
||||||
try:
|
|
||||||
raise B
|
|
||||||
except B as _b:
|
|
||||||
b = _b
|
|
||||||
try:
|
|
||||||
raise C
|
|
||||||
except C as _c:
|
|
||||||
c = _c
|
|
||||||
a.__context__ = c
|
|
||||||
try:
|
|
||||||
raise D
|
|
||||||
except D as _d:
|
|
||||||
d = _d
|
|
||||||
e = E()
|
|
||||||
raise e # noqa: B904
|
|
||||||
|
|
||||||
self.assertIs(cm.exception, e)
|
|
||||||
# Verify the expected context chain cycle
|
|
||||||
self.assertIs(e.__context__, d)
|
|
||||||
self.assertIs(d.__context__, c)
|
|
||||||
self.assertIs(c.__context__, b)
|
|
||||||
self.assertIs(b.__context__, a)
|
|
||||||
self.assertIs(a.__context__, c)
|
|
||||||
|
|
||||||
|
|
||||||
instantiate_parametrized_tests(ExceptionTests)
|
instantiate_parametrized_tests(ExceptionTests)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1481,331 +1481,6 @@ class TestGeneratorThrow(GeneratorTestsBase):
|
||||||
self._compile_check(fn)
|
self._compile_check(fn)
|
||||||
|
|
||||||
|
|
||||||
class GeneratorCloseCPythonTests(GeneratorTestsBase):
|
|
||||||
# Taken from commit
|
|
||||||
# https://github.com/python/cpython/blob/d51a4ca1123e3e49e5cae4273355bdfd9e419a10
|
|
||||||
# changed the tests a little bit to run them inside dynamo
|
|
||||||
# + replaced all self.assert* calls to plain assert statements
|
|
||||||
|
|
||||||
def test_close_no_return_value(self):
|
|
||||||
def f():
|
|
||||||
yield
|
|
||||||
|
|
||||||
@torch.compile(backend="eager", fullgraph=True)
|
|
||||||
def fn(t):
|
|
||||||
gen = f()
|
|
||||||
gen.send(None)
|
|
||||||
assert gen.close() is None
|
|
||||||
return t.sin()
|
|
||||||
|
|
||||||
t = torch.randn(2)
|
|
||||||
fn(t)
|
|
||||||
|
|
||||||
def test_close_return_value(self):
|
|
||||||
def f():
|
|
||||||
try:
|
|
||||||
yield
|
|
||||||
# close() raises GeneratorExit here, which is caught
|
|
||||||
except GeneratorExit:
|
|
||||||
return 0 # noqa: B901
|
|
||||||
|
|
||||||
@torch.compile(backend="eager", fullgraph=True)
|
|
||||||
def fn(t):
|
|
||||||
gen = f()
|
|
||||||
gen.send(None)
|
|
||||||
assert gen.close() == 0
|
|
||||||
return t.sin()
|
|
||||||
|
|
||||||
t = torch.randn(2)
|
|
||||||
fn(t)
|
|
||||||
|
|
||||||
def test_close_not_catching_exit(self):
|
|
||||||
def f():
|
|
||||||
yield
|
|
||||||
# close() raises GeneratorExit here, which isn't caught and
|
|
||||||
# therefore propagates -- no return value
|
|
||||||
return 0 # noqa: B901
|
|
||||||
|
|
||||||
@torch.compile(backend="eager", fullgraph=True)
|
|
||||||
def fn(t):
|
|
||||||
gen = f()
|
|
||||||
gen.send(None)
|
|
||||||
assert gen.close() is None
|
|
||||||
return t.sin()
|
|
||||||
|
|
||||||
t = torch.randn(2)
|
|
||||||
fn(t)
|
|
||||||
|
|
||||||
def test_close_not_started(self):
|
|
||||||
def f():
|
|
||||||
try:
|
|
||||||
yield
|
|
||||||
except GeneratorExit:
|
|
||||||
return 0 # noqa: B901
|
|
||||||
|
|
||||||
@torch.compile(backend="eager", fullgraph=True)
|
|
||||||
def fn(t):
|
|
||||||
gen = f()
|
|
||||||
assert gen.close() is None
|
|
||||||
return t.sin()
|
|
||||||
|
|
||||||
t = torch.randn(2)
|
|
||||||
fn(t)
|
|
||||||
|
|
||||||
def test_close_exhausted(self):
|
|
||||||
def f():
|
|
||||||
try:
|
|
||||||
yield
|
|
||||||
except GeneratorExit:
|
|
||||||
return 0 # noqa: B901
|
|
||||||
|
|
||||||
@torch.compile(backend="eager", fullgraph=True)
|
|
||||||
def fn(t):
|
|
||||||
gen = f()
|
|
||||||
next(gen)
|
|
||||||
z = 0
|
|
||||||
try:
|
|
||||||
next(gen) # -> StopIteration
|
|
||||||
except StopIteration:
|
|
||||||
z = 1
|
|
||||||
except Exception as e:
|
|
||||||
# anything other than StopIteration should fail
|
|
||||||
raise AssertionError from e
|
|
||||||
assert z == 1
|
|
||||||
assert gen.close() is None
|
|
||||||
return t.sin()
|
|
||||||
|
|
||||||
t = torch.randn(2)
|
|
||||||
fn(t)
|
|
||||||
|
|
||||||
def test_close_closed(self):
|
|
||||||
def f():
|
|
||||||
try:
|
|
||||||
yield
|
|
||||||
except GeneratorExit:
|
|
||||||
return 0 # noqa: B901
|
|
||||||
|
|
||||||
@torch.compile(backend="eager", fullgraph=True)
|
|
||||||
def fn(t):
|
|
||||||
gen = f()
|
|
||||||
gen.send(None)
|
|
||||||
assert gen.close() == 0
|
|
||||||
assert gen.close() is None
|
|
||||||
return t.sin()
|
|
||||||
|
|
||||||
t = torch.randn(2)
|
|
||||||
fn(t)
|
|
||||||
|
|
||||||
def test_close_raises(self):
|
|
||||||
def f():
|
|
||||||
try:
|
|
||||||
yield
|
|
||||||
except GeneratorExit:
|
|
||||||
pass
|
|
||||||
raise RuntimeError
|
|
||||||
|
|
||||||
@torch.compile(backend="eager", fullgraph=True)
|
|
||||||
def fn(t):
|
|
||||||
gen = f()
|
|
||||||
gen.send(None)
|
|
||||||
z = 0
|
|
||||||
try:
|
|
||||||
gen.close() # -> RuntimeError
|
|
||||||
except RuntimeError:
|
|
||||||
z = 1
|
|
||||||
except Exception as e:
|
|
||||||
raise AssertionError from e
|
|
||||||
assert z == 1
|
|
||||||
return t.sin()
|
|
||||||
|
|
||||||
t = torch.randn(2)
|
|
||||||
fn(t)
|
|
||||||
|
|
||||||
|
|
||||||
class GeneratorThrowCpythonTests(GeneratorTestsBase):
|
|
||||||
# Taken from commit
|
|
||||||
# https://github.com/python/cpython/blob/d51a4ca1123e3e49e5cae4273355bdfd9e419a10
|
|
||||||
# changed the tests a little bit to run them inside dynamo
|
|
||||||
# + replaced all self.assert* calls to plain assert statements
|
|
||||||
|
|
||||||
def test_exception_context_with_yield(self):
|
|
||||||
def f():
|
|
||||||
try:
|
|
||||||
raise KeyError("a")
|
|
||||||
except Exception:
|
|
||||||
yield
|
|
||||||
|
|
||||||
def fn(t):
|
|
||||||
gen = f()
|
|
||||||
gen.send(None)
|
|
||||||
try:
|
|
||||||
gen.throw(ValueError)
|
|
||||||
except ValueError as e:
|
|
||||||
context = e.__context__
|
|
||||||
assert (type(context), context.args) == (KeyError, ("a",))
|
|
||||||
except Exception as e:
|
|
||||||
raise AssertionError from e
|
|
||||||
return t.sin()
|
|
||||||
|
|
||||||
self._compile_check(fn)
|
|
||||||
|
|
||||||
def test_exception_context_with_yield_inside_generator(self):
|
|
||||||
# Check that the context is also available from inside the generator
|
|
||||||
# with yield, as opposed to outside.
|
|
||||||
def f():
|
|
||||||
z = 0
|
|
||||||
try:
|
|
||||||
raise KeyError("a")
|
|
||||||
except Exception:
|
|
||||||
try:
|
|
||||||
yield
|
|
||||||
except Exception as exc:
|
|
||||||
z = 1
|
|
||||||
assert type(exc) == ValueError
|
|
||||||
context = exc.__context__
|
|
||||||
assert (type(context), context.args) == (KeyError, ("a",))
|
|
||||||
yield "b"
|
|
||||||
finally:
|
|
||||||
assert z == 1
|
|
||||||
|
|
||||||
def fn(t):
|
|
||||||
gen = f()
|
|
||||||
gen.send(None)
|
|
||||||
actual = gen.throw(ValueError)
|
|
||||||
# This ensures that the assertions inside were executed.
|
|
||||||
assert actual == "b"
|
|
||||||
return t.sin()
|
|
||||||
|
|
||||||
self._compile_check(fn)
|
|
||||||
|
|
||||||
def test_exception_context_with_yield_from(self):
|
|
||||||
def f():
|
|
||||||
yield
|
|
||||||
|
|
||||||
def g():
|
|
||||||
try:
|
|
||||||
raise KeyError("a")
|
|
||||||
except Exception:
|
|
||||||
yield from f()
|
|
||||||
|
|
||||||
def fn(t):
|
|
||||||
gen = g()
|
|
||||||
gen.send(None)
|
|
||||||
try:
|
|
||||||
gen.throw(ValueError)
|
|
||||||
except ValueError as e:
|
|
||||||
context = e.__context__
|
|
||||||
assert (type(context), context.args) == (KeyError, ("a",))
|
|
||||||
except Exception as e:
|
|
||||||
raise AssertionError from e
|
|
||||||
return t.sin()
|
|
||||||
|
|
||||||
self._compile_check(fn)
|
|
||||||
|
|
||||||
def test_exception_context_with_yield_from_with_context_cycle(self):
|
|
||||||
# Check trying to create an exception context cycle:
|
|
||||||
# https://bugs.python.org/issue40696
|
|
||||||
has_cycle = None
|
|
||||||
|
|
||||||
def f():
|
|
||||||
yield
|
|
||||||
|
|
||||||
def g(exc):
|
|
||||||
nonlocal has_cycle
|
|
||||||
try:
|
|
||||||
raise exc
|
|
||||||
except Exception:
|
|
||||||
try:
|
|
||||||
yield from f()
|
|
||||||
except Exception as exc:
|
|
||||||
has_cycle = exc is exc.__context__
|
|
||||||
yield
|
|
||||||
|
|
||||||
def fn(t):
|
|
||||||
exc = KeyError("a")
|
|
||||||
gen = g(exc)
|
|
||||||
gen.send(None)
|
|
||||||
gen.throw(exc)
|
|
||||||
# This also distinguishes from the initial has_cycle=None.
|
|
||||||
assert has_cycle is False
|
|
||||||
return t.sin()
|
|
||||||
|
|
||||||
self._compile_check(fn)
|
|
||||||
|
|
||||||
def test_throw_after_none_exc_type(self):
|
|
||||||
def g():
|
|
||||||
try:
|
|
||||||
raise KeyError
|
|
||||||
except KeyError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
try:
|
|
||||||
yield
|
|
||||||
except Exception:
|
|
||||||
raise RuntimeError # noqa: B904
|
|
||||||
|
|
||||||
def fn(t):
|
|
||||||
gen = g()
|
|
||||||
gen.send(None)
|
|
||||||
z = 0
|
|
||||||
try:
|
|
||||||
gen.throw(ValueError)
|
|
||||||
except RuntimeError:
|
|
||||||
z += 1
|
|
||||||
except Exception:
|
|
||||||
raise AssertionError # noqa: B904
|
|
||||||
assert z == 1
|
|
||||||
return t.sin()
|
|
||||||
|
|
||||||
self._compile_check(fn)
|
|
||||||
|
|
||||||
|
|
||||||
class GeneratorCPythonTests(GeneratorTestsBase):
|
|
||||||
# Taken from commit
|
|
||||||
# https://github.com/python/cpython/blob/d51a4ca1123e3e49e5cae4273355bdfd9e419a10
|
|
||||||
# changed the tests a little bit to run them inside dynamo
|
|
||||||
# + replaced all self.assert* calls to plain assert statements
|
|
||||||
|
|
||||||
def test_send_non_none_to_new_gen(self):
|
|
||||||
def f():
|
|
||||||
yield 1
|
|
||||||
|
|
||||||
def fn(t):
|
|
||||||
g = f()
|
|
||||||
z = 0
|
|
||||||
try:
|
|
||||||
g.send(0)
|
|
||||||
except TypeError:
|
|
||||||
z += 1
|
|
||||||
except Exception as e:
|
|
||||||
raise AssertionError from e
|
|
||||||
assert z == 1
|
|
||||||
assert next(g) == 1
|
|
||||||
return t.sin()
|
|
||||||
|
|
||||||
self._compile_check(fn)
|
|
||||||
|
|
||||||
def test_issue103488(self):
|
|
||||||
def gen_raises():
|
|
||||||
yield 1
|
|
||||||
raise ValueError
|
|
||||||
|
|
||||||
def loop():
|
|
||||||
try:
|
|
||||||
for _ in gen_raises():
|
|
||||||
if True is False: # noqa: PLR0133
|
|
||||||
return
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def fn(t):
|
|
||||||
# This should not raise
|
|
||||||
loop()
|
|
||||||
return t.sin()
|
|
||||||
|
|
||||||
self._compile_check(fn)
|
|
||||||
|
|
||||||
|
|
||||||
instantiate_parametrized_tests(GeneratorTests)
|
instantiate_parametrized_tests(GeneratorTests)
|
||||||
instantiate_parametrized_tests(TestGeneratorSend)
|
instantiate_parametrized_tests(TestGeneratorSend)
|
||||||
instantiate_parametrized_tests(TestGeneratorClose)
|
instantiate_parametrized_tests(TestGeneratorClose)
|
||||||
|
|
|
||||||
|
|
@ -1,52 +0,0 @@
|
||||||
# Owner(s): ["module: dynamo"]
|
|
||||||
|
|
||||||
import sys
|
|
||||||
import unittest
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch._dynamo.test_case
|
|
||||||
from torch.testing._internal.common_utils import make_dynamo_test
|
|
||||||
|
|
||||||
|
|
||||||
class TestPEP479(torch._dynamo.test_case.CPythonTestCase):
|
|
||||||
# Tests taken from CPython source code in cpython/Lib/test/test_generator_stop.py
|
|
||||||
# https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_generator_stop.py
|
|
||||||
@unittest.skipIf(sys.version_info < (3, 12), "Test does not work in Python < 3.12")
|
|
||||||
@make_dynamo_test
|
|
||||||
def test_stopiteration_wrapping(self):
|
|
||||||
def f():
|
|
||||||
raise StopIteration
|
|
||||||
|
|
||||||
def g():
|
|
||||||
yield f()
|
|
||||||
|
|
||||||
with self.assertRaises(RuntimeError) as cm:
|
|
||||||
next(g())
|
|
||||||
self.assertEqual("generator raised StopIteration", str(cm.exception))
|
|
||||||
|
|
||||||
@unittest.skipIf(sys.version_info < (3, 12), "Test does not work in Python < 3.12")
|
|
||||||
@make_dynamo_test
|
|
||||||
def test_stopiteration_wrapping_context(self):
|
|
||||||
def f():
|
|
||||||
raise StopIteration
|
|
||||||
|
|
||||||
def g():
|
|
||||||
yield f()
|
|
||||||
|
|
||||||
try:
|
|
||||||
next(g())
|
|
||||||
except RuntimeError as exc:
|
|
||||||
self.assertIs(type(exc.__cause__), StopIteration)
|
|
||||||
self.assertIs(type(exc.__context__), StopIteration)
|
|
||||||
self.assertTrue(exc.__suppress_context__)
|
|
||||||
else:
|
|
||||||
self.fail(
|
|
||||||
"__cause__, __context__, or __suppress_context__ "
|
|
||||||
"were not properly set"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
from torch._dynamo.test_case import run_tests
|
|
||||||
|
|
||||||
run_tests()
|
|
||||||
|
|
@ -1,563 +0,0 @@
|
||||||
# Owner(s): ["module: dynamo"]
|
|
||||||
|
|
||||||
# ruff: noqa
|
|
||||||
# flake8: noqa
|
|
||||||
|
|
||||||
import sys
|
|
||||||
import types
|
|
||||||
import unittest
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch._dynamo.config
|
|
||||||
import torch._dynamo.test_case
|
|
||||||
import torch._functorch.config
|
|
||||||
import torch.nn
|
|
||||||
import torch.utils.checkpoint
|
|
||||||
from torch.testing._internal.common_utils import make_dynamo_test
|
|
||||||
|
|
||||||
|
|
||||||
def get_tb():
|
|
||||||
try:
|
|
||||||
raise OSError()
|
|
||||||
except:
|
|
||||||
return sys.exc_info()[2]
|
|
||||||
|
|
||||||
|
|
||||||
class Context:
|
|
||||||
def __enter__(self):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_value, exc_tb):
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
class MyException(Exception):
|
|
||||||
def __init__(self):
|
|
||||||
raise RuntimeError()
|
|
||||||
|
|
||||||
|
|
||||||
class ContextManager:
|
|
||||||
def __enter__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def __exit__(self, t, v, tb):
|
|
||||||
raise NameError
|
|
||||||
|
|
||||||
|
|
||||||
class TestRaise(torch._dynamo.test_case.CPythonTestCase):
|
|
||||||
# Tests taken from CPython source code in cpython/Lib/test/test_raise.py
|
|
||||||
# https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_raise.py
|
|
||||||
@make_dynamo_test
|
|
||||||
def test_invalid_reraise(self):
|
|
||||||
try:
|
|
||||||
raise
|
|
||||||
except RuntimeError as e:
|
|
||||||
self.assertIn("No active exception", str(e))
|
|
||||||
else:
|
|
||||||
self.fail("No exception raised")
|
|
||||||
|
|
||||||
@make_dynamo_test
|
|
||||||
def test_reraise(self):
|
|
||||||
try:
|
|
||||||
try:
|
|
||||||
raise IndexError
|
|
||||||
except IndexError as e:
|
|
||||||
exc1 = e
|
|
||||||
raise
|
|
||||||
except IndexError as exc2:
|
|
||||||
self.assertIs(exc1, exc2)
|
|
||||||
else:
|
|
||||||
self.fail("No exception raised")
|
|
||||||
|
|
||||||
@make_dynamo_test
|
|
||||||
def test_except_reraise(self):
|
|
||||||
def reraise():
|
|
||||||
try:
|
|
||||||
raise TypeError("foo")
|
|
||||||
except:
|
|
||||||
try:
|
|
||||||
raise KeyError("caught")
|
|
||||||
except KeyError:
|
|
||||||
pass
|
|
||||||
raise
|
|
||||||
|
|
||||||
self.assertRaises(TypeError, reraise)
|
|
||||||
|
|
||||||
@make_dynamo_test
|
|
||||||
def test_finally_reraise(self):
|
|
||||||
def reraise():
|
|
||||||
try:
|
|
||||||
raise TypeError("foo")
|
|
||||||
except:
|
|
||||||
try:
|
|
||||||
raise KeyError("caught")
|
|
||||||
finally:
|
|
||||||
raise
|
|
||||||
|
|
||||||
self.assertRaises(KeyError, reraise)
|
|
||||||
|
|
||||||
@make_dynamo_test
|
|
||||||
def test_nested_reraise(self):
|
|
||||||
def nested_reraise():
|
|
||||||
raise
|
|
||||||
|
|
||||||
def reraise():
|
|
||||||
try:
|
|
||||||
raise TypeError("foo")
|
|
||||||
except:
|
|
||||||
nested_reraise()
|
|
||||||
|
|
||||||
self.assertRaises(TypeError, reraise)
|
|
||||||
|
|
||||||
@make_dynamo_test
|
|
||||||
def test_raise_from_None(self):
|
|
||||||
try:
|
|
||||||
try:
|
|
||||||
raise TypeError("foo")
|
|
||||||
except:
|
|
||||||
raise ValueError() from None
|
|
||||||
except ValueError as e:
|
|
||||||
self.assertIsInstance(e.__context__, TypeError)
|
|
||||||
self.assertIsNone(e.__cause__)
|
|
||||||
|
|
||||||
@make_dynamo_test
|
|
||||||
def test_with_reraise1(self):
|
|
||||||
def reraise():
|
|
||||||
try:
|
|
||||||
raise TypeError("foo")
|
|
||||||
except:
|
|
||||||
with Context():
|
|
||||||
pass
|
|
||||||
raise
|
|
||||||
|
|
||||||
self.assertRaises(TypeError, reraise)
|
|
||||||
|
|
||||||
@make_dynamo_test
|
|
||||||
def test_with_reraise2(self):
|
|
||||||
def reraise():
|
|
||||||
try:
|
|
||||||
raise TypeError("foo")
|
|
||||||
except:
|
|
||||||
with Context():
|
|
||||||
raise KeyError("caught")
|
|
||||||
raise
|
|
||||||
|
|
||||||
self.assertRaises(TypeError, reraise)
|
|
||||||
|
|
||||||
@make_dynamo_test
|
|
||||||
def test_yield_reraise(self):
|
|
||||||
def reraise():
|
|
||||||
try:
|
|
||||||
raise TypeError("foo")
|
|
||||||
except:
|
|
||||||
yield 1
|
|
||||||
raise
|
|
||||||
|
|
||||||
g = reraise()
|
|
||||||
next(g)
|
|
||||||
self.assertRaises(TypeError, lambda: next(g))
|
|
||||||
self.assertRaises(StopIteration, lambda: next(g))
|
|
||||||
|
|
||||||
@make_dynamo_test
|
|
||||||
def test_erroneous_exception(self):
|
|
||||||
try:
|
|
||||||
raise MyException
|
|
||||||
except RuntimeError:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
self.fail("No exception raised")
|
|
||||||
|
|
||||||
@unittest.expectedFailure # object
|
|
||||||
@make_dynamo_test
|
|
||||||
def test_new_returns_invalid_instance(self):
|
|
||||||
# See issue #11627.
|
|
||||||
class MyException2(Exception):
|
|
||||||
def __new__(cls, *args):
|
|
||||||
return object()
|
|
||||||
|
|
||||||
with self.assertRaises(TypeError):
|
|
||||||
raise MyException2
|
|
||||||
|
|
||||||
@unittest.expectedFailure # Assertion with non-string message
|
|
||||||
@make_dynamo_test
|
|
||||||
def test_assert_with_tuple_arg(self):
|
|
||||||
try:
|
|
||||||
assert False, (3,)
|
|
||||||
except AssertionError as e:
|
|
||||||
self.assertEqual(str(e), "(3,)")
|
|
||||||
|
|
||||||
|
|
||||||
class TestCause(torch._dynamo.test_case.TestCase):
|
|
||||||
# Tests taken from CPython source code in cpython/Lib/test/test_raise.py
|
|
||||||
# https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_raise.py
|
|
||||||
def setUp(self):
|
|
||||||
self._prev = torch._dynamo.config.enable_trace_unittest
|
|
||||||
torch._dynamo.config.enable_trace_unittest = True
|
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
torch._dynamo.config.enable_trace_unittest = self._prev
|
|
||||||
|
|
||||||
@make_dynamo_test
|
|
||||||
def testCauseSyntax(self):
|
|
||||||
try:
|
|
||||||
try:
|
|
||||||
try:
|
|
||||||
raise TypeError
|
|
||||||
except Exception:
|
|
||||||
raise ValueError from None
|
|
||||||
except ValueError as exc:
|
|
||||||
self.assertIsNone(exc.__cause__)
|
|
||||||
self.assertTrue(exc.__suppress_context__)
|
|
||||||
exc.__suppress_context__ = False
|
|
||||||
raise exc
|
|
||||||
except ValueError as exc:
|
|
||||||
e = exc
|
|
||||||
|
|
||||||
self.assertIsNone(e.__cause__)
|
|
||||||
self.assertFalse(e.__suppress_context__)
|
|
||||||
self.assertIsInstance(e.__context__, TypeError)
|
|
||||||
|
|
||||||
@make_dynamo_test
|
|
||||||
def test_invalid_cause(self):
|
|
||||||
try:
|
|
||||||
raise IndexError from 5
|
|
||||||
except TypeError as e:
|
|
||||||
self.assertIn("exception cause", str(e))
|
|
||||||
else:
|
|
||||||
self.fail("No exception raised")
|
|
||||||
|
|
||||||
@make_dynamo_test
|
|
||||||
def test_class_cause(self):
|
|
||||||
try:
|
|
||||||
raise IndexError from KeyError
|
|
||||||
except IndexError as e:
|
|
||||||
self.assertIsInstance(e.__cause__, KeyError)
|
|
||||||
else:
|
|
||||||
self.fail("No exception raised")
|
|
||||||
|
|
||||||
@make_dynamo_test
|
|
||||||
def test_instance_cause(self):
|
|
||||||
cause = KeyError()
|
|
||||||
try:
|
|
||||||
raise IndexError from cause
|
|
||||||
except IndexError as e:
|
|
||||||
self.assertIs(e.__cause__, cause)
|
|
||||||
else:
|
|
||||||
self.fail("No exception raised")
|
|
||||||
|
|
||||||
@make_dynamo_test
|
|
||||||
def test_erroneous_cause(self):
|
|
||||||
try:
|
|
||||||
raise IndexError from MyException
|
|
||||||
except RuntimeError:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
self.fail("No exception raised")
|
|
||||||
|
|
||||||
|
|
||||||
class TestTraceback(torch._dynamo.test_case.TestCase):
|
|
||||||
# Tests taken from CPython source code in cpython/Lib/test/test_raise.py
|
|
||||||
# https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_raise.py
|
|
||||||
def setUp(self):
|
|
||||||
self._prev = torch._dynamo.config.enable_trace_unittest
|
|
||||||
torch._dynamo.config.enable_trace_unittest = True
|
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
torch._dynamo.config.enable_trace_unittest = self._prev
|
|
||||||
|
|
||||||
@unittest.expectedFailure # Dynamo doesn't track traceback
|
|
||||||
@make_dynamo_test
|
|
||||||
def test_sets_traceback(self):
|
|
||||||
try:
|
|
||||||
raise IndexError()
|
|
||||||
except IndexError as e:
|
|
||||||
self.assertIsInstance(e.__traceback__, types.TracebackType)
|
|
||||||
else:
|
|
||||||
self.fail("No exception raised")
|
|
||||||
|
|
||||||
@unittest.expectedFailure # Dynamo doesn't track traceback
|
|
||||||
@make_dynamo_test
|
|
||||||
def test_accepts_traceback(self):
|
|
||||||
tb = get_tb()
|
|
||||||
try:
|
|
||||||
raise IndexError().with_traceback(tb)
|
|
||||||
except IndexError as e:
|
|
||||||
self.assertNotEqual(e.__traceback__, tb)
|
|
||||||
self.assertEqual(e.__traceback__.tb_next, tb)
|
|
||||||
else:
|
|
||||||
self.fail("No exception raised")
|
|
||||||
|
|
||||||
|
|
||||||
class TestTracebackType(torch._dynamo.test_case.TestCase):
|
|
||||||
# Tests taken from CPython source code in cpython/Lib/test/test_raise.py
|
|
||||||
# https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_raise.py
|
|
||||||
def setUp(self):
|
|
||||||
self._prev = torch._dynamo.config.enable_trace_unittest
|
|
||||||
torch._dynamo.config.enable_trace_unittest = True
|
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
torch._dynamo.config.enable_trace_unittest = self._prev
|
|
||||||
|
|
||||||
def raiser(self):
|
|
||||||
raise ValueError
|
|
||||||
|
|
||||||
@unittest.expectedFailure # Dynamo doesn't track traceback
|
|
||||||
@make_dynamo_test
|
|
||||||
def test_attrs(self):
|
|
||||||
try:
|
|
||||||
self.raiser()
|
|
||||||
except Exception as exc:
|
|
||||||
tb = exc.__traceback__
|
|
||||||
|
|
||||||
self.assertIsInstance(tb.tb_next, types.TracebackType)
|
|
||||||
self.assertIs(tb.tb_frame, sys._getframe())
|
|
||||||
self.assertIsInstance(tb.tb_lasti, int)
|
|
||||||
self.assertIsInstance(tb.tb_lineno, int)
|
|
||||||
|
|
||||||
self.assertIs(tb.tb_next.tb_next, None)
|
|
||||||
|
|
||||||
# Invalid assignments
|
|
||||||
with self.assertRaises(TypeError):
|
|
||||||
del tb.tb_next
|
|
||||||
|
|
||||||
with self.assertRaises(TypeError):
|
|
||||||
tb.tb_next = "asdf"
|
|
||||||
|
|
||||||
# Loops
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
tb.tb_next = tb
|
|
||||||
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
tb.tb_next.tb_next = tb
|
|
||||||
|
|
||||||
# Valid assignments
|
|
||||||
tb.tb_next = None
|
|
||||||
self.assertIs(tb.tb_next, None)
|
|
||||||
|
|
||||||
new_tb = get_tb()
|
|
||||||
tb.tb_next = new_tb
|
|
||||||
self.assertIs(tb.tb_next, new_tb)
|
|
||||||
|
|
||||||
@unittest.expectedFailure # Dynamo doesn't track traceback
|
|
||||||
@make_dynamo_test
|
|
||||||
def test_constructor(self):
|
|
||||||
other_tb = get_tb()
|
|
||||||
frame = sys._getframe()
|
|
||||||
|
|
||||||
tb = types.TracebackType(other_tb, frame, 1, 2)
|
|
||||||
self.assertEqual(tb.tb_next, other_tb)
|
|
||||||
self.assertEqual(tb.tb_frame, frame)
|
|
||||||
self.assertEqual(tb.tb_lasti, 1)
|
|
||||||
self.assertEqual(tb.tb_lineno, 2)
|
|
||||||
|
|
||||||
tb = types.TracebackType(None, frame, 1, 2)
|
|
||||||
self.assertEqual(tb.tb_next, None)
|
|
||||||
|
|
||||||
with self.assertRaises(TypeError):
|
|
||||||
types.TracebackType("no", frame, 1, 2)
|
|
||||||
|
|
||||||
with self.assertRaises(TypeError):
|
|
||||||
types.TracebackType(other_tb, "no", 1, 2)
|
|
||||||
|
|
||||||
with self.assertRaises(TypeError):
|
|
||||||
types.TracebackType(other_tb, frame, "no", 2)
|
|
||||||
|
|
||||||
with self.assertRaises(TypeError):
|
|
||||||
types.TracebackType(other_tb, frame, 1, "nuh-uh")
|
|
||||||
|
|
||||||
|
|
||||||
class TestContext(torch._dynamo.test_case.TestCase):
|
|
||||||
# Tests taken from CPython source code in cpython/Lib/test/test_raise.py
|
|
||||||
# https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_raise.py
|
|
||||||
def setUp(self):
|
|
||||||
self._prev = torch._dynamo.config.enable_trace_unittest
|
|
||||||
torch._dynamo.config.enable_trace_unittest = True
|
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
torch._dynamo.config.enable_trace_unittest = self._prev
|
|
||||||
|
|
||||||
@unittest.expectedFailure # missing Exception.__eq__
|
|
||||||
@make_dynamo_test
|
|
||||||
def test_instance_context_instance_raise(self):
|
|
||||||
context = IndexError()
|
|
||||||
try:
|
|
||||||
try:
|
|
||||||
raise context
|
|
||||||
except:
|
|
||||||
raise OSError()
|
|
||||||
except OSError as e:
|
|
||||||
self.assertEqual(e.__context__, context)
|
|
||||||
else:
|
|
||||||
self.fail("No exception raised")
|
|
||||||
|
|
||||||
@unittest.expectedFailure # missing Exception.__eq__ and Exception.__repr__
|
|
||||||
@make_dynamo_test
|
|
||||||
def test_class_context_instance_raise(self):
|
|
||||||
context = IndexError
|
|
||||||
try:
|
|
||||||
try:
|
|
||||||
raise context
|
|
||||||
except:
|
|
||||||
raise OSError()
|
|
||||||
except OSError as e:
|
|
||||||
self.assertNotEqual(e.__context__, context)
|
|
||||||
self.assertIsInstance(e.__context__, context)
|
|
||||||
else:
|
|
||||||
self.fail("No exception raised")
|
|
||||||
|
|
||||||
@unittest.expectedFailure # missing Exception.__eq__ and Exception.__repr__
|
|
||||||
@make_dynamo_test
|
|
||||||
def test_class_context_class_raise(self):
|
|
||||||
context = IndexError
|
|
||||||
try:
|
|
||||||
try:
|
|
||||||
raise context
|
|
||||||
except:
|
|
||||||
raise OSError
|
|
||||||
except OSError as e:
|
|
||||||
self.assertNotEqual(e.__context__, context)
|
|
||||||
self.assertIsInstance(e.__context__, context)
|
|
||||||
else:
|
|
||||||
self.fail("No exception raised")
|
|
||||||
|
|
||||||
@make_dynamo_test
|
|
||||||
def test_c_exception_context(self):
|
|
||||||
try:
|
|
||||||
try:
|
|
||||||
raise ZeroDivisionError
|
|
||||||
except:
|
|
||||||
raise OSError
|
|
||||||
except OSError as e:
|
|
||||||
self.assertIsInstance(e.__context__, ZeroDivisionError)
|
|
||||||
else:
|
|
||||||
self.fail("No exception raised")
|
|
||||||
|
|
||||||
@make_dynamo_test
|
|
||||||
def test_c_exception_raise(self):
|
|
||||||
try:
|
|
||||||
try:
|
|
||||||
raise ZeroDivisionError
|
|
||||||
except:
|
|
||||||
raise NameError
|
|
||||||
except NameError as e:
|
|
||||||
self.assertIsInstance(e.__context__, ZeroDivisionError)
|
|
||||||
else:
|
|
||||||
self.fail("No exception raised")
|
|
||||||
|
|
||||||
@make_dynamo_test
|
|
||||||
def test_noraise_finally(self):
|
|
||||||
try:
|
|
||||||
try:
|
|
||||||
pass
|
|
||||||
finally:
|
|
||||||
raise OSError
|
|
||||||
except OSError as e:
|
|
||||||
self.assertIsNone(e.__context__)
|
|
||||||
else:
|
|
||||||
self.fail("No exception raised")
|
|
||||||
|
|
||||||
@make_dynamo_test
|
|
||||||
def test_raise_finally(self):
|
|
||||||
try:
|
|
||||||
try:
|
|
||||||
raise ZeroDivisionError
|
|
||||||
finally:
|
|
||||||
raise OSError
|
|
||||||
except OSError as e:
|
|
||||||
self.assertIsInstance(e.__context__, ZeroDivisionError)
|
|
||||||
else:
|
|
||||||
self.fail("No exception raised")
|
|
||||||
|
|
||||||
@make_dynamo_test
|
|
||||||
def test_context_manager(self):
|
|
||||||
try:
|
|
||||||
with ContextManager():
|
|
||||||
raise ZeroDivisionError
|
|
||||||
except NameError as e:
|
|
||||||
self.assertIsInstance(e.__context__, ZeroDivisionError)
|
|
||||||
else:
|
|
||||||
self.fail("No exception raised")
|
|
||||||
|
|
||||||
@make_dynamo_test
|
|
||||||
def test_cycle_broken(self):
|
|
||||||
# Self-cycles (when re-raising a caught exception) are broken
|
|
||||||
try:
|
|
||||||
try:
|
|
||||||
raise ZeroDivisionError
|
|
||||||
except ZeroDivisionError as e:
|
|
||||||
raise e
|
|
||||||
except ZeroDivisionError as e:
|
|
||||||
self.assertIsNone(e.__context__)
|
|
||||||
|
|
||||||
@make_dynamo_test
|
|
||||||
def test_reraise_cycle_broken(self):
|
|
||||||
# Non-trivial context cycles (through re-raising a previous exception)
|
|
||||||
# are broken too.
|
|
||||||
try:
|
|
||||||
try:
|
|
||||||
raise NameError
|
|
||||||
except NameError as a:
|
|
||||||
try:
|
|
||||||
raise ZeroDivisionError
|
|
||||||
except ZeroDivisionError:
|
|
||||||
raise a
|
|
||||||
except NameError as e:
|
|
||||||
self.assertIsNone(e.__context__.__context__)
|
|
||||||
|
|
||||||
@make_dynamo_test
|
|
||||||
def test_3118(self):
|
|
||||||
# deleting the generator caused the __context__ to be cleared
|
|
||||||
def gen():
|
|
||||||
try:
|
|
||||||
yield 1
|
|
||||||
finally:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def f():
|
|
||||||
g = gen()
|
|
||||||
next(g)
|
|
||||||
try:
|
|
||||||
try:
|
|
||||||
raise ValueError
|
|
||||||
except:
|
|
||||||
del g
|
|
||||||
raise KeyError
|
|
||||||
except Exception as e:
|
|
||||||
self.assertIsInstance(e.__context__, ValueError)
|
|
||||||
|
|
||||||
f()
|
|
||||||
|
|
||||||
@unittest.expectedFailure # too CPython specific(?)
|
|
||||||
@make_dynamo_test
|
|
||||||
def test_3611(self):
|
|
||||||
# A re-raised exception in a __del__ caused the __context__
|
|
||||||
# to be cleared
|
|
||||||
class C:
|
|
||||||
def __del__(self):
|
|
||||||
try:
|
|
||||||
raise ZeroDivisionError
|
|
||||||
except:
|
|
||||||
raise
|
|
||||||
|
|
||||||
def f():
|
|
||||||
x = C()
|
|
||||||
try:
|
|
||||||
try:
|
|
||||||
x.x
|
|
||||||
except AttributeError:
|
|
||||||
del x
|
|
||||||
raise TypeError
|
|
||||||
except Exception as e:
|
|
||||||
self.assertNotEqual(e.__context__, None)
|
|
||||||
self.assertIsInstance(e.__context__, AttributeError)
|
|
||||||
|
|
||||||
with support.catch_unraisable_exception() as cm:
|
|
||||||
f()
|
|
||||||
|
|
||||||
self.assertEqual(ZeroDivisionError, cm.unraisable.exc_type)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
from torch._dynamo.test_case import run_tests
|
|
||||||
|
|
||||||
run_tests()
|
|
||||||
|
|
@ -1,107 +0,0 @@
|
||||||
# Owner(s): ["module: dynamo"]
|
|
||||||
import sys
|
|
||||||
import unittest
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch._dynamo.test_case
|
|
||||||
from torch.testing._internal.common_utils import make_dynamo_test
|
|
||||||
|
|
||||||
|
|
||||||
class SysTests(torch._dynamo.test_case.TestCase):
|
|
||||||
def test_exc_info(self):
|
|
||||||
@torch.compile(backend="eager", fullgraph=True)
|
|
||||||
def fn(t):
|
|
||||||
try:
|
|
||||||
raise ValueError
|
|
||||||
except Exception:
|
|
||||||
typ, _, _ = sys.exc_info()
|
|
||||||
if typ is ValueError:
|
|
||||||
return t.sin()
|
|
||||||
else:
|
|
||||||
return t.cos()
|
|
||||||
|
|
||||||
t = torch.randn(2)
|
|
||||||
y = fn(t)
|
|
||||||
self.assertEqual(y, t.sin())
|
|
||||||
|
|
||||||
|
|
||||||
class CPythonActiveExceptionTests(torch._dynamo.test_case.CPythonTestCase):
|
|
||||||
# Tests taken from CPython source code in cpython/Lib/test/test_sys.py
|
|
||||||
# https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_sys.py
|
|
||||||
|
|
||||||
@make_dynamo_test
|
|
||||||
def test_exc_info_no_exception(self):
|
|
||||||
self.assertEqual(sys.exc_info(), (None, None, None))
|
|
||||||
|
|
||||||
@unittest.skipIf(sys.version_info < (3, 11), "Python 3.11+")
|
|
||||||
@make_dynamo_test
|
|
||||||
def test_sys_exception_no_exception(self):
|
|
||||||
self.assertEqual(sys.exception(), None)
|
|
||||||
|
|
||||||
@make_dynamo_test
|
|
||||||
def test_exc_info_with_exception_instance(self):
|
|
||||||
def f():
|
|
||||||
raise ValueError(42)
|
|
||||||
|
|
||||||
try:
|
|
||||||
f()
|
|
||||||
except Exception as e_:
|
|
||||||
e = e_
|
|
||||||
exc_info = sys.exc_info()
|
|
||||||
|
|
||||||
self.assertIsInstance(e, ValueError)
|
|
||||||
self.assertIs(exc_info[0], ValueError)
|
|
||||||
self.assertIs(exc_info[1], e)
|
|
||||||
self.assertIs(exc_info[2], e.__traceback__)
|
|
||||||
|
|
||||||
@make_dynamo_test
|
|
||||||
def test_exc_info_with_exception_type(self):
|
|
||||||
def f():
|
|
||||||
raise ValueError
|
|
||||||
|
|
||||||
try:
|
|
||||||
f()
|
|
||||||
except Exception as e_:
|
|
||||||
e = e_
|
|
||||||
exc_info = sys.exc_info()
|
|
||||||
|
|
||||||
self.assertIsInstance(e, ValueError)
|
|
||||||
self.assertIs(exc_info[0], ValueError)
|
|
||||||
self.assertIs(exc_info[1], e)
|
|
||||||
self.assertIs(exc_info[2], e.__traceback__)
|
|
||||||
|
|
||||||
@unittest.skipIf(sys.version_info < (3, 11), "Python 3.11+")
|
|
||||||
@make_dynamo_test
|
|
||||||
def test_sys_exception_with_exception_instance(self):
|
|
||||||
def f():
|
|
||||||
raise ValueError(42)
|
|
||||||
|
|
||||||
try:
|
|
||||||
f()
|
|
||||||
except Exception as e_:
|
|
||||||
e = e_
|
|
||||||
exc = sys.exception()
|
|
||||||
|
|
||||||
self.assertIsInstance(e, ValueError)
|
|
||||||
self.assertIs(exc, e)
|
|
||||||
|
|
||||||
@unittest.skipIf(sys.version_info < (3, 11), "Python 3.11+")
|
|
||||||
@make_dynamo_test
|
|
||||||
def test_sys_exception_with_exception_type(self):
|
|
||||||
def f():
|
|
||||||
raise ValueError
|
|
||||||
|
|
||||||
try:
|
|
||||||
f()
|
|
||||||
except Exception as e_:
|
|
||||||
e = e_
|
|
||||||
exc = sys.exception()
|
|
||||||
|
|
||||||
self.assertIsInstance(e, ValueError)
|
|
||||||
self.assertIs(exc, e)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
from torch._dynamo.test_case import run_tests
|
|
||||||
|
|
||||||
run_tests()
|
|
||||||
|
|
@ -1,8 +1,5 @@
|
||||||
# Owner(s): ["module: dynamo"]
|
# Owner(s): ["module: dynamo"]
|
||||||
import sys
|
|
||||||
import unittest
|
import unittest
|
||||||
import warnings
|
|
||||||
from itertools import product
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch._dynamo.test_case
|
import torch._dynamo.test_case
|
||||||
|
|
@ -28,591 +25,6 @@ class TestUnittest(torch._dynamo.test_case.TestCase):
|
||||||
self.assertEqual(z, 1)
|
self.assertEqual(z, 1)
|
||||||
|
|
||||||
|
|
||||||
class CPythonTest_Assertions(torch._dynamo.test_case.CPythonTestCase):
|
|
||||||
# Tests taken from CPython source code in cpython/Lib/test/test_unittest/test_assertions.py
|
|
||||||
# https://github.com/python/cpython/blob/3.13/Lib/test/test_unittest/test_assertions.py
|
|
||||||
|
|
||||||
@make_dynamo_test
|
|
||||||
def test_AlmostEqual(self):
|
|
||||||
self.assertAlmostEqual(1.00000001, 1.0)
|
|
||||||
self.assertNotAlmostEqual(1.0000001, 1.0)
|
|
||||||
self.assertRaises(self.failureException, self.assertAlmostEqual, 1.0000001, 1.0)
|
|
||||||
self.assertRaises(
|
|
||||||
self.failureException, self.assertNotAlmostEqual, 1.00000001, 1.0
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertAlmostEqual(1.1, 1.0, places=0)
|
|
||||||
self.assertRaises(
|
|
||||||
self.failureException, self.assertAlmostEqual, 1.1, 1.0, places=1
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertAlmostEqual(0, 0.1 + 0.1j, places=0)
|
|
||||||
self.assertNotAlmostEqual(0, 0.1 + 0.1j, places=1)
|
|
||||||
self.assertRaises(
|
|
||||||
self.failureException, self.assertAlmostEqual, 0, 0.1 + 0.1j, places=1
|
|
||||||
)
|
|
||||||
self.assertRaises(
|
|
||||||
self.failureException, self.assertNotAlmostEqual, 0, 0.1 + 0.1j, places=0
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertAlmostEqual(float("inf"), float("inf"))
|
|
||||||
self.assertRaises(
|
|
||||||
self.failureException, self.assertNotAlmostEqual, float("inf"), float("inf")
|
|
||||||
)
|
|
||||||
|
|
||||||
@make_dynamo_test
|
|
||||||
def test_AmostEqualWithDelta(self):
|
|
||||||
self.assertAlmostEqual(1.1, 1.0, delta=0.5)
|
|
||||||
self.assertAlmostEqual(1.0, 1.1, delta=0.5)
|
|
||||||
self.assertNotAlmostEqual(1.1, 1.0, delta=0.05)
|
|
||||||
self.assertNotAlmostEqual(1.0, 1.1, delta=0.05)
|
|
||||||
|
|
||||||
self.assertAlmostEqual(1.0, 1.0, delta=0.5)
|
|
||||||
self.assertRaises(
|
|
||||||
self.failureException, self.assertNotAlmostEqual, 1.0, 1.0, delta=0.5
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertRaises(
|
|
||||||
self.failureException, self.assertAlmostEqual, 1.1, 1.0, delta=0.05
|
|
||||||
)
|
|
||||||
self.assertRaises(
|
|
||||||
self.failureException, self.assertNotAlmostEqual, 1.1, 1.0, delta=0.5
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertRaises(
|
|
||||||
TypeError, self.assertAlmostEqual, 1.1, 1.0, places=2, delta=2
|
|
||||||
)
|
|
||||||
self.assertRaises(
|
|
||||||
TypeError, self.assertNotAlmostEqual, 1.1, 1.0, places=2, delta=2
|
|
||||||
)
|
|
||||||
|
|
||||||
@make_dynamo_test
|
|
||||||
def test_assertRaises(self):
|
|
||||||
def _raise(e):
|
|
||||||
raise e
|
|
||||||
|
|
||||||
self.assertRaises(KeyError, _raise, KeyError)
|
|
||||||
self.assertRaises(KeyError, _raise, KeyError("key"))
|
|
||||||
try:
|
|
||||||
self.assertRaises(KeyError, lambda: None)
|
|
||||||
except self.failureException as e:
|
|
||||||
self.assertIn("KeyError not raised", str(e))
|
|
||||||
else:
|
|
||||||
self.fail("assertRaises() didn't fail")
|
|
||||||
try:
|
|
||||||
self.assertRaises(KeyError, _raise, ValueError)
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
self.fail("assertRaises() didn't let exception pass through")
|
|
||||||
with self.assertRaises(KeyError) as cm:
|
|
||||||
try:
|
|
||||||
raise KeyError
|
|
||||||
except Exception as e:
|
|
||||||
exc = e
|
|
||||||
raise
|
|
||||||
self.assertIs(cm.exception, exc)
|
|
||||||
|
|
||||||
with self.assertRaises(KeyError):
|
|
||||||
raise KeyError("key")
|
|
||||||
try:
|
|
||||||
with self.assertRaises(KeyError):
|
|
||||||
pass
|
|
||||||
except self.failureException as e:
|
|
||||||
self.assertIn("KeyError not raised", str(e))
|
|
||||||
else:
|
|
||||||
self.fail("assertRaises() didn't fail")
|
|
||||||
try:
|
|
||||||
with self.assertRaises(KeyError):
|
|
||||||
raise ValueError
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
self.fail("assertRaises() didn't let exception pass through")
|
|
||||||
|
|
||||||
@make_dynamo_test
|
|
||||||
def testAssertNotRegex(self):
|
|
||||||
self.assertNotRegex("Ala ma kota", r"r+")
|
|
||||||
try:
|
|
||||||
self.assertNotRegex("Ala ma kota", r"k.t", "Message")
|
|
||||||
except self.failureException as e:
|
|
||||||
self.assertIn("Message", e.args[0])
|
|
||||||
else:
|
|
||||||
self.fail("assertNotRegex should have failed.")
|
|
||||||
|
|
||||||
|
|
||||||
class CPythonTestLongMessage(torch._dynamo.test_case.CPythonTestCase):
|
|
||||||
"""Test that the individual asserts honour longMessage.
|
|
||||||
This actually tests all the message behaviour for
|
|
||||||
asserts that use longMessage."""
|
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
super().setUp()
|
|
||||||
|
|
||||||
class TestableTestFalse(unittest.TestCase):
|
|
||||||
longMessage = False
|
|
||||||
failureException = self.failureException
|
|
||||||
|
|
||||||
def testTest(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
class TestableTestTrue(unittest.TestCase):
|
|
||||||
longMessage = True
|
|
||||||
failureException = self.failureException
|
|
||||||
|
|
||||||
def testTest(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
self.testableTrue = TestableTestTrue("testTest")
|
|
||||||
self.testableFalse = TestableTestFalse("testTest")
|
|
||||||
|
|
||||||
def testDefault(self):
|
|
||||||
self.assertTrue(unittest.TestCase.longMessage)
|
|
||||||
|
|
||||||
def test_formatMsg(self):
|
|
||||||
self.assertEqual(self.testableFalse._formatMessage(None, "foo"), "foo")
|
|
||||||
self.assertEqual(self.testableFalse._formatMessage("foo", "bar"), "foo")
|
|
||||||
|
|
||||||
self.assertEqual(self.testableTrue._formatMessage(None, "foo"), "foo")
|
|
||||||
self.assertEqual(self.testableTrue._formatMessage("foo", "bar"), "bar : foo")
|
|
||||||
|
|
||||||
# This blows up if _formatMessage uses string concatenation
|
|
||||||
self.testableTrue._formatMessage(object(), "foo")
|
|
||||||
|
|
||||||
def test_formatMessage_unicode_error(self):
|
|
||||||
one = "".join(chr(i) for i in range(255))
|
|
||||||
# this used to cause a UnicodeDecodeError constructing msg
|
|
||||||
self.testableTrue._formatMessage(one, "\uFFFD")
|
|
||||||
|
|
||||||
def assertMessages(self, methodName, args, errors):
|
|
||||||
"""
|
|
||||||
Check that methodName(*args) raises the correct error messages.
|
|
||||||
errors should be a list of 4 regex that match the error when:
|
|
||||||
1) longMessage = False and no msg passed;
|
|
||||||
2) longMessage = False and msg passed;
|
|
||||||
3) longMessage = True and no msg passed;
|
|
||||||
4) longMessage = True and msg passed;
|
|
||||||
"""
|
|
||||||
|
|
||||||
def getMethod(i):
|
|
||||||
useTestableFalse = i < 2
|
|
||||||
if useTestableFalse:
|
|
||||||
test = self.testableFalse
|
|
||||||
else:
|
|
||||||
test = self.testableTrue
|
|
||||||
return getattr(test, methodName)
|
|
||||||
|
|
||||||
for i, expected_regex in enumerate(errors):
|
|
||||||
testMethod = getMethod(i)
|
|
||||||
kwargs = {}
|
|
||||||
withMsg = i % 2
|
|
||||||
if withMsg:
|
|
||||||
kwargs = {"msg": "oops"}
|
|
||||||
|
|
||||||
# with self.assertRaisesRegex(
|
|
||||||
# self.failureException, expected_regex=expected_regex
|
|
||||||
# ):
|
|
||||||
# testMethod(*args, **kwargs)
|
|
||||||
with self.assertRaises(self.failureException) as cm:
|
|
||||||
testMethod(*args, **kwargs)
|
|
||||||
self.assertRegex(str(cm.exception), expected_regex)
|
|
||||||
|
|
||||||
@make_dynamo_test
|
|
||||||
def testAssertTrue(self):
|
|
||||||
self.assertMessages(
|
|
||||||
"assertTrue",
|
|
||||||
(False,),
|
|
||||||
[
|
|
||||||
"False is not true",
|
|
||||||
"oops",
|
|
||||||
"False is not true",
|
|
||||||
"False is not true : oops",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
@make_dynamo_test
|
|
||||||
def testAssertFalse(self):
|
|
||||||
self.assertMessages(
|
|
||||||
"assertFalse",
|
|
||||||
(True,),
|
|
||||||
[
|
|
||||||
"True is not false",
|
|
||||||
"oops",
|
|
||||||
"True is not false",
|
|
||||||
"True is not false : oops",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
@make_dynamo_test
|
|
||||||
def testNotEqual(self):
|
|
||||||
self.assertMessages(
|
|
||||||
"assertNotEqual", (1, 1), ["1 == 1", "oops", "1 == 1", "1 == 1 : oops"]
|
|
||||||
)
|
|
||||||
|
|
||||||
@make_dynamo_test
|
|
||||||
def testAlmostEqual(self):
|
|
||||||
self.assertMessages(
|
|
||||||
"assertAlmostEqual",
|
|
||||||
(1, 2),
|
|
||||||
[
|
|
||||||
r"^1 != 2 within 7 places \(1 difference\)$",
|
|
||||||
"^oops$",
|
|
||||||
r"^1 != 2 within 7 places \(1 difference\)$",
|
|
||||||
r"^1 != 2 within 7 places \(1 difference\) : oops$",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
@make_dynamo_test
|
|
||||||
def testNotAlmostEqual(self):
|
|
||||||
self.assertMessages(
|
|
||||||
"assertNotAlmostEqual",
|
|
||||||
(1, 1),
|
|
||||||
[
|
|
||||||
"^1 == 1 within 7 places$",
|
|
||||||
"^oops$",
|
|
||||||
"^1 == 1 within 7 places$",
|
|
||||||
"^1 == 1 within 7 places : oops$",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
@make_dynamo_test
|
|
||||||
def test_baseAssertEqual(self):
|
|
||||||
self.assertMessages(
|
|
||||||
"_baseAssertEqual",
|
|
||||||
(1, 2),
|
|
||||||
["^1 != 2$", "^oops$", "^1 != 2$", "^1 != 2 : oops$"],
|
|
||||||
)
|
|
||||||
|
|
||||||
@unittest.expectedFailure
|
|
||||||
@make_dynamo_test
|
|
||||||
def testAssertSequenceEqual(self):
|
|
||||||
# Error messages are multiline so not testing on full message
|
|
||||||
# assertTupleEqual and assertListEqual delegate to this method
|
|
||||||
self.assertMessages(
|
|
||||||
"assertSequenceEqual",
|
|
||||||
([], [None]),
|
|
||||||
[r"\+ \[None\]$", "^oops$", r"\+ \[None\]$", r"\+ \[None\] : oops$"],
|
|
||||||
)
|
|
||||||
|
|
||||||
@make_dynamo_test
|
|
||||||
def testAssertSetEqual(self):
|
|
||||||
self.assertMessages(
|
|
||||||
"assertSetEqual",
|
|
||||||
(set(), set([None])), # noqa: C405
|
|
||||||
["None$", "^oops$", "None$", "None : oops$"],
|
|
||||||
)
|
|
||||||
|
|
||||||
@make_dynamo_test
|
|
||||||
def testAssertIn(self):
|
|
||||||
self.assertMessages(
|
|
||||||
"assertIn",
|
|
||||||
(None, []),
|
|
||||||
[
|
|
||||||
r"^None not found in \[\]$",
|
|
||||||
"^oops$",
|
|
||||||
r"^None not found in \[\]$",
|
|
||||||
r"^None not found in \[\] : oops$",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
@make_dynamo_test
|
|
||||||
def testAssertNotIn(self):
|
|
||||||
self.assertMessages(
|
|
||||||
"assertNotIn",
|
|
||||||
(None, [None]),
|
|
||||||
[
|
|
||||||
r"^None unexpectedly found in \[None\]$",
|
|
||||||
"^oops$",
|
|
||||||
r"^None unexpectedly found in \[None\]$",
|
|
||||||
r"^None unexpectedly found in \[None\] : oops$",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
@unittest.expectedFailure
|
|
||||||
@make_dynamo_test
|
|
||||||
def testAssertDictEqual(self):
|
|
||||||
self.assertMessages(
|
|
||||||
"assertDictEqual",
|
|
||||||
({}, {"key": "value"}),
|
|
||||||
[
|
|
||||||
r"\+ \{'key': 'value'\}$",
|
|
||||||
"^oops$",
|
|
||||||
r"\+ \{'key': 'value'\}$",
|
|
||||||
r"\+ \{'key': 'value'\} : oops$",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
@unittest.expectedFailure
|
|
||||||
@make_dynamo_test
|
|
||||||
def testAssertMultiLineEqual(self):
|
|
||||||
self.assertMessages(
|
|
||||||
"assertMultiLineEqual",
|
|
||||||
("", "foo"),
|
|
||||||
[r"\+ foo\n$", "^oops$", r"\+ foo\n$", r"\+ foo\n : oops$"],
|
|
||||||
)
|
|
||||||
|
|
||||||
@make_dynamo_test
|
|
||||||
def testAssertLess(self):
|
|
||||||
self.assertMessages(
|
|
||||||
"assertLess",
|
|
||||||
(2, 1),
|
|
||||||
[
|
|
||||||
"^2 not less than 1$",
|
|
||||||
"^oops$",
|
|
||||||
"^2 not less than 1$",
|
|
||||||
"^2 not less than 1 : oops$",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
@make_dynamo_test
|
|
||||||
def testAssertLessEqual(self):
|
|
||||||
self.assertMessages(
|
|
||||||
"assertLessEqual",
|
|
||||||
(2, 1),
|
|
||||||
[
|
|
||||||
"^2 not less than or equal to 1$",
|
|
||||||
"^oops$",
|
|
||||||
"^2 not less than or equal to 1$",
|
|
||||||
"^2 not less than or equal to 1 : oops$",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
@make_dynamo_test
|
|
||||||
def testAssertGreater(self):
|
|
||||||
self.assertMessages(
|
|
||||||
"assertGreater",
|
|
||||||
(1, 2),
|
|
||||||
[
|
|
||||||
"^1 not greater than 2$",
|
|
||||||
"^oops$",
|
|
||||||
"^1 not greater than 2$",
|
|
||||||
"^1 not greater than 2 : oops$",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
@make_dynamo_test
|
|
||||||
def testAssertGreaterEqual(self):
|
|
||||||
self.assertMessages(
|
|
||||||
"assertGreaterEqual",
|
|
||||||
(1, 2),
|
|
||||||
[
|
|
||||||
"^1 not greater than or equal to 2$",
|
|
||||||
"^oops$",
|
|
||||||
"^1 not greater than or equal to 2$",
|
|
||||||
"^1 not greater than or equal to 2 : oops$",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
@make_dynamo_test
|
|
||||||
def testAssertIsNone(self):
|
|
||||||
self.assertMessages(
|
|
||||||
"assertIsNone",
|
|
||||||
("not None",),
|
|
||||||
[
|
|
||||||
"^'not None' is not None$",
|
|
||||||
"^oops$",
|
|
||||||
"^'not None' is not None$",
|
|
||||||
"^'not None' is not None : oops$",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
@make_dynamo_test
|
|
||||||
def testAssertIsNotNone(self):
|
|
||||||
self.assertMessages(
|
|
||||||
"assertIsNotNone",
|
|
||||||
(None,),
|
|
||||||
[
|
|
||||||
"^unexpectedly None$",
|
|
||||||
"^oops$",
|
|
||||||
"^unexpectedly None$",
|
|
||||||
"^unexpectedly None : oops$",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
@make_dynamo_test
|
|
||||||
def testAssertIs(self):
|
|
||||||
self.assertMessages(
|
|
||||||
"assertIs",
|
|
||||||
(None, "foo"),
|
|
||||||
[
|
|
||||||
"^None is not 'foo'$",
|
|
||||||
"^oops$",
|
|
||||||
"^None is not 'foo'$",
|
|
||||||
"^None is not 'foo' : oops$",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
@make_dynamo_test
|
|
||||||
def testAssertIsNot(self):
|
|
||||||
self.assertMessages(
|
|
||||||
"assertIsNot",
|
|
||||||
(None, None),
|
|
||||||
[
|
|
||||||
"^unexpectedly identical: None$",
|
|
||||||
"^oops$",
|
|
||||||
"^unexpectedly identical: None$",
|
|
||||||
"^unexpectedly identical: None : oops$",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
@make_dynamo_test
|
|
||||||
def testAssertRegex(self):
|
|
||||||
self.assertMessages(
|
|
||||||
"assertRegex",
|
|
||||||
("foo", "bar"),
|
|
||||||
[
|
|
||||||
"^Regex didn't match:",
|
|
||||||
"^oops$",
|
|
||||||
"^Regex didn't match:",
|
|
||||||
"^Regex didn't match: (.*) : oops$",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
@make_dynamo_test
|
|
||||||
def testAssertNotRegex(self):
|
|
||||||
self.assertMessages(
|
|
||||||
"assertNotRegex",
|
|
||||||
("foo", "foo"),
|
|
||||||
[
|
|
||||||
"^Regex matched:",
|
|
||||||
"^oops$",
|
|
||||||
"^Regex matched:",
|
|
||||||
"^Regex matched: (.*) : oops$",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
def assertMessagesCM(self, methodName, args, func, errors):
|
|
||||||
"""
|
|
||||||
Check that the correct error messages are raised while executing:
|
|
||||||
with method(*args):
|
|
||||||
func()
|
|
||||||
*errors* should be a list of 4 regex that match the error when:
|
|
||||||
1) longMessage = False and no msg passed;
|
|
||||||
2) longMessage = False and msg passed;
|
|
||||||
3) longMessage = True and no msg passed;
|
|
||||||
4) longMessage = True and msg passed;
|
|
||||||
"""
|
|
||||||
p = product((self.testableFalse, self.testableTrue), ({}, {"msg": "oops"}))
|
|
||||||
for (cls, kwargs), err in zip(p, errors):
|
|
||||||
method = getattr(cls, methodName)
|
|
||||||
# with self.assertRaisesRegex(cls.failureException, err):
|
|
||||||
with self.assertRaises(cls.failureException) as c:
|
|
||||||
with method(*args, **kwargs) as cm: # noqa: F841
|
|
||||||
func()
|
|
||||||
self.assertRegex(str(c.exception), err)
|
|
||||||
|
|
||||||
@make_dynamo_test
|
|
||||||
def testAssertRaises(self):
|
|
||||||
self.assertMessagesCM(
|
|
||||||
"assertRaises",
|
|
||||||
(TypeError,),
|
|
||||||
lambda: None,
|
|
||||||
[
|
|
||||||
"^TypeError not raised$",
|
|
||||||
"^oops$",
|
|
||||||
"^TypeError not raised$",
|
|
||||||
"^TypeError not raised : oops$",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
@unittest.expectedFailure
|
|
||||||
@make_dynamo_test
|
|
||||||
def testAssertRaisesRegex(self):
|
|
||||||
self.assertMessagesCM(
|
|
||||||
"assertRaisesRegex",
|
|
||||||
(TypeError, "unused regex"),
|
|
||||||
lambda: None,
|
|
||||||
[
|
|
||||||
"^TypeError not raised$",
|
|
||||||
"^oops$",
|
|
||||||
"^TypeError not raised$",
|
|
||||||
"^TypeError not raised : oops$",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
# test error raised but with wrong message
|
|
||||||
def raise_wrong_message():
|
|
||||||
raise TypeError("foo")
|
|
||||||
|
|
||||||
self.assertMessagesCM(
|
|
||||||
"assertRaisesRegex",
|
|
||||||
(TypeError, "regex"),
|
|
||||||
raise_wrong_message,
|
|
||||||
[
|
|
||||||
'^"regex" does not match "foo"$',
|
|
||||||
"^oops$",
|
|
||||||
'^"regex" does not match "foo"$',
|
|
||||||
'^"regex" does not match "foo" : oops$',
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
@unittest.expectedFailure
|
|
||||||
@make_dynamo_test
|
|
||||||
def testAssertWarns(self):
|
|
||||||
self.assertMessagesCM(
|
|
||||||
"assertWarns",
|
|
||||||
(UserWarning,),
|
|
||||||
lambda: None,
|
|
||||||
[
|
|
||||||
"^UserWarning not triggered$",
|
|
||||||
"^oops$",
|
|
||||||
"^UserWarning not triggered$",
|
|
||||||
"^UserWarning not triggered : oops$",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
@unittest.expectedFailure
|
|
||||||
@unittest.skipIf(sys.version_info < (3, 13), "feature landed in 3.13")
|
|
||||||
@make_dynamo_test
|
|
||||||
def test_assertNotWarns(self):
|
|
||||||
def warn_future():
|
|
||||||
warnings.warn("xyz", FutureWarning, stacklevel=2)
|
|
||||||
|
|
||||||
self.assertMessagesCM(
|
|
||||||
"_assertNotWarns",
|
|
||||||
(FutureWarning,),
|
|
||||||
warn_future,
|
|
||||||
[
|
|
||||||
"^FutureWarning triggered$",
|
|
||||||
"^oops$",
|
|
||||||
"^FutureWarning triggered$",
|
|
||||||
"^FutureWarning triggered : oops$",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
@unittest.expectedFailure
|
|
||||||
@make_dynamo_test
|
|
||||||
def testAssertWarnsRegex(self):
|
|
||||||
# test error not raised
|
|
||||||
self.assertMessagesCM(
|
|
||||||
"assertWarnsRegex",
|
|
||||||
(UserWarning, "unused regex"),
|
|
||||||
lambda: None,
|
|
||||||
[
|
|
||||||
"^UserWarning not triggered$",
|
|
||||||
"^oops$",
|
|
||||||
"^UserWarning not triggered$",
|
|
||||||
"^UserWarning not triggered : oops$",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
# test warning raised but with wrong message
|
|
||||||
def raise_wrong_message():
|
|
||||||
warnings.warn("foo")
|
|
||||||
|
|
||||||
self.assertMessagesCM(
|
|
||||||
"assertWarnsRegex",
|
|
||||||
(UserWarning, "regex"),
|
|
||||||
raise_wrong_message,
|
|
||||||
[
|
|
||||||
'^"regex" does not match "foo"$',
|
|
||||||
"^oops$",
|
|
||||||
'^"regex" does not match "foo"$',
|
|
||||||
'^"regex" does not match "foo" : oops$',
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from torch._dynamo.test_case import run_tests
|
from torch._dynamo.test_case import run_tests
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1593,6 +1593,13 @@ def get_selected_tests(options) -> list[str]:
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if sys.version_info[:2] < (3, 13):
|
||||||
|
# Skip tests for older Python versions as they may use syntax or features
|
||||||
|
# not supported in those versions
|
||||||
|
options.exclude.extend(
|
||||||
|
[test for test in selected_tests if test.startswith("dynamo/cpython/3_13/")]
|
||||||
|
)
|
||||||
|
|
||||||
selected_tests = exclude_tests(options.exclude, selected_tests)
|
selected_tests = exclude_tests(options.exclude, selected_tests)
|
||||||
|
|
||||||
if sys.platform == "win32" and not options.ignore_win_blocklist:
|
if sys.platform == "win32" and not options.ignore_win_blocklist:
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
# mypy: allow-untyped-defs
|
||||||
|
|
||||||
"""Testing utilities for Dynamo, providing a specialized TestCase class and test running functionality.
|
"""Testing utilities for Dynamo, providing a specialized TestCase class and test running functionality.
|
||||||
|
|
||||||
This module extends PyTorch's testing framework with Dynamo-specific testing capabilities.
|
This module extends PyTorch's testing framework with Dynamo-specific testing capabilities.
|
||||||
|
|
@ -10,8 +12,13 @@ It includes:
|
||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
import importlib
|
import importlib
|
||||||
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import pathlib
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
import unittest
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
@ -98,7 +105,70 @@ class TestCase(TorchTestCase):
|
||||||
|
|
||||||
|
|
||||||
class CPythonTestCase(TestCase):
|
class CPythonTestCase(TestCase):
|
||||||
|
"""
|
||||||
|
Test class for CPython tests located in "test/dynamo/CPython/Py_version/*".
|
||||||
|
|
||||||
|
This class enables specific features that are disabled by default, such as
|
||||||
|
tracing through unittest methods.
|
||||||
|
"""
|
||||||
|
|
||||||
_stack: contextlib.ExitStack
|
_stack: contextlib.ExitStack
|
||||||
|
dynamo_strict_nopython = True
|
||||||
|
|
||||||
|
# Restore original unittest methods to simplify tracing CPython test cases.
|
||||||
|
assertEqual = unittest.TestCase.assertEqual # type: ignore[assignment]
|
||||||
|
assertNotEqual = unittest.TestCase.assertNotEqual # type: ignore[assignment]
|
||||||
|
assertTrue = unittest.TestCase.assertTrue
|
||||||
|
assertFalse = unittest.TestCase.assertFalse
|
||||||
|
assertIs = unittest.TestCase.assertIs
|
||||||
|
assertIsNot = unittest.TestCase.assertIsNot
|
||||||
|
assertIsNone = unittest.TestCase.assertIsNone
|
||||||
|
assertIsNotNone = unittest.TestCase.assertIsNotNone
|
||||||
|
assertIn = unittest.TestCase.assertIn
|
||||||
|
assertNotIn = unittest.TestCase.assertNotIn
|
||||||
|
assertIsInstance = unittest.TestCase.assertIsInstance
|
||||||
|
assertNotIsInstance = unittest.TestCase.assertNotIsInstance
|
||||||
|
assertAlmostEqual = unittest.TestCase.assertAlmostEqual
|
||||||
|
assertNotAlmostEqual = unittest.TestCase.assertNotAlmostEqual
|
||||||
|
assertGreater = unittest.TestCase.assertGreater
|
||||||
|
assertGreaterEqual = unittest.TestCase.assertGreaterEqual
|
||||||
|
assertLess = unittest.TestCase.assertLess
|
||||||
|
assertLessEqual = unittest.TestCase.assertLessEqual
|
||||||
|
assertRegex = unittest.TestCase.assertRegex
|
||||||
|
assertNotRegex = unittest.TestCase.assertNotRegex
|
||||||
|
assertCountEqual = unittest.TestCase.assertCountEqual
|
||||||
|
assertMultiLineEqual = unittest.TestCase.assertMultiLineEqual
|
||||||
|
assertSequenceEqual = unittest.TestCase.assertSequenceEqual
|
||||||
|
assertListEqual = unittest.TestCase.assertListEqual
|
||||||
|
assertTupleEqual = unittest.TestCase.assertTupleEqual
|
||||||
|
assertSetEqual = unittest.TestCase.assertSetEqual
|
||||||
|
assertDictEqual = unittest.TestCase.assertDictEqual
|
||||||
|
assertRaises = unittest.TestCase.assertRaises
|
||||||
|
assertRaisesRegex = unittest.TestCase.assertRaisesRegex
|
||||||
|
assertWarns = unittest.TestCase.assertWarns
|
||||||
|
assertWarnsRegex = unittest.TestCase.assertWarnsRegex
|
||||||
|
assertLogs = unittest.TestCase.assertLogs
|
||||||
|
fail = unittest.TestCase.fail
|
||||||
|
failureException = unittest.TestCase.failureException
|
||||||
|
|
||||||
|
def compile_fn(self, fn, backend, nopython):
|
||||||
|
# We want to compile only the test function, excluding any setup code
|
||||||
|
# from unittest
|
||||||
|
method = getattr(self, self._testMethodName)
|
||||||
|
method = torch._dynamo.optimize(backend, nopython=nopython)(method)
|
||||||
|
setattr(self, self._testMethodName, method)
|
||||||
|
return fn
|
||||||
|
|
||||||
|
def _dynamo_test_key(self):
|
||||||
|
suffix = super()._dynamo_test_key()
|
||||||
|
test_cls = self.__class__
|
||||||
|
test_file = inspect.getfile(test_cls).split(os.sep)[-1].split(".")[0]
|
||||||
|
py_ver = re.search(r"/([\d_]+)/", inspect.getfile(test_cls))
|
||||||
|
if py_ver:
|
||||||
|
py_ver = py_ver.group().strip(os.sep).replace("_", "") # type: ignore[assignment]
|
||||||
|
else:
|
||||||
|
return suffix
|
||||||
|
return f"CPython{py_ver}-{test_file}-{suffix}"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def tearDownClass(cls) -> None:
|
def tearDownClass(cls) -> None:
|
||||||
|
|
@ -107,6 +177,24 @@ class CPythonTestCase(TestCase):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls) -> None:
|
def setUpClass(cls) -> None:
|
||||||
|
# Skip test if python versions doesn't match
|
||||||
|
normalized_path = pathlib.PurePath("dynamo/cpython").as_posix()
|
||||||
|
regex = re.escape(normalized_path) + r"\b\d+_\d{2}\b"
|
||||||
|
m = re.search(regex, inspect.getfile(cls))
|
||||||
|
if m:
|
||||||
|
test_py_ver = tuple(map(int, m.group().split("_")))
|
||||||
|
py_ver = sys.version_info[:2]
|
||||||
|
if py_ver != test_py_ver:
|
||||||
|
expected = ".".join(map(str, test_py_ver))
|
||||||
|
got = ".".join(map(str, py_ver))
|
||||||
|
raise unittest.SkipTest(
|
||||||
|
f"Test requires Python {expected} but got Python {got}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise unittest.SkipTest(
|
||||||
|
f"Test requires a specific Python version but not found in path {inspect.getfile(cls)}"
|
||||||
|
)
|
||||||
|
|
||||||
super().setUpClass()
|
super().setUpClass()
|
||||||
cls._stack = contextlib.ExitStack() # type: ignore[attr-defined]
|
cls._stack = contextlib.ExitStack() # type: ignore[attr-defined]
|
||||||
cls._stack.enter_context( # type: ignore[attr-defined]
|
cls._stack.enter_context( # type: ignore[attr-defined]
|
||||||
|
|
|
||||||
|
|
@ -1989,11 +1989,11 @@ class BuiltinVariable(VariableTracker):
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
unimplemented_v2(
|
unimplemented_v2(
|
||||||
gb_type="Failed to trace builtin operator",
|
gb_type="Failed to trace unittest method",
|
||||||
context=f"function: unittest.TestCase.{name}",
|
context=f"function: unittest.TestCase.{name}",
|
||||||
explanation=f"Dynamo does not know how to trace builtin operator `{name}` ",
|
explanation=f"Dynamo does not know how to trace unittest method `{name}` ",
|
||||||
hints=[
|
hints=[
|
||||||
f"Avoid calling builtin `{name}`. "
|
f"Avoid calling `TestCase.{name}`. "
|
||||||
"Please report an issue to PyTorch.",
|
"Please report an issue to PyTorch.",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -3157,6 +3157,13 @@ class TestCase(expecttest.TestCase):
|
||||||
def wrap_with_cuda_memory_check(self, method):
|
def wrap_with_cuda_memory_check(self, method):
|
||||||
return self.wrap_method_with_policy(method, self.assertLeaksNoCudaTensors)
|
return self.wrap_method_with_policy(method, self.assertLeaksNoCudaTensors)
|
||||||
|
|
||||||
|
def _dynamo_test_key(self):
|
||||||
|
return f"{self.__class__.__name__}.{self._testMethodName}"
|
||||||
|
|
||||||
|
def compile_fn(self, fn, backend, nopython):
|
||||||
|
# Allows subclasses to control compilation
|
||||||
|
return torch._dynamo.optimize(backend, nopython=nopython)(fn)
|
||||||
|
|
||||||
def _run_custom(self, result=None):
|
def _run_custom(self, result=None):
|
||||||
using_unittest = isinstance(result, unittest.TestResult)
|
using_unittest = isinstance(result, unittest.TestResult)
|
||||||
|
|
||||||
|
|
@ -3232,16 +3239,16 @@ class TestCase(expecttest.TestCase):
|
||||||
|
|
||||||
with unittest.mock.patch("torch._dynamo.config.suppress_errors", suppress_errors), maybe_disable_size_asserts:
|
with unittest.mock.patch("torch._dynamo.config.suppress_errors", suppress_errors), maybe_disable_size_asserts:
|
||||||
if TEST_WITH_AOT_EAGER:
|
if TEST_WITH_AOT_EAGER:
|
||||||
super_run = torch._dynamo.optimize("aot_eager_decomp_partition")(super_run)
|
super_run = self.compile_fn(super_run, "aot_eager_decomp_partition", nopython)
|
||||||
elif TEST_WITH_TORCHDYNAMO or TEST_WITH_TORCHINDUCTOR:
|
elif TEST_WITH_TORCHDYNAMO or TEST_WITH_TORCHINDUCTOR:
|
||||||
if TEST_WITH_TORCHINDUCTOR:
|
if TEST_WITH_TORCHINDUCTOR:
|
||||||
super_run = torch._dynamo.optimize("inductor")(super_run)
|
super_run = self.compile_fn(super_run, "inductor", nopython)
|
||||||
else:
|
else:
|
||||||
# Assume eager-generated GraphModules will not error out.
|
# Assume eager-generated GraphModules will not error out.
|
||||||
# If we do, this is probably a Dynamo bug!
|
# If we do, this is probably a Dynamo bug!
|
||||||
super_run = torch._dynamo.optimize("eager_noexcept", nopython=nopython)(super_run)
|
super_run = self.compile_fn(super_run, "eager_noexcept", nopython)
|
||||||
|
|
||||||
key = f"{self.__class__.__name__}.{self._testMethodName}"
|
key = self._dynamo_test_key()
|
||||||
|
|
||||||
def expect_failure(f, file_name):
|
def expect_failure(f, file_name):
|
||||||
@wraps(f)
|
@wraps(f)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user