Fix segfault during NumPy string tensor conversion (#155364)

By checking dtype first, but add elemnt_size check as well

Fixes https://github.com/pytorch/pytorch/issues/155328

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155364
Approved by: https://github.com/Skylion007
This commit is contained in:
Nikita Shulga 2025-06-06 13:33:34 -07:00 committed by PyTorch MergeBot
parent be2e43264d
commit 5596cefba6
2 changed files with 14 additions and 4 deletions

View File

@ -286,6 +286,14 @@ class TestNumPyInterop(TestCase):
pass
self.assertTrue(sys.getrefcount(x) == 2)
@skipIfTorchDynamo("No need to test invalid dtypes that should fail by design.")
@onlyCPU
def test_from_numpy_zero_element_type(self):
# This tests that dtype check happens before strides check
# which results in div-by-zero on-x86
x = np.ndarray((3, 3), dtype=str)
self.assertRaises(TypeError, lambda: torch.from_numpy(x))
@skipMeta
def test_from_list_of_ndarray_warning(self, device):
warning_msg = (

View File

@ -231,8 +231,13 @@ at::Tensor tensor_from_numpy(
int ndim = PyArray_NDIM(array);
auto sizes = to_aten_shape(ndim, PyArray_DIMS(array));
auto strides = to_aten_shape(ndim, PyArray_STRIDES(array));
// This must go before the INCREF and element_size checks
// in case the dtype mapping doesn't exist and an exception is thrown
auto torch_dtype = numpy_dtype_to_aten(PyArray_TYPE(array));
// NumPy strides use bytes. Torch strides use element counts.
auto element_size_in_bytes = PyArray_ITEMSIZE(array);
const auto element_size_in_bytes = PyArray_ITEMSIZE(array);
TORCH_CHECK(element_size_in_bytes > 0, "element_size must be 0");
for (auto& stride : strides) {
TORCH_CHECK_VALUE(
stride % element_size_in_bytes == 0,
@ -255,9 +260,6 @@ at::Tensor tensor_from_numpy(
PyArray_EquivByteorders(PyArray_DESCR(array)->byteorder, NPY_NATIVE),
"given numpy array has byte order different from the native byte order. "
"Conversion between byte orders is currently not supported.");
// This has to go before the INCREF in case the dtype mapping doesn't
// exist and an exception is thrown
auto torch_dtype = numpy_dtype_to_aten(PyArray_TYPE(array));
Py_INCREF(obj);
return at::lift_fresh(at::from_blob(
data_ptr,