mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[te][reapply] Add fast log approximation based on sleef (#49575)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49575 This is a fast log implementations benchmark: ``` buck run mode/opt //caffe2/benchmarks/cpp/tensorexpr:tensorexpr_bench -c 'fbcode.caffe2_gpu_type=none' ``` Test Plan: buck test mode/no-gpu //caffe2/test/cpp/tensorexpr:tensorexpr -- *.fastLogFloat Reviewed By: bertmaher Differential Revision: D25627157 fbshipit-source-id: a4920f4f4005ce617d372b375e790ca966275cd9
This commit is contained in:
parent
c78fd76f18
commit
1047957831
145
benchmarks/cpp/tensorexpr/bench_approx.cpp
Normal file
145
benchmarks/cpp/tensorexpr/bench_approx.cpp
Normal file
|
|
@ -0,0 +1,145 @@
|
|||
#include <benchmark/benchmark.h>
|
||||
#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
|
||||
#include <torch/csrc/jit/tensorexpr/loopnest.h>
|
||||
#include <torch/csrc/jit/tensorexpr/tensor.h>
|
||||
#include <torch/csrc/jit/tensorexpr/llvm_codegen.h>
|
||||
#include <torch/torch.h>
|
||||
|
||||
using namespace torch::jit::tensorexpr;
|
||||
|
||||
static void log_sleef(benchmark::State& state) {
|
||||
KernelScope ks;
|
||||
auto N = VarHandle("N", kInt);
|
||||
Placeholder A("A", kFloat, {N});
|
||||
torch::jit::tensorexpr::Tensor* B =
|
||||
Compute("B", {N}, [&](const VarHandle& i) {
|
||||
return log(A.load(i));
|
||||
});
|
||||
LoopNest ln({B});
|
||||
ln.prepareForCodegen();
|
||||
ln.vectorizeInnerLoops();
|
||||
Stmt* s = ln.root_stmt();
|
||||
s = torch::jit::tensorexpr::IRSimplifier::simplify(s);
|
||||
std::vector<CodeGen::BufferArg> args;
|
||||
args.emplace_back(B);
|
||||
args.emplace_back(A);
|
||||
args.emplace_back(N);
|
||||
LLVMCodeGen cg(s, args);
|
||||
at::Tensor A_t = torch::abs(torch::randn({state.range(0)}));
|
||||
at::Tensor B_t = torch::randn({state.range(0)});
|
||||
auto B_ref = at::log(A_t);
|
||||
cg.call({B_t.data_ptr<float>(), A_t.data_ptr<float>(), state.range(0)});
|
||||
assert(at::allclose(B_t, B_ref));
|
||||
for (auto _ : state) {
|
||||
cg.call({B_t.data_ptr<float>(), A_t.data_ptr<float>(), state.range(0)});
|
||||
}
|
||||
state.counters["log/s"] = benchmark::Counter(
|
||||
uint64_t(state.range(0) * state.iterations()), benchmark::Counter::kIsRate);
|
||||
}
|
||||
|
||||
static void log_fast(benchmark::State& state) {
|
||||
KernelScope ks;
|
||||
auto N = VarHandle("N", kInt);
|
||||
Placeholder A("A", kFloat, {N});
|
||||
torch::jit::tensorexpr::Tensor* B =
|
||||
Compute("B", {N}, [&](const VarHandle& i) {
|
||||
return fast_log(A.load(i));
|
||||
});
|
||||
LoopNest ln({B});
|
||||
ln.prepareForCodegen();
|
||||
ln.vectorizeInnerLoops();
|
||||
Stmt* s = ln.root_stmt();
|
||||
s = torch::jit::tensorexpr::IRSimplifier::simplify(s);
|
||||
std::vector<CodeGen::BufferArg> args;
|
||||
args.emplace_back(B);
|
||||
args.emplace_back(A);
|
||||
args.emplace_back(N);
|
||||
LLVMCodeGen cg(s, args);
|
||||
at::Tensor A_t = torch::abs(torch::randn({state.range(0)}));
|
||||
at::Tensor B_t = torch::randn({state.range(0)});
|
||||
auto B_ref = at::log(A_t);
|
||||
cg.call({B_t.data_ptr<float>(), A_t.data_ptr<float>(), state.range(0)});
|
||||
assert(at::allclose(B_t, B_ref));
|
||||
for (auto _ : state) {
|
||||
cg.call({B_t.data_ptr<float>(), A_t.data_ptr<float>(), state.range(0)});
|
||||
}
|
||||
state.counters["log/s"] = benchmark::Counter(
|
||||
uint64_t(state.range(0) * state.iterations()), benchmark::Counter::kIsRate);
|
||||
}
|
||||
|
||||
static void log_aten(benchmark::State& state) {
|
||||
at::Tensor A_t = torch::abs(torch::randn({state.range(0)}));
|
||||
at::Tensor B_t = torch::randn({state.range(0)});
|
||||
for (auto _ : state) {
|
||||
at::native::log_out(B_t, A_t);
|
||||
}
|
||||
state.counters["log/s"] = benchmark::Counter(
|
||||
uint64_t(state.range(0) * state.iterations()), benchmark::Counter::kIsRate);
|
||||
}
|
||||
|
||||
static void logit_fast(benchmark::State& state) {
|
||||
KernelScope ks;
|
||||
auto N = VarHandle("N", kInt);
|
||||
Placeholder A("A", kFloat, {N});
|
||||
torch::jit::tensorexpr::Tensor* B =
|
||||
Compute("B", {N}, [&](const VarHandle& i) {
|
||||
auto A_elem = A.load(i);
|
||||
return fast_log(A_elem / (FloatImm::make(1.0f) - A_elem));
|
||||
});
|
||||
LoopNest ln({B});
|
||||
ln.prepareForCodegen();
|
||||
ln.vectorizeInnerLoops();
|
||||
Stmt* s = ln.root_stmt();
|
||||
s = torch::jit::tensorexpr::IRSimplifier::simplify(s);
|
||||
std::vector<CodeGen::BufferArg> args;
|
||||
args.emplace_back(B);
|
||||
args.emplace_back(A);
|
||||
args.emplace_back(N);
|
||||
LLVMCodeGen cg(s, args);
|
||||
at::Tensor A_t = torch::abs(torch::randn({state.range(0)}));
|
||||
at::Tensor B_t = torch::randn({state.range(0)});
|
||||
auto B_ref = at::logit(A_t);
|
||||
cg.call({B_t.data_ptr<float>(), A_t.data_ptr<float>(), state.range(0)});
|
||||
assert(at::allclose(B_t, B_ref));
|
||||
for (auto _ : state) {
|
||||
cg.call({B_t.data_ptr<float>(), A_t.data_ptr<float>(), state.range(0)});
|
||||
}
|
||||
state.counters["logit/s"] = benchmark::Counter(
|
||||
uint64_t(state.range(0) * state.iterations()), benchmark::Counter::kIsRate);
|
||||
}
|
||||
|
||||
static void logit_aten(benchmark::State& state) {
|
||||
at::Tensor A_t = torch::abs(torch::randn({state.range(0)}));
|
||||
at::Tensor B_t = torch::randn({state.range(0)});
|
||||
for (auto _ : state) {
|
||||
at::native::logit_out(B_t, A_t);
|
||||
}
|
||||
state.counters["logit/s"] = benchmark::Counter(
|
||||
uint64_t(state.range(0) * state.iterations()), benchmark::Counter::kIsRate);
|
||||
}
|
||||
|
||||
BENCHMARK(log_sleef)
|
||||
->Args({2<<5})
|
||||
->Args({2<<8})
|
||||
->Args({2<<12})
|
||||
->Args({2<<14});
|
||||
BENCHMARK(log_fast)
|
||||
->Args({2<<5})
|
||||
->Args({2<<8})
|
||||
->Args({2<<12})
|
||||
->Args({2<<14});
|
||||
BENCHMARK(log_aten)
|
||||
->Args({2<<5})
|
||||
->Args({2<<8})
|
||||
->Args({2<<12})
|
||||
->Args({2<<14});
|
||||
BENCHMARK(logit_fast)
|
||||
->Args({2<<5})
|
||||
->Args({2<<8})
|
||||
->Args({2<<12})
|
||||
->Args({2<<14});
|
||||
BENCHMARK(logit_aten)
|
||||
->Args({2<<5})
|
||||
->Args({2<<8})
|
||||
->Args({2<<12})
|
||||
->Args({2<<14});
|
||||
|
|
@ -733,6 +733,38 @@ TEST(ATen, logFloat) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST(ATen, fastLogFloat) {
|
||||
KernelScope kernel_scope;
|
||||
const int kTotalSize = 128 * 128;
|
||||
Placeholder a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kFloat));
|
||||
Placeholder b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kFloat));
|
||||
|
||||
VarHandle index = VarHandle("index", kInt);
|
||||
ExprHandle load_a = a_buf.load(index);
|
||||
Stmt* store_b = b_buf.store({index}, fast_log(load_a));
|
||||
Stmt* stmt = For::make(index, 0, kTotalSize, store_b);
|
||||
|
||||
PaddedBuffer<float> a_v(kTotalSize);
|
||||
PaddedBuffer<float> b_v(kTotalSize);
|
||||
|
||||
for (int i = 0; i < kTotalSize; ++i) {
|
||||
a_v(i) = at::randn({1}).item().to<float>();
|
||||
}
|
||||
|
||||
SimpleIREvaluator ir_eval(stmt, a_buf, b_buf);
|
||||
ir_eval(a_v, b_v);
|
||||
|
||||
for (int i = 0; i < kTotalSize; ++i) {
|
||||
auto test = b_v(i);
|
||||
auto ref = std::log(a_v(i));
|
||||
if (std::isnan(ref)) {
|
||||
ASSERT_EQ(std::isnan(test), true);
|
||||
} else {
|
||||
ASSERT_FLOAT_EQ(test, ref);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ATen, log10Float) {
|
||||
KernelScope kernel_scope;
|
||||
const int kTotalSize = 128;
|
||||
|
|
|
|||
|
|
@ -217,6 +217,38 @@ TEST(LLVM, BitCast) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST(LLVM, fastLogFloat) {
|
||||
KernelScope kernel_scope;
|
||||
const int kTotalSize = 128 * 128;
|
||||
Placeholder a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kFloat));
|
||||
Placeholder b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kFloat));
|
||||
|
||||
VarHandle index = VarHandle("index", kInt);
|
||||
ExprHandle load_a = a_buf.load(index);
|
||||
Stmt* store_b = b_buf.store({index}, fast_log(load_a));
|
||||
Stmt* stmt = For::make(index, 0, kTotalSize, store_b);
|
||||
|
||||
PaddedBuffer<float> a_v(kTotalSize);
|
||||
PaddedBuffer<float> b_v(kTotalSize);
|
||||
|
||||
for (int i = 0; i < kTotalSize; ++i) {
|
||||
a_v(i) = at::randn({1}).item().to<float>();
|
||||
}
|
||||
|
||||
LLVMCodeGen ir_eval(stmt, {a_buf, b_buf});
|
||||
ir_eval.call({a_v, b_v});
|
||||
|
||||
for (int i = 0; i < kTotalSize; ++i) {
|
||||
auto test = b_v(i);
|
||||
auto ref = std::log(a_v(i));
|
||||
if (std::isnan(ref)) {
|
||||
ASSERT_EQ(std::isnan(test), true);
|
||||
} else {
|
||||
ASSERT_FLOAT_EQ(test, ref);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(LLVM, LetTest01) {
|
||||
KernelScope kernel_scope;
|
||||
|
||||
|
|
|
|||
|
|
@ -337,9 +337,12 @@ class SimpleIREvaluator : public CodeGen, public IRVisitor {
|
|||
std::vector<T> result_v(lhs_v.size());
|
||||
for (size_t i = 0; i < lhs_v.size(); i++) {
|
||||
switch (op_type) {
|
||||
case IRNodeType::kLshift:
|
||||
result_v[i] = lhs_v[i] << rhs_v[i];
|
||||
case IRNodeType::kLshift: {
|
||||
typename std::make_unsigned<T>::type a =
|
||||
static_cast<typename std::make_unsigned<T>::type>(lhs_v[i]);
|
||||
result_v[i] = a << rhs_v[i];
|
||||
break;
|
||||
}
|
||||
case IRNodeType::kRshift:
|
||||
result_v[i] = lhs_v[i] >> rhs_v[i];
|
||||
break;
|
||||
|
|
|
|||
|
|
@ -128,6 +128,46 @@ ExprHandle fabs(const ExprHandle& v) {
|
|||
return Intrinsics::make(kFabs, v);
|
||||
}
|
||||
|
||||
ExprHandle fast_log(const ExprHandle& v) {
|
||||
// this implementation is taken from sleef:
|
||||
// https://github.com/shibatch/sleef/blob/master/src/libm/sleefsp.c#L1131
|
||||
// to generate coefficients, this tool is provided
|
||||
// https://github.com/shibatch/sleef/blob/master/src/gencoef/gencoef.txt
|
||||
auto ilogb2kf = [](ExprHandle x) {
|
||||
auto y = (bitcast<int32_t>(x) >> IntImm::make(23)) & IntImm::make(0xff);
|
||||
return y - IntImm::make(0x7f);
|
||||
};
|
||||
|
||||
auto ldexp3kf = [](ExprHandle x, ExprHandle e) {
|
||||
return bitcast<float>(bitcast<int32_t>(x) + (e << IntImm::make(23)));
|
||||
};
|
||||
auto e = ilogb2kf(v * FloatImm::make(1.0 / 0.75));
|
||||
auto m = ldexp3kf(v, IntImm::make(-1) * e);
|
||||
auto one = FloatImm::make(1.0f);
|
||||
auto x = (m - one) / (m + one);
|
||||
auto x2 = x * x;
|
||||
|
||||
auto mlaf = [](ExprHandle x, ExprHandle y, float z) {
|
||||
return x * y + FloatImm::make(z);
|
||||
};
|
||||
|
||||
auto t = FloatImm::make(0.2392828464508056640625);
|
||||
t = mlaf(t, x2, 0.28518211841583251953125);
|
||||
t = mlaf(t, x2, 0.400005877017974853515625);
|
||||
t = mlaf(t, x2, 0.666666686534881591796875);
|
||||
t = mlaf(t, x2, 2.0);
|
||||
x = x * t + FloatImm::make(0.693147180559945286226764) * e;
|
||||
x = IfThenElse::make(
|
||||
v < FloatImm::make(0),
|
||||
FloatImm::make(std::numeric_limits<float>::quiet_NaN()),
|
||||
x);
|
||||
x = IfThenElse::make(
|
||||
v == FloatImm::make(0),
|
||||
FloatImm::make(-std::numeric_limits<float>::infinity()),
|
||||
x);
|
||||
return x;
|
||||
}
|
||||
|
||||
ExprHandle log(const ExprHandle& v) {
|
||||
return Intrinsics::make(kLog, v);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -290,6 +290,7 @@ TORCH_API ExprHandle exp(const ExprHandle& v);
|
|||
TORCH_API ExprHandle expm1(const ExprHandle& v);
|
||||
TORCH_API ExprHandle fabs(const ExprHandle& v);
|
||||
TORCH_API ExprHandle log(const ExprHandle& v);
|
||||
TORCH_API ExprHandle fast_log(const ExprHandle& v);
|
||||
TORCH_API ExprHandle log2(const ExprHandle& v);
|
||||
TORCH_API ExprHandle log10(const ExprHandle& v);
|
||||
TORCH_API ExprHandle log1p(const ExprHandle& v);
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user