mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert D31445797: [nnc] Added a cache to use singleton instances of PytorchLLVMJIT for every triple,cpu,attrs combination
Test Plan: revert-hammer
Differential Revision:
D31445797 (7e5ef5e517)
Original commit changeset: 4e1450100928
fbshipit-source-id: fc13b34dbb66c7a22816eb46cf6d98ae9f332d39
This commit is contained in:
parent
097fdcdf0c
commit
2e6fa0261f
|
|
@ -1804,12 +1804,6 @@ TEST(LLVM, CodeGenKernelFuncName) {
|
|||
// Check that the kernel function name used by LLVMCodeGen
|
||||
// is not empty.
|
||||
ASSERT_NE(cg.kernel_func_name(), "");
|
||||
|
||||
// Do another codegen and ensure that the kernel func name is different
|
||||
// from the one above.
|
||||
LLVMCodeGen cg2(store, {a, b});
|
||||
ASSERT_NE(cg2.kernel_func_name(), "");
|
||||
ASSERT_NE(cg.kernel_func_name(), cg2.kernel_func_name());
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
|
|
|
|||
|
|
@ -99,10 +99,6 @@ class TORCH_API CodeGen {
|
|||
return kernel_func_name_;
|
||||
}
|
||||
|
||||
void set_kernel_func_name(std::string kernel_func_name) {
|
||||
kernel_func_name_ = std::move(kernel_func_name);
|
||||
}
|
||||
|
||||
protected:
|
||||
static void* argToPtr(const BufferArg& bufferArg, const CallArg& callArg);
|
||||
|
||||
|
|
|
|||
|
|
@ -152,11 +152,13 @@ struct FunctionCallee {
|
|||
|
||||
class LLVMCodeGenCallee {
|
||||
public:
|
||||
LLVMCodeGenCallee(llvm::orc::PytorchLLVMJIT* jit, void* kernelAddress)
|
||||
: jit_(jit), kernelAddress_(kernelAddress) {}
|
||||
LLVMCodeGenCallee(
|
||||
std::unique_ptr<llvm::orc::PytorchLLVMJIT> jit,
|
||||
void* kernelAddress)
|
||||
: jit_(std::move(jit)), kernelAddress_(kernelAddress) {}
|
||||
|
||||
llvm::orc::PytorchLLVMJIT* getJIT() {
|
||||
return jit_;
|
||||
return jit_.get();
|
||||
}
|
||||
|
||||
void* getKernelAddress() {
|
||||
|
|
@ -168,13 +170,7 @@ class LLVMCodeGenCallee {
|
|||
}
|
||||
|
||||
private:
|
||||
// This is not necessarily needed in the callee. We just need the JIT to be
|
||||
// alive for the call to this kernel to work. Since the JIT is owned by the
|
||||
// PytorchLLVMJITCache, we don't need to save them here.
|
||||
//
|
||||
// Retaining a pointer to the JIT here only to denote that this is necessary
|
||||
// for the calls to work.
|
||||
llvm::orc::PytorchLLVMJIT* jit_;
|
||||
std::unique_ptr<llvm::orc::PytorchLLVMJIT> jit_;
|
||||
void* kernelAddress_;
|
||||
};
|
||||
|
||||
|
|
@ -182,7 +178,7 @@ class LLVMCodeGenImpl : public IRVisitor {
|
|||
private:
|
||||
std::unique_ptr<llvm::LLVMContext> context_;
|
||||
llvm::IRBuilder<> irb_;
|
||||
llvm::orc::PytorchLLVMJIT* jit_;
|
||||
std::unique_ptr<llvm::orc::PytorchLLVMJIT> jit_;
|
||||
std::unique_ptr<llvm::Module> module_;
|
||||
llvm::Function* fn_;
|
||||
llvm::BasicBlock* bb_;
|
||||
|
|
@ -239,10 +235,13 @@ class LLVMCodeGenImpl : public IRVisitor {
|
|||
at::Device device,
|
||||
Dtype dtype,
|
||||
std::string kernel_func_name,
|
||||
llvm::orc::PytorchLLVMJIT* jit);
|
||||
c10::optional<std::string> triple,
|
||||
c10::optional<std::string> cpu,
|
||||
c10::optional<std::string> attrs);
|
||||
~LLVMCodeGenImpl() = default;
|
||||
|
||||
llvm::JITTargetAddress getKernelAddress() const;
|
||||
std::unique_ptr<llvm::orc::PytorchLLVMJIT> releaseJIT();
|
||||
|
||||
void visit(AddPtr v) override;
|
||||
void visit(SubPtr v) override;
|
||||
|
|
@ -340,15 +339,18 @@ LLVMCodeGen::LLVMCodeGen(
|
|||
c10::optional<std::string> triple,
|
||||
c10::optional<std::string> cpu,
|
||||
c10::optional<std::string> attrs)
|
||||
: CodeGen(stmt, args, device) {
|
||||
auto jit = llvm::orc::PytorchLLVMJITCache::getPytorchLLVMJITInstance(
|
||||
triple, cpu, attrs);
|
||||
auto unique_kernel_func_name = jit->getUniqueFunctionName(kernel_func_name);
|
||||
set_kernel_func_name(unique_kernel_func_name);
|
||||
impl_ = std::make_unique<LLVMCodeGenImpl>(
|
||||
stmt, args, device, dtype, unique_kernel_func_name, jit);
|
||||
: CodeGen(stmt, args, device, kernel_func_name),
|
||||
impl_(std::make_unique<LLVMCodeGenImpl>(
|
||||
stmt,
|
||||
args,
|
||||
device,
|
||||
dtype,
|
||||
this->kernel_func_name(),
|
||||
triple,
|
||||
cpu,
|
||||
attrs)) {
|
||||
callee_ = std::make_unique<LLVMCodeGenCallee>(
|
||||
jit, (void*)impl_->getKernelAddress());
|
||||
impl_->releaseJIT(), (void*)impl_->getKernelAddress());
|
||||
}
|
||||
|
||||
void LLVMCodeGen::cleanup_memory() {
|
||||
|
|
@ -410,16 +412,27 @@ llvm::JITTargetAddress LLVMCodeGenImpl::getKernelAddress() const {
|
|||
return kernelAddress_;
|
||||
}
|
||||
|
||||
std::unique_ptr<llvm::orc::PytorchLLVMJIT> LLVMCodeGenImpl::releaseJIT() {
|
||||
return std::move(jit_);
|
||||
}
|
||||
|
||||
namespace {
|
||||
// Global mutex to protect LLVM initialization. TargetRegistry::lookupTarget
|
||||
// in particular is not thread-safe.
|
||||
static std::mutex llvmInitMutex;
|
||||
} // namespace
|
||||
|
||||
LLVMCodeGenImpl::LLVMCodeGenImpl(
|
||||
StmtPtr stmt,
|
||||
const std::vector<CodeGen::BufferArg>& args,
|
||||
at::Device device,
|
||||
Dtype dtype,
|
||||
std::string kernel_func_name,
|
||||
llvm::orc::PytorchLLVMJIT* jit)
|
||||
c10::optional<std::string> triple,
|
||||
c10::optional<std::string> cpu,
|
||||
c10::optional<std::string> attrs)
|
||||
: context_(std::make_unique<llvm::LLVMContext>()),
|
||||
irb_(getContext()),
|
||||
jit_(jit),
|
||||
kernel_func_name_(std::move(kernel_func_name)) {
|
||||
// Manually map types to LLVM types.
|
||||
ByteTy_ = llvm::Type::getInt8Ty(getContext());
|
||||
|
|
@ -434,6 +447,14 @@ LLVMCodeGenImpl::LLVMCodeGenImpl(
|
|||
VoidTy_ = llvm::Type::getVoidTy(getContext());
|
||||
BoolTy_ = ByteTy_;
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> g(llvmInitMutex);
|
||||
llvm::InitializeAllTargets();
|
||||
llvm::InitializeAllTargetMCs();
|
||||
llvm::InitializeAllAsmPrinters();
|
||||
jit_ = std::make_unique<llvm::orc::PytorchLLVMJIT>(triple, cpu, attrs);
|
||||
}
|
||||
|
||||
module_ = std::make_unique<llvm::Module>("pytorch", getContext());
|
||||
module_->setDataLayout(jit_->getDataLayout());
|
||||
module_->setTargetTriple(jit_->getTargetMachine().getTargetTriple().str());
|
||||
|
|
|
|||
|
|
@ -18,7 +18,6 @@
|
|||
#include <llvm/Support/CFGUpdate.h>
|
||||
#include <llvm/Support/DynamicLibrary.h>
|
||||
#include <llvm/Support/Host.h>
|
||||
#include <llvm/Support/TargetSelect.h>
|
||||
#include <llvm/Support/raw_ostream.h>
|
||||
#include <llvm/Target/TargetMachine.h>
|
||||
|
||||
|
|
@ -316,50 +315,6 @@ void dumpCFG(const llvm::cfg::Update<llvm::BasicBlock*>& update) {
|
|||
}
|
||||
#endif
|
||||
|
||||
std::string PytorchLLVMJIT::getUniqueFunctionName(const std::string& name) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
auto it = existing_functions_.find(name);
|
||||
if (it == existing_functions_.end()) {
|
||||
existing_functions_[name] = 0;
|
||||
return name;
|
||||
}
|
||||
existing_functions_[name] = it->second + 1;
|
||||
std::string unique_name = name + "_" + std::to_string(it->second + 1);
|
||||
return unique_name;
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, std::unique_ptr<PytorchLLVMJIT>>
|
||||
PytorchLLVMJITCache::jit_cache_;
|
||||
std::mutex PytorchLLVMJITCache::mutex_;
|
||||
|
||||
PytorchLLVMJIT* PytorchLLVMJITCache::getPytorchLLVMJITInstance(
|
||||
c10::optional<std::string> triple,
|
||||
c10::optional<std::string> cpu,
|
||||
c10::optional<std::string> attrs) {
|
||||
std::string cacheKey = getCacheKey(triple, cpu, attrs);
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
auto it = jit_cache_.find(cacheKey);
|
||||
if (it != jit_cache_.end()) {
|
||||
return it->second.get();
|
||||
}
|
||||
llvm::InitializeAllTargets();
|
||||
llvm::InitializeAllTargetMCs();
|
||||
llvm::InitializeAllAsmPrinters();
|
||||
auto jit = std::make_unique<PytorchLLVMJIT>(triple, cpu, attrs);
|
||||
auto jit_to_return = jit.get();
|
||||
jit_cache_[cacheKey] = std::move(jit);
|
||||
return jit_to_return;
|
||||
}
|
||||
|
||||
std::string PytorchLLVMJITCache::getCacheKey(
|
||||
c10::optional<std::string> triple,
|
||||
c10::optional<std::string> cpu,
|
||||
c10::optional<std::string> attrs) {
|
||||
return "triple:" + std::string(triple ? *triple : "") +
|
||||
"cpu:" + std::string(cpu ? *cpu : "") +
|
||||
"attrs:" + std::string(attrs ? *attrs : "");
|
||||
}
|
||||
|
||||
} // end namespace orc
|
||||
} // end namespace llvm
|
||||
|
||||
|
|
|
|||
|
|
@ -12,7 +12,6 @@
|
|||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
|
@ -61,34 +60,12 @@ class TORCH_API PytorchLLVMJIT {
|
|||
c10::optional<std::string> attrs);
|
||||
~PytorchLLVMJIT();
|
||||
|
||||
// While creating any function in the module that is being added to this JIT,
|
||||
// get a unique name by calling `getUniqueFunctionName()` method. That
|
||||
// ensures that there is no duplicate function names in this JIT.
|
||||
void addModule(std::unique_ptr<Module> M, std::unique_ptr<LLVMContext> C);
|
||||
|
||||
JITSymbol findSymbol(const std::string Name);
|
||||
|
||||
bool hasSymbol(const std::string& Name);
|
||||
|
||||
// Returns a function name that is unique in this JIT (among the function
|
||||
// names tracked by calling this method).
|
||||
//
|
||||
// When getUniqueFunctionName is called with a name that has never been used
|
||||
// before, it returns the input name as is. When it is called with the same
|
||||
// name subsequently, it appends "_<num>" to the name to uniquify it.
|
||||
//
|
||||
// For example:
|
||||
// * First call to getUniqueFunctionName("func") => returns "func"
|
||||
// * Second call to getUniqueFunctionName("func") => returns "func_1"
|
||||
// * Third call to getUniqueFunctionName("func") => returns "func_2"
|
||||
//
|
||||
// NOTE: This method does not keep track of all the functions that are added
|
||||
// to this JIT. It only keeps track of the function names that are uniquified
|
||||
// by calling this method directly.
|
||||
//
|
||||
// Recommendation: Call this method before adding any function to this JIT.
|
||||
std::string getUniqueFunctionName(const std::string& name);
|
||||
|
||||
TargetMachine& getTargetMachine();
|
||||
|
||||
const DataLayout& getDataLayout();
|
||||
|
|
@ -96,27 +73,6 @@ class TORCH_API PytorchLLVMJIT {
|
|||
private:
|
||||
// Use the PImpl idiom here to hide the no-rtti parts of the JIT structure.
|
||||
std::unique_ptr<PytorchLLVMJITImpl> impl_;
|
||||
|
||||
std::mutex mutex_;
|
||||
std::unordered_map<std::string, int> existing_functions_;
|
||||
};
|
||||
|
||||
class TORCH_API PytorchLLVMJITCache {
|
||||
public:
|
||||
static PytorchLLVMJIT* getPytorchLLVMJITInstance(
|
||||
c10::optional<std::string> triple,
|
||||
c10::optional<std::string> cpu,
|
||||
c10::optional<std::string> attrs);
|
||||
|
||||
private:
|
||||
static std::unordered_map<std::string, std::unique_ptr<PytorchLLVMJIT>>
|
||||
jit_cache_;
|
||||
static std::mutex mutex_;
|
||||
|
||||
static std::string getCacheKey(
|
||||
c10::optional<std::string> triple,
|
||||
c10::optional<std::string> cpu,
|
||||
c10::optional<std::string> attrs);
|
||||
};
|
||||
|
||||
} // end namespace orc
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user