mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-08 07:38:39 +01:00
In fast-math mode emit a tanh that has a faster min/max.
PiperOrigin-RevId: 164943597
This commit is contained in:
parent
87605f3d6a
commit
c0f9b0a91e
|
|
@ -100,7 +100,7 @@ operator()(llvm::Module& module) const {
|
||||||
|
|
||||||
CHECK(!llvm::verifyModule(module, &llvm::dbgs()));
|
CHECK(!llvm::verifyModule(module, &llvm::dbgs()));
|
||||||
|
|
||||||
runtime::RewriteIRRuntimeFunctions(&module);
|
runtime::RewriteIRRuntimeFunctions(&module, enable_fast_math_);
|
||||||
|
|
||||||
// Buffer for holding machine code prior to constructing the ObjectFile.
|
// Buffer for holding machine code prior to constructing the ObjectFile.
|
||||||
llvm::SmallVector<char, 0> stream_buffer;
|
llvm::SmallVector<char, 0> stream_buffer;
|
||||||
|
|
|
||||||
|
|
@ -42,7 +42,7 @@ class CompilerFunctor {
|
||||||
|
|
||||||
explicit CompilerFunctor(
|
explicit CompilerFunctor(
|
||||||
llvm::TargetMachine* target_machine, const Disassembler* disassembler,
|
llvm::TargetMachine* target_machine, const Disassembler* disassembler,
|
||||||
int opt_level, bool optimize_for_size,
|
int opt_level, bool optimize_for_size, bool enable_fast_math,
|
||||||
const VectorIntrinsics& available_intrinsics,
|
const VectorIntrinsics& available_intrinsics,
|
||||||
LLVMCompiler::ModuleHook pre_optimization_hook = nullptr,
|
LLVMCompiler::ModuleHook pre_optimization_hook = nullptr,
|
||||||
LLVMCompiler::ModuleHook post_optimization_hook = nullptr)
|
LLVMCompiler::ModuleHook post_optimization_hook = nullptr)
|
||||||
|
|
@ -50,6 +50,7 @@ class CompilerFunctor {
|
||||||
disassembler_(CHECK_NOTNULL(disassembler)),
|
disassembler_(CHECK_NOTNULL(disassembler)),
|
||||||
opt_level_(opt_level),
|
opt_level_(opt_level),
|
||||||
optimize_for_size_(optimize_for_size),
|
optimize_for_size_(optimize_for_size),
|
||||||
|
enable_fast_math_(enable_fast_math),
|
||||||
available_intrinsics_(available_intrinsics),
|
available_intrinsics_(available_intrinsics),
|
||||||
pre_optimization_hook_(pre_optimization_hook),
|
pre_optimization_hook_(pre_optimization_hook),
|
||||||
post_optimization_hook_(post_optimization_hook) {}
|
post_optimization_hook_(post_optimization_hook) {}
|
||||||
|
|
@ -72,6 +73,7 @@ class CompilerFunctor {
|
||||||
const Disassembler* disassembler_;
|
const Disassembler* disassembler_;
|
||||||
const unsigned opt_level_;
|
const unsigned opt_level_;
|
||||||
const bool optimize_for_size_;
|
const bool optimize_for_size_;
|
||||||
|
const bool enable_fast_math_;
|
||||||
const VectorIntrinsics available_intrinsics_;
|
const VectorIntrinsics available_intrinsics_;
|
||||||
LLVMCompiler::ModuleHook pre_optimization_hook_;
|
LLVMCompiler::ModuleHook pre_optimization_hook_;
|
||||||
LLVMCompiler::ModuleHook post_optimization_hook_;
|
LLVMCompiler::ModuleHook post_optimization_hook_;
|
||||||
|
|
|
||||||
|
|
@ -442,6 +442,7 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::Compile(
|
||||||
CompilerTargetOptions(module->config()),
|
CompilerTargetOptions(module->config()),
|
||||||
CodeGenOptLevel(module->config()),
|
CodeGenOptLevel(module->config()),
|
||||||
options::OptimizeForSizeRequested(module->config()),
|
options::OptimizeForSizeRequested(module->config()),
|
||||||
|
module->config().debug_options().xla_enable_fast_math(),
|
||||||
pre_optimization_ir_hook, post_optimization_ir_hook);
|
pre_optimization_ir_hook, post_optimization_ir_hook);
|
||||||
llvm_module->setDataLayout(jit->data_layout());
|
llvm_module->setDataLayout(jit->data_layout());
|
||||||
llvm_module->setTargetTriple(jit->target_triple().getTriple());
|
llvm_module->setTargetTriple(jit->target_triple().getTriple());
|
||||||
|
|
@ -794,6 +795,7 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
|
||||||
CompilerFunctor compiler_functor(
|
CompilerFunctor compiler_functor(
|
||||||
target_machine.get(), &disassembler, opt_level,
|
target_machine.get(), &disassembler, opt_level,
|
||||||
options::OptimizeForSizeRequested(module->config()),
|
options::OptimizeForSizeRequested(module->config()),
|
||||||
|
module->config().debug_options().xla_enable_fast_math(),
|
||||||
CompilerFunctor::AllIntrinsics(), pre_optimization_ir_dump_hook,
|
CompilerFunctor::AllIntrinsics(), pre_optimization_ir_dump_hook,
|
||||||
post_optimization_ir_dump_hook);
|
post_optimization_ir_dump_hook);
|
||||||
llvm::object::OwningBinary<llvm::object::ObjectFile> object_file =
|
llvm::object::OwningBinary<llvm::object::ObjectFile> object_file =
|
||||||
|
|
|
||||||
|
|
@ -30,9 +30,33 @@ const char* const kTanhV4F32SymbolName = "__xla_cpu_runtime_TanhV4F32";
|
||||||
const char* const kTanhV8F32SymbolName = "__xla_cpu_runtime_TanhV8F32";
|
const char* const kTanhV8F32SymbolName = "__xla_cpu_runtime_TanhV8F32";
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
llvm::Value* EmitFMinOrMax(llvm::IRBuilder<>* ir_builder, llvm::Module* module,
|
||||||
|
llvm::Type* vector_type, llvm::Value* lhs,
|
||||||
|
llvm::Value* rhs, bool is_min,
|
||||||
|
bool enable_fast_math) {
|
||||||
|
if (enable_fast_math) {
|
||||||
|
// Using an unordered comparison lets LLVM generate a vminps / vmaxps
|
||||||
|
// instruction on x86. vminps/vmaxps choose the second operand if either
|
||||||
|
// operand is a NaN and thus don't accurately implement the semantics of the
|
||||||
|
// minnum and maxnum intrinsics, necessitating different IR emission.
|
||||||
|
//
|
||||||
|
// We can _probably_ do this even when fast math is disabled, but we can
|
||||||
|
// certainly do this if fast math is enabled (and nnan applies).
|
||||||
|
auto* compare = ir_builder->CreateFCmp(
|
||||||
|
is_min ? llvm::FCmpInst::FCMP_ULE : llvm::FCmpInst::FCMP_UGE, lhs, rhs);
|
||||||
|
return ir_builder->CreateSelect(compare, lhs, rhs);
|
||||||
|
} else {
|
||||||
|
llvm::Function* intrinsic = llvm::Intrinsic::getDeclaration(
|
||||||
|
module, is_min ? llvm::Intrinsic::minnum : llvm::Intrinsic::maxnum,
|
||||||
|
vector_type);
|
||||||
|
return ir_builder->CreateCall(intrinsic, {lhs, rhs});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
llvm::Function* EmitVectorF32TanhIfNeeded(llvm::Module* module,
|
llvm::Function* EmitVectorF32TanhIfNeeded(llvm::Module* module,
|
||||||
llvm::StringRef function_name,
|
llvm::StringRef function_name,
|
||||||
int vector_width) {
|
int vector_width,
|
||||||
|
bool enable_fast_math) {
|
||||||
llvm::Function* vector_tanh_function = module->getFunction(function_name);
|
llvm::Function* vector_tanh_function = module->getFunction(function_name);
|
||||||
if (vector_tanh_function == nullptr) {
|
if (vector_tanh_function == nullptr) {
|
||||||
// If the function declaration is not present in the module, there can't be
|
// If the function declaration is not present in the module, there can't be
|
||||||
|
|
@ -45,11 +69,6 @@ llvm::Function* EmitVectorF32TanhIfNeeded(llvm::Module* module,
|
||||||
llvm::VectorType* vector_type =
|
llvm::VectorType* vector_type =
|
||||||
llvm::VectorType::get(float_type, vector_width);
|
llvm::VectorType::get(float_type, vector_width);
|
||||||
|
|
||||||
llvm::Function* min_intrinsic = llvm::Intrinsic::getDeclaration(
|
|
||||||
module, llvm::Intrinsic::minnum, vector_type);
|
|
||||||
llvm::Function* max_intrinsic = llvm::Intrinsic::getDeclaration(
|
|
||||||
module, llvm::Intrinsic::maxnum, vector_type);
|
|
||||||
|
|
||||||
llvm::BasicBlock* vector_tanh_body =
|
llvm::BasicBlock* vector_tanh_body =
|
||||||
llvm::BasicBlock::Create(*context, "body", vector_tanh_function);
|
llvm::BasicBlock::Create(*context, "body", vector_tanh_function);
|
||||||
|
|
||||||
|
|
@ -59,15 +78,24 @@ llvm::Function* EmitVectorF32TanhIfNeeded(llvm::Module* module,
|
||||||
fast_math_flags.setUnsafeAlgebra();
|
fast_math_flags.setUnsafeAlgebra();
|
||||||
ir_builder.setFastMathFlags(fast_math_flags);
|
ir_builder.setFastMathFlags(fast_math_flags);
|
||||||
|
|
||||||
|
auto emit_fmin = [&](llvm::Value* lhs, llvm::Value* rhs) {
|
||||||
|
return EmitFMinOrMax(&ir_builder, module, vector_type, lhs, rhs,
|
||||||
|
/*is_min=*/true,
|
||||||
|
/*enable_fast_math=*/enable_fast_math);
|
||||||
|
};
|
||||||
|
auto emit_fmax = [&](llvm::Value* lhs, llvm::Value* rhs) {
|
||||||
|
return EmitFMinOrMax(&ir_builder, module, vector_type, lhs, rhs,
|
||||||
|
/*is_min=*/false,
|
||||||
|
/*enable_fast_math=*/enable_fast_math);
|
||||||
|
};
|
||||||
|
|
||||||
llvm::Value* input = &*vector_tanh_function->arg_begin();
|
llvm::Value* input = &*vector_tanh_function->arg_begin();
|
||||||
CHECK_EQ(input->getType(), vector_type);
|
CHECK_EQ(input->getType(), vector_type);
|
||||||
|
|
||||||
// This implements the same rational interpolant as implemented in Eigen3.
|
// This implements the same rational interpolant as implemented in Eigen3.
|
||||||
llvm::Value* input_clamped = ir_builder.CreateCall(
|
llvm::Value* input_clamped =
|
||||||
min_intrinsic,
|
emit_fmin(emit_fmax(input, llvm::ConstantFP::get(vector_type, -9.0)),
|
||||||
{ir_builder.CreateCall(max_intrinsic,
|
llvm::ConstantFP::get(vector_type, 9.0));
|
||||||
{input, llvm::ConstantFP::get(vector_type, -9.0)}),
|
|
||||||
llvm::ConstantFP::get(vector_type, 9.0)});
|
|
||||||
|
|
||||||
std::array<float, 7> numerator_coeffs(
|
std::array<float, 7> numerator_coeffs(
|
||||||
{{-2.76076847742355e-16f, 2.00018790482477e-13f, -8.60467152213735e-11f,
|
{{-2.76076847742355e-16f, 2.00018790482477e-13f, -8.60467152213735e-11f,
|
||||||
|
|
@ -105,11 +133,13 @@ llvm::Function* EmitVectorF32TanhIfNeeded(llvm::Module* module,
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void RewriteIRRuntimeFunctions(llvm::Module* module) {
|
void RewriteIRRuntimeFunctions(llvm::Module* module, bool enable_fast_math) {
|
||||||
auto* tanh_v4f32 = EmitVectorF32TanhIfNeeded(module, kTanhV4F32SymbolName,
|
auto* tanh_v4f32 =
|
||||||
/*vector_width=*/4);
|
EmitVectorF32TanhIfNeeded(module, kTanhV4F32SymbolName,
|
||||||
auto* tanh_v8f32 = EmitVectorF32TanhIfNeeded(module, kTanhV8F32SymbolName,
|
/*vector_width=*/4, enable_fast_math);
|
||||||
/*vector_width=*/8);
|
auto* tanh_v8f32 =
|
||||||
|
EmitVectorF32TanhIfNeeded(module, kTanhV8F32SymbolName,
|
||||||
|
/*vector_width=*/8, enable_fast_math);
|
||||||
|
|
||||||
// Gather all the call sites, force inline them and then delete the vector
|
// Gather all the call sites, force inline them and then delete the vector
|
||||||
// function bodies.
|
// function bodies.
|
||||||
|
|
|
||||||
|
|
@ -33,7 +33,7 @@ extern const char* const kTanhV8F32SymbolName;
|
||||||
// |LinkIRRuntimeFunctions| rewrites calls to these functions into generic LLVM
|
// |LinkIRRuntimeFunctions| rewrites calls to these functions into generic LLVM
|
||||||
// IR.
|
// IR.
|
||||||
|
|
||||||
void RewriteIRRuntimeFunctions(llvm::Module* module);
|
void RewriteIRRuntimeFunctions(llvm::Module* module, bool enable_fast_math);
|
||||||
|
|
||||||
} // namespace runtime
|
} // namespace runtime
|
||||||
} // namespace cpu
|
} // namespace cpu
|
||||||
|
|
|
||||||
|
|
@ -171,7 +171,7 @@ CompilerFunctor::VectorIntrinsics GetAvailableIntrinsics() {
|
||||||
|
|
||||||
SimpleOrcJIT::SimpleOrcJIT(const llvm::TargetOptions& target_options,
|
SimpleOrcJIT::SimpleOrcJIT(const llvm::TargetOptions& target_options,
|
||||||
llvm::CodeGenOpt::Level opt_level,
|
llvm::CodeGenOpt::Level opt_level,
|
||||||
bool optimize_for_size,
|
bool optimize_for_size, bool enable_fast_math,
|
||||||
LLVMCompiler::ModuleHook pre_optimization_hook,
|
LLVMCompiler::ModuleHook pre_optimization_hook,
|
||||||
LLVMCompiler::ModuleHook post_optimization_hook)
|
LLVMCompiler::ModuleHook post_optimization_hook)
|
||||||
: target_machine_(
|
: target_machine_(
|
||||||
|
|
@ -186,12 +186,12 @@ SimpleOrcJIT::SimpleOrcJIT(const llvm::TargetOptions& target_options,
|
||||||
data_layout_(target_machine_->createDataLayout()),
|
data_layout_(target_machine_->createDataLayout()),
|
||||||
object_layer_(
|
object_layer_(
|
||||||
[] { return std::make_shared<llvm::SectionMemoryManager>(); }),
|
[] { return std::make_shared<llvm::SectionMemoryManager>(); }),
|
||||||
compile_layer_(
|
compile_layer_(object_layer_,
|
||||||
object_layer_,
|
CompilerFunctor(target_machine_.get(), &disassembler_,
|
||||||
CompilerFunctor(target_machine_.get(), &disassembler_, opt_level,
|
opt_level, optimize_for_size,
|
||||||
optimize_for_size, GetAvailableIntrinsics(),
|
enable_fast_math, GetAvailableIntrinsics(),
|
||||||
std::move(pre_optimization_hook),
|
std::move(pre_optimization_hook),
|
||||||
std::move(post_optimization_hook))) {
|
std::move(post_optimization_hook))) {
|
||||||
VLOG(1) << "CPU target: " << target_machine_->getTargetCPU().str()
|
VLOG(1) << "CPU target: " << target_machine_->getTargetCPU().str()
|
||||||
<< " features: " << target_machine_->getTargetFeatureString().str();
|
<< " features: " << target_machine_->getTargetFeatureString().str();
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -63,6 +63,7 @@ class SimpleOrcJIT {
|
||||||
// level optimizations are applied.
|
// level optimizations are applied.
|
||||||
SimpleOrcJIT(const llvm::TargetOptions& target_options,
|
SimpleOrcJIT(const llvm::TargetOptions& target_options,
|
||||||
llvm::CodeGenOpt::Level opt_level, bool optimize_for_size,
|
llvm::CodeGenOpt::Level opt_level, bool optimize_for_size,
|
||||||
|
bool enable_fast_math,
|
||||||
LLVMCompiler::ModuleHook pre_optimization_hook,
|
LLVMCompiler::ModuleHook pre_optimization_hook,
|
||||||
LLVMCompiler::ModuleHook post_optimization_hook);
|
LLVMCompiler::ModuleHook post_optimization_hook);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -292,6 +292,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatBinaryOp(
|
||||||
|
|
||||||
llvm::Value* ElementalIrEmitter::EmitFloatMax(llvm::Value* lhs_value,
|
llvm::Value* ElementalIrEmitter::EmitFloatMax(llvm::Value* lhs_value,
|
||||||
llvm::Value* rhs_value) const {
|
llvm::Value* rhs_value) const {
|
||||||
|
// TODO(b/64580527): We can do better here if fast-math is enabled.
|
||||||
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::maxnum,
|
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::maxnum,
|
||||||
{lhs_value, rhs_value},
|
{lhs_value, rhs_value},
|
||||||
{lhs_value->getType()}, ir_builder_);
|
{lhs_value->getType()}, ir_builder_);
|
||||||
|
|
@ -299,6 +300,7 @@ llvm::Value* ElementalIrEmitter::EmitFloatMax(llvm::Value* lhs_value,
|
||||||
|
|
||||||
llvm::Value* ElementalIrEmitter::EmitFloatMin(llvm::Value* lhs_value,
|
llvm::Value* ElementalIrEmitter::EmitFloatMin(llvm::Value* lhs_value,
|
||||||
llvm::Value* rhs_value) const {
|
llvm::Value* rhs_value) const {
|
||||||
|
// TODO(b/64580527): We can do better here if fast-math is enabled.
|
||||||
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::minnum,
|
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::minnum,
|
||||||
{lhs_value, rhs_value},
|
{lhs_value, rhs_value},
|
||||||
{lhs_value->getType()}, ir_builder_);
|
{lhs_value->getType()}, ir_builder_);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user