pytorch/benchmarks/cpp/tensorexpr/bench_batchnorm.cpp
Bert Maher 71d5a8ea62 [nnc] Benchmark inference batchnorm (#52251)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/52251

Batchnorm in inference is just a bunch of pointwise ops.  NNC
should be able to do a good job of this, and indeed it does.  For fun
I've included a fused BN->Relu (although the real fusion fun would be
Conv->BN->Relu...).

```
---------------------------------------------------------------------------------------
Benchmark                                Time           CPU Iterations UserCounters...
---------------------------------------------------------------------------------------
BatchNorm/ATen/1/64/112/112         252886 ns     252875 ns       2785 GB/s=25.3981G/s
BatchNorm/ATen/1/256/14/14           12145 ns      12145 ns      55347 GB/s=33.0525G/s
BatchNorm/ATen/1/128/28/28           18919 ns      18918 ns      37749 GB/s=42.437G/s
BatchNorm/ATen/1/64/56/56            61434 ns      61433 ns      11315 GB/s=26.1363G/s
BatchNorm/ATen/1/512/7/7             11924 ns      11923 ns      59070 GB/s=16.8327G/s
BatchNorm/ATen/5/64/112/112        1873321 ns    1873292 ns        382 GB/s=17.1424G/s
BatchNorm/ATen/5/256/14/14           83470 ns      83459 ns       8538 GB/s=24.0483G/s
BatchNorm/ATen/5/128/28/28          157521 ns     157520 ns       4440 GB/s=25.4829G/s
BatchNorm/ATen/5/64/56/56           314675 ns     314670 ns       2235 GB/s=25.513G/s
BatchNorm/ATen/5/512/7/7             48129 ns      48128 ns      14582 GB/s=20.851G/s

BatchNorm/NNC/1/64/112/112          249454 ns     249428 ns       2802 GB/s=25.749G/s
BatchNorm/NNC/1/256/14/14             9321 ns       9321 ns      74573 GB/s=43.066G/s
BatchNorm/NNC/1/128/28/28            16874 ns      16873 ns      40999 GB/s=47.5797G/s
BatchNorm/NNC/1/64/56/56             59276 ns      59275 ns      12047 GB/s=27.0878G/s
BatchNorm/NNC/1/512/7/7               3452 ns       3452 ns     202610 GB/s=58.1394G/s
BatchNorm/NNC/5/64/112/112         1820201 ns    1820038 ns        373 GB/s=17.6439G/s
BatchNorm/NNC/5/256/14/14            78429 ns      78420 ns       8871 GB/s=25.5935G/s
BatchNorm/NNC/5/128/28/28           155214 ns     155202 ns       4514 GB/s=25.8635G/s
BatchNorm/NNC/5/64/56/56            311454 ns     311449 ns       2163 GB/s=25.7768G/s
BatchNorm/NNC/5/512/7/7              26853 ns      26851 ns      25283 GB/s=37.3735G/s

BatchNorm/ATenRelu/1/64/112/112     378879 ns     378849 ns       1844 GB/s=16.9528G/s
BatchNorm/ATenRelu/1/256/14/14       16707 ns      16705 ns      41391 GB/s=24.029G/s
BatchNorm/ATenRelu/1/128/28/28       30235 ns      30235 ns      23060 GB/s=26.5529G/s
BatchNorm/ATenRelu/1/64/56/56        91164 ns      91160 ns       7662 GB/s=17.6132G/s
BatchNorm/ATenRelu/1/512/7/7         14681 ns      14681 ns      46088 GB/s=13.6707G/s
BatchNorm/ATenRelu/5/64/112/112    2864060 ns    2863566 ns        243 GB/s=11.2142G/s
BatchNorm/ATenRelu/5/256/14/14      118376 ns     118367 ns       5907 GB/s=16.9561G/s
BatchNorm/ATenRelu/5/128/28/28      237893 ns     237873 ns       2936 GB/s=16.8749G/s
BatchNorm/ATenRelu/5/64/56/56       472452 ns     472386 ns       1479 GB/s=16.9949G/s
BatchNorm/ATenRelu/5/512/7/7         61389 ns      61379 ns      11442 GB/s=16.3496G/s

BatchNorm/NNCRelu/1/64/112/112      248378 ns     248341 ns       2812 GB/s=25.8618G/s
BatchNorm/NNCRelu/1/256/14/14         9965 ns       9964 ns      76013 GB/s=40.2861G/s
BatchNorm/NNCRelu/1/128/28/28        16153 ns      16153 ns      43343 GB/s=49.7004G/s
BatchNorm/NNCRelu/1/64/56/56         58761 ns      58757 ns      12095 GB/s=27.3265G/s
BatchNorm/NNCRelu/1/512/7/7          10529 ns      10529 ns      66590 GB/s=19.0625G/s
BatchNorm/NNCRelu/5/64/112/112     1799001 ns    1798757 ns        362 GB/s=17.8527G/s
BatchNorm/NNCRelu/5/256/14/14        78252 ns      78246 ns       8974 GB/s=25.6504G/s
BatchNorm/NNCRelu/5/128/28/28       154940 ns     154923 ns       4483 GB/s=25.9102G/s
BatchNorm/NNCRelu/5/64/56/56        312329 ns     312324 ns       2244 GB/s=25.7046G/s
BatchNorm/NNCRelu/5/512/7/7          51203 ns      51199 ns      13559 GB/s=19.6004G/s
```

