Revert D31445799: [nnc] Use given kernel function name while emitting code

Test Plan: revert-hammer

Differential Revision:
D31445799 (c30dc52739)

Original commit changeset: 8d1642098313

fbshipit-source-id: 6b9d8c816437e9fcba8eb19cc683bc0a46a04cf5
This commit is contained in:
Raghavan Raman 2021-10-08 12:36:18 -07:00 committed by Facebook GitHub Bot
parent 2e6fa0261f
commit 92ce188510
2 changed files with 5 additions and 30 deletions

View File

@ -1793,19 +1793,6 @@ TEST(LLVM, CustomTarget) {
->run(ss.str());
}
TEST(LLVM, CodeGenKernelFuncName) {
BufHandle a("A", {1}, kInt);
BufHandle b("B", {1}, kInt);
std::vector<int32_t> a_buffer = {42};
std::vector<int32_t> b_buffer = {-11};
auto store = b.store({0}, a.load(0));
LLVMCodeGen cg(store, {a, b});
// Check that the kernel function name used by LLVMCodeGen
// is not empty.
ASSERT_NE(cg.kernel_func_name(), "");
}
} // namespace jit
} // namespace torch

View File

@ -184,7 +184,6 @@ class LLVMCodeGenImpl : public IRVisitor {
llvm::BasicBlock* bb_;
llvm::Value* value_{nullptr};
llvm::JITTargetAddress kernelAddress_;
std::string kernel_func_name_;
#define LLVM_TYPE_DECLARE(_1, Name) llvm::Type* Name##Ty_;
AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, LLVM_TYPE_DECLARE);
@ -234,7 +233,6 @@ class LLVMCodeGenImpl : public IRVisitor {
const std::vector<CodeGen::BufferArg>& args,
at::Device device,
Dtype dtype,
std::string kernel_func_name,
c10::optional<std::string> triple,
c10::optional<std::string> cpu,
c10::optional<std::string> attrs);
@ -340,15 +338,8 @@ LLVMCodeGen::LLVMCodeGen(
c10::optional<std::string> cpu,
c10::optional<std::string> attrs)
: CodeGen(stmt, args, device, kernel_func_name),
impl_(std::make_unique<LLVMCodeGenImpl>(
stmt,
args,
device,
dtype,
this->kernel_func_name(),
triple,
cpu,
attrs)) {
impl_(std::make_unique<
LLVMCodeGenImpl>(stmt, args, device, dtype, triple, cpu, attrs)) {
callee_ = std::make_unique<LLVMCodeGenCallee>(
impl_->releaseJIT(), (void*)impl_->getKernelAddress());
}
@ -427,13 +418,10 @@ LLVMCodeGenImpl::LLVMCodeGenImpl(
const std::vector<CodeGen::BufferArg>& args,
at::Device device,
Dtype dtype,
std::string kernel_func_name,
c10::optional<std::string> triple,
c10::optional<std::string> cpu,
c10::optional<std::string> attrs)
: context_(std::make_unique<llvm::LLVMContext>()),
irb_(getContext()),
kernel_func_name_(std::move(kernel_func_name)) {
: context_(std::make_unique<llvm::LLVMContext>()), irb_(getContext()) {
// Manually map types to LLVM types.
ByteTy_ = llvm::Type::getInt8Ty(getContext());
CharTy_ = llvm::Type::getInt8Ty(getContext());
@ -490,7 +478,7 @@ LLVMCodeGenImpl::LLVMCodeGenImpl(
emitKernel(stmt, params);
jit_->addModule(std::move(module_), std::move(context_));
auto sym = jit_->findSymbol(kernel_func_name_);
auto sym = jit_->findSymbol("wrapper");
kernelAddress_ = assertSuccess(sym.getAddress());
}
@ -522,7 +510,7 @@ void LLVMCodeGenImpl::emitWrapper(const std::vector<llvm::Type*>& params) {
auto wrapper = llvm::Function::Create(
llvm::FunctionType::get(IntTy_, {voidPtrPtrTy}, false),
llvm::Function::ExternalLinkage,
kernel_func_name_,
"wrapper",
module_.get());
auto wrapBB = llvm::BasicBlock::Create(getContext(), "wrapBB", wrapper);
irb_.SetInsertPoint(wrapBB);