pytorch/benchmarks/cpp/tensorexpr/bench_approx.cpp
Zirui Tao 2b202667c1 [1/N] CPU pointwise optimization: Add a benchmark for Relu
Summary: As title

Test Plan:
Building: finished in 01:58.4 min (100%) 16761/16761 jobs, 16761 updated
  Total time: 02:32.3 min
Run on (24 X 2394.45 MHz CPU s)
2021-02-16 21:29:30
----------------------------------------------------------------------------------------------------
Benchmark                                             Time           CPU Iterations UserCounters...
----------------------------------------------------------------------------------------------------
relu_nnc/64                                        1738 ns       1738 ns     410535 log/s=36.8257M/s
relu_nnc/512                                       1708 ns       1708 ns     408678 log/s=299.711M/s
relu_nnc/8192                                      3297 ns       3297 ns     214362 log/s=2.48499G/s
relu_nnc/32768                                    10725 ns      10722 ns      61032 log/s=3.05603G/s
log_nnc_sleef/64                                   2076 ns       2075 ns     326248 log/s=30.8436M/s
log_nnc_sleef/512                                  3070 ns       3069 ns     230616 log/s=166.81M/s
log_nnc_sleef/8192                                22214 ns      22210 ns      31251 log/s=368.849M/s
log_nnc_sleef/32768                               85835 ns      85824 ns       8366 log/s=381.804M/s
log_nnc_fast/64                                    1852 ns       1852 ns     379123 log/s=34.5532M/s
log_nnc_fast/512                                   2456 ns       2456 ns     299463 log/s=208.503M/s
log_nnc_fast/8192                                 10953 ns      10952 ns      69894 log/s=747.957M/s
log_nnc_fast/32768                                35424 ns      35422 ns      19986 log/s=925.08M/s
log_nnc_vml/64                                     2361 ns       2361 ns     356220 log/s=27.1063M/s
log_nnc_vml/512                                    2218 ns       2218 ns     313444 log/s=230.857M/s
log_nnc_vml/8192                                   8420 ns       8420 ns      81594 log/s=972.912M/s
log_nnc_vml/32768                                 29484 ns      29484 ns      21701 log/s=1.1114G/s
log_aten/64                                       15970 ns      15970 ns      44401 log/s=4.00742M/s
log_aten/512                                      18344 ns      18344 ns      41056 log/s=27.9114M/s
log_aten/8192                                     24894 ns      24893 ns      27414 log/s=329.084M/s
log_aten/32768                                    29129 ns      29125 ns      22477 log/s=1.12508G/s
logit_nnc_sleef/64                                 2379 ns       2379 ns     261168 logit/s=26.8981M/s
logit_nnc_sleef/512                                5778 ns       5774 ns     114009 logit/s=88.6757M/s
logit_nnc_sleef/8192                              57268 ns      57236 ns      12429 logit/s=143.127M/s
logit_nnc_sleef/32768                            216356 ns     216344 ns       3026 logit/s=151.462M/s
logit_nnc_fast/64                                  2178 ns       2173 ns     282306 logit/s=29.4565M/s
logit_nnc_fast/512                                 2955 ns       2943 ns     202527 logit/s=173.95M/s
logit_nnc_fast/8192                               14836 ns      14835 ns      46794 logit/s=552.192M/s
logit_nnc_fast/32768                              53999 ns      53997 ns      12842 logit/s=606.846M/s
logit_nnc_vml/64                                   2132 ns       2132 ns     335874 logit/s=30.018M/s
logit_nnc_vml/512                                  3029 ns       3029 ns     250988 logit/s=169.058M/s
logit_nnc_vml/8192                                13264 ns      13263 ns      53504 logit/s=617.655M/s
logit_nnc_vml/32768                               49395 ns      48284 ns      14526 logit/s=678.654M/s
logit_aten/64                                     88180 ns      86690 ns       9270 logit/s=738.261k/s
logit_aten/512                                    54682 ns      54489 ns      10000 logit/s=9.3964M/s
logit_aten/8192                                  170878 ns     164357 ns       6965 logit/s=49.8427M/s
logit_aten/32768                                 452291 ns     434638 ns       3967 logit/s=75.3915M/s
logit_caffe2/64                                   30170 ns      29902 ns      24686 logit/s=2.14029M/s
logit_caffe2/512                                 203517 ns     201201 ns       3570 logit/s=2.54472M/s
logit_caffe2/8192                               3199528 ns    3157098 ns        220 logit/s=2.59479M/s
logit_caffe2/32768                             12520838 ns   12504846 ns         56 logit/s=2.62042M/s
tanh_nnc_fast/64                                   1979 ns       1977 ns     309745 tanh/s=32.3752M/s
tanh_nnc_fast/512                                  2331 ns       2331 ns     300937 tanh/s=219.636M/s
tanh_nnc_fast/8192                                 8323 ns       8323 ns      83601 tanh/s=984.26M/s
tanh_nnc_fast/32768                               30767 ns      30766 ns      23024 tanh/s=1065.06M/s
tanh_aten/64                                      17181 ns      17180 ns      36818 tanh/s=3.72522M/s
tanh_aten/512                                     19071 ns      19036 ns      37243 tanh/s=26.8968M/s
tanh_aten/8192                                    53542 ns      52006 ns      16268 tanh/s=157.521M/s
tanh_aten/32768                                  619869 ns     587600 ns       1000 tanh/s=55.7658M/s
tanh_caffe2/64                                     9668 ns       9654 ns      70926 tanh/s=6.62919M/s
tanh_caffe2/512                                   70409 ns      70409 ns       9881 tanh/s=7.27184M/s
tanh_caffe2/8192                                1179098 ns    1179011 ns        644 tanh/s=6.9482M/s
tanh_caffe2/32768                               4384300 ns    4382613 ns        156 tanh/s=7.47682M/s
BatchNorm/ATen/1/64/112/112                    23186429 ns   23183715 ns         27 GB/s=277.028M/s
BatchNorm/ATen/1/256/14/14                      1772907 ns    1770636 ns        394 GB/s=226.703M/s
BatchNorm/ATen/1/128/28/28                      3069417 ns    3069229 ns        232 GB/s=261.569M/s
BatchNorm/ATen/1/64/56/56                       6367276 ns    6367190 ns        111 GB/s=252.173M/s
BatchNorm/ATen/1/512/7/7                        1334734 ns    1334373 ns        516 GB/s=150.411M/s
BatchNorm/ATen/5/64/112/112                   131727903 ns  131721364 ns          7 GB/s=243.792M/s
BatchNorm/ATen/5/256/14/14                      7879002 ns    7874672 ns         85 GB/s=254.873M/s
BatchNorm/ATen/5/128/28/28                     15561373 ns   15269781 ns         42 GB/s=262.877M/s
BatchNorm/ATen/5/64/56/56                      29169722 ns   29107393 ns         24 GB/s=275.812M/s
BatchNorm/ATen/5/512/7/7                        5042006 ns    5028687 ns        100 GB/s=199.559M/s
BatchNorm/NNC/1/64/112/112                      3303598 ns    3271058 ns        188 GB/s=1.96344G/s
BatchNorm/NNC/1/256/14/14                        330641 ns     326644 ns       2033 GB/s=1.22889G/s
BatchNorm/NNC/1/128/28/28                        498706 ns     497894 ns       1131 GB/s=1.61242G/s
BatchNorm/NNC/1/64/56/56                        1116910 ns    1114768 ns        641 GB/s=1.44033G/s
BatchNorm/NNC/1/512/7/7                          163380 ns     163351 ns       3493 GB/s=1.22867G/s
BatchNorm/NNC/5/64/112/112                     16392078 ns   16386427 ns         41 GB/s=1.95971G/s
BatchNorm/NNC/5/256/14/14                       1133781 ns    1133369 ns        674 GB/s=1.77086G/s
BatchNorm/NNC/5/128/28/28                       2053208 ns    2053211 ns        276 GB/s=1.95503G/s
BatchNorm/NNC/5/64/56/56                        3874949 ns    3874734 ns        165 GB/s=2.07193G/s
BatchNorm/NNC/5/512/7/7                          653665 ns     651498 ns       1236 GB/s=1.54033G/s
BatchNorm/ATenRelu/1/64/112/112                36878892 ns   36100523 ns         22 GB/s=177.907M/s
BatchNorm/ATenRelu/1/256/14/14                  6404318 ns    5544976 ns        100 GB/s=72.3913M/s
BatchNorm/ATenRelu/1/128/28/28                  5897059 ns    5735509 ns        106 GB/s=139.973M/s
BatchNorm/ATenRelu/1/64/56/56                  10075458 ns    9965146 ns         62 GB/s=161.125M/s
BatchNorm/ATenRelu/1/512/7/7                    2680507 ns    2662541 ns        254 GB/s=75.3806M/s
BatchNorm/ATenRelu/5/64/112/112               145738113 ns  144253693 ns          5 GB/s=222.612M/s
BatchNorm/ATenRelu/5/256/14/14                 13582519 ns   13427209 ns         65 GB/s=149.476M/s
BatchNorm/ATenRelu/5/128/28/28                 22747138 ns   22627185 ns         31 GB/s=177.401M/s
BatchNorm/ATenRelu/5/64/56/56                  53609692 ns   52936728 ns         15 GB/s=151.656M/s
BatchNorm/ATenRelu/5/512/7/7                   11378314 ns   11083777 ns         65 GB/s=90.5395M/s
BatchNorm/NNCRelu/1/64/112/112                  3154436 ns    3148939 ns        193 GB/s=2.03958G/s
BatchNorm/NNCRelu/1/256/14/14                    337341 ns     337163 ns       1926 GB/s=1.19055G/s
BatchNorm/NNCRelu/1/128/28/28                    505570 ns     505569 ns       1231 GB/s=1.58794G/s
BatchNorm/NNCRelu/1/64/56/56                     903452 ns     903421 ns        659 GB/s=1.77728G/s
BatchNorm/NNCRelu/1/512/7/7                      158521 ns     158321 ns       3781 GB/s=1.2677G/s
BatchNorm/NNCRelu/5/64/112/112                 15488210 ns   15480019 ns         41 GB/s=2.07446G/s
BatchNorm/NNCRelu/5/256/14/14                   1149186 ns    1148963 ns        649 GB/s=1.74683G/s
BatchNorm/NNCRelu/5/128/28/28                   2011589 ns    2011424 ns        320 GB/s=1.99564G/s
BatchNorm/NNCRelu/5/64/56/56                    3776274 ns    3776060 ns        161 GB/s=2.12607G/s
BatchNorm/NNCRelu/5/512/7/7                      699762 ns     699582 ns        975 GB/s=1.43446G/s
BM_CompileSwish                                30471825 ns   30470017 ns         24
BM_CompileSwishLLVMOnly                        27479624 ns   27473475 ns         25
FusedOverhead                                    196219 ns     196195 ns       3342
UnfusedOverhead                                  220210 ns     220119 ns       3302
Gemm/Torch/128/128/128                           115526 ns     115343 ns       7414 GFLOPS=36.3637G/s
Gemm/TensorExprNoopt/128/128/128                3155851 ns    3155706 ns        210 GFLOPS=1.32912G/s
Gemm/TensorExprTile32x32/128/128/128             124454 ns     124452 ns       5774 GFLOPS=33.7021G/s
Gemm/TensorExprTile4x16/128/128/128              174408 ns     174366 ns       3987 GFLOPS=24.0546G/s
Gemm/TensorExprTile4x16VecUnroll/128/128/128      72949 ns      72948 ns       9028 GFLOPS=57.4974G/s
Gemm/TensorExprTile4x16Cache/128/128/128          73237 ns      73234 ns       9501 GFLOPS=57.2726G/s
Reduce1D/Torch/16777216                       426865265 ns  426853756 ns          2 BYTES=157.217M/s
Reduce1D/Naive/16777216                       132347709 ns  132343710 ns          5 BYTES=507.08M/s
Reduce1D/NativeRfactor/16777216               234668375 ns  234664682 ns          3 BYTES=285.978M/s
Reduce1D/TeNaive/16777216                      20468304 ns   20467906 ns         34 BYTES=3.27874G/s
Reduce1D/TeSplitTail/16777216                  20378995 ns   20378678 ns         34 BYTES=3.29309G/s
Reduce1D/TeSplitMask/16777216                  20371783 ns   20371260 ns         36 BYTES=3.29429G/s
Reduce1D/TeRfactorV2/16777216                   8235908 ns    8235723 ns         84 BYTES=8.14851G/s

