Speed up sum over a dimension (#6026)

Perf numbers:
https://gist.github.com/colesbury/9e28dd7b0f27b0b019f68adbd4bd4b88

I've changed the dispatch stub so that it doesn't require every kernel
to be compiled for every instruction set. Kernel implementations are
stored in the stub's table with the REGISTER_DISPATCH macro.

I've also moved vec256 to it's own folder and split up the
specializations before they get too unwieldy.

Change UnaryOpsKernel to use new DisaptchStub

 - Prefer signed integers. Mixing signed and unsigned integers is a
   pain and ATen mostly uses signed integers (int64_t).
 - Use inline lambda instead of struct for UnaryOps
 - Rename partial load overload "load_partial"
This commit is contained in:
Sam Gross 2018-03-29 18:13:43 -04:00 committed by GitHub
parent 3dffac91bc
commit e4c0bb1809
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 999 additions and 889 deletions

View File

@ -86,9 +86,9 @@ FOREACH(i RANGE ${NUM_CPU_CAPABILITY_NAMES})
SET(cpu_kernel_cpp ${NEW_IMPL} ${cpu_kernel_cpp}) # Create list of copies
LIST(GET CPU_CAPABILITY_FLAGS ${i} FLAGS)
IF(MSVC)
SET(MACRO_FLAG "/DCPUCAPABILITY${CPU_CAPABILITY}")
SET(MACRO_FLAG "/DCPU_CAPABILITY=${CPU_CAPABILITY} /DCPU_CAPABILITY_${CPU_CAPABILITY}")
ELSE(MSVC)
SET(MACRO_FLAG "-DCPUCAPABILITY${CPU_CAPABILITY}")
SET(MACRO_FLAG "-DCPU_CAPABILITY=${CPU_CAPABILITY} -DCPU_CAPABILITY_${CPU_CAPABILITY}")
ENDIF(MSVC)
SET_SOURCE_FILES_PROPERTIES(${NEW_IMPL} PROPERTIES COMPILE_FLAGS "${FLAGS} ${MACRO_FLAG}")
ENDFOREACH()

View File

@ -21,7 +21,7 @@ void init_tbb_num_threads();
// deemed inefficient to parallelise over arrays shorter than 32768. Further,
// no parallel algorithm (such as parallel_reduce) should split work into
// smaller than GRAIN_SIZE chunks.
constexpr size_t TBB_GRAIN_SIZE = 32768;
constexpr int64_t TBB_GRAIN_SIZE = 32768;
} // namespace internal
template <class T, template <class> class OP>
@ -35,7 +35,7 @@ T parallel_reduce(
T result_;
static tbb::affinity_partitioner ap;
if ((size_t)(end - start) < internal::TBB_GRAIN_SIZE) {
if (end - start < internal::TBB_GRAIN_SIZE) {
result_ = f(data, start, end, init_);
} else {
result_ = tbb::parallel_reduce(
@ -89,28 +89,4 @@ void parallel_reduce_2d(
}
}
template <class T>
void parallel_for_1d(
void (*f)(T*, const T*, size_t, size_t),
Tensor& result,
const Tensor& self) {
internal::init_tbb_num_threads();
static tbb::affinity_partitioner ap;
T* arr_out = result.data<T>();
const T* arr_in = self.data<T>();
size_t start = 0;
size_t end = self.numel();
if (end - start < internal::TBB_GRAIN_SIZE) {
f(arr_out, arr_in, start, end);
} else {
tbb::parallel_for(
tbb::blocked_range<size_t>(start, end, internal::TBB_GRAIN_SIZE),
[&arr_out, &arr_in, &f](const tbb::blocked_range<size_t> r) {
f(arr_out, arr_in, r.begin(), r.end());
},
ap);
}
}
} // namespace at

View File

@ -0,0 +1,25 @@
#pragma once
#if defined(_MSC_VER)
/* Microsoft C/C++-compatible compiler */
#include <intrin.h>
#if _MSC_VER <= 1900
#define _mm256_extract_epi64(X, Y) (((uint64_t*)&X)[Y])
#endif
#elif defined(__GNUC__) && (defined(__x86_64__) || defined(__i386__))
/* GCC-compatible compiler, targeting x86/x86-64 */
#include <x86intrin.h>
#elif defined(__GNUC__) && defined(__ARM_NEON__)
/* GCC-compatible compiler, targeting ARM with NEON */
#include <arm_neon.h>
#elif defined(__GNUC__) && defined(__IWMMXT__)
/* GCC-compatible compiler, targeting ARM with WMMX */
#include <mmintrin.h>
#elif (defined(__GNUC__) || defined(__xlC__)) && \
(defined(__VEC__) || defined(__ALTIVEC__))
/* XLC or GCC-compatible compiler, targeting PowerPC with VMX/VSX */
#include <altivec.h>
#elif defined(__GNUC__) && defined(__SPE__)
/* GCC-compatible compiler, targeting PowerPC with SPE */
#include <spe.h>
#endif

View File

@ -0,0 +1,35 @@
#pragma once
#include "intrinsics.h"
#include "vec256_base.h"
#include "vec256_float.h"
#include "vec256_double.h"
#include "vec256_int.h"
#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <cstring>
#include <iostream>
namespace at {
namespace vec256 {
template <typename T>
std::ostream& operator<<(std::ostream& stream, const Vec256<T>& vec) {
T buf[Vec256<T>::size()];
vec.store(buf);
stream << "vec[";
for (int i = 0; i != vec.size(); i++) {
if (i != 0) {
stream << ", ";
}
stream << buf[i];
}
stream << "]";
return stream;
}
}
}

View File

@ -0,0 +1,106 @@
#pragma once
#include <cstring>
#if defined(__GNUC__)
#define __at_align32__ __attribute__((aligned(32)))
#elif defined(_WIN32)
#define __at_align32__ __declspec(align(32))
#else
#define __at_align32__
#endif
namespace at {
namespace vec256 {
// NOTE: If you specialize on a type, you must define all operations!
// emulates vectorized types
template <class T>
struct Vec256 {
static constexpr int size = 32 / sizeof(T);
__at_align32__ T values[32 / sizeof(T)];
Vec256() {}
Vec256(T val) {
for (int i = 0; i != size; i++) {
values[i] = val;
}
}
void load(const void* ptr) {
std::memcpy(values, ptr, 32);
};
void load_partial(const void* ptr, int count) {
std::memcpy(values, ptr, count * sizeof(T));
}
static Vec256 s_load(const T* ptr) {
Vec256 vec;
vec.load(ptr);
return vec;
}
void store(T *ptr) const {
std::memcpy(ptr, values, 32);
}
void store_partial(void* ptr, int count) const {
std::memcpy(ptr, values, count * sizeof(T));
}
Vec256<T> map(T (*f)(T)) const {
Vec256<T> ret;
for (int64_t i = 0; i != size; i++) {
ret.values[i] = f(values[i]);
}
return ret;
}
Vec256<T> abs() const {
Vec256<T> ret;
for (int64_t i = 0; i < size; i++) {
ret.values[i] = values[i] < 0 ? -values[i] : values[i];
}
return ret;
}
Vec256<T> exp() const {
return map(std::exp);
}
Vec256<T> log() const {
return map(std::log);
}
Vec256<T> ceil() const {
return map(std::ceil);
}
Vec256<T> cos() const {
return map(std::cos);
}
Vec256<T> floor() const {
return map(std::floor);
}
Vec256<T> round() const {
return map(std::round);
}
Vec256<T> sin() const {
return map(std::sin);
}
Vec256<T> trunc() const {
return map(std::trunc);
}
Vec256<T> sqrt() const {
return map(std::sqrt);
}
};
template <class T> Vec256<T> operator+(const Vec256<T> &a, const Vec256<T> &b) {
Vec256<T> c = Vec256<T>();
for (int i = 0; i != c.size; i++) {
c.values[i] = a.values[i] + b.values[i];
}
return c;
}
template <class T> Vec256<T> operator*(const Vec256<T> &a, const Vec256<T> &b) {
Vec256<T> c = Vec256<T>();
for (int i = 0; i != c.size; i++) {
c.values[i] = a.values[i] * b.values[i];
}
return c;
}
}
}

View File

@ -0,0 +1,97 @@
#pragma once
#include "intrinsics.h"
#include "vec256_base.h"
namespace at {
namespace vec256 {
#ifdef __AVX__
template <> class Vec256<double> {
public:
static constexpr int size = 4;
__m256d values;
Vec256() {}
Vec256(__m256d v) : values(v) {}
Vec256(double val) {
values = _mm256_set1_pd(val);
}
operator __m256d() const {
return values;
}
void load(const void* ptr) {
values = _mm256_loadu_pd(reinterpret_cast<const double*>(ptr));
}
void load_partial(const void *ptr, int count) {
double tmp_values[size];
std::memcpy(tmp_values, ptr, count * sizeof(double));
load(tmp_values);
}
static Vec256<double> s_load(const void* ptr) {
Vec256<double> vec;
vec.load(ptr);
return vec;
}
void store(void* ptr) const {
_mm256_storeu_pd(reinterpret_cast<double*>(ptr), values);
}
void store_partial(void* ptr, int count) const {
double tmp_values[size];
store(tmp_values);
std::memcpy(ptr, tmp_values, count * sizeof(double));
}
Vec256<double> map(double (*f)(double)) const {
__at_align32__ double tmp[4];
store(tmp);
for (int64_t i = 0; i < 4; i++) {
tmp[i] = f(tmp[i]);
}
return s_load(tmp);
}
Vec256<double> abs() const {
auto mask = _mm256_set1_pd(-0.f);
return _mm256_andnot_pd(mask, values);
}
Vec256<double> exp() const {
return map(std::exp);
}
Vec256<double> log() const {
return map(std::log);
}
Vec256<double> sin() const {
return map(std::sin);
}
Vec256<double> cos() const {
return map(std::cos);
}
Vec256<double> ceil() const {
return _mm256_ceil_pd(values);
}
Vec256<double> floor() const {
return _mm256_floor_pd(values);
}
Vec256<double> round() const {
return _mm256_round_pd(values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
}
Vec256<double> trunc() const {
return _mm256_round_pd(values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC));
}
Vec256<double> sqrt() const {
return _mm256_sqrt_pd(values);
}
};
template <>
Vec256<double> inline operator+(const Vec256<double>& a, const Vec256<double>& b) {
return _mm256_add_pd(a, b);
}
template <>
Vec256<double> inline operator*(const Vec256<double>& a, const Vec256<double>& b) {
return _mm256_mul_pd(a, b);
}
#endif
}}

View File

@ -0,0 +1,116 @@
#pragma once
#include "intrinsics.h"
#include "vec256_base.h"
#ifdef __AVX2__
#include <ATen/native/cpu/avx_mathfun.h>
#endif
namespace at {
namespace vec256 {
#ifdef __AVX__
template <> class Vec256<float> {
public:
static constexpr int size = 8;
__m256 values;
Vec256() {}
Vec256(__m256 v) : values(v) {}
Vec256(float val) {
values = _mm256_set1_ps(val);
}
operator __m256() const {
return values;
}
void load(const void *ptr) {
values = _mm256_loadu_ps(reinterpret_cast<const float*>(ptr));
}
void load_partial(const void *ptr, int count) {
float tmp_values[size];
std::memcpy(tmp_values, ptr, count * sizeof(float));
load(tmp_values);
}
static Vec256<float> s_load(const void* ptr) {
Vec256<float> vec;
vec.load(ptr);
return vec;
}
void store(void *ptr) const {
_mm256_storeu_ps(reinterpret_cast<float*>(ptr), values);
}
void store_partial(void* ptr, int count) const {
float tmp_values[size];
store(tmp_values);
std::memcpy(ptr, tmp_values, count * sizeof(float));
}
Vec256<float> map(float (*f)(float)) const {
__at_align32__ float tmp[8];
store(tmp);
for (int64_t i = 0; i < 8; i++) {
tmp[i] = f(tmp[i]);
}
return s_load(tmp);
}
Vec256<float> abs() const {
auto mask = _mm256_set1_ps(-0.f);
return _mm256_andnot_ps(mask, values);
}
Vec256<float> exp() const {
#ifdef __AVX2__
return exp256_ps(values);
#else
return map(std::exp);
#endif
}
Vec256<float> log() const {
#ifdef __AVX2__
return log256_ps(values);
#else
return map(std::log);
#endif
}
Vec256<float> sin() const {
#ifdef __AVX2__
return sin256_ps(values);
#else
return map(std::sin);
#endif
}
Vec256<float> cos() const {
#ifdef __AVX2__
return cos256_ps(values);
#else
return map(std::cos);
#endif
}
Vec256<float> ceil() const {
return _mm256_ceil_ps(values);
}
Vec256<float> floor() const {
return _mm256_floor_ps(values);
}
Vec256<float> round() const {
return _mm256_round_ps(values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
}
Vec256<float> trunc() const {
return _mm256_round_ps(values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC));
}
Vec256<float> sqrt() const {
return _mm256_sqrt_ps(values);
}
};
template <>
Vec256<float> inline operator+(const Vec256<float>& a, const Vec256<float>& b) {
return _mm256_add_ps(a, b);
}
template <>
Vec256<float> inline operator*(const Vec256<float>& a, const Vec256<float>& b) {
return _mm256_mul_ps(a, b);
}
#endif
}}

View File

@ -0,0 +1,157 @@
#pragma once
#include "intrinsics.h"
#include "vec256_base.h"
namespace at {
namespace vec256 {
#ifdef __AVX2__
struct Vec256i {
__m256i values;
Vec256i() {}
Vec256i(__m256i v) : values(v) {}
operator __m256i() const {
return values;
}
void load(const void *ptr) {
values = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(ptr));
}
void store(void *ptr) const {
_mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values);
}
};
template <>
struct Vec256<int64_t> : public Vec256i {
static constexpr int size = 4;
using Vec256i::Vec256i;
Vec256() {}
Vec256(int64_t v) { values = _mm256_set1_epi64x(v); }
static Vec256<int64_t> s_load(const void* ptr) {
Vec256<int64_t> vec;
vec.load(ptr);
return vec;
}
void load_partial(const void *ptr, int count) {
int64_t tmp_values[size];
std::memcpy(tmp_values, ptr, count * sizeof(int64_t));
load(tmp_values);
}
void store_partial(void* ptr, int count) const {
__at_align32__ int64_t tmp_values[size];
store(tmp_values);
std::memcpy(ptr, tmp_values, count * sizeof(int64_t));
}
Vec256<int64_t> abs() const {
auto zero = _mm256_set1_epi64x(0);
auto is_larger = _mm256_cmpgt_epi64(zero, values);
auto inverse = _mm256_xor_si256(values, is_larger);
return _mm256_sub_epi64(inverse, is_larger);
}
};
template <>
struct Vec256<int32_t> : public Vec256i {
static constexpr int size = 8;
using Vec256i::Vec256i;
Vec256() {}
Vec256(int32_t v) { values = _mm256_set1_epi32(v); }
static Vec256<int32_t> s_load(const void* ptr) {
Vec256<int32_t> vec;
vec.load(ptr);
return vec;
}
void load_partial(const void *ptr, int count) {
int32_t tmp_values[size];
std::memcpy(tmp_values, ptr, count * sizeof(int32_t));
load(tmp_values);
}
void store_partial(void* ptr, int count) const {
__at_align32__ int32_t tmp_values[size];
store(tmp_values);
std::memcpy(ptr, tmp_values, count * sizeof(int32_t));
}
Vec256<int32_t> abs() const {
return _mm256_abs_epi32(values);
}
};
template <>
struct Vec256<int16_t> : public Vec256i {
static constexpr int size = 16;
using Vec256i::Vec256i;
Vec256() {}
Vec256(int16_t v) { values = _mm256_set1_epi16(v); }
static Vec256<int16_t> s_load(const void* ptr) {
Vec256<int16_t> vec;
vec.load(ptr);
return vec;
}
void load_partial(const void *ptr, int count) {
int16_t tmp_values[size];
std::memcpy(tmp_values, ptr, count * sizeof(int16_t));
load(tmp_values);
}
void store_partial(void* ptr, int count) const {
__at_align32__ int16_t tmp_values[size];
store(tmp_values);
std::memcpy(ptr, tmp_values, count * sizeof(int16_t));
}
Vec256<int16_t> abs() const {
return _mm256_abs_epi16(values);
}
};
template <>
Vec256<int64_t> inline operator+(const Vec256<int64_t>& a, const Vec256<int64_t>& b) {
return _mm256_add_epi64(a, b);
}
template <>
Vec256<int32_t> inline operator+(const Vec256<int32_t>& a, const Vec256<int32_t>& b) {
return _mm256_add_epi32(a, b);
}
template <>
Vec256<int16_t> inline operator+(const Vec256<int16_t>& a, const Vec256<int16_t>& b) {
return _mm256_add_epi16(a, b);
}
// AVX2 has no intrinsic for int64_t multiply so it needs to be emulated
// This could be implemented more efficiently using epi32 instructions
// This is also technically avx compatible, but then we'll need AVX
// code for add as well.
template <>
Vec256<int64_t> inline operator*(const Vec256<int64_t>& a, const Vec256<int64_t>& b) {
int64_t a0 = _mm256_extract_epi64(a, 0);
int64_t a1 = _mm256_extract_epi64(a, 1);
int64_t a2 = _mm256_extract_epi64(a, 2);
int64_t a3 = _mm256_extract_epi64(a, 3);
int64_t b0 = _mm256_extract_epi64(b, 0);
int64_t b1 = _mm256_extract_epi64(b, 1);
int64_t b2 = _mm256_extract_epi64(b, 2);
int64_t b3 = _mm256_extract_epi64(b, 3);
int64_t c0 = a0 * b0;
int64_t c1 = a1 * b1;
int64_t c2 = a2 * b2;
int64_t c3 = a3 * b3;
return _mm256_set_epi64x(c3, c2, c1, c0);
}
template <>
Vec256<int32_t> inline operator*(const Vec256<int32_t>& a, const Vec256<int32_t>& b) {
return _mm256_mullo_epi32(a, b);
}
template <>
Vec256<int16_t> inline operator*(const Vec256<int16_t>& a, const Vec256<int16_t>& b) {
return _mm256_mullo_epi16(a, b);
}
#endif
}}

