mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/57397 Introduces two main classes in C++ runtime: ScriptProfile is the implementation for enalbing and disabling interpreter profiling in C++. This should be only used from Python, and we will add corresponding Python API in the next diff. InstructionSpan is a utility class to instrument execution of each single instruction. A start timestamp is recorded in the consturctor, and an end timestamp is recorded in the destructor. During destruction, this will send runtime data to all enabled ScriptProfile instances. Test Plan: build/bin/test_jit --gtest_filter='ScriptProfileTest.Basic' Imported from OSS Reviewed By: gmagogsfm Differential Revision: D28133579 fbshipit-source-id: e7e30e96151367022793ab3ad323f01c51ad4a3b
63 lines
1.6 KiB
C++
63 lines
1.6 KiB
C++
#include <gtest/gtest.h>
|
|
|
|
#include <c10/util/Optional.h>
|
|
#include <test/cpp/jit/test_utils.h>
|
|
#include <torch/csrc/jit/ir/ir.h>
|
|
#include <torch/csrc/jit/ir/irparser.h>
|
|
#include <torch/csrc/jit/runtime/script_profile.h>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
TEST(ScriptProfileTest, Basic) {
|
|
const std::string source_string = R"V0G0N(
|
|
def foo(a, b):
|
|
return a + b #
|
|
)V0G0N";
|
|
auto begin = source_string.find("return");
|
|
auto end = source_string.find(" #");
|
|
|
|
Graph g;
|
|
const auto graph_string = R"IR(
|
|
graph(%a : Tensor,
|
|
%b : Tensor):
|
|
%2 : int = prim::Constant[value=1]()
|
|
%3 : Tensor = aten::add(%a, %b, %2)
|
|
return (%3))IR";
|
|
|
|
torch::jit::parseIR(graph_string, &g);
|
|
auto source = std::make_shared<Source>(source_string, "", 0);
|
|
auto node = *g.nodes().begin();
|
|
node->setSourceRange(SourceRange{source, begin, end});
|
|
|
|
ScriptProfile p;
|
|
p.enable();
|
|
{
|
|
profiling::InstructionSpan g0(*node);
|
|
profiling::InstructionSpan g1(*node);
|
|
profiling::InstructionSpan g2(*node);
|
|
}
|
|
p.disable();
|
|
|
|
auto stats = p.dumpStats();
|
|
EXPECT_EQ(stats.size(), 1);
|
|
auto it = stats.find(*source.get());
|
|
EXPECT_NE(it, stats.end());
|
|
auto& lines = it->second;
|
|
EXPECT_EQ(lines.size(), 1);
|
|
const auto& stat = lines.at(source->lineno_for_offset(begin));
|
|
EXPECT_EQ(stat.count, 3);
|
|
}
|
|
|
|
TEST(ScriptProfileTest, CallingOrder) {
|
|
ScriptProfile p;
|
|
p.enable();
|
|
EXPECT_THROW(p.dumpStats(), c10::Error);
|
|
p.disable();
|
|
auto dp = std::make_shared<profiling::Datapoint>(SourceRange{});
|
|
EXPECT_THROW(p.addDatapoint(std::move(dp)), c10::Error);
|
|
}
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|