mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Fixed a memory leak when calling from_numpy on a numpy array with an … (#121156)
…unsupported dtype. Fixes #121138. The lambda function that DECREFs the object is not called when the dtype conversion functions throws. This PR moves the conversion before the INCREF, which prevents the memory leak. Pull Request resolved: https://github.com/pytorch/pytorch/pull/121156 Approved by: https://github.com/soulitzer, https://github.com/albanD
This commit is contained in:
parent
360761f7d0
commit
76f3663efe
|
|
@ -6,6 +6,7 @@ import torch
|
|||
import numpy as np
|
||||
|
||||
from itertools import product
|
||||
import sys
|
||||
|
||||
from torch.testing._internal.common_utils import \
|
||||
(skipIfTorchDynamo, TestCase, run_tests)
|
||||
|
|
@ -257,6 +258,18 @@ class TestNumPyInterop(TestCase):
|
|||
x.strides = (3,)
|
||||
self.assertRaises(ValueError, lambda: torch.from_numpy(x))
|
||||
|
||||
@skipIfTorchDynamo("No need to test invalid dtypes that should fail by design.")
|
||||
def test_from_numpy_no_leak_on_invalid_dtype(self):
|
||||
# This used to leak memory as the `from_numpy` call raised an exception and didn't decref the temporary
|
||||
# object. See https://github.com/pytorch/pytorch/issues/121138
|
||||
x = np.array("value".encode('ascii'))
|
||||
for _ in range(1000):
|
||||
try:
|
||||
torch.from_numpy(x)
|
||||
except TypeError:
|
||||
pass
|
||||
self.assertTrue(sys.getrefcount(x) == 2)
|
||||
|
||||
@skipMeta
|
||||
def test_from_list_of_ndarray_warning(self, device):
|
||||
warning_msg = r"Creating a tensor from a list of numpy.ndarrays is extremely slow"
|
||||
|
|
|
|||
|
|
@ -258,6 +258,9 @@ at::Tensor tensor_from_numpy(
|
|||
PyArray_EquivByteorders(PyArray_DESCR(array)->byteorder, NPY_NATIVE),
|
||||
"given numpy array has byte order different from the native byte order. "
|
||||
"Conversion between byte orders is currently not supported.");
|
||||
// This has to go before the INCREF in case the dtype mapping doesn't
|
||||
// exist and an exception is thrown
|
||||
auto torch_dtype = numpy_dtype_to_aten(PyArray_TYPE(array));
|
||||
Py_INCREF(obj);
|
||||
return at::lift_fresh(at::from_blob(
|
||||
data_ptr,
|
||||
|
|
@ -267,7 +270,7 @@ at::Tensor tensor_from_numpy(
|
|||
pybind11::gil_scoped_acquire gil;
|
||||
Py_DECREF(obj);
|
||||
},
|
||||
at::device(kCPU).dtype(numpy_dtype_to_aten(PyArray_TYPE(array)))));
|
||||
at::device(kCPU).dtype(torch_dtype)));
|
||||
}
|
||||
|
||||
int aten_to_numpy_dtype(const ScalarType scalar_type) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user