Reapply "[jit] Implement ScriptProfile to collect instruction profiles." (#58783)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/58783

This reverts commit fc804b5def.

Test Plan: Imported from OSS

Reviewed By: gmagogsfm

Differential Revision: D28617037

Pulled By: zhxchen17

fbshipit-source-id: 645de2ede20500a5c218d6ec3c7faae94de37a14
This commit is contained in:
Zhengxu Chen 2021-05-24 18:22:01 -07:00 committed by Facebook GitHub Bot
parent 705dd9ffac
commit 2b0ec9c3cf
8 changed files with 351 additions and 1 deletions

View File

@ -65,6 +65,7 @@ set(JIT_TEST_SRCS
${JIT_TEST_ROOT}/test_subgraph_rewriter.cpp
${JIT_TEST_ROOT}/test_subgraph_utils.cpp
${JIT_TEST_ROOT}/test_utils.cpp
${JIT_TEST_ROOT}/test_script_profile.cpp
)
if(USE_CUDA)

View File

@ -0,0 +1,62 @@
#include <gtest/gtest.h>
#include <c10/util/Optional.h>
#include <test/cpp/jit/test_utils.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/ir/irparser.h>
#include <torch/csrc/jit/runtime/script_profile.h>
namespace torch {
namespace jit {
TEST(ScriptProfileTest, Basic) {
const std::string source_string = R"V0G0N(
def foo(a, b):
return a + b #
)V0G0N";
auto begin = source_string.find("return");
auto end = source_string.find(" #");
Graph g;
const auto graph_string = R"IR(
graph(%a : Tensor,
%b : Tensor):
%2 : int = prim::Constant[value=1]()
%3 : Tensor = aten::add(%a, %b, %2)
return (%3))IR";
torch::jit::parseIR(graph_string, &g);
auto source = std::make_shared<Source>(source_string, "", 0);
auto node = *g.nodes().begin();
node->setSourceRange(SourceRange{source, begin, end});
ScriptProfile p;
p.enable();
{
profiling::InstructionSpan g0(*node);
profiling::InstructionSpan g1(*node);
profiling::InstructionSpan g2(*node);
}
p.disable();
auto stats = p.dumpStats();
EXPECT_EQ(stats.size(), 1);
auto it = stats.find(*source.get());
EXPECT_NE(it, stats.end());
auto& lines = it->second;
EXPECT_EQ(lines.size(), 1);
const auto& stat = lines.at(source->lineno_for_offset(begin));
EXPECT_EQ(stat.count, 3);
}
TEST(ScriptProfileTest, CallingOrder) {
ScriptProfile p;
p.enable();
EXPECT_THROW(p.dumpStats(), c10::Error);
p.disable();
auto dp = std::make_shared<profiling::Datapoint>(SourceRange{});
EXPECT_THROW(p.addDatapoint(std::move(dp)), c10::Error);
}
} // namespace jit
} // namespace torch

View File

