Allow trace through unittest (#146500)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146500
Approved by: https://github.com/anijain2305
This commit is contained in:
Guilherme Leobas 2025-04-07 19:59:29 +00:00 committed by PyTorch MergeBot
parent 1791b4150b
commit f3b2fb6c66
15 changed files with 739 additions and 113 deletions

View File

@ -1744,10 +1744,13 @@ class GraphModule(torch.nn.Module):
class ContextlibContextManagerTests(torch._dynamo.test_case.TestCase):
def setUp(self):
self._prev = torch._dynamo.config.enable_trace_contextlib
self._u_prev = torch._dynamo.config.enable_trace_unittest
torch._dynamo.config.enable_trace_contextlib = True
torch._dynamo.config.enable_trace_unittest = True
def tearDown(self):
torch._dynamo.config.enable_trace_contextlib = self._prev
torch._dynamo.config.enable_trace_unittest = self._u_prev
def test_ctx_basic0(self):
@contextlib.contextmanager
@ -2691,7 +2694,7 @@ class GraphModule(torch.nn.Module):
self.assertEqual(y, t.sin())
class CPythonContextManagerTestCase(torch._dynamo.test_case.TestCase):
class CPythonContextManagerTestCase(torch._dynamo.test_case.CPythonTestCase):
# Tests taken from CPython source code in cpython/Lib/test/test_contextlib.py
# https://github.com/python/cpython/blob/d48cc82ed25e26b02eb97c6263d95dcaa1e9111b/Lib/test/test_contextlib.py#L70
@ -2721,7 +2724,6 @@ class CPythonContextManagerTestCase(torch._dynamo.test_case.TestCase):
self.assertEqual(state, [1, 42, 999])
self.assertEqual(y, t.sum() + 42)
@unittest.expectedFailure
def test_contextmanager_finally(self):
state = []
@ -2831,7 +2833,6 @@ class CPythonContextManagerTestCase(torch._dynamo.test_case.TestCase):
self.assertEqual(frames[0].name, "test_contextmanager_traceback")
self.assertEqual(frames[0].line, "raise stop_exc")
@unittest.expectedFailure
def test_contextmanager_no_reraise(self):
@contextmanager
def whee():
@ -2847,7 +2848,6 @@ class CPythonContextManagerTestCase(torch._dynamo.test_case.TestCase):
fn(torch.randn(2, 3))
@unittest.expectedFailure
def test_contextmanager_trap_yield_after_throw(self):
@contextmanager
def whoo():
@ -2866,7 +2866,6 @@ class CPythonContextManagerTestCase(torch._dynamo.test_case.TestCase):
fn(torch.randn(2, 3))
@unittest.expectedFailure
def test_contextmanager_trap_no_yield(self):
@contextmanager
def whoo():
@ -2882,7 +2881,6 @@ class CPythonContextManagerTestCase(torch._dynamo.test_case.TestCase):
fn(torch.randn(2, 3))
@unittest.expectedFailure
def test_contextmanager_trap_second_yield(self):
@contextmanager
def whoo():

View File

@ -307,7 +307,7 @@ Attempted to inline function marked as skipped
Hint: Remove the function `case.py` from torch/_dynamo/trace_rules.py. More graph breaks may occur as a result of attempting to trace into the function.
Hint: Please file an issue to PyTorch.
Developer debug context: qualname: skip, name: skip, filename: `case.py`, skip reason: skipped according trace_rules.lookup SKIP_DIRS
Developer debug context: qualname: skip, name: skip, filename: `case.py`, skip reason: skipped according trace_rules.lookup unittest
from user code:

View File

@ -177,7 +177,6 @@ class ExceptionTests(torch._dynamo.test_case.TestCase):
res = opt_fn(x)
self.assertEqual(ref, res)
@unittest.skipIf(sys.version_info < (3, 11), "Python 3.11+")
@make_dynamo_test
def test_raise_match(self):
a = AttributeError
@ -259,7 +258,6 @@ class ExceptionTests(torch._dynamo.test_case.TestCase):
opt_fn = torch.compile(fn, backend="eager")
opt_fn(x)
@unittest.skipIf(sys.version_info < (3, 11), "Python 3.11+")
def test_exception_with_ctx_manager(self):
def fn(x):
x = torch.cos(x)
@ -853,7 +851,6 @@ class ExceptionTests(torch._dynamo.test_case.TestCase):
t = torch.randn(2)
fn(t)
@unittest.skipIf(sys.version_info < (3, 11), "Python 3.11+")
def test_user_defined_exception_with_args(self):
@torch.compile(backend="eager", fullgraph=True)
def fn(t):
@ -889,6 +886,12 @@ class ExceptionTests(torch._dynamo.test_case.TestCase):
class CPythonExceptionTests(torch._dynamo.test_case.TestCase):
# Tests taken from CPython source code in cpython/Lib/test/test_exceptions.py
# https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_exceptions.py
def setUp(self):
self._u_prev = torch._dynamo.config.enable_trace_unittest
torch._dynamo.config.enable_trace_unittest = True
def tearDown(self):
torch._dynamo.config.enable_trace_unittest = self._u_prev
@make_dynamo_test
def testChainingAttrs(self):
@ -976,7 +979,6 @@ class CPythonExceptionTests(torch._dynamo.test_case.TestCase):
assert exc is oe
assert exc.__context__ is ve
@unittest.expectedFailure
@make_dynamo_test
def test_raise_does_not_create_context_chain_cycle(self):
A = AssertionError
@ -1015,7 +1017,6 @@ class CPythonExceptionTests(torch._dynamo.test_case.TestCase):
self.assertIs(c.__context__, b)
self.assertIsNone(b.__context__)
@unittest.expectedFailure
@make_dynamo_test
def test_no_hang_on_context_chain_cycle1(self):
# See issue 25782. Cycle in context chain.
@ -1071,7 +1072,6 @@ class CPythonExceptionTests(torch._dynamo.test_case.TestCase):
self.assertIs(b.__context__, a)
self.assertIs(a.__context__, c)
@unittest.expectedFailure
@make_dynamo_test
def test_no_hang_on_context_chain_cycle3(self):
# See issue 25782. Longer context chain with cycle.

View File

@ -8,19 +8,9 @@ import torch._dynamo.test_case
from torch.testing._internal.common_utils import make_dynamo_test
class TestPEP479(torch._dynamo.test_case.TestCase):
class TestPEP479(torch._dynamo.test_case.CPythonTestCase):
# Tests taken from CPython source code in cpython/Lib/test/test_generator_stop.py
# https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_generator_stop.py
def assertTrue(self, expr, msg=None):
assert bool(expr) is True, msg
def assertIs(self, expr1, expr2, msg=None):
assert expr1 is expr2, msg
def assertEqual(self, x, y):
assert x == y
@unittest.skipIf(sys.version_info < (3, 12), "Test does not work in Python < 3.12")
@make_dynamo_test
def test_stopiteration_wrapping(self):
@ -30,16 +20,9 @@ class TestPEP479(torch._dynamo.test_case.TestCase):
def g():
yield f()
try:
with self.assertRaises(RuntimeError) as cm:
next(g())
except RuntimeError as cm:
self.assertEqual("generator raised StopIteration", cm.args[0])
except Exception:
self.fail("Error!")
# with self.assertRaises(RuntimeError) as cm:
# next(g())
# self.assertEqual("generator raised StopIteration", str(cm.exception))
self.assertEqual("generator raised StopIteration", str(cm.exception))
@unittest.skipIf(sys.version_info < (3, 12), "Test does not work in Python < 3.12")
@make_dynamo_test

View File

@ -44,32 +44,9 @@ class ContextManager:
raise NameError
class TestRaise(torch._dynamo.test_case.TestCase):
class TestRaise(torch._dynamo.test_case.CPythonTestCase):
# Tests taken from CPython source code in cpython/Lib/test/test_raise.py
# https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_raise.py
def assertIn(self, member, container, msg=None):
assert member in container, msg
def assertIs(self, expr1, expr2, msg=None):
assert expr1 is expr2, msg
def assertRaises(self, expected_exception, *args, **kwargs):
z = 0
try:
yield
except expected_exception:
z = 1
except Exception:
z = 2
assert z == 1
def assertIsInstance(self, obj, cls, msg=None):
assert isinstance(obj, cls), msg
def assertIsNone(self, obj, msg=None):
assert obj is None, msg
@make_dynamo_test
def test_invalid_reraise(self):
try:
@ -213,34 +190,12 @@ class TestRaise(torch._dynamo.test_case.TestCase):
class TestCause(torch._dynamo.test_case.TestCase):
# Tests taken from CPython source code in cpython/Lib/test/test_raise.py
# https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_raise.py
def setUp(self):
self._prev = torch._dynamo.config.enable_trace_unittest
torch._dynamo.config.enable_trace_unittest = True
def assertIn(self, member, container, msg=None):
assert member in container, msg
def assertIs(self, expr1, expr2, msg=None):
assert expr1 is expr2, msg
def assertRaises(self, expected_exception, *args, **kwargs):
z = 0
try:
yield
except expected_exception:
z = 1
except Exception:
z = 2
assert z == 1
def assertIsInstance(self, obj, cls, msg=None):
assert isinstance(obj, cls), msg
def assertIsNone(self, obj, msg=None):
assert obj is None, msg
def assertTrue(self, expr, msg=None):
assert bool(expr) is True, msg
def assertFalse(self, expr, msg=None):
assert bool(expr) is False, msg
def tearDown(self):
torch._dynamo.config.enable_trace_unittest = self._prev
@make_dynamo_test
def testCauseSyntax(self):
@ -303,6 +258,12 @@ class TestCause(torch._dynamo.test_case.TestCase):
class TestTraceback(torch._dynamo.test_case.TestCase):
# Tests taken from CPython source code in cpython/Lib/test/test_raise.py
# https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_raise.py
def setUp(self):
self._prev = torch._dynamo.config.enable_trace_unittest
torch._dynamo.config.enable_trace_unittest = True
def tearDown(self):
torch._dynamo.config.enable_trace_unittest = self._prev
@unittest.expectedFailure # Dynamo doesn't track traceback
@make_dynamo_test
@ -330,6 +291,12 @@ class TestTraceback(torch._dynamo.test_case.TestCase):
class TestTracebackType(torch._dynamo.test_case.TestCase):
# Tests taken from CPython source code in cpython/Lib/test/test_raise.py
# https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_raise.py
def setUp(self):
self._prev = torch._dynamo.config.enable_trace_unittest
torch._dynamo.config.enable_trace_unittest = True
def tearDown(self):
torch._dynamo.config.enable_trace_unittest = self._prev
def raiser(self):
raise ValueError
@ -402,28 +369,12 @@ class TestTracebackType(torch._dynamo.test_case.TestCase):
class TestContext(torch._dynamo.test_case.TestCase):
# Tests taken from CPython source code in cpython/Lib/test/test_raise.py
# https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_raise.py
def setUp(self):
self._prev = torch._dynamo.config.enable_trace_unittest
torch._dynamo.config.enable_trace_unittest = True
def assertIn(self, member, container, msg=None):
assert member in container, msg
def assertIs(self, expr1, expr2, msg=None):
assert expr1 is expr2, msg
def assertRaises(self, expected_exception, *args, **kwargs):
z = 0
try:
yield
except expected_exception:
z = 1
except Exception:
z = 2
assert z == 1
def assertIsInstance(self, obj, cls, msg=None):
assert isinstance(obj, cls), msg
def assertIsNone(self, obj, msg=None):
assert obj is None, msg
def tearDown(self):
torch._dynamo.config.enable_trace_unittest = self._prev
@unittest.expectedFailure # missing Exception.__eq__
@make_dynamo_test

View File

@ -25,9 +25,10 @@ class SysTests(torch._dynamo.test_case.TestCase):
self.assertEqual(y, t.sin())
class CPythonActiveExceptionTests(torch._dynamo.test_case.TestCase):
class CPythonActiveExceptionTests(torch._dynamo.test_case.CPythonTestCase):
# Tests taken from CPython source code in cpython/Lib/test/test_sys.py
# https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_sys.py
@make_dynamo_test
def test_exc_info_no_exception(self):
self.assertEqual(sys.exc_info(), (None, None, None))
@ -37,7 +38,6 @@ class CPythonActiveExceptionTests(torch._dynamo.test_case.TestCase):
def test_sys_exception_no_exception(self):
self.assertEqual(sys.exception(), None)
@unittest.expectedFailure
@make_dynamo_test
def test_exc_info_with_exception_instance(self):
def f():
@ -54,7 +54,6 @@ class CPythonActiveExceptionTests(torch._dynamo.test_case.TestCase):
self.assertIs(exc_info[1], e)
self.assertIs(exc_info[2], e.__traceback__)
@unittest.expectedFailure
@make_dynamo_test
def test_exc_info_with_exception_type(self):
def f():
@ -71,7 +70,6 @@ class CPythonActiveExceptionTests(torch._dynamo.test_case.TestCase):
self.assertIs(exc_info[1], e)
self.assertIs(exc_info[2], e.__traceback__)
@unittest.expectedFailure
@unittest.skipIf(sys.version_info < (3, 11), "Python 3.11+")
@make_dynamo_test
def test_sys_exception_with_exception_instance(self):
@ -87,7 +85,6 @@ class CPythonActiveExceptionTests(torch._dynamo.test_case.TestCase):
self.assertIsInstance(e, ValueError)
self.assertIs(exc, e)
@unittest.expectedFailure
@unittest.skipIf(sys.version_info < (3, 11), "Python 3.11+")
@make_dynamo_test
def test_sys_exception_with_exception_type(self):

View File

@ -0,0 +1,619 @@
# Owner(s): ["module: dynamo"]
import sys
import unittest
import warnings
from itertools import product
import torch
import torch._dynamo.test_case
from torch.testing._internal.common_utils import make_dynamo_test
class TestUnittest(torch._dynamo.test_case.TestCase):
def setUp(self):
self._prev = torch._dynamo.config.enable_trace_unittest
torch._dynamo.config.enable_trace_unittest = True
def tearDown(self):
torch._dynamo.config.enable_trace_unittest = self._prev
@make_dynamo_test
def test_SkipTest(self):
z = 0
SkipTest = unittest.SkipTest
try:
raise SkipTest("abcd")
except Exception:
z = 1
self.assertEqual(z, 1)
class CPythonTest_Assertions(torch._dynamo.test_case.CPythonTestCase):
# Tests taken from CPython source code in cpython/Lib/test/test_unittest/test_assertions.py
# https://github.com/python/cpython/blob/3.13/Lib/test/test_unittest/test_assertions.py
@make_dynamo_test
def test_AlmostEqual(self):
self.assertAlmostEqual(1.00000001, 1.0)
self.assertNotAlmostEqual(1.0000001, 1.0)
self.assertRaises(self.failureException, self.assertAlmostEqual, 1.0000001, 1.0)
self.assertRaises(
self.failureException, self.assertNotAlmostEqual, 1.00000001, 1.0
)
self.assertAlmostEqual(1.1, 1.0, places=0)
self.assertRaises(
self.failureException, self.assertAlmostEqual, 1.1, 1.0, places=1
)
self.assertAlmostEqual(0, 0.1 + 0.1j, places=0)
self.assertNotAlmostEqual(0, 0.1 + 0.1j, places=1)
self.assertRaises(
self.failureException, self.assertAlmostEqual, 0, 0.1 + 0.1j, places=1
)
self.assertRaises(
self.failureException, self.assertNotAlmostEqual, 0, 0.1 + 0.1j, places=0
)
self.assertAlmostEqual(float("inf"), float("inf"))
self.assertRaises(
self.failureException, self.assertNotAlmostEqual, float("inf"), float("inf")
)
@make_dynamo_test
def test_AmostEqualWithDelta(self):
self.assertAlmostEqual(1.1, 1.0, delta=0.5)
self.assertAlmostEqual(1.0, 1.1, delta=0.5)
self.assertNotAlmostEqual(1.1, 1.0, delta=0.05)
self.assertNotAlmostEqual(1.0, 1.1, delta=0.05)
self.assertAlmostEqual(1.0, 1.0, delta=0.5)
self.assertRaises(
self.failureException, self.assertNotAlmostEqual, 1.0, 1.0, delta=0.5
)
self.assertRaises(
self.failureException, self.assertAlmostEqual, 1.1, 1.0, delta=0.05
)
self.assertRaises(
self.failureException, self.assertNotAlmostEqual, 1.1, 1.0, delta=0.5
)
self.assertRaises(
TypeError, self.assertAlmostEqual, 1.1, 1.0, places=2, delta=2
)
self.assertRaises(
TypeError, self.assertNotAlmostEqual, 1.1, 1.0, places=2, delta=2
)
@make_dynamo_test
def test_assertRaises(self):
def _raise(e):
raise e
self.assertRaises(KeyError, _raise, KeyError)
self.assertRaises(KeyError, _raise, KeyError("key"))
try:
self.assertRaises(KeyError, lambda: None)
except self.failureException as e:
self.assertIn("KeyError not raised", str(e))
else:
self.fail("assertRaises() didn't fail")
try:
self.assertRaises(KeyError, _raise, ValueError)
except ValueError:
pass
else:
self.fail("assertRaises() didn't let exception pass through")
with self.assertRaises(KeyError) as cm:
try:
raise KeyError
except Exception as e:
exc = e
raise
self.assertIs(cm.exception, exc)
with self.assertRaises(KeyError):
raise KeyError("key")
try:
with self.assertRaises(KeyError):
pass
except self.failureException as e:
self.assertIn("KeyError not raised", str(e))
else:
self.fail("assertRaises() didn't fail")
try:
with self.assertRaises(KeyError):
raise ValueError
except ValueError:
pass
else:
self.fail("assertRaises() didn't let exception pass through")
@make_dynamo_test
def testAssertNotRegex(self):
self.assertNotRegex("Ala ma kota", r"r+")
try:
self.assertNotRegex("Ala ma kota", r"k.t", "Message")
except self.failureException as e:
self.assertIn("Message", e.args[0])
else:
self.fail("assertNotRegex should have failed.")
class CPythonTestLongMessage(torch._dynamo.test_case.CPythonTestCase):
"""Test that the individual asserts honour longMessage.
This actually tests all the message behaviour for
asserts that use longMessage."""
def setUp(self):
super().setUp()
class TestableTestFalse(unittest.TestCase):
longMessage = False
failureException = self.failureException
def testTest(self):
pass
class TestableTestTrue(unittest.TestCase):
longMessage = True
failureException = self.failureException
def testTest(self):
pass
self.testableTrue = TestableTestTrue("testTest")
self.testableFalse = TestableTestFalse("testTest")
def testDefault(self):
self.assertTrue(unittest.TestCase.longMessage)
def test_formatMsg(self):
self.assertEqual(self.testableFalse._formatMessage(None, "foo"), "foo")
self.assertEqual(self.testableFalse._formatMessage("foo", "bar"), "foo")
self.assertEqual(self.testableTrue._formatMessage(None, "foo"), "foo")
self.assertEqual(self.testableTrue._formatMessage("foo", "bar"), "bar : foo")
# This blows up if _formatMessage uses string concatenation
self.testableTrue._formatMessage(object(), "foo")
def test_formatMessage_unicode_error(self):
one = "".join(chr(i) for i in range(255))
# this used to cause a UnicodeDecodeError constructing msg
self.testableTrue._formatMessage(one, "\uFFFD")
def assertMessages(self, methodName, args, errors):
"""
Check that methodName(*args) raises the correct error messages.
errors should be a list of 4 regex that match the error when:
1) longMessage = False and no msg passed;
2) longMessage = False and msg passed;
3) longMessage = True and no msg passed;
4) longMessage = True and msg passed;
"""
def getMethod(i):
useTestableFalse = i < 2
if useTestableFalse:
test = self.testableFalse
else:
test = self.testableTrue
return getattr(test, methodName)
for i, expected_regex in enumerate(errors):
testMethod = getMethod(i)
kwargs = {}
withMsg = i % 2
if withMsg:
kwargs = {"msg": "oops"}
# with self.assertRaisesRegex(
# self.failureException, expected_regex=expected_regex
# ):
# testMethod(*args, **kwargs)
with self.assertRaises(self.failureException) as cm:
testMethod(*args, **kwargs)
self.assertRegex(str(cm.exception), expected_regex)
@make_dynamo_test
def testAssertTrue(self):
self.assertMessages(
"assertTrue",
(False,),
[
"False is not true",
"oops",
"False is not true",
"False is not true : oops",
],
)
@make_dynamo_test
def testAssertFalse(self):
self.assertMessages(
"assertFalse",
(True,),
[
"True is not false",
"oops",
"True is not false",
"True is not false : oops",
],
)
@make_dynamo_test
def testNotEqual(self):
self.assertMessages(
"assertNotEqual", (1, 1), ["1 == 1", "oops", "1 == 1", "1 == 1 : oops"]
)
@make_dynamo_test
def testAlmostEqual(self):
self.assertMessages(
"assertAlmostEqual",
(1, 2),
[
r"^1 != 2 within 7 places \(1 difference\)$",
"^oops$",
r"^1 != 2 within 7 places \(1 difference\)$",
r"^1 != 2 within 7 places \(1 difference\) : oops$",
],
)
@make_dynamo_test
def testNotAlmostEqual(self):
self.assertMessages(
"assertNotAlmostEqual",
(1, 1),
[
"^1 == 1 within 7 places$",
"^oops$",
"^1 == 1 within 7 places$",
"^1 == 1 within 7 places : oops$",
],
)
@make_dynamo_test
def test_baseAssertEqual(self):
self.assertMessages(
"_baseAssertEqual",
(1, 2),
["^1 != 2$", "^oops$", "^1 != 2$", "^1 != 2 : oops$"],
)
@unittest.expectedFailure
@make_dynamo_test
def testAssertSequenceEqual(self):
# Error messages are multiline so not testing on full message
# assertTupleEqual and assertListEqual delegate to this method
self.assertMessages(
"assertSequenceEqual",
([], [None]),
[r"\+ \[None\]$", "^oops$", r"\+ \[None\]$", r"\+ \[None\] : oops$"],
)
@make_dynamo_test
def testAssertSetEqual(self):
self.assertMessages(
"assertSetEqual",
(set(), set([None])), # noqa: C405
["None$", "^oops$", "None$", "None : oops$"],
)
@make_dynamo_test
def testAssertIn(self):
self.assertMessages(
"assertIn",
(None, []),
[
r"^None not found in \[\]$",
"^oops$",
r"^None not found in \[\]$",
r"^None not found in \[\] : oops$",
],
)
@make_dynamo_test
def testAssertNotIn(self):
self.assertMessages(
"assertNotIn",
(None, [None]),
[
r"^None unexpectedly found in \[None\]$",
"^oops$",
r"^None unexpectedly found in \[None\]$",
r"^None unexpectedly found in \[None\] : oops$",
],
)
@unittest.expectedFailure
@make_dynamo_test
def testAssertDictEqual(self):
self.assertMessages(
"assertDictEqual",
({}, {"key": "value"}),
[
r"\+ \{'key': 'value'\}$",
"^oops$",
r"\+ \{'key': 'value'\}$",
r"\+ \{'key': 'value'\} : oops$",
],
)
@unittest.expectedFailure
@make_dynamo_test
def testAssertMultiLineEqual(self):
self.assertMessages(
"assertMultiLineEqual",
("", "foo"),
[r"\+ foo\n$", "^oops$", r"\+ foo\n$", r"\+ foo\n : oops$"],
)
@make_dynamo_test
def testAssertLess(self):
self.assertMessages(
"assertLess",
(2, 1),
[
"^2 not less than 1$",
"^oops$",
"^2 not less than 1$",
"^2 not less than 1 : oops$",
],
)
@make_dynamo_test
def testAssertLessEqual(self):
self.assertMessages(
"assertLessEqual",
(2, 1),
[
"^2 not less than or equal to 1$",
"^oops$",
"^2 not less than or equal to 1$",
"^2 not less than or equal to 1 : oops$",
],
)
@make_dynamo_test
def testAssertGreater(self):
self.assertMessages(
"assertGreater",
(1, 2),
[
"^1 not greater than 2$",
"^oops$",
"^1 not greater than 2$",
"^1 not greater than 2 : oops$",
],
)
@make_dynamo_test
def testAssertGreaterEqual(self):
self.assertMessages(
"assertGreaterEqual",
(1, 2),
[
"^1 not greater than or equal to 2$",
"^oops$",
"^1 not greater than or equal to 2$",
"^1 not greater than or equal to 2 : oops$",
],
)
@make_dynamo_test
def testAssertIsNone(self):
self.assertMessages(
"assertIsNone",
("not None",),
[
"^'not None' is not None$",
"^oops$",
"^'not None' is not None$",
"^'not None' is not None : oops$",
],
)
@make_dynamo_test
def testAssertIsNotNone(self):
self.assertMessages(
"assertIsNotNone",
(None,),
[
"^unexpectedly None$",
"^oops$",
"^unexpectedly None$",
"^unexpectedly None : oops$",
],
)
@make_dynamo_test
def testAssertIs(self):
self.assertMessages(
"assertIs",
(None, "foo"),
[
"^None is not 'foo'$",
"^oops$",
"^None is not 'foo'$",
"^None is not 'foo' : oops$",
],
)
@make_dynamo_test
def testAssertIsNot(self):
self.assertMessages(
"assertIsNot",
(None, None),
[
"^unexpectedly identical: None$",
"^oops$",
"^unexpectedly identical: None$",
"^unexpectedly identical: None : oops$",
],
)
@make_dynamo_test
def testAssertRegex(self):
self.assertMessages(
"assertRegex",
("foo", "bar"),
[
"^Regex didn't match:",
"^oops$",
"^Regex didn't match:",
"^Regex didn't match: (.*) : oops$",
],
)
@make_dynamo_test
def testAssertNotRegex(self):
self.assertMessages(
"assertNotRegex",
("foo", "foo"),
[
"^Regex matched:",
"^oops$",
"^Regex matched:",
"^Regex matched: (.*) : oops$",
],
)
def assertMessagesCM(self, methodName, args, func, errors):
"""
Check that the correct error messages are raised while executing:
with method(*args):
func()
*errors* should be a list of 4 regex that match the error when:
1) longMessage = False and no msg passed;
2) longMessage = False and msg passed;
3) longMessage = True and no msg passed;
4) longMessage = True and msg passed;
"""
p = product((self.testableFalse, self.testableTrue), ({}, {"msg": "oops"}))
for (cls, kwargs), err in zip(p, errors):
method = getattr(cls, methodName)
# with self.assertRaisesRegex(cls.failureException, err):
with self.assertRaises(cls.failureException) as c:
with method(*args, **kwargs) as cm: # noqa: F841
func()
self.assertRegex(str(c.exception), err)
@make_dynamo_test
def testAssertRaises(self):
self.assertMessagesCM(
"assertRaises",
(TypeError,),
lambda: None,
[
"^TypeError not raised$",
"^oops$",
"^TypeError not raised$",
"^TypeError not raised : oops$",
],
)
@unittest.expectedFailure
@make_dynamo_test
def testAssertRaisesRegex(self):
self.assertMessagesCM(
"assertRaisesRegex",
(TypeError, "unused regex"),
lambda: None,
[
"^TypeError not raised$",
"^oops$",
"^TypeError not raised$",
"^TypeError not raised : oops$",
],
)
# test error raised but with wrong message
def raise_wrong_message():
raise TypeError("foo")
self.assertMessagesCM(
"assertRaisesRegex",
(TypeError, "regex"),
raise_wrong_message,
[
'^"regex" does not match "foo"$',
"^oops$",
'^"regex" does not match "foo"$',
'^"regex" does not match "foo" : oops$',
],
)
@unittest.expectedFailure
@make_dynamo_test
def testAssertWarns(self):
self.assertMessagesCM(
"assertWarns",
(UserWarning,),
lambda: None,
[
"^UserWarning not triggered$",
"^oops$",
"^UserWarning not triggered$",
"^UserWarning not triggered : oops$",
],
)
@unittest.expectedFailure
@unittest.skipIf(sys.version_info < (3, 13), "feature landed in 3.13")
@make_dynamo_test
def test_assertNotWarns(self):
def warn_future():
warnings.warn("xyz", FutureWarning, stacklevel=2)
self.assertMessagesCM(
"_assertNotWarns",
(FutureWarning,),
warn_future,
[
"^FutureWarning triggered$",
"^oops$",
"^FutureWarning triggered$",
"^FutureWarning triggered : oops$",
],
)
@unittest.expectedFailure
@make_dynamo_test
def testAssertWarnsRegex(self):
# test error not raised
self.assertMessagesCM(
"assertWarnsRegex",
(UserWarning, "unused regex"),
lambda: None,
[
"^UserWarning not triggered$",
"^oops$",
"^UserWarning not triggered$",
"^UserWarning not triggered : oops$",
],
)
# test warning raised but with wrong message
def raise_wrong_message():
warnings.warn("foo")
self.assertMessagesCM(
"assertWarnsRegex",
(UserWarning, "regex"),
raise_wrong_message,
[
'^"regex" does not match "foo"$',
"^oops$",
'^"regex" does not match "foo"$',
'^"regex" does not match "foo" : oops$',
],
)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()

View File

@ -401,6 +401,9 @@ enable_cpp_symbolic_shape_guards = False
# Enable tracing through contextlib.contextmanager
enable_trace_contextlib = True
# Enable tracing through unittest
enable_trace_unittest = False
# Enable tracing generator functions lazily. If False, Dynamo will exhaust
# generators upon first execution. And if True, the generator will be accessed lazily
enable_faithful_generator_behavior = True

View File

@ -95,3 +95,22 @@ class TestCase(TorchTestCase):
if self._prior_is_grad_enabled is not torch.is_grad_enabled():
log.warning("Running test changed grad mode")
torch.set_grad_enabled(self._prior_is_grad_enabled)
class CPythonTestCase(TestCase):
_stack: contextlib.ExitStack
@classmethod
def tearDownClass(cls) -> None:
cls._stack.close()
super().tearDownClass()
@classmethod
def setUpClass(cls) -> None:
super().setUpClass()
cls._stack = contextlib.ExitStack() # type: ignore[attr-defined]
cls._stack.enter_context( # type: ignore[attr-defined]
config.patch(
enable_trace_unittest=True,
),
)

View File

@ -3173,7 +3173,6 @@ BUILTIN_SKIPLIST = (
random,
traceback,
linecache,
unittest,
)
# third party libraries skiplist is defined by str, because users may not use these libraries.
@ -3580,6 +3579,12 @@ def check_file(filename, is_inlined_call=False):
):
return SkipResult(True, "FBCODE_SKIP_TORCHREC_DIRS")
if (
filename.startswith(_module_dir(unittest))
and not torch._dynamo.config.enable_trace_unittest
):
return SkipResult(True, "unittest")
if bool(SKIP_DIRS_RE.match(filename)):
return SkipResult(True, "SKIP_DIRS")

