mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Reapply "[jit] Implement ScriptProfile to collect instruction profiles." (#58783)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/58783
This reverts commit fc804b5def.
Test Plan: Imported from OSS
Reviewed By: gmagogsfm
Differential Revision: D28617037
Pulled By: zhxchen17
fbshipit-source-id: 645de2ede20500a5c218d6ec3c7faae94de37a14
This commit is contained in:
parent
705dd9ffac
commit
2b0ec9c3cf
|
|
@ -65,6 +65,7 @@ set(JIT_TEST_SRCS
|
|||
${JIT_TEST_ROOT}/test_subgraph_rewriter.cpp
|
||||
${JIT_TEST_ROOT}/test_subgraph_utils.cpp
|
||||
${JIT_TEST_ROOT}/test_utils.cpp
|
||||
${JIT_TEST_ROOT}/test_script_profile.cpp
|
||||
)
|
||||
|
||||
if(USE_CUDA)
|
||||
|
|
|
|||
62
test/cpp/jit/test_script_profile.cpp
Normal file
62
test/cpp/jit/test_script_profile.cpp
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
#include <gtest/gtest.h>
|
||||
|
||||
#include <c10/util/Optional.h>
|
||||
#include <test/cpp/jit/test_utils.h>
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
#include <torch/csrc/jit/ir/irparser.h>
|
||||
#include <torch/csrc/jit/runtime/script_profile.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
TEST(ScriptProfileTest, Basic) {
|
||||
const std::string source_string = R"V0G0N(
|
||||
def foo(a, b):
|
||||
return a + b #
|
||||
)V0G0N";
|
||||
auto begin = source_string.find("return");
|
||||
auto end = source_string.find(" #");
|
||||
|
||||
Graph g;
|
||||
const auto graph_string = R"IR(
|
||||
graph(%a : Tensor,
|
||||
%b : Tensor):
|
||||
%2 : int = prim::Constant[value=1]()
|
||||
%3 : Tensor = aten::add(%a, %b, %2)
|
||||
return (%3))IR";
|
||||
|
||||
torch::jit::parseIR(graph_string, &g);
|
||||
auto source = std::make_shared<Source>(source_string, "", 0);
|
||||
auto node = *g.nodes().begin();
|
||||
node->setSourceRange(SourceRange{source, begin, end});
|
||||
|
||||
ScriptProfile p;
|
||||
p.enable();
|
||||
{
|
||||
profiling::InstructionSpan g0(*node);
|
||||
profiling::InstructionSpan g1(*node);
|
||||
profiling::InstructionSpan g2(*node);
|
||||
}
|
||||
p.disable();
|
||||
|
||||
auto stats = p.dumpStats();
|
||||
EXPECT_EQ(stats.size(), 1);
|
||||
auto it = stats.find(*source.get());
|
||||
EXPECT_NE(it, stats.end());
|
||||
auto& lines = it->second;
|
||||
EXPECT_EQ(lines.size(), 1);
|
||||
const auto& stat = lines.at(source->lineno_for_offset(begin));
|
||||
EXPECT_EQ(stat.count, 3);
|
||||
}
|
||||
|
||||
TEST(ScriptProfileTest, CallingOrder) {
|
||||
ScriptProfile p;
|
||||
p.enable();
|
||||
EXPECT_THROW(p.dumpStats(), c10::Error);
|
||||
p.disable();
|
||||
auto dp = std::make_shared<profiling::Datapoint>(SourceRange{});
|
||||
EXPECT_THROW(p.addDatapoint(std::move(dp)), c10::Error);
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
@ -254,6 +254,7 @@ core_sources_full_mobile = [
|
|||
"torch/csrc/jit/runtime/logging.cpp",
|
||||
"torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp",
|
||||
"torch/csrc/jit/runtime/profiling_record.cpp",
|
||||
"torch/csrc/jit/runtime/script_profile.cpp",
|
||||
"torch/csrc/jit/runtime/symbolic_script.cpp",
|
||||
"torch/csrc/jit/serialization/callstack_debug_info_serialization.cpp",
|
||||
"torch/csrc/jit/serialization/import.cpp",
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
#include <functional>
|
||||
#include <memory>
|
||||
|
||||
#include <ATen/core/ivalue.h>
|
||||
#include <c10/macros/Export.h>
|
||||
#include <torch/csrc/jit/frontend/source_range.h>
|
||||
|
||||
|
|
@ -18,7 +19,7 @@ namespace jit {
|
|||
* support heteogeneous lookup, and also shared_ptr is an implementation detail
|
||||
* which should be encapsulated.
|
||||
*/
|
||||
class TORCH_API SourceRef {
|
||||
class TORCH_API SourceRef : public CustomClassHolder {
|
||||
public:
|
||||
explicit SourceRef(std::shared_ptr<Source> source)
|
||||
: source_(std::move(source)) {}
|
||||
|
|
@ -34,6 +35,9 @@ class TORCH_API SourceRef {
|
|||
bool operator<(const SourceRef& other) const {
|
||||
return *this < *other.source_.get();
|
||||
}
|
||||
const Source* operator->() const {
|
||||
return source_.get();
|
||||
}
|
||||
|
||||
private:
|
||||
std::shared_ptr<Source> source_;
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@
|
|||
#include <torch/csrc/jit/serialization/import.h>
|
||||
#include <torch/csrc/jit/testing/file_check.h>
|
||||
|
||||
#include <c10/util/intrusive_ptr.h>
|
||||
#include <torch/csrc/jit/frontend/parser.h>
|
||||
#include <torch/csrc/jit/frontend/tracer.h>
|
||||
#include <torch/csrc/jit/ir/constants.h>
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@
|
|||
#include <torch/csrc/jit/runtime/jit_exception.h>
|
||||
#include <torch/csrc/jit/runtime/operator.h>
|
||||
#include <torch/csrc/jit/runtime/profiling_record.h>
|
||||
#include <torch/csrc/jit/runtime/script_profile.h>
|
||||
#include <torch/csrc/jit/runtime/vararg_functions.h>
|
||||
|
||||
#ifdef USE_RPC
|
||||
|
|
@ -229,6 +230,8 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target {
|
|||
// std::cout << "RUNNING ";
|
||||
// frames.back().function->dump(std::cout, frame.pc);
|
||||
Instruction inst = frame.function->instructions_[frame.pc];
|
||||
profiling::InstructionSpan instSpan{
|
||||
*frame.function->instructions_source()[frame.pc]};
|
||||
switch (inst.op) {
|
||||
case ENTER: {
|
||||
const auto& obj = peek(stack, 0, 1);
|
||||
|
|
|
|||
177
torch/csrc/jit/runtime/script_profile.cpp
Normal file
177
torch/csrc/jit/runtime/script_profile.cpp
Normal file
|
|
@ -0,0 +1,177 @@
|
|||
#include <torch/csrc/jit/runtime/script_profile.h>
|
||||
|
||||
#include <atomic>
|
||||
#include <chrono>
|
||||
#include <mutex>
|
||||
#include <unordered_set>
|
||||
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/intrusive_ptr.h>
|
||||
#include <torch/csrc/jit/api/function_impl.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
namespace {
|
||||
|
||||
class ProfilesRegistry {
|
||||
public:
|
||||
bool empty() {
|
||||
return empty_.load(std::memory_order_relaxed);
|
||||
}
|
||||
|
||||
void addProfile(ScriptProfile& p) {
|
||||
std::lock_guard<std::mutex> g(mutex_);
|
||||
enabledProfiles_.emplace(&p);
|
||||
empty_.store(false, std::memory_order_relaxed);
|
||||
}
|
||||
|
||||
void removeProfile(ScriptProfile& p) {
|
||||
std::lock_guard<std::mutex> g(mutex_);
|
||||
enabledProfiles_.erase(&p);
|
||||
if (enabledProfiles_.empty()) {
|
||||
empty_.store(true, std::memory_order_relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
void send(std::unique_ptr<profiling::Datapoint> datapoint) {
|
||||
auto shared = std::shared_ptr<profiling::Datapoint>(std::move(datapoint));
|
||||
std::lock_guard<std::mutex> g(mutex_);
|
||||
for (auto* p : enabledProfiles_) {
|
||||
p->addDatapoint(shared);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::atomic<bool> empty_{true};
|
||||
std::mutex mutex_;
|
||||
std::unordered_set<ScriptProfile*> enabledProfiles_;
|
||||
};
|
||||
|
||||
ProfilesRegistry& getProfilesRegistry() {
|
||||
static auto registry = std::ref(*new ProfilesRegistry{});
|
||||
return registry;
|
||||
}
|
||||
|
||||
auto initBindings() {
|
||||
torch::class_<SourceRef>("profiling", "SourceRef")
|
||||
.def(
|
||||
"starting_lineno",
|
||||
[](const c10::intrusive_ptr<SourceRef>& self) {
|
||||
return static_cast<int64_t>((*self)->starting_line_no());
|
||||
})
|
||||
.def("text", [](const c10::intrusive_ptr<SourceRef>& self) {
|
||||
return (*self)->text();
|
||||
});
|
||||
|
||||
torch::class_<InstructionStats>("profiling", "InstructionStats")
|
||||
.def(
|
||||
"count",
|
||||
[](const c10::intrusive_ptr<InstructionStats>& self) {
|
||||
return self->count;
|
||||
})
|
||||
.def("duration_ns", [](const c10::intrusive_ptr<InstructionStats>& self) {
|
||||
return static_cast<int64_t>(self->duration.count());
|
||||
});
|
||||
|
||||
torch::class_<SourceStats>("profiling", "SourceStats")
|
||||
.def(
|
||||
"source",
|
||||
[](const c10::intrusive_ptr<SourceStats>& self) {
|
||||
return c10::make_intrusive<SourceRef>(self->getSourceRef());
|
||||
})
|
||||
.def("line_map", &SourceStats::getLineMap);
|
||||
|
||||
torch::class_<ScriptProfile>("profiling", "_ScriptProfile")
|
||||
.def(torch::init<>())
|
||||
.def("enable", &ScriptProfile::enable)
|
||||
.def("disable", &ScriptProfile::disable)
|
||||
.def("_dump_stats", [](const c10::intrusive_ptr<ScriptProfile>& self) {
|
||||
const auto& stats = self->dumpStats();
|
||||
c10::List<c10::intrusive_ptr<SourceStats>> ret;
|
||||
for (const auto& source : stats) {
|
||||
SourceStats::LineMap lineMap;
|
||||
for (const auto& line : source.second) {
|
||||
lineMap.insert(
|
||||
line.first, c10::make_intrusive<InstructionStats>(line.second));
|
||||
}
|
||||
ret.push_back(c10::make_intrusive<SourceStats>(
|
||||
source.first, std::move(lineMap)));
|
||||
}
|
||||
return ret;
|
||||
});
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const auto torchBindInitializer = initBindings();
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace profiling {
|
||||
|
||||
InstructionSpan::InstructionSpan(Node& node) {
|
||||
if (getProfilesRegistry().empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
datapoint_ = std::make_unique<Datapoint>(node.sourceRange());
|
||||
}
|
||||
|
||||
InstructionSpan::~InstructionSpan() {
|
||||
if (!datapoint_) {
|
||||
return;
|
||||
}
|
||||
|
||||
datapoint_->end = std::chrono::steady_clock::now();
|
||||
getProfilesRegistry().send(std::move(datapoint_));
|
||||
}
|
||||
|
||||
} // namespace profiling
|
||||
|
||||
void ScriptProfile::enable() {
|
||||
if (!std::exchange(enabled_, true)) {
|
||||
getProfilesRegistry().addProfile(*this);
|
||||
}
|
||||
}
|
||||
|
||||
void ScriptProfile::disable() {
|
||||
if (std::exchange(enabled_, false)) {
|
||||
getProfilesRegistry().removeProfile(*this);
|
||||
}
|
||||
}
|
||||
|
||||
void ScriptProfile::addDatapoint(
|
||||
std::shared_ptr<profiling::Datapoint> datapoint) {
|
||||
TORCH_CHECK(enabled_, "Cannot only add datapoint to disabled profilers.");
|
||||
datapoints_.push_back(std::move(datapoint));
|
||||
}
|
||||
|
||||
const ScriptProfile::SourceMap& ScriptProfile::dumpStats() {
|
||||
TORCH_CHECK(!enabled_, "Only disabled profilers are allowed to dump stats.");
|
||||
|
||||
for (const auto& datapoint : datapoints_) {
|
||||
if (const auto& source = datapoint->sourceRange.source()) {
|
||||
if (auto fileLineCol = datapoint->sourceRange.file_line_col()) {
|
||||
auto it = sourceMap_.find(*source.get());
|
||||
if (it == sourceMap_.end()) {
|
||||
it = sourceMap_.emplace(SourceRef{source}, LineMap{}).first;
|
||||
}
|
||||
auto& stats = it->second[std::get<1>(*fileLineCol)];
|
||||
stats.count++;
|
||||
stats.duration += datapoint->end - datapoint->start;
|
||||
}
|
||||
}
|
||||
}
|
||||
datapoints_.clear();
|
||||
|
||||
return sourceMap_;
|
||||
}
|
||||
|
||||
ScriptProfile::~ScriptProfile() {
|
||||
if (enabled_) {
|
||||
getProfilesRegistry().removeProfile(*this);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
101
torch/csrc/jit/runtime/script_profile.h
Normal file
101
torch/csrc/jit/runtime/script_profile.h
Normal file
|
|
@ -0,0 +1,101 @@
|
|||
#pragma once
|
||||
|
||||
#include <chrono>
|
||||
#include <map>
|
||||
#include <string>
|
||||
|
||||
#include <ATen/core/ivalue.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <torch/csrc/jit/frontend/source_ref.h>
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace profiling {
|
||||
|
||||
struct Datapoint {
|
||||
using Timepoint = std::chrono::time_point<std::chrono::steady_clock>;
|
||||
SourceRange sourceRange;
|
||||
Timepoint start;
|
||||
Timepoint end;
|
||||
|
||||
explicit Datapoint(SourceRange sr)
|
||||
: sourceRange(std::move(sr)), start(std::chrono::steady_clock::now()) {}
|
||||
};
|
||||
|
||||
class TORCH_API InstructionSpan {
|
||||
public:
|
||||
explicit InstructionSpan(Node&);
|
||||
~InstructionSpan();
|
||||
InstructionSpan(InstructionSpan&&) = delete;
|
||||
InstructionSpan& operator=(InstructionSpan&&) = delete;
|
||||
|
||||
private:
|
||||
std::unique_ptr<Datapoint> datapoint_;
|
||||
};
|
||||
|
||||
} // namespace profiling
|
||||
|
||||
struct TORCH_API InstructionStats : public CustomClassHolder {
|
||||
int64_t count{0};
|
||||
std::chrono::nanoseconds duration{0};
|
||||
};
|
||||
|
||||
class TORCH_API SourceStats : public CustomClassHolder {
|
||||
public:
|
||||
using LineMap = c10::Dict<int64_t, c10::intrusive_ptr<InstructionStats>>;
|
||||
|
||||
SourceStats(SourceRef source, LineMap lineMap)
|
||||
: source_(std::move(source)), lineMap_(std::move(lineMap)) {}
|
||||
|
||||
const SourceRef& getSourceRef() const {
|
||||
return source_;
|
||||
}
|
||||
|
||||
const LineMap& getLineMap() const {
|
||||
return lineMap_;
|
||||
}
|
||||
|
||||
private:
|
||||
SourceRef source_;
|
||||
LineMap lineMap_;
|
||||
};
|
||||
|
||||
/**
|
||||
* ScriptProfile is an underlying C++ implementation for TorchScript profiling.
|
||||
* The profiling section is specified by calling enable() and disable():
|
||||
*
|
||||
* ...
|
||||
* scriptProfile.enable();
|
||||
* ...
|
||||
* (scripts)
|
||||
* ...
|
||||
* scriptProfile.disable();
|
||||
* ...
|
||||
*
|
||||
* To retrieve collected runtime data, users may call dumpStats() and do
|
||||
* arbitrary filtering on the data they want. Note that dumpStats() should
|
||||
* not be called inside a profiling section.
|
||||
* In general, stats are aggregated per source function body, and then by line
|
||||
* number.
|
||||
*/
|
||||
class TORCH_API ScriptProfile : public CustomClassHolder {
|
||||
// Aggregates datapoints by function source id, then by line number.
|
||||
using LineMap = std::map<int64_t, InstructionStats>;
|
||||
using SourceMap = std::map<SourceRef, LineMap, std::less<>>;
|
||||
|
||||
public:
|
||||
void enable();
|
||||
void disable();
|
||||
const SourceMap& dumpStats();
|
||||
void addDatapoint(std::shared_ptr<profiling::Datapoint>);
|
||||
~ScriptProfile() override;
|
||||
|
||||
private:
|
||||
bool enabled_{false};
|
||||
std::vector<std::shared_ptr<profiling::Datapoint>> datapoints_;
|
||||
SourceMap sourceMap_;
|
||||
};
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
Loading…
Reference in New Issue
Block a user