@ -254,6 +254,7 @@ core_sources_full_mobile = [
"torch/csrc/jit/runtime/logging.cpp",
"torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp",
"torch/csrc/jit/runtime/profiling_record.cpp",
"torch/csrc/jit/runtime/script_profile.cpp",
"torch/csrc/jit/runtime/symbolic_script.cpp",
"torch/csrc/jit/serialization/callstack_debug_info_serialization.cpp",
"torch/csrc/jit/serialization/import.cpp",

View File

@ -3,6 +3,7 @@
#include <functional>
#include <memory>
#include <ATen/core/ivalue.h>
#include <c10/macros/Export.h>
#include <torch/csrc/jit/frontend/source_range.h>
@ -18,7 +19,7 @@ namespace jit {
* support heteogeneous lookup, and also shared_ptr is an implementation detail
* which should be encapsulated.
*/
class TORCH_API SourceRef {
class TORCH_API SourceRef : public CustomClassHolder {
public:
explicit SourceRef(std::shared_ptr<Source> source)
: source_(std::move(source)) {}
@ -34,6 +35,9 @@ class TORCH_API SourceRef {
bool operator<(const SourceRef& other) const {
return *this < *other.source_.get();
}
const Source* operator->() const {
return source_.get();
}
private:
std::shared_ptr<Source> source_;

View File

@ -14,6 +14,7 @@
#include <torch/csrc/jit/serialization/import.h>
#include <torch/csrc/jit/testing/file_check.h>
#include <c10/util/intrusive_ptr.h>
#include <torch/csrc/jit/frontend/parser.h>
#include <torch/csrc/jit/frontend/tracer.h>
#include <torch/csrc/jit/ir/constants.h>

View File

@ -22,6 +22,7 @@
#include <torch/csrc/jit/runtime/jit_exception.h>
#include <torch/csrc/jit/runtime/operator.h>
#include <torch/csrc/jit/runtime/profiling_record.h>
#include <torch/csrc/jit/runtime/script_profile.h>
#include <torch/csrc/jit/runtime/vararg_functions.h>
#ifdef USE_RPC
@ -229,6 +230,8 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target {
// std::cout << "RUNNING ";
// frames.back().function->dump(std::cout, frame.pc);
Instruction inst = frame.function->instructions_[frame.pc];
profiling::InstructionSpan instSpan{
*frame.function->instructions_source()[frame.pc]};
switch (inst.op) {
case ENTER: {
const auto& obj = peek(stack, 0, 1);

View File

@ -0,0 +1,177 @@
#include <torch/csrc/jit/runtime/script_profile.h>
#include <atomic>
#include <chrono>
#include <mutex>
#include <unordered_set>
#include <c10/util/Exception.h>
#include <c10/util/intrusive_ptr.h>
#include <torch/csrc/jit/api/function_impl.h>
namespace torch {
namespace jit {
namespace {
class ProfilesRegistry {
public:
bool empty() {
return empty_.load(std::memory_order_relaxed);
}
void addProfile(ScriptProfile& p) {
std::lock_guard<std::mutex> g(mutex_);
enabledProfiles_.emplace(&p);
empty_.store(false, std::memory_order_relaxed);
}
void removeProfile(ScriptProfile& p) {
std::lock_guard<std::mutex> g(mutex_);
enabledProfiles_.erase(&p);
if (enabledProfiles_.empty()) {
empty_.store(true, std::memory_order_relaxed);
}
}
void send(std::unique_ptr<profiling::Datapoint> datapoint) {
auto shared = std::shared_ptr<profiling::Datapoint>(std::move(datapoint));
std::lock_guard<std::mutex> g(mutex_);
for (auto* p : enabledProfiles_) {
p->addDatapoint(shared);
}
}
private:
std::atomic<bool> empty_{true};
std::mutex mutex_;
std::unordered_set<ScriptProfile*> enabledProfiles_;
};
ProfilesRegistry& getProfilesRegistry() {
static auto registry = std::ref(*new ProfilesRegistry{});
return registry;
}
auto initBindings() {
torch::class_<SourceRef>("profiling", "SourceRef")
.def(
"starting_lineno",
[](const c10::intrusive_ptr<SourceRef>& self) {
return static_cast<int64_t>((*self)->starting_line_no());
})
.def("text", [](const c10::intrusive_ptr<SourceRef>& self) {
return (*self)->text();
});
torch::class_<InstructionStats>("profiling", "InstructionStats")
.def(
"count",
[](const c10::intrusive_ptr<InstructionStats>& self) {
return self->count;
})
.def("duration_ns", [](const c10::intrusive_ptr<InstructionStats>& self) {
return static_cast<int64_t>(self->duration.count());
});
torch::class_<SourceStats>("profiling", "SourceStats")
.def(
"source",
[](const c10::intrusive_ptr<SourceStats>& self) {
return c10::make_intrusive<SourceRef>(self->getSourceRef());
})
.def("line_map", &SourceStats::getLineMap);
torch::class_<ScriptProfile>("profiling", "_ScriptProfile")
.def(torch::init<>())
.def("enable", &ScriptProfile::enable)
.def("disable", &ScriptProfile::disable)
.def("_dump_stats", [](const c10::intrusive_ptr<ScriptProfile>& self) {
const auto& stats = self->dumpStats();
c10::List<c10::intrusive_ptr<SourceStats>> ret;
for (const auto& source : stats) {
SourceStats::LineMap lineMap;
for (const auto& line : source.second) {
lineMap.insert(
line.first, c10::make_intrusive<InstructionStats>(line.second));
}
ret.push_back(c10::make_intrusive<SourceStats>(
source.first, std::move(lineMap)));
}
return ret;
});
return nullptr;
}
const auto torchBindInitializer = initBindings();
} // namespace
namespace profiling {
InstructionSpan::InstructionSpan(Node& node) {
if (getProfilesRegistry().empty()) {
return;
}
datapoint_ = std::make_unique<Datapoint>(node.sourceRange());
}
InstructionSpan::~InstructionSpan() {
if (!datapoint_) {
return;
}
datapoint_->end = std::chrono::steady_clock::now();
getProfilesRegistry().send(std::move(datapoint_));
}
} // namespace profiling
void ScriptProfile::enable() {
if (!std::exchange(enabled_, true)) {
getProfilesRegistry().addProfile(*this);
}
}
void ScriptProfile::disable() {
if (std::exchange(enabled_, false)) {
getProfilesRegistry().removeProfile(*this);
}
}
void ScriptProfile::addDatapoint(
std::shared_ptr<profiling::Datapoint> datapoint) {
TORCH_CHECK(enabled_, "Cannot only add datapoint to disabled profilers.");
datapoints_.push_back(std::move(datapoint));
}
const ScriptProfile::SourceMap& ScriptProfile::dumpStats() {
TORCH_CHECK(!enabled_, "Only disabled profilers are allowed to dump stats.");
for (const auto& datapoint : datapoints_) {
if (const auto& source = datapoint->sourceRange.source()) {
if (auto fileLineCol = datapoint->sourceRange.file_line_col()) {
auto it = sourceMap_.find(*source.get());
if (it == sourceMap_.end()) {
it = sourceMap_.emplace(SourceRef{source}, LineMap{}).first;
}
auto& stats = it->second[std::get<1>(*fileLineCol)];
stats.count++;
stats.duration += datapoint->end - datapoint->start;
}
}
}
datapoints_.clear();
return sourceMap_;
}
ScriptProfile::~ScriptProfile() {
if (enabled_) {
getProfilesRegistry().removeProfile(*this);
}
}
} // namespace jit
} // namespace torch

View File

@ -0,0 +1,101 @@
#pragma once
#include <chrono>
#include <map>
#include <string>
#include <ATen/core/ivalue.h>
#include <c10/macros/Macros.h>
#include <torch/csrc/jit/frontend/source_ref.h>
#include <torch/csrc/jit/ir/ir.h>
namespace torch {
namespace jit {
namespace profiling {
struct Datapoint {
using Timepoint = std::chrono::time_point<std::chrono::steady_clock>;
SourceRange sourceRange;
Timepoint start;
Timepoint end;
explicit Datapoint(SourceRange sr)
: sourceRange(std::move(sr)), start(std::chrono::steady_clock::now()) {}
};
class TORCH_API InstructionSpan {
public:
explicit InstructionSpan(Node&);
~InstructionSpan();
InstructionSpan(InstructionSpan&&) = delete;
InstructionSpan& operator=(InstructionSpan&&) = delete;
private:
std::unique_ptr<Datapoint> datapoint_;
};
} // namespace profiling
struct TORCH_API InstructionStats : public CustomClassHolder {
int64_t count{0};
std::chrono::nanoseconds duration{0};
};
class TORCH_API SourceStats : public CustomClassHolder {
public:
using LineMap = c10::Dict<int64_t, c10::intrusive_ptr<InstructionStats>>;
SourceStats(SourceRef source, LineMap lineMap)
: source_(std::move(source)), lineMap_(std::move(lineMap)) {}
const SourceRef& getSourceRef() const {
return source_;
}
const LineMap& getLineMap() const {
return lineMap_;
}
private:
SourceRef source_;
LineMap lineMap_;
};
/**
* ScriptProfile is an underlying C++ implementation for TorchScript profiling.
* The profiling section is specified by calling enable() and disable():
*
* ...
* scriptProfile.enable();
* ...
* (scripts)
* ...
* scriptProfile.disable();
* ...
*
* To retrieve collected runtime data, users may call dumpStats() and do
* arbitrary filtering on the data they want. Note that dumpStats() should
* not be called inside a profiling section.
* In general, stats are aggregated per source function body, and then by line
* number.
*/
class TORCH_API ScriptProfile : public CustomClassHolder {
// Aggregates datapoints by function source id, then by line number.
using LineMap = std::map<int64_t, InstructionStats>;
using SourceMap = std::map<SourceRef, LineMap, std::less<>>;
public:
void enable();
void disable();
const SourceMap& dumpStats();
void addDatapoint(std::shared_ptr<profiling::Datapoint>);
~ScriptProfile() override;
private:
bool enabled_{false};
std::vector<std::shared_ptr<profiling::Datapoint>> datapoints_;
SourceMap sourceMap_;
};
} // namespace jit
} // namespace torch