View File

@ -10,6 +10,7 @@ import operator
import sys
import types
import typing
import unittest
from collections import defaultdict, OrderedDict
from collections.abc import KeysView, Sequence
from typing import Callable, TYPE_CHECKING, Union
@ -1657,7 +1658,10 @@ class BuiltinVariable(VariableTracker):
)
def call_len(self, tx: "InstructionTranslator", *args, **kwargs):
return args[0].call_method(tx, "__len__", args[1:], kwargs)
try:
return args[0].call_method(tx, "__len__", args[1:], kwargs)
except AttributeError as e:
raise_observed_exception(type(e), tx, args=list(e.args))
def call_getitem(self, tx: "InstructionTranslator", *args, **kwargs):
return args[0].call_method(tx, "__getitem__", args[1:], kwargs)
@ -1871,6 +1875,30 @@ class BuiltinVariable(VariableTracker):
variables.UserDefinedObjectVariable,
),
):
if (
isinstance(obj, variables.UserDefinedObjectVariable)
and issubclass(obj.value.__class__, unittest.TestCase)
and config.enable_trace_unittest
and name
in (
"assertRaisesRegex",
"assertNotWarns",
"assertWarnsRegex",
"assertDictEqual",
"assertSequenceEqual",
"assertWarns",
)
):
unimplemented_v2(
gb_type="Failed to trace builtin operator",
context=f"function: unittest.TestCase.{name}",
explanation=f"Dynamo does not know how to trace builtin operator `{name}` ",
hints=[
f"Avoid calling builtin `{name}`. "
"Please report an issue to PyTorch.",
],
)
try:
return obj.var_getattr(tx, name)
except NotImplementedError:

