Revert D31445797: [nnc] Added a cache to use singleton instances of PytorchLLVMJIT for every triple,cpu,attrs combination

Test Plan: revert-hammer

Differential Revision:
D31445797 (7e5ef5e517)

Original commit changeset: 4e1450100928

fbshipit-source-id: fc13b34dbb66c7a22816eb46cf6d98ae9f332d39
This commit is contained in:
Raghavan Raman 2021-10-08 12:36:18 -07:00 committed by Facebook GitHub Bot
parent 097fdcdf0c
commit 2e6fa0261f
5 changed files with 43 additions and 121 deletions

View File

@ -1804,12 +1804,6 @@ TEST(LLVM, CodeGenKernelFuncName) {
// Check that the kernel function name used by LLVMCodeGen
// is not empty.
ASSERT_NE(cg.kernel_func_name(), "");
// Do another codegen and ensure that the kernel func name is different
// from the one above.
LLVMCodeGen cg2(store, {a, b});
ASSERT_NE(cg2.kernel_func_name(), "");
ASSERT_NE(cg.kernel_func_name(), cg2.kernel_func_name());
}
} // namespace jit

View File

@ -99,10 +99,6 @@ class TORCH_API CodeGen {
return kernel_func_name_;
}
void set_kernel_func_name(std::string kernel_func_name) {
kernel_func_name_ = std::move(kernel_func_name);
}
protected:
static void* argToPtr(const BufferArg& bufferArg, const CallArg& callArg);

View File

