mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Restore test_warn_types (#90810)
Signed-off-by: Edward Z. Yang <ezyang@fb.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/90810 Approved by: https://github.com/ngimel
This commit is contained in:
parent
e8e591b72f
commit
cc504ce292
|
|
@ -5722,6 +5722,27 @@ class TestTorch(TestCase):
|
|||
r"the unspecified dimension size -1 can be any value and is ambiguous"):
|
||||
torch.randn(2, 0).unflatten(1, (2, -1, 0))
|
||||
|
||||
# Test that warnings generated from C++ are translated to the correct type
|
||||
def test_warn_types(self):
|
||||
test_cases = [
|
||||
# function, warning type, message
|
||||
(torch._C._warn, UserWarning, r"Test message for TORCH_WARN"),
|
||||
(torch._C._warn_deprecation, DeprecationWarning, r"Test message for TORCH_WARN_DEPRECATION"),
|
||||
]
|
||||
|
||||
for fn, warning_type, message in test_cases:
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.resetwarnings()
|
||||
warnings.filterwarnings('always', category=warning_type)
|
||||
fn()
|
||||
|
||||
self.assertEqual(len(w), 1, msg=f'{warning_type} not raised')
|
||||
warning = w[0].message
|
||||
self.assertTrue(isinstance(warning, warning_type), msg=f'{warning_type} not raised')
|
||||
self.assertTrue(re.search(
|
||||
message,
|
||||
str(warning)))
|
||||
|
||||
def test_structseq_repr(self):
|
||||
a = torch.arange(250).reshape(5, 5, 10)
|
||||
expected = """
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user