mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert D31445799: [nnc] Use given kernel function name while emitting code
Test Plan: revert-hammer
Differential Revision:
D31445799 (c30dc52739)
Original commit changeset: 8d1642098313
fbshipit-source-id: 6b9d8c816437e9fcba8eb19cc683bc0a46a04cf5
This commit is contained in:
parent
2e6fa0261f
commit
92ce188510
|
|
@ -1793,19 +1793,6 @@ TEST(LLVM, CustomTarget) {
|
|||
->run(ss.str());
|
||||
}
|
||||
|
||||
TEST(LLVM, CodeGenKernelFuncName) {
|
||||
BufHandle a("A", {1}, kInt);
|
||||
BufHandle b("B", {1}, kInt);
|
||||
std::vector<int32_t> a_buffer = {42};
|
||||
std::vector<int32_t> b_buffer = {-11};
|
||||
|
||||
auto store = b.store({0}, a.load(0));
|
||||
LLVMCodeGen cg(store, {a, b});
|
||||
// Check that the kernel function name used by LLVMCodeGen
|
||||
// is not empty.
|
||||
ASSERT_NE(cg.kernel_func_name(), "");
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
||||
|
|
|
|||
|
|
@ -184,7 +184,6 @@ class LLVMCodeGenImpl : public IRVisitor {
|
|||
llvm::BasicBlock* bb_;
|
||||
llvm::Value* value_{nullptr};
|
||||
llvm::JITTargetAddress kernelAddress_;
|
||||
std::string kernel_func_name_;
|
||||
|
||||
#define LLVM_TYPE_DECLARE(_1, Name) llvm::Type* Name##Ty_;
|
||||
AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, LLVM_TYPE_DECLARE);
|
||||
|
|
@ -234,7 +233,6 @@ class LLVMCodeGenImpl : public IRVisitor {
|
|||
const std::vector<CodeGen::BufferArg>& args,
|
||||
at::Device device,
|
||||
Dtype dtype,
|
||||
std::string kernel_func_name,
|
||||
c10::optional<std::string> triple,
|
||||
c10::optional<std::string> cpu,
|
||||
c10::optional<std::string> attrs);
|
||||
|
|
@ -340,15 +338,8 @@ LLVMCodeGen::LLVMCodeGen(
|
|||
c10::optional<std::string> cpu,
|
||||
c10::optional<std::string> attrs)
|
||||
: CodeGen(stmt, args, device, kernel_func_name),
|
||||
impl_(std::make_unique<LLVMCodeGenImpl>(
|
||||
stmt,
|
||||
args,
|
||||
device,
|
||||
dtype,
|
||||
this->kernel_func_name(),
|
||||
triple,
|
||||
cpu,
|
||||
attrs)) {
|
||||
impl_(std::make_unique<
|
||||
LLVMCodeGenImpl>(stmt, args, device, dtype, triple, cpu, attrs)) {
|
||||
callee_ = std::make_unique<LLVMCodeGenCallee>(
|
||||
impl_->releaseJIT(), (void*)impl_->getKernelAddress());
|
||||
}
|
||||
|
|
@ -427,13 +418,10 @@ LLVMCodeGenImpl::LLVMCodeGenImpl(
|
|||
const std::vector<CodeGen::BufferArg>& args,
|
||||
at::Device device,
|
||||
Dtype dtype,
|
||||
std::string kernel_func_name,
|
||||
c10::optional<std::string> triple,
|
||||
c10::optional<std::string> cpu,
|
||||
c10::optional<std::string> attrs)
|
||||
: context_(std::make_unique<llvm::LLVMContext>()),
|
||||
irb_(getContext()),
|
||||
kernel_func_name_(std::move(kernel_func_name)) {
|
||||
: context_(std::make_unique<llvm::LLVMContext>()), irb_(getContext()) {
|
||||
// Manually map types to LLVM types.
|
||||
ByteTy_ = llvm::Type::getInt8Ty(getContext());
|
||||
CharTy_ = llvm::Type::getInt8Ty(getContext());
|
||||
|
|
@ -490,7 +478,7 @@ LLVMCodeGenImpl::LLVMCodeGenImpl(
|
|||
emitKernel(stmt, params);
|
||||
|
||||
jit_->addModule(std::move(module_), std::move(context_));
|
||||
auto sym = jit_->findSymbol(kernel_func_name_);
|
||||
auto sym = jit_->findSymbol("wrapper");
|
||||
kernelAddress_ = assertSuccess(sym.getAddress());
|
||||
}
|
||||
|
||||
|
|
@ -522,7 +510,7 @@ void LLVMCodeGenImpl::emitWrapper(const std::vector<llvm::Type*>& params) {
|
|||
auto wrapper = llvm::Function::Create(
|
||||
llvm::FunctionType::get(IntTy_, {voidPtrPtrTy}, false),
|
||||
llvm::Function::ExternalLinkage,
|
||||
kernel_func_name_,
|
||||
"wrapper",
|
||||
module_.get());
|
||||
auto wrapBB = llvm::BasicBlock::Create(getContext(), "wrapBB", wrapper);
|
||||
irb_.SetInsertPoint(wrapBB);
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user