mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/67865 - Add int version of vectorized PrefixSum - Use unaligned load/store instructions - Add exclusive scan version. "exclusive" means that the i-th input element is not included in the i-th sum. For details see https://en.cppreference.com/w/cpp/algorithm/exclusive_scan Test Plan: ``` buck build mode/opt-clang //caffe2/benchmarks/cpp/tensorexpr:tensorexpr_bench OMP_NUM_THREADS=1 numactl -m 0 -C 5 \ ./buck-out/opt/gen/caffe2/benchmarks/cpp/tensorexpr/tensorexpr_bench --benchmark_filter=PrefixSumBench ``` For full benchmark results, see P465274613 ``` PrefixSumBench/LocalInt/64 57 ns 56 ns 12414048 GB/s=9.06239G/s PrefixSumBench/LocalInt/256 221 ns 221 ns 3160853 GB/s=9.28635G/s PrefixSumBench/LocalInt/1024 818 ns 817 ns 857922 GB/s=10.0235G/s PrefixSumBench/LocalInt/4096 3211 ns 3210 ns 217614 GB/s=10.2093G/s PrefixSumBench/LocalInt/16384 12806 ns 12804 ns 54805 GB/s=10.2364G/s PrefixSumBench/LocalInt/65536 51115 ns 51079 ns 13741 GB/s=10.2643G/s PrefixSumBench/LocalInt/262144 205974 ns 205912 ns 3401 GB/s=10.1847G/s PrefixSumBench/LocalInt/1048576 829523 ns 828859 ns 845 GB/s=10.1207G/s PrefixSumBench/LocalIntAVX2/64 45 ns 45 ns 15568113 GB/s=11.3549G/s PrefixSumBench/LocalIntAVX2/256 208 ns 208 ns 3371174 GB/s=9.86913G/s PrefixSumBench/LocalIntAVX2/1024 893 ns 892 ns 783154 GB/s=9.18629G/s PrefixSumBench/LocalIntAVX2/4096 3618 ns 3613 ns 193834 GB/s=9.06838G/s PrefixSumBench/LocalIntAVX2/16384 14416 ns 14411 ns 48564 GB/s=9.09543G/s PrefixSumBench/LocalIntAVX2/65536 57650 ns 57617 ns 12156 GB/s=9.09952G/s PrefixSumBench/LocalIntAVX2/262144 230855 ns 230612 ns 3035 GB/s=9.09386G/s PrefixSumBench/LocalIntAVX2/1048576 924265 ns 923777 ns 758 GB/s=9.08077G/s PrefixSumBench/LocalIntAVX512/64 23 ns 23 ns 24876551 GB/s=22.0697G/s PrefixSumBench/LocalIntAVX512/256 95 ns 95 ns 7387386 GB/s=21.556G/s PrefixSumBench/LocalIntAVX512/1024 435 ns 435 ns 1609682 GB/s=18.8425G/s PrefixSumBench/LocalIntAVX512/4096 1815 ns 1815 ns 385462 GB/s=18.0561G/s PrefixSumBench/LocalIntAVX512/16384 7479 ns 7476 ns 93660 GB/s=17.5335G/s PrefixSumBench/LocalIntAVX512/65536 30171 ns 29879 ns 23430 GB/s=17.5468G/s PrefixSumBench/LocalIntAVX512/262144 125805 ns 125631 ns 5570 GB/s=16.6929G/s PrefixSumBench/LocalIntAVX512/1048576 504216 ns 503983 ns 1384 GB/s=16.6446G/s PrefixSumBench/ExclusiveScanIntAVX512/64 23 ns 23 ns 30058295 PrefixSumBench/ExclusiveScanIntAVX512/256 101 ns 101 ns 7398498 PrefixSumBench/ExclusiveScanIntAVX512/1024 435 ns 434 ns 1403877 PrefixSumBench/ExclusiveScanIntAVX512/4096 1979 ns 1978 ns 354016 PrefixSumBench/ExclusiveScanIntAVX512/16384 7828 ns 7819 ns 89551 PrefixSumBench/ExclusiveScanIntAVX512/65536 31206 ns 31192 ns 22408 PrefixSumBench/ExclusiveScanIntAVX512/262144 130106 ns 130023 ns 5388 PrefixSumBench/ExclusiveScanIntAVX512/1048576 525515 ns 524976 ns 1244 ``` Reviewed By: navahgar, swolchok Differential Revision: D32011740 fbshipit-source-id: 7962de710bd588291dd6bf0c719f579c55f7c063
396 lines
12 KiB
C++
396 lines
12 KiB
C++
#include <benchmark/benchmark.h>
|
|
#include "ATen/Functions.h"
|
|
|
|
#include <torch/csrc/jit/jit_log.h>
|
|
#include <torch/csrc/jit/tensorexpr/ir.h>
|
|
#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
|
|
#include <torch/csrc/jit/tensorexpr/llvm_codegen.h>
|
|
#include <torch/csrc/jit/tensorexpr/loopnest.h>
|
|
#include <torch/csrc/jit/tensorexpr/operators/operators.h>
|
|
#include <torch/csrc/jit/tensorexpr/tensor.h>
|
|
#include <torch/torch.h>
|
|
|
|
#include <immintrin.h>
|
|
|
|
using namespace torch::jit::tensorexpr;
|
|
|
|
namespace {
|
|
#ifdef __AVX2__
|
|
|
|
#define _mm256_slli_si1(x) \
|
|
_mm256_blend_epi32( \
|
|
_mm256_permutevar8x32_ps(x, _mm256_set_epi32(6, 5, 4, 3, 2, 1, 0, 7)), \
|
|
_mm256_setzero_si256(), \
|
|
1)
|
|
#define _mm256_slli_si2(x) \
|
|
_mm256_blend_epi32( \
|
|
_mm256_permutevar8x32_ps(x, _mm256_set_epi32(5, 4, 3, 2, 1, 0, 7, 6)), \
|
|
_mm256_setzero_si256(), \
|
|
3)
|
|
#define _mm256_slli_si4(x) \
|
|
_mm256_blend_epi32( \
|
|
_mm256_permutevar8x32_ps(x, _mm256_set_epi32(3, 2, 1, 0, 7, 6, 5, 4)), \
|
|
_mm256_setzero_si256(), \
|
|
15)
|
|
|
|
__m256 PrefixSum(__m256 x) {
|
|
x = _mm256_add_ps(x, _mm256_slli_si1(x));
|
|
x = _mm256_add_ps(x, _mm256_slli_si2(x));
|
|
x = _mm256_add_ps(x, _mm256_slli_si4(x));
|
|
return x; // local prefix sums
|
|
}
|
|
|
|
__m256i PrefixSumInt(__m256i x) {
|
|
x = _mm256_add_epi32(x, _mm256_slli_si1(x));
|
|
x = _mm256_add_epi32(x, _mm256_slli_si2(x));
|
|
x = _mm256_add_epi32(x, _mm256_slli_si4(x));
|
|
return x; // local prefix sums
|
|
}
|
|
|
|
// Util function to log the given value. Not used during benchmarking.
|
|
template <class T>
|
|
inline void Log(const __m256i& value) {
|
|
const size_t n = sizeof(__m256i) / sizeof(T);
|
|
T buffer[n];
|
|
_mm256_storeu_si256((__m256i*)buffer, value);
|
|
for (int i = 0; i < n; i++)
|
|
std::cout << buffer[n - i - 1] << " ";
|
|
std::cout << std::endl;
|
|
}
|
|
#endif
|
|
|
|
#ifdef __AVX512F__
|
|
|
|
#define _mm512_slli_si512(x, k) \
|
|
_mm512_alignr_epi32(x, _mm512_setzero_si512(), 16 - k)
|
|
|
|
__m512 PrefixSum(__m512 x) {
|
|
x = _mm512_add_ps(x, _mm512_slli_si512(x, 1));
|
|
x = _mm512_add_ps(x, _mm512_slli_si512(x, 2));
|
|
x = _mm512_add_ps(x, _mm512_slli_si512(x, 4));
|
|
x = _mm512_add_ps(x, _mm512_slli_si512(x, 8));
|
|
return x; // local prefix sums
|
|
}
|
|
|
|
__m512i PrefixSumInt(__m512i x) {
|
|
x = _mm512_add_epi32(x, _mm512_slli_si512(x, 1));
|
|
x = _mm512_add_epi32(x, _mm512_slli_si512(x, 2));
|
|
x = _mm512_add_epi32(x, _mm512_slli_si512(x, 4));
|
|
x = _mm512_add_epi32(x, _mm512_slli_si512(x, 8));
|
|
return x; // local prefix sums
|
|
}
|
|
|
|
template <int index>
|
|
float _mm512_extract_f32(__m512 target) {
|
|
return _mm512_cvtss_f32(_mm512_alignr_epi32(target, target, index));
|
|
}
|
|
|
|
// extract the last i32 from target
|
|
int _mm512_extract_epi32(__m512i target) {
|
|
__m256i x = _mm512_extracti32x8_epi32(target, 1);
|
|
return _mm256_extract_epi32(x, 7);
|
|
}
|
|
|
|
void PrefixSum(float* output_data, float* input_data, size_t input_size) {
|
|
float carry = 0.0f;
|
|
for (int i = 0; i < input_size / 16; i++) {
|
|
__m512 x = _mm512_loadu_ps(input_data + i * 16);
|
|
x = PrefixSum(x);
|
|
x = _mm512_add_ps(x, _mm512_set1_ps(carry));
|
|
carry = _mm512_extract_f32<15>(x);
|
|
_mm512_storeu_ps((__m512*)(output_data + i * 16), x);
|
|
}
|
|
}
|
|
|
|
void PrefixSum(int* output_data, int* input_data, size_t input_size) {
|
|
int carry = 0;
|
|
for (int i = 0; i < input_size / 16; i++) {
|
|
__m512i x = _mm512_loadu_epi32(input_data + i * 16);
|
|
x = PrefixSumInt(x);
|
|
x = _mm512_add_epi32(x, _mm512_set1_epi32(carry));
|
|
carry = _mm512_extract_epi32(x);
|
|
_mm512_storeu_epi32((__m512i*)(output_data + i * 16), x);
|
|
}
|
|
}
|
|
#endif
|
|
|
|
// PrefixSum: the same as inclusive scan
|
|
class PrefixSumBench : public benchmark::Fixture {
|
|
public:
|
|
void SetUp(const benchmark::State& state) override {
|
|
input_size_ = state.range(0);
|
|
input_ = torch::rand(input_size_);
|
|
ref_ = prefixSum(input_);
|
|
|
|
// no type promotion. Default is int->long.
|
|
input_int_ = torch::randint(1000, {input_size_}, at::kInt);
|
|
ref_int_ = at::cumsum(input_int_, 0, at::kInt);
|
|
}
|
|
|
|
void TearDown(benchmark::State& state) override {
|
|
if (output_.numel() > 0) {
|
|
if (output_.numel() == ref_.numel()) {
|
|
TORCH_CHECK(at::allclose(ref_, output_, 1e-3, 1e-3));
|
|
}
|
|
state.counters["GB/s"] = benchmark::Counter(
|
|
uint64_t(state.iterations()) * 2 * output_.nbytes(),
|
|
benchmark::Counter::kIsRate);
|
|
} else {
|
|
if (output_int_.numel() == ref_int_.numel()) {
|
|
TORCH_CHECK(ref_int_.equal(output_int_));
|
|
}
|
|
state.counters["GB/s"] = benchmark::Counter(
|
|
uint64_t(state.iterations()) * 2 * output_int_.nbytes(),
|
|
benchmark::Counter::kIsRate);
|
|
}
|
|
}
|
|
|
|
at::Tensor prefixSum(const at::Tensor& inp) {
|
|
return at::cumsum(inp, 0);
|
|
}
|
|
|
|
void runATen(benchmark::State& state) {
|
|
output_ = prefixSum(input_);
|
|
for (auto _ : state) {
|
|
at::cumsum_out(output_, input_, 0);
|
|
}
|
|
}
|
|
|
|
void runLocal(benchmark::State& state) {
|
|
output_ = at::empty_like(ref_);
|
|
for (auto _ : state) {
|
|
auto input_data = input_.data_ptr<float>();
|
|
auto output_data = output_.data_ptr<float>();
|
|
float sum = 0.0f;
|
|
for (int i = 0; i < input_size_; ++i) {
|
|
sum = sum + input_data[i];
|
|
output_data[i] = sum;
|
|
}
|
|
}
|
|
}
|
|
|
|
// no type promotion
|
|
void runLocalInt(benchmark::State& state) {
|
|
output_int_ = at::empty_like(input_int_);
|
|
for (auto _ : state) {
|
|
auto input_data = input_int_.data_ptr<int>();
|
|
auto output_data = output_int_.data_ptr<int>();
|
|
int sum = 0;
|
|
for (int i = 0; i < input_size_; ++i) {
|
|
sum = sum + input_data[i];
|
|
output_data[i] = sum;
|
|
}
|
|
}
|
|
}
|
|
|
|
void runNNC(benchmark::State& state) {
|
|
BufHandle input("input", {input_size_}, kFloat);
|
|
BufHandle output("output", {input_size_}, kFloat);
|
|
BufHandle s("s", {1}, kFloat);
|
|
VarHandle i("i", kInt);
|
|
auto allocS = Allocate::make(s);
|
|
auto initS = Store::make(s, {0}, 0.0f);
|
|
auto accumS = Store::make(
|
|
s, {0}, Add::make(Load::make(s, {0}), Load::make(input, {i})));
|
|
auto store = Store::make(output, {i}, Load::make(s, {0}));
|
|
auto forI = For::make(i, 0, input_size_, Block::make({accumS, store}));
|
|
auto freeS = Free::make(s);
|
|
auto par = Block::make({allocS, initS, forI, freeS});
|
|
LoopNest nest(par, {output.node()});
|
|
|
|
std::vector<CodeGen::BufferArg> buf_args;
|
|
buf_args.emplace_back(input);
|
|
buf_args.emplace_back(output);
|
|
LLVMCodeGen cg(nest.root_stmt(), buf_args);
|
|
|
|
std::vector<CodeGen::CallArg> call_args;
|
|
output_ = at::empty_like(ref_);
|
|
for (auto _ : state) {
|
|
call_args.clear();
|
|
call_args.emplace_back(input_.data_ptr<float>());
|
|
call_args.emplace_back(output_.data_ptr<float>());
|
|
cg.call(call_args);
|
|
}
|
|
}
|
|
|
|
#ifdef __AVX2__
|
|
void runLocalAVX2(benchmark::State& state) {
|
|
output_ = at::empty_like(ref_);
|
|
for (auto _ : state) {
|
|
float* input_data = input_.data_ptr<float>();
|
|
float* output_data = output_.data_ptr<float>();
|
|
|
|
float carry = 0.0f;
|
|
for (int i = 0; i < input_size_ / 8; i++) {
|
|
__m256 x = _mm256_loadu_ps(input_data + i * 8);
|
|
x = PrefixSum(x);
|
|
x = _mm256_add_ps(x, _mm256_set1_ps(carry));
|
|
(reinterpret_cast<__m256*>(output_data))[i] = x;
|
|
carry = _mm256_cvtss_f32(_mm256_permutevar8x32_ps(
|
|
x, _mm256_set_epi32(6, 5, 4, 3, 2, 1, 0, 7)));
|
|
}
|
|
}
|
|
}
|
|
|
|
void runLocalIntAVX2(benchmark::State& state) {
|
|
output_int_ = at::empty_like(input_int_);
|
|
for (auto _ : state) {
|
|
auto input_data = input_int_.data_ptr<int>();
|
|
auto output_data = output_int_.data_ptr<int>();
|
|
|
|
int carry = 0;
|
|
for (size_t i = 0; i < input_size_ / 8; i++) {
|
|
__m256i x = _mm256_loadu_si256((__m256i*)(input_data + i * 8));
|
|
x = PrefixSumInt(x);
|
|
x = _mm256_add_epi32(x, _mm256_set1_epi32(carry));
|
|
_mm256_storeu_si256((__m256i*)(output_data + i * 8), x);
|
|
carry = _mm256_extract_epi32(x, 7);
|
|
}
|
|
}
|
|
}
|
|
#endif
|
|
|
|
#ifdef __AVX512F__
|
|
void runLocalAVX512(benchmark::State& state) {
|
|
output_ = at::empty_like(ref_);
|
|
for (auto _ : state) {
|
|
auto input_data = input_.data_ptr<float>();
|
|
auto output_data = output_.data_ptr<float>();
|
|
PrefixSum(output_data, input_data, input_size_);
|
|
}
|
|
}
|
|
|
|
void runLocalIntAVX512(benchmark::State& state) {
|
|
output_int_ = at::empty_like(input_int_);
|
|
for (auto _ : state) {
|
|
auto input_data = input_int_.data_ptr<int>();
|
|
auto output_data = output_int_.data_ptr<int>();
|
|
PrefixSum(output_data, input_data, input_size_);
|
|
}
|
|
}
|
|
|
|
void runExclusiveScanAVX512(benchmark::State& state) {
|
|
output_ = at::empty({input_size_ + 1}, at::kFloat);
|
|
for (auto _ : state) {
|
|
auto input_data = input_.data_ptr<float>();
|
|
auto output_data = output_.data_ptr<float>();
|
|
output_data[0] = 0.0f;
|
|
PrefixSum(output_data + 1, input_data, input_size_);
|
|
}
|
|
}
|
|
|
|
void runExclusiveScanIntAVX512(benchmark::State& state) {
|
|
output_int_ = at::empty({input_size_ + 1}, at::kInt);
|
|
for (auto _ : state) {
|
|
auto input_data = input_int_.data_ptr<int>();
|
|
auto output_data = output_int_.data_ptr<int>();
|
|
output_data[0] = 0;
|
|
PrefixSum(output_data + 1, input_data, input_size_);
|
|
}
|
|
}
|
|
|
|
#endif
|
|
|
|
private:
|
|
int input_size_;
|
|
at::Tensor input_;
|
|
at::Tensor output_;
|
|
at::Tensor ref_;
|
|
at::Tensor input_int_;
|
|
at::Tensor output_int_;
|
|
at::Tensor ref_int_; // no type promotion
|
|
};
|
|
|
|
} // namespace
|
|
|
|
BENCHMARK_DEFINE_F(PrefixSumBench, ATen)(benchmark::State& state) {
|
|
runATen(state);
|
|
}
|
|
|
|
BENCHMARK_DEFINE_F(PrefixSumBench, Local)(benchmark::State& state) {
|
|
runLocal(state);
|
|
}
|
|
|
|
BENCHMARK_DEFINE_F(PrefixSumBench, LocalInt)(benchmark::State& state) {
|
|
runLocalInt(state);
|
|
}
|
|
|
|
BENCHMARK_DEFINE_F(PrefixSumBench, NNC)(benchmark::State& state) {
|
|
runNNC(state);
|
|
}
|
|
|
|
#ifdef __AVX2__
|
|
BENCHMARK_DEFINE_F(PrefixSumBench, LocalAVX2)(benchmark::State& state) {
|
|
runLocalAVX2(state);
|
|
}
|
|
BENCHMARK_DEFINE_F(PrefixSumBench, LocalIntAVX2)(benchmark::State& state) {
|
|
runLocalIntAVX2(state);
|
|
}
|
|
#endif
|
|
|
|
#ifdef __AVX512F__
|
|
BENCHMARK_DEFINE_F(PrefixSumBench, LocalAVX512)(benchmark::State& state) {
|
|
runLocalAVX512(state);
|
|
}
|
|
BENCHMARK_DEFINE_F(PrefixSumBench, LocalIntAVX512)(benchmark::State& state) {
|
|
runLocalIntAVX512(state);
|
|
}
|
|
|
|
BENCHMARK_DEFINE_F(PrefixSumBench, ExclusiveScanAVX512)
|
|
(benchmark::State& state) {
|
|
runExclusiveScanAVX512(state);
|
|
}
|
|
BENCHMARK_DEFINE_F(PrefixSumBench, ExclusiveScanIntAVX512)
|
|
(benchmark::State& state) {
|
|
runExclusiveScanIntAVX512(state);
|
|
}
|
|
#endif
|
|
|
|
//---------- float benchmarks ----------//
|
|
BENCHMARK_REGISTER_F(PrefixSumBench, ATen)
|
|
->RangeMultiplier(4)
|
|
->Ranges({{1 << 6, 1 << 20}});
|
|
|
|
BENCHMARK_REGISTER_F(PrefixSumBench, NNC)
|
|
->RangeMultiplier(4)
|
|
->Ranges({{1 << 6, 1 << 20}});
|
|
|
|
BENCHMARK_REGISTER_F(PrefixSumBench, Local)
|
|
->RangeMultiplier(4)
|
|
->Ranges({{1 << 6, 1 << 20}});
|
|
|
|
#ifdef __AVX2__
|
|
BENCHMARK_REGISTER_F(PrefixSumBench, LocalAVX2)
|
|
->RangeMultiplier(4)
|
|
->Ranges({{1 << 6, 1 << 20}});
|
|
#endif
|
|
|
|
#ifdef __AVX512F__
|
|
BENCHMARK_REGISTER_F(PrefixSumBench, LocalAVX512)
|
|
->RangeMultiplier(4)
|
|
->Ranges({{1 << 6, 1 << 20}});
|
|
BENCHMARK_REGISTER_F(PrefixSumBench, ExclusiveScanAVX512)
|
|
->RangeMultiplier(4)
|
|
->Ranges({{1 << 6, 1 << 20}});
|
|
#endif
|
|
|
|
//---------- int benchmarks ----------//
|
|
BENCHMARK_REGISTER_F(PrefixSumBench, LocalInt)
|
|
->RangeMultiplier(4)
|
|
->Ranges({{1 << 6, 1 << 20}});
|
|
|
|
#ifdef __AVX2__
|
|
BENCHMARK_REGISTER_F(PrefixSumBench, LocalIntAVX2)
|
|
->RangeMultiplier(4)
|
|
->Ranges({{1 << 6, 1 << 20}});
|
|
#endif
|
|
|
|
#ifdef __AVX512F__
|
|
BENCHMARK_REGISTER_F(PrefixSumBench, LocalIntAVX512)
|
|
->RangeMultiplier(4)
|
|
->Ranges({{1 << 6, 1 << 20}});
|
|
BENCHMARK_REGISTER_F(PrefixSumBench, ExclusiveScanIntAVX512)
|
|
->RangeMultiplier(4)
|
|
->Ranges({{1 << 6, 1 << 20}});
|
|
#endif
|