mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Support non-contiguous tensors for unary ops (#6119)
This commit is contained in:
parent
a6bfa16c17
commit
ae35e0e924
|
|
@ -1,6 +1,6 @@
|
|||
#pragma once
|
||||
|
||||
#include <sstream>
|
||||
#include "ATen/Parallel.h"
|
||||
#include "ATen/TensorUtils.h"
|
||||
|
||||
namespace at {
|
||||
|
|
@ -8,11 +8,11 @@ namespace at {
|
|||
/*
|
||||
* The basic strategy for apply is as follows:
|
||||
*
|
||||
* 1. Starting with the outermost index, loop until we reach a dimension where the
|
||||
* data is no longer contiguous, i.e. the stride at that dimension is not equal to
|
||||
* the size of the tensor defined by the outer dimensions. Let's call this outer
|
||||
* (contiguous) tensor A. Note that if the Tensor is contiguous, then A is equal
|
||||
* to the entire Tensor. Let's call the inner tensor B.
|
||||
* 1. Starting with the outermost index, loop until we reach a dimension where
|
||||
* the data is no longer contiguous, i.e. the stride at that dimension is not
|
||||
* equal to the size of the tensor defined by the outer dimensions. Let's call
|
||||
* this outer (contiguous) tensor A. Note that if the Tensor is contiguous, then
|
||||
* A is equal to the entire Tensor. Let's call the inner tensor B.
|
||||
*
|
||||
* 2. We loop through the indices in B, starting at its outermost dimension. For
|
||||
* example, if B is a 2x2 matrix, then we do:
|
||||
|
|
@ -22,289 +22,370 @@ namespace at {
|
|||
* B[1][0]
|
||||
* B[1][1]
|
||||
*
|
||||
* We set the offset into the underlying storage as (storageOffset + stride_B * index_B),
|
||||
* i.e. basically we compute the offset into the storage as we would normally for a
|
||||
* Tensor. But because we are guaranteed the subsequent data is contiguous in memory, we
|
||||
* can simply loop for sizeof(A) iterations and perform the operation, without having to
|
||||
* follow the order described by the strides of A.
|
||||
* We set the offset into the underlying storage as (storageOffset + stride_B *
|
||||
* index_B), i.e. basically we compute the offset into the storage as we would
|
||||
* normally for a Tensor. But because we are guaranteed the subsequent data is
|
||||
* contiguous in memory, we can simply loop for sizeof(A) iterations and perform
|
||||
* the operation, without having to follow the order described by the strides of
|
||||
* A.
|
||||
*
|
||||
* 3. As an optimization, we merge dimensions of A that are contiguous in memory. For
|
||||
* example, if A is a 3x3x3x3 tensor narrowed from a 3x3x4x3 tensor, then the first two
|
||||
* dimensions can be merged for the purposes of APPLY, reducing the number of nested
|
||||
* loops.
|
||||
* 3. As an optimization, we merge dimensions of A that are contiguous in
|
||||
* memory. For example, if A is a 3x3x3x3 tensor narrowed from a 3x3x4x3 tensor,
|
||||
* then the first two dimensions can be merged for the purposes of APPLY,
|
||||
* reducing the number of nested loops.
|
||||
*/
|
||||
|
||||
// TODO: turn this macro into a proper template
|
||||
#define __ATH_TENSOR_APPLYX_PREAMBLE(TYPE, ATENSOR, DIM, ALLOW_CONTIGUOUS) \
|
||||
TYPE *ATENSOR##_data = NULL; \
|
||||
int64_t *ATENSOR##_counter = NULL, *ATENSOR##_sizes = NULL, *ATENSOR##_strides = NULL, *ATENSOR##_dimOffset = NULL; \
|
||||
int64_t ATENSOR##_stride = 0, ATENSOR##_size = 0, ATENSOR##_dim = 0, ATENSOR##_i; \
|
||||
int ATENSOR##_contiguous = ALLOW_CONTIGUOUS && DIM < 0; \
|
||||
\
|
||||
if(ATENSOR.sizes().equals({0})) \
|
||||
TH_TENSOR_APPLY_hasFinished = true; \
|
||||
else \
|
||||
{ \
|
||||
ATENSOR##_data = ATENSOR.data<TYPE>(); \
|
||||
ATENSOR##_size = 1; \
|
||||
ATENSOR##_stride = 1; \
|
||||
for(ATENSOR##_i = ATENSOR.dim() - 1; ATENSOR##_i >= 0; ATENSOR##_i--) { \
|
||||
if(ATENSOR.sizes()[ATENSOR##_i] != 1) { \
|
||||
if(ATENSOR.strides()[ATENSOR##_i] == ATENSOR##_size && ATENSOR##_i != DIM) \
|
||||
ATENSOR##_size *= ATENSOR.sizes()[ATENSOR##_i]; \
|
||||
else{ \
|
||||
ATENSOR##_contiguous = 0; \
|
||||
break; \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
if (!ATENSOR##_contiguous) { \
|
||||
/* Find the dimension of contiguous sections */ \
|
||||
ATENSOR##_dim = 1; \
|
||||
for(ATENSOR##_i = ATENSOR.dim() - 2; ATENSOR##_i >= 0; ATENSOR##_i--) \
|
||||
{ \
|
||||
if(ATENSOR.strides()[ATENSOR##_i] != ATENSOR.strides()[ATENSOR##_i+1] * ATENSOR.sizes()[ATENSOR##_i+1] || ATENSOR##_i == DIM || ATENSOR##_i+1 == DIM) \
|
||||
ATENSOR##_dim++; \
|
||||
} \
|
||||
/* Allocate an array of 3*dim elements, where dim is the number of contiguous sections */ \
|
||||
ATENSOR##_counter = new int64_t[3*ATENSOR##_dim]; \
|
||||
ATENSOR##_sizes = ATENSOR##_counter + ATENSOR##_dim; \
|
||||
ATENSOR##_strides = ATENSOR##_counter + 2*ATENSOR##_dim; \
|
||||
TH_TENSOR_dim_index = ATENSOR##_dim-1; \
|
||||
ATENSOR##_dimOffset = (DIM == ATENSOR.dim()-1) ? &ATENSOR##_i : &ATENSOR##_counter[DIM]; \
|
||||
ATENSOR##_sizes[TH_TENSOR_dim_index] = ATENSOR.sizes()[ATENSOR.dim()-1]; \
|
||||
ATENSOR##_strides[TH_TENSOR_dim_index] = ATENSOR.strides()[ATENSOR.dim()-1]; \
|
||||
/* ATENSOR##_counter tracks where we are in the storage. The offset into the */ \
|
||||
/* storage is given by storage_offset + (i * j), where i is the stride */ \
|
||||
/* vector and j is tensor_counter vector. This sets the starting position for the loop. */ \
|
||||
for(ATENSOR##_i = ATENSOR##_dim-1; ATENSOR##_i >= 0; --ATENSOR##_i) { \
|
||||
ATENSOR##_counter[ATENSOR##_i] = 0; \
|
||||
} \
|
||||
for(ATENSOR##_i = ATENSOR.dim()-2; ATENSOR##_i >= 0; --ATENSOR##_i) { \
|
||||
if (ATENSOR.strides()[ATENSOR##_i] == ATENSOR.strides()[ATENSOR##_i+1] * ATENSOR.sizes()[ATENSOR##_i+1] && ATENSOR##_i != DIM && ATENSOR##_i+1 != DIM) { \
|
||||
ATENSOR##_sizes[TH_TENSOR_dim_index] = ATENSOR.sizes()[ATENSOR##_i] * ATENSOR##_sizes[TH_TENSOR_dim_index]; \
|
||||
if (DIM != ATENSOR.dim()-1 && ATENSOR##_i < DIM) \
|
||||
ATENSOR##_dimOffset--; \
|
||||
} else { \
|
||||
--TH_TENSOR_dim_index; \
|
||||
ATENSOR##_sizes[TH_TENSOR_dim_index] = ATENSOR.sizes()[ATENSOR##_i]; \
|
||||
ATENSOR##_strides[TH_TENSOR_dim_index] = ATENSOR.strides()[ATENSOR##_i]; \
|
||||
} \
|
||||
} \
|
||||
/* Size of the inner most section */ \
|
||||
ATENSOR##_size = ATENSOR##_sizes[ATENSOR##_dim-1]; \
|
||||
/* Stride of the inner most section */ \
|
||||
ATENSOR##_stride = ATENSOR##_strides[ATENSOR##_dim-1]; \
|
||||
} \
|
||||
} \
|
||||
ATENSOR##_i = 0;
|
||||
|
||||
// TODO: turn this macro into a proper template
|
||||
#define __ATH_TENSOR_APPLYX_UPDATE_COUNTERS(ATENSOR, ALWAYS_UPDATE) \
|
||||
if(ATENSOR##_i == ATENSOR##_size || ALWAYS_UPDATE) \
|
||||
{ \
|
||||
if(ATENSOR##_contiguous) \
|
||||
break; \
|
||||
\
|
||||
if(ATENSOR##_dim == 1) \
|
||||
break; \
|
||||
\
|
||||
/* Reset pointer to beginning of loop */ \
|
||||
ATENSOR##_data -= ATENSOR##_size*ATENSOR##_stride; \
|
||||
for(ATENSOR##_i = ATENSOR##_dim-2; ATENSOR##_i >= 0; ATENSOR##_i--) \
|
||||
{ \
|
||||
ATENSOR##_counter[ATENSOR##_i]++; \
|
||||
/* Jump ahread by the stride of this dimension */ \
|
||||
ATENSOR##_data += ATENSOR##_strides[ATENSOR##_i]; \
|
||||
\
|
||||
if(ATENSOR##_counter[ATENSOR##_i] == ATENSOR##_sizes[ATENSOR##_i]) \
|
||||
{ \
|
||||
if(ATENSOR##_i == 0) \
|
||||
{ \
|
||||
TH_TENSOR_APPLY_hasFinished = true; \
|
||||
break; \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
/* Reset the pointer to the beginning of the chunk defined by this dimension */ \
|
||||
ATENSOR##_data -= ATENSOR##_counter[ATENSOR##_i]*ATENSOR##_strides[ATENSOR##_i]; \
|
||||
ATENSOR##_counter[ATENSOR##_i] = 0; \
|
||||
} \
|
||||
} \
|
||||
else \
|
||||
break; \
|
||||
} \
|
||||
ATENSOR##_i = 0; \
|
||||
inline Tensor sort_strides(Tensor& tensor_) {
|
||||
IntList strides = tensor_.strides();
|
||||
std::vector<int64_t> indices;
|
||||
indices.reserve(tensor_.ndimension());
|
||||
for (int64_t i = 0; i < tensor_.ndimension(); i++) {
|
||||
indices.push_back(i);
|
||||
}
|
||||
std::sort(indices.begin(), indices.end(), [&strides](int64_t i1, int64_t i2) {
|
||||
return strides[i1] > strides[i2];
|
||||
});
|
||||
Tensor tensor = tensor_.permute(indices);
|
||||
return tensor;
|
||||
}
|
||||
|
||||
template <typename Arg>
|
||||
inline void _setup_arrays(Tensor& tensor, Arg* iter) {
|
||||
int64_t max_dim = tensor.ndimension();
|
||||
iter->dim_ = 0;
|
||||
for (int64_t i = 0; i < max_dim; i++) {
|
||||
int64_t size = tensor.size(i);
|
||||
int64_t stride = tensor.stride(i);
|
||||
while (i + 1 < max_dim &&
|
||||
(tensor.size(i + 1) == 1 ||
|
||||
tensor.stride(i) == tensor.size(i + 1) * tensor.stride(i + 1))) {
|
||||
size = size * tensor.size(i + 1);
|
||||
if (tensor.size(i + 1) != 1)
|
||||
stride = tensor.stride(i + 1);
|
||||
i++;
|
||||
}
|
||||
iter->sizes_[iter->dim_] = size;
|
||||
iter->strides_[iter->dim_] = stride;
|
||||
iter->dim_++;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int N>
|
||||
struct strided_tensor_iter_fixed {
|
||||
public:
|
||||
T* data_ = NULL;
|
||||
int64_t dim_;
|
||||
|
||||
int64_t counter_[N];
|
||||
int64_t sizes_[N];
|
||||
int64_t strides_[N];
|
||||
|
||||
strided_tensor_iter_fixed(strided_tensor_iter_fixed const&) = delete;
|
||||
void operator=(strided_tensor_iter_fixed const& x) = delete;
|
||||
strided_tensor_iter_fixed(strided_tensor_iter_fixed&&) = default;
|
||||
strided_tensor_iter_fixed(Tensor& tensor, bool sort_strides = false)
|
||||
: data_(tensor.data<T>()) {
|
||||
memset(counter_, 0, sizeof(int64_t) * N);
|
||||
_setup_arrays(tensor, this);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct strided_tensor_iter {
|
||||
private:
|
||||
public:
|
||||
T* data_ = NULL;
|
||||
int64_t dim_;
|
||||
|
||||
std::vector<int64_t> counter_;
|
||||
std::vector<int64_t> sizes_;
|
||||
std::vector<int64_t> strides_;
|
||||
|
||||
strided_tensor_iter(strided_tensor_iter const&) = delete;
|
||||
void operator=(strided_tensor_iter const& x) = delete;
|
||||
strided_tensor_iter(strided_tensor_iter&&) = default;
|
||||
strided_tensor_iter(Tensor& tensor)
|
||||
: data_(tensor.data<T>()),
|
||||
dim_(tensor.ndimension()),
|
||||
counter_(dim_, 0),
|
||||
sizes_(tensor.sizes()),
|
||||
strides_(tensor.strides()) {
|
||||
_setup_arrays(tensor, this);
|
||||
}
|
||||
};
|
||||
|
||||
inline bool _all_equal_numel(at::ArrayRef<Tensor> tensors) {
|
||||
if (tensors.size() == 0)
|
||||
return true;
|
||||
int64_t all_numel = tensors[0].numel();
|
||||
for (size_t i = 1; i < tensors.size(); i++) {
|
||||
if (tensors[i].numel() != all_numel)
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
inline std::string _all_equal_numel_error(at::ArrayRef<Tensor> tensors) {
|
||||
std::ostringstream oss;
|
||||
oss << "inconsistent tensor size, expected ";
|
||||
for (size_t i = 0; i < tensors.size() - 1; i++) {
|
||||
oss << tensors[i].sizes() << ", ";
|
||||
}
|
||||
oss << "and " << tensors[tensors.size() - 1]
|
||||
<< " to have the same number of elements, but got ";
|
||||
for (size_t i = 0; i < tensors.size() - 1; i++) {
|
||||
oss << tensors[i].numel() << ", ";
|
||||
}
|
||||
oss << "and " << tensors[tensors.size() - 1].numel()
|
||||
<< " elements respectively";
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
inline bool _apply_preamble(ArrayRef<Tensor> tensors) {
|
||||
checkBackend("CPU_tensor_apply", tensors, Backend::CPU);
|
||||
if (!_all_equal_numel(tensors))
|
||||
throw std::runtime_error(_all_equal_numel_error(tensors));
|
||||
// An empty tensor has no elements
|
||||
for (auto& t : tensors)
|
||||
if (t.sizes().equals({0}))
|
||||
return false;
|
||||
internal::init_tbb_num_threads();
|
||||
return true;
|
||||
}
|
||||
|
||||
inline int64_t _max_dim_tensors(ArrayRef<Tensor> tensors) {
|
||||
int64_t dim = 0;
|
||||
for (auto& t : tensors)
|
||||
dim = std::max(dim, t.ndimension());
|
||||
return dim;
|
||||
}
|
||||
|
||||
inline void iterate(){};
|
||||
|
||||
template <typename Arg, typename... Args>
|
||||
inline void iterate(Arg& iter, Args&... iter_tail) {
|
||||
iter.counter_[iter.dim_ - 1]++;
|
||||
iter.data_ += iter.strides_[iter.dim_ - 1];
|
||||
if (iter.counter_[iter.dim_ - 1] == iter.sizes_[iter.dim_ - 1]) {
|
||||
for (int64_t i = iter.dim_ - 1; i > 0; i--) {
|
||||
if (iter.counter_[i] == iter.sizes_[i]) {
|
||||
iter.counter_[i] = 0;
|
||||
iter.counter_[i - 1]++;
|
||||
iter.data_ = iter.data_ - (iter.sizes_[i] * iter.strides_[i]) +
|
||||
iter.strides_[i - 1];
|
||||
}
|
||||
}
|
||||
}
|
||||
iterate(iter_tail...);
|
||||
}
|
||||
|
||||
inline void forward(int64_t offset){};
|
||||
|
||||
template <typename Arg, typename... Args>
|
||||
inline void forward(int64_t offset, Arg& iter, Args&... iter_tail) {
|
||||
int64_t multi = offset;
|
||||
for (int64_t i = iter.dim_ - 1; i >= 0; i--) {
|
||||
int64_t inc = multi % iter.sizes_[i];
|
||||
multi = multi / iter.sizes_[i];
|
||||
iter.data_ = iter.data_ + inc * iter.strides_[i];
|
||||
iter.counter_[i] += inc;
|
||||
}
|
||||
forward(offset, iter_tail...);
|
||||
}
|
||||
|
||||
inline int64_t max_dim() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <typename Arg, typename... Args>
|
||||
inline int64_t max_dim(Arg& iter, Args&... iter_tail) {
|
||||
return std::max(iter.dim_, max_dim(iter_tail...));
|
||||
}
|
||||
|
||||
inline void apply_op(){};
|
||||
|
||||
template <typename Op, typename... Args>
|
||||
inline void
|
||||
apply_op(int64_t numel, int64_t offset, const Op& op, Args... iters) {
|
||||
// For 0-dim tensors
|
||||
if (numel == 1 && max_dim(iters...) == 0) {
|
||||
op(*iters.data_...);
|
||||
return;
|
||||
}
|
||||
if (offset > 0)
|
||||
forward(offset, iters...);
|
||||
for (int64_t i = 0; i < numel; i++) {
|
||||
op(*iters.data_...);
|
||||
iterate(iters...);
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
Apply a pointwise operator to sequence of tensors
|
||||
|
||||
The calling convention for op is a function/functor that takes takes the same
|
||||
number of pointers of type scalar as the number of given tensors. For example,
|
||||
to compute a = b * c, op would be of the form:
|
||||
[](scalar* a_val, const scalar* b_val, const scalar* c_val) { a_val[0] =
|
||||
b_val[0] * c_val[0]; };
|
||||
*/
|
||||
|
||||
template <typename scalar1, typename Op>
|
||||
inline void CPU_tensor_apply1(Tensor tensor1, const Op op) {
|
||||
if (!_apply_preamble({tensor1}))
|
||||
return;
|
||||
if (tensor1.ndimension() < 8) {
|
||||
apply_op(
|
||||
tensor1.numel(),
|
||||
0,
|
||||
op,
|
||||
strided_tensor_iter_fixed<scalar1, 8>(tensor1, true));
|
||||
} else {
|
||||
apply_op(tensor1.numel(), 0, op, strided_tensor_iter<scalar1>(tensor1));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar1, typename scalar2, typename Op>
|
||||
void CPU_tensor_apply2_dim(Tensor& tensor1, Tensor& tensor2, int64_t dim, Op op) {
|
||||
checkBackend("CPU_tensor_apply2", {tensor1, tensor2}, Backend::CPU);
|
||||
bool TH_TENSOR_APPLY_hasFinished = false;
|
||||
int64_t TH_TENSOR_dim_index = 0;
|
||||
__ATH_TENSOR_APPLYX_PREAMBLE(scalar1, tensor1, dim, 1)
|
||||
__ATH_TENSOR_APPLYX_PREAMBLE(scalar2, tensor2, dim, 1)
|
||||
auto t1_numel = tensor1.numel();
|
||||
auto t2_numel = tensor2.numel();
|
||||
if(t1_numel != t2_numel) {
|
||||
std::ostringstream oss;
|
||||
oss << "inconsistent tensor size, expected " << tensor1.sizes() << " and " << tensor2.sizes()
|
||||
<< " to have the same number of elements, but got " << t1_numel << " and " << t2_numel << " elements respectively";
|
||||
throw std::runtime_error(oss.str());
|
||||
inline void CPU_tensor_apply2(Tensor tensor1, Tensor tensor2, const Op op) {
|
||||
if (!_apply_preamble({tensor1, tensor2}))
|
||||
return;
|
||||
if (_max_dim_tensors({tensor1, tensor2}) <= 8) {
|
||||
apply_op(
|
||||
tensor1.numel(),
|
||||
0,
|
||||
op,
|
||||
strided_tensor_iter_fixed<scalar1, 8>(tensor1),
|
||||
strided_tensor_iter_fixed<scalar2, 8>(tensor2));
|
||||
} else {
|
||||
apply_op(
|
||||
tensor1.numel(),
|
||||
0,
|
||||
op,
|
||||
strided_tensor_iter<scalar1>(tensor1),
|
||||
strided_tensor_iter<scalar2>(tensor2));
|
||||
}
|
||||
while(!TH_TENSOR_APPLY_hasFinished)
|
||||
{
|
||||
/* Loop through the inner most region of the Tensor */
|
||||
for(; tensor1_i < tensor1_size && tensor2_i < tensor2_size; tensor1_i++, tensor2_i++, tensor1_data += tensor1_stride, tensor2_data += tensor2_stride)
|
||||
{
|
||||
op(*tensor1_data, *tensor2_data);
|
||||
}
|
||||
__ATH_TENSOR_APPLYX_UPDATE_COUNTERS(tensor1, 0)
|
||||
__ATH_TENSOR_APPLYX_UPDATE_COUNTERS(tensor2, 0)
|
||||
}
|
||||
if(tensor1_counter != NULL)
|
||||
delete [] tensor1_counter;
|
||||
if(tensor2_counter != NULL)
|
||||
delete [] tensor2_counter;
|
||||
}
|
||||
|
||||
/*
|
||||
Apply a pointwise operator to two tensors.
|
||||
|
||||
The calling convention for op is a function/functor that takes takes two references to
|
||||
type scalar; at least one of these references should be non-const in order to write the output.
|
||||
For example, to compute a = b^2, op would be of the form:
|
||||
[](scalar &a_val, const scalar &b_val) { a_val = b_val * b_val; };
|
||||
*/
|
||||
template<typename scalar1, typename scalar2, typename Op>
|
||||
void CPU_tensor_apply2(Tensor tensor1, Tensor tensor2, Op op) {
|
||||
CPU_tensor_apply2_dim<scalar1, scalar2, Op>(tensor1, tensor2, -1, op);
|
||||
template <typename scalar1, typename scalar2, typename scalar3, typename Op>
|
||||
inline void
|
||||
CPU_tensor_apply3(Tensor tensor1, Tensor tensor2, Tensor tensor3, const Op op) {
|
||||
if (!_apply_preamble({tensor1, tensor2, tensor3}))
|
||||
return;
|
||||
if (_max_dim_tensors({tensor1, tensor2, tensor3}) <= 8) {
|
||||
apply_op(
|
||||
tensor1.numel(),
|
||||
0,
|
||||
op,
|
||||
strided_tensor_iter_fixed<scalar1, 8>(tensor1),
|
||||
strided_tensor_iter_fixed<scalar2, 8>(tensor2),
|
||||
strided_tensor_iter_fixed<scalar3, 8>(tensor3));
|
||||
} else {
|
||||
apply_op(
|
||||
tensor1.numel(),
|
||||
0,
|
||||
op,
|
||||
strided_tensor_iter<scalar1>(tensor1),
|
||||
strided_tensor_iter<scalar2>(tensor2),
|
||||
strided_tensor_iter<scalar3>(tensor3));
|
||||
}
|
||||
}
|
||||
|
||||
template<typename scalar1, typename scalar2, typename scalar3, typename Op>
|
||||
void CPU_tensor_apply3_dim(Tensor &tensor1, Tensor& tensor2, Tensor& tensor3, int64_t dim, Op op) {
|
||||
checkBackend("CPU_tensor_apply3", {tensor1, tensor2, tensor3}, Backend::CPU);
|
||||
bool TH_TENSOR_APPLY_hasFinished = false;
|
||||
int64_t TH_TENSOR_dim_index = 0;
|
||||
__ATH_TENSOR_APPLYX_PREAMBLE(scalar1, tensor1, dim, 1)
|
||||
__ATH_TENSOR_APPLYX_PREAMBLE(scalar2, tensor2, dim, 1)
|
||||
__ATH_TENSOR_APPLYX_PREAMBLE(scalar3, tensor3, dim, 1)
|
||||
|
||||
int elements_equal = 1;
|
||||
auto t1_numel = tensor1.numel();
|
||||
auto t2_numel = tensor2.numel();
|
||||
auto t3_numel = tensor3.numel();
|
||||
if(t1_numel!= t2_numel) {
|
||||
elements_equal = 0;
|
||||
} else if(t1_numel != t3_numel) {
|
||||
elements_equal = 0;
|
||||
template <
|
||||
typename scalar1,
|
||||
typename scalar2,
|
||||
typename scalar3,
|
||||
typename scalar4,
|
||||
typename Op>
|
||||
inline void CPU_tensor_apply4(
|
||||
Tensor tensor1,
|
||||
Tensor tensor2,
|
||||
Tensor tensor3,
|
||||
Tensor tensor4,
|
||||
const Op op) {
|
||||
if (!_apply_preamble({tensor1, tensor2, tensor3, tensor4}))
|
||||
return;
|
||||
if (_max_dim_tensors({tensor1, tensor2, tensor3, tensor4}) <= 8) {
|
||||
apply_op(
|
||||
tensor1.numel(),
|
||||
0,
|
||||
op,
|
||||
strided_tensor_iter_fixed<scalar1, 8>(tensor1),
|
||||
strided_tensor_iter_fixed<scalar2, 8>(tensor2),
|
||||
strided_tensor_iter_fixed<scalar3, 8>(tensor3),
|
||||
strided_tensor_iter_fixed<scalar4, 8>(tensor4));
|
||||
} else {
|
||||
apply_op(
|
||||
tensor1.numel(),
|
||||
0,
|
||||
op,
|
||||
strided_tensor_iter<scalar1>(tensor1),
|
||||
strided_tensor_iter<scalar2>(tensor2),
|
||||
strided_tensor_iter<scalar3>(tensor3),
|
||||
strided_tensor_iter<scalar4>(tensor4));
|
||||
}
|
||||
if (elements_equal == 0) {
|
||||
std::ostringstream oss;
|
||||
oss << "inconsistent tensor size, expected " << tensor1.sizes() << ", " << tensor2.sizes() << ", and " << tensor3.sizes()
|
||||
<< " to have the same number of elements, but got " << t1_numel << ", " << t2_numel << ", and " << t3_numel << " elements respectively";
|
||||
throw std::runtime_error(oss.str());
|
||||
}
|
||||
|
||||
while(!TH_TENSOR_APPLY_hasFinished)
|
||||
{
|
||||
/* Loop through the inner most region of the Tensor */
|
||||
for(; tensor1_i < tensor1_size && tensor2_i < tensor2_size && tensor3_i < tensor3_size; tensor1_i++, tensor2_i++, tensor3_i++, tensor1_data += tensor1_stride, tensor2_data += tensor2_stride, tensor3_data += tensor3_stride)
|
||||
{
|
||||
op(*tensor1_data, *tensor2_data, *tensor3_data);
|
||||
}
|
||||
__ATH_TENSOR_APPLYX_UPDATE_COUNTERS(tensor1, 0)
|
||||
__ATH_TENSOR_APPLYX_UPDATE_COUNTERS(tensor2, 0)
|
||||
__ATH_TENSOR_APPLYX_UPDATE_COUNTERS(tensor3, 0)
|
||||
}
|
||||
if(tensor1_counter != NULL)
|
||||
delete [] tensor1_counter;
|
||||
if(tensor2_counter != NULL)
|
||||
delete [] tensor2_counter;
|
||||
if(tensor3_counter != NULL)
|
||||
delete [] tensor3_counter;
|
||||
}
|
||||
|
||||
/*
|
||||
Apply a pointwise operator to three tensors.
|
||||
|
||||
The calling convention for op is a function/functor that takes takes three references to
|
||||
type scalar; at least one of these references should be non-const in order to write the output.
|
||||
For example, to compute a = b + c, op would be of the form:
|
||||
[](scalar &a_val, const scalar &b_val, const scalar &c_val) { a_val = b_val + c_val; };
|
||||
*/
|
||||
template<typename scalar1, typename scalar2, typename scalar3, typename Op>
|
||||
void CPU_tensor_apply3(Tensor tensor1, Tensor tensor2, Tensor tensor3, Op op) {
|
||||
CPU_tensor_apply3_dim<scalar1, scalar2, scalar3, Op>(tensor1, tensor2, tensor3, -1, op);
|
||||
}
|
||||
|
||||
template <typename scalar1, typename scalar2, typename scalar3, typename scalar4, typename Op>
|
||||
void CPU_tensor_apply4_dim(Tensor &tensor1, Tensor& tensor2, Tensor& tensor3, Tensor& tensor4, int64_t dim, Op op) {
|
||||
checkBackend("CPU_tensor_apply4", {tensor1, tensor2, tensor3, tensor4}, Backend::CPU);
|
||||
bool TH_TENSOR_APPLY_hasFinished = false;
|
||||
int64_t TH_TENSOR_dim_index = 0;
|
||||
__ATH_TENSOR_APPLYX_PREAMBLE(scalar1, tensor1, dim, 1)
|
||||
__ATH_TENSOR_APPLYX_PREAMBLE(scalar2, tensor2, dim, 1)
|
||||
__ATH_TENSOR_APPLYX_PREAMBLE(scalar3, tensor3, dim, 1)
|
||||
__ATH_TENSOR_APPLYX_PREAMBLE(scalar4, tensor4, dim, 1)
|
||||
|
||||
int elements_equal = 1;
|
||||
auto t1_numel = tensor1.numel();
|
||||
auto t2_numel = tensor2.numel();
|
||||
auto t3_numel = tensor3.numel();
|
||||
auto t4_numel = tensor4.numel();
|
||||
if(t1_numel!= t2_numel) {
|
||||
elements_equal = 0;
|
||||
} else if(t1_numel != t3_numel) {
|
||||
elements_equal = 0;
|
||||
} else if(t1_numel != t4_numel) {
|
||||
elements_equal = 0;
|
||||
template <typename scalar1, typename Op>
|
||||
inline void CPU_tensor_parallel_apply1(Tensor tensor1, const Op op) {
|
||||
if (!_apply_preamble({tensor1}))
|
||||
return;
|
||||
if (tensor1.numel() < internal::TBB_GRAIN_SIZE) {
|
||||
CPU_tensor_apply1<scalar1>(tensor1, op);
|
||||
return;
|
||||
}
|
||||
if (elements_equal == 0) {
|
||||
std::ostringstream oss;
|
||||
oss << "inconsistent tensor size, expected " << tensor1.sizes() << ", " << tensor2.sizes() << ", "
|
||||
<< tensor3.sizes() << ", and " << tensor4.sizes() << " to have the same number of elements, but got "
|
||||
<< t1_numel << ", " << t2_numel << ", " << t3_numel << ", and " << t4_numel << " elements respectively";
|
||||
throw std::runtime_error(oss.str());
|
||||
auto range = tbb::blocked_range<size_t>(0, tensor1.numel());
|
||||
if (tensor1.ndimension() < 8) {
|
||||
tbb::parallel_for(
|
||||
range, [&tensor1, &op](const tbb::blocked_range<size_t> r) {
|
||||
apply_op(
|
||||
r.end() - r.begin(),
|
||||
r.begin(),
|
||||
op,
|
||||
strided_tensor_iter_fixed<scalar1, 8>(tensor1, true));
|
||||
});
|
||||
} else {
|
||||
tbb::parallel_for(
|
||||
range, [&tensor1, &op](const tbb::blocked_range<size_t> r) {
|
||||
apply_op(
|
||||
r.end() - r.begin(),
|
||||
r.begin(),
|
||||
op,
|
||||
strided_tensor_iter<scalar1>(tensor1));
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
while(!TH_TENSOR_APPLY_hasFinished)
|
||||
{
|
||||
/* Loop through the inner most region of the Tensor */
|
||||
for(; tensor1_i < tensor1_size && tensor2_i < tensor2_size && tensor3_i < tensor3_size && tensor4_i < tensor4_size
|
||||
; tensor1_i++, tensor2_i++, tensor3_i++, tensor4_i++,
|
||||
tensor1_data += tensor1_stride, tensor2_data += tensor2_stride, tensor3_data += tensor3_stride, tensor4_data += tensor4_stride)
|
||||
{
|
||||
op(*tensor1_data, *tensor2_data, *tensor3_data, *tensor4_data);
|
||||
}
|
||||
__ATH_TENSOR_APPLYX_UPDATE_COUNTERS(tensor1, 0)
|
||||
__ATH_TENSOR_APPLYX_UPDATE_COUNTERS(tensor2, 0)
|
||||
__ATH_TENSOR_APPLYX_UPDATE_COUNTERS(tensor3, 0)
|
||||
__ATH_TENSOR_APPLYX_UPDATE_COUNTERS(tensor4, 0)
|
||||
template <typename scalar1, typename scalar2, typename Op>
|
||||
inline void
|
||||
CPU_tensor_parallel_apply2(Tensor tensor1, Tensor tensor2, const Op op) {
|
||||
if (!_apply_preamble({tensor1, tensor2}))
|
||||
return;
|
||||
if ((tensor1.numel() + tensor2.numel()) < internal::TBB_GRAIN_SIZE) {
|
||||
CPU_tensor_apply2<scalar1, scalar2>(tensor1, tensor2, op);
|
||||
return;
|
||||
}
|
||||
auto range = tbb::blocked_range<size_t>(0, tensor1.numel());
|
||||
if (tensor1.ndimension() < 8 && tensor2.ndimension() < 8) {
|
||||
tbb::parallel_for(
|
||||
range, [&tensor1, &tensor2, &op](const tbb::blocked_range<size_t> r) {
|
||||
apply_op(
|
||||
r.end() - r.begin(),
|
||||
r.begin(),
|
||||
op,
|
||||
strided_tensor_iter_fixed<scalar1, 8>(tensor1),
|
||||
strided_tensor_iter_fixed<scalar2, 8>(tensor2));
|
||||
});
|
||||
} else {
|
||||
tbb::parallel_for(
|
||||
range, [&tensor1, &tensor2, &op](const tbb::blocked_range<size_t> r) {
|
||||
apply_op(
|
||||
r.end() - r.begin(),
|
||||
r.begin(),
|
||||
op,
|
||||
strided_tensor_iter<scalar1>(tensor1),
|
||||
strided_tensor_iter<scalar2>(tensor2));
|
||||
});
|
||||
}
|
||||
if(tensor1_counter != NULL)
|
||||
delete [] tensor1_counter;
|
||||
if(tensor2_counter != NULL)
|
||||
delete [] tensor2_counter;
|
||||
if(tensor3_counter != NULL)
|
||||
delete [] tensor3_counter;
|
||||
if(tensor4_counter != NULL)
|
||||
delete [] tensor4_counter;
|
||||
}
|
||||
|
||||
/*
|
||||
Apply a pointwise operator to four tensors.
|
||||
|
||||
The calling convention for op is a function/functor that takes takes four references to
|
||||
type scalar; at least one of these references should be non-const in order to write the output.
|
||||
For example, to compute a = b + c * d, op would be of the form:
|
||||
[](scalar &a_val, const scalar &b_val, const scalar &c_val, const scalar &d_val) {
|
||||
a_val = b_val + c_val * d_val;
|
||||
};
|
||||
*/
|
||||
template<typename scalar1, typename scalar2, typename scalar3, typename scalar4, typename Op>
|
||||
void CPU_tensor_apply4(Tensor tensor1, Tensor tensor2, Tensor tensor3, Tensor tensor4, Op op) {
|
||||
CPU_tensor_apply4_dim<scalar1, scalar2, scalar3, scalar4, Op>(tensor1, tensor2, tensor3, tensor4, -1, op);
|
||||
}
|
||||
|
||||
}
|
||||
} // namespace at
|
||||
|
|
|
|||
|
|
@ -1114,7 +1114,6 @@
|
|||
- Int
|
||||
- Short
|
||||
backends:
|
||||
- CPU
|
||||
- CUDA
|
||||
variants:
|
||||
- method
|
||||
|
|
@ -1161,7 +1160,6 @@
|
|||
types:
|
||||
- floating_point
|
||||
backends:
|
||||
- CPU
|
||||
- CUDA
|
||||
variants:
|
||||
- method
|
||||
|
|
@ -1354,7 +1352,6 @@
|
|||
types:
|
||||
- floating_point
|
||||
backends:
|
||||
- CPU
|
||||
- CUDA
|
||||
variants:
|
||||
- method
|
||||
|
|
@ -1400,7 +1397,6 @@
|
|||
types:
|
||||
- floating_point
|
||||
backends:
|
||||
- CPU
|
||||
- CUDA
|
||||
variants:
|
||||
- method
|
||||
|
|
@ -1475,7 +1471,6 @@
|
|||
types:
|
||||
- floating_point
|
||||
backends:
|
||||
- CPU
|
||||
- CUDA
|
||||
variants:
|
||||
- method
|
||||
|
|
@ -1695,7 +1690,6 @@
|
|||
types:
|
||||
- floating_point
|
||||
backends:
|
||||
- CPU
|
||||
- CUDA
|
||||
variants:
|
||||
- method
|
||||
|
|
@ -1741,7 +1735,6 @@
|
|||
types:
|
||||
- floating_point
|
||||
backends:
|
||||
- CPU
|
||||
- CUDA
|
||||
variants:
|
||||
- method
|
||||
|
|
@ -1758,7 +1751,6 @@
|
|||
types:
|
||||
- floating_point
|
||||
backends:
|
||||
- CPU
|
||||
- CUDA
|
||||
variants:
|
||||
- method
|
||||
|
|
@ -1775,7 +1767,6 @@
|
|||
types:
|
||||
- floating_point
|
||||
backends:
|
||||
- CPU
|
||||
- CUDA
|
||||
variants:
|
||||
- method
|
||||
|
|
@ -1792,7 +1783,6 @@
|
|||
types:
|
||||
- floating_point
|
||||
backends:
|
||||
- CPU
|
||||
- CUDA
|
||||
variants:
|
||||
- method
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ namespace internal {
|
|||
// for a certain number of workers. If there are multiple threads making
|
||||
// a request at the size of the maximum number of threads, they will
|
||||
// be allocated a number proportional to the other requests.
|
||||
void init_tbb_num_threads();
|
||||
AT_API void init_tbb_num_threads();
|
||||
// This parameter is heuristically chosen to determine the minimum number of
|
||||
// work that warrants paralellism. For example, when summing an array, it is
|
||||
// deemed inefficient to parallelise over arrays shorter than 32768. Further,
|
||||
|
|
|
|||
|
|
@ -72,7 +72,7 @@ void checkDefined(CheckedFrom c, const TensorArg& t);
|
|||
void checkAllDefined(CheckedFrom c, at::ArrayRef<TensorArg> t);
|
||||
|
||||
// FixMe: does TensorArg slow things down?
|
||||
void checkBackend(CheckedFrom c, at::ArrayRef<Tensor> t, at::Backend backend);
|
||||
AT_API void checkBackend(CheckedFrom c, at::ArrayRef<Tensor> t, at::Backend backend);
|
||||
|
||||
// Methods for getting data_ptr if tensor is defined
|
||||
void * maybe_data_ptr(const Tensor& tensor);
|
||||
|
|
|
|||
|
|
@ -3,48 +3,106 @@
|
|||
#include "ATen/ExpandUtils.h"
|
||||
#include "ATen/NativeFunctions.h"
|
||||
#include "ATen/WrapDimUtils.h"
|
||||
#include "cpu/UnaryOpsKernel.h"
|
||||
|
||||
#include "ATen/CPUApplyUtils.h"
|
||||
#include "ATen/Parallel.h"
|
||||
#include "ATen/native/cpu/UnaryOpsKernel.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <functional>
|
||||
#include <numeric>
|
||||
#include <vector>
|
||||
|
||||
#include <map>
|
||||
|
||||
namespace at { namespace native {
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
||||
#define IMPLEMENT_UNARY_OP(op) \
|
||||
Tensor op(const Tensor& self) { \
|
||||
Tensor result = self.type().tensor(); \
|
||||
return at::op ## _out(result, self); \
|
||||
} \
|
||||
Tensor& op##_(Tensor& self) { \
|
||||
return at::op ## _out(self, self); \
|
||||
} \
|
||||
Tensor& _ ## op ## _out_cuda(Tensor& result, const Tensor& self) { \
|
||||
return at::_ ## op ## _out(result, self); \
|
||||
} \
|
||||
Tensor& _ ## op ## _out_cpu(Tensor& result, const Tensor& self) { \
|
||||
if (result.is_contiguous() && self.is_contiguous()) { \
|
||||
result.resize_(self.sizes()); \
|
||||
if (result.numel() > 0) { \
|
||||
op ## Impl(result, self); \
|
||||
} \
|
||||
return result; \
|
||||
} \
|
||||
return at::_ ## op ## _out(result, self); \
|
||||
#define IMPLEMENT_UNARY_OP_PREQUEL(op) \
|
||||
Tensor op(const Tensor& self) { \
|
||||
Tensor result = self.type().tensor(); \
|
||||
return at::op##_out(result, self); \
|
||||
} \
|
||||
Tensor& _##op##__cuda(Tensor& self) { \
|
||||
return at::_##op##_out(self, self); \
|
||||
} \
|
||||
Tensor& _##op##_out_cuda(Tensor& result, const Tensor& self) { \
|
||||
return at::_##op##_out(result, self); \
|
||||
}
|
||||
|
||||
#define IMPLEMENT_UNARY_OP_FLOAT_CMATH(op) \
|
||||
Tensor& _##op##__cpu(Tensor& self_) { \
|
||||
if (self_.numel() > 0) { \
|
||||
Tensor self = sort_strides(self_); \
|
||||
AT_DISPATCH_FLOATING_TYPES(self.type(), op, [&] { \
|
||||
CPU_tensor_parallel_apply1<scalar_t>( \
|
||||
self, [](scalar_t& y) { y = std::op(y); }); \
|
||||
}); \
|
||||
} \
|
||||
return self_; \
|
||||
} \
|
||||
Tensor& _##op##_out_cpu(Tensor& result, const Tensor& self) { \
|
||||
result.resize_(self.sizes()); \
|
||||
if (result.numel() > 0) { \
|
||||
AT_DISPATCH_FLOATING_TYPES(self.type(), op, [&] { \
|
||||
CPU_tensor_parallel_apply2<scalar_t, scalar_t>( \
|
||||
result, self, [](scalar_t& y, scalar_t& x) { y = std::op(x); }); \
|
||||
}); \
|
||||
} \
|
||||
return result; \
|
||||
}
|
||||
|
||||
#define IMPLEMENT_UNARY_OP_VEC(op) \
|
||||
Tensor& _##op##__cpu(Tensor& self_) { \
|
||||
if (self_.numel() > 0) { \
|
||||
Tensor self = sort_strides(self_); \
|
||||
if (self.is_contiguous()) { \
|
||||
op##Impl(self, self); \
|
||||
} else { \
|
||||
AT_DISPATCH_FLOATING_TYPES(self.type(), op, [&] { \
|
||||
CPU_tensor_parallel_apply1<scalar_t>( \
|
||||
self, [](scalar_t& y) { y = std::op(y); }); \
|
||||
}); \
|
||||
} \
|
||||
} \
|
||||
return self_; \
|
||||
} \
|
||||
Tensor& _##op##_out_cpu(Tensor& result, const Tensor& self) { \
|
||||
result.resize_(self.sizes()); \
|
||||
if (result.numel() > 0) { \
|
||||
if (result.is_contiguous() && self.is_contiguous()) { \
|
||||
op##Impl(result, self); \
|
||||
} else { \
|
||||
AT_DISPATCH_FLOATING_TYPES(self.type(), op, [&] { \
|
||||
CPU_tensor_parallel_apply2<scalar_t, scalar_t>( \
|
||||
result, self, [](scalar_t& y, scalar_t& x) { y = std::op(x); }); \
|
||||
}); \
|
||||
} \
|
||||
} \
|
||||
return result; \
|
||||
}
|
||||
|
||||
IMPLEMENT_UNARY_OP_PREQUEL(abs)
|
||||
IMPLEMENT_UNARY_OP_PREQUEL(ceil)
|
||||
IMPLEMENT_UNARY_OP_PREQUEL(cos)
|
||||
IMPLEMENT_UNARY_OP_PREQUEL(exp)
|
||||
IMPLEMENT_UNARY_OP_PREQUEL(floor)
|
||||
IMPLEMENT_UNARY_OP_PREQUEL(log)
|
||||
IMPLEMENT_UNARY_OP_PREQUEL(round)
|
||||
IMPLEMENT_UNARY_OP_PREQUEL(sin)
|
||||
IMPLEMENT_UNARY_OP_PREQUEL(sqrt)
|
||||
IMPLEMENT_UNARY_OP_PREQUEL(trunc)
|
||||
|
||||
IMPLEMENT_UNARY_OP_VEC(abs)
|
||||
IMPLEMENT_UNARY_OP_VEC(ceil)
|
||||
IMPLEMENT_UNARY_OP_FLOAT_CMATH(cos)
|
||||
IMPLEMENT_UNARY_OP_FLOAT_CMATH(exp)
|
||||
IMPLEMENT_UNARY_OP_VEC(floor)
|
||||
IMPLEMENT_UNARY_OP_FLOAT_CMATH(log)
|
||||
IMPLEMENT_UNARY_OP_VEC(round)
|
||||
IMPLEMENT_UNARY_OP_FLOAT_CMATH(sin)
|
||||
IMPLEMENT_UNARY_OP_VEC(sqrt)
|
||||
IMPLEMENT_UNARY_OP_VEC(trunc)
|
||||
}
|
||||
|
||||
IMPLEMENT_UNARY_OP(abs)
|
||||
IMPLEMENT_UNARY_OP(ceil)
|
||||
IMPLEMENT_UNARY_OP(cos)
|
||||
IMPLEMENT_UNARY_OP(exp)
|
||||
IMPLEMENT_UNARY_OP(floor)
|
||||
IMPLEMENT_UNARY_OP(log)
|
||||
IMPLEMENT_UNARY_OP(round)
|
||||
IMPLEMENT_UNARY_OP(sin)
|
||||
IMPLEMENT_UNARY_OP(sqrt)
|
||||
IMPLEMENT_UNARY_OP(trunc)
|
||||
|
||||
}} // namespace at::native
|
||||
} // namespace at
|
||||
|
|
|
|||
|
|
@ -7,12 +7,14 @@
|
|||
#include "ATen/cpu/vec256/vec256.h"
|
||||
#include "ATen/native/cpu/CapabilityDispatch.h"
|
||||
|
||||
namespace at { namespace native { namespace {
|
||||
namespace at { namespace native {
|
||||
namespace {
|
||||
|
||||
using namespace vec256;
|
||||
|
||||
template <typename scalar_t, typename F>
|
||||
static void unary_kernel(scalar_t* arr_out, const scalar_t* arr_in, int64_t size, F func) {
|
||||
static void
|
||||
unary_kernel(scalar_t* arr_out, const scalar_t* arr_in, int64_t size, F func) {
|
||||
using Vec = Vec256<scalar_t>;
|
||||
int64_t size_rounded = size - (size % Vec::size);
|
||||
int64_t k = 0;
|
||||
|
|
@ -52,94 +54,59 @@ static void parallel_apply(Tensor& result, const Tensor& self, F f) {
|
|||
|
||||
static void abs_kernel(Tensor& result, const Tensor& self) {
|
||||
AT_DISPATCH_ALL_TYPES(self.type(), "abs", [&] {
|
||||
parallel_apply<scalar_t>(result, self, [](const Vec256<scalar_t>& x) {
|
||||
return x.abs();
|
||||
});
|
||||
});
|
||||
parallel_apply<scalar_t>(
|
||||
result,
|
||||
self,
|
||||
[](const Vec256<scalar_t>& x) { return x.abs(); }); });
|
||||
}
|
||||
|
||||
static void ceil_kernel(Tensor& result, const Tensor& self) {
|
||||
AT_DISPATCH_FLOATING_TYPES(self.type(), "ceil", [&] {
|
||||
parallel_apply<scalar_t>(result, self, [](const Vec256<scalar_t>& x) {
|
||||
return x.ceil();
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
static void cos_kernel(Tensor& result, const Tensor& self) {
|
||||
AT_DISPATCH_FLOATING_TYPES(self.type(), "cos", [&] {
|
||||
parallel_apply<scalar_t>(result, self, [](const Vec256<scalar_t>& x) {
|
||||
return x.cos();
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
static void exp_kernel(Tensor& result, const Tensor& self) {
|
||||
AT_DISPATCH_FLOATING_TYPES(self.type(), "exp", [&] {
|
||||
parallel_apply<scalar_t>(result, self, [](const Vec256<scalar_t>& x) {
|
||||
return x.exp();
|
||||
});
|
||||
});
|
||||
parallel_apply<scalar_t>(
|
||||
result,
|
||||
self,
|
||||
[](const Vec256<scalar_t>& x) { return x.ceil(); }); });
|
||||
}
|
||||
|
||||
static void floor_kernel(Tensor& result, const Tensor& self) {
|
||||
AT_DISPATCH_FLOATING_TYPES(self.type(), "floor", [&] {
|
||||
parallel_apply<scalar_t>(result, self, [](const Vec256<scalar_t>& x) {
|
||||
return x.floor();
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
static void log_kernel(Tensor& result, const Tensor& self) {
|
||||
AT_DISPATCH_FLOATING_TYPES(self.type(), "log", [&] {
|
||||
parallel_apply<scalar_t>(result, self, [](const Vec256<scalar_t>& x) {
|
||||
return x.log();
|
||||
});
|
||||
});
|
||||
parallel_apply<scalar_t>(
|
||||
result,
|
||||
self,
|
||||
[](const Vec256<scalar_t>& x) { return x.floor(); }); });
|
||||
}
|
||||
|
||||
static void round_kernel(Tensor& result, const Tensor& self) {
|
||||
AT_DISPATCH_FLOATING_TYPES(self.type(), "round", [&] {
|
||||
parallel_apply<scalar_t>(result, self, [](const Vec256<scalar_t>& x) {
|
||||
return x.round();
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
static void sin_kernel(Tensor& result, const Tensor& self) {
|
||||
AT_DISPATCH_FLOATING_TYPES(self.type(), "sin", [&] {
|
||||
parallel_apply<scalar_t>(result, self, [](const Vec256<scalar_t>& x) {
|
||||
return x.sin();
|
||||
});
|
||||
});
|
||||
parallel_apply<scalar_t>(
|
||||
result,
|
||||
self,
|
||||
[](const Vec256<scalar_t>& x) { return x.round(); }); });
|
||||
}
|
||||
|
||||
static void sqrt_kernel(Tensor& result, const Tensor& self) {
|
||||
AT_DISPATCH_FLOATING_TYPES(self.type(), "sqrt", [&] {
|
||||
parallel_apply<scalar_t>(result, self, [](const Vec256<scalar_t>& x) {
|
||||
return x.sqrt();
|
||||
});
|
||||
});
|
||||
parallel_apply<scalar_t>(
|
||||
result,
|
||||
self,
|
||||
[](const Vec256<scalar_t>& x) { return x.sqrt(); }); });
|
||||
}
|
||||
|
||||
static void trunc_kernel(Tensor& result, const Tensor& self) {
|
||||
AT_DISPATCH_FLOATING_TYPES(self.type(), "trunc", [&] {
|
||||
parallel_apply<scalar_t>(result, self, [](const Vec256<scalar_t>& x) {
|
||||
return x.trunc();
|
||||
});
|
||||
});
|
||||
parallel_apply<scalar_t>(
|
||||
result,
|
||||
self,
|
||||
[](const Vec256<scalar_t>& x) { return x.trunc(); }); });
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
} // anonymous namespace
|
||||
|
||||
|
||||
REGISTER_DISPATCH(absImpl, &abs_kernel);
|
||||
REGISTER_DISPATCH(ceilImpl, &ceil_kernel);
|
||||
REGISTER_DISPATCH(cosImpl, &cos_kernel);
|
||||
REGISTER_DISPATCH(expImpl, &exp_kernel);
|
||||
REGISTER_DISPATCH(floorImpl, &floor_kernel);
|
||||
REGISTER_DISPATCH(logImpl, &log_kernel);
|
||||
REGISTER_DISPATCH(roundImpl, &round_kernel);
|
||||
REGISTER_DISPATCH(sinImpl, &sin_kernel);
|
||||
REGISTER_DISPATCH(sqrtImpl, &sqrt_kernel);
|
||||
REGISTER_DISPATCH(truncImpl, &trunc_kernel);
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
#pragma once
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/Parallel.h>
|
||||
#include <stdexcept>
|
||||
#include "CapabilityDispatch.h"
|
||||
|
||||
|
|
@ -11,22 +10,16 @@ using unary_fn = void(*)(Tensor&, const Tensor&);
|
|||
|
||||
extern DispatchStub<unary_fn> absImpl;
|
||||
extern DispatchStub<unary_fn> ceilImpl;
|
||||
extern DispatchStub<unary_fn> cosImpl;
|
||||
extern DispatchStub<unary_fn> expImpl;
|
||||
extern DispatchStub<unary_fn> floorImpl;
|
||||
extern DispatchStub<unary_fn> logImpl;
|
||||
extern DispatchStub<unary_fn> roundImpl;
|
||||
extern DispatchStub<unary_fn> sinImpl;
|
||||
extern DispatchStub<unary_fn> sqrtImpl;
|
||||
extern DispatchStub<unary_fn> truncImpl;
|
||||
|
||||
// Missing unary functions
|
||||
// TODO: Add generic apply function for contiguous and non-contiguous tensors
|
||||
// The goal here is to move more ops entirely into ATen and take advantage of
|
||||
// automatic vectorization with file-specific flags
|
||||
// acos
|
||||
// asin
|
||||
// atan
|
||||
// cos
|
||||
// cosh
|
||||
// digamma
|
||||
// erf
|
||||
|
|
@ -37,6 +30,7 @@ extern DispatchStub<unary_fn> truncImpl;
|
|||
// log1p
|
||||
// rsqrt
|
||||
// sigmoid
|
||||
// sin
|
||||
// sinh
|
||||
// tan
|
||||
// tanh
|
||||
|
|
|
|||
|
|
@ -44,6 +44,9 @@
|
|||
- func: abs(Tensor self) -> Tensor
|
||||
|
||||
- func: abs_(Tensor self) -> Tensor
|
||||
dispatch:
|
||||
CPU: _abs__cpu
|
||||
CUDA: _abs__cuda
|
||||
|
||||
- func: abs_out(Tensor result, Tensor self) -> Tensor
|
||||
variants: function
|
||||
|
|
@ -118,6 +121,9 @@
|
|||
- func: ceil(Tensor self) -> Tensor
|
||||
|
||||
- func: ceil_(Tensor self) -> Tensor
|
||||
dispatch:
|
||||
CPU: _ceil__cpu
|
||||
CUDA: _ceil__cuda
|
||||
|
||||
- func: ceil_out(Tensor result, Tensor self) -> Tensor
|
||||
variants: function
|
||||
|
|
@ -170,6 +176,9 @@
|
|||
- func: cos(Tensor self) -> Tensor
|
||||
|
||||
- func: cos_(Tensor self) -> Tensor
|
||||
dispatch:
|
||||
CPU: _cos__cpu
|
||||
CUDA: _cos__cuda
|
||||
|
||||
- func: cos_out(Tensor result, Tensor self) -> Tensor
|
||||
variants: function
|
||||
|
|
@ -343,6 +352,9 @@
|
|||
- func: exp(Tensor self) -> Tensor
|
||||
|
||||
- func: exp_(Tensor self) -> Tensor
|
||||
dispatch:
|
||||
CPU: _exp__cpu
|
||||
CUDA: _exp__cuda
|
||||
|
||||
- func: exp_out(Tensor result, Tensor self) -> Tensor
|
||||
variants: function
|
||||
|
|
@ -368,6 +380,9 @@
|
|||
- func: floor(Tensor self) -> Tensor
|
||||
|
||||
- func: floor_(Tensor self) -> Tensor
|
||||
dispatch:
|
||||
CPU: _floor__cpu
|
||||
CUDA: _floor__cuda
|
||||
|
||||
- func: floor_out(Tensor result, Tensor self) -> Tensor
|
||||
variants: function
|
||||
|
|
@ -451,6 +466,9 @@
|
|||
- func: log(Tensor self) -> Tensor
|
||||
|
||||
- func: log_(Tensor self) -> Tensor
|
||||
dispatch:
|
||||
CPU: _log__cpu
|
||||
CUDA: _log__cuda
|
||||
|
||||
- func: log_out(Tensor result, Tensor self) -> Tensor
|
||||
variants: function
|
||||
|
|
@ -604,6 +622,9 @@
|
|||
- func: round(Tensor self) -> Tensor
|
||||
|
||||
- func: round_(Tensor self) -> Tensor
|
||||
dispatch:
|
||||
CPU: _round__cpu
|
||||
CUDA: _round__cuda
|
||||
|
||||
- func: round_out(Tensor result, Tensor self) -> Tensor
|
||||
variants: function
|
||||
|
|
@ -632,6 +653,9 @@
|
|||
- func: sin(Tensor self) -> Tensor
|
||||
|
||||
- func: sin_(Tensor self) -> Tensor
|
||||
dispatch:
|
||||
CPU: _sin__cpu
|
||||
CUDA: _sin__cuda
|
||||
|
||||
- func: sin_out(Tensor result, Tensor self) -> Tensor
|
||||
variants: function
|
||||
|
|
@ -719,6 +743,9 @@
|
|||
- func: sqrt(Tensor self) -> Tensor
|
||||
|
||||
- func: sqrt_(Tensor self) -> Tensor
|
||||
dispatch:
|
||||
CPU: _sqrt__cpu
|
||||
CUDA: _sqrt__cuda
|
||||
|
||||
- func: sqrt_out(Tensor result, Tensor self) -> Tensor
|
||||
variants: function
|
||||
|
|
@ -774,6 +801,9 @@
|
|||
- func: trunc(Tensor self) -> Tensor
|
||||
|
||||
- func: trunc_(Tensor self) -> Tensor
|
||||
dispatch:
|
||||
CPU: _trunc__cpu
|
||||
CUDA: _trunc__cuda
|
||||
|
||||
- func: trunc_out(Tensor result, Tensor self) -> Tensor
|
||||
variants: function
|
||||
|
|
|
|||
|
|
@ -7,6 +7,9 @@ ENDIF(MSVC)
|
|||
ADD_EXECUTABLE(scalar_test scalar_test.cpp)
|
||||
target_link_libraries(scalar_test ATen)
|
||||
|
||||
ADD_EXECUTABLE(apply_utils_test apply_utils_test.cpp)
|
||||
target_link_libraries(apply_utils_test ATen)
|
||||
|
||||
ADD_EXECUTABLE(basic basic.cpp)
|
||||
target_link_libraries(basic ATen)
|
||||
|
||||
|
|
|
|||
139
aten/src/ATen/test/apply_utils_test.cpp
Normal file
139
aten/src/ATen/test/apply_utils_test.cpp
Normal file
|
|
@ -0,0 +1,139 @@
|
|||
#define CATCH_CONFIG_MAIN
|
||||
#include "catch.hpp"
|
||||
|
||||
#include "ATen/ATen.h"
|
||||
#include "ATen/CPUApplyUtils.h"
|
||||
#include "test_assert.h"
|
||||
#include "test_seed.h"
|
||||
|
||||
#include <iostream>
|
||||
using namespace std;
|
||||
using namespace at;
|
||||
|
||||
void fill_tensor(int64_t scalar, Tensor& t_) {
|
||||
auto t = t_.view(-1);
|
||||
for (int64_t i = 0; i < t.numel(); i++) {
|
||||
t[i] = (i + 1) * scalar;
|
||||
}
|
||||
}
|
||||
|
||||
// This test exercises all sequential applyX functions. Given a shape and two
|
||||
// transpose dimensions we create 5 tensors (a0, ..., a4) of the given shape and
|
||||
// transpose the dimension a with b for each tensor. Then we call the applyX
|
||||
// function on each floating type. a4 is allocated in doubles only, whereas a0,
|
||||
// ..., a3 are allocated in the given type. For each applyX function we once
|
||||
// write the same type as we read (using a0, ..., aX-1) and we once write to
|
||||
// double (using a4 as a target). We also exercise on a zero_dim and empty
|
||||
// tensor.
|
||||
void test(Type& type, IntList shape, int64_t a = 0, int64_t b = 1) {
|
||||
auto zero_dim = type.tensor({});
|
||||
zero_dim.fill_(2);
|
||||
zero_dim.exp_();
|
||||
AT_DISPATCH_FLOATING_TYPES(zero_dim.type(), "test0", [&] {
|
||||
ASSERT(zero_dim.data<scalar_t>()[0] == std::exp(2));
|
||||
});
|
||||
|
||||
auto empty_t = type.tensor({0});
|
||||
empty_t.fill_(3);
|
||||
empty_t.exp_();
|
||||
|
||||
auto a0 = type.tensor();
|
||||
auto a1 = type.tensor();
|
||||
auto a2 = type.tensor();
|
||||
auto a3 = type.tensor();
|
||||
auto a4 = CPU(kDouble).tensor();
|
||||
|
||||
std::vector<Tensor> tensors({a0, a1, a2, a3, a4});
|
||||
for (size_t i = 0; i < tensors.size(); i++) {
|
||||
tensors[i].resize_(shape);
|
||||
fill_tensor(i + 1, tensors[i]);
|
||||
if (a >= 0 && b >= 0) {
|
||||
tensors[i].transpose_(a, b);
|
||||
}
|
||||
}
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(a0.type(), "test1", [&] {
|
||||
CPU_tensor_apply2<scalar_t, scalar_t>(
|
||||
a0, a1, [](scalar_t& y, const scalar_t& x) { y = x * x; });
|
||||
CPU_tensor_apply2<double, scalar_t>(
|
||||
a4, a1, [](double& y, scalar_t x) { y = (double)(x * x); });
|
||||
for (int64_t i = 0; i < a0.numel(); i++) {
|
||||
auto target = a1.data<scalar_t>()[i] * a1.data<scalar_t>()[i];
|
||||
ASSERT(a0.data<scalar_t>()[i] == target);
|
||||
ASSERT(a4.data<double>()[i] == target);
|
||||
}
|
||||
});
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(a0.type(), "test2", [&] {
|
||||
CPU_tensor_apply3<scalar_t, scalar_t, scalar_t>(
|
||||
a0, a1, a2, [](scalar_t& y, const scalar_t& x, const scalar_t& z) {
|
||||
y = x * x + z;
|
||||
});
|
||||
CPU_tensor_apply3<double, scalar_t, scalar_t>(
|
||||
a4, a1, a2, [](double& y, const scalar_t& x, const scalar_t& z) {
|
||||
y = (double)(x * x + z);
|
||||
});
|
||||
for (int64_t i = 0; i < a0.numel(); i++) {
|
||||
auto target = a1.data<scalar_t>()[i] * a1.data<scalar_t>()[i];
|
||||
target = target + a2.data<scalar_t>()[i];
|
||||
ASSERT(a0.data<scalar_t>()[i] == target);
|
||||
ASSERT(a4.data<double>()[i] == target);
|
||||
}
|
||||
});
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(a0.type(), "test3", [&] {
|
||||
CPU_tensor_apply4<scalar_t, scalar_t, scalar_t, scalar_t>(
|
||||
a0,
|
||||
a1,
|
||||
a2,
|
||||
a3,
|
||||
[](scalar_t& y,
|
||||
const scalar_t& x,
|
||||
const scalar_t& z,
|
||||
const scalar_t& a) { y = x * x + z * a; });
|
||||
CPU_tensor_apply4<double, scalar_t, scalar_t, scalar_t>(
|
||||
a4,
|
||||
a1,
|
||||
a2,
|
||||
a3,
|
||||
[](double& y, const scalar_t& x, const scalar_t& z, const scalar_t& a) {
|
||||
y = (double)(x * x + z * a);
|
||||
});
|
||||
for (int64_t i = 0; i < a0.numel(); i++) {
|
||||
auto target = a1.data<scalar_t>()[i] * a1.data<scalar_t>()[i];
|
||||
target = target + a2.data<scalar_t>()[i] * a3.data<scalar_t>()[i];
|
||||
ASSERT(a0.data<scalar_t>()[i] == target);
|
||||
ASSERT(a4.data<double>()[i] == target);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
TEST_CASE("apply utils test 2-dim small contiguous", "[cpu]") {
|
||||
manual_seed(123, at::Backend::CPU);
|
||||
test(CPU(kDouble), {2, 1}, -1, -1);
|
||||
}
|
||||
|
||||
TEST_CASE("apply utils test 2-dim small", "[cpu]") {
|
||||
manual_seed(123, at::Backend::CPU);
|
||||
test(CPU(kDouble), {2, 1});
|
||||
}
|
||||
|
||||
TEST_CASE("apply utils test 2-dim", "[cpu]") {
|
||||
manual_seed(123, at::Backend::CPU);
|
||||
test(CPU(kDouble), {20, 10});
|
||||
}
|
||||
|
||||
TEST_CASE("apply utils test 3-dim", "[cpu]") {
|
||||
manual_seed(123, at::Backend::CPU);
|
||||
test(CPU(kDouble), {3, 4, 2});
|
||||
}
|
||||
|
||||
TEST_CASE("apply utils test 3-dim medium", "[cpu]") {
|
||||
manual_seed(123, at::Backend::CPU);
|
||||
test(CPU(kDouble), {3, 40, 2});
|
||||
}
|
||||
|
||||
TEST_CASE("apply utils test 10-dim", "[cpu]") {
|
||||
manual_seed(123, at::Backend::CPU);
|
||||
test(CPU(kDouble), {3, 4, 2, 5, 2, 1, 3, 4, 2, 3});
|
||||
}
|
||||
|
|
@ -9,6 +9,7 @@ $BUILD_ROOT/src/ATen/test/atest
|
|||
$BUILD_ROOT/src/ATen/test/scalar_test
|
||||
$BUILD_ROOT/src/ATen/test/broadcast_test
|
||||
$BUILD_ROOT/src/ATen/test/wrapdim_test
|
||||
$BUILD_ROOT/src/ATen/test/apply_utils_test
|
||||
$BUILD_ROOT/src/ATen/test/dlconvertor_test
|
||||
$BUILD_ROOT/src/ATen/test/native_test
|
||||
$BUILD_ROOT/src/ATen/test/scalar_tensor_test
|
||||
|
|
|
|||
|
|
@ -266,7 +266,7 @@ class TestTorch(TestCase):
|
|||
|
||||
def compare_reference(input, dtype):
|
||||
input = torch.tensor(input, dtype=dtype)
|
||||
res1 = torchfn(input)
|
||||
res1 = torchfn(input.clone())
|
||||
res2 = input.clone().apply_(lambda x: mathfn(x))
|
||||
torch.testing.assert_allclose(res1, res2)
|
||||
|
||||
|
|
@ -287,6 +287,32 @@ class TestTorch(TestCase):
|
|||
check_non_contiguous((5, 7), torch.float)
|
||||
check_non_contiguous((1024,), torch.float)
|
||||
|
||||
# If size(dim) == 1, stride(dim) is not defined.
|
||||
# The code needs to be able to handle this
|
||||
def check_contiguous_size1(dtype):
|
||||
contig = torch.randn((5, 100), dtype=dtype)
|
||||
contig = contig[:1, :50]
|
||||
contig2 = torch.empty(contig.size(), dtype=dtype)
|
||||
contig2.copy_(contig)
|
||||
self.assertTrue(contig.is_contiguous())
|
||||
self.assertTrue(contig2.is_contiguous())
|
||||
self.assertEqual(torchfn(contig), torchfn(contig2), 'contiguous size1')
|
||||
|
||||
check_contiguous_size1(torch.double)
|
||||
check_contiguous_size1(torch.float)
|
||||
|
||||
def check_contiguous_size1_largedim(dtype):
|
||||
contig = torch.randn((5, 2, 3, 1, 4, 5, 3, 2, 1, 2, 3, 4), dtype=dtype)
|
||||
contig = contig[:1, :, :, :, :, :, :, :, :, :, :, :]
|
||||
contig2 = torch.empty(contig.size(), dtype=dtype)
|
||||
contig2.copy_(contig)
|
||||
self.assertTrue(contig.is_contiguous())
|
||||
self.assertTrue(contig2.is_contiguous())
|
||||
self.assertEqual(torchfn(contig), torchfn(contig2), 'contiguous size1')
|
||||
|
||||
check_contiguous_size1_largedim(torch.double)
|
||||
check_contiguous_size1_largedim(torch.float)
|
||||
|
||||
def check_large(dtype):
|
||||
input = torch.randn(1024, 512, dtype=dtype)
|
||||
actual = torchfn(input)
|
||||
|
|
@ -298,11 +324,20 @@ class TestTorch(TestCase):
|
|||
check_large(torch.double)
|
||||
check_large(torch.float)
|
||||
|
||||
def _test_math_by_name(self, function_name):
|
||||
torchfn = getattr(torch, function_name)
|
||||
mathfn = getattr(math, function_name)
|
||||
def __test_math_by_name(self, function_name, mathfn, selffn):
|
||||
mathfn = getattr(math, mathfn)
|
||||
if selffn:
|
||||
def torchfn(x):
|
||||
return getattr(x, function_name)()
|
||||
else:
|
||||
torchfn = getattr(torch, function_name)
|
||||
self._test_math(torchfn, mathfn)
|
||||
|
||||
def _test_math_by_name(self, function_name, test_self=True):
|
||||
if test_self:
|
||||
self.__test_math_by_name(function_name + "_", function_name, True)
|
||||
self.__test_math_by_name(function_name, function_name, False)
|
||||
|
||||
def test_sin(self):
|
||||
self._test_math_by_name('sin')
|
||||
|
||||
|
|
|
|||
|
|
@ -246,7 +246,7 @@ class TestDataLoader(TestCase):
|
|||
dataiter = iter(dataloader)
|
||||
self.assertEqual(len(list(dataiter)), 1)
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS, "FIXME: Intermittent CUDA out-of-memory error")
|
||||
@unittest.skip("FIXME: Intermittent CUDA out-of-memory error on Windows and time-out under ASAN")
|
||||
def test_multi_keep(self):
|
||||
dataloader = torch.utils.data.DataLoader(self.dataset,
|
||||
batch_size=self.batch_size,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user