mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146498 Approved by: https://github.com/anijain2305, https://github.com/zou3519
111 lines
2.9 KiB
Python
111 lines
2.9 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
import sys
|
|
import unittest
|
|
|
|
import torch
|
|
import torch._dynamo.test_case
|
|
from torch.testing._internal.common_utils import make_dynamo_test
|
|
|
|
|
|
class SysTests(torch._dynamo.test_case.TestCase):
|
|
def test_exc_info(self):
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(t):
|
|
try:
|
|
raise ValueError
|
|
except Exception:
|
|
typ, _, _ = sys.exc_info()
|
|
if typ is ValueError:
|
|
return t.sin()
|
|
else:
|
|
return t.cos()
|
|
|
|
t = torch.randn(2)
|
|
y = fn(t)
|
|
self.assertEqual(y, t.sin())
|
|
|
|
|
|
class CPythonActiveExceptionTests(torch._dynamo.test_case.TestCase):
|
|
# Tests taken from CPython source code in cpython/Lib/test/test_sys.py
|
|
# https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_sys.py
|
|
@make_dynamo_test
|
|
def test_exc_info_no_exception(self):
|
|
self.assertEqual(sys.exc_info(), (None, None, None))
|
|
|
|
@unittest.skipIf(sys.version_info < (3, 11), "Python 3.11+")
|
|
@make_dynamo_test
|
|
def test_sys_exception_no_exception(self):
|
|
self.assertEqual(sys.exception(), None)
|
|
|
|
@unittest.expectedFailure
|
|
@make_dynamo_test
|
|
def test_exc_info_with_exception_instance(self):
|
|
def f():
|
|
raise ValueError(42)
|
|
|
|
try:
|
|
f()
|
|
except Exception as e_:
|
|
e = e_
|
|
exc_info = sys.exc_info()
|
|
|
|
self.assertIsInstance(e, ValueError)
|
|
self.assertIs(exc_info[0], ValueError)
|
|
self.assertIs(exc_info[1], e)
|
|
self.assertIs(exc_info[2], e.__traceback__)
|
|
|
|
@unittest.expectedFailure
|
|
@make_dynamo_test
|
|
def test_exc_info_with_exception_type(self):
|
|
def f():
|
|
raise ValueError
|
|
|
|
try:
|
|
f()
|
|
except Exception as e_:
|
|
e = e_
|
|
exc_info = sys.exc_info()
|
|
|
|
self.assertIsInstance(e, ValueError)
|
|
self.assertIs(exc_info[0], ValueError)
|
|
self.assertIs(exc_info[1], e)
|
|
self.assertIs(exc_info[2], e.__traceback__)
|
|
|
|
@unittest.expectedFailure
|
|
@unittest.skipIf(sys.version_info < (3, 11), "Python 3.11+")
|
|
@make_dynamo_test
|
|
def test_sys_exception_with_exception_instance(self):
|
|
def f():
|
|
raise ValueError(42)
|
|
|
|
try:
|
|
f()
|
|
except Exception as e_:
|
|
e = e_
|
|
exc = sys.exception()
|
|
|
|
self.assertIsInstance(e, ValueError)
|
|
self.assertIs(exc, e)
|
|
|
|
@unittest.expectedFailure
|
|
@unittest.skipIf(sys.version_info < (3, 11), "Python 3.11+")
|
|
@make_dynamo_test
|
|
def test_sys_exception_with_exception_type(self):
|
|
def f():
|
|
raise ValueError
|
|
|
|
try:
|
|
f()
|
|
except Exception as e_:
|
|
e = e_
|
|
exc = sys.exception()
|
|
|
|
self.assertIsInstance(e, ValueError)
|
|
self.assertIs(exc, e)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
run_tests()
|