Add Python binding resizable to class {Untyped,Typed}Storage (#119286)

This PR exposes `resizable` method of `StorageImpl` to Python frontend to make it accessible for users.

Fixes #119233

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119286
Approved by: https://github.com/ezyang, https://github.com/mikaylagawarecki
This commit is contained in:
Hirochika Matsumoto 2024-02-07 19:15:50 +00:00 committed by PyTorch MergeBot
parent d054cd3e44
commit 02c24b0b5e
3 changed files with 21 additions and 0 deletions

View File

@ -8305,6 +8305,13 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
self.assertEqual((sizeof_100 - sizeof_empty) // (sizeof_10 - sizeof_empty), 10) self.assertEqual((sizeof_100 - sizeof_empty) // (sizeof_10 - sizeof_empty), 10)
self.assertEqual((sizeof_100 - sizeof_empty) % (sizeof_10 - sizeof_empty), 0) self.assertEqual((sizeof_100 - sizeof_empty) % (sizeof_10 - sizeof_empty), 0)
@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
def test_resizable(self) -> None:
x = torch.randn(5)
self.assertTrue(x.storage().resizable())
x.numpy()
self.assertFalse(x.storage().resizable())
def test_iter(self) -> None: def test_iter(self) -> None:
x = torch.randn(5, 5) x = torch.randn(5, 5)
for i, sub in enumerate(x): for i, sub in enumerate(x):

View File

@ -63,6 +63,13 @@ static PyObject* THPStorage_dataPtr(PyObject* self, PyObject* noargs) {
END_HANDLE_TH_ERRORS END_HANDLE_TH_ERRORS
} }
static PyObject* THPStorage_resizable(PyObject* self, PyObject* noargs) {
HANDLE_TH_ERRORS
THPStorage_assertNotNull(self);
return PyBool_FromLong(THPStorage_Unpack(self).resizable());
END_HANDLE_TH_ERRORS
}
static PyObject* THPStorage_copy_( static PyObject* THPStorage_copy_(
PyObject* self, PyObject* self,
PyObject* args, PyObject* args,
@ -668,6 +675,7 @@ static PyMethodDef THPStorage_methods[] = {
{"resize_", THPStorage_resize_, METH_O, nullptr}, {"resize_", THPStorage_resize_, METH_O, nullptr},
{"nbytes", THPStorage_nbytes, METH_NOARGS, nullptr}, {"nbytes", THPStorage_nbytes, METH_NOARGS, nullptr},
{"data_ptr", THPStorage_dataPtr, METH_NOARGS, nullptr}, {"data_ptr", THPStorage_dataPtr, METH_NOARGS, nullptr},
{"resizable", THPStorage_resizable, METH_NOARGS, nullptr},
{"_write_file", THPStorage_writeFile, METH_VARARGS, nullptr}, {"_write_file", THPStorage_writeFile, METH_VARARGS, nullptr},
{"_new_with_file", {"_new_with_file",
THPStorage_newWithFile, THPStorage_newWithFile,

View File

@ -47,6 +47,8 @@ class _StorageBase:
def data_ptr(self) -> int: ... # type: ignore[empty-body] # noqa: E704 def data_ptr(self) -> int: ... # type: ignore[empty-body] # noqa: E704
def resizable(self) -> bool: ... # type: ignore[empty-body] # noqa: E704
# Defined in torch/csrc/generic/StorageSharing.cpp # Defined in torch/csrc/generic/StorageSharing.cpp
def _share_filename_cpu_(self, *args, **kwargs): ... # noqa: E704 def _share_filename_cpu_(self, *args, **kwargs): ... # noqa: E704
def _share_fd_cpu_(self, *args, **kwargs): ... # noqa: E704 def _share_fd_cpu_(self, *args, **kwargs): ... # noqa: E704
@ -957,6 +959,10 @@ class TypedStorage:
def _data_ptr(self): def _data_ptr(self):
return self._untyped_storage.data_ptr() return self._untyped_storage.data_ptr()
def resizable(self):
_warn_typed_storage_removal()
return self._untyped_storage.resizable()
def resize_(self, size): def resize_(self, size):
_warn_typed_storage_removal() _warn_typed_storage_removal()
self._resize_(size) self._resize_(size)