mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[pytorch] fix code analyzer for LLVM 9 & 10 (#42135)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/42135 Tested the code analyzer with LLVM 9 & 10 and fixed a couple issues: - Rename local demangle() which is available as public API since LLVM 9; - Fix falsely associated op registrations due to the `phi` instruction; Test Plan: Imported from OSS Reviewed By: iseeyuan Differential Revision: D22795508 Pulled By: ljk53 fbshipit-source-id: 2d47af088acd3312a7ea5fd9361cdccd48940fe6
This commit is contained in:
parent
fd9205e14b
commit
8ddd2c4e1b
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
#include <iostream>
|
||||
|
||||
#include <ATen/core/op_registration/hacky_wrapper_for_legacy_signatures.h>
|
||||
#include <c10/core/TensorOptions.h>
|
||||
#include <torch/library.h>
|
||||
|
||||
|
|
@ -60,6 +61,11 @@ Tensor FF_op(const Tensor& self) {
|
|||
return self;
|
||||
}
|
||||
|
||||
// GG -> FF
|
||||
Tensor GG_op(const Tensor& self) {
|
||||
return call_FF_op(self);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// NB: Some of these registrations (AA, EE) are not what you
|
||||
|
|
@ -93,11 +99,14 @@ TORCH_LIBRARY_FRAGMENT_THIS_API_IS_FOR_PER_OP_REGISTRATION_ONLY(_test, m) {
|
|||
|
||||
TORCH_LIBRARY_IMPL(_test, CPU, m) {
|
||||
m.impl_UNBOXED("EE", EE_op);
|
||||
m.impl("FF", torch::CppFunction::makeUnboxedOnly(FF_op));
|
||||
m.impl("FF",
|
||||
torch::dispatch(DispatchKey::CPU,
|
||||
torch::CppFunction::makeUnboxedOnly(FF_op))
|
||||
);
|
||||
m.impl("GG",
|
||||
[] (Tensor a) -> Tensor {
|
||||
return call_FF_op(a);
|
||||
});
|
||||
torch::dispatch(DispatchKey::CPU,
|
||||
c10::impl::hacky_wrapper_for_legacy_signatures(TORCH_FN((GG_op))))
|
||||
);
|
||||
m.impl("HH",
|
||||
[] (Tensor a) -> Tensor {
|
||||
return a;
|
||||
|
|
|
|||
|
|
@ -195,7 +195,9 @@ using PATH = std::unordered_map<std::string,
|
|||
std::unordered_map<std::string, std::string>>;
|
||||
|
||||
// Referenced the logic in llvm-cxxfilt.cpp.
|
||||
std::string demangle(const std::string& mangled) {
|
||||
// Starting from LLVM 9 it provides a `demangle()` API. Here we keep our ad-hoc
|
||||
// version for backward compatibility.
|
||||
std::string _demangle(const std::string& mangled) {
|
||||
int status;
|
||||
const char* decorated = mangled.c_str();
|
||||
size_t decoratedLength = mangled.length();
|
||||
|
|
@ -275,7 +277,7 @@ private:
|
|||
SET roots;
|
||||
for (const auto& F : visibleFuncs) {
|
||||
std::string name = F->getName();
|
||||
auto demangled = demangle(name);
|
||||
auto demangled = _demangle(name);
|
||||
if (RootSymbolPatternLoc.pattern->match(demangled)) {
|
||||
roots.insert(name);
|
||||
if (Verbose) {
|
||||
|
|
@ -299,12 +301,12 @@ private:
|
|||
visibleFuncs->insert(&F);
|
||||
}
|
||||
std::string caller = F.getName();
|
||||
std::string callerDemangled = demangle(caller);
|
||||
std::string callerDemangled = _demangle(caller);
|
||||
for (BasicBlock& BB : F) {
|
||||
for (Instruction& I : BB) {
|
||||
scanReferredFunctions(I, [&](Function* func) -> void {
|
||||
std::string callee = func->getName();
|
||||
std::string calleeDemangled = demangle(callee);
|
||||
std::string calleeDemangled = _demangle(callee);
|
||||
(*deps)[caller].insert(callee);
|
||||
if (Verbose > 1) {
|
||||
std::cerr << "[DEBUG][FUNC_CALL] " << callerDemangled << " => "
|
||||
|
|
@ -393,7 +395,7 @@ private:
|
|||
// APIs are almost always in the same function.
|
||||
static void scanConnectedNodes(
|
||||
Value* src,
|
||||
const VALUE_SET& blocked,
|
||||
VALUE_SET blocked,
|
||||
const std::function<void(Value*)>& CB, VALUE_MAP* debugPath) {
|
||||
std::deque<Value*> worklist;
|
||||
SmallPtrSet<Value*, 16> visited;
|
||||
|
|
@ -418,6 +420,32 @@ private:
|
|||
}
|
||||
};
|
||||
|
||||
auto blockSiblingOperands = [&](User* U, Value* V) -> void {
|
||||
// This is to handle a special case only appears in LLVM 9 (not in 5 - 8
|
||||
// and 10), where it can falsely associate unrelated PyTorch op
|
||||
// registrations.
|
||||
//
|
||||
// If the value `V` is used by a PHI-node `U`, then we should stop
|
||||
// crawling `U`'s operands, i.e. `V`'s siblings in `U`. E.g.:
|
||||
//
|
||||
// 114: ; preds = %111, %109
|
||||
// %115 = phi i32 [ %110, %109 ], [ %112, %111 ]
|
||||
//
|
||||
// `%115` might take the value of `%110` or `%112`, depending on from
|
||||
// which label it comes. Assuming `V` is `%110` and `U` is `%115`, we can
|
||||
// continue to scan `%115` but should not crawl `%112` as it does not
|
||||
// directly pass data from `%110` to `%112` (and vice versa).
|
||||
//
|
||||
// NB: we probably should do the same for other LLVM instructions with
|
||||
// this kind of selective semantics. But for the purpose of analyzing
|
||||
// PyTorch registrations it seems to be sufficent for now.
|
||||
if (isa<PHINode>(U)) {
|
||||
for (auto& S : U->operands()) {
|
||||
blocked.insert(S);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
auto expandUsers = [&](Value* V) -> void {
|
||||
// If the value is not constant, then the user of the value might pass
|
||||
// other value into it, e.g.:
|
||||
|
|
@ -434,6 +462,7 @@ private:
|
|||
}
|
||||
for (auto U : V->users()) {
|
||||
insert(U, V);
|
||||
blockSiblingOperands(U, V);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -524,7 +553,7 @@ private:
|
|||
if (!visitedOps->empty()) {
|
||||
if (Verbose) {
|
||||
std::cerr << "[INFO] ignore extra op schema str: " << *schemaStr
|
||||
<< " in: " << demangle(src->getFunction()->getName())
|
||||
<< " in: " << _demangle(src->getFunction()->getName())
|
||||
<< ", because already found valid op schema str: "
|
||||
<< *visitedOps->begin() << std::endl;
|
||||
}
|
||||
|
|
@ -544,7 +573,7 @@ private:
|
|||
(*visitedFunctions).insert(F->getName());
|
||||
}
|
||||
if (Verbose > 1) {
|
||||
std::cerr << "[DEBUG][FUNC] " << demangle(F->getName()) << std::endl;
|
||||
std::cerr << "[DEBUG][FUNC] " << _demangle(F->getName()) << std::endl;
|
||||
printDebugPath(debugPath.get(), src, V);
|
||||
}
|
||||
}
|
||||
|
|
@ -619,7 +648,7 @@ private:
|
|||
std::cerr << op << " ";
|
||||
}
|
||||
std::cerr << ") in a registration call in function: "
|
||||
<< demangle(I->getFunction()->getName())
|
||||
<< _demangle(I->getFunction()->getName())
|
||||
<< " contextualNamespace: " << contextualNamespace
|
||||
<< std::endl;
|
||||
}
|
||||
|
|
@ -628,7 +657,7 @@ private:
|
|||
if (visitedFunctions.empty()) {
|
||||
std::cerr << "[WARNING] could not find registered function for op: "
|
||||
<< op << " in function: "
|
||||
<< demangle(I->getFunction()->getName())
|
||||
<< _demangle(I->getFunction()->getName())
|
||||
<< " contextualNamespace: " << contextualNamespace
|
||||
<< std::endl;
|
||||
}
|
||||
|
|
@ -636,7 +665,7 @@ private:
|
|||
(*schemaStrToFunctions)[op].insert(func);
|
||||
if (Verbose) {
|
||||
std::cerr << "[DEBUG][OP_REG] " << op << " => "
|
||||
<< demangle(func) << std::endl;
|
||||
<< _demangle(func) << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -644,7 +673,7 @@ private:
|
|||
}
|
||||
|
||||
static std::string inferContextualNamespace(Instruction* I) {
|
||||
auto functionName = demangle(I->getFunction()->getName());
|
||||
auto functionName = _demangle(I->getFunction()->getName());
|
||||
for (auto& pattern : TorchLibraryInitPattern) {
|
||||
if (!pattern.pattern->match(functionName)) {
|
||||
continue;
|
||||
|
|
@ -703,13 +732,13 @@ private:
|
|||
std::cerr << op << " ";
|
||||
}
|
||||
std::cerr << ") in a invocation call in function: "
|
||||
<< demangle(caller) << std::endl;
|
||||
<< _demangle(caller) << std::endl;
|
||||
}
|
||||
for (const auto& op : visitedOps) {
|
||||
opSchemaStrs->insert(op);
|
||||
(*functionToSchemaStrs)[caller].insert(op);
|
||||
if (Verbose) {
|
||||
std::cerr << "[DEBUG][OP_CALL] " << demangle(caller) << " => "
|
||||
std::cerr << "[DEBUG][OP_CALL] " << _demangle(caller) << " => "
|
||||
<< op << std::endl;
|
||||
}
|
||||
}
|
||||
|
|
@ -790,7 +819,7 @@ private:
|
|||
|
||||
static void printDebugValue(Value* V) {
|
||||
if (auto F = dyn_cast<Function>(V)) {
|
||||
std::cerr << "[FUNC] " << demangle(F->getName());
|
||||
std::cerr << "[FUNC] " << _demangle(F->getName());
|
||||
} else if (isa<Constant>(V)) {
|
||||
std::cerr << "[CONST] " << *V;
|
||||
} else if (isa<Instruction>(V)) {
|
||||
|
|
@ -806,14 +835,14 @@ private:
|
|||
std::ostream& out, const SET& keys, const GRAPH& graph,
|
||||
const PATH* path) {
|
||||
for (const auto& K : keys) {
|
||||
out << "- name: " << demangle(K) << std::endl;
|
||||
out << "- name: " << _demangle(K) << std::endl;
|
||||
auto it = graph.find(K);
|
||||
if (it == graph.end() || it->second.empty()) {
|
||||
continue;
|
||||
}
|
||||
out << " depends:" << std::endl;
|
||||
for (const auto& value : it->second) {
|
||||
out << " - name: " << demangle(value) << std::endl;
|
||||
out << " - name: " << _demangle(value) << std::endl;
|
||||
if (path) {
|
||||
std::vector<std::string> rpath;
|
||||
for (std::string prev = value;
|
||||
|
|
@ -821,7 +850,7 @@ private:
|
|||
prev = path->at(K).at(prev));
|
||||
out << " path:" << std::endl;
|
||||
for (auto pit = rpath.rbegin(); pit != rpath.rend(); ++pit) {
|
||||
out << " - " << demangle(*pit) << std::endl;
|
||||
out << " - " << _demangle(*pit) << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user