pytorch/torch/csrc/utils/tensor_apply.cpp
Richard Barnes 3979cb0656 irange for size_t (#55320)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/55320

Test Plan: Sandcastle

Reviewed By: ngimel

Differential Revision: D27572577

fbshipit-source-id: 97710fd2bb1303006b05828a0d1343b0b59ccb03
2021-06-03 01:04:13 -07:00

109 lines
3.4 KiB
C++

#include <torch/csrc/utils/tensor_apply.h>
#include <ATen/TensorUtils.h>
#include <ATen/ExpandUtils.h>
#include <c10/util/irange.h>
#include <torch/csrc/Exceptions.h>
#include <torch/csrc/utils/python_numbers.h>
#include <torch/csrc/utils/python_scalars.h>
using namespace at;
namespace torch { namespace utils {
struct StridedData {
StridedData(const Tensor & tensor)
: data(tensor.data_ptr())
, strides(tensor.strides())
, elementSize(tensor.element_size()) {}
void* data;
IntArrayRef strides;
int64_t elementSize;
void step(int dim) {
data = (char*)data + (strides[dim] * elementSize);
}
};
template<size_t N>
static void recursive_apply(IntArrayRef sizes, ScalarType scalarType, int64_t dim,
PyObject* fn, std::array<StridedData, N> strided_data) {
int64_t ndim = sizes.size();
if (dim == ndim) {
auto args = THPObjectPtr(PyTuple_New(N));
if (!args) throw python_error();
for(const auto i : c10::irange(N)) {
PyObject* arg = load_scalar(strided_data[i].data, scalarType);
if (!arg) throw python_error();
PyTuple_SET_ITEM(args.get(), i, arg);
}
auto ret = THPObjectPtr(PyObject_CallObject(fn, args.get()));
if (!ret) throw python_error();
store_scalar(strided_data[0].data, scalarType, ret.get());
return;
}
auto n = sizes[dim];
for(const auto i : c10::irange(n)) {
(void)i; // Suppress unused variable warning
recursive_apply(sizes, scalarType, dim + 1, fn, strided_data);
for (auto& td : strided_data) {
td.step(dim);
}
}
}
const Tensor & apply_(const Tensor & self, PyObject* fn) {
if (self.is_meta()) {
return self; // Just skip
}
if (!self.device().is_cpu()) {
throw TypeError("apply_ is only implemented on CPU tensors");
}
auto scalarType = self.scalar_type();
recursive_apply<1>(self.sizes(), scalarType, 0, fn, {{ self }});
return self;
}
const Tensor & map_(const Tensor & self, const Tensor & other_, PyObject* fn) {
if (!other_.options().type_equal(self.options())) {
throw TypeError("map_: expected %s for 'other' (got %s)",
self.toString().c_str(), other_.toString().c_str());
}
if (self.is_meta()) {
return self; // Just skip
}
if (!self.device().is_cpu()) {
throw TypeError("map_ is only implemented on CPU tensors");
}
c10::MaybeOwned<Tensor> other = expand_inplace(self, other_, "map_");
auto scalarType = self.scalar_type();
recursive_apply<2>(self.sizes(), scalarType, 0, fn, {{ self, *other }});
return self;
}
const Tensor & map2_(const Tensor & self, const Tensor & x_, const Tensor & y_, PyObject* fn) {
if (!x_.options().type_equal(self.options())) {
throw TypeError("map2_: expected %s for argument 'x' (got %s)",
self.toString().c_str(), x_.toString().c_str());
}
if (!y_.options().type_equal(self.options())) {
throw TypeError("map2_: expected %s for argument 'y' (got %s)",
self.toString().c_str(), y_.toString().c_str());
}
if (self.is_meta()) {
return self; // Just skip
}
if (!self.device().is_cpu() || !x_.device().is_cpu() || !y_.device().is_cpu()) {
throw TypeError("map2_ is only implemented on CPU tensors");
}
auto others = expand_inplace(self, x_, y_, "map2_");
auto scalarType = self.scalar_type();
recursive_apply<3>(self.sizes(), scalarType, 0, fn, {{ self, *std::get<0>(others), *std::get<1>(others) }});
return self;
}
}} // namespace torch::utils