@ -152,11 +152,13 @@ struct FunctionCallee {
class LLVMCodeGenCallee {
public:
LLVMCodeGenCallee(llvm::orc::PytorchLLVMJIT* jit, void* kernelAddress)
: jit_(jit), kernelAddress_(kernelAddress) {}
LLVMCodeGenCallee(
std::unique_ptr<llvm::orc::PytorchLLVMJIT> jit,
void* kernelAddress)
: jit_(std::move(jit)), kernelAddress_(kernelAddress) {}
llvm::orc::PytorchLLVMJIT* getJIT() {
return jit_;
return jit_.get();
}
void* getKernelAddress() {
@ -168,13 +170,7 @@ class LLVMCodeGenCallee {
}
private:
// This is not necessarily needed in the callee. We just need the JIT to be
// alive for the call to this kernel to work. Since the JIT is owned by the
// PytorchLLVMJITCache, we don't need to save them here.
//
// Retaining a pointer to the JIT here only to denote that this is necessary
// for the calls to work.
llvm::orc::PytorchLLVMJIT* jit_;
std::unique_ptr<llvm::orc::PytorchLLVMJIT> jit_;
void* kernelAddress_;
};
@ -182,7 +178,7 @@ class LLVMCodeGenImpl : public IRVisitor {
private:
std::unique_ptr<llvm::LLVMContext> context_;
llvm::IRBuilder<> irb_;
llvm::orc::PytorchLLVMJIT* jit_;
std::unique_ptr<llvm::orc::PytorchLLVMJIT> jit_;
std::unique_ptr<llvm::Module> module_;
llvm::Function* fn_;
llvm::BasicBlock* bb_;
@ -239,10 +235,13 @@ class LLVMCodeGenImpl : public IRVisitor {
at::Device device,
Dtype dtype,
std::string kernel_func_name,
llvm::orc::PytorchLLVMJIT* jit);
c10::optional<std::string> triple,
c10::optional<std::string> cpu,
c10::optional<std::string> attrs);
~LLVMCodeGenImpl() = default;
llvm::JITTargetAddress getKernelAddress() const;
std::unique_ptr<llvm::orc::PytorchLLVMJIT> releaseJIT();
void visit(AddPtr v) override;
void visit(SubPtr v) override;
@ -340,15 +339,18 @@ LLVMCodeGen::LLVMCodeGen(
c10::optional<std::string> triple,
c10::optional<std::string> cpu,
c10::optional<std::string> attrs)
: CodeGen(stmt, args, device) {
auto jit = llvm::orc::PytorchLLVMJITCache::getPytorchLLVMJITInstance(
triple, cpu, attrs);
auto unique_kernel_func_name = jit->getUniqueFunctionName(kernel_func_name);
set_kernel_func_name(unique_kernel_func_name);
impl_ = std::make_unique<LLVMCodeGenImpl>(
stmt, args, device, dtype, unique_kernel_func_name, jit);
: CodeGen(stmt, args, device, kernel_func_name),
impl_(std::make_unique<LLVMCodeGenImpl>(
stmt,
args,
device,
dtype,
this->kernel_func_name(),
triple,
cpu,
attrs)) {
callee_ = std::make_unique<LLVMCodeGenCallee>(
jit, (void*)impl_->getKernelAddress());
impl_->releaseJIT(), (void*)impl_->getKernelAddress());
}
void LLVMCodeGen::cleanup_memory() {
@ -410,16 +412,27 @@ llvm::JITTargetAddress LLVMCodeGenImpl::getKernelAddress() const {
return kernelAddress_;
}
std::unique_ptr<llvm::orc::PytorchLLVMJIT> LLVMCodeGenImpl::releaseJIT() {
return std::move(jit_);
}
namespace {
// Global mutex to protect LLVM initialization. TargetRegistry::lookupTarget
// in particular is not thread-safe.
static std::mutex llvmInitMutex;
} // namespace
LLVMCodeGenImpl::LLVMCodeGenImpl(
StmtPtr stmt,
const std::vector<CodeGen::BufferArg>& args,
at::Device device,
Dtype dtype,
std::string kernel_func_name,
llvm::orc::PytorchLLVMJIT* jit)
c10::optional<std::string> triple,
c10::optional<std::string> cpu,
c10::optional<std::string> attrs)
: context_(std::make_unique<llvm::LLVMContext>()),
irb_(getContext()),
jit_(jit),
kernel_func_name_(std::move(kernel_func_name)) {
// Manually map types to LLVM types.
ByteTy_ = llvm::Type::getInt8Ty(getContext());
@ -434,6 +447,14 @@ LLVMCodeGenImpl::LLVMCodeGenImpl(
VoidTy_ = llvm::Type::getVoidTy(getContext());
BoolTy_ = ByteTy_;
{
std::lock_guard<std::mutex> g(llvmInitMutex);
llvm::InitializeAllTargets();
llvm::InitializeAllTargetMCs();
llvm::InitializeAllAsmPrinters();
jit_ = std::make_unique<llvm::orc::PytorchLLVMJIT>(triple, cpu, attrs);
}
module_ = std::make_unique<llvm::Module>("pytorch", getContext());
module_->setDataLayout(jit_->getDataLayout());
module_->setTargetTriple(jit_->getTargetMachine().getTargetTriple().str());

View File

@ -18,7 +18,6 @@
#include <llvm/Support/CFGUpdate.h>
#include <llvm/Support/DynamicLibrary.h>
#include <llvm/Support/Host.h>
#include <llvm/Support/TargetSelect.h>
#include <llvm/Support/raw_ostream.h>
#include <llvm/Target/TargetMachine.h>
@ -316,50 +315,6 @@ void dumpCFG(const llvm::cfg::Update<llvm::BasicBlock*>& update) {
}
#endif
std::string PytorchLLVMJIT::getUniqueFunctionName(const std::string& name) {
std::lock_guard<std::mutex> lock(mutex_);
auto it = existing_functions_.find(name);
if (it == existing_functions_.end()) {
existing_functions_[name] = 0;
return name;
}
existing_functions_[name] = it->second + 1;
std::string unique_name = name + "_" + std::to_string(it->second + 1);
return unique_name;
}
std::unordered_map<std::string, std::unique_ptr<PytorchLLVMJIT>>
PytorchLLVMJITCache::jit_cache_;
std::mutex PytorchLLVMJITCache::mutex_;
PytorchLLVMJIT* PytorchLLVMJITCache::getPytorchLLVMJITInstance(
c10::optional<std::string> triple,
c10::optional<std::string> cpu,
c10::optional<std::string> attrs) {
std::string cacheKey = getCacheKey(triple, cpu, attrs);
std::lock_guard<std::mutex> lock(mutex_);
auto it = jit_cache_.find(cacheKey);
if (it != jit_cache_.end()) {
return it->second.get();
}
llvm::InitializeAllTargets();
llvm::InitializeAllTargetMCs();
llvm::InitializeAllAsmPrinters();
auto jit = std::make_unique<PytorchLLVMJIT>(triple, cpu, attrs);
auto jit_to_return = jit.get();
jit_cache_[cacheKey] = std::move(jit);
return jit_to_return;
}
std::string PytorchLLVMJITCache::getCacheKey(
c10::optional<std::string> triple,
c10::optional<std::string> cpu,
c10::optional<std::string> attrs) {
return "triple:" + std::string(triple ? *triple : "") +
"cpu:" + std::string(cpu ? *cpu : "") +
"attrs:" + std::string(attrs ? *attrs : "");
}
} // end namespace orc
} // end namespace llvm

View File

@ -12,7 +12,6 @@
#include <memory>
#include <string>
#include <unordered_map>
namespace torch {
namespace jit {
@ -61,34 +60,12 @@ class TORCH_API PytorchLLVMJIT {
c10::optional<std::string> attrs);
~PytorchLLVMJIT();
// While creating any function in the module that is being added to this JIT,
// get a unique name by calling `getUniqueFunctionName()` method. That
// ensures that there is no duplicate function names in this JIT.
void addModule(std::unique_ptr<Module> M, std::unique_ptr<LLVMContext> C);
JITSymbol findSymbol(const std::string Name);
bool hasSymbol(const std::string& Name);
// Returns a function name that is unique in this JIT (among the function
// names tracked by calling this method).
//
// When getUniqueFunctionName is called with a name that has never been used
// before, it returns the input name as is. When it is called with the same
// name subsequently, it appends "_<num>" to the name to uniquify it.
//
// For example:
// * First call to getUniqueFunctionName("func") => returns "func"
// * Second call to getUniqueFunctionName("func") => returns "func_1"
// * Third call to getUniqueFunctionName("func") => returns "func_2"
//
// NOTE: This method does not keep track of all the functions that are added
// to this JIT. It only keeps track of the function names that are uniquified
// by calling this method directly.
//
// Recommendation: Call this method before adding any function to this JIT.
std::string getUniqueFunctionName(const std::string& name);
TargetMachine& getTargetMachine();
const DataLayout& getDataLayout();
@ -96,27 +73,6 @@ class TORCH_API PytorchLLVMJIT {
private:
// Use the PImpl idiom here to hide the no-rtti parts of the JIT structure.
std::unique_ptr<PytorchLLVMJITImpl> impl_;
std::mutex mutex_;
std::unordered_map<std::string, int> existing_functions_;
};
class TORCH_API PytorchLLVMJITCache {
public:
static PytorchLLVMJIT* getPytorchLLVMJITInstance(
c10::optional<std::string> triple,
c10::optional<std::string> cpu,
c10::optional<std::string> attrs);
private:
static std::unordered_map<std::string, std::unique_ptr<PytorchLLVMJIT>>
jit_cache_;
static std::mutex mutex_;
static std::string getCacheKey(
c10::optional<std::string> triple,
c10::optional<std::string> cpu,
c10::optional<std::string> attrs);
};
} // end namespace orc