pytorch/test/cpp/jit/test_backend_compiler_preprocess.cpp
Kimish Patel 2ce21b2e61 [Pytorch backend delegation] Preprocess to accept (#58873)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/58873

BackenDebugInforRecorder

Prior to this PR:
In order to generate debug handles corresponding to the graph being
lowered, backend's preprocess will call generate_debug_handles and will
get map of Node*-to-debug_handles.
In order to facilitate this, to_backend will own
BackendDebugInfoRecorder and initialize thread local pointer to it.
generate_debug_handle function will query thread local pointer to see if
there is a valid BackendDebugInforRecorder for the context. If there is
it will generate debug handles.

After this PR:
Signature of preprocess is changed such that backends have to register
preprocess that accepts instance of BackendDebugInfoRecorder by
reference. generate_debug_handles is no more a free function but becomes
part of the API of BackendDebugInfoRecorder. Now backend's preprocess
function will call generate_debug_handles on BackendDebugInfoRecorder
instead of free function.

Reason for this change:
With RAII that initializes thread local pointer, results in a lose
contract with backends, which may result in backends not storing
debug information. Making it part of API results in
backends having to be aware of BackendDebugInfoRecorder and explicitly
chosing not to generate/store debug information if they chose to do so.

Test Plan:
backend tests

Imported from OSS

Reviewed By: jbschlosser, raziel

Differential Revision: D28648613

fbshipit-source-id: c9b7e7bf0f78e87023ea7bc08612cf893b08cb98
2021-06-11 10:16:00 -07:00

78 lines
2.8 KiB
C++

#include <torch/csrc/jit/backends/backend.h>
#include <torch/csrc/jit/backends/backend_preprocess.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/passes/inliner.h>
namespace torch {
namespace jit {
namespace {
// For this backend, the actual compilation happens in preprocess function AOT.
// Put here for demonstration of backend
// as a whole piece. It's used when compilation is required. A dummy function
// can be passed when there's no usage of compilation in runtime backend lib.
c10::IValue preprocess(
const Module& mod,
const c10::Dict<IValue, IValue>& method_compile_spec,
const BackendDebugHandleGenerator& generate_debug_handles) {
// The output of this process would produce a dictionary
// Key: method name.
// Val: compiled blob (represented by a string).
c10::Dict<IValue, IValue> compiled(StringType::get(), StringType::get());
for (const auto& method : mod.get_methods()) {
auto graph = method.function().graph()->copy();
// Must inline the graph for debug info map.
Inline(*graph);
// This is here because to test module hierarchy we will have
// getattr nodes which after inlining dont serve any purpose.
// Without removing them we will run into compilation errors.
// So eliminate deadcode just remove those getattr nodes.
EliminateDeadCode(graph);
// NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
auto key = method.name();
auto node_debug_handles = generate_debug_handles(graph);
std::stringstream ss;
for (const auto& node : graph->nodes()) {
switch (node->kind()) {
case prim::Constant:
ss << node->kind().toDisplayString() << "#"
<< toIValue(node->output()).value();
ss << "<debug_handle>" << node_debug_handles[node];
break;
// NOLINTNEXTLINE(bugprone-branch-clone)
case aten::add:
ss << node->kind().toQualString();
ss << "<debug_handle>" << node_debug_handles[node];
break;
case aten::sub:
ss << node->kind().toQualString();
ss << "<debug_handle>" << node_debug_handles[node];
break;
default:
TORCH_CHECK(
false,
"The node of ",
node->kind().toQualString(),
" is not supported in this compiler. Source code: ",
node->sourceRange().str());
break;
}
ss << ",";
}
std::string blob = ss.str();
if (!blob.empty()) {
blob.pop_back();
}
compiled.insert(method.name(), blob);
}
return compiled;
}
constexpr auto backend_name = "backend_with_compiler_demo";
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
static auto pre_reg = backend_preprocess_register(backend_name, preprocess);
} // namespace
} // namespace jit
} // namespace torch