From f5cf05c98358edd4aee58374d2a594954137078b Mon Sep 17 00:00:00 2001 From: gaoyufeng <15834128411@126.com> Date: Fri, 25 Jul 2025 23:49:42 +0000 Subject: [PATCH] =?UTF-8?q?Throw=20invalid=5Fargument=20instead=20of=20Run?= =?UTF-8?q?timeError=20when=20parameters=20exceed=E2=80=A6=20(#158267)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Throw invalid_argument instead of RuntimeError when parameters exceed limits (for torch.int32 dtype) Fixes #157707 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158267 Approved by: https://github.com/albanD --- test/test_numpy_interop.py | 2 +- test/test_torch.py | 6 ++--- .../numpy_tests/core/test_indexing.py | 2 +- torch/csrc/utils/python_numbers.h | 22 ++++++++----------- 4 files changed, 14 insertions(+), 18 deletions(-) diff --git a/test/test_numpy_interop.py b/test/test_numpy_interop.py index 5673010fcc4..286882dfdb3 100644 --- a/test/test_numpy_interop.py +++ b/test/test_numpy_interop.py @@ -488,7 +488,7 @@ class TestNumPyInterop(TestCase): ) # type: ignore[call-overload] else: self.assertRaisesRegex( - RuntimeError, + ValueError, "(Overflow|an integer is required)", lambda: torch.mean(torch.randn(1, 1), np.uint64(-1)), ) # type: ignore[call-overload] diff --git a/test/test_torch.py b/test/test_torch.py index 7af57f23b8f..a44831eb4ac 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -9432,7 +9432,7 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j], f"after calling manual_seed({seed:x}), but got {actual_initial_seed:x} instead") self.assertEqual(expected_initial_seed, actual_initial_seed, msg=msg) for invalid_seed in [min_int64 - 1, max_uint64 + 1]: - with self.assertRaisesRegex(RuntimeError, r'Overflow when unpacking long long'): + with self.assertRaisesRegex(ValueError, r'Overflow when unpacking long long'): torch.manual_seed(invalid_seed) torch.set_rng_state(rng_state) @@ -10851,8 +10851,8 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j], def test_invalid_arg_error_handling(self) -> None: """ Tests that errors from old TH functions are propagated back """ for invalid_val in [-1, 2**65]: - self.assertRaises(RuntimeError, lambda: torch.set_num_threads(invalid_val)) - self.assertRaises(RuntimeError, lambda: torch.set_num_interop_threads(invalid_val)) + self.assertRaises((ValueError, RuntimeError), lambda: torch.set_num_threads(invalid_val)) + self.assertRaises((ValueError, RuntimeError), lambda: torch.set_num_interop_threads(invalid_val)) def _get_tensor_prop(self, t): preserved = ( diff --git a/test/torch_np/numpy_tests/core/test_indexing.py b/test/torch_np/numpy_tests/core/test_indexing.py index ed402bd8595..91dae968683 100644 --- a/test/torch_np/numpy_tests/core/test_indexing.py +++ b/test/torch_np/numpy_tests/core/test_indexing.py @@ -219,7 +219,7 @@ class TestIndexing(TestCase): assert_raises(IndexError, a.__getitem__, 1 << 30) # Index overflow produces IndexError # Note torch raises RuntimeError here - assert_raises((IndexError, RuntimeError), a.__getitem__, 1 << 64) + assert_raises((IndexError, ValueError), a.__getitem__, 1 << 64) def test_single_bool_index(self): # Single boolean index diff --git a/torch/csrc/utils/python_numbers.h b/torch/csrc/utils/python_numbers.h index e331559f4a0..25ca2692b32 100644 --- a/torch/csrc/utils/python_numbers.h +++ b/torch/csrc/utils/python_numbers.h @@ -62,13 +62,11 @@ inline int32_t THPUtils_unpackInt(PyObject* obj) { if (value == -1 && PyErr_Occurred()) { throw python_error(); } - if (overflow != 0) { - throw std::runtime_error("Overflow when unpacking long"); - } - if (value > std::numeric_limits::max() || - value < std::numeric_limits::min()) { - throw std::runtime_error("Overflow when unpacking long"); - } + TORCH_CHECK_VALUE(overflow == 0, "Overflow when unpacking long long"); + TORCH_CHECK_VALUE( + value <= std::numeric_limits::max() && + value >= std::numeric_limits::min(), + "Overflow when unpacking long"); return (int32_t)value; } @@ -78,9 +76,7 @@ inline int64_t THPUtils_unpackLong(PyObject* obj) { if (value == -1 && PyErr_Occurred()) { throw python_error(); } - if (overflow != 0) { - throw std::runtime_error("Overflow when unpacking long long"); - } + TORCH_CHECK_VALUE(overflow == 0, "Overflow when unpacking long long"); return (int64_t)value; } @@ -89,9 +85,9 @@ inline uint32_t THPUtils_unpackUInt32(PyObject* obj) { if (PyErr_Occurred()) { throw python_error(); } - if (value > std::numeric_limits::max()) { - throw std::runtime_error("Overflow when unpacking unsigned long"); - } + TORCH_CHECK_VALUE( + value <= std::numeric_limits::max(), + "Overflow when unpacking long long"); return (uint32_t)value; }