Add itemsize and nbytes properties to Tensor (#98322)

Adds properties for itemsize and nbytes to Tensor matching the properties in NumPy.

Fixes https://github.com/pytorch/pytorch/issues/12728

Pull Request resolved: https://github.com/pytorch/pytorch/pull/98322
Approved by: https://github.com/ezyang
This commit is contained in:
BJ Hargrave 2023-04-05 12:11:55 +00:00 committed by PyTorch MergeBot
parent 14ccad73b4
commit 555ab310dc
10 changed files with 85 additions and 0 deletions

View File

@ -129,6 +129,7 @@ If you don't see an operation listed here, but it would help your use case, plea
:attr:`Tensor.is_sparse_csr`,None
:func:`torch.is_tensor`,None
:meth:`Tensor.item`,None
:attr:`Tensor.itemsize`,None
":meth:`Tensor.kthvalue`, :func:`torch.kthvalue`",:ref:`removes_dimensions-doc`
":meth:`Tensor.le`, :func:`torch.le`",:ref:`unifies_names_from_inputs-doc`
":meth:`Tensor.log`, :func:`torch.log`",:ref:`keeps_input_names-doc`
@ -160,6 +161,7 @@ If you don't see an operation listed here, but it would help your use case, plea
":meth:`Tensor.mv`, :func:`torch.mv`",:ref:`contracts_away_dims-doc`
:attr:`Tensor.names`,See documentation
":meth:`Tensor.narrow`, :func:`torch.narrow`",:ref:`keeps_input_names-doc`
:attr:`Tensor.nbytes`,None
:attr:`Tensor.ndim`,None
:meth:`Tensor.ndimension`,None
":meth:`Tensor.ne`, :func:`torch.ne`",:ref:`unifies_names_from_inputs-doc`

View File

@ -199,6 +199,8 @@ Tensor class reference
Tensor.ndim
Tensor.real
Tensor.imag
Tensor.nbytes
Tensor.itemsize
Tensor.abs
Tensor.abs_

View File

@ -6110,16 +6110,27 @@ class TestTorch(TestCase):
complexdouble = torch.ComplexDoubleStorage().element_size()
self.assertEqual(byte, torch.ByteTensor().element_size())
self.assertEqual(byte, torch.ByteTensor().itemsize)
self.assertEqual(char, torch.CharTensor().element_size())
self.assertEqual(char, torch.CharTensor().itemsize)
self.assertEqual(short, torch.ShortTensor().element_size())
self.assertEqual(short, torch.ShortTensor().itemsize)
self.assertEqual(int, torch.IntTensor().element_size())
self.assertEqual(int, torch.IntTensor().itemsize)
self.assertEqual(long, torch.LongTensor().element_size())
self.assertEqual(long, torch.LongTensor().itemsize)
self.assertEqual(float, torch.FloatTensor().element_size())
self.assertEqual(float, torch.FloatTensor().itemsize)
self.assertEqual(double, torch.DoubleTensor().element_size())
self.assertEqual(double, torch.DoubleTensor().itemsize)
self.assertEqual(bool, torch.BoolTensor().element_size())
self.assertEqual(bool, torch.BoolTensor().itemsize)
self.assertEqual(bfloat16, torch.tensor([], dtype=torch.bfloat16).element_size())
self.assertEqual(bfloat16, torch.tensor([], dtype=torch.bfloat16).itemsize)
self.assertEqual(complexfloat, torch.tensor([], dtype=torch.complex64).element_size())
self.assertEqual(complexfloat, torch.tensor([], dtype=torch.complex64).itemsize)
self.assertEqual(complexdouble, torch.tensor([], dtype=torch.complex128).element_size())
self.assertEqual(complexdouble, torch.tensor([], dtype=torch.complex128).itemsize)
self.assertGreater(byte, 0)
self.assertGreater(char, 0)
@ -7469,6 +7480,14 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
c = torch.randn(1, 0)
self.assertEqual(2, c.ndim)
def test_nbytes(self):
a = torch.randn(1, 2, 3, dtype=torch.float64)
self.assertEqual(a.numel() * a.element_size(), a.nbytes)
b = torch.randn(())
self.assertEqual(b.numel() * b.element_size(), b.nbytes)
c = torch.randn(1, 0)
self.assertEqual(c.numel() * c.element_size(), c.nbytes)
def test_fill_diagonal(self):
a1 = torch.randn(7, 3)
a2 = a1.clone()

View File

@ -157,6 +157,8 @@ _SKIP_PYTHON_BINDINGS = [
"_nested_view_from_buffer", # View only version of _nested_from_buffer. This will force users to only use the "safe" version.
"_nested_view_from_buffer_copy",
"_nested_view_from_buffer_copy_out",
"nbytes",
"itemsize",
]
SKIP_PYTHON_BINDINGS = [

View File

@ -1423,6 +1423,8 @@ class _TensorBase(metaclass=_TensorMeta):
_grad: Optional[Tensor]
grad: Optional[Tensor]
_backward_hooks: Optional[Dict[_int, Callable[[Tensor], Optional[Tensor]]]]
nbytes: _int
itemsize: _int
${tensor_method_hints}
# Defined in torch/csrc/multiprocessing/init.cpp

View File

@ -6615,6 +6615,22 @@ Alias for :meth:`~Tensor.dim()`
""",
)
add_docstr_all(
"itemsize",
r"""
Alias for :meth:`~Tensor.element_size()`
""",
)
add_docstr_all(
"nbytes",
r"""
Returns the number of bytes consumed by the "view" of elements of the Tensor
if the Tensor does not use sparse storage layout.
Defined to be :meth:`~Tensor.numel()` * :meth:`~Tensor.element_size()`
""",
)
add_docstr_all(
"T",
r"""

View File

@ -1345,6 +1345,24 @@ static PyObject* THPVariable_device(THPVariable* self, void* unused) {
END_HANDLE_TH_ERRORS
}
static PyObject* THPVariable_get_nbytes(THPVariable* self, void* unused) {
HANDLE_TH_ERRORS
if (check_has_torch_function((PyObject*)self)) {
return handle_torch_function_getter(self, "nbytes");
}
return PyLong_FromSize_t(THPVariable_Unpack(self).nbytes());
END_HANDLE_TH_ERRORS
}
static PyObject* THPVariable_get_itemsize(THPVariable* self, void* unused) {
HANDLE_TH_ERRORS
if (check_has_torch_function((PyObject*)self)) {
return handle_torch_function_getter(self, "itemsize");
}
return PyLong_FromSize_t(THPVariable_Unpack(self).itemsize());
END_HANDLE_TH_ERRORS
}
int THPVariable_set_real(PyObject* self, PyObject* real, void* unused) {
HANDLE_TH_ERRORS
auto& self_ = THPVariable_Unpack(self);
@ -1463,6 +1481,8 @@ static struct PyGetSetDef THPVariable_properties[] = {
{"layout", (getter)THPVariable_layout, nullptr, nullptr, nullptr},
{"device", (getter)THPVariable_device, nullptr, nullptr, nullptr},
{"ndim", (getter)THPVariable_get_ndim, nullptr, nullptr, nullptr},
{"nbytes", (getter)THPVariable_get_nbytes, nullptr, nullptr, nullptr},
{"itemsize", (getter)THPVariable_get_itemsize, nullptr, nullptr, nullptr},
{"names",
(getter)THPVariable_get_names,
(setter)THPVariable_set_names,

View File

@ -132,6 +132,8 @@ std::shared_ptr<SugaredValue> SimpleValue::attr(
{"mT", "aten"},
{"mH", "aten"},
{"is_ort", "prim"},
{"itemsize", "prim"},
{"nbytes", "prim"},
{"ndim", "prim"},
{"name", "prim"},
{"real", "aten"},

View File

@ -2423,6 +2423,24 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs1{
}
},
aliasAnalysisFromSchema()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA("prim::nbytes(Tensor a) -> int"),
[](Stack& stack) {
at::Tensor a;
pop(stack, a);
const auto nbytes = static_cast<int64_t>(a.nbytes());
push(stack, nbytes);
},
aliasAnalysisFromSchema()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA("prim::itemsize(Tensor a) -> int"),
[](Stack& stack) {
at::Tensor a;
pop(stack, a);
const auto itemsize = static_cast<int64_t>(a.itemsize());
push(stack, itemsize);
},
aliasAnalysisFromSchema()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA("prim::index(Device self) -> int?"),
[](Stack& stack) {

View File

@ -1215,9 +1215,11 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
Tensor.is_sparse.__get__: lambda self: -1,
Tensor.is_sparse_csr.__get__: lambda self: -1,
Tensor.is_vulkan.__get__: lambda self: -1,
Tensor.itemsize.__get__: lambda self: -1,
Tensor.layout.__get__: lambda self: -1,
Tensor.name.__get__: lambda self: -1,
Tensor.names.__get__: lambda self: -1,
Tensor.nbytes.__get__: lambda self: -1,
Tensor.ndim.__get__: lambda self: -1,
Tensor.output_nr.__get__: lambda self: -1,
Tensor.requires_grad.__get__: lambda self: -1,