Implement tensor.size(Dimname), tensor.stride(Dimname)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/22989

Test Plan: Imported from OSS

Differential Revision: D16364437

Pulled By: zou3519

fbshipit-source-id: 393a93fecac27b5d3b1a7f7692590d8fd5e95a5d
This commit is contained in:
Richard Zou 2019-07-22 12:53:15 -07:00 committed by Facebook Github Bot
parent 965b97f5f0
commit b4b51ed5ec
6 changed files with 87 additions and 0 deletions

View File

@ -516,6 +516,9 @@ class CAFFE2_API Tensor {
Tensor detach() const;
Tensor & detach_();
int64_t size(int64_t dim) const;
#ifdef BUILD_NAMEDTENSOR
int64_t size(Dimname dim) const;
#endif
Tensor slice(int64_t dim=0, int64_t start=0, int64_t end=9223372036854775807, int64_t step=1) const;
std::tuple<Tensor,Tensor> slogdet() const;
Tensor smm(const Tensor & mat2) const;
@ -529,6 +532,9 @@ class CAFFE2_API Tensor {
Tensor sspaddmm(const Tensor & mat1, const Tensor & mat2, Scalar beta=1, Scalar alpha=1) const;
Tensor stft(int64_t n_fft, c10::optional<int64_t> hop_length=c10::nullopt, c10::optional<int64_t> win_length=c10::nullopt, const Tensor & window={}, bool normalized=false, bool onesided=true) const;
int64_t stride(int64_t dim) const;
#ifdef BUILD_NAMEDTENSOR
int64_t stride(Dimname dim) const;
#endif
Tensor sum(c10::optional<ScalarType> dtype=c10::nullopt) const;
Tensor sum(IntArrayRef dim, bool keepdim=false, c10::optional<ScalarType> dtype=c10::nullopt) const;
Tensor sum_to_size(IntArrayRef size) const;

View File

@ -715,6 +715,12 @@ inline int64_t Tensor::size(int64_t dim) const {
static auto table = globalATenDispatch().getOpTable("aten::size(Tensor self, int dim) -> int");
return table->getOp<int64_t (const Tensor &, int64_t)>(tensorTypeIdToBackend(type_id()), is_variable())(*this, dim);
}
#ifdef BUILD_NAMEDTENSOR
inline int64_t Tensor::size(Dimname dim) const {
static auto table = globalATenDispatch().getOpTable("aten::size(Tensor self, Dimname dim) -> int");
return table->getOp<int64_t (const Tensor &, Dimname)>(tensorTypeIdToBackend(type_id()), is_variable())(*this, dim);
}
#endif
inline Tensor Tensor::slice(int64_t dim, int64_t start, int64_t end, int64_t step) const {
static auto table = globalATenDispatch().getOpTable("aten::slice(Tensor(a) self, int dim=0, int start=0, int end=9223372036854775807, int step=1) -> Tensor(a)");
return table->getOp<Tensor (const Tensor &, int64_t, int64_t, int64_t, int64_t)>(tensorTypeIdToBackend(type_id()), is_variable())(*this, dim, start, end, step);
@ -767,6 +773,12 @@ inline int64_t Tensor::stride(int64_t dim) const {
static auto table = globalATenDispatch().getOpTable("aten::stride(Tensor self, int dim) -> int");
return table->getOp<int64_t (const Tensor &, int64_t)>(tensorTypeIdToBackend(type_id()), is_variable())(*this, dim);
}
#ifdef BUILD_NAMEDTENSOR
inline int64_t Tensor::stride(Dimname dim) const {
static auto table = globalATenDispatch().getOpTable("aten::stride(Tensor self, Dimname dim) -> int");
return table->getOp<int64_t (const Tensor &, Dimname)>(tensorTypeIdToBackend(type_id()), is_variable())(*this, dim);
}
#endif
inline Tensor Tensor::sum(c10::optional<ScalarType> dtype) const {
static auto table = globalATenDispatch().getOpTable("aten::sum(Tensor self, *, ScalarType? dtype=None) -> Tensor");
return table->getOp<Tensor (const Tensor &, c10::optional<ScalarType>)>(tensorTypeIdToBackend(type_id()), is_variable())(*this, dtype);

View File

@ -2,6 +2,9 @@
#include <ATen/NativeFunctions.h>
#include <ATen/WrapDimUtils.h>
#include <ATen/detail/CUDAHooksInterface.h>
#ifdef BUILD_NAMEDTENSOR
#include <ATen/NamedTensorUtils.h>
#endif
#include <ATen/Config.h>
namespace at {
@ -23,6 +26,18 @@ int64_t stride(const Tensor& self, int64_t dim) {
return self.strides()[dim];
}
#ifdef BUILD_NAMEDTENSOR
int64_t size(const Tensor& self, Dimname dim) {
size_t pos_dim = dimname_to_position(self, dim);
return self.sizes()[pos_dim];
}
int64_t stride(const Tensor& self, Dimname dim) {
size_t pos_dim = dimname_to_position(self, dim);
return self.strides()[pos_dim];
}
#endif
bool cudnn_is_acceptable(const Tensor& self) {
if (!globalContext().userEnabledCuDNN()) return false;
if (!self.is_cuda()) return false;

View File

@ -1854,6 +1854,11 @@
device_guard: False
named_guard: False
- func: size(Tensor self, Dimname dim) -> int
variants: function, method
device_guard: False
named_guard: False
- func: slice(Tensor(a) self, int dim=0, int start=0, int end=9223372036854775807, int step=1) -> Tensor(a)
variants: function, method
device_guard: False
@ -1965,6 +1970,12 @@
- func: stride(Tensor self, int dim) -> int
variants: function, method
device_guard: False
named_guard: False
- func: stride(Tensor self, Dimname dim) -> int
variants: function, method
device_guard: False
named_guard: False
- func: sum(Tensor self, *, ScalarType? dtype=None) -> Tensor
variants: function, method

View File

@ -71,6 +71,28 @@ class TestNamedTensor(TestCase):
def test_empty_cuda(self):
self._test_factory(torch.empty, 'cuda')
def test_size(self):
t = torch.empty(2, 3, 5, names=('N', None, 'C'))
self.assertEqual(t.size('N'), 2)
self.assertEqual(t.size('C'), 5)
with self.assertRaisesRegex(RuntimeError, 'Please look up dimensions by name*'):
t.size(None)
with self.assertRaisesRegex(RuntimeError, 'Name \'channels\' not found in '):
t.size('channels')
with self.assertRaisesRegex(RuntimeError, 'Name \'N\' not found in '):
torch.empty(2, 3, 4).size('N')
def test_stride(self):
t = torch.empty(2, 3, 5, names=('N', None, 'C'))
self.assertEqual(t.stride('N'), 3 * 5)
self.assertEqual(t.stride('C'), 1)
with self.assertRaisesRegex(RuntimeError, 'Please look up dimensions by name'):
t.stride(None)
with self.assertRaisesRegex(RuntimeError, 'Name \'channels\' not found in '):
t.stride('channels')
with self.assertRaisesRegex(RuntimeError, 'Name \'N\' not found in '):
torch.empty(2, 3, 4).stride('N')
def test_info_smoke(self):
# Smoke test for info functions / methods / attributes on named tensors.
tensor = torch.empty(1, 1, names=('N', 'D'))
@ -97,10 +119,12 @@ class TestNamedTensor(TestCase):
tensor.nelement()
tensor.shape
tensor.size()
tensor.size(1)
tensor.storage()
tensor.storage_offset()
tensor.storage_type()
tensor.stride()
tensor.stride(1)
tensor.data
tensor.data_ptr()
tensor.ndim

View File

@ -77,6 +77,9 @@ static PyObject * THPVariable_size(PyObject* self, PyObject* args, PyObject* kwa
static PythonArgParser parser({
"size(int64_t dim)",
"size()",
#ifdef BUILD_NAMEDTENSOR
"size(Dimname dim)",
#endif
});
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
ParsedArgs<3> parsed_args;
@ -92,6 +95,14 @@ static PyObject * THPVariable_size(PyObject* self, PyObject* args, PyObject* kwa
// torch.Size and tuple in python.
return THPSize_New(self_);
}
#ifdef BUILD_NAMEDTENSOR
else if (r.idx == 2) {
if (jit::tracer::isTracing()) {
TORCH_INTERNAL_ASSERT("NYI: Named tensors w/ JIT");
}
return wrap(self_.size(r.dimname(0)));
}
#endif
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
@ -102,6 +113,9 @@ static PyObject * THPVariable_stride(PyObject* self, PyObject* args, PyObject* k
static PythonArgParser parser({
"stride(int64_t dim)",
"stride()",
#ifdef BUILD_NAMEDTENSOR
"stride(Dimname dim)",
#endif
});
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
ParsedArgs<3> parsed_args;
@ -115,6 +129,11 @@ static PyObject * THPVariable_stride(PyObject* self, PyObject* args, PyObject* k
// torch.Size and tuple in python
return THPUtils_packInt64Array(strides.size(), strides.data());
}
#ifdef BUILD_NAMEDTENSOR
else if (r.idx == 2) {
return wrap(self_.stride(r.dimname(0)));
}
#endif
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}