[Pytorch lite predictor] Use KinetoEdgeCPUProfiler for operator profiling. (#63367)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63367

This diff changes the way operator profiling is done in lite predictor
benchmarking binary.
Instead of using custom callbacks it uses KinetoEdgeCPUProfiler to profile
events and then generate operator level metric from it.
Since KinetoEvents do not contain cpu clock time, now we report only wallclock
time.
This unifies various profiling effort that we have for benchmarking purpose. In
production we will still use observer based mechanism, but the advantage of
using kineto profiler is that we get few other things for free, such as:
- chrome trace generation.
- operator level memory profiling (to be added)
- flop counts (to be added)

Furthermore possible we can use python post processing script to parse chrome
trace and generate output similar to torch.profiler. (To be done)

Test Plan:
aibench run
Model without debug info:
https://www.internalfb.com/intern/aibench/details/219598441154763
Model with debug info and `--print_module_info true` (see Operator summary has now module hierarchy information).
https://www.internalfb.com/intern/aibench/details/617154236292985

Reviewed By: raziel

Differential Revision: D30327514

fbshipit-source-id: 3bb2f2daaaedfb04bd6f5d9c91292783f9c4344f
This commit is contained in:
Kimish Patel 2021-08-30 20:53:50 -07:00 committed by Facebook GitHub Bot
parent 7ca4728e6d
commit bc9277dca3
9 changed files with 72 additions and 182 deletions

View File

@ -455,171 +455,6 @@ TEST(LiteInterpreterTest, BuiltinFunction) {
AT_ASSERT(str == expected);
}
#if !defined FB_XPLAT_BUILD
TEST(LiteInterpreterTest, ModuleInfoBasic) {
Module m("M");
m.define(R"JIT(
def forward(self, x):
return 2 * x
)JIT");
std::stringstream ss;
m._save_for_mobile(ss, {}, true);
mobile::Module bc = _load_for_mobile(ss);
std::unordered_set<std::string> module_debug_info_set;
size_t pc = 0;
while (true) {
try {
std::string module_info = bc.get_forward_method_debug_info(pc);
if (!module_info.empty() &&
(module_info.find("debug_handle") == std::string::npos)) {
module_debug_info_set.insert(module_info);
}
++pc;
} catch (const std::exception& e) {
break;
}
}
AT_ASSERT(module_debug_info_set.count("top(M)::<unknown>.aten::mul"));
}
TEST(LiteInterpreterTest, NotSaveModuleInfo) {
Module m("M");
m.define(R"JIT(
def forward(self, x):
return x + 5
)JIT");
std::stringstream ss;
m._save_for_mobile(ss);
mobile::Module bc = _load_for_mobile(ss);
size_t pc = 0;
while (true) {
try {
std::string module_info = bc.get_forward_method_debug_info(pc);
AT_ASSERT(
module_info.empty() ||
(module_info.find("debug_handle") != std::string::npos));
++pc;
} catch (const std::exception& e) {
break;
}
}
}
TEST(LiteInterpreterTest, OneSubmoduleModuleInfo) {
Module a("A");
a.define(R"JIT(
def forward(self, x):
return 2 * x + 5
)JIT");
Module b("B");
b.register_module("A0", a);
b.define(R"JIT(
def forward(self, x):
return self.A0.forward(x) + 1
)JIT");
std::stringstream ss;
b._save_for_mobile(ss, {}, true);
mobile::Module bc = _load_for_mobile(ss);
std::set<std::string> module_debug_info_set;
size_t pc = 0;
while (true) {
try {
std::string module_info = bc.get_forward_method_debug_info(pc);
if (!module_info.empty() &&
(module_info.find("debug_handle") == std::string::npos)) {
module_debug_info_set.insert(module_info);
}
++pc;
} catch (const std::exception& e) {
break;
}
}
AT_ASSERT(module_debug_info_set.count("top(B)::<unknown>.aten::add"));
AT_ASSERT(module_debug_info_set.count(
"top(B)::<unknown>.A0(A)::forward.aten::add"));
AT_ASSERT(module_debug_info_set.count(
"top(B)::<unknown>.A0(A)::forward.aten::mul"));
}
TEST(LiteInterpreterTest, TwoSubmodulesModuleInfo) {
Module a("A");
a.define(R"JIT(
def forward(self, x):
return x + 1
)JIT");
Module b("B");
b.define(R"JIT(
def forward(self, x):
return x + 2
)JIT");
Module c("C");
c.register_module("A0", a);
c.register_module("B0", b);
c.define(R"JIT(
def forward(self, x):
return self.A0.forward(x) + self.B0.forward(x)
)JIT");
std::stringstream ss;
c._save_for_mobile(ss, {}, true);
mobile::Module bc = _load_for_mobile(ss);
std::set<std::string> module_debug_info_set;
size_t pc = 0;
while (true) {
try {
std::string module_info = bc.get_forward_method_debug_info(pc);
if (!module_info.empty() &&
(module_info.find("debug_handle") == std::string::npos)) {
module_debug_info_set.insert(module_info);
}
++pc;
} catch (const std::exception& e) {
break;
}
}
AT_ASSERT(module_debug_info_set.count("top(C)::<unknown>.aten::add"));
AT_ASSERT(module_debug_info_set.count(
"top(C)::<unknown>.A0(A)::forward.aten::add"));
AT_ASSERT(module_debug_info_set.count(
"top(C)::<unknown>.B0(B)::forward.aten::add"));
}
TEST(LiteInterpreterTest, GetRuntimeByteCodeVersion) {
auto runtime_bytecode_version = _get_runtime_bytecode_version();
AT_ASSERT(
runtime_bytecode_version ==
caffe2::serialize::kMaxSupportedBytecodeVersion);
}
/**
* The test below is disarmed for FB internal xplat builds since
* BUCK requires us to pass in the script_module_v4.ptl file in
* as a resource dependency of the build rule for this file, and
* we would need to access it via the C++ Resources API instead
* of directly reading from disk (which is what the open source
* build/run does).
*/
TEST(LiteInterpreterTest, GetByteCodeVersion) {
std::string filePath(__FILE__);
auto test_model_file_v4 =
filePath.substr(0, filePath.find_last_of("/\\") + 1);
test_model_file_v4.append("script_module_v4.ptl");
auto version_v4 = _get_model_bytecode_version(test_model_file_v4);
AT_ASSERT(version_v4 == 4);
}
#endif // !defined(FB_XPLAT_BUILD)
namespace {
void compareModelOutput(

View File

@ -319,7 +319,7 @@ core_sources_full_mobile_no_backend_interface = [
"torch/csrc/jit/testing/hooks_for_testing.cpp",
"torch/csrc/utils/tensor_flatten.cpp",
"torch/csrc/utils/variadic.cpp",
] + libtorch_profiler_sources
]
core_sources_full_mobile = core_sources_full_mobile_no_backend_interface + [
"torch/csrc/jit/backends/backend_debug_info.cpp",
@ -337,7 +337,7 @@ core_sources_full = core_sources_full_mobile + [
"torch/csrc/jit/tensorexpr/external_functions_codegen.cpp",
]
libtorch_core_sources = sorted(core_sources_common + core_sources_full + core_trainer_sources)
libtorch_core_sources = sorted(core_sources_common + core_sources_full + core_trainer_sources + libtorch_profiler_sources)
# These files are the only ones that are supported on Windows.
libtorch_distributed_base_sources = [

View File

@ -13,6 +13,12 @@ namespace jit {
namespace {
C10_ALWAYS_INLINE std::string debugHandlesNotFoundMessage(
const std::string& debug_handles_string) {
return "Debug info for handle(s): " + debug_handles_string +
", was not found.";
}
std::pair<std::vector<StackEntry>, std::string> getStackTraceWithModuleHierarchy(
const DebugInfoTuple& source_callstack,
const std::string& caller_name) {
@ -152,8 +158,7 @@ std::string MobileDebugTable::getModuleHierarchyInfo(
const std::string& top_module_type_name) const {
const auto it = callstack_ptr_map_.find(debug_handle);
if (it == callstack_ptr_map_.end()) {
return "Module info for handle, " + std::to_string(debug_handle) +
", not found.";
return debugHandlesNotFoundMessage(std::to_string(debug_handle));
}
return (getStackTraceWithModuleHierarchy(
{it->second}, "top", top_module_type_name))
@ -172,8 +177,7 @@ std::string MobileDebugTable::getSourceDebugString(
const std::string& top_module_type_name) const {
const auto it = callstack_ptr_map_.find(debug_handle);
if (it == callstack_ptr_map_.end()) {
return "Debug info for handle, " + std::to_string(debug_handle) +
", not found.";
return debugHandlesNotFoundMessage(std::to_string(debug_handle));
}
return (getStackTraceWithModuleHierarchy(
{it->second}, "top", top_module_type_name))
@ -208,8 +212,7 @@ std::pair<std::string, std::string> MobileDebugTable::
debug_handles_string += std::to_string(debug_handle);
}
debug_handles_string += "}";
debug_handles_string =
"Debug info for handles: " + debug_handles_string + ", was not found.";
debug_handles_string = debugHandlesNotFoundMessage(debug_handles_string);
return {debug_handles_string, debug_handles_string};
}
return (getStackTraceWithModuleHierarchy(

View File

@ -517,12 +517,15 @@ mobile::Module BytecodeDeserializer::deserialize(
auto bvals = std::move(*readArchive("bytecode", mcu).toTuple()).elements();
c10::optional<std::vector<IValue>> debug_handles;
bool has_debug_handles{false};
if (reader_->hasRecord("mobile_debug_handles.pkl")) {
debug_handles =
readArchive("mobile_debug_handles", mcu).toTuple()->elements();
has_debug_handles = true;
}
parseMethods(bvals, debug_handles, *mcu);
auto m = mobile::Module(readArchive("data", mcu).toObject(), mcu);
m.setHasDebugHandles(has_debug_handles);
#if defined(SYMBOLICATE_MOBILE_DEBUG_HANDLE)
MobileDebugTable debug_table = MobileDebugTable(reader_, compilation_unit_);
m.setDebugTable(std::move(debug_table));

View File

@ -57,6 +57,9 @@ bool InterpreterState::run(Stack& stack) {
auto inst_with_handle = code_->instructions_with_handles_.at(pc);
Instruction inst = inst_with_handle.instruction;
DebugHandle debug_handle = inst_with_handle.debug_handle;
// If no valid debug handle found then just log pc.
// This is possible when we did not save debug handles
debug_handle = debug_handle == -1 ? pc : debug_handle;
// std::cout << "RUNNING " << pc << " "
// << code_->instructions_with_handles_[pc].instruction;

View File

@ -145,8 +145,7 @@ std::string Module::getCallStack(const int64_t debug_handle) const {
// We really need to change this part, so in the next step for profiling support
// for delegates, the first thing will be to rewrite how profiling is done
// for lite interpreter.
std::string Module::get_forward_method_debug_info(size_t pc) const {
auto debug_handle = find_method("forward")->get_debug_handle(pc);
std::string Module::get_forward_method_debug_info(int64_t debug_handle) const {
#if defined(SYMBOLICATE_MOBILE_DEBUG_HANDLE)
return getDebugTable().getModuleHierarchyInfo(
debug_handle, getTopModuleTypeName(*this));

View File

@ -78,7 +78,7 @@ class TORCH_API Module {
}
const std::vector<at::Tensor> parameters() const;
const std::map<std::string, at::Tensor> named_parameters() const;
std::string get_forward_method_debug_info(size_t pc) const;
std::string get_forward_method_debug_info(int64_t debug_handle) const;
std::string getModuleHierarchy(const int64_t debug_handle) const;
std::string getCallStack(const int64_t debug_handle) const;
/// Enables "training" mode.
@ -115,11 +115,20 @@ class TORCH_API Module {
return debug_table_;
}
void setHasDebugHandles(bool has_debug_handles) {
has_debug_handles_ = has_debug_handles;
}
bool hasDebugHandles() const {
return has_debug_handles_;
}
private:
c10::intrusive_ptr<c10::ivalue::Object> object_;
std::unordered_map<std::string, std::string> metadata_;
std::shared_ptr<CompilationUnit> cu_;
MobileDebugTable debug_table_;
bool has_debug_handles_;
};
} // namespace mobile
} // namespace jit

View File

@ -2,7 +2,6 @@
#include <string>
#include <vector>
namespace profiler = torch::autograd::profiler;
namespace torch {
namespace jit {
namespace mobile {
@ -27,17 +26,26 @@ KinetoEdgeCPUProfiler::KinetoEdgeCPUProfiler(
if (with_modules || with_stack) {
auto post_processing = [this, with_stack, with_modules](
std::vector<profiler::KinetoEvent>& events) {
std::string no_debug_info("Model was not saved with debug information");
for (auto& e : events) {
if (with_modules) {
// Since KinetoEvents's module hierarchy takes vector of strings we
// just construct a temporary vector using one string element
e.moduleHierarchy(std::vector<std::string>(
{this->m_.getModuleHierarchy(e.debugHandle())}));
if (this->m_.hasDebugHandles()) {
e.moduleHierarchy(std::vector<std::string>(
{this->m_.getModuleHierarchy(e.debugHandle())}));
} else {
e.moduleHierarchy(std::vector<std::string>({no_debug_info}));
}
} else if (with_stack) {
// Since KinetoEvents's stack trace takes vector of strings we just
// construct a temporary vector using one string element
e.stack(std::vector<std::string>(
{this->m_.getCallStack(e.debugHandle())}));
if (this->m_.hasDebugHandles()) {
e.stack(std::vector<std::string>(
{this->m_.getCallStack(e.debugHandle())}));
} else {
e.stack(std::vector<std::string>({no_debug_info}));
}
}
}
};
@ -55,8 +63,33 @@ KinetoEdgeCPUProfiler::KinetoEdgeCPUProfiler(
trace_file_name_ = fname;
}
const std::unique_ptr<profiler::ProfilerResult>& KinetoEdgeCPUProfiler::
disableProfiler() {
TORCH_CHECK(
!profiler_result_,
"KinetoEdgeCPUProfiler already disabled. "
"To get list of events use getProfilerResults()");
profiler_result_ = profiler::disableProfiler();
return profiler_result_;
}
const std::unique_ptr<profiler::ProfilerResult>& KinetoEdgeCPUProfiler::
getProfilerResult() {
TORCH_CHECK(
profiler_result_,
"KinetoEdgeCPUProfiler has not been disabled. "
"use disableProfiler() API first, which returns the ProfilerResult.");
return profiler_result_;
}
KinetoEdgeCPUProfiler::~KinetoEdgeCPUProfiler() {
profiler::disableProfiler()->save(trace_file_name_);
if (!trace_file_name_.empty()) {
if (profiler_result_) {
profiler_result_->save(trace_file_name_);
} else {
profiler::disableProfiler()->save(trace_file_name_);
}
}
}
} // namespace mobile
} // namespace jit

View File

@ -2,6 +2,7 @@
#include <torch/csrc/autograd/profiler_kineto.h>
#include <torch/csrc/jit/mobile/module.h>
namespace profiler = torch::autograd::profiler;
namespace torch {
namespace jit {
namespace mobile {
@ -53,6 +54,9 @@ class TORCH_API KinetoEdgeCPUProfiler {
const bool with_flops = false,
const bool with_modules = false);
const std::unique_ptr<profiler::ProfilerResult>& disableProfiler();
const std::unique_ptr<profiler::ProfilerResult>& getProfilerResult();
~KinetoEdgeCPUProfiler();
private:
@ -62,6 +66,7 @@ class TORCH_API KinetoEdgeCPUProfiler {
*/
const mobile::Module& m_;
std::string trace_file_name_;
std::unique_ptr<profiler::ProfilerResult> profiler_result_;
};
} // namespace mobile
} // namespace jit