View File

@ -22,7 +22,9 @@ None values for efficiency and code reuse.
import collections
import functools
import inspect
import types
from collections.abc import Hashable as py_Hashable
from typing import Optional, TYPE_CHECKING
from torch._subclasses.fake_tensor import is_fake
@ -53,6 +55,10 @@ if TYPE_CHECKING:
# - (perhaps) Define how it is compared in _HashableTracker._eq_impl
def was_instancecheck_override(obj):
return type(obj).__dict__.get("__instancecheck__", False)
def is_hashable(x):
# NB - performing isinstance check on a LazVT realizes the VT, accidentally
# inserting the guard. To avoid this, lazyVT `is_hashable` methods looks at
@ -72,6 +78,13 @@ def is_hashable(x):
return x.as_proxy().node.meta.get("example_value") is not None
elif isinstance(x, variables.TupleVariable):
return all(is_hashable(e) for e in x.items)
elif (
isinstance(x, variables.UserDefinedObjectVariable)
and not was_instancecheck_override(x.value)
and inspect.getattr_static(x.value, "__hash__") is int.__hash__
and isinstance(x.value, int)
):
return isinstance(x.value, py_Hashable)
else:
return isinstance(
x,
@ -80,7 +93,7 @@ def is_hashable(x):
variables.SymNodeVariable,
variables.ConstantVariable,
variables.EnumVariable,
variables.user_defined.UserDefinedClassVariable,
variables.UserDefinedClassVariable,
variables.UserFunctionVariable,
variables.SkipFunctionVariable,
variables.misc.NumpyVariable,
@ -140,6 +153,11 @@ class ConstDictVariable(VariableTracker):
# Access the underlying value inside the referent_vt for the key representation
Hashable = ConstDictVariable._HashableTracker
return Hashable(self.vt.referent_vt).underlying_value
elif isinstance(self.vt, variables.UserDefinedObjectVariable):
# The re module in Python 3.13+ has a dictionary (_cache2) with
# an object as key (`class _ZeroSentinel(int): ...`):
# python test/dynamo/test_unittest.py CPythonTestLongMessage.test_baseAssertEqual
return self.vt.value
else:
x = self.vt.as_python_constant()
return x

View File

@ -1080,6 +1080,11 @@ class NestedUserFunctionVariable(BaseUserFunctionVariable):
def has_closure(self):
return self.closure is not None
def const_getattr(self, tx, name):
if name == "__name__":
return self.fn_name.as_python_constant()
return super().const_getattr(tx, name)
def has_self(self):
return False