[pytorch] fix code analyzer for LLVM 9 & 10 (#42135)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/42135

Tested the code analyzer with LLVM 9 & 10 and fixed a couple issues:
- Rename local demangle() which is available as public API since LLVM 9;
- Fix falsely associated op registrations due to the `phi` instruction;

Test Plan: Imported from OSS

Reviewed By: iseeyuan

Differential Revision: D22795508

Pulled By: ljk53

fbshipit-source-id: 2d47af088acd3312a7ea5fd9361cdccd48940fe6
This commit is contained in:
Jiakai Liu 2020-07-28 14:55:26 -07:00 committed by Facebook GitHub Bot
parent fd9205e14b
commit 8ddd2c4e1b
2 changed files with 59 additions and 21 deletions

View File

@ -2,6 +2,7 @@
#include <iostream>
#include <ATen/core/op_registration/hacky_wrapper_for_legacy_signatures.h>
#include <c10/core/TensorOptions.h>
#include <torch/library.h>
@ -60,6 +61,11 @@ Tensor FF_op(const Tensor& self) {
return self;
}
// GG -> FF
Tensor GG_op(const Tensor& self) {
return call_FF_op(self);
}
namespace {
// NB: Some of these registrations (AA, EE) are not what you
@ -93,11 +99,14 @@ TORCH_LIBRARY_FRAGMENT_THIS_API_IS_FOR_PER_OP_REGISTRATION_ONLY(_test, m) {
TORCH_LIBRARY_IMPL(_test, CPU, m) {
m.impl_UNBOXED("EE", EE_op);
m.impl("FF", torch::CppFunction::makeUnboxedOnly(FF_op));
m.impl("FF",
torch::dispatch(DispatchKey::CPU,
torch::CppFunction::makeUnboxedOnly(FF_op))
);
m.impl("GG",
[] (Tensor a) -> Tensor {
return call_FF_op(a);
});
torch::dispatch(DispatchKey::CPU,
c10::impl::hacky_wrapper_for_legacy_signatures(TORCH_FN((GG_op))))
);
m.impl("HH",
[] (Tensor a) -> Tensor {
return a;

View File

@ -195,7 +195,9 @@ using PATH = std::unordered_map<std::string,
std::unordered_map<std::string, std::string>>;
// Referenced the logic in llvm-cxxfilt.cpp.
std::string demangle(const std::string& mangled) {
// Starting from LLVM 9 it provides a `demangle()` API. Here we keep our ad-hoc
// version for backward compatibility.
std::string _demangle(const std::string& mangled) {
int status;
const char* decorated = mangled.c_str();
size_t decoratedLength = mangled.length();
@ -275,7 +277,7 @@ private:
SET roots;
for (const auto& F : visibleFuncs) {
std::string name = F->getName();
auto demangled = demangle(name);
auto demangled = _demangle(name);
if (RootSymbolPatternLoc.pattern->match(demangled)) {
roots.insert(name);
if (Verbose) {
@ -299,12 +301,12 @@ private:
visibleFuncs->insert(&F);
}
std::string caller = F.getName();
std::string callerDemangled = demangle(caller);
std::string callerDemangled = _demangle(caller);
for (BasicBlock& BB : F) {
for (Instruction& I : BB) {
scanReferredFunctions(I, [&](Function* func) -> void {
std::string callee = func->getName();
std::string calleeDemangled = demangle(callee);
std::string calleeDemangled = _demangle(callee);
(*deps)[caller].insert(callee);
if (Verbose > 1) {
std::cerr << "[DEBUG][FUNC_CALL] " << callerDemangled << " => "
@ -393,7 +395,7 @@ private:
// APIs are almost always in the same function.
static void scanConnectedNodes(
Value* src,
const VALUE_SET& blocked,
VALUE_SET blocked,
const std::function<void(Value*)>& CB, VALUE_MAP* debugPath) {
std::deque<Value*> worklist;
SmallPtrSet<Value*, 16> visited;
@ -418,6 +420,32 @@ private:
}
};
auto blockSiblingOperands = [&](User* U, Value* V) -> void {
// This is to handle a special case only appears in LLVM 9 (not in 5 - 8
// and 10), where it can falsely associate unrelated PyTorch op
// registrations.
//
// If the value `V` is used by a PHI-node `U`, then we should stop
// crawling `U`'s operands, i.e. `V`'s siblings in `U`. E.g.:
//
// 114: ; preds = %111, %109
// %115 = phi i32 [ %110, %109 ], [ %112, %111 ]
//
// `%115` might take the value of `%110` or `%112`, depending on from
// which label it comes. Assuming `V` is `%110` and `U` is `%115`, we can
// continue to scan `%115` but should not crawl `%112` as it does not
// directly pass data from `%110` to `%112` (and vice versa).
//
// NB: we probably should do the same for other LLVM instructions with
// this kind of selective semantics. But for the purpose of analyzing
// PyTorch registrations it seems to be sufficent for now.
if (isa<PHINode>(U)) {
for (auto& S : U->operands()) {
blocked.insert(S);
}
}
};
auto expandUsers = [&](Value* V) -> void {
// If the value is not constant, then the user of the value might pass
// other value into it, e.g.:
@ -434,6 +462,7 @@ private:
}
for (auto U : V->users()) {
insert(U, V);
blockSiblingOperands(U, V);
}
};
@ -524,7 +553,7 @@ private:
if (!visitedOps->empty()) {
if (Verbose) {
std::cerr << "[INFO] ignore extra op schema str: " << *schemaStr
<< " in: " << demangle(src->getFunction()->getName())
<< " in: " << _demangle(src->getFunction()->getName())
<< ", because already found valid op schema str: "
<< *visitedOps->begin() << std::endl;
}
@ -544,7 +573,7 @@ private:
(*visitedFunctions).insert(F->getName());
}
if (Verbose > 1) {
std::cerr << "[DEBUG][FUNC] " << demangle(F->getName()) << std::endl;
std::cerr << "[DEBUG][FUNC] " << _demangle(F->getName()) << std::endl;
printDebugPath(debugPath.get(), src, V);
}
}
@ -619,7 +648,7 @@ private:
std::cerr << op << " ";
}
std::cerr << ") in a registration call in function: "
<< demangle(I->getFunction()->getName())
<< _demangle(I->getFunction()->getName())
<< " contextualNamespace: " << contextualNamespace
<< std::endl;
}
@ -628,7 +657,7 @@ private:
if (visitedFunctions.empty()) {
std::cerr << "[WARNING] could not find registered function for op: "
<< op << " in function: "
<< demangle(I->getFunction()->getName())
<< _demangle(I->getFunction()->getName())
<< " contextualNamespace: " << contextualNamespace
<< std::endl;
}
@ -636,7 +665,7 @@ private:
(*schemaStrToFunctions)[op].insert(func);
if (Verbose) {
std::cerr << "[DEBUG][OP_REG] " << op << " => "
<< demangle(func) << std::endl;
<< _demangle(func) << std::endl;
}
}
}
@ -644,7 +673,7 @@ private:
}
static std::string inferContextualNamespace(Instruction* I) {
auto functionName = demangle(I->getFunction()->getName());
auto functionName = _demangle(I->getFunction()->getName());
for (auto& pattern : TorchLibraryInitPattern) {
if (!pattern.pattern->match(functionName)) {
continue;
@ -703,13 +732,13 @@ private:
std::cerr << op << " ";
}
std::cerr << ") in a invocation call in function: "
<< demangle(caller) << std::endl;
<< _demangle(caller) << std::endl;
}
for (const auto& op : visitedOps) {
opSchemaStrs->insert(op);
(*functionToSchemaStrs)[caller].insert(op);
if (Verbose) {
std::cerr << "[DEBUG][OP_CALL] " << demangle(caller) << " => "
std::cerr << "[DEBUG][OP_CALL] " << _demangle(caller) << " => "
<< op << std::endl;
}
}
@ -790,7 +819,7 @@ private:
static void printDebugValue(Value* V) {
if (auto F = dyn_cast<Function>(V)) {
std::cerr << "[FUNC] " << demangle(F->getName());
std::cerr << "[FUNC] " << _demangle(F->getName());
} else if (isa<Constant>(V)) {
std::cerr << "[CONST] " << *V;
} else if (isa<Instruction>(V)) {
@ -806,14 +835,14 @@ private:
std::ostream& out, const SET& keys, const GRAPH& graph,
const PATH* path) {
for (const auto& K : keys) {
out << "- name: " << demangle(K) << std::endl;
out << "- name: " << _demangle(K) << std::endl;
auto it = graph.find(K);
if (it == graph.end() || it->second.empty()) {
continue;
}
out << " depends:" << std::endl;
for (const auto& value : it->second) {
out << " - name: " << demangle(value) << std::endl;
out << " - name: " << _demangle(value) << std::endl;
if (path) {
std::vector<std::string> rpath;
for (std::string prev = value;
@ -821,7 +850,7 @@ private:
prev = path->at(K).at(prev));
out << " path:" << std::endl;
for (auto pit = rpath.rbegin(); pit != rpath.rend(); ++pit) {
out << " - " << demangle(*pit) << std::endl;
out << " - " << _demangle(*pit) << std::endl;
}
}
}