mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72390
This class didn't add much value and only caused more boilerplate code.
This change removes the class and updates all the use cases with
uses of `ExprHandle`.
A side effect of this change is different names in loop variables, which
caused massive mechanical changes in our tests.
Test Plan: Imported from OSS
Reviewed By: navahgar
Differential Revision: D34030296
Pulled By: ZolotukhinM
fbshipit-source-id: 2ba4e313506a43ab129a10d99e72b638b7d40108
(cherry picked from commit c2ec46a058)
72 lines
2.6 KiB
C++
72 lines
2.6 KiB
C++
#include <benchmark/benchmark.h>
|
|
#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
|
|
#include <torch/csrc/jit/tensorexpr/llvm_codegen.h>
|
|
#include <torch/csrc/jit/tensorexpr/loopnest.h>
|
|
#include <torch/csrc/jit/tensorexpr/tensor.h>
|
|
|
|
#ifdef TORCH_ENABLE_LLVM
|
|
namespace te = torch::jit::tensorexpr;
|
|
|
|
static void BM_CompileSwish(benchmark::State& state) {
|
|
for (auto _ : state) {
|
|
constexpr int N = 512;
|
|
te::VarHandle n("n", te::kInt);
|
|
te::BufHandle A("A", {N}, te::kFloat);
|
|
te::Tensor relu = te::Compute("relu", {n}, [&](const te::VarHandle& i) {
|
|
return te::Max::make(A.load(i), 0.f, false);
|
|
});
|
|
te::Tensor min6 = te::Compute("min6", {n}, [&](const te::VarHandle& i) {
|
|
return te::Min::make(relu.load(i), 6.f, false);
|
|
});
|
|
te::Tensor plus3 = te::Compute("plus3", {n}, [&](const te::VarHandle& i) {
|
|
return min6.load(i) + 3.f;
|
|
});
|
|
te::Tensor times = te::Compute("times", {n}, [&](const te::VarHandle& i) {
|
|
return A.load(i) * plus3.load(i);
|
|
});
|
|
te::Tensor sixth = te::Compute("sixth", {n}, [&](const te::VarHandle& i) {
|
|
return times.load(i) * 1.f / 6.f;
|
|
});
|
|
te::LoopNest nest({sixth}, {relu, min6, plus3, times, sixth});
|
|
for (auto tensor : {relu, min6, plus3, times}) {
|
|
nest.computeInline(tensor.buf());
|
|
}
|
|
nest.prepareForCodegen();
|
|
te::StmtPtr s = te::IRSimplifier::simplify(nest.root_stmt());
|
|
te::LLVMCodeGen cg(s, {A, sixth, n});
|
|
}
|
|
}
|
|
|
|
static void BM_CompileSwishLLVMOnly(benchmark::State& state) {
|
|
constexpr int N = 512;
|
|
te::VarHandle n("n", te::kInt);
|
|
te::BufHandle A("A", {N}, te::kFloat);
|
|
te::Tensor relu = te::Compute("relu", {n}, [&](const te::VarHandle& i) {
|
|
return te::Max::make(A.load(i), 0.f, false);
|
|
});
|
|
te::Tensor min6 = te::Compute("min6", {n}, [&](const te::VarHandle& i) {
|
|
return te::Min::make(relu.load(i), 6.f, false);
|
|
});
|
|
te::Tensor plus3 = te::Compute(
|
|
"plus3", {n}, [&](const te::VarHandle& i) { return min6.load(i) + 3.f; });
|
|
te::Tensor times = te::Compute("times", {n}, [&](const te::VarHandle& i) {
|
|
return A.load(i) * plus3.load(i);
|
|
});
|
|
te::Tensor sixth = te::Compute("sixth", {n}, [&](const te::VarHandle& i) {
|
|
return times.load(i) * 1.f / 6.f;
|
|
});
|
|
te::LoopNest nest({sixth}, {relu, min6, plus3, times, sixth});
|
|
for (auto tensor : {relu, min6, plus3, times}) {
|
|
nest.computeInline(tensor.buf());
|
|
}
|
|
nest.prepareForCodegen();
|
|
te::StmtPtr s = te::IRSimplifier::simplify(nest.root_stmt());
|
|
for (auto _ : state) {
|
|
te::LLVMCodeGen cg(s, {A, sixth, n});
|
|
}
|
|
}
|
|
|
|
BENCHMARK(BM_CompileSwish);
|
|
BENCHMARK(BM_CompileSwishLLVMOnly);
|
|
#endif // TORCH_ENABLE_LLVM
|