Test Plan: Imported from OSS

Reviewed By: ZolotukhinM

Differential Revision: D26440786

Pulled By: bertmaher

fbshipit-source-id: 7d3f7bf6eee4c37736e9875d31ae1b483af9fb6f
2021-02-16 10:57:38 -08:00

219 lines
6.0 KiB
C++

#include <benchmark/benchmark.h>
#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
#include <torch/csrc/jit/tensorexpr/llvm_codegen.h>
#include <torch/csrc/jit/tensorexpr/loopnest.h>
#include <torch/csrc/jit/tensorexpr/tensor.h>
#include <torch/torch.h>
using namespace torch::jit::tensorexpr;
namespace {
class BatchNorm : public benchmark::Fixture {
public:
void SetUp(const benchmark::State& state) override {
N_ = state.range(0);
C_ = state.range(1);
H_ = state.range(2);
W_ = state.range(3);
input_ = torch::ones({N_, C_, H_, W_});
weight_ = torch::ones({C_});
bias_ = torch::ones({C_});
mean_ = torch::ones({C_}) * 0.5f;
var_ = torch::ones({C_}) * 0.1f;
ref_ = at::batch_norm(
input_,
weight_,
bias_,
mean_,
var_,
training_,
momentum_,
eps_,
cudnn_enabled_);
output_ = at::empty_like(ref_);
}
void TearDown(benchmark::State& state) override {
TORCH_CHECK(at::allclose(ref_, output_));
state.counters["GB/s"] = benchmark::Counter(
uint64_t(state.iterations()) * (input_.nbytes() + ref_.nbytes()),
benchmark::Counter::kIsRate);
}
int N_;
int C_;
int H_;
int W_;
at::Tensor input_;
at::Tensor weight_;
at::Tensor bias_;
at::Tensor mean_;
at::Tensor var_;
at::Tensor output_;
at::Tensor ref_;
bool training_{false};
float momentum_{0.1};
float eps_{1.0e-5f};
bool cudnn_enabled_{false};
};
} // namespace
BENCHMARK_DEFINE_F(BatchNorm, ATen)(benchmark::State& state) {
for (auto _ : state) {
output_ = at::batch_norm(
input_,
weight_,
bias_,
mean_,
var_,
training_,
momentum_,
eps_,
cudnn_enabled_);
}
}
BENCHMARK_DEFINE_F(BatchNorm, NNC)(benchmark::State& state) {
KernelScope ks;
Placeholder input("input", kFloat, {N_, C_, H_, W_});
Placeholder weight("weight", kFloat, {C_});
Placeholder bias("bias", kFloat, {C_});
Placeholder mean("mean", kFloat, {C_});
Placeholder var("var", kFloat, {C_});
VarHandle eps("eps", kFloat);
using axis = const VarHandle&;
Tensor* output = Compute(
"output",
{{N_, "N"}, {C_, "C"}, {H_, "H"}, {W_, "W"}},
[&](axis n, axis c, axis h, axis w) {
// Compute affine terms.
auto inv_var = FloatImm::make(1.0f) / sqrt(var.load(c) + eps);
auto weight_v = weight.load(c);
auto bias_v = bias.load(c);
auto alpha = inv_var * weight_v;
auto beta = bias_v - mean.load(c) * alpha;
return input.load(n, c, h, w) * alpha + beta;
});
LoopNest nest({output});
nest.prepareForCodegen();
Stmt* s = IRSimplifier::simplify(nest.root_stmt());
LLVMCodeGen cg(s, {input, weight, bias, mean, var, output, eps});
std::vector<CodeGen::CallArg> args;
for (auto _ : state) {
args.clear();
output_ = at::empty_like(input_);
for (auto const& t : {input_, weight_, bias_, mean_, var_, output_}) {
args.push_back(t.data_ptr<float>());
}
args.push_back(eps_);
cg.call(args);
}
}
BENCHMARK_DEFINE_F(BatchNorm, ATenRelu)(benchmark::State& state) {
for (auto _ : state) {
output_ = at::batch_norm(
input_,
weight_,
bias_,
mean_,
var_,
training_,
momentum_,
eps_,
cudnn_enabled_);
output_.relu_();
}
}
BENCHMARK_DEFINE_F(BatchNorm, NNCRelu)(benchmark::State& state) {
KernelScope ks;
Placeholder input("input", kFloat, {N_, C_, H_, W_});
Placeholder weight("weight", kFloat, {C_});
Placeholder bias("bias", kFloat, {C_});
Placeholder mean("mean", kFloat, {C_});
Placeholder var("var", kFloat, {C_});
VarHandle eps("eps", kFloat);
using axis = const VarHandle&;
Tensor* output = Compute(
"output",
{{N_, "N"}, {C_, "C"}, {H_, "H"}, {W_, "W"}},
[&](axis n, axis c, axis h, axis w) {
// Compute affine terms.
auto inv_var = FloatImm::make(1.0f) / sqrt(var.load(c) + eps);
auto weight_v = weight.load(c);
auto bias_v = bias.load(c);
auto alpha = inv_var * weight_v;
auto beta = bias_v - mean.load(c) * alpha;
auto bn = input.load(n, c, h, w) * alpha + beta;
return CompareSelect::make(bn, 0.f, 0.f, bn, kLT);
});
LoopNest nest({output});
nest.prepareForCodegen();
Stmt* s = IRSimplifier::simplify(nest.root_stmt());
LLVMCodeGen cg(s, {input, weight, bias, mean, var, output, eps});
std::vector<CodeGen::CallArg> args;
for (auto _ : state) {
args.clear();
output_ = at::empty_like(input_);
for (auto const& t : {input_, weight_, bias_, mean_, var_, output_}) {
args.push_back(t.data_ptr<float>());
}
args.push_back(eps_);
cg.call(args);
}
}
BENCHMARK_REGISTER_F(BatchNorm, ATen)
->Args({1, 64, 112, 112})
->Args({1, 256, 14, 14})
->Args({1, 128, 28, 28})
->Args({1, 64, 56, 56})
->Args({1, 512, 7, 7})
->Args({5, 64, 112, 112})
->Args({5, 256, 14, 14})
->Args({5, 128, 28, 28})
->Args({5, 64, 56, 56})
->Args({5, 512, 7, 7});
BENCHMARK_REGISTER_F(BatchNorm, NNC)
->Args({1, 64, 112, 112})
->Args({1, 256, 14, 14})
->Args({1, 128, 28, 28})
->Args({1, 64, 56, 56})
->Args({1, 512, 7, 7})
->Args({5, 64, 112, 112})
->Args({5, 256, 14, 14})
->Args({5, 128, 28, 28})
->Args({5, 64, 56, 56})
->Args({5, 512, 7, 7});
BENCHMARK_REGISTER_F(BatchNorm, ATenRelu)
->Args({1, 64, 112, 112})
->Args({1, 256, 14, 14})
->Args({1, 128, 28, 28})
->Args({1, 64, 56, 56})
->Args({1, 512, 7, 7})
->Args({5, 64, 112, 112})
->Args({5, 256, 14, 14})
->Args({5, 128, 28, 28})
->Args({5, 64, 56, 56})
->Args({5, 512, 7, 7});
BENCHMARK_REGISTER_F(BatchNorm, NNCRelu)
->Args({1, 64, 112, 112})
->Args({1, 256, 14, 14})
->Args({1, 128, 28, 28})
->Args({1, 64, 56, 56})
->Args({1, 512, 7, 7})
->Args({5, 64, 112, 112})
->Args({5, 256, 14, 14})
->Args({5, 128, 28, 28})
->Args({5, 64, 56, 56})
->Args({5, 512, 7, 7});