mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Implement tensor.size(Dimname), tensor.stride(Dimname)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/22989 Test Plan: Imported from OSS Differential Revision: D16364437 Pulled By: zou3519 fbshipit-source-id: 393a93fecac27b5d3b1a7f7692590d8fd5e95a5d
This commit is contained in:
parent
965b97f5f0
commit
b4b51ed5ec
|
|
@ -516,6 +516,9 @@ class CAFFE2_API Tensor {
|
|||
Tensor detach() const;
|
||||
Tensor & detach_();
|
||||
int64_t size(int64_t dim) const;
|
||||
#ifdef BUILD_NAMEDTENSOR
|
||||
int64_t size(Dimname dim) const;
|
||||
#endif
|
||||
Tensor slice(int64_t dim=0, int64_t start=0, int64_t end=9223372036854775807, int64_t step=1) const;
|
||||
std::tuple<Tensor,Tensor> slogdet() const;
|
||||
Tensor smm(const Tensor & mat2) const;
|
||||
|
|
@ -529,6 +532,9 @@ class CAFFE2_API Tensor {
|
|||
Tensor sspaddmm(const Tensor & mat1, const Tensor & mat2, Scalar beta=1, Scalar alpha=1) const;
|
||||
Tensor stft(int64_t n_fft, c10::optional<int64_t> hop_length=c10::nullopt, c10::optional<int64_t> win_length=c10::nullopt, const Tensor & window={}, bool normalized=false, bool onesided=true) const;
|
||||
int64_t stride(int64_t dim) const;
|
||||
#ifdef BUILD_NAMEDTENSOR
|
||||
int64_t stride(Dimname dim) const;
|
||||
#endif
|
||||
Tensor sum(c10::optional<ScalarType> dtype=c10::nullopt) const;
|
||||
Tensor sum(IntArrayRef dim, bool keepdim=false, c10::optional<ScalarType> dtype=c10::nullopt) const;
|
||||
Tensor sum_to_size(IntArrayRef size) const;
|
||||
|
|
|
|||
|
|
@ -715,6 +715,12 @@ inline int64_t Tensor::size(int64_t dim) const {
|
|||
static auto table = globalATenDispatch().getOpTable("aten::size(Tensor self, int dim) -> int");
|
||||
return table->getOp<int64_t (const Tensor &, int64_t)>(tensorTypeIdToBackend(type_id()), is_variable())(*this, dim);
|
||||
}
|
||||
#ifdef BUILD_NAMEDTENSOR
|
||||
inline int64_t Tensor::size(Dimname dim) const {
|
||||
static auto table = globalATenDispatch().getOpTable("aten::size(Tensor self, Dimname dim) -> int");
|
||||
return table->getOp<int64_t (const Tensor &, Dimname)>(tensorTypeIdToBackend(type_id()), is_variable())(*this, dim);
|
||||
}
|
||||
#endif
|
||||
inline Tensor Tensor::slice(int64_t dim, int64_t start, int64_t end, int64_t step) const {
|
||||
static auto table = globalATenDispatch().getOpTable("aten::slice(Tensor(a) self, int dim=0, int start=0, int end=9223372036854775807, int step=1) -> Tensor(a)");
|
||||
return table->getOp<Tensor (const Tensor &, int64_t, int64_t, int64_t, int64_t)>(tensorTypeIdToBackend(type_id()), is_variable())(*this, dim, start, end, step);
|
||||
|
|
@ -767,6 +773,12 @@ inline int64_t Tensor::stride(int64_t dim) const {
|
|||
static auto table = globalATenDispatch().getOpTable("aten::stride(Tensor self, int dim) -> int");
|
||||
return table->getOp<int64_t (const Tensor &, int64_t)>(tensorTypeIdToBackend(type_id()), is_variable())(*this, dim);
|
||||
}
|
||||
#ifdef BUILD_NAMEDTENSOR
|
||||
inline int64_t Tensor::stride(Dimname dim) const {
|
||||
static auto table = globalATenDispatch().getOpTable("aten::stride(Tensor self, Dimname dim) -> int");
|
||||
return table->getOp<int64_t (const Tensor &, Dimname)>(tensorTypeIdToBackend(type_id()), is_variable())(*this, dim);
|
||||
}
|
||||
#endif
|
||||
inline Tensor Tensor::sum(c10::optional<ScalarType> dtype) const {
|
||||
static auto table = globalATenDispatch().getOpTable("aten::sum(Tensor self, *, ScalarType? dtype=None) -> Tensor");
|
||||
return table->getOp<Tensor (const Tensor &, c10::optional<ScalarType>)>(tensorTypeIdToBackend(type_id()), is_variable())(*this, dtype);
|
||||
|
|
|
|||
|
|
@ -2,6 +2,9 @@
|
|||
#include <ATen/NativeFunctions.h>
|
||||
#include <ATen/WrapDimUtils.h>
|
||||
#include <ATen/detail/CUDAHooksInterface.h>
|
||||
#ifdef BUILD_NAMEDTENSOR
|
||||
#include <ATen/NamedTensorUtils.h>
|
||||
#endif
|
||||
|
||||
#include <ATen/Config.h>
|
||||
namespace at {
|
||||
|
|
@ -23,6 +26,18 @@ int64_t stride(const Tensor& self, int64_t dim) {
|
|||
return self.strides()[dim];
|
||||
}
|
||||
|
||||
#ifdef BUILD_NAMEDTENSOR
|
||||
int64_t size(const Tensor& self, Dimname dim) {
|
||||
size_t pos_dim = dimname_to_position(self, dim);
|
||||
return self.sizes()[pos_dim];
|
||||
}
|
||||
|
||||
int64_t stride(const Tensor& self, Dimname dim) {
|
||||
size_t pos_dim = dimname_to_position(self, dim);
|
||||
return self.strides()[pos_dim];
|
||||
}
|
||||
#endif
|
||||
|
||||
bool cudnn_is_acceptable(const Tensor& self) {
|
||||
if (!globalContext().userEnabledCuDNN()) return false;
|
||||
if (!self.is_cuda()) return false;
|
||||
|
|
|
|||
|
|
@ -1854,6 +1854,11 @@
|
|||
device_guard: False
|
||||
named_guard: False
|
||||
|
||||
- func: size(Tensor self, Dimname dim) -> int
|
||||
variants: function, method
|
||||
device_guard: False
|
||||
named_guard: False
|
||||
|
||||
- func: slice(Tensor(a) self, int dim=0, int start=0, int end=9223372036854775807, int step=1) -> Tensor(a)
|
||||
variants: function, method
|
||||
device_guard: False
|
||||
|
|
@ -1965,6 +1970,12 @@
|
|||
- func: stride(Tensor self, int dim) -> int
|
||||
variants: function, method
|
||||
device_guard: False
|
||||
named_guard: False
|
||||
|
||||
- func: stride(Tensor self, Dimname dim) -> int
|
||||
variants: function, method
|
||||
device_guard: False
|
||||
named_guard: False
|
||||
|
||||
- func: sum(Tensor self, *, ScalarType? dtype=None) -> Tensor
|
||||
variants: function, method
|
||||
|
|
|
|||
|
|
@ -71,6 +71,28 @@ class TestNamedTensor(TestCase):
|
|||
def test_empty_cuda(self):
|
||||
self._test_factory(torch.empty, 'cuda')
|
||||
|
||||
def test_size(self):
|
||||
t = torch.empty(2, 3, 5, names=('N', None, 'C'))
|
||||
self.assertEqual(t.size('N'), 2)
|
||||
self.assertEqual(t.size('C'), 5)
|
||||
with self.assertRaisesRegex(RuntimeError, 'Please look up dimensions by name*'):
|
||||
t.size(None)
|
||||
with self.assertRaisesRegex(RuntimeError, 'Name \'channels\' not found in '):
|
||||
t.size('channels')
|
||||
with self.assertRaisesRegex(RuntimeError, 'Name \'N\' not found in '):
|
||||
torch.empty(2, 3, 4).size('N')
|
||||
|
||||
def test_stride(self):
|
||||
t = torch.empty(2, 3, 5, names=('N', None, 'C'))
|
||||
self.assertEqual(t.stride('N'), 3 * 5)
|
||||
self.assertEqual(t.stride('C'), 1)
|
||||
with self.assertRaisesRegex(RuntimeError, 'Please look up dimensions by name'):
|
||||
t.stride(None)
|
||||
with self.assertRaisesRegex(RuntimeError, 'Name \'channels\' not found in '):
|
||||
t.stride('channels')
|
||||
with self.assertRaisesRegex(RuntimeError, 'Name \'N\' not found in '):
|
||||
torch.empty(2, 3, 4).stride('N')
|
||||
|
||||
def test_info_smoke(self):
|
||||
# Smoke test for info functions / methods / attributes on named tensors.
|
||||
tensor = torch.empty(1, 1, names=('N', 'D'))
|
||||
|
|
@ -97,10 +119,12 @@ class TestNamedTensor(TestCase):
|
|||
tensor.nelement()
|
||||
tensor.shape
|
||||
tensor.size()
|
||||
tensor.size(1)
|
||||
tensor.storage()
|
||||
tensor.storage_offset()
|
||||
tensor.storage_type()
|
||||
tensor.stride()
|
||||
tensor.stride(1)
|
||||
tensor.data
|
||||
tensor.data_ptr()
|
||||
tensor.ndim
|
||||
|
|
|
|||
|
|
@ -77,6 +77,9 @@ static PyObject * THPVariable_size(PyObject* self, PyObject* args, PyObject* kwa
|
|||
static PythonArgParser parser({
|
||||
"size(int64_t dim)",
|
||||
"size()",
|
||||
#ifdef BUILD_NAMEDTENSOR
|
||||
"size(Dimname dim)",
|
||||
#endif
|
||||
});
|
||||
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
|
||||
ParsedArgs<3> parsed_args;
|
||||
|
|
@ -92,6 +95,14 @@ static PyObject * THPVariable_size(PyObject* self, PyObject* args, PyObject* kwa
|
|||
// torch.Size and tuple in python.
|
||||
return THPSize_New(self_);
|
||||
}
|
||||
#ifdef BUILD_NAMEDTENSOR
|
||||
else if (r.idx == 2) {
|
||||
if (jit::tracer::isTracing()) {
|
||||
TORCH_INTERNAL_ASSERT("NYI: Named tensors w/ JIT");
|
||||
}
|
||||
return wrap(self_.size(r.dimname(0)));
|
||||
}
|
||||
#endif
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
|
@ -102,6 +113,9 @@ static PyObject * THPVariable_stride(PyObject* self, PyObject* args, PyObject* k
|
|||
static PythonArgParser parser({
|
||||
"stride(int64_t dim)",
|
||||
"stride()",
|
||||
#ifdef BUILD_NAMEDTENSOR
|
||||
"stride(Dimname dim)",
|
||||
#endif
|
||||
});
|
||||
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
|
||||
ParsedArgs<3> parsed_args;
|
||||
|
|
@ -115,6 +129,11 @@ static PyObject * THPVariable_stride(PyObject* self, PyObject* args, PyObject* k
|
|||
// torch.Size and tuple in python
|
||||
return THPUtils_packInt64Array(strides.size(), strides.data());
|
||||
}
|
||||
#ifdef BUILD_NAMEDTENSOR
|
||||
else if (r.idx == 2) {
|
||||
return wrap(self_.stride(r.dimname(0)));
|
||||
}
|
||||
#endif
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user