#include #include #include #include #include #include #include using namespace torch::jit::tensorexpr; namespace { class ConcatBench : public benchmark::Fixture { public: void init(const std::vector> input_sizes, int concat_dim) { input_sizes_ = std::move(input_sizes); concat_dim_ = concat_dim; inputs_.resize(input_sizes_.size()); for (const auto i : c10::irange(input_sizes_.size())) { inputs_[i] = torch::ones({input_sizes_[i][0], input_sizes_[i][1]}); } output_size_.resize(input_sizes_.front().size()); for (const auto i : c10::irange(output_size_.size())) { if (i == static_cast(concat_dim_)) { output_size_[i] = 0; for (const auto j : c10::irange(input_sizes_.size())) { output_size_[i] += input_sizes_[j][i]; } } else { output_size_[i] = input_sizes_.front()[i]; } } ref_ = at::cat(inputs_, concat_dim_); output_ = at::empty_like(ref_); } void TearDown(benchmark::State& state) override { TORCH_CHECK(at::allclose(ref_, output_)); state.counters["GB/s"] = benchmark::Counter( uint64_t(state.iterations()) * 2 * output_.nbytes(), benchmark::Counter::kIsRate); } void runATen(benchmark::State& state) { for (auto _ : state) { output_ = at::cat(inputs_, concat_dim_); } } void runNNC(benchmark::State& state) { size_t num_inputs = inputs_.size(); size_t num_dims = 2; std::vector inputs; for (size_t i = 0; i < num_inputs; ++i) { inputs.emplace_back(BufHandle( "input" + std::to_string(i), {input_sizes_[i][0], input_sizes_[i][1]}, kFloat)); } Tensor output = Compute( "aten_cat", {output_size_[0], output_size_[1]}, [&](const VarHandle& m, const VarHandle& n) { int d = 0; std::vector cumulative_concat_dim_sizes(num_inputs); for (const auto i : c10::irange(num_inputs)) { cumulative_concat_dim_sizes[i] = d; d += input_sizes_[i][concat_dim_]; } auto load = inputs.back().load(m, n - cumulative_concat_dim_sizes.back()); for (size_t i = num_inputs - 1; i > 0; --i) { load = ifThenElse( CompareSelect::make( n, IntImm::make(cumulative_concat_dim_sizes[i]), kLT), inputs[i - 1].load(m, n - cumulative_concat_dim_sizes[i - 1]), load); } return load; }); LoopNest nest({output}); nest.prepareForCodegen(); StmtPtr s = IRSimplifier::simplify(nest.root_stmt()); std::vector buf_args(inputs.begin(), inputs.end()); buf_args.push_back(output); LLVMCodeGen cg(s, buf_args); std::vector call_args; for (auto _ : state) { output_ = at::empty_like(ref_); call_args.clear(); for (const auto& inp : inputs_) { call_args.push_back(inp.data_ptr()); } call_args.push_back(output_.data_ptr()); cg.call(call_args); } } void runNNCLoop(benchmark::State& state) { size_t num_inputs = inputs_.size(); size_t num_dims = 2; TORCH_INTERNAL_ASSERT(concat_dim_ == 1); auto output_buf = alloc( alloc("aten_cat", kHandle), std::vector( {alloc(output_size_[0]), alloc(output_size_[1])}), kFloat); std::vector inputs; std::vector for_stmts(num_inputs); int cumulative_input_sizes = 0; for (size_t i = 0; i < num_inputs; ++i) { inputs.emplace_back(BufHandle( "input" + std::to_string(i), {input_sizes_[i][0], input_sizes_[i][1]}, kFloat)); std::vector for_vars(num_inputs); for (const auto d : c10::irange(num_dims)) { for_vars[d] = alloc("i" + std::to_string(i) + "_" + std::to_string(d), kInt); } auto store = alloc( output_buf, std::vector( {for_vars[0], alloc(for_vars[1], alloc(cumulative_input_sizes))}), alloc( inputs[i].node(), std::vector({for_vars[0], for_vars[1]}))); auto for_st = alloc( for_vars[0], alloc(0), alloc(input_sizes_[i][0]), alloc( for_vars[1], alloc(0), alloc(input_sizes_[i][1]), store)); for_stmts[i] = for_st; cumulative_input_sizes += input_sizes_[i][1]; } auto output = Tensor(output_buf, alloc(for_stmts)); LoopNest nest({output}); nest.prepareForCodegen(); nest.vectorizeInnerLoops(); StmtPtr s = IRSimplifier::simplify(nest.root_stmt()); std::vector buf_args(inputs.begin(), inputs.end()); buf_args.push_back(output); LLVMCodeGen cg(s, buf_args); std::vector call_args; for (auto _ : state) { output_ = at::empty_like(ref_); call_args.clear(); for (const auto& inp : inputs_) { call_args.push_back(inp.data_ptr()); } call_args.push_back(output_.data_ptr()); cg.call(call_args); } } std::vector> input_sizes_; int concat_dim_; std::vector inputs_; std::vector output_size_; at::Tensor output_; at::Tensor ref_; }; class Concat2D2Input : public ConcatBench { public: void SetUp(const benchmark::State& state) override { init( {{state.range(0), state.range(1)}, {state.range(2), state.range(3)}}, state.range(4)); } }; } // namespace BENCHMARK_DEFINE_F(Concat2D2Input, ATen)(benchmark::State& state) { runATen(state); } BENCHMARK_DEFINE_F(Concat2D2Input, NNC)(benchmark::State& state) { runNNC(state); } BENCHMARK_DEFINE_F(Concat2D2Input, NNCLoop)(benchmark::State& state) { runNNCLoop(state); } BENCHMARK_REGISTER_F(Concat2D2Input, ATen) ->Args({1, 160, 1, 14, 1}) ->Args({1, 580, 1, 174, 1}) ->Args({20, 160, 20, 14, 1}) ->Args({20, 580, 20, 174, 1}) ->Args({8, 512, 8, 512, 1}); BENCHMARK_REGISTER_F(Concat2D2Input, NNC) ->Args({1, 160, 1, 14, 1}) ->Args({1, 580, 1, 174, 1}) ->Args({20, 160, 20, 14, 1}) ->Args({20, 580, 20, 174, 1}) ->Args({8, 512, 8, 512, 1}); BENCHMARK_REGISTER_F(Concat2D2Input, NNCLoop) ->Args({1, 160, 1, 14, 1}) ->Args({1, 580, 1, 174, 1}) ->Args({20, 160, 20, 14, 1}) ->Args({20, 580, 20, 174, 1}) ->Args({8, 512, 8, 512, 1}); namespace { class Concat2D3Input : public ConcatBench { public: void SetUp(const benchmark::State& state) override { init( {{state.range(0), state.range(1)}, {state.range(2), state.range(3)}, {state.range(4), state.range(5)}}, state.range(6)); } }; } // namespace BENCHMARK_DEFINE_F(Concat2D3Input, ATen)(benchmark::State& state) { runATen(state); } BENCHMARK_DEFINE_F(Concat2D3Input, NNC)(benchmark::State& state) { runNNC(state); } BENCHMARK_DEFINE_F(Concat2D3Input, NNCLoop)(benchmark::State& state) { runNNCLoop(state); } BENCHMARK_REGISTER_F(Concat2D3Input, ATen)->Args({8, 512, 8, 512, 8, 512, 1}); BENCHMARK_REGISTER_F(Concat2D3Input, NNC)->Args({8, 512, 8, 512, 8, 512, 1}); BENCHMARK_REGISTER_F(Concat2D3Input, NNCLoop) ->Args({8, 512, 8, 512, 8, 512, 1}); namespace { class Concat2D7Input : public ConcatBench { public: void SetUp(const benchmark::State& state) override { init( {{state.range(0), state.range(1)}, {state.range(2), state.range(3)}, {state.range(4), state.range(5)}, {state.range(6), state.range(7)}, {state.range(8), state.range(9)}, {state.range(10), state.range(11)}, {state.range(12), state.range(13)}}, state.range(14)); } }; } // namespace BENCHMARK_DEFINE_F(Concat2D7Input, ATen)(benchmark::State& state) { runATen(state); } BENCHMARK_DEFINE_F(Concat2D7Input, NNC)(benchmark::State& state) { runNNC(state); } BENCHMARK_DEFINE_F(Concat2D7Input, NNCLoop)(benchmark::State& state) { runNNCLoop(state); } BENCHMARK_REGISTER_F(Concat2D7Input, ATen) ->Args({8, 128, 8, 256, 8, 384, 8, 512, 8, 512, 8, 512, 8, 512, 1}); BENCHMARK_REGISTER_F(Concat2D7Input, NNC) ->Args({8, 128, 8, 256, 8, 384, 8, 512, 8, 512, 8, 512, 8, 512, 1}); BENCHMARK_REGISTER_F(Concat2D7Input, NNCLoop) ->Args({8, 128, 8, 256, 8, 384, 8, 512, 8, 512, 8, 512, 8, 512, 1});