#include #include #include #include #include #include #include using namespace at; namespace torch { namespace utils { struct StridedData { StridedData(const Tensor& tensor) : data(tensor.data_ptr()), strides(tensor.strides()), elementSize(tensor.element_size()) {} void* data; IntArrayRef strides; int64_t elementSize; void step(int dim) { data = (char*)data + (strides[dim] * elementSize); } }; template static void recursive_apply( IntArrayRef sizes, ScalarType scalarType, int64_t dim, PyObject* fn, std::array strided_data) { int64_t ndim = static_cast(sizes.size()); if (dim == ndim) { auto args = THPObjectPtr(PyTuple_New(N)); if (!args) throw python_error(); for (const auto i : c10::irange(N)) { PyObject* arg = load_scalar(strided_data[i].data, scalarType); if (!arg) throw python_error(); PyTuple_SET_ITEM(args.get(), i, arg); } auto ret = THPObjectPtr(PyObject_CallObject(fn, args.get())); if (!ret) throw python_error(); store_scalar(strided_data[0].data, scalarType, ret.get()); return; } auto n = sizes[dim]; for (const auto i : c10::irange(n)) { (void)i; // Suppress unused variable warning recursive_apply(sizes, scalarType, dim + 1, fn, strided_data); for (auto& td : strided_data) { td.step(dim); } } } const Tensor& apply_(const Tensor& self, PyObject* fn) { if (self.is_meta()) { return self; // Just skip } if (!self.device().is_cpu()) { throw TypeError("apply_ is only implemented on CPU tensors"); } auto scalarType = self.scalar_type(); recursive_apply<1>(self.sizes(), scalarType, 0, fn, {{self}}); return self; } const Tensor& map_(const Tensor& self, const Tensor& other_, PyObject* fn) { if (!other_.options().type_equal(self.options())) { throw TypeError( "map_: expected %s for 'other' (got %s)", self.toString().c_str(), other_.toString().c_str()); } if (self.is_meta()) { return self; // Just skip } if (!self.device().is_cpu()) { throw TypeError("map_ is only implemented on CPU tensors"); } c10::MaybeOwned other = expand_inplace(self, other_, "map_"); auto scalarType = self.scalar_type(); recursive_apply<2>(self.sizes(), scalarType, 0, fn, {{self, *other}}); return self; } const Tensor& map2_( const Tensor& self, const Tensor& x_, const Tensor& y_, PyObject* fn) { if (!x_.options().type_equal(self.options())) { throw TypeError( "map2_: expected %s for argument 'x' (got %s)", self.toString().c_str(), x_.toString().c_str()); } if (!y_.options().type_equal(self.options())) { throw TypeError( "map2_: expected %s for argument 'y' (got %s)", self.toString().c_str(), y_.toString().c_str()); } if (self.is_meta()) { return self; // Just skip } if (!self.device().is_cpu() || !x_.device().is_cpu() || !y_.device().is_cpu()) { throw TypeError("map2_ is only implemented on CPU tensors"); } auto others = expand_inplace(self, x_, y_, "map2_"); auto scalarType = self.scalar_type(); recursive_apply<3>( self.sizes(), scalarType, 0, fn, {{self, *std::get<0>(others), *std::get<1>(others)}}); return self; } } // namespace utils } // namespace torch