mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Pytorch lite predictor] Use KinetoEdgeCPUProfiler for operator profiling. (#63367)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63367 This diff changes the way operator profiling is done in lite predictor benchmarking binary. Instead of using custom callbacks it uses KinetoEdgeCPUProfiler to profile events and then generate operator level metric from it. Since KinetoEvents do not contain cpu clock time, now we report only wallclock time. This unifies various profiling effort that we have for benchmarking purpose. In production we will still use observer based mechanism, but the advantage of using kineto profiler is that we get few other things for free, such as: - chrome trace generation. - operator level memory profiling (to be added) - flop counts (to be added) Furthermore possible we can use python post processing script to parse chrome trace and generate output similar to torch.profiler. (To be done) Test Plan: aibench run Model without debug info: https://www.internalfb.com/intern/aibench/details/219598441154763 Model with debug info and `--print_module_info true` (see Operator summary has now module hierarchy information). https://www.internalfb.com/intern/aibench/details/617154236292985 Reviewed By: raziel Differential Revision: D30327514 fbshipit-source-id: 3bb2f2daaaedfb04bd6f5d9c91292783f9c4344f
This commit is contained in:
parent
7ca4728e6d
commit
bc9277dca3
|
|
@ -455,171 +455,6 @@ TEST(LiteInterpreterTest, BuiltinFunction) {
|
|||
AT_ASSERT(str == expected);
|
||||
}
|
||||
|
||||
#if !defined FB_XPLAT_BUILD
|
||||
TEST(LiteInterpreterTest, ModuleInfoBasic) {
|
||||
Module m("M");
|
||||
m.define(R"JIT(
|
||||
def forward(self, x):
|
||||
return 2 * x
|
||||
)JIT");
|
||||
|
||||
std::stringstream ss;
|
||||
m._save_for_mobile(ss, {}, true);
|
||||
mobile::Module bc = _load_for_mobile(ss);
|
||||
|
||||
std::unordered_set<std::string> module_debug_info_set;
|
||||
size_t pc = 0;
|
||||
while (true) {
|
||||
try {
|
||||
std::string module_info = bc.get_forward_method_debug_info(pc);
|
||||
if (!module_info.empty() &&
|
||||
(module_info.find("debug_handle") == std::string::npos)) {
|
||||
module_debug_info_set.insert(module_info);
|
||||
}
|
||||
++pc;
|
||||
} catch (const std::exception& e) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
AT_ASSERT(module_debug_info_set.count("top(M)::<unknown>.aten::mul"));
|
||||
}
|
||||
|
||||
TEST(LiteInterpreterTest, NotSaveModuleInfo) {
|
||||
Module m("M");
|
||||
m.define(R"JIT(
|
||||
def forward(self, x):
|
||||
return x + 5
|
||||
)JIT");
|
||||
|
||||
std::stringstream ss;
|
||||
m._save_for_mobile(ss);
|
||||
mobile::Module bc = _load_for_mobile(ss);
|
||||
|
||||
size_t pc = 0;
|
||||
while (true) {
|
||||
try {
|
||||
std::string module_info = bc.get_forward_method_debug_info(pc);
|
||||
AT_ASSERT(
|
||||
module_info.empty() ||
|
||||
(module_info.find("debug_handle") != std::string::npos));
|
||||
++pc;
|
||||
} catch (const std::exception& e) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(LiteInterpreterTest, OneSubmoduleModuleInfo) {
|
||||
Module a("A");
|
||||
a.define(R"JIT(
|
||||
def forward(self, x):
|
||||
return 2 * x + 5
|
||||
)JIT");
|
||||
Module b("B");
|
||||
b.register_module("A0", a);
|
||||
b.define(R"JIT(
|
||||
def forward(self, x):
|
||||
return self.A0.forward(x) + 1
|
||||
)JIT");
|
||||
|
||||
std::stringstream ss;
|
||||
b._save_for_mobile(ss, {}, true);
|
||||
mobile::Module bc = _load_for_mobile(ss);
|
||||
|
||||
std::set<std::string> module_debug_info_set;
|
||||
size_t pc = 0;
|
||||
while (true) {
|
||||
try {
|
||||
std::string module_info = bc.get_forward_method_debug_info(pc);
|
||||
if (!module_info.empty() &&
|
||||
(module_info.find("debug_handle") == std::string::npos)) {
|
||||
module_debug_info_set.insert(module_info);
|
||||
}
|
||||
++pc;
|
||||
} catch (const std::exception& e) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
AT_ASSERT(module_debug_info_set.count("top(B)::<unknown>.aten::add"));
|
||||
AT_ASSERT(module_debug_info_set.count(
|
||||
"top(B)::<unknown>.A0(A)::forward.aten::add"));
|
||||
AT_ASSERT(module_debug_info_set.count(
|
||||
"top(B)::<unknown>.A0(A)::forward.aten::mul"));
|
||||
}
|
||||
|
||||
TEST(LiteInterpreterTest, TwoSubmodulesModuleInfo) {
|
||||
Module a("A");
|
||||
a.define(R"JIT(
|
||||
def forward(self, x):
|
||||
return x + 1
|
||||
)JIT");
|
||||
Module b("B");
|
||||
b.define(R"JIT(
|
||||
def forward(self, x):
|
||||
return x + 2
|
||||
)JIT");
|
||||
Module c("C");
|
||||
c.register_module("A0", a);
|
||||
c.register_module("B0", b);
|
||||
c.define(R"JIT(
|
||||
def forward(self, x):
|
||||
return self.A0.forward(x) + self.B0.forward(x)
|
||||
)JIT");
|
||||
|
||||
std::stringstream ss;
|
||||
c._save_for_mobile(ss, {}, true);
|
||||
mobile::Module bc = _load_for_mobile(ss);
|
||||
|
||||
std::set<std::string> module_debug_info_set;
|
||||
size_t pc = 0;
|
||||
while (true) {
|
||||
try {
|
||||
std::string module_info = bc.get_forward_method_debug_info(pc);
|
||||
if (!module_info.empty() &&
|
||||
(module_info.find("debug_handle") == std::string::npos)) {
|
||||
module_debug_info_set.insert(module_info);
|
||||
}
|
||||
++pc;
|
||||
} catch (const std::exception& e) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
AT_ASSERT(module_debug_info_set.count("top(C)::<unknown>.aten::add"));
|
||||
AT_ASSERT(module_debug_info_set.count(
|
||||
"top(C)::<unknown>.A0(A)::forward.aten::add"));
|
||||
AT_ASSERT(module_debug_info_set.count(
|
||||
"top(C)::<unknown>.B0(B)::forward.aten::add"));
|
||||
}
|
||||
|
||||
TEST(LiteInterpreterTest, GetRuntimeByteCodeVersion) {
|
||||
auto runtime_bytecode_version = _get_runtime_bytecode_version();
|
||||
AT_ASSERT(
|
||||
runtime_bytecode_version ==
|
||||
caffe2::serialize::kMaxSupportedBytecodeVersion);
|
||||
}
|
||||
|
||||
/**
|
||||
* The test below is disarmed for FB internal xplat builds since
|
||||
* BUCK requires us to pass in the script_module_v4.ptl file in
|
||||
* as a resource dependency of the build rule for this file, and
|
||||
* we would need to access it via the C++ Resources API instead
|
||||
* of directly reading from disk (which is what the open source
|
||||
* build/run does).
|
||||
*/
|
||||
TEST(LiteInterpreterTest, GetByteCodeVersion) {
|
||||
std::string filePath(__FILE__);
|
||||
auto test_model_file_v4 =
|
||||
filePath.substr(0, filePath.find_last_of("/\\") + 1);
|
||||
test_model_file_v4.append("script_module_v4.ptl");
|
||||
|
||||
auto version_v4 = _get_model_bytecode_version(test_model_file_v4);
|
||||
AT_ASSERT(version_v4 == 4);
|
||||
}
|
||||
#endif // !defined(FB_XPLAT_BUILD)
|
||||
|
||||
namespace {
|
||||
|
||||
void compareModelOutput(
|
||||
|
|
|
|||
|
|
@ -319,7 +319,7 @@ core_sources_full_mobile_no_backend_interface = [
|
|||
"torch/csrc/jit/testing/hooks_for_testing.cpp",
|
||||
"torch/csrc/utils/tensor_flatten.cpp",
|
||||
"torch/csrc/utils/variadic.cpp",
|
||||
] + libtorch_profiler_sources
|
||||
]
|
||||
|
||||
core_sources_full_mobile = core_sources_full_mobile_no_backend_interface + [
|
||||
"torch/csrc/jit/backends/backend_debug_info.cpp",
|
||||
|
|
@ -337,7 +337,7 @@ core_sources_full = core_sources_full_mobile + [
|
|||
"torch/csrc/jit/tensorexpr/external_functions_codegen.cpp",
|
||||
]
|
||||
|
||||
libtorch_core_sources = sorted(core_sources_common + core_sources_full + core_trainer_sources)
|
||||
libtorch_core_sources = sorted(core_sources_common + core_sources_full + core_trainer_sources + libtorch_profiler_sources)
|
||||
|
||||
# These files are the only ones that are supported on Windows.
|
||||
libtorch_distributed_base_sources = [
|
||||
|
|
|
|||
|
|
@ -13,6 +13,12 @@ namespace jit {
|
|||
|
||||
namespace {
|
||||
|
||||
C10_ALWAYS_INLINE std::string debugHandlesNotFoundMessage(
|
||||
const std::string& debug_handles_string) {
|
||||
return "Debug info for handle(s): " + debug_handles_string +
|
||||
", was not found.";
|
||||
}
|
||||
|
||||
std::pair<std::vector<StackEntry>, std::string> getStackTraceWithModuleHierarchy(
|
||||
const DebugInfoTuple& source_callstack,
|
||||
const std::string& caller_name) {
|
||||
|
|
@ -152,8 +158,7 @@ std::string MobileDebugTable::getModuleHierarchyInfo(
|
|||
const std::string& top_module_type_name) const {
|
||||
const auto it = callstack_ptr_map_.find(debug_handle);
|
||||
if (it == callstack_ptr_map_.end()) {
|
||||
return "Module info for handle, " + std::to_string(debug_handle) +
|
||||
", not found.";
|
||||
return debugHandlesNotFoundMessage(std::to_string(debug_handle));
|
||||
}
|
||||
return (getStackTraceWithModuleHierarchy(
|
||||
{it->second}, "top", top_module_type_name))
|
||||
|
|
@ -172,8 +177,7 @@ std::string MobileDebugTable::getSourceDebugString(
|
|||
const std::string& top_module_type_name) const {
|
||||
const auto it = callstack_ptr_map_.find(debug_handle);
|
||||
if (it == callstack_ptr_map_.end()) {
|
||||
return "Debug info for handle, " + std::to_string(debug_handle) +
|
||||
", not found.";
|
||||
return debugHandlesNotFoundMessage(std::to_string(debug_handle));
|
||||
}
|
||||
return (getStackTraceWithModuleHierarchy(
|
||||
{it->second}, "top", top_module_type_name))
|
||||
|
|
@ -208,8 +212,7 @@ std::pair<std::string, std::string> MobileDebugTable::
|
|||
debug_handles_string += std::to_string(debug_handle);
|
||||
}
|
||||
debug_handles_string += "}";
|
||||
debug_handles_string =
|
||||
"Debug info for handles: " + debug_handles_string + ", was not found.";
|
||||
debug_handles_string = debugHandlesNotFoundMessage(debug_handles_string);
|
||||
return {debug_handles_string, debug_handles_string};
|
||||
}
|
||||
return (getStackTraceWithModuleHierarchy(
|
||||
|
|
|
|||
|
|
@ -517,12 +517,15 @@ mobile::Module BytecodeDeserializer::deserialize(
|
|||
auto bvals = std::move(*readArchive("bytecode", mcu).toTuple()).elements();
|
||||
|
||||
c10::optional<std::vector<IValue>> debug_handles;
|
||||
bool has_debug_handles{false};
|
||||
if (reader_->hasRecord("mobile_debug_handles.pkl")) {
|
||||
debug_handles =
|
||||
readArchive("mobile_debug_handles", mcu).toTuple()->elements();
|
||||
has_debug_handles = true;
|
||||
}
|
||||
parseMethods(bvals, debug_handles, *mcu);
|
||||
auto m = mobile::Module(readArchive("data", mcu).toObject(), mcu);
|
||||
m.setHasDebugHandles(has_debug_handles);
|
||||
#if defined(SYMBOLICATE_MOBILE_DEBUG_HANDLE)
|
||||
MobileDebugTable debug_table = MobileDebugTable(reader_, compilation_unit_);
|
||||
m.setDebugTable(std::move(debug_table));
|
||||
|
|
|
|||
|
|
@ -57,6 +57,9 @@ bool InterpreterState::run(Stack& stack) {
|
|||
auto inst_with_handle = code_->instructions_with_handles_.at(pc);
|
||||
Instruction inst = inst_with_handle.instruction;
|
||||
DebugHandle debug_handle = inst_with_handle.debug_handle;
|
||||
// If no valid debug handle found then just log pc.
|
||||
// This is possible when we did not save debug handles
|
||||
debug_handle = debug_handle == -1 ? pc : debug_handle;
|
||||
|
||||
// std::cout << "RUNNING " << pc << " "
|
||||
// << code_->instructions_with_handles_[pc].instruction;
|
||||
|
|
|
|||
|
|
@ -145,8 +145,7 @@ std::string Module::getCallStack(const int64_t debug_handle) const {
|
|||
// We really need to change this part, so in the next step for profiling support
|
||||
// for delegates, the first thing will be to rewrite how profiling is done
|
||||
// for lite interpreter.
|
||||
std::string Module::get_forward_method_debug_info(size_t pc) const {
|
||||
auto debug_handle = find_method("forward")->get_debug_handle(pc);
|
||||
std::string Module::get_forward_method_debug_info(int64_t debug_handle) const {
|
||||
#if defined(SYMBOLICATE_MOBILE_DEBUG_HANDLE)
|
||||
return getDebugTable().getModuleHierarchyInfo(
|
||||
debug_handle, getTopModuleTypeName(*this));
|
||||
|
|
|
|||
|
|
@ -78,7 +78,7 @@ class TORCH_API Module {
|
|||
}
|
||||
const std::vector<at::Tensor> parameters() const;
|
||||
const std::map<std::string, at::Tensor> named_parameters() const;
|
||||
std::string get_forward_method_debug_info(size_t pc) const;
|
||||
std::string get_forward_method_debug_info(int64_t debug_handle) const;
|
||||
std::string getModuleHierarchy(const int64_t debug_handle) const;
|
||||
std::string getCallStack(const int64_t debug_handle) const;
|
||||
/// Enables "training" mode.
|
||||
|
|
@ -115,11 +115,20 @@ class TORCH_API Module {
|
|||
return debug_table_;
|
||||
}
|
||||
|
||||
void setHasDebugHandles(bool has_debug_handles) {
|
||||
has_debug_handles_ = has_debug_handles;
|
||||
}
|
||||
|
||||
bool hasDebugHandles() const {
|
||||
return has_debug_handles_;
|
||||
}
|
||||
|
||||
private:
|
||||
c10::intrusive_ptr<c10::ivalue::Object> object_;
|
||||
std::unordered_map<std::string, std::string> metadata_;
|
||||
std::shared_ptr<CompilationUnit> cu_;
|
||||
MobileDebugTable debug_table_;
|
||||
bool has_debug_handles_;
|
||||
};
|
||||
} // namespace mobile
|
||||
} // namespace jit
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace profiler = torch::autograd::profiler;
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace mobile {
|
||||
|
|
@ -27,17 +26,26 @@ KinetoEdgeCPUProfiler::KinetoEdgeCPUProfiler(
|
|||
if (with_modules || with_stack) {
|
||||
auto post_processing = [this, with_stack, with_modules](
|
||||
std::vector<profiler::KinetoEvent>& events) {
|
||||
std::string no_debug_info("Model was not saved with debug information");
|
||||
for (auto& e : events) {
|
||||
if (with_modules) {
|
||||
// Since KinetoEvents's module hierarchy takes vector of strings we
|
||||
// just construct a temporary vector using one string element
|
||||
e.moduleHierarchy(std::vector<std::string>(
|
||||
{this->m_.getModuleHierarchy(e.debugHandle())}));
|
||||
if (this->m_.hasDebugHandles()) {
|
||||
e.moduleHierarchy(std::vector<std::string>(
|
||||
{this->m_.getModuleHierarchy(e.debugHandle())}));
|
||||
} else {
|
||||
e.moduleHierarchy(std::vector<std::string>({no_debug_info}));
|
||||
}
|
||||
} else if (with_stack) {
|
||||
// Since KinetoEvents's stack trace takes vector of strings we just
|
||||
// construct a temporary vector using one string element
|
||||
e.stack(std::vector<std::string>(
|
||||
{this->m_.getCallStack(e.debugHandle())}));
|
||||
if (this->m_.hasDebugHandles()) {
|
||||
e.stack(std::vector<std::string>(
|
||||
{this->m_.getCallStack(e.debugHandle())}));
|
||||
} else {
|
||||
e.stack(std::vector<std::string>({no_debug_info}));
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
|
@ -55,8 +63,33 @@ KinetoEdgeCPUProfiler::KinetoEdgeCPUProfiler(
|
|||
trace_file_name_ = fname;
|
||||
}
|
||||
|
||||
const std::unique_ptr<profiler::ProfilerResult>& KinetoEdgeCPUProfiler::
|
||||
disableProfiler() {
|
||||
TORCH_CHECK(
|
||||
!profiler_result_,
|
||||
"KinetoEdgeCPUProfiler already disabled. "
|
||||
"To get list of events use getProfilerResults()");
|
||||
profiler_result_ = profiler::disableProfiler();
|
||||
return profiler_result_;
|
||||
}
|
||||
|
||||
const std::unique_ptr<profiler::ProfilerResult>& KinetoEdgeCPUProfiler::
|
||||
getProfilerResult() {
|
||||
TORCH_CHECK(
|
||||
profiler_result_,
|
||||
"KinetoEdgeCPUProfiler has not been disabled. "
|
||||
"use disableProfiler() API first, which returns the ProfilerResult.");
|
||||
return profiler_result_;
|
||||
}
|
||||
|
||||
KinetoEdgeCPUProfiler::~KinetoEdgeCPUProfiler() {
|
||||
profiler::disableProfiler()->save(trace_file_name_);
|
||||
if (!trace_file_name_.empty()) {
|
||||
if (profiler_result_) {
|
||||
profiler_result_->save(trace_file_name_);
|
||||
} else {
|
||||
profiler::disableProfiler()->save(trace_file_name_);
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace mobile
|
||||
} // namespace jit
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
#include <torch/csrc/autograd/profiler_kineto.h>
|
||||
#include <torch/csrc/jit/mobile/module.h>
|
||||
|
||||
namespace profiler = torch::autograd::profiler;
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace mobile {
|
||||
|
|
@ -53,6 +54,9 @@ class TORCH_API KinetoEdgeCPUProfiler {
|
|||
const bool with_flops = false,
|
||||
const bool with_modules = false);
|
||||
|
||||
const std::unique_ptr<profiler::ProfilerResult>& disableProfiler();
|
||||
const std::unique_ptr<profiler::ProfilerResult>& getProfilerResult();
|
||||
|
||||
~KinetoEdgeCPUProfiler();
|
||||
|
||||
private:
|
||||
|
|
@ -62,6 +66,7 @@ class TORCH_API KinetoEdgeCPUProfiler {
|
|||
*/
|
||||
const mobile::Module& m_;
|
||||
std::string trace_file_name_;
|
||||
std::unique_ptr<profiler::ProfilerResult> profiler_result_;
|
||||
};
|
||||
} // namespace mobile
|
||||
} // namespace jit
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user