mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/52251 Batchnorm in inference is just a bunch of pointwise ops. NNC should be able to do a good job of this, and indeed it does. For fun I've included a fused BN->Relu (although the real fusion fun would be Conv->BN->Relu...). ``` --------------------------------------------------------------------------------------- Benchmark Time CPU Iterations UserCounters... --------------------------------------------------------------------------------------- BatchNorm/ATen/1/64/112/112 252886 ns 252875 ns 2785 GB/s=25.3981G/s BatchNorm/ATen/1/256/14/14 12145 ns 12145 ns 55347 GB/s=33.0525G/s BatchNorm/ATen/1/128/28/28 18919 ns 18918 ns 37749 GB/s=42.437G/s BatchNorm/ATen/1/64/56/56 61434 ns 61433 ns 11315 GB/s=26.1363G/s BatchNorm/ATen/1/512/7/7 11924 ns 11923 ns 59070 GB/s=16.8327G/s BatchNorm/ATen/5/64/112/112 1873321 ns 1873292 ns 382 GB/s=17.1424G/s BatchNorm/ATen/5/256/14/14 83470 ns 83459 ns 8538 GB/s=24.0483G/s BatchNorm/ATen/5/128/28/28 157521 ns 157520 ns 4440 GB/s=25.4829G/s BatchNorm/ATen/5/64/56/56 314675 ns 314670 ns 2235 GB/s=25.513G/s BatchNorm/ATen/5/512/7/7 48129 ns 48128 ns 14582 GB/s=20.851G/s BatchNorm/NNC/1/64/112/112 249454 ns 249428 ns 2802 GB/s=25.749G/s BatchNorm/NNC/1/256/14/14 9321 ns 9321 ns 74573 GB/s=43.066G/s BatchNorm/NNC/1/128/28/28 16874 ns 16873 ns 40999 GB/s=47.5797G/s BatchNorm/NNC/1/64/56/56 59276 ns 59275 ns 12047 GB/s=27.0878G/s BatchNorm/NNC/1/512/7/7 3452 ns 3452 ns 202610 GB/s=58.1394G/s BatchNorm/NNC/5/64/112/112 1820201 ns 1820038 ns 373 GB/s=17.6439G/s BatchNorm/NNC/5/256/14/14 78429 ns 78420 ns 8871 GB/s=25.5935G/s BatchNorm/NNC/5/128/28/28 155214 ns 155202 ns 4514 GB/s=25.8635G/s BatchNorm/NNC/5/64/56/56 311454 ns 311449 ns 2163 GB/s=25.7768G/s BatchNorm/NNC/5/512/7/7 26853 ns 26851 ns 25283 GB/s=37.3735G/s BatchNorm/ATenRelu/1/64/112/112 378879 ns 378849 ns 1844 GB/s=16.9528G/s BatchNorm/ATenRelu/1/256/14/14 16707 ns 16705 ns 41391 GB/s=24.029G/s BatchNorm/ATenRelu/1/128/28/28 30235 ns 30235 ns 23060 GB/s=26.5529G/s BatchNorm/ATenRelu/1/64/56/56 91164 ns 91160 ns 7662 GB/s=17.6132G/s BatchNorm/ATenRelu/1/512/7/7 14681 ns 14681 ns 46088 GB/s=13.6707G/s BatchNorm/ATenRelu/5/64/112/112 2864060 ns 2863566 ns 243 GB/s=11.2142G/s BatchNorm/ATenRelu/5/256/14/14 118376 ns 118367 ns 5907 GB/s=16.9561G/s BatchNorm/ATenRelu/5/128/28/28 237893 ns 237873 ns 2936 GB/s=16.8749G/s BatchNorm/ATenRelu/5/64/56/56 472452 ns 472386 ns 1479 GB/s=16.9949G/s BatchNorm/ATenRelu/5/512/7/7 61389 ns 61379 ns 11442 GB/s=16.3496G/s BatchNorm/NNCRelu/1/64/112/112 248378 ns 248341 ns 2812 GB/s=25.8618G/s BatchNorm/NNCRelu/1/256/14/14 9965 ns 9964 ns 76013 GB/s=40.2861G/s BatchNorm/NNCRelu/1/128/28/28 16153 ns 16153 ns 43343 GB/s=49.7004G/s BatchNorm/NNCRelu/1/64/56/56 58761 ns 58757 ns 12095 GB/s=27.3265G/s BatchNorm/NNCRelu/1/512/7/7 10529 ns 10529 ns 66590 GB/s=19.0625G/s BatchNorm/NNCRelu/5/64/112/112 1799001 ns 1798757 ns 362 GB/s=17.8527G/s BatchNorm/NNCRelu/5/256/14/14 78252 ns 78246 ns 8974 GB/s=25.6504G/s BatchNorm/NNCRelu/5/128/28/28 154940 ns 154923 ns 4483 GB/s=25.9102G/s BatchNorm/NNCRelu/5/64/56/56 312329 ns 312324 ns 2244 GB/s=25.7046G/s BatchNorm/NNCRelu/5/512/7/7 51203 ns 51199 ns 13559 GB/s=19.6004G/s ``` Test Plan: Imported from OSS Reviewed By: ZolotukhinM Differential Revision: D26440786 Pulled By: bertmaher fbshipit-source-id: 7d3f7bf6eee4c37736e9875d31ae1b483af9fb6f
219 lines
6.0 KiB
C++
219 lines
6.0 KiB
C++
#include <benchmark/benchmark.h>
|
|
#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
|
|
#include <torch/csrc/jit/tensorexpr/llvm_codegen.h>
|
|
#include <torch/csrc/jit/tensorexpr/loopnest.h>
|
|
#include <torch/csrc/jit/tensorexpr/tensor.h>
|
|
#include <torch/torch.h>
|
|
|
|
using namespace torch::jit::tensorexpr;
|
|
|
|
namespace {
|
|
class BatchNorm : public benchmark::Fixture {
|
|
public:
|
|
void SetUp(const benchmark::State& state) override {
|
|
N_ = state.range(0);
|
|
C_ = state.range(1);
|
|
H_ = state.range(2);
|
|
W_ = state.range(3);
|
|
input_ = torch::ones({N_, C_, H_, W_});
|
|
weight_ = torch::ones({C_});
|
|
bias_ = torch::ones({C_});
|
|
mean_ = torch::ones({C_}) * 0.5f;
|
|
var_ = torch::ones({C_}) * 0.1f;
|
|
ref_ = at::batch_norm(
|
|
input_,
|
|
weight_,
|
|
bias_,
|
|
mean_,
|
|
var_,
|
|
training_,
|
|
momentum_,
|
|
eps_,
|
|
cudnn_enabled_);
|
|
output_ = at::empty_like(ref_);
|
|
}
|
|
|
|
void TearDown(benchmark::State& state) override {
|
|
TORCH_CHECK(at::allclose(ref_, output_));
|
|
state.counters["GB/s"] = benchmark::Counter(
|
|
uint64_t(state.iterations()) * (input_.nbytes() + ref_.nbytes()),
|
|
benchmark::Counter::kIsRate);
|
|
}
|
|
|
|
int N_;
|
|
int C_;
|
|
int H_;
|
|
int W_;
|
|
at::Tensor input_;
|
|
at::Tensor weight_;
|
|
at::Tensor bias_;
|
|
at::Tensor mean_;
|
|
at::Tensor var_;
|
|
at::Tensor output_;
|
|
at::Tensor ref_;
|
|
bool training_{false};
|
|
float momentum_{0.1};
|
|
float eps_{1.0e-5f};
|
|
bool cudnn_enabled_{false};
|
|
};
|
|
} // namespace
|
|
|
|
BENCHMARK_DEFINE_F(BatchNorm, ATen)(benchmark::State& state) {
|
|
for (auto _ : state) {
|
|
output_ = at::batch_norm(
|
|
input_,
|
|
weight_,
|
|
bias_,
|
|
mean_,
|
|
var_,
|
|
training_,
|
|
momentum_,
|
|
eps_,
|
|
cudnn_enabled_);
|
|
}
|
|
}
|
|
|
|
BENCHMARK_DEFINE_F(BatchNorm, NNC)(benchmark::State& state) {
|
|
KernelScope ks;
|
|
|
|
Placeholder input("input", kFloat, {N_, C_, H_, W_});
|
|
Placeholder weight("weight", kFloat, {C_});
|
|
Placeholder bias("bias", kFloat, {C_});
|
|
Placeholder mean("mean", kFloat, {C_});
|
|
Placeholder var("var", kFloat, {C_});
|
|
VarHandle eps("eps", kFloat);
|
|
|
|
using axis = const VarHandle&;
|
|
Tensor* output = Compute(
|
|
"output",
|
|
{{N_, "N"}, {C_, "C"}, {H_, "H"}, {W_, "W"}},
|
|
[&](axis n, axis c, axis h, axis w) {
|
|
// Compute affine terms.
|
|
auto inv_var = FloatImm::make(1.0f) / sqrt(var.load(c) + eps);
|
|
auto weight_v = weight.load(c);
|
|
auto bias_v = bias.load(c);
|
|
auto alpha = inv_var * weight_v;
|
|
auto beta = bias_v - mean.load(c) * alpha;
|
|
|
|
return input.load(n, c, h, w) * alpha + beta;
|
|
});
|
|
LoopNest nest({output});
|
|
nest.prepareForCodegen();
|
|
Stmt* s = IRSimplifier::simplify(nest.root_stmt());
|
|
LLVMCodeGen cg(s, {input, weight, bias, mean, var, output, eps});
|
|
|
|
std::vector<CodeGen::CallArg> args;
|
|
for (auto _ : state) {
|
|
args.clear();
|
|
output_ = at::empty_like(input_);
|
|
for (auto const& t : {input_, weight_, bias_, mean_, var_, output_}) {
|
|
args.push_back(t.data_ptr<float>());
|
|
}
|
|
args.push_back(eps_);
|
|
cg.call(args);
|
|
}
|
|
}
|
|
|
|
BENCHMARK_DEFINE_F(BatchNorm, ATenRelu)(benchmark::State& state) {
|
|
for (auto _ : state) {
|
|
output_ = at::batch_norm(
|
|
input_,
|
|
weight_,
|
|
bias_,
|
|
mean_,
|
|
var_,
|
|
training_,
|
|
momentum_,
|
|
eps_,
|
|
cudnn_enabled_);
|
|
output_.relu_();
|
|
}
|
|
}
|
|
|
|
BENCHMARK_DEFINE_F(BatchNorm, NNCRelu)(benchmark::State& state) {
|
|
KernelScope ks;
|
|
|
|
Placeholder input("input", kFloat, {N_, C_, H_, W_});
|
|
Placeholder weight("weight", kFloat, {C_});
|
|
Placeholder bias("bias", kFloat, {C_});
|
|
Placeholder mean("mean", kFloat, {C_});
|
|
Placeholder var("var", kFloat, {C_});
|
|
VarHandle eps("eps", kFloat);
|
|
|
|
using axis = const VarHandle&;
|
|
Tensor* output = Compute(
|
|
"output",
|
|
{{N_, "N"}, {C_, "C"}, {H_, "H"}, {W_, "W"}},
|
|
[&](axis n, axis c, axis h, axis w) {
|
|
// Compute affine terms.
|
|
auto inv_var = FloatImm::make(1.0f) / sqrt(var.load(c) + eps);
|
|
auto weight_v = weight.load(c);
|
|
auto bias_v = bias.load(c);
|
|
auto alpha = inv_var * weight_v;
|
|
auto beta = bias_v - mean.load(c) * alpha;
|
|
|
|
auto bn = input.load(n, c, h, w) * alpha + beta;
|
|
return CompareSelect::make(bn, 0.f, 0.f, bn, kLT);
|
|
});
|
|
LoopNest nest({output});
|
|
nest.prepareForCodegen();
|
|
Stmt* s = IRSimplifier::simplify(nest.root_stmt());
|
|
LLVMCodeGen cg(s, {input, weight, bias, mean, var, output, eps});
|
|
|
|
std::vector<CodeGen::CallArg> args;
|
|
for (auto _ : state) {
|
|
args.clear();
|
|
output_ = at::empty_like(input_);
|
|
for (auto const& t : {input_, weight_, bias_, mean_, var_, output_}) {
|
|
args.push_back(t.data_ptr<float>());
|
|
}
|
|
args.push_back(eps_);
|
|
cg.call(args);
|
|
}
|
|
}
|
|
|
|
BENCHMARK_REGISTER_F(BatchNorm, ATen)
|
|
->Args({1, 64, 112, 112})
|
|
->Args({1, 256, 14, 14})
|
|
->Args({1, 128, 28, 28})
|
|
->Args({1, 64, 56, 56})
|
|
->Args({1, 512, 7, 7})
|
|
->Args({5, 64, 112, 112})
|
|
->Args({5, 256, 14, 14})
|
|
->Args({5, 128, 28, 28})
|
|
->Args({5, 64, 56, 56})
|
|
->Args({5, 512, 7, 7});
|
|
BENCHMARK_REGISTER_F(BatchNorm, NNC)
|
|
->Args({1, 64, 112, 112})
|
|
->Args({1, 256, 14, 14})
|
|
->Args({1, 128, 28, 28})
|
|
->Args({1, 64, 56, 56})
|
|
->Args({1, 512, 7, 7})
|
|
->Args({5, 64, 112, 112})
|
|
->Args({5, 256, 14, 14})
|
|
->Args({5, 128, 28, 28})
|
|
->Args({5, 64, 56, 56})
|
|
->Args({5, 512, 7, 7});
|
|
BENCHMARK_REGISTER_F(BatchNorm, ATenRelu)
|
|
->Args({1, 64, 112, 112})
|
|
->Args({1, 256, 14, 14})
|
|
->Args({1, 128, 28, 28})
|
|
->Args({1, 64, 56, 56})
|
|
->Args({1, 512, 7, 7})
|
|
->Args({5, 64, 112, 112})
|
|
->Args({5, 256, 14, 14})
|
|
->Args({5, 128, 28, 28})
|
|
->Args({5, 64, 56, 56})
|
|
->Args({5, 512, 7, 7});
|
|
BENCHMARK_REGISTER_F(BatchNorm, NNCRelu)
|
|
->Args({1, 64, 112, 112})
|
|
->Args({1, 256, 14, 14})
|
|
->Args({1, 128, 28, 28})
|
|
->Args({1, 64, 56, 56})
|
|
->Args({1, 512, 7, 7})
|
|
->Args({5, 64, 112, 112})
|
|
->Args({5, 256, 14, 14})
|
|
->Args({5, 128, 28, 28})
|
|
->Args({5, 64, 56, 56})
|
|
->Args({5, 512, 7, 7});
|