Support non-contiguous tensors for unary ops (#6119)

This commit is contained in:
cpuhrsch 2018-04-27 19:31:34 +00:00 committed by Adam Paszke
parent a6bfa16c17
commit ae35e0e924
13 changed files with 693 additions and 395 deletions

View File

@ -1,6 +1,6 @@
#pragma once
#include <sstream>
#include "ATen/Parallel.h"
#include "ATen/TensorUtils.h"
namespace at {
@ -8,11 +8,11 @@ namespace at {
/*
* The basic strategy for apply is as follows:
*
* 1. Starting with the outermost index, loop until we reach a dimension where the
* data is no longer contiguous, i.e. the stride at that dimension is not equal to
* the size of the tensor defined by the outer dimensions. Let's call this outer
* (contiguous) tensor A. Note that if the Tensor is contiguous, then A is equal
* to the entire Tensor. Let's call the inner tensor B.
* 1. Starting with the outermost index, loop until we reach a dimension where
* the data is no longer contiguous, i.e. the stride at that dimension is not
* equal to the size of the tensor defined by the outer dimensions. Let's call
* this outer (contiguous) tensor A. Note that if the Tensor is contiguous, then
* A is equal to the entire Tensor. Let's call the inner tensor B.
*
* 2. We loop through the indices in B, starting at its outermost dimension. For
* example, if B is a 2x2 matrix, then we do:
@ -22,289 +22,370 @@ namespace at {
* B[1][0]
* B[1][1]
*
* We set the offset into the underlying storage as (storageOffset + stride_B * index_B),
* i.e. basically we compute the offset into the storage as we would normally for a
* Tensor. But because we are guaranteed the subsequent data is contiguous in memory, we
* can simply loop for sizeof(A) iterations and perform the operation, without having to
* follow the order described by the strides of A.
* We set the offset into the underlying storage as (storageOffset + stride_B *
* index_B), i.e. basically we compute the offset into the storage as we would
* normally for a Tensor. But because we are guaranteed the subsequent data is
* contiguous in memory, we can simply loop for sizeof(A) iterations and perform
* the operation, without having to follow the order described by the strides of
* A.
*
* 3. As an optimization, we merge dimensions of A that are contiguous in memory. For
* example, if A is a 3x3x3x3 tensor narrowed from a 3x3x4x3 tensor, then the first two
* dimensions can be merged for the purposes of APPLY, reducing the number of nested
* loops.
* 3. As an optimization, we merge dimensions of A that are contiguous in
* memory. For example, if A is a 3x3x3x3 tensor narrowed from a 3x3x4x3 tensor,
* then the first two dimensions can be merged for the purposes of APPLY,
* reducing the number of nested loops.
*/
// TODO: turn this macro into a proper template
#define __ATH_TENSOR_APPLYX_PREAMBLE(TYPE, ATENSOR, DIM, ALLOW_CONTIGUOUS) \
TYPE *ATENSOR##_data = NULL; \
int64_t *ATENSOR##_counter = NULL, *ATENSOR##_sizes = NULL, *ATENSOR##_strides = NULL, *ATENSOR##_dimOffset = NULL; \
int64_t ATENSOR##_stride = 0, ATENSOR##_size = 0, ATENSOR##_dim = 0, ATENSOR##_i; \
int ATENSOR##_contiguous = ALLOW_CONTIGUOUS && DIM < 0; \
\
if(ATENSOR.sizes().equals({0})) \
TH_TENSOR_APPLY_hasFinished = true; \
else \
{ \
ATENSOR##_data = ATENSOR.data<TYPE>(); \
ATENSOR##_size = 1; \
ATENSOR##_stride = 1; \
for(ATENSOR##_i = ATENSOR.dim() - 1; ATENSOR##_i >= 0; ATENSOR##_i--) { \
if(ATENSOR.sizes()[ATENSOR##_i] != 1) { \
if(ATENSOR.strides()[ATENSOR##_i] == ATENSOR##_size && ATENSOR##_i != DIM) \
ATENSOR##_size *= ATENSOR.sizes()[ATENSOR##_i]; \
else{ \
ATENSOR##_contiguous = 0; \
break; \
} \
} \
} \
if (!ATENSOR##_contiguous) { \
/* Find the dimension of contiguous sections */ \
ATENSOR##_dim = 1; \
for(ATENSOR##_i = ATENSOR.dim() - 2; ATENSOR##_i >= 0; ATENSOR##_i--) \
{ \
if(ATENSOR.strides()[ATENSOR##_i] != ATENSOR.strides()[ATENSOR##_i+1] * ATENSOR.sizes()[ATENSOR##_i+1] || ATENSOR##_i == DIM || ATENSOR##_i+1 == DIM) \
ATENSOR##_dim++; \
} \
/* Allocate an array of 3*dim elements, where dim is the number of contiguous sections */ \
ATENSOR##_counter = new int64_t[3*ATENSOR##_dim]; \
ATENSOR##_sizes = ATENSOR##_counter + ATENSOR##_dim; \
ATENSOR##_strides = ATENSOR##_counter + 2*ATENSOR##_dim; \
TH_TENSOR_dim_index = ATENSOR##_dim-1; \
ATENSOR##_dimOffset = (DIM == ATENSOR.dim()-1) ? &ATENSOR##_i : &ATENSOR##_counter[DIM]; \
ATENSOR##_sizes[TH_TENSOR_dim_index] = ATENSOR.sizes()[ATENSOR.dim()-1]; \
ATENSOR##_strides[TH_TENSOR_dim_index] = ATENSOR.strides()[ATENSOR.dim()-1]; \
/* ATENSOR##_counter tracks where we are in the storage. The offset into the */ \
/* storage is given by storage_offset + (i * j), where i is the stride */ \
/* vector and j is tensor_counter vector. This sets the starting position for the loop. */ \
for(ATENSOR##_i = ATENSOR##_dim-1; ATENSOR##_i >= 0; --ATENSOR##_i) { \
ATENSOR##_counter[ATENSOR##_i] = 0; \
} \
for(ATENSOR##_i = ATENSOR.dim()-2; ATENSOR##_i >= 0; --ATENSOR##_i) { \
if (ATENSOR.strides()[ATENSOR##_i] == ATENSOR.strides()[ATENSOR##_i+1] * ATENSOR.sizes()[ATENSOR##_i+1] && ATENSOR##_i != DIM && ATENSOR##_i+1 != DIM) { \
ATENSOR##_sizes[TH_TENSOR_dim_index] = ATENSOR.sizes()[ATENSOR##_i] * ATENSOR##_sizes[TH_TENSOR_dim_index]; \
if (DIM != ATENSOR.dim()-1 && ATENSOR##_i < DIM) \
ATENSOR##_dimOffset--; \
} else { \
--TH_TENSOR_dim_index; \
ATENSOR##_sizes[TH_TENSOR_dim_index] = ATENSOR.sizes()[ATENSOR##_i]; \
ATENSOR##_strides[TH_TENSOR_dim_index] = ATENSOR.strides()[ATENSOR##_i]; \
} \
} \
/* Size of the inner most section */ \
ATENSOR##_size = ATENSOR##_sizes[ATENSOR##_dim-1]; \
/* Stride of the inner most section */ \
ATENSOR##_stride = ATENSOR##_strides[ATENSOR##_dim-1]; \
} \
} \
ATENSOR##_i = 0;
// TODO: turn this macro into a proper template
#define __ATH_TENSOR_APPLYX_UPDATE_COUNTERS(ATENSOR, ALWAYS_UPDATE) \
if(ATENSOR##_i == ATENSOR##_size || ALWAYS_UPDATE) \
{ \
if(ATENSOR##_contiguous) \
break; \
\
if(ATENSOR##_dim == 1) \
break; \
\
/* Reset pointer to beginning of loop */ \
ATENSOR##_data -= ATENSOR##_size*ATENSOR##_stride; \
for(ATENSOR##_i = ATENSOR##_dim-2; ATENSOR##_i >= 0; ATENSOR##_i--) \
{ \
ATENSOR##_counter[ATENSOR##_i]++; \
/* Jump ahread by the stride of this dimension */ \
ATENSOR##_data += ATENSOR##_strides[ATENSOR##_i]; \
\
if(ATENSOR##_counter[ATENSOR##_i] == ATENSOR##_sizes[ATENSOR##_i]) \
{ \
if(ATENSOR##_i == 0) \
{ \
TH_TENSOR_APPLY_hasFinished = true; \
break; \
} \
else \
{ \
/* Reset the pointer to the beginning of the chunk defined by this dimension */ \
ATENSOR##_data -= ATENSOR##_counter[ATENSOR##_i]*ATENSOR##_strides[ATENSOR##_i]; \
ATENSOR##_counter[ATENSOR##_i] = 0; \
} \
} \
else \
break; \
} \
ATENSOR##_i = 0; \
inline Tensor sort_strides(Tensor& tensor_) {
IntList strides = tensor_.strides();
std::vector<int64_t> indices;
indices.reserve(tensor_.ndimension());
for (int64_t i = 0; i < tensor_.ndimension(); i++) {
indices.push_back(i);
}
std::sort(indices.begin(), indices.end(), [&strides](int64_t i1, int64_t i2) {
return strides[i1] > strides[i2];
});
Tensor tensor = tensor_.permute(indices);
return tensor;
}
template <typename Arg>
inline void _setup_arrays(Tensor& tensor, Arg* iter) {
int64_t max_dim = tensor.ndimension();
iter->dim_ = 0;
for (int64_t i = 0; i < max_dim; i++) {
int64_t size = tensor.size(i);
int64_t stride = tensor.stride(i);
while (i + 1 < max_dim &&
(tensor.size(i + 1) == 1 ||
tensor.stride(i) == tensor.size(i + 1) * tensor.stride(i + 1))) {
size = size * tensor.size(i + 1);
if (tensor.size(i + 1) != 1)
stride = tensor.stride(i + 1);
i++;
}
iter->sizes_[iter->dim_] = size;
iter->strides_[iter->dim_] = stride;
iter->dim_++;
}
}
template <typename T, int N>
struct strided_tensor_iter_fixed {
public:
T* data_ = NULL;
int64_t dim_;
int64_t counter_[N];
int64_t sizes_[N];
int64_t strides_[N];
strided_tensor_iter_fixed(strided_tensor_iter_fixed const&) = delete;
void operator=(strided_tensor_iter_fixed const& x) = delete;
strided_tensor_iter_fixed(strided_tensor_iter_fixed&&) = default;
strided_tensor_iter_fixed(Tensor& tensor, bool sort_strides = false)
: data_(tensor.data<T>()) {
memset(counter_, 0, sizeof(int64_t) * N);
_setup_arrays(tensor, this);
}
};
template <typename T>
struct strided_tensor_iter {
private:
public:
T* data_ = NULL;
int64_t dim_;
std::vector<int64_t> counter_;
std::vector<int64_t> sizes_;
std::vector<int64_t> strides_;
strided_tensor_iter(strided_tensor_iter const&) = delete;
void operator=(strided_tensor_iter const& x) = delete;
strided_tensor_iter(strided_tensor_iter&&) = default;
strided_tensor_iter(Tensor& tensor)
: data_(tensor.data<T>()),
dim_(tensor.ndimension()),
counter_(dim_, 0),
sizes_(tensor.sizes()),
strides_(tensor.strides()) {
_setup_arrays(tensor, this);
}
};
inline bool _all_equal_numel(at::ArrayRef<Tensor> tensors) {
if (tensors.size() == 0)
return true;
int64_t all_numel = tensors[0].numel();
for (size_t i = 1; i < tensors.size(); i++) {
if (tensors[i].numel() != all_numel)
return false;
}
return true;
}
inline std::string _all_equal_numel_error(at::ArrayRef<Tensor> tensors) {
std::ostringstream oss;
oss << "inconsistent tensor size, expected ";
for (size_t i = 0; i < tensors.size() - 1; i++) {
oss << tensors[i].sizes() << ", ";
}
oss << "and " << tensors[tensors.size() - 1]
<< " to have the same number of elements, but got ";
for (size_t i = 0; i < tensors.size() - 1; i++) {
oss << tensors[i].numel() << ", ";
}
oss << "and " << tensors[tensors.size() - 1].numel()
<< " elements respectively";
return oss.str();
}
inline bool _apply_preamble(ArrayRef<Tensor> tensors) {
checkBackend("CPU_tensor_apply", tensors, Backend::CPU);
if (!_all_equal_numel(tensors))
throw std::runtime_error(_all_equal_numel_error(tensors));
// An empty tensor has no elements
for (auto& t : tensors)
if (t.sizes().equals({0}))
return false;
internal::init_tbb_num_threads();
return true;
}
inline int64_t _max_dim_tensors(ArrayRef<Tensor> tensors) {
int64_t dim = 0;
for (auto& t : tensors)
dim = std::max(dim, t.ndimension());
return dim;
}
inline void iterate(){};
template <typename Arg, typename... Args>
inline void iterate(Arg& iter, Args&... iter_tail) {
iter.counter_[iter.dim_ - 1]++;
iter.data_ += iter.strides_[iter.dim_ - 1];
if (iter.counter_[iter.dim_ - 1] == iter.sizes_[iter.dim_ - 1]) {
for (int64_t i = iter.dim_ - 1; i > 0; i--) {
if (iter.counter_[i] == iter.sizes_[i]) {
iter.counter_[i] = 0;
iter.counter_[i - 1]++;
iter.data_ = iter.data_ - (iter.sizes_[i] * iter.strides_[i]) +
iter.strides_[i - 1];
}
}
}
iterate(iter_tail...);
}
inline void forward(int64_t offset){};
template <typename Arg, typename... Args>
inline void forward(int64_t offset, Arg& iter, Args&... iter_tail) {
int64_t multi = offset;
for (int64_t i = iter.dim_ - 1; i >= 0; i--) {
int64_t inc = multi % iter.sizes_[i];
multi = multi / iter.sizes_[i];
iter.data_ = iter.data_ + inc * iter.strides_[i];
iter.counter_[i] += inc;
}
forward(offset, iter_tail...);
}
inline int64_t max_dim() {
return 0;
}
template <typename Arg, typename... Args>
inline int64_t max_dim(Arg& iter, Args&... iter_tail) {
return std::max(iter.dim_, max_dim(iter_tail...));
}
inline void apply_op(){};
template <typename Op, typename... Args>
inline void
apply_op(int64_t numel, int64_t offset, const Op& op, Args... iters) {
// For 0-dim tensors
if (numel == 1 && max_dim(iters...) == 0) {
op(*iters.data_...);
return;
}
if (offset > 0)
forward(offset, iters...);
for (int64_t i = 0; i < numel; i++) {
op(*iters.data_...);
iterate(iters...);
}
}
/*
Apply a pointwise operator to sequence of tensors
The calling convention for op is a function/functor that takes takes the same
number of pointers of type scalar as the number of given tensors. For example,
to compute a = b * c, op would be of the form:
[](scalar* a_val, const scalar* b_val, const scalar* c_val) { a_val[0] =
b_val[0] * c_val[0]; };
*/
template <typename scalar1, typename Op>
inline void CPU_tensor_apply1(Tensor tensor1, const Op op) {
if (!_apply_preamble({tensor1}))
return;
if (tensor1.ndimension() < 8) {
apply_op(
tensor1.numel(),
0,
op,
strided_tensor_iter_fixed<scalar1, 8>(tensor1, true));
} else {
apply_op(tensor1.numel(), 0, op, strided_tensor_iter<scalar1>(tensor1));
}
}
template <typename scalar1, typename scalar2, typename Op>
void CPU_tensor_apply2_dim(Tensor& tensor1, Tensor& tensor2, int64_t dim, Op op) {
checkBackend("CPU_tensor_apply2", {tensor1, tensor2}, Backend::CPU);
bool TH_TENSOR_APPLY_hasFinished = false;
int64_t TH_TENSOR_dim_index = 0;
__ATH_TENSOR_APPLYX_PREAMBLE(scalar1, tensor1, dim, 1)
__ATH_TENSOR_APPLYX_PREAMBLE(scalar2, tensor2, dim, 1)
auto t1_numel = tensor1.numel();
auto t2_numel = tensor2.numel();
if(t1_numel != t2_numel) {
std::ostringstream oss;
oss << "inconsistent tensor size, expected " << tensor1.sizes() << " and " << tensor2.sizes()
<< " to have the same number of elements, but got " << t1_numel << " and " << t2_numel << " elements respectively";
throw std::runtime_error(oss.str());
inline void CPU_tensor_apply2(Tensor tensor1, Tensor tensor2, const Op op) {
if (!_apply_preamble({tensor1, tensor2}))
return;
if (_max_dim_tensors({tensor1, tensor2}) <= 8) {
apply_op(
tensor1.numel(),
0,
op,
strided_tensor_iter_fixed<scalar1, 8>(tensor1),
strided_tensor_iter_fixed<scalar2, 8>(tensor2));
} else {
apply_op(
tensor1.numel(),
0,
op,
strided_tensor_iter<scalar1>(tensor1),
strided_tensor_iter<scalar2>(tensor2));
}
while(!TH_TENSOR_APPLY_hasFinished)
{
/* Loop through the inner most region of the Tensor */
for(; tensor1_i < tensor1_size && tensor2_i < tensor2_size; tensor1_i++, tensor2_i++, tensor1_data += tensor1_stride, tensor2_data += tensor2_stride)
{
op(*tensor1_data, *tensor2_data);
}
__ATH_TENSOR_APPLYX_UPDATE_COUNTERS(tensor1, 0)
__ATH_TENSOR_APPLYX_UPDATE_COUNTERS(tensor2, 0)
}
if(tensor1_counter != NULL)
delete [] tensor1_counter;
if(tensor2_counter != NULL)
delete [] tensor2_counter;
}
/*
Apply a pointwise operator to two tensors.
The calling convention for op is a function/functor that takes takes two references to
type scalar; at least one of these references should be non-const in order to write the output.
For example, to compute a = b^2, op would be of the form:
[](scalar &a_val, const scalar &b_val) { a_val = b_val * b_val; };
*/
template<typename scalar1, typename scalar2, typename Op>
void CPU_tensor_apply2(Tensor tensor1, Tensor tensor2, Op op) {
CPU_tensor_apply2_dim<scalar1, scalar2, Op>(tensor1, tensor2, -1, op);
template <typename scalar1, typename scalar2, typename scalar3, typename Op>
inline void
CPU_tensor_apply3(Tensor tensor1, Tensor tensor2, Tensor tensor3, const Op op) {
if (!_apply_preamble({tensor1, tensor2, tensor3}))
return;
if (_max_dim_tensors({tensor1, tensor2, tensor3}) <= 8) {
apply_op(
tensor1.numel(),
0,
op,
strided_tensor_iter_fixed<scalar1, 8>(tensor1),
strided_tensor_iter_fixed<scalar2, 8>(tensor2),
strided_tensor_iter_fixed<scalar3, 8>(tensor3));
} else {
apply_op(
tensor1.numel(),
0,
op,
strided_tensor_iter<scalar1>(tensor1),
strided_tensor_iter<scalar2>(tensor2),
strided_tensor_iter<scalar3>(tensor3));
}
}
template<typename scalar1, typename scalar2, typename scalar3, typename Op>
void CPU_tensor_apply3_dim(Tensor &tensor1, Tensor& tensor2, Tensor& tensor3, int64_t dim, Op op) {
checkBackend("CPU_tensor_apply3", {tensor1, tensor2, tensor3}, Backend::CPU);
bool TH_TENSOR_APPLY_hasFinished = false;
int64_t TH_TENSOR_dim_index = 0;
__ATH_TENSOR_APPLYX_PREAMBLE(scalar1, tensor1, dim, 1)
__ATH_TENSOR_APPLYX_PREAMBLE(scalar2, tensor2, dim, 1)
__ATH_TENSOR_APPLYX_PREAMBLE(scalar3, tensor3, dim, 1)
int elements_equal = 1;
auto t1_numel = tensor1.numel();
auto t2_numel = tensor2.numel();
auto t3_numel = tensor3.numel();
if(t1_numel!= t2_numel) {
elements_equal = 0;
} else if(t1_numel != t3_numel) {
elements_equal = 0;
template <
typename scalar1,
typename scalar2,
typename scalar3,
typename scalar4,
typename Op>
inline void CPU_tensor_apply4(
Tensor tensor1,
Tensor tensor2,
Tensor tensor3,
Tensor tensor4,
const Op op) {
if (!_apply_preamble({tensor1, tensor2, tensor3, tensor4}))
return;
if (_max_dim_tensors({tensor1, tensor2, tensor3, tensor4}) <= 8) {
apply_op(
tensor1.numel(),
0,
op,
strided_tensor_iter_fixed<scalar1, 8>(tensor1),
strided_tensor_iter_fixed<scalar2, 8>(tensor2),
strided_tensor_iter_fixed<scalar3, 8>(tensor3),
strided_tensor_iter_fixed<scalar4, 8>(tensor4));
} else {
apply_op(
tensor1.numel(),
0,
op,
strided_tensor_iter<scalar1>(tensor1),
strided_tensor_iter<scalar2>(tensor2),
strided_tensor_iter<scalar3>(tensor3),
strided_tensor_iter<scalar4>(tensor4));
}
if (elements_equal == 0) {
std::ostringstream oss;
oss << "inconsistent tensor size, expected " << tensor1.sizes() << ", " << tensor2.sizes() << ", and " << tensor3.sizes()
<< " to have the same number of elements, but got " << t1_numel << ", " << t2_numel << ", and " << t3_numel << " elements respectively";
throw std::runtime_error(oss.str());
}
while(!TH_TENSOR_APPLY_hasFinished)
{
/* Loop through the inner most region of the Tensor */
for(; tensor1_i < tensor1_size && tensor2_i < tensor2_size && tensor3_i < tensor3_size; tensor1_i++, tensor2_i++, tensor3_i++, tensor1_data += tensor1_stride, tensor2_data += tensor2_stride, tensor3_data += tensor3_stride)
{
op(*tensor1_data, *tensor2_data, *tensor3_data);
}
__ATH_TENSOR_APPLYX_UPDATE_COUNTERS(tensor1, 0)
__ATH_TENSOR_APPLYX_UPDATE_COUNTERS(tensor2, 0)
__ATH_TENSOR_APPLYX_UPDATE_COUNTERS(tensor3, 0)
}
if(tensor1_counter != NULL)
delete [] tensor1_counter;
if(tensor2_counter != NULL)
delete [] tensor2_counter;
if(tensor3_counter != NULL)
delete [] tensor3_counter;
}
/*
Apply a pointwise operator to three tensors.
The calling convention for op is a function/functor that takes takes three references to
type scalar; at least one of these references should be non-const in order to write the output.
For example, to compute a = b + c, op would be of the form:
[](scalar &a_val, const scalar &b_val, const scalar &c_val) { a_val = b_val + c_val; };
*/
template<typename scalar1, typename scalar2, typename scalar3, typename Op>
void CPU_tensor_apply3(Tensor tensor1, Tensor tensor2, Tensor tensor3, Op op) {
CPU_tensor_apply3_dim<scalar1, scalar2, scalar3, Op>(tensor1, tensor2, tensor3, -1, op);
}
template <typename scalar1, typename scalar2, typename scalar3, typename scalar4, typename Op>
void CPU_tensor_apply4_dim(Tensor &tensor1, Tensor& tensor2, Tensor& tensor3, Tensor& tensor4, int64_t dim, Op op) {
checkBackend("CPU_tensor_apply4", {tensor1, tensor2, tensor3, tensor4}, Backend::CPU);
bool TH_TENSOR_APPLY_hasFinished = false;
int64_t TH_TENSOR_dim_index = 0;
__ATH_TENSOR_APPLYX_PREAMBLE(scalar1, tensor1, dim, 1)
__ATH_TENSOR_APPLYX_PREAMBLE(scalar2, tensor2, dim, 1)
__ATH_TENSOR_APPLYX_PREAMBLE(scalar3, tensor3, dim, 1)
__ATH_TENSOR_APPLYX_PREAMBLE(scalar4, tensor4, dim, 1)
int elements_equal = 1;
auto t1_numel = tensor1.numel();
auto t2_numel = tensor2.numel();
auto t3_numel = tensor3.numel();
auto t4_numel = tensor4.numel();
if(t1_numel!= t2_numel) {
elements_equal = 0;
} else if(t1_numel != t3_numel) {
elements_equal = 0;
} else if(t1_numel != t4_numel) {
elements_equal = 0;
template <typename scalar1, typename Op>
inline void CPU_tensor_parallel_apply1(Tensor tensor1, const Op op) {
if (!_apply_preamble({tensor1}))
return;
if (tensor1.numel() < internal::TBB_GRAIN_SIZE) {
CPU_tensor_apply1<scalar1>(tensor1, op);
return;
}
if (elements_equal == 0) {
std::ostringstream oss;
oss << "inconsistent tensor size, expected " << tensor1.sizes() << ", " << tensor2.sizes() << ", "
<< tensor3.sizes() << ", and " << tensor4.sizes() << " to have the same number of elements, but got "
<< t1_numel << ", " << t2_numel << ", " << t3_numel << ", and " << t4_numel << " elements respectively";
throw std::runtime_error(oss.str());
auto range = tbb::blocked_range<size_t>(0, tensor1.numel());
if (tensor1.ndimension() < 8) {
tbb::parallel_for(
range, [&tensor1, &op](const tbb::blocked_range<size_t> r) {
apply_op(
r.end() - r.begin(),
r.begin(),
op,
strided_tensor_iter_fixed<scalar1, 8>(tensor1, true));
});
} else {
tbb::parallel_for(
range, [&tensor1, &op](const tbb::blocked_range<size_t> r) {
apply_op(
r.end() - r.begin(),
r.begin(),
op,
strided_tensor_iter<scalar1>(tensor1));
});
}
}
while(!TH_TENSOR_APPLY_hasFinished)
{
/* Loop through the inner most region of the Tensor */
for(; tensor1_i < tensor1_size && tensor2_i < tensor2_size && tensor3_i < tensor3_size && tensor4_i < tensor4_size
; tensor1_i++, tensor2_i++, tensor3_i++, tensor4_i++,
tensor1_data += tensor1_stride, tensor2_data += tensor2_stride, tensor3_data += tensor3_stride, tensor4_data += tensor4_stride)
{
op(*tensor1_data, *tensor2_data, *tensor3_data, *tensor4_data);
}
__ATH_TENSOR_APPLYX_UPDATE_COUNTERS(tensor1, 0)
__ATH_TENSOR_APPLYX_UPDATE_COUNTERS(tensor2, 0)
__ATH_TENSOR_APPLYX_UPDATE_COUNTERS(tensor3, 0)
__ATH_TENSOR_APPLYX_UPDATE_COUNTERS(tensor4, 0)
template <typename scalar1, typename scalar2, typename Op>
inline void
CPU_tensor_parallel_apply2(Tensor tensor1, Tensor tensor2, const Op op) {
if (!_apply_preamble({tensor1, tensor2}))
return;
if ((tensor1.numel() + tensor2.numel()) < internal::TBB_GRAIN_SIZE) {
CPU_tensor_apply2<scalar1, scalar2>(tensor1, tensor2, op);
return;
}
auto range = tbb::blocked_range<size_t>(0, tensor1.numel());
if (tensor1.ndimension() < 8 && tensor2.ndimension() < 8) {
tbb::parallel_for(
range, [&tensor1, &tensor2, &op](const tbb::blocked_range<size_t> r) {
apply_op(
r.end() - r.begin(),
r.begin(),
op,
strided_tensor_iter_fixed<scalar1, 8>(tensor1),
strided_tensor_iter_fixed<scalar2, 8>(tensor2));
});
} else {
tbb::parallel_for(
range, [&tensor1, &tensor2, &op](const tbb::blocked_range<size_t> r) {
apply_op(
r.end() - r.begin(),
r.begin(),
op,
strided_tensor_iter<scalar1>(tensor1),
strided_tensor_iter<scalar2>(tensor2));
});
}
if(tensor1_counter != NULL)
delete [] tensor1_counter;
if(tensor2_counter != NULL)
delete [] tensor2_counter;
if(tensor3_counter != NULL)
delete [] tensor3_counter;
if(tensor4_counter != NULL)
delete [] tensor4_counter;
}
/*
Apply a pointwise operator to four tensors.
The calling convention for op is a function/functor that takes takes four references to
type scalar; at least one of these references should be non-const in order to write the output.
For example, to compute a = b + c * d, op would be of the form:
[](scalar &a_val, const scalar &b_val, const scalar &c_val, const scalar &d_val) {
a_val = b_val + c_val * d_val;
};
*/
template<typename scalar1, typename scalar2, typename scalar3, typename scalar4, typename Op>
void CPU_tensor_apply4(Tensor tensor1, Tensor tensor2, Tensor tensor3, Tensor tensor4, Op op) {
CPU_tensor_apply4_dim<scalar1, scalar2, scalar3, scalar4, Op>(tensor1, tensor2, tensor3, tensor4, -1, op);
}
}
} // namespace at

View File

@ -1114,7 +1114,6 @@
- Int
- Short
backends:
- CPU
- CUDA
variants:
- method
@ -1161,7 +1160,6 @@
types:
- floating_point
backends:
- CPU
- CUDA
variants:
- method
@ -1354,7 +1352,6 @@
types:
- floating_point
backends:
- CPU
- CUDA
variants:
- method
@ -1400,7 +1397,6 @@
types:
- floating_point
backends:
- CPU
- CUDA
variants:
- method
@ -1475,7 +1471,6 @@
types:
- floating_point
backends:
- CPU
- CUDA
variants:
- method
@ -1695,7 +1690,6 @@
types:
- floating_point
backends:
- CPU
- CUDA
variants:
- method
@ -1741,7 +1735,6 @@
types:
- floating_point
backends:
- CPU
- CUDA
variants:
- method
@ -1758,7 +1751,6 @@
types:
- floating_point
backends:
- CPU
- CUDA
variants:
- method
@ -1775,7 +1767,6 @@
types:
- floating_point
backends:
- CPU
- CUDA
variants:
- method
@ -1792,7 +1783,6 @@
types:
- floating_point
backends:
- CPU
- CUDA
variants:
- method

View File

@ -15,7 +15,7 @@ namespace internal {
// for a certain number of workers. If there are multiple threads making
// a request at the size of the maximum number of threads, they will
// be allocated a number proportional to the other requests.
void init_tbb_num_threads();
AT_API void init_tbb_num_threads();
// This parameter is heuristically chosen to determine the minimum number of
// work that warrants paralellism. For example, when summing an array, it is
// deemed inefficient to parallelise over arrays shorter than 32768. Further,

View File

@ -72,7 +72,7 @@ void checkDefined(CheckedFrom c, const TensorArg& t);
void checkAllDefined(CheckedFrom c, at::ArrayRef<TensorArg> t);
// FixMe: does TensorArg slow things down?
void checkBackend(CheckedFrom c, at::ArrayRef<Tensor> t, at::Backend backend);
AT_API void checkBackend(CheckedFrom c, at::ArrayRef<Tensor> t, at::Backend backend);
// Methods for getting data_ptr if tensor is defined
void * maybe_data_ptr(const Tensor& tensor);

View File

@ -3,48 +3,106 @@
#include "ATen/ExpandUtils.h"
#include "ATen/NativeFunctions.h"
#include "ATen/WrapDimUtils.h"
#include "cpu/UnaryOpsKernel.h"
#include "ATen/CPUApplyUtils.h"
#include "ATen/Parallel.h"
#include "ATen/native/cpu/UnaryOpsKernel.h"
#include <algorithm>
#include <cmath>
#include <functional>
#include <numeric>
#include <vector>
#include <map>
namespace at { namespace native {
namespace at {
namespace native {
#define IMPLEMENT_UNARY_OP(op) \
Tensor op(const Tensor& self) { \
Tensor result = self.type().tensor(); \
return at::op ## _out(result, self); \
} \
Tensor& op##_(Tensor& self) { \
return at::op ## _out(self, self); \
} \
Tensor& _ ## op ## _out_cuda(Tensor& result, const Tensor& self) { \
return at::_ ## op ## _out(result, self); \
} \
Tensor& _ ## op ## _out_cpu(Tensor& result, const Tensor& self) { \
if (result.is_contiguous() && self.is_contiguous()) { \
result.resize_(self.sizes()); \
if (result.numel() > 0) { \
op ## Impl(result, self); \
} \
return result; \
} \
return at::_ ## op ## _out(result, self); \
#define IMPLEMENT_UNARY_OP_PREQUEL(op) \
Tensor op(const Tensor& self) { \
Tensor result = self.type().tensor(); \
return at::op##_out(result, self); \
} \
Tensor& _##op##__cuda(Tensor& self) { \
return at::_##op##_out(self, self); \
} \
Tensor& _##op##_out_cuda(Tensor& result, const Tensor& self) { \
return at::_##op##_out(result, self); \
}
#define IMPLEMENT_UNARY_OP_FLOAT_CMATH(op) \
Tensor& _##op##__cpu(Tensor& self_) { \
if (self_.numel() > 0) { \
Tensor self = sort_strides(self_); \
AT_DISPATCH_FLOATING_TYPES(self.type(), op, [&] { \
CPU_tensor_parallel_apply1<scalar_t>( \
self, [](scalar_t& y) { y = std::op(y); }); \
}); \
} \
return self_; \
} \
Tensor& _##op##_out_cpu(Tensor& result, const Tensor& self) { \
result.resize_(self.sizes()); \
if (result.numel() > 0) { \
AT_DISPATCH_FLOATING_TYPES(self.type(), op, [&] { \
CPU_tensor_parallel_apply2<scalar_t, scalar_t>( \
result, self, [](scalar_t& y, scalar_t& x) { y = std::op(x); }); \
}); \
} \
return result; \
}
#define IMPLEMENT_UNARY_OP_VEC(op) \
Tensor& _##op##__cpu(Tensor& self_) { \
if (self_.numel() > 0) { \
Tensor self = sort_strides(self_); \
if (self.is_contiguous()) { \
op##Impl(self, self); \
} else { \
AT_DISPATCH_FLOATING_TYPES(self.type(), op, [&] { \
CPU_tensor_parallel_apply1<scalar_t>( \
self, [](scalar_t& y) { y = std::op(y); }); \
}); \
} \
} \
return self_; \
} \
Tensor& _##op##_out_cpu(Tensor& result, const Tensor& self) { \
result.resize_(self.sizes()); \
if (result.numel() > 0) { \
if (result.is_contiguous() && self.is_contiguous()) { \
op##Impl(result, self); \
} else { \
AT_DISPATCH_FLOATING_TYPES(self.type(), op, [&] { \
CPU_tensor_parallel_apply2<scalar_t, scalar_t>( \
result, self, [](scalar_t& y, scalar_t& x) { y = std::op(x); }); \
}); \
} \
} \
return result; \
}
IMPLEMENT_UNARY_OP_PREQUEL(abs)
IMPLEMENT_UNARY_OP_PREQUEL(ceil)
IMPLEMENT_UNARY_OP_PREQUEL(cos)
IMPLEMENT_UNARY_OP_PREQUEL(exp)
IMPLEMENT_UNARY_OP_PREQUEL(floor)
IMPLEMENT_UNARY_OP_PREQUEL(log)
IMPLEMENT_UNARY_OP_PREQUEL(round)
IMPLEMENT_UNARY_OP_PREQUEL(sin)
IMPLEMENT_UNARY_OP_PREQUEL(sqrt)
IMPLEMENT_UNARY_OP_PREQUEL(trunc)
IMPLEMENT_UNARY_OP_VEC(abs)
IMPLEMENT_UNARY_OP_VEC(ceil)
IMPLEMENT_UNARY_OP_FLOAT_CMATH(cos)
IMPLEMENT_UNARY_OP_FLOAT_CMATH(exp)
IMPLEMENT_UNARY_OP_VEC(floor)
IMPLEMENT_UNARY_OP_FLOAT_CMATH(log)
IMPLEMENT_UNARY_OP_VEC(round)
IMPLEMENT_UNARY_OP_FLOAT_CMATH(sin)
IMPLEMENT_UNARY_OP_VEC(sqrt)
IMPLEMENT_UNARY_OP_VEC(trunc)
}
IMPLEMENT_UNARY_OP(abs)
IMPLEMENT_UNARY_OP(ceil)
IMPLEMENT_UNARY_OP(cos)
IMPLEMENT_UNARY_OP(exp)
IMPLEMENT_UNARY_OP(floor)
IMPLEMENT_UNARY_OP(log)
IMPLEMENT_UNARY_OP(round)
IMPLEMENT_UNARY_OP(sin)
IMPLEMENT_UNARY_OP(sqrt)
IMPLEMENT_UNARY_OP(trunc)
}} // namespace at::native
} // namespace at

View File

@ -7,12 +7,14 @@
#include "ATen/cpu/vec256/vec256.h"
#include "ATen/native/cpu/CapabilityDispatch.h"
namespace at { namespace native { namespace {
namespace at { namespace native {
namespace {
using namespace vec256;
template <typename scalar_t, typename F>
static void unary_kernel(scalar_t* arr_out, const scalar_t* arr_in, int64_t size, F func) {
static void
unary_kernel(scalar_t* arr_out, const scalar_t* arr_in, int64_t size, F func) {
using Vec = Vec256<scalar_t>;
int64_t size_rounded = size - (size % Vec::size);
int64_t k = 0;
@ -52,94 +54,59 @@ static void parallel_apply(Tensor& result, const Tensor& self, F f) {
static void abs_kernel(Tensor& result, const Tensor& self) {
AT_DISPATCH_ALL_TYPES(self.type(), "abs", [&] {
parallel_apply<scalar_t>(result, self, [](const Vec256<scalar_t>& x) {
return x.abs();
});
});
parallel_apply<scalar_t>(
result,
self,
[](const Vec256<scalar_t>& x) { return x.abs(); }); });
}
static void ceil_kernel(Tensor& result, const Tensor& self) {
AT_DISPATCH_FLOATING_TYPES(self.type(), "ceil", [&] {
parallel_apply<scalar_t>(result, self, [](const Vec256<scalar_t>& x) {
return x.ceil();
});
});
}
static void cos_kernel(Tensor& result, const Tensor& self) {
AT_DISPATCH_FLOATING_TYPES(self.type(), "cos", [&] {
parallel_apply<scalar_t>(result, self, [](const Vec256<scalar_t>& x) {
return x.cos();
});
});
}
static void exp_kernel(Tensor& result, const Tensor& self) {
AT_DISPATCH_FLOATING_TYPES(self.type(), "exp", [&] {
parallel_apply<scalar_t>(result, self, [](const Vec256<scalar_t>& x) {
return x.exp();
});
});
parallel_apply<scalar_t>(
result,
self,
[](const Vec256<scalar_t>& x) { return x.ceil(); }); });
}
static void floor_kernel(Tensor& result, const Tensor& self) {
AT_DISPATCH_FLOATING_TYPES(self.type(), "floor", [&] {
parallel_apply<scalar_t>(result, self, [](const Vec256<scalar_t>& x) {
return x.floor();
});
});
}
static void log_kernel(Tensor& result, const Tensor& self) {
AT_DISPATCH_FLOATING_TYPES(self.type(), "log", [&] {
parallel_apply<scalar_t>(result, self, [](const Vec256<scalar_t>& x) {
return x.log();
});
});
parallel_apply<scalar_t>(
result,
self,
[](const Vec256<scalar_t>& x) { return x.floor(); }); });
}
static void round_kernel(Tensor& result, const Tensor& self) {
AT_DISPATCH_FLOATING_TYPES(self.type(), "round", [&] {
parallel_apply<scalar_t>(result, self, [](const Vec256<scalar_t>& x) {
return x.round();
});
});
}
static void sin_kernel(Tensor& result, const Tensor& self) {
AT_DISPATCH_FLOATING_TYPES(self.type(), "sin", [&] {
parallel_apply<scalar_t>(result, self, [](const Vec256<scalar_t>& x) {
return x.sin();
});
});
parallel_apply<scalar_t>(
result,
self,
[](const Vec256<scalar_t>& x) { return x.round(); }); });
}
static void sqrt_kernel(Tensor& result, const Tensor& self) {
AT_DISPATCH_FLOATING_TYPES(self.type(), "sqrt", [&] {
parallel_apply<scalar_t>(result, self, [](const Vec256<scalar_t>& x) {
return x.sqrt();
});
});
parallel_apply<scalar_t>(
result,
self,
[](const Vec256<scalar_t>& x) { return x.sqrt(); }); });
}
static void trunc_kernel(Tensor& result, const Tensor& self) {
AT_DISPATCH_FLOATING_TYPES(self.type(), "trunc", [&] {
parallel_apply<scalar_t>(result, self, [](const Vec256<scalar_t>& x) {
return x.trunc();
});
});
parallel_apply<scalar_t>(
result,
self,
[](const Vec256<scalar_t>& x) { return x.trunc(); }); });
}
} // anonymous namespace
} // anonymous namespace
REGISTER_DISPATCH(absImpl, &abs_kernel);
REGISTER_DISPATCH(ceilImpl, &ceil_kernel);
REGISTER_DISPATCH(cosImpl, &cos_kernel);
REGISTER_DISPATCH(expImpl, &exp_kernel);
REGISTER_DISPATCH(floorImpl, &floor_kernel);
REGISTER_DISPATCH(logImpl, &log_kernel);
REGISTER_DISPATCH(roundImpl, &round_kernel);
REGISTER_DISPATCH(sinImpl, &sin_kernel);
REGISTER_DISPATCH(sqrtImpl, &sqrt_kernel);
REGISTER_DISPATCH(truncImpl, &trunc_kernel);

View File

@ -1,7 +1,6 @@
#pragma once
#include <ATen/ATen.h>
#include <ATen/Parallel.h>
#include <stdexcept>
#include "CapabilityDispatch.h"
@ -11,22 +10,16 @@ using unary_fn = void(*)(Tensor&, const Tensor&);
extern DispatchStub<unary_fn> absImpl;
extern DispatchStub<unary_fn> ceilImpl;
extern DispatchStub<unary_fn> cosImpl;
extern DispatchStub<unary_fn> expImpl;
extern DispatchStub<unary_fn> floorImpl;
extern DispatchStub<unary_fn> logImpl;
extern DispatchStub<unary_fn> roundImpl;
extern DispatchStub<unary_fn> sinImpl;
extern DispatchStub<unary_fn> sqrtImpl;
extern DispatchStub<unary_fn> truncImpl;
// Missing unary functions
// TODO: Add generic apply function for contiguous and non-contiguous tensors
// The goal here is to move more ops entirely into ATen and take advantage of
// automatic vectorization with file-specific flags
// acos
// asin
// atan
// cos
// cosh
// digamma
// erf
@ -37,6 +30,7 @@ extern DispatchStub<unary_fn> truncImpl;
// log1p
// rsqrt
// sigmoid
// sin
// sinh
// tan
// tanh

View File

@ -44,6 +44,9 @@
- func: abs(Tensor self) -> Tensor
- func: abs_(Tensor self) -> Tensor
dispatch:
CPU: _abs__cpu
CUDA: _abs__cuda
- func: abs_out(Tensor result, Tensor self) -> Tensor
variants: function
@ -118,6 +121,9 @@
- func: ceil(Tensor self) -> Tensor
- func: ceil_(Tensor self) -> Tensor
dispatch:
CPU: _ceil__cpu
CUDA: _ceil__cuda
- func: ceil_out(Tensor result, Tensor self) -> Tensor
variants: function
@ -170,6 +176,9 @@
- func: cos(Tensor self) -> Tensor
- func: cos_(Tensor self) -> Tensor
dispatch:
CPU: _cos__cpu
CUDA: _cos__cuda
- func: cos_out(Tensor result, Tensor self) -> Tensor
variants: function
@ -343,6 +352,9 @@
- func: exp(Tensor self) -> Tensor
- func: exp_(Tensor self) -> Tensor
dispatch:
CPU: _exp__cpu
CUDA: _exp__cuda
- func: exp_out(Tensor result, Tensor self) -> Tensor
variants: function
@ -368,6 +380,9 @@
- func: floor(Tensor self) -> Tensor
- func: floor_(Tensor self) -> Tensor
dispatch:
CPU: _floor__cpu
CUDA: _floor__cuda
- func: floor_out(Tensor result, Tensor self) -> Tensor
variants: function
@ -451,6 +466,9 @@
- func: log(Tensor self) -> Tensor
- func: log_(Tensor self) -> Tensor
dispatch:
CPU: _log__cpu
CUDA: _log__cuda
- func: log_out(Tensor result, Tensor self) -> Tensor
variants: function
@ -604,6 +622,9 @@
- func: round(Tensor self) -> Tensor
- func: round_(Tensor self) -> Tensor
dispatch:
CPU: _round__cpu
CUDA: _round__cuda
- func: round_out(Tensor result, Tensor self) -> Tensor
variants: function
@ -632,6 +653,9 @@
- func: sin(Tensor self) -> Tensor
- func: sin_(Tensor self) -> Tensor
dispatch:
CPU: _sin__cpu
CUDA: _sin__cuda
- func: sin_out(Tensor result, Tensor self) -> Tensor
variants: function
@ -719,6 +743,9 @@
- func: sqrt(Tensor self) -> Tensor
- func: sqrt_(Tensor self) -> Tensor
dispatch:
CPU: _sqrt__cpu
CUDA: _sqrt__cuda
- func: sqrt_out(Tensor result, Tensor self) -> Tensor
variants: function
@ -774,6 +801,9 @@
- func: trunc(Tensor self) -> Tensor
- func: trunc_(Tensor self) -> Tensor
dispatch:
CPU: _trunc__cpu
CUDA: _trunc__cuda
- func: trunc_out(Tensor result, Tensor self) -> Tensor
variants: function

View File

@ -7,6 +7,9 @@ ENDIF(MSVC)
ADD_EXECUTABLE(scalar_test scalar_test.cpp)
target_link_libraries(scalar_test ATen)
ADD_EXECUTABLE(apply_utils_test apply_utils_test.cpp)
target_link_libraries(apply_utils_test ATen)
ADD_EXECUTABLE(basic basic.cpp)
target_link_libraries(basic ATen)

View File

@ -0,0 +1,139 @@
#define CATCH_CONFIG_MAIN
#include "catch.hpp"
#include "ATen/ATen.h"
#include "ATen/CPUApplyUtils.h"
#include "test_assert.h"
#include "test_seed.h"
#include <iostream>
using namespace std;
using namespace at;
void fill_tensor(int64_t scalar, Tensor& t_) {
auto t = t_.view(-1);
for (int64_t i = 0; i < t.numel(); i++) {
t[i] = (i + 1) * scalar;
}
}
// This test exercises all sequential applyX functions. Given a shape and two
// transpose dimensions we create 5 tensors (a0, ..., a4) of the given shape and
// transpose the dimension a with b for each tensor. Then we call the applyX
// function on each floating type. a4 is allocated in doubles only, whereas a0,
// ..., a3 are allocated in the given type. For each applyX function we once
// write the same type as we read (using a0, ..., aX-1) and we once write to
// double (using a4 as a target). We also exercise on a zero_dim and empty
// tensor.
void test(Type& type, IntList shape, int64_t a = 0, int64_t b = 1) {
auto zero_dim = type.tensor({});
zero_dim.fill_(2);
zero_dim.exp_();
AT_DISPATCH_FLOATING_TYPES(zero_dim.type(), "test0", [&] {
ASSERT(zero_dim.data<scalar_t>()[0] == std::exp(2));
});
auto empty_t = type.tensor({0});
empty_t.fill_(3);
empty_t.exp_();
auto a0 = type.tensor();
auto a1 = type.tensor();
auto a2 = type.tensor();
auto a3 = type.tensor();
auto a4 = CPU(kDouble).tensor();
std::vector<Tensor> tensors({a0, a1, a2, a3, a4});
for (size_t i = 0; i < tensors.size(); i++) {
tensors[i].resize_(shape);
fill_tensor(i + 1, tensors[i]);
if (a >= 0 && b >= 0) {
tensors[i].transpose_(a, b);
}
}
AT_DISPATCH_FLOATING_TYPES(a0.type(), "test1", [&] {
CPU_tensor_apply2<scalar_t, scalar_t>(
a0, a1, [](scalar_t& y, const scalar_t& x) { y = x * x; });
CPU_tensor_apply2<double, scalar_t>(
a4, a1, [](double& y, scalar_t x) { y = (double)(x * x); });
for (int64_t i = 0; i < a0.numel(); i++) {
auto target = a1.data<scalar_t>()[i] * a1.data<scalar_t>()[i];
ASSERT(a0.data<scalar_t>()[i] == target);
ASSERT(a4.data<double>()[i] == target);
}
});
AT_DISPATCH_FLOATING_TYPES(a0.type(), "test2", [&] {
CPU_tensor_apply3<scalar_t, scalar_t, scalar_t>(
a0, a1, a2, [](scalar_t& y, const scalar_t& x, const scalar_t& z) {
y = x * x + z;
});
CPU_tensor_apply3<double, scalar_t, scalar_t>(
a4, a1, a2, [](double& y, const scalar_t& x, const scalar_t& z) {
y = (double)(x * x + z);
});
for (int64_t i = 0; i < a0.numel(); i++) {
auto target = a1.data<scalar_t>()[i] * a1.data<scalar_t>()[i];
target = target + a2.data<scalar_t>()[i];
ASSERT(a0.data<scalar_t>()[i] == target);
ASSERT(a4.data<double>()[i] == target);
}
});
AT_DISPATCH_FLOATING_TYPES(a0.type(), "test3", [&] {
CPU_tensor_apply4<scalar_t, scalar_t, scalar_t, scalar_t>(
a0,
a1,
a2,
a3,
[](scalar_t& y,
const scalar_t& x,
const scalar_t& z,
const scalar_t& a) { y = x * x + z * a; });
CPU_tensor_apply4<double, scalar_t, scalar_t, scalar_t>(
a4,
a1,
a2,
a3,
[](double& y, const scalar_t& x, const scalar_t& z, const scalar_t& a) {
y = (double)(x * x + z * a);
});
for (int64_t i = 0; i < a0.numel(); i++) {
auto target = a1.data<scalar_t>()[i] * a1.data<scalar_t>()[i];
target = target + a2.data<scalar_t>()[i] * a3.data<scalar_t>()[i];
ASSERT(a0.data<scalar_t>()[i] == target);
ASSERT(a4.data<double>()[i] == target);
}
});
}
TEST_CASE("apply utils test 2-dim small contiguous", "[cpu]") {
manual_seed(123, at::Backend::CPU);
test(CPU(kDouble), {2, 1}, -1, -1);
}
TEST_CASE("apply utils test 2-dim small", "[cpu]") {
manual_seed(123, at::Backend::CPU);
test(CPU(kDouble), {2, 1});
}
TEST_CASE("apply utils test 2-dim", "[cpu]") {
manual_seed(123, at::Backend::CPU);
test(CPU(kDouble), {20, 10});
}
TEST_CASE("apply utils test 3-dim", "[cpu]") {
manual_seed(123, at::Backend::CPU);
test(CPU(kDouble), {3, 4, 2});
}
TEST_CASE("apply utils test 3-dim medium", "[cpu]") {
manual_seed(123, at::Backend::CPU);
test(CPU(kDouble), {3, 40, 2});
}
TEST_CASE("apply utils test 10-dim", "[cpu]") {
manual_seed(123, at::Backend::CPU);
test(CPU(kDouble), {3, 4, 2, 5, 2, 1, 3, 4, 2, 3});
}

View File

@ -9,6 +9,7 @@ $BUILD_ROOT/src/ATen/test/atest
$BUILD_ROOT/src/ATen/test/scalar_test
$BUILD_ROOT/src/ATen/test/broadcast_test
$BUILD_ROOT/src/ATen/test/wrapdim_test
$BUILD_ROOT/src/ATen/test/apply_utils_test
$BUILD_ROOT/src/ATen/test/dlconvertor_test
$BUILD_ROOT/src/ATen/test/native_test
$BUILD_ROOT/src/ATen/test/scalar_tensor_test

View File

@ -266,7 +266,7 @@ class TestTorch(TestCase):
def compare_reference(input, dtype):
input = torch.tensor(input, dtype=dtype)
res1 = torchfn(input)
res1 = torchfn(input.clone())
res2 = input.clone().apply_(lambda x: mathfn(x))
torch.testing.assert_allclose(res1, res2)
@ -287,6 +287,32 @@ class TestTorch(TestCase):
check_non_contiguous((5, 7), torch.float)
check_non_contiguous((1024,), torch.float)
# If size(dim) == 1, stride(dim) is not defined.
# The code needs to be able to handle this
def check_contiguous_size1(dtype):
contig = torch.randn((5, 100), dtype=dtype)
contig = contig[:1, :50]
contig2 = torch.empty(contig.size(), dtype=dtype)
contig2.copy_(contig)
self.assertTrue(contig.is_contiguous())
self.assertTrue(contig2.is_contiguous())
self.assertEqual(torchfn(contig), torchfn(contig2), 'contiguous size1')
check_contiguous_size1(torch.double)
check_contiguous_size1(torch.float)
def check_contiguous_size1_largedim(dtype):
contig = torch.randn((5, 2, 3, 1, 4, 5, 3, 2, 1, 2, 3, 4), dtype=dtype)
contig = contig[:1, :, :, :, :, :, :, :, :, :, :, :]
contig2 = torch.empty(contig.size(), dtype=dtype)
contig2.copy_(contig)
self.assertTrue(contig.is_contiguous())
self.assertTrue(contig2.is_contiguous())
self.assertEqual(torchfn(contig), torchfn(contig2), 'contiguous size1')
check_contiguous_size1_largedim(torch.double)
check_contiguous_size1_largedim(torch.float)
def check_large(dtype):
input = torch.randn(1024, 512, dtype=dtype)
actual = torchfn(input)
@ -298,11 +324,20 @@ class TestTorch(TestCase):
check_large(torch.double)
check_large(torch.float)
def _test_math_by_name(self, function_name):
torchfn = getattr(torch, function_name)
mathfn = getattr(math, function_name)
def __test_math_by_name(self, function_name, mathfn, selffn):
mathfn = getattr(math, mathfn)
if selffn:
def torchfn(x):
return getattr(x, function_name)()
else:
torchfn = getattr(torch, function_name)
self._test_math(torchfn, mathfn)
def _test_math_by_name(self, function_name, test_self=True):
if test_self:
self.__test_math_by_name(function_name + "_", function_name, True)
self.__test_math_by_name(function_name, function_name, False)
def test_sin(self):
self._test_math_by_name('sin')

View File

@ -246,7 +246,7 @@ class TestDataLoader(TestCase):
dataiter = iter(dataloader)
self.assertEqual(len(list(dataiter)), 1)
@unittest.skipIf(IS_WINDOWS, "FIXME: Intermittent CUDA out-of-memory error")
@unittest.skip("FIXME: Intermittent CUDA out-of-memory error on Windows and time-out under ASAN")
def test_multi_keep(self):
dataloader = torch.utils.data.DataLoader(self.dataset,
batch_size=self.batch_size,