CPU info:

Running ```sudo lshw -class processor```. Get 24 CPUs with identical architecture as follows:

  *-cpu:0
       description: CPU
       product: Intel Core Processor (Broadwell)
       vendor: Intel Corp.
       physical id: 400
       bus info: cpu@0
       version: 6.61.2
       slot: CPU 0
       size: 2GHz
       capacity: 2GHz
       width: 64 bits
       capabilities: fpu fpu_exception wp vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 syscall nx rdtscp x86-64 constant_tsc rep_good nopl xtopology cpuid pni pclmulqdq vmx ssse3 fma cx16 pcid sse4_1 sse4_2 movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch cpuid_fault invpcid_single pti tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase bmi1 hle avx2 smep bmi2 erms invpcid rtm rdseed adx smap xsaveopt arat
       configuration: cores=1 enabledcores=1 microcode=1 threads=1

Reviewed By: bwasti

Differential Revision: D26275048

fbshipit-source-id: 3de669f622eb8cd328787caa878dc0c05de600a5
2021-02-17 17:18:28 -08:00

443 lines
14 KiB
C++

#include <benchmark/benchmark.h>
#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
#include <torch/csrc/jit/tensorexpr/loopnest.h>
#include <torch/csrc/jit/tensorexpr/tensor.h>
#include <torch/csrc/jit/tensorexpr/llvm_codegen.h>
#include <torch/torch.h>
#include "caffe2/operators/tanh_op.h"
#include "caffe2/operators/logit_op.h"
using namespace torch::jit;
using namespace torch::jit::tensorexpr;
void vectorize(tensorexpr::LoopNest* ln, tensorexpr::Tensor* target, int width) {
auto loops = ln->getLoopStmtsFor(target);
For *outer, *inner, *tail;
ln->splitWithTail(loops[0], width, &outer, &inner, &tail);
ln->vectorize(inner);
}
void optimizePointwise(tensorexpr::LoopNest* ln, tensorexpr::Tensor* target) {
std::vector<For*> loops = ln->getLoopStmtsFor(target);
For *outer, *inner, *tail;
ln->splitWithTail(loops[0], 16 * 8, &outer, &inner, &tail);
ln->vectorize(inner);
ln->splitWithTail(outer, 8, &outer, &inner, &tail);
Stmt* unrolled;
LoopNest::unroll(inner, &unrolled);
}
static void relu_nnc(benchmark::State& state) {
KernelScope ks;
auto N = VarHandle("N", kInt);
Placeholder A("A", kFloat, {N});
auto clamp = 0;
torch::jit::tensorexpr::Tensor* B = Compute("B", {N}, [&](const VarHandle& i){
auto A_elem = [&]() {
auto elem = A.load(i);
auto min = FloatImm::make(clamp);
return CompareSelect::make(elem, min, min, elem, kLT);
}();
return A_elem;
});
LoopNest ln({B});
optimizePointwise(&ln, B);
ln.prepareForCodegen();
Stmt* s = ln.root_stmt();
s = torch::jit::tensorexpr::IRSimplifier::simplify(s);
std::vector<CodeGen::BufferArg> args;
args.emplace_back(B);
args.emplace_back(A);
args.emplace_back(N);
LLVMCodeGen cg(s, args);
at::Tensor A_t = torch::randn({state.range(0)});
at::Tensor B_t = torch::randn(state.range(0));
auto B_ref = at::relu(A_t);
cg.call({B_t.data_ptr<float>(), A_t.data_ptr<float>(), state.range(0)});
TORCH_CHECK(at::allclose(B_t, B_ref));
for (auto _ : state){
cg.call({B_t.data_ptr<float>(), A_t.data_ptr<float>(), state.range(0)});
}
state.counters["log/s"] = benchmark::Counter(
uint64_t(state.range(0) * state.iterations()), benchmark::Counter::kIsRate);
}
static void log_nnc_sleef(benchmark::State& state) {
KernelScope ks;
auto N = VarHandle("N", kInt);
Placeholder A("A", kFloat, {N});
torch::jit::tensorexpr::Tensor* B =
Compute("B", {N}, [&](const VarHandle& i) {
return log(A.load(i));
});
LoopNest ln({B});
ln.prepareForCodegen();
vectorize(&ln, B, 8);
Stmt* s = ln.root_stmt();
s = torch::jit::tensorexpr::IRSimplifier::simplify(s);
std::vector<CodeGen::BufferArg> args;
args.emplace_back(B);
args.emplace_back(A);
args.emplace_back(N);
LLVMCodeGen cg(s, args);
at::Tensor A_t = torch::abs(torch::randn({state.range(0)}));
at::Tensor B_t = torch::randn({state.range(0)});
auto B_ref = at::log(A_t);
cg.call({B_t.data_ptr<float>(), A_t.data_ptr<float>(), state.range(0)});
TORCH_CHECK(at::allclose(B_t, B_ref));
for (auto _ : state) {
cg.call({B_t.data_ptr<float>(), A_t.data_ptr<float>(), state.range(0)});
}
state.counters["log/s"] = benchmark::Counter(
uint64_t(state.range(0) * state.iterations()), benchmark::Counter::kIsRate);
}
static void log_nnc_fast(benchmark::State& state) {
KernelScope ks;
auto N = VarHandle("N", kInt);
Placeholder A("A", kFloat, {N});
torch::jit::tensorexpr::Tensor* B =
Compute("B", {N}, [&](const VarHandle& i) {
return fast_log(A.load(i));
});
LoopNest ln({B});
optimizePointwise(&ln, B);
ln.prepareForCodegen();
Stmt* s = ln.root_stmt();
s = torch::jit::tensorexpr::IRSimplifier::simplify(s);
std::vector<CodeGen::BufferArg> args;
args.emplace_back(B);
args.emplace_back(A);
args.emplace_back(N);
LLVMCodeGen cg(s, args);
at::Tensor A_t = torch::abs(torch::randn({state.range(0)}));
at::Tensor B_t = torch::randn({state.range(0)});
auto B_ref = at::log(A_t);
cg.call({B_t.data_ptr<float>(), A_t.data_ptr<float>(), state.range(0)});
TORCH_CHECK(at::allclose(B_t, B_ref));
for (auto _ : state) {
cg.call({B_t.data_ptr<float>(), A_t.data_ptr<float>(), state.range(0)});
}
state.counters["log/s"] = benchmark::Counter(
uint64_t(state.range(0) * state.iterations()), benchmark::Counter::kIsRate);
}
static void log_nnc_vml(benchmark::State& state) {
KernelScope ks;
auto N = VarHandle("N", kInt);
Placeholder A("A", kFloat, {N});
torch::jit::tensorexpr::Tensor* B =
Compute("B", {N}, [&](const VarHandle& i) {
return log_vml(A.load(i));
});
LoopNest ln({B});
vectorize(&ln, B, 8);
ln.prepareForCodegen();
Stmt* s = ln.root_stmt();
s = torch::jit::tensorexpr::IRSimplifier::simplify(s);
std::vector<CodeGen::BufferArg> args;
args.emplace_back(B);
args.emplace_back(A);
args.emplace_back(N);
LLVMCodeGen cg(s, args);
at::Tensor A_t = torch::abs(torch::randn({state.range(0)}));
at::Tensor B_t = torch::randn({state.range(0)});
auto B_ref = at::log(A_t);
cg.call({B_t.data_ptr<float>(), A_t.data_ptr<float>(), state.range(0)});
TORCH_CHECK(at::allclose(B_t, B_ref));
for (auto _ : state) {
cg.call({B_t.data_ptr<float>(), A_t.data_ptr<float>(), state.range(0)});
}
state.counters["log/s"] = benchmark::Counter(
uint64_t(state.range(0) * state.iterations()), benchmark::Counter::kIsRate);
}
static void log_aten(benchmark::State& state) {
at::Tensor A_t = torch::abs(torch::randn({state.range(0)}));
at::Tensor B_t = torch::randn({state.range(0)});
for (auto _ : state) {
at::native::log_out(B_t, A_t);
}
state.counters["log/s"] = benchmark::Counter(
uint64_t(state.range(0) * state.iterations()), benchmark::Counter::kIsRate);
}
static void logit_nnc_sleef(benchmark::State& state) {
KernelScope ks;
auto N = VarHandle("N", kInt);
Placeholder A("A", kFloat, {N});
auto clamp = 1e-6f;
tensorexpr::Tensor* B = Compute("B", {N}, [&](const VarHandle& i) {
auto A_elem = [&]() {
auto elem = A.load(i);
auto min = FloatImm::make(clamp);
auto max = FloatImm::make(1.0f - clamp);
elem = CompareSelect::make(elem, min, min, elem, kLT);
return CompareSelect::make(elem, max, max, elem, kGT);
}();
return log(A_elem / (FloatImm::make(1.0f) - A_elem));
});
LoopNest ln({B});
ln.prepareForCodegen();
optimizePointwise(&ln, B);
Stmt* s = ln.root_stmt();
s = torch::jit::tensorexpr::IRSimplifier::simplify(s);
std::vector<CodeGen::BufferArg> args;
args.emplace_back(B);
args.emplace_back(A);
args.emplace_back(N);
LLVMCodeGen cg(s, args);
at::Tensor A_t = torch::abs(torch::randn({state.range(0)}));
at::Tensor B_t = torch::randn({state.range(0)});
auto B_ref = at::logit(A_t, clamp);
cg.call({B_t.data_ptr<float>(), A_t.data_ptr<float>(), state.range(0)});
TORCH_CHECK(at::allclose(at::nan_to_num(B_t), at::nan_to_num(B_ref)));
for (auto _ : state) {
cg.call({B_t.data_ptr<float>(), A_t.data_ptr<float>(), state.range(0)});
}
state.counters["logit/s"] = benchmark::Counter(
uint64_t(state.range(0) * state.iterations()), benchmark::Counter::kIsRate);
}
static void logit_nnc_fast(benchmark::State& state) {
KernelScope ks;
auto N = VarHandle("N", kInt);
Placeholder A("A", kFloat, {N});
auto clamp = 1e-6f;
tensorexpr::Tensor* B = Compute("B", {N}, [&](const VarHandle& i) {
auto A_elem = [&]() {
auto elem = A.load(i);
auto min = FloatImm::make(clamp);
auto max = FloatImm::make(1.0f - clamp);
elem = CompareSelect::make(elem, min, min, elem, kLT);
return CompareSelect::make(elem, max, max, elem, kGT);
}();
return fast_log(A_elem / (FloatImm::make(1.0f) - A_elem));
});
LoopNest ln({B});
ln.prepareForCodegen();
optimizePointwise(&ln, B);
Stmt* s = ln.root_stmt();
s = torch::jit::tensorexpr::IRSimplifier::simplify(s);
std::vector<CodeGen::BufferArg> args;
args.emplace_back(B);
args.emplace_back(A);
args.emplace_back(N);
LLVMCodeGen cg(s, args);
at::Tensor A_t = torch::abs(torch::randn({state.range(0)}));
at::Tensor B_t = torch::randn({state.range(0)});
auto B_ref = at::logit(A_t, clamp);
cg.call({B_t.data_ptr<float>(), A_t.data_ptr<float>(), state.range(0)});
TORCH_CHECK(at::allclose(at::nan_to_num(B_t), at::nan_to_num(B_ref)));
for (auto _ : state) {
cg.call({B_t.data_ptr<float>(), A_t.data_ptr<float>(), state.range(0)});
}
state.counters["logit/s"] = benchmark::Counter(
uint64_t(state.range(0) * state.iterations()), benchmark::Counter::kIsRate);
}
static void logit_nnc_vml(benchmark::State& state) {
KernelScope ks;
auto N = VarHandle("N", kInt);
Placeholder A("A", kFloat, {N});
auto clamp = 1e-6f;
tensorexpr::Tensor* B = Compute("B", {N}, [&](const VarHandle& i) {
auto A_elem = [&]() {
auto elem = A.load(i);
auto min = FloatImm::make(clamp);
auto max = FloatImm::make(1.0f - clamp);
elem = CompareSelect::make(elem, min, min, elem, kLT);
return CompareSelect::make(elem, max, max, elem, kGT);
}();
return log_vml(A_elem / (FloatImm::make(1.0f) - A_elem));
});
LoopNest ln({B});
ln.prepareForCodegen();
vectorize(&ln, B, 16);
Stmt* s = ln.root_stmt();
s = torch::jit::tensorexpr::IRSimplifier::simplify(s);
std::vector<CodeGen::BufferArg> args;
args.emplace_back(B);
args.emplace_back(A);
args.emplace_back(N);
LLVMCodeGen cg(s, args);
at::Tensor A_t = torch::abs(torch::randn({state.range(0)}));
at::Tensor B_t = torch::randn({state.range(0)});
auto B_ref = at::logit(A_t, clamp);
cg.call({B_t.data_ptr<float>(), A_t.data_ptr<float>(), state.range(0)});
TORCH_CHECK(at::allclose(at::nan_to_num(B_t), at::nan_to_num(B_ref)));
for (auto _ : state) {
cg.call({B_t.data_ptr<float>(), A_t.data_ptr<float>(), state.range(0)});
}
state.counters["logit/s"] = benchmark::Counter(
uint64_t(state.range(0) * state.iterations()), benchmark::Counter::kIsRate);
}
static void logit_aten(benchmark::State& state) {
at::Tensor A_t = torch::abs(torch::randn({state.range(0)}));
at::Tensor B_t = torch::randn({state.range(0)});
auto clamp = 1e-6f;
for (auto _ : state) {
at::native::logit_out(B_t, A_t, clamp);
}
state.counters["logit/s"] = benchmark::Counter(
uint64_t(state.range(0) * state.iterations()), benchmark::Counter::kIsRate);
}
template <typename T>
void logit_caffe2_impl(int size, const T* X, T* Y, float eps_ = 1e-6f) {
using namespace caffe2;
ConstEigenVectorMap<T> X_vec(X, size);
EigenVectorMap<T> Y_vec(Y, size);
Y_vec = X_vec.array().min(static_cast<T>(1.0f - eps_));
Y_vec = Y_vec.array().max(eps_);
Y_vec = (Y_vec.array() / (T(1) - Y_vec.array())).log();
}
static void logit_caffe2(benchmark::State& state) {
at::Tensor A_t = torch::abs(torch::randn({state.range(0)}));
at::Tensor B_t = torch::randn({state.range(0)});
at::Tensor B_ref = torch::randn({state.range(0)});
auto N = state.range(0);
auto X = A_t.data_ptr<float>();
auto Y = B_t.data_ptr<float>();
auto clamp = 1e-6f;
at::native::logit_out(B_ref, A_t, clamp);
logit_caffe2_impl(N, X, Y, clamp);
TORCH_CHECK(at::allclose(at::nan_to_num(B_t), at::nan_to_num(B_ref)));
for (auto _ : state) {
logit_caffe2_impl(N, X, Y, clamp);
}
state.counters["logit/s"] = benchmark::Counter(
uint64_t(state.range(0) * state.iterations()), benchmark::Counter::kIsRate);
}
static void tanh_nnc_fast(benchmark::State& state) {
KernelScope ks;
auto N = VarHandle("N", kInt);
Placeholder A("A", kFloat, {N});
torch::jit::tensorexpr::Tensor* B =
Compute("B", {N}, [&](const VarHandle& i) {
return fast_tanh(A.load(i));
});
LoopNest ln({B});
optimizePointwise(&ln, B);
ln.prepareForCodegen();
Stmt* s = ln.root_stmt();
s = torch::jit::tensorexpr::IRSimplifier::simplify(s);
std::vector<CodeGen::BufferArg> args;
args.emplace_back(B);
args.emplace_back(A);
args.emplace_back(N);
LLVMCodeGen cg(s, args);
at::Tensor A_t = torch::abs(torch::randn({state.range(0)}));
at::Tensor B_t = torch::randn({state.range(0)});
auto B_ref = at::tanh(A_t);
cg.call({B_t.data_ptr<float>(), A_t.data_ptr<float>(), state.range(0)});
TORCH_CHECK(at::allclose(B_t, B_ref, 1e-3f, 1e-6f));
for (auto _ : state) {
cg.call({B_t.data_ptr<float>(), A_t.data_ptr<float>(), state.range(0)});
}
state.counters["tanh/s"] = benchmark::Counter(
uint64_t(state.range(0) * state.iterations()), benchmark::Counter::kIsRate);
}
static void tanh_aten(benchmark::State& state) {
at::Tensor A_t = torch::abs(torch::randn({state.range(0)}));
at::Tensor B_t = torch::randn({state.range(0)});
for (auto _ : state) {
at::native::tanh_out(B_t, A_t);
}
state.counters["tanh/s"] = benchmark::Counter(
uint64_t(state.range(0) * state.iterations()), benchmark::Counter::kIsRate);
}
static void tanh_caffe2(benchmark::State& state) {
at::Tensor A_t = torch::abs(torch::randn({state.range(0)}));
at::Tensor B_t = torch::randn({state.range(0)});
at::Tensor B_ref = torch::randn({state.range(0)});
auto N = state.range(0);
auto X = A_t.data_ptr<float>();
auto Y = B_t.data_ptr<float>();
caffe2::CPUContext c;
auto tanh = caffe2::TanhFunctor<caffe2::CPUContext>();
at::native::tanh_out(B_ref, A_t);
tanh(N, X, Y, &c);
TORCH_CHECK(at::native::allclose(B_t, B_ref, 1e-3f, 1e-6f));
for (auto _ : state) {
tanh(N, X, Y, &c);
}
state.counters["tanh/s"] = benchmark::Counter(
uint64_t(state.range(0) * state.iterations()), benchmark::Counter::kIsRate);
}
BENCHMARK(relu_nnc)
->Args({2<<5})
->Args({2<<8})
->Args({2<<12})
->Args({2<<14});
BENCHMARK(log_nnc_sleef)
->Args({2<<5})
->Args({2<<8})
->Args({2<<12})
->Args({2<<14});
BENCHMARK(log_nnc_fast)
->Args({2<<5})
->Args({2<<8})
->Args({2<<12})
->Args({2<<14});
BENCHMARK(log_nnc_vml)
->Args({2<<5})
->Args({2<<8})
->Args({2<<12})
->Args({2<<14});
BENCHMARK(log_aten)
->Args({2<<5})
->Args({2<<8})
->Args({2<<12})
->Args({2<<14});
BENCHMARK(logit_nnc_sleef)
->Args({2<<5})
->Args({2<<8})
->Args({2<<12})
->Args({2<<14});
BENCHMARK(logit_nnc_fast)
->Args({2<<5})
->Args({2<<8})
->Args({2<<12})
->Args({2<<14});
BENCHMARK(logit_nnc_vml)
->Args({2<<5})
->Args({2<<8})
->Args({2<<12})
->Args({2<<14});
BENCHMARK(logit_aten)
->Args({2<<5})
->Args({2<<8})
->Args({2<<12})
->Args({2<<14});
BENCHMARK(logit_caffe2)
->Args({2<<5})
->Args({2<<8})
->Args({2<<12})
->Args({2<<14});
BENCHMARK(tanh_nnc_fast)
->Args({2<<5})
->Args({2<<8})
->Args({2<<12})
->Args({2<<14});
BENCHMARK(tanh_aten)
->Args({2<<5})
->Args({2<<8})
->Args({2<<12})
->Args({2<<14});
BENCHMARK(tanh_caffe2)
->Args({2<<5})
->Args({2<<8})
->Args({2<<12})
->Args({2<<14});