Restore test_warn_types (#90810)

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90810
Approved by: https://github.com/ngimel
This commit is contained in:
Edward Z. Yang 2022-12-14 10:05:20 +08:00 committed by PyTorch MergeBot
parent e8e591b72f
commit cc504ce292

View File

@ -5722,6 +5722,27 @@ class TestTorch(TestCase):
r"the unspecified dimension size -1 can be any value and is ambiguous"):
torch.randn(2, 0).unflatten(1, (2, -1, 0))
# Test that warnings generated from C++ are translated to the correct type
def test_warn_types(self):
test_cases = [
# function, warning type, message
(torch._C._warn, UserWarning, r"Test message for TORCH_WARN"),
(torch._C._warn_deprecation, DeprecationWarning, r"Test message for TORCH_WARN_DEPRECATION"),
]
for fn, warning_type, message in test_cases:
with warnings.catch_warnings(record=True) as w:
warnings.resetwarnings()
warnings.filterwarnings('always', category=warning_type)
fn()
self.assertEqual(len(w), 1, msg=f'{warning_type} not raised')
warning = w[0].message
self.assertTrue(isinstance(warning, warning_type), msg=f'{warning_type} not raised')
self.assertTrue(re.search(
message,
str(warning)))
def test_structseq_repr(self):
a = torch.arange(250).reshape(5, 5, 10)
expected = """