View File

@ -12,46 +12,39 @@
#include <map>
namespace at { namespace native {
using reduce_type = void(Tensor&, const Tensor&, size_t, bool);
reduce_type* sumImpl = &DispatchStub<reduce_type>::init<sumImplC, &sumImpl>;
reduce_type* prodImpl = &DispatchStub<reduce_type>::init<prodImplC, &prodImpl>;
namespace at {
namespace native {
// ALL REDUCE #################################################################
Tensor _reduce_cpu(reduce_type* f, const Tensor& self) {
Tensor result = self.type().tensor({});
f(result, self, 0, true);
return result;
}
Tensor _sum_cpu(const Tensor& self) {
if (self.is_contiguous())
return _reduce_cpu(sumImpl, self);
if (self.is_contiguous()) {
Tensor result = self.type().tensor({});
sum_kernel(result, self, at::nullopt);
return result;
}
return self._sumall();
}
Tensor _prod_cpu(const Tensor& self) {
if (self.is_contiguous())
return _reduce_cpu(prodImpl, self);
Tensor _prod_cpu(const Tensor &self) {
if (self.is_contiguous()) {
Tensor result = self.type().tensor({});
prod_kernel(result, self, at::nullopt);
return result;
}
return self._prodall();
}
Tensor _sum_cuda(const Tensor& self_) {
return self_._sumall();
}
Tensor _sum_cuda(const Tensor &self_) { return self_._sumall(); }
Tensor _prod_cuda(const Tensor& self_) {
return self_._prodall();
}
Tensor _prod_cuda(const Tensor &self_) { return self_._prodall(); }
// \ALL REDUCE ################################################################
// DIM REDUCE #################################################################
static bool
_dimreduce_return_trivial(Tensor& result, const Tensor& self, int64_t ident) {
static bool _dimreduce_return_trivial(Tensor &result, const Tensor &self,
int64_t ident) {
if (self.numel() == 1 && self.ndimension() == 0) {
result.resize_({});
result.fill_(self);
@ -66,8 +59,8 @@ _dimreduce_return_trivial(Tensor& result, const Tensor& self, int64_t ident) {
return false;
}
static Tensor&
_dimreduce_setup(Tensor& result, const Tensor& self, int64_t dim) {
static Tensor &_dimreduce_setup(Tensor &result, const Tensor &self,
int64_t dim) {
IntList self_sizes = self.sizes();
std::vector<int64_t> result_sizes;
result_sizes.insert(result_sizes.end(), self_sizes.begin(), self_sizes.end());
@ -76,62 +69,56 @@ _dimreduce_setup(Tensor& result, const Tensor& self, int64_t dim) {
return result;
}
Tensor& _reduce_out_cpu(
reduce_type* f,
Tensor& result,
const Tensor& self,
int64_t dim,
bool keepdim) {
result = _dimreduce_setup(result, self, dim);
f(result, self, dim, false);
if (!keepdim)
result.squeeze_(dim);
return result;
}
Tensor&
_sum_out_cpu(Tensor& result, const Tensor& self, int64_t dim_, bool keepdim) {
Tensor &_sum_out_cpu(Tensor &result, const Tensor &self, int64_t dim_,
bool keepdim) {
int64_t dim = maybe_wrap_dim(dim_, self.dim());
if (_dimreduce_return_trivial(result, self, 0))
return result;
if (self.is_contiguous() && result.is_contiguous()) {
return _reduce_out_cpu(sumImpl, result, self, dim, keepdim);
_dimreduce_setup(result, self, dim);
sum_kernel(result, self, dim);
if (!keepdim) result.squeeze_(dim);
return result;
}
return at::_sum_out(result, self, dim, keepdim);
}
Tensor&
_prod_out_cpu(Tensor& result, const Tensor& self, int64_t dim_, bool keepdim) {
Tensor &_prod_out_cpu(Tensor &result, const Tensor &self, int64_t dim_,
bool keepdim) {
int64_t dim = maybe_wrap_dim(dim_, self.dim());
if (_dimreduce_return_trivial(result, self, 1))
return result;
if (self.is_contiguous() && result.is_contiguous()) {
return _reduce_out_cpu(prodImpl, result, self, dim, keepdim);
_dimreduce_setup(result, self, dim);
prod_kernel(result, self, dim);
if (!keepdim) result.squeeze_(dim);
return result;
}
return at::_prod_out(result, self, dim, keepdim);
}
Tensor&
_sum_out_cuda(Tensor& result, const Tensor& self, int64_t dim, bool keepdim) {
Tensor &_sum_out_cuda(Tensor &result, const Tensor &self, int64_t dim,
bool keepdim) {
return at::_sum_out(result, self, dim, keepdim);
}
Tensor&
_prod_out_cuda(Tensor& result, const Tensor& self, int64_t dim, bool keepdim) {
Tensor &_prod_out_cuda(Tensor &result, const Tensor &self, int64_t dim,
bool keepdim) {
return at::_prod_out(result, self, dim, keepdim);
}
Tensor sum(const Tensor& self, int64_t dim_, bool keepdim) {
Tensor sum(const Tensor &self, int64_t dim_, bool keepdim) {
int64_t dim = maybe_wrap_dim(dim_, self.dim());
Tensor result = self.type().tensor();
return at::sum_out(result, self, dim, keepdim);
}
Tensor prod(const Tensor& self, int64_t dim_, bool keepdim) {
Tensor prod(const Tensor &self, int64_t dim_, bool keepdim) {
int64_t dim = maybe_wrap_dim(dim_, self.dim());
Tensor result = self.type().tensor();
return at::prod_out(result, self, dim, keepdim);
}
// \DIM REDUCE ################################################################
}} // namespace at::native
}
}

View File

@ -14,46 +14,35 @@
namespace at { namespace native {
using unary_type = void(Tensor&, const Tensor&);
#define DISPATCH(NAME) \
unary_type* NAME ## Impl = DispatchStub<unary_type>::init<NAME ## ImplC, &NAME ## Impl>;\
#define BASIC(NAME) \
Tensor NAME(const Tensor& self) { \
Tensor result = self.type().tensor(); \
return at::NAME ## _out(result, self); \
} \
#define SELF(NAME) \
Tensor& NAME##_(Tensor& self) { \
return at::NAME ## _out(self, self); \
} \
#define OUTCPU(NAME) \
Tensor& _ ## NAME ## _out_cpu(Tensor& result, const Tensor& self) { \
return _unops_out_cpu(NAME ## Impl, result, self) ? result \
: at::_ ## NAME ## _out(result, self); \
} \
#define OUTCUDA(NAME) \
Tensor& _ ## NAME ## _out_cuda(Tensor& result, const Tensor& self) { \
return at::_ ## NAME ## _out(result, self); \
} \
bool _unops_out_cpu(unary_type* f, Tensor& result, const Tensor& self) {
if (result.is_contiguous() && self.is_contiguous()) {
result.resize_(self.sizes());
f(result, self);
return true;
}
return false;
#define IMPLEMENT_UNARY_OP(op) \
Tensor op(const Tensor& self) { \
Tensor result = self.type().tensor(); \
return at::op ## _out(result, self); \
} \
Tensor& op##_(Tensor& self) { \
return at::op ## _out(self, self); \
} \
Tensor& _ ## op ## _out_cuda(Tensor& result, const Tensor& self) { \
return at::_ ## op ## _out(result, self); \
} \
Tensor& _ ## op ## _out_cpu(Tensor& result, const Tensor& self) { \
if (result.is_contiguous() && self.is_contiguous()) { \
result.resize_(self.sizes()); \
op ## Impl(result, self); \
return result; \
} \
return at::_ ## op ## _out(result, self); \
}
UNARY_OPS_MACRO(DISPATCH)
UNARY_OPS_MACRO(BASIC)
UNARY_OPS_MACRO(SELF)
UNARY_OPS_MACRO(OUTCPU)
UNARY_OPS_MACRO(OUTCUDA)
IMPLEMENT_UNARY_OP(abs)
IMPLEMENT_UNARY_OP(ceil)
IMPLEMENT_UNARY_OP(cos)
IMPLEMENT_UNARY_OP(exp)
IMPLEMENT_UNARY_OP(floor)
IMPLEMENT_UNARY_OP(log)
IMPLEMENT_UNARY_OP(round)
IMPLEMENT_UNARY_OP(sin)
IMPLEMENT_UNARY_OP(sqrt)
IMPLEMENT_UNARY_OP(trunc)
}} // namespace at::native

View File

@ -1,44 +1,93 @@
#include <type_traits>
#pragma once
#include "ATen/cpu/cpuinfo/include/cpuinfo.h"
#include <type_traits>
#include <iostream>
namespace at { namespace native {
// Implements instruction set specific function dispatch.
//
// Kernels that may make use of specialized instruction sets (e.g. AVX) are
// compiled multiple times with different compiler flags (e.g. -mavx). A
// DispatchStub contains a table of function pointers for a kernel. At runtime,
// the fastest available kernel is chosen based on the features reported by
// cpuinfo.
//
// Example:
//
// In native/cpu/MyKernel.h:
// using fn_type = void(*)(const Tensor& x);
// DispatchStub<fn_type> stub;
//
// In native/cpu/MyKernel.cpp:
// void kernel(const Tensor& x) { ... }
// REGISTER_DISPATCH(stub, &kernel);
//
// To call:
// stub(tensor);
//
enum class CPUCapability { DEFAULT, AVX, AVX2 };
namespace at {
namespace native {
#if defined(CPUCAPABILITYDEFAULT)
constexpr CPUCapability CURRENT_CAPABILITY = CPUCapability::DEFAULT;
#elif defined(CPUCAPABILITYAVX)
constexpr CPUCapability CURRENT_CAPABILITY = CPUCapability::AVX;
#elif defined(CPUCAPABILITYAVX2)
constexpr CPUCapability CURRENT_CAPABILITY = CPUCapability::AVX2;
#endif
enum class CPUCapability { DEFAULT, AVX, AVX2, NUM_OPTIONS };
template <typename FnType>
struct DispatchStub {};
template <typename FnPtr>
struct DispatchStub {
static_assert(std::is_pointer<FnPtr>::value, "FnPtr should be a pointer type");
template <typename... ArgTypes>
struct DispatchStub<void(ArgTypes...)> {
using FnType = void(ArgTypes...);
template <template <CPUCapability> class allImpl, FnType** dispatch_ptr>
static void init(ArgTypes... args) {
*dispatch_ptr = allImpl<CPUCapability::DEFAULT>::function;
// Check if platform is supported
if (cpuinfo_initialize()) {
// Set function pointer to best implementation last
#if defined(HAVE_AVX_CPU_DEFINITION)
if (!std::getenv("ATEN_DISABLE_AVX") && cpuinfo_has_x86_avx()) {
*dispatch_ptr = allImpl<CPUCapability::AVX>::function;
}
#endif
#if defined(HAVE_AVX2_CPU_DEFINITION)
if (!std::getenv("ATEN_DISABLE_AVX2") && cpuinfo_has_x86_avx2()) {
*dispatch_ptr = allImpl<CPUCapability::AVX2>::function;
}
#endif
template <typename... ArgTypes>
void operator()(ArgTypes... args) {
if (!dispatch_ptr) {
dispatch_ptr = choose_impl();
}
(*dispatch_ptr)(args...);
}
FnPtr choose_impl() {
if (cpuinfo_initialize()) {
int avx2 = static_cast<int>(CPUCapability::AVX2);
if (!std::getenv("ATEN_DISABLE_AVX2") && cpuinfo_has_x86_avx2() && table[avx2]) {
return table[avx2];
}
int avx = static_cast<int>(CPUCapability::AVX);
if (!std::getenv("ATEN_DISABLE_AVX") && cpuinfo_has_x86_avx() && table[avx]) {
return table[avx];
}
}
int def = static_cast<int>(CPUCapability::DEFAULT);
AT_ASSERT(table[def], "DispatchStub: missing default kernel");
return table[def];
}
FnPtr dispatch_ptr = nullptr;
FnPtr table[static_cast<int>(CPUCapability::NUM_OPTIONS)];
};
}} // namespace at::native
#if defined(CPU_CAPABILITY)
constexpr CPUCapability CURRENT_CAPABILITY = CPUCapability::CPU_CAPABILITY;
// Registers an implementation a kernel for the current CPU capability.
template<typename FnPtr>
struct RegisterDispatch {
RegisterDispatch(DispatchStub<FnPtr>& stub, FnPtr value) {
stub.table[static_cast<int>(CURRENT_CAPABILITY)] = value;
}
};
// We only define the stub once in the DEFAULT capability compilation
#if defined(CPU_CAPABILITY_DEFAULT)
#define _DEFINE_STUB(stub, fn) DispatchStub<decltype(fn)> stub
#else
#define _DEFINE_STUB(stub, fn)
#endif
#define REGISTER_DISPATCH(stub, fn) \
_DEFINE_STUB(stub, fn); \
static RegisterDispatch<decltype(fn)> stub ## __register(stub, fn);
#endif
}
}

View File

@ -1,139 +1,154 @@
#include "ATen/native/cpu/ReduceOpsKernel.h"
#include <numeric>
#include "ATen/Dispatch.h"
#include "ATen/Parallel.h"
#include "ATen/native/cpu/Vec256.h"
#include "ATen/optional.h"
#include "ATen/cpu/vec256/vec256.h"
namespace at { namespace native {
namespace at {
namespace native {
using namespace vec256;
// This adds the content of arr to sum
template <class scalar_t, template <class> class OP, CPUCapability C>
inline scalar_t
allreduce_kernel_(const scalar_t* arr, size_t start, size_t end, scalar_t sum) {
Vec256<scalar_t> part_sum;
// Use all 16 registers.
Vec256<scalar_t> tmp_sum[4], tmp_sum1, tmp_sum2, tmp_sum3;
Vec256<scalar_t> a[8];
size_t width =
256 / sizeof(scalar_t); // primitives per 256 bytes (two cache lines)
size_t epr = 32 / sizeof(scalar_t); // primitives per Vec256
size_t k = 0;
for (; k < (end - start) / width; k++) {
for (size_t i = 0; i < 8; i++) {
a[i].load(arr + (k * width) + i * epr + start);
}
for (size_t i = 0; i < 8; i += 2) {
tmp_sum[i / 2] = OP<Vec256<scalar_t>>()(a[i], a[i + 1]);
}
tmp_sum1 = OP<Vec256<scalar_t>>()(tmp_sum[0], tmp_sum[1]);
tmp_sum2 = OP<Vec256<scalar_t>>()(tmp_sum[2], tmp_sum[3]);
if (k == 0) {
part_sum = OP<Vec256<scalar_t>>()(tmp_sum1, tmp_sum2);
} else {
tmp_sum3 = OP<Vec256<scalar_t>>()(tmp_sum1, tmp_sum2);
part_sum = OP<Vec256<scalar_t>>()(part_sum, tmp_sum3);
}
}
if (k > 0) {
scalar_t sarr[32 / sizeof(scalar_t)];
part_sum.store(sarr);
for (size_t i = 0; i < part_sum.size; i++) {
sum = OP<scalar_t>()(sum, sarr[i]);
}
}
k = k * width + start;
for (; k < end; k++) {
sum = OP<scalar_t>()(sum, arr[k]);
}
return sum;
static inline int64_t round_down(int64_t a, int64_t m) {
return a - (a % m);
}
// This overwrites the content of outarr
template <class scalar_t, template <class> class OP, CPUCapability C>
inline void dimreduce_kernel_(
const scalar_t* arr,
scalar_t* outarr,
size_t num_rows,
size_t num_cols) {
size_t width =
256 / (sizeof(scalar_t)); // primitives per 256 bytes (two cache lines)
Vec256<scalar_t> a[8];
Vec256<scalar_t> b[8];
constexpr size_t epr = 32 / sizeof(scalar_t); // primitives per Vec256
size_t tile = 0;
for (; tile < (num_cols) / width; tile++) {
size_t row_ind = tile * width;
for (size_t i = 0; i < num_rows; i += 1) {
for (int ib = 0; ib < 8; ib++) {
if (i == 0) {
b[ib].load(arr + i * num_cols + tile * width + ib * epr);
} else {
a[ib].load(arr + i * num_cols + tile * width + ib * epr);
b[ib] = OP<Vec256<scalar_t>>()(b[ib], a[ib]);
template<typename F>
static void parallel_for(int64_t end, int64_t step, bool parallelize, F func) {
if (parallelize) {
tbb::parallel_for<int64_t>(0, end, step, func);
} else {
for (int64_t i = 0; i != end; i += step) {
func(i);
}
}
}
static tbb::affinity_partitioner ap;
// Vectorized reduction defined by reduce operation `Op` with identity `ident`.
// The reduction is built on top of reduce128, which reduces down a column
// 128 bytes wide (WIDTH scalar elements). The width of 128 bytes is chosen
// because of the "adjacent cache line prefetch" behavior on x86 CPUs.
template<typename scalar_t, template <class> class Op, int ident>
struct Reduction {
// reduction width in number of scalar elements
static constexpr int WIDTH = 128 / sizeof(scalar_t);
using Vec = Vec256<scalar_t>;
using Reduce = Op<Vec>;
using ReduceScalar = Op<scalar_t>;
static void apply(Tensor& res, const Tensor& self, at::optional<int64_t> dim) {
internal::init_tbb_num_threads();
auto out = res.data<scalar_t>();
auto data = self.data<scalar_t>();
auto numel = self.numel();
if (!dim.has_value()) {
*out = reduce_all(data, numel);
return;
}
int64_t n = self.size(*dim);
int64_t stride = self.stride(*dim);
int64_t batch = numel / (n * stride);
bool paralellize = batch * n > internal::TBB_GRAIN_SIZE;
parallel_for(batch, 1, paralellize, [=](int64_t b) {
if (stride == 1) {
out[b] = reduce_all(&data[b * n], n);
} else {
reduce2d(&data[b * n * stride], &out[b * stride], n, stride, stride);
}
});
}
static scalar_t reduce_all(const scalar_t* data, int64_t size) {
int64_t k = size / WIDTH;
scalar_t sum;
if (size > internal::TBB_GRAIN_SIZE) {
sum = tbb::parallel_reduce(
tbb::blocked_range<int64_t>(0, k, internal::TBB_GRAIN_SIZE / WIDTH),
scalar_t(ident),
[=](const tbb::blocked_range<int64_t>& r, scalar_t init) {
scalar_t buf[WIDTH];
reduce128(&data[r.begin() * WIDTH], buf, r.end() - r.begin(), WIDTH);
return std::accumulate(buf, buf + WIDTH, init, ReduceScalar());
},
ReduceScalar(),
ap);
} else {
scalar_t buf[WIDTH];
reduce128(data, buf, k, WIDTH);
sum = std::accumulate(buf, buf + WIDTH, scalar_t(ident), ReduceScalar());
}
for (int i = k * WIDTH; i != size; i++) {
sum = ReduceScalar()(sum, data[i]);
}
return sum;
}
// Reduce down a column of WIDTH elements (128 bytes) with the given number
// of rows. Stores the results in out[0 ... WIDTH-1].
static void reduce128(const scalar_t* data, scalar_t* out, int64_t rows, int64_t stride) {
Vec acc[4] = {ident, ident, ident, ident}; // 128 bytes (two cache lines)
static_assert(sizeof(acc) == 128, "accumulator should be 128 bytes");
for (int64_t row = 0; row != rows; row++) {
for (int j = 0; j != 4; j++) {
auto val = Vec::s_load(&data[row * stride + j * Vec::size]);
acc[j] = Reduce()(acc[j], val);
}
}
for (int j = 0; j != 4; j++) {
acc[j].store(&out[j * Vec::size]);
}
}
// Reduce a 2d matrix down each column. Stores the results in out[0 ... cols-1]
static void reduce2d(const scalar_t* data, scalar_t* out, int64_t rows, int64_t cols, int64_t stride) {
int64_t cols_rounded = round_down(cols, WIDTH);
bool paralellize = cols * rows > internal::TBB_GRAIN_SIZE;
parallel_for(cols_rounded, WIDTH, paralellize, [=](int64_t col) {
reduce128(&data[col], &out[col], rows, stride);
});
if (cols_rounded != cols) {
scalar_t buf[WIDTH];
for (int64_t j = 0; j != cols - cols_rounded; j++) {
buf[j] = ident;
}
for (int64_t row = 0; row != rows; row++) {
for (int64_t j = 0; j != cols - cols_rounded; j++) {
auto val = data[row * stride + j + cols_rounded];
buf[j] = ReduceScalar()(buf[j], val);
}
}
}
for (int ib = 0; ib < 8; ib++) {
b[ib].store(outarr + row_ind + ib * epr);
}
}
size_t k = tile * width;
for (; k < num_cols; k++) {
for (size_t i = 0; i < num_rows; i += 1) {
if (i == 0) {
outarr[k] = arr[i * num_cols + k];
} else {
outarr[k] = OP<scalar_t>()(outarr[k], arr[i * num_cols + k]);
for (int64_t j = 0; j != cols - cols_rounded; j++) {
out[j + cols_rounded] = buf[j];
}
}
}
}
};
template <template <class> class OP, CPUCapability C>
inline void allImpl(
Tensor& result,
const Tensor& self,
size_t dim,
bool all,
const char* name,
int64_t init) {
AT_DISPATCH_ALL_TYPES(self.type(), name, [&] {
if (all) {
result.fill_(at::parallel_reduce<scalar_t, OP>(
&allreduce_kernel_<scalar_t, OP, CURRENT_CAPABILITY>,
self.data<scalar_t>(),
(size_t)0,
(size_t)self.numel(),
(scalar_t)init));
} else {
at::parallel_reduce_2d<scalar_t>(
&dimreduce_kernel_<scalar_t, OP, CURRENT_CAPABILITY>,
self.sizes()[dim],
self.strides()[dim],
self.numel(),
self.data<scalar_t>(),
result.data<scalar_t>());
}
static void sum_kernel_impl(Tensor& result, const Tensor& self, at::optional<int64_t> dim) {
AT_DISPATCH_ALL_TYPES(self.type(), "sum", [&] {
Reduction<scalar_t, std::plus, 0>::apply(result, self, dim);
});
}
template <>
void sumImplC<CURRENT_CAPABILITY>::function(
Tensor& result,
const Tensor& self,
size_t dim,
bool all) {
allImpl<std::plus, CURRENT_CAPABILITY>(result, self, dim, all, "sum", 0);
static void prod_kernel_impl(Tensor& result, const Tensor& self, at::optional<int64_t> dim) {
AT_DISPATCH_ALL_TYPES(self.type(), "prod", [&] {
Reduction<scalar_t, std::multiplies, 1>::apply(result, self, dim);
});
}
template <>
void prodImplC<CURRENT_CAPABILITY>::function(
Tensor& result,
const Tensor& self,
size_t dim,
bool all) {
allImpl<std::multiplies, CURRENT_CAPABILITY>(
result, self, dim, all, "prod", 1);
REGISTER_DISPATCH(sum_kernel, &sum_kernel_impl);
REGISTER_DISPATCH(prod_kernel, &prod_kernel_impl);
}
}
}} // namespace at::native

View File

@ -1,21 +1,16 @@
#pragma once
#include <ATen/ATen.h>
#include <ATen/Parallel.h>
#include <stdexcept>
#include <ATen/optional.h>
#include "CapabilityDispatch.h"
namespace at { namespace native {
namespace at {
namespace native {
template <CPUCapability C>
struct sumImplC {
static void
function(Tensor& result, const Tensor& self, size_t dim, bool all);
};
using reduce_fn = void(*)(Tensor &, const Tensor &, at::optional<int64_t>);
template <CPUCapability C>
struct prodImplC {
static void
function(Tensor& result, const Tensor& self, size_t dim, bool all);
};
extern DispatchStub<reduce_fn> sum_kernel;
extern DispatchStub<reduce_fn> prod_kernel;
}} // namespace at::native
}
}

View File

@ -1,76 +1,142 @@
#include "ATen/native/cpu/UnaryOpsKernel.h"
#include <cmath>
#include <iostream>
#include "ATen/Dispatch.h"
#include "ATen/Parallel.h"
#include "ATen/native/cpu/Vec256.h"
#include "ATen/cpu/vec256/vec256.h"
#include "ATen/native/cpu/CapabilityDispatch.h"
namespace at { namespace native {
using namespace vec256;
// This modifies arr in place with given OP
template <class scalar_t, template <class> class VOP, CPUCapability C>
inline void
kernel_(scalar_t* arr_out, const scalar_t* arr_in, size_t start, size_t end) {
Vec256<scalar_t> a;
size_t epr = 32 / sizeof(scalar_t); // primitives per Vec256
size_t k = start;
size_t vec_end = end > epr ? end - epr : 0;
for (; k < vec_end; k += epr) {
a.load(arr_in + k);
VOP<scalar_t>()(a).store(arr_out + k);
template <typename scalar_t, typename F>
static void unary_kernel(scalar_t* arr_out, const scalar_t* arr_in, int64_t size, F func) {
using Vec = Vec256<scalar_t>;
int64_t size_rounded = size - (size % Vec::size);
int64_t k = 0;
for (; k != size_rounded; k += Vec::size) {
auto value = func(Vec::s_load(arr_in + k));
value.store(arr_out + k);
}
size_t leftover = std::min((end - k), a.size);
a.load(arr_in + k, leftover);
VOP<scalar_t>()(a).store(arr_out + k, leftover);
auto leftover = size - k;
Vec a;
a.load_partial(arr_in + k, leftover);
func(a).store_partial(arr_out + k, leftover);
}
// Functions excluding one-offs
#define GENERIC_UNARY_OPS_MACRO(MACRO) \
MACRO (ceil) \
MACRO (cos) \
MACRO (exp) \
MACRO (floor) \
MACRO (log) \
MACRO (round) \
MACRO (sin) \
MACRO (sqrt) \
MACRO (trunc) \
template <class scalar_t, class F>
static void parallel_apply(Tensor& result, const Tensor& self, F f) {
internal::init_tbb_num_threads();
namespace {
static tbb::affinity_partitioner ap;
#define FUNCVOP(NAME) \
template <typename T> \
struct NAME##VOP { \
Vec256<T> operator()(Vec256<T>& x) const { \
return x.NAME(); \
} \
};
UNARY_OPS_MACRO(FUNCVOP)
} // namespace
#define FUNCImpl(NAME) \
template <> \
void NAME##ImplC<CURRENT_CAPABILITY>::function( \
Tensor& result, const Tensor& self) { \
AT_DISPATCH_FLOATING_TYPES(self.type(), NAME, [&] { \
at::parallel_for_1d<scalar_t>( \
&kernel_<scalar_t, NAME##VOP, CURRENT_CAPABILITY>, result, self); \
}); \
auto arr_out = result.data<scalar_t>();
auto arr_in = self.data<scalar_t>();
int64_t size = self.numel();
if (size < internal::TBB_GRAIN_SIZE) {
unary_kernel(arr_out, arr_in, size, f);
} else {
tbb::parallel_for(
tbb::blocked_range<int64_t>(0, size, internal::TBB_GRAIN_SIZE),
[&](const tbb::blocked_range<int64_t>& r) {
auto size = r.end() - r.begin();
unary_kernel(arr_out + r.begin(), arr_in + r.begin(), size, f);
},
ap);
}
}
GENERIC_UNARY_OPS_MACRO(FUNCImpl)
template <>
void absImplC<CURRENT_CAPABILITY>::function(
Tensor& result,
const Tensor& self) {
AT_DISPATCH_ALL_TYPES(self.type(), abs, [&] {
at::parallel_for_1d<scalar_t>(
&kernel_<scalar_t, absVOP, CURRENT_CAPABILITY>, result, self);
static void abs_kernel(Tensor& result, const Tensor& self) {
AT_DISPATCH_ALL_TYPES(self.type(), "abs", [&] {
parallel_apply<scalar_t>(result, self, [](const Vec256<scalar_t>& x) {
return x.abs();
});
});
}
static void ceil_kernel(Tensor& result, const Tensor& self) {
AT_DISPATCH_FLOATING_TYPES(self.type(), "ceil", [&] {
parallel_apply<scalar_t>(result, self, [](const Vec256<scalar_t>& x) {
return x.ceil();
});
});
}
static void cos_kernel(Tensor& result, const Tensor& self) {
AT_DISPATCH_FLOATING_TYPES(self.type(), "cos", [&] {
parallel_apply<scalar_t>(result, self, [](const Vec256<scalar_t>& x) {
return x.cos();
});
});
}
static void exp_kernel(Tensor& result, const Tensor& self) {
AT_DISPATCH_FLOATING_TYPES(self.type(), "exp", [&] {
parallel_apply<scalar_t>(result, self, [](const Vec256<scalar_t>& x) {
return x.exp();
});
});
}
static void floor_kernel(Tensor& result, const Tensor& self) {
AT_DISPATCH_FLOATING_TYPES(self.type(), "floor", [&] {
parallel_apply<scalar_t>(result, self, [](const Vec256<scalar_t>& x) {
return x.floor();
});
});
}
static void log_kernel(Tensor& result, const Tensor& self) {
AT_DISPATCH_FLOATING_TYPES(self.type(), "log", [&] {
parallel_apply<scalar_t>(result, self, [](const Vec256<scalar_t>& x) {
return x.log();
});
});
}
static void round_kernel(Tensor& result, const Tensor& self) {
AT_DISPATCH_FLOATING_TYPES(self.type(), "round", [&] {
parallel_apply<scalar_t>(result, self, [](const Vec256<scalar_t>& x) {
return x.round();
});
});
}
static void sin_kernel(Tensor& result, const Tensor& self) {
AT_DISPATCH_FLOATING_TYPES(self.type(), "sin", [&] {
parallel_apply<scalar_t>(result, self, [](const Vec256<scalar_t>& x) {
return x.sin();
});
});
}
static void sqrt_kernel(Tensor& result, const Tensor& self) {
AT_DISPATCH_FLOATING_TYPES(self.type(), "sqrt", [&] {
parallel_apply<scalar_t>(result, self, [](const Vec256<scalar_t>& x) {
return x.sqrt();
});
});
}
static void trunc_kernel(Tensor& result, const Tensor& self) {
AT_DISPATCH_FLOATING_TYPES(self.type(), "trunc", [&] {
parallel_apply<scalar_t>(result, self, [](const Vec256<scalar_t>& x) {
return x.trunc();
});
});
}
REGISTER_DISPATCH(absImpl, &abs_kernel);
REGISTER_DISPATCH(ceilImpl, &ceil_kernel);
REGISTER_DISPATCH(cosImpl, &cos_kernel);
REGISTER_DISPATCH(expImpl, &exp_kernel);
REGISTER_DISPATCH(floorImpl, &floor_kernel);
REGISTER_DISPATCH(logImpl, &log_kernel);
REGISTER_DISPATCH(roundImpl, &round_kernel);
REGISTER_DISPATCH(sinImpl, &sin_kernel);
REGISTER_DISPATCH(sqrtImpl, &sqrt_kernel);
REGISTER_DISPATCH(truncImpl, &trunc_kernel);
}} // namespace at::native

View File

@ -1,4 +1,5 @@
#pragma once
#include <ATen/ATen.h>
#include <ATen/Parallel.h>
#include <stdexcept>
@ -6,25 +7,18 @@
namespace at { namespace native {
#define FUNCImplC(NAME) \
template <CPUCapability C>\
struct NAME ## ImplC {\
static void function(Tensor& result, const Tensor& self);\
};\
using unary_fn = void(*)(Tensor&, const Tensor&);
#define UNARY_OPS_MACRO(MACRO) \
MACRO (abs) \
MACRO (ceil) \
MACRO (cos) \
MACRO (exp) \
MACRO (floor) \
MACRO (log) \
MACRO (round) \
MACRO (sin) \
MACRO (sqrt) \
MACRO (trunc) \
UNARY_OPS_MACRO(FUNCImplC)
extern DispatchStub<unary_fn> absImpl;
extern DispatchStub<unary_fn> ceilImpl;
extern DispatchStub<unary_fn> cosImpl;
extern DispatchStub<unary_fn> expImpl;
extern DispatchStub<unary_fn> floorImpl;
extern DispatchStub<unary_fn> logImpl;
extern DispatchStub<unary_fn> roundImpl;
extern DispatchStub<unary_fn> sinImpl;
extern DispatchStub<unary_fn> sqrtImpl;
extern DispatchStub<unary_fn> truncImpl;
// Missing unary functions
// TODO: Add generic apply function for contiguous and non-contiguous tensors
@ -46,6 +40,5 @@ UNARY_OPS_MACRO(FUNCImplC)
// sinh
// tan
// tanh
// trunc
}} // namespace at::native

View File

@ -1,518 +0,0 @@
#pragma once
#include <algorithm>
#include <cmath>
#include <cstddef>
#include <cstdint>
#include <cstdlib>
#include <cstring>
#include "Intrinsics.h"
#ifdef __AVX2__
#include "avx_mathfun.h"
#endif
#if defined(__GNUC__)
#define __at_align32__ __attribute__((aligned(32)))
#elif defined(_WIN32)
#define __at_align32__ __declspec(align(32))
#else
#define __at_align32__
#endif
// NOTE:
// If you specialize on a type, you must define all operations!
// C arrays and intrinsic types don't mix
//
// NOTE:
// When testing make sure to test all capabilities (AVX, AVX2, DEFAULT, etc.)
namespace at { namespace native { namespace vec256 {
template <class T>
class Vec256 {
public:
__at_align32__ T values[32 / sizeof(T)]; // Mimics AVX behavior
inline void load(const T* ptr) {
std::memcpy(values, ptr, 32);
};
inline void store(T* ptr) const {
std::memcpy(ptr, values, 32);
}
inline void load(const T* ptr, size_t count) {
size_t section = count * sizeof(T);
std::memcpy(values, ptr, section);
};
inline void store(T* ptr, size_t count) const {
size_t section = count * sizeof(T);
std::memcpy(ptr, values, section);
}
size_t size = 32 / sizeof(T);
inline void operator=(const Vec256<T>& b) {
std::memcpy(values, b.values, 32);
}
inline Vec256<T> map(T (*f)(T)) {
Vec256<T> ret;
for (size_t i = 0; i < size; i++) {
ret.values[i] = f(values[i]);
}
return ret;
}
inline Vec256<T> abs() {
Vec256<T> ret;
for (size_t i = 0; i < size; i++)
ret.values[i] = values[i] < 0 ? -values[i] : values[i];
return ret;
}
inline Vec256<T> exp() {
return map(std::exp);
}
inline Vec256<T> log() {
return map(std::log);
}
inline Vec256<T> ceil() {
return map(std::ceil);
}
inline Vec256<T> cos() {
return map(std::cos);
}
inline Vec256<T> floor() {
return map(std::floor);
}
inline Vec256<T> round() {
return map(std::round);
}
inline Vec256<T> sin() {
return map(std::sin);
}
inline Vec256<T> trunc() {
return map(std::trunc);
}
inline Vec256<T> sqrt() {
return map(std::sqrt);
}
};
template <class T>
Vec256<T> operator+(const Vec256<T>& a, const Vec256<T>& b) {
Vec256<T> c = Vec256<T>();
for (size_t i = 0; i < a.size; i++)
c.values[i] = a.values[i] + b.values[i];
return c;
}
template <class T>
Vec256<T> operator*(const Vec256<T>& a, const Vec256<T>& b) {
Vec256<T> c = Vec256<T>();
for (size_t i = 0; i < a.size; i++)
c.values[i] = a.values[i] * b.values[i];
return c;
}
#ifdef __AVX__
template <>
class Vec256<float> {
public:
__m256 values;
Vec256<float>() {}
inline void load(const float* ptr) {
values = _mm256_loadu_ps(ptr);
}
inline void store(float* ptr) const {
_mm256_storeu_ps(ptr, values);
}
inline void load(const float* ptr, size_t count) {
float tmp_values[8];
std::memcpy(tmp_values, ptr, count * sizeof(float));
load(tmp_values);
}
inline void store(float* ptr, size_t count) const {
float tmp_values[8];
store(tmp_values);
std::memcpy(ptr, tmp_values, count * sizeof(float));
}
size_t size = 8;
inline void operator=(const Vec256<float>& b) {
values = b.values;
}
inline Vec256<float> map(float (*f)(float)) {
__at_align32__ float tmp[8];
store(tmp);
for (size_t i = 0; i < 8; i++)
tmp[i] = f(tmp[i]);
Vec256<float> ret;
ret.load(tmp);
return ret;
}
inline Vec256<float> abs() {
Vec256<float> ret;
__m256 mask = _mm256_set1_ps(-0.f);
ret.values = _mm256_andnot_ps(mask, values);
return ret;
}
#ifdef __AVX2__
inline Vec256<float> exp() {
Vec256<float> ret;
ret.values = exp256_ps(values);
return ret;
}
#else
inline Vec256<float> exp() {
return map(std::exp);
}
#endif
#ifdef __AVX2__
inline Vec256<float> log() {
Vec256<float> ret;
ret.values = log256_ps(values);
return ret;
}
#else
inline Vec256<float> log() {
return map(std::log);
}
#endif
#ifdef __AVX2__
inline Vec256<float> sin() {
Vec256<float> ret;
ret.values = sin256_ps(values);
return ret;
}
#else
inline Vec256<float> sin() {
return map(std::sin);
}
#endif
#ifdef __AVX2__
inline Vec256<float> cos() {
Vec256<float> ret;
ret.values = cos256_ps(values);
return ret;
}
#else
inline Vec256<float> cos() {
return map(std::cos);
}
#endif
inline Vec256<float> ceil() {
Vec256<float> ret;
ret.values = _mm256_ceil_ps(values);
return ret;
}
inline Vec256<float> floor() {
Vec256<float> ret;
ret.values = _mm256_floor_ps(values);
return ret;
}
inline Vec256<float> round() {
Vec256<float> ret;
ret.values = _mm256_round_ps(
values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
return ret;
}
inline Vec256<float> trunc() {
Vec256<float> ret;
ret.values =
_mm256_round_ps(values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC));
return ret;
}
inline Vec256<float> sqrt() {
Vec256<float> ret;
ret.values = _mm256_sqrt_ps(values);
return ret;
}
};
template <>
class Vec256<double> {
public:
__m256d values;
Vec256<double>() {}
inline void load(const double* ptr) {
values = _mm256_loadu_pd(ptr);
}
inline void store(double* ptr) const {
_mm256_storeu_pd(ptr, values);
}
inline void load(const double* ptr, size_t count) {
double tmp_values[4];
std::memcpy(tmp_values, ptr, count * sizeof(double));
load(tmp_values);
}
inline void store(double* ptr, size_t count) const {
double tmp_values[4];
store(tmp_values);
std::memcpy(ptr, tmp_values, count * sizeof(double));
}
size_t size = 4;
inline void operator=(const Vec256<double>& b) {
values = b.values;
}
inline Vec256<double> map(double (*f)(double)) {
__at_align32__ double tmp[4];
store(tmp);
for (size_t i = 0; i < 4; i++)
tmp[i] = f(tmp[i]);
Vec256<double> ret;
ret.load(tmp);
return ret;
}
inline Vec256<double> abs() {
Vec256<double> ret;
__m256d mask = _mm256_set1_pd(-0.);
ret.values = _mm256_andnot_pd(mask, values);
return ret;
}
inline Vec256<double> exp() {
return map(std::exp);
}
inline Vec256<double> log() {
return map(std::log);
}
inline Vec256<double> cos() {
return map(std::cos);
}
inline Vec256<double> ceil() {
Vec256<double> ret;
ret.values = _mm256_ceil_pd(values);
return ret;
}
inline Vec256<double> floor() {
Vec256<double> ret;
ret.values = _mm256_floor_pd(values);
return ret;
}
inline Vec256<double> round() {
Vec256<double> ret;
ret.values = _mm256_round_pd(
values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
return ret;
}
inline Vec256<double> sin() {
return map(std::sin);
}
inline Vec256<double> trunc() {
Vec256<double> ret;
ret.values =
_mm256_round_pd(values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC));
return ret;
}
inline Vec256<double> sqrt() {
Vec256<double> ret;
ret.values = _mm256_sqrt_pd(values);
return ret;
}
};
template <>
Vec256<float> inline operator+(const Vec256<float>& a, const Vec256<float>& b) {
Vec256<float> c = Vec256<float>();
c.values = _mm256_add_ps(a.values, b.values);
return c;
}
template <>
Vec256<float> inline operator*(const Vec256<float>& a, const Vec256<float>& b) {
Vec256<float> c = Vec256<float>();
c.values = _mm256_mul_ps(a.values, b.values);
return c;
}
template <>
Vec256<double> inline operator+(
const Vec256<double>& a,
const Vec256<double>& b) {
Vec256<double> c = Vec256<double>();
c.values = _mm256_add_pd(a.values, b.values);
return c;
}
template <>
Vec256<double> inline operator*(
const Vec256<double>& a,
const Vec256<double>& b) {
Vec256<double> c = Vec256<double>();
c.values = _mm256_mul_pd(a.values, b.values);
return c;
}
#endif
#ifdef __AVX2__
template <>
class Vec256<int64_t> {
public:
__m256i values;
Vec256<int64_t>() {}
inline void load(const int64_t* ptr) {
values = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(ptr));
}
inline void store(int64_t* ptr) const {
_mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values);
}
inline void load(const int64_t* ptr, size_t count) {
int64_t tmp_values[4];
std::memcpy(tmp_values, ptr, count * sizeof(int64_t));
load(tmp_values);
}
inline void store(int64_t* ptr, size_t count) const {
int64_t tmp_values[4];
store(tmp_values);
std::memcpy(ptr, tmp_values, count * sizeof(int64_t));
}
size_t size = 4;
inline void operator=(const Vec256<int64_t>& b) {
values = b.values;
}
inline Vec256<int64_t> abs() {
__m256i zero = _mm256_set1_epi64x(0);
__m256i is_larger = _mm256_cmpgt_epi64(zero, values);
__m256i inverse = _mm256_xor_si256(values, is_larger);
Vec256<int64_t> ret;
ret.values = _mm256_sub_epi64(inverse, is_larger);
return ret;
}
};
template <>
class Vec256<int32_t> {
public:
__m256i values;
Vec256<int32_t>() {}
inline void load(const int32_t* ptr) {
values = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(ptr));
}
inline void store(int32_t* ptr) const {
_mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values);
}
inline void load(const int32_t* ptr, size_t count) {
int32_t tmp_values[8];
std::memcpy(tmp_values, ptr, count * sizeof(int32_t));
load(tmp_values);
}
inline void store(int32_t* ptr, size_t count) const {
int32_t tmp_values[8];
store(tmp_values);
std::memcpy(ptr, tmp_values, count * sizeof(int32_t));
}
size_t size = 8;
inline void operator=(const Vec256<int32_t>& b) {
values = b.values;
}
inline Vec256<int32_t> abs() {
Vec256<int32_t> ret;
ret.values = _mm256_abs_epi32(values);
return ret;
}
};
template <>
class Vec256<int16_t> {
public:
__m256i values;
Vec256<int16_t>() {}
inline void load(const int16_t* ptr) {
values = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(ptr));
}
inline void store(int16_t* ptr) const {
_mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values);
}
inline void load(const int16_t* ptr, size_t count) {
int16_t tmp_values[16];
std::memcpy(tmp_values, ptr, count * sizeof(int16_t));
load(tmp_values);
}
inline void store(int16_t* ptr, size_t count) const {
int16_t tmp_values[16];
store(tmp_values);
std::memcpy(ptr, tmp_values, count * sizeof(int16_t));
}
size_t size = 16;
inline void operator=(const Vec256<int16_t>& b) {
values = b.values;
}
inline Vec256<int16_t> abs() {
Vec256<int16_t> ret;
ret.values = _mm256_abs_epi16(values);
return ret;
}
};
template <>
Vec256<int64_t> inline operator+(
const Vec256<int64_t>& a,
const Vec256<int64_t>& b) {
Vec256<int64_t> c = Vec256<int64_t>();
c.values = _mm256_add_epi64(a.values, b.values);
return c;
}
template <>
Vec256<int32_t> inline operator+(
const Vec256<int32_t>& a,
const Vec256<int32_t>& b) {
Vec256<int32_t> c = Vec256<int32_t>();
c.values = _mm256_add_epi32(a.values, b.values);
return c;
}
template <>
Vec256<int16_t> inline operator+(
const Vec256<int16_t>& a,
const Vec256<int16_t>& b) {
Vec256<int16_t> c = Vec256<int16_t>();
c.values = _mm256_add_epi16(a.values, b.values);
return c;
}
// AVX2 has no intrinsic for int64_t multiply so it needs to be emulated
// This could be implemented more efficiently using epi32 instructions
// This is also technically avx compatible, but then we'll need AVX
// code for add as well.
template <>
Vec256<int64_t> inline operator*(
const Vec256<int64_t>& a,
const Vec256<int64_t>& b) {
Vec256<int64_t> c = Vec256<int64_t>();
int64_t a0 = _mm256_extract_epi64(a.values, 0);
int64_t a1 = _mm256_extract_epi64(a.values, 1);
int64_t a2 = _mm256_extract_epi64(a.values, 2);
int64_t a3 = _mm256_extract_epi64(a.values, 3);
int64_t b0 = _mm256_extract_epi64(b.values, 0);
int64_t b1 = _mm256_extract_epi64(b.values, 1);
int64_t b2 = _mm256_extract_epi64(b.values, 2);
int64_t b3 = _mm256_extract_epi64(b.values, 3);
int64_t c0 = a0 * b0;
int64_t c1 = a1 * b1;
int64_t c2 = a2 * b2;
int64_t c3 = a3 * b3;
c.values = _mm256_set_epi64x(c3, c2, c1, c0);
return c;
}
template <>
Vec256<int32_t> inline operator*(
const Vec256<int32_t>& a,
const Vec256<int32_t>& b) {
Vec256<int32_t> c = Vec256<int32_t>();
c.values = _mm256_mullo_epi32(a.values, b.values);
return c;
}
template <>
Vec256<int16_t> inline operator*(
const Vec256<int16_t>& a,
const Vec256<int16_t>& b) {
Vec256<int16_t> c = Vec256<int16_t>();
c.values = _mm256_mullo_epi16(a.values, b.values);
return c;
}
#endif
}}} // namespace at::native::vec256

View File

@ -574,13 +574,10 @@ class TestTorch(TestCase):
diff = np.abs(nvs - tvs.numpy()).sum()
self.assertEqual(diff, 0)
sizes = []
sizes += [[2, 3, 3, 3, 3, 2, 2, 3, 2, 3, 2, 3, 3]]
sizes += [[4, 4, 4, 4, 4, 4, 4, 4, 4, 4]]
sizes += [[1, 32 * 8 * 32 * 8]]
sizes += [[1, 32770]]
for size in sizes:
_run_test(size)
_run_test([2, 3, 3, 3, 3, 2, 2, 3, 2, 3, 2, 3, 3])
_run_test([4, 4, 4, 4, 4, 4, 4, 4, 4, 4])
_run_test([1, 32 * 8 * 32 * 8])
_run_test([1, 32770])
def _testCSelection(self, torchfn, mathfn):
# Two tensors
@ -1160,15 +1157,41 @@ class TestTorch(TestCase):
def test_cpow(self):
self._test_cop(torch.pow, lambda x, y: float('nan') if x < 0 else math.pow(x, y))
# TODO: these tests only check if it's possible to pass a return value
# it'd be good to expand them
def test_sum(self):
def test_sum_all(self):
def check_sum_all(tensor):
pylist = tensor.reshape(-1).tolist()
self.assertEqual(tensor.sum(), sum(pylist))
check_sum_all(torch.tensor([1, 2, 3, 4, 5]))
check_sum_all(torch.randn(200000))
check_sum_all(torch.randn(2000, 2)[:, 0])
@unittest.skipIf(not TEST_NUMPY, 'Numpy not found')
def test_sum_dim(self):
def check_sum_dim(tensor, dim):
expected = tensor.numpy().sum(dim)
actual = tensor.sum(dim)
self.assertEqual(expected.shape, actual.shape)
self.assertTrue(np.allclose(expected, actual.numpy()))
check_sum_dim(torch.randn(3, 5, 7), 0)
check_sum_dim(torch.randn(3, 5, 7), 1)
check_sum_dim(torch.randn(3, 5, 7), 2)
check_sum_dim(torch.randn(100000), -1)
check_sum_dim(torch.randn(5, 400000), 1)
check_sum_dim(torch.randn(50, 50, 50), 0)
check_sum_dim(torch.randn(50, 50, 50), 1)
check_sum_dim(torch.randn(50, 50, 50), 2)
def test_sum_out(self):
x = torch.rand(100, 100)
res1 = torch.sum(x, 1)
res2 = torch.Tensor()
torch.sum(x, 1, out=res2)
self.assertEqual(res1, res2)
# TODO: these tests only check if it's possible to pass a return value
# it'd be good to expand them
def test_prod(self):
x = torch.rand(100, 100)
res1 = torch.prod(x, 1)
@ -2396,8 +2419,7 @@ class TestTorch(TestCase):
def test_cat_scalars(self):
x = torch.tensor(0)
y = torch.tensor(1)
with self.assertRaisesRegexp(RuntimeError,
'zero-dimensional.*cannot be concatenated'):
with self.assertRaisesRegex(RuntimeError, 'zero-dimensional.*cannot be concatenated'):
torch.cat([x, y])
@staticmethod