mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Profiling GraphExecutor
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/19994 Differential Revision: D15307752 Pulled By: Krovatkin fbshipit-source-id: 7b35191042199ef16823487e15fe639968cbdc89
This commit is contained in:
parent
f4d9bfaa4d
commit
9499c7b7ee
|
|
@ -380,6 +380,7 @@ if (NOT INTERN_BUILD_MOBILE)
|
||||||
${TORCH_SRC_DIR}/csrc/jit/subgraph_matcher.cpp
|
${TORCH_SRC_DIR}/csrc/jit/subgraph_matcher.cpp
|
||||||
${TORCH_SRC_DIR}/csrc/jit/symbolic_script.cpp
|
${TORCH_SRC_DIR}/csrc/jit/symbolic_script.cpp
|
||||||
${TORCH_SRC_DIR}/csrc/jit/profiling_record.cpp
|
${TORCH_SRC_DIR}/csrc/jit/profiling_record.cpp
|
||||||
|
${TORCH_SRC_DIR}/csrc/jit/profiling_graph_executor_impl.cpp
|
||||||
${TORCH_SRC_DIR}/csrc/jit/passes/alias_analysis.cpp
|
${TORCH_SRC_DIR}/csrc/jit/passes/alias_analysis.cpp
|
||||||
${TORCH_SRC_DIR}/csrc/jit/passes/batch_mm.cpp
|
${TORCH_SRC_DIR}/csrc/jit/passes/batch_mm.cpp
|
||||||
${TORCH_SRC_DIR}/csrc/jit/passes/canonicalize.cpp
|
${TORCH_SRC_DIR}/csrc/jit/passes/canonicalize.cpp
|
||||||
|
|
|
||||||
|
|
@ -258,6 +258,11 @@ def enable_cpu_fuser_if(cond):
|
||||||
return wrapper
|
return wrapper
|
||||||
return noop_fuser
|
return noop_fuser
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def enable_profiling_mode():
|
||||||
|
torch._C._jit_set_profiling_mode(True)
|
||||||
|
yield
|
||||||
|
torch._C._jit_set_profiling_mode(False)
|
||||||
|
|
||||||
# note: not re-entrant, use unnested only
|
# note: not re-entrant, use unnested only
|
||||||
@contextmanager
|
@contextmanager
|
||||||
|
|
@ -5001,6 +5006,31 @@ a")
|
||||||
# NOTE: cannot optimize yet because broadcasts are not inserted before the fuser runs
|
# NOTE: cannot optimize yet because broadcasts are not inserted before the fuser runs
|
||||||
self.checkScript(script, [alpha, beta, x, y], optimize=False, outputs=outputs)
|
self.checkScript(script, [alpha, beta, x, y], optimize=False, outputs=outputs)
|
||||||
|
|
||||||
|
def test_profiling_graph_executor(self):
|
||||||
|
@torch.jit.script
|
||||||
|
def basic(x, y):
|
||||||
|
a = x + y
|
||||||
|
b = x * y
|
||||||
|
c = x + 1
|
||||||
|
d = a - c
|
||||||
|
e = b - c
|
||||||
|
return d + e
|
||||||
|
|
||||||
|
a = torch.rand(2, 3)
|
||||||
|
b = torch.rand(2, 3)
|
||||||
|
|
||||||
|
with enable_profiling_mode():
|
||||||
|
basic(a, b)
|
||||||
|
basic(a, b)
|
||||||
|
basic(a, b)
|
||||||
|
|
||||||
|
# this tests that a profiling count is being decrement by
|
||||||
|
# a profile instruction.
|
||||||
|
# this is the easiest way to test that a graph was instrumented
|
||||||
|
# from python
|
||||||
|
with self.assertRaisesRegex(RuntimeError, "Not yet implemented"):
|
||||||
|
basic(a, b)
|
||||||
|
|
||||||
def test_resize_input_ops(self):
|
def test_resize_input_ops(self):
|
||||||
# resize_ and resize_as resize the input tensor. because our shape analysis
|
# resize_ and resize_as resize the input tensor. because our shape analysis
|
||||||
# is flow invariant, we set any Tensor that can alias a resized Tensor
|
# is flow invariant, we set any Tensor that can alias a resized Tensor
|
||||||
|
|
|
||||||
|
|
@ -68,6 +68,7 @@ libtorch_sources = [
|
||||||
"torch/csrc/jit/register_c10_ops.cpp",
|
"torch/csrc/jit/register_c10_ops.cpp",
|
||||||
"torch/csrc/jit/subgraph_matcher.cpp",
|
"torch/csrc/jit/subgraph_matcher.cpp",
|
||||||
"torch/csrc/jit/symbolic_script.cpp",
|
"torch/csrc/jit/symbolic_script.cpp",
|
||||||
|
"torch/csrc/jit/profiling_graph_executor_impl.cpp",
|
||||||
"torch/csrc/jit/profiling_record.cpp",
|
"torch/csrc/jit/profiling_record.cpp",
|
||||||
"torch/csrc/jit/operator.cpp",
|
"torch/csrc/jit/operator.cpp",
|
||||||
"torch/csrc/jit/passes/alias_analysis.cpp",
|
"torch/csrc/jit/passes/alias_analysis.cpp",
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@
|
||||||
#include <torch/csrc/jit/argument_spec.h>
|
#include <torch/csrc/jit/argument_spec.h>
|
||||||
#include <torch/csrc/jit/autodiff.h>
|
#include <torch/csrc/jit/autodiff.h>
|
||||||
#include <torch/csrc/jit/custom_operator.h>
|
#include <torch/csrc/jit/custom_operator.h>
|
||||||
|
#include <torch/csrc/jit/graph_executor_impl.h>
|
||||||
#include <torch/csrc/jit/interpreter.h>
|
#include <torch/csrc/jit/interpreter.h>
|
||||||
#include <torch/csrc/jit/ir.h>
|
#include <torch/csrc/jit/ir.h>
|
||||||
#include <torch/csrc/jit/pass_manager.h>
|
#include <torch/csrc/jit/pass_manager.h>
|
||||||
|
|
@ -27,6 +28,8 @@
|
||||||
#include <torch/csrc/jit/passes/requires_grad_analysis.h>
|
#include <torch/csrc/jit/passes/requires_grad_analysis.h>
|
||||||
#include <torch/csrc/jit/passes/shape_analysis.h>
|
#include <torch/csrc/jit/passes/shape_analysis.h>
|
||||||
#include <torch/csrc/jit/passes/specialize_autogradzero.h>
|
#include <torch/csrc/jit/passes/specialize_autogradzero.h>
|
||||||
|
#include <torch/csrc/jit/profiling_graph_executor_impl.h>
|
||||||
|
#include <torch/csrc/jit/profiling_record.h>
|
||||||
#include <torch/csrc/jit/resource_guard.h>
|
#include <torch/csrc/jit/resource_guard.h>
|
||||||
#include <torch/csrc/jit/tracer.h>
|
#include <torch/csrc/jit/tracer.h>
|
||||||
|
|
||||||
|
|
@ -58,6 +61,11 @@ std::shared_ptr<Graph> lastExecutedOptimizedGraph() {
|
||||||
return last_executed_optimized_graph.lock();
|
return last_executed_optimized_graph.lock();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ExecutionPlan::run(Stack& stack) const {
|
||||||
|
InterpreterState(code).run(stack);
|
||||||
|
last_executed_optimized_graph = graph;
|
||||||
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
using tensor_list = std::vector<at::Tensor>;
|
using tensor_list = std::vector<at::Tensor>;
|
||||||
|
|
@ -70,31 +78,6 @@ using autograd::variable_list;
|
||||||
const size_t autodiffSubgraphNodeThreshold = 2;
|
const size_t autodiffSubgraphNodeThreshold = 2;
|
||||||
const size_t autodiffSubgraphInlineThreshold = 5;
|
const size_t autodiffSubgraphInlineThreshold = 5;
|
||||||
|
|
||||||
struct ExecutionPlan {
|
|
||||||
ExecutionPlan() = default;
|
|
||||||
ExecutionPlan(std::shared_ptr<Graph> graph)
|
|
||||||
: code(graph), graph(std::move(graph)) {}
|
|
||||||
|
|
||||||
void run(Stack& stack) const {
|
|
||||||
InterpreterState(code).run(stack);
|
|
||||||
last_executed_optimized_graph = graph;
|
|
||||||
}
|
|
||||||
|
|
||||||
operator bool() const {
|
|
||||||
return static_cast<bool>(graph);
|
|
||||||
}
|
|
||||||
|
|
||||||
ExecutionPlanState getDebugState() {
|
|
||||||
ExecutionPlanState state;
|
|
||||||
state.code = &code;
|
|
||||||
state.graph = graph.get();
|
|
||||||
return state;
|
|
||||||
}
|
|
||||||
|
|
||||||
Code code;
|
|
||||||
std::shared_ptr<Graph> graph;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct CaptureList {
|
struct CaptureList {
|
||||||
CaptureList(size_t capture_size) {
|
CaptureList(size_t capture_size) {
|
||||||
capture_types_.reserve(capture_size);
|
capture_types_.reserve(capture_size);
|
||||||
|
|
@ -489,27 +472,15 @@ GraphExecutor* getGradExecutor(Operation& op) {
|
||||||
// and different requires_grad states, and handles specializations for each
|
// and different requires_grad states, and handles specializations for each
|
||||||
// situation. GraphExecutor is completely unaware of tracing or module
|
// situation. GraphExecutor is completely unaware of tracing or module
|
||||||
// parameters to keep the tracing concerns separated.
|
// parameters to keep the tracing concerns separated.
|
||||||
struct GraphExecutorImpl {
|
struct GraphExecutorImpl : public GraphExecutorImplBase {
|
||||||
static std::shared_ptr<Graph> prepareGraph(std::shared_ptr<Graph>& graph) {
|
GraphExecutorImpl(const std::shared_ptr<Graph>& graph, bool optimize)
|
||||||
auto copy = graph->copy();
|
: GraphExecutorImplBase(graph, optimize), arg_spec_creator_(*graph) {
|
||||||
EraseShapeInformation(copy);
|
|
||||||
return copy;
|
|
||||||
}
|
|
||||||
|
|
||||||
GraphExecutorImpl(std::shared_ptr<Graph> graph, bool optimize)
|
|
||||||
: graph(prepareGraph(graph)),
|
|
||||||
// until we have correct alias analysis any use of mutable operators
|
|
||||||
// disables all optimization
|
|
||||||
optimize(optimize),
|
|
||||||
num_inputs(this->graph->inputs().size()),
|
|
||||||
arg_spec_creator_(*graph),
|
|
||||||
num_outputs(this->graph->outputs().size()) {
|
|
||||||
logging::getLogger()->addStatValue(
|
logging::getLogger()->addStatValue(
|
||||||
logging::runtime_counters::GRAPH_EXECUTORS_CONSTRUCTED, 1.0);
|
logging::runtime_counters::GRAPH_EXECUTORS_CONSTRUCTED, 1.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
// entry point where execution begins
|
// entry point where execution begins
|
||||||
void run(Stack& stack) {
|
void run(Stack& stack) override {
|
||||||
AT_CHECK(
|
AT_CHECK(
|
||||||
stack.size() >= num_inputs,
|
stack.size() >= num_inputs,
|
||||||
"expected ",
|
"expected ",
|
||||||
|
|
@ -529,7 +500,7 @@ struct GraphExecutorImpl {
|
||||||
return execution_plan.run(stack);
|
return execution_plan.run(stack);
|
||||||
}
|
}
|
||||||
|
|
||||||
GraphExecutorState getDebugState() {
|
GraphExecutorState getDebugState() override {
|
||||||
GraphExecutorState state;
|
GraphExecutorState state;
|
||||||
state.graph = graph.get();
|
state.graph = graph.get();
|
||||||
if (fallback) {
|
if (fallback) {
|
||||||
|
|
@ -541,7 +512,7 @@ struct GraphExecutorImpl {
|
||||||
return state;
|
return state;
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
protected:
|
||||||
friend struct GraphExecutor;
|
friend struct GraphExecutor;
|
||||||
|
|
||||||
const ExecutionPlan& getOrCompileFallback() {
|
const ExecutionPlan& getOrCompileFallback() {
|
||||||
|
|
@ -720,18 +691,9 @@ struct GraphExecutorImpl {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// The unoptimized starting graph. This field is effectively const, but we
|
~GraphExecutorImpl() override = default;
|
||||||
// can't make it so because Graph::copy() is not const (and making it const is
|
|
||||||
// not that easy at this point).
|
|
||||||
std::shared_ptr<Graph> graph;
|
|
||||||
|
|
||||||
// If false, we'll run the graph as we get it, without any optimizations.
|
|
||||||
// Useful for debugging.
|
|
||||||
const bool optimize;
|
|
||||||
const size_t num_inputs;
|
|
||||||
ArgumentSpecCreator arg_spec_creator_;
|
ArgumentSpecCreator arg_spec_creator_;
|
||||||
const size_t num_outputs;
|
|
||||||
|
|
||||||
// Populated only when optimize is false (and in that case plan_cache will be
|
// Populated only when optimize is false (and in that case plan_cache will be
|
||||||
// unused). The compiled version of graph.
|
// unused). The compiled version of graph.
|
||||||
ExecutionPlan fallback;
|
ExecutionPlan fallback;
|
||||||
|
|
@ -739,14 +701,15 @@ struct GraphExecutorImpl {
|
||||||
// Mapping from argument configurations to optimized versions of the graph
|
// Mapping from argument configurations to optimized versions of the graph
|
||||||
// that are specialized to the spec.
|
// that are specialized to the spec.
|
||||||
std::unordered_map<ArgumentSpec, ExecutionPlan> plan_cache;
|
std::unordered_map<ArgumentSpec, ExecutionPlan> plan_cache;
|
||||||
|
|
||||||
// GraphExecutors can be accessed from multiple threads, so this thread needs
|
|
||||||
// to be held every time we access the fallback or plan_cache.
|
|
||||||
std::mutex compile_mutex;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
GraphExecutor::GraphExecutor(std::shared_ptr<Graph> graph, bool optimize)
|
GraphExecutor::GraphExecutor(std::shared_ptr<Graph> graph, bool optimize)
|
||||||
: pImpl(new GraphExecutorImpl(std::move(graph), optimize)) {}
|
: pImpl(
|
||||||
|
getProfilingMode()
|
||||||
|
? dynamic_cast<GraphExecutorImplBase*>(
|
||||||
|
new ProfilingGraphExecutorImpl(graph, optimize))
|
||||||
|
: dynamic_cast<GraphExecutorImplBase*>(
|
||||||
|
new GraphExecutorImpl(graph, optimize))) {}
|
||||||
|
|
||||||
void GraphExecutor::run(Stack& inputs) {
|
void GraphExecutor::run(Stack& inputs) {
|
||||||
return pImpl->run(inputs);
|
return pImpl->run(inputs);
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,7 @@ struct GraphExecutorState {
|
||||||
std::unordered_map<ArgumentSpec, ExecutionPlanState> execution_plans;
|
std::unordered_map<ArgumentSpec, ExecutionPlanState> execution_plans;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct GraphExecutorImpl;
|
struct GraphExecutorImplBase;
|
||||||
struct TORCH_API GraphExecutor {
|
struct TORCH_API GraphExecutor {
|
||||||
GraphExecutor() = default;
|
GraphExecutor() = default;
|
||||||
GraphExecutor(std::shared_ptr<Graph> graph, bool optimize = true);
|
GraphExecutor(std::shared_ptr<Graph> graph, bool optimize = true);
|
||||||
|
|
@ -38,7 +38,7 @@ struct TORCH_API GraphExecutor {
|
||||||
GraphExecutorState getDebugState();
|
GraphExecutorState getDebugState();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::shared_ptr<GraphExecutorImpl> pImpl;
|
std::shared_ptr<GraphExecutorImplBase> pImpl;
|
||||||
};
|
};
|
||||||
|
|
||||||
// These passes need to run before it is valid to pass to the interpreter
|
// These passes need to run before it is valid to pass to the interpreter
|
||||||
|
|
@ -48,6 +48,8 @@ TORCH_API void runRequiredPasses(const std::shared_ptr<Graph>& g);
|
||||||
TORCH_API void debugSetAutodiffSubgraphInlining(bool state);
|
TORCH_API void debugSetAutodiffSubgraphInlining(bool state);
|
||||||
TORCH_API std::shared_ptr<Graph> lastExecutedOptimizedGraph();
|
TORCH_API std::shared_ptr<Graph> lastExecutedOptimizedGraph();
|
||||||
|
|
||||||
|
TORCH_API bool& getProfilingMode();
|
||||||
|
|
||||||
namespace detail {
|
namespace detail {
|
||||||
|
|
||||||
GraphExecutor* getGradExecutor(Operation& op);
|
GraphExecutor* getGradExecutor(Operation& op);
|
||||||
|
|
|
||||||
101
torch/csrc/jit/graph_executor_impl.h
Normal file
101
torch/csrc/jit/graph_executor_impl.h
Normal file
|
|
@ -0,0 +1,101 @@
|
||||||
|
#pragma once
|
||||||
|
#include <torch/csrc/jit/graph_executor.h>
|
||||||
|
|
||||||
|
#include <ATen/core/ivalue.h>
|
||||||
|
#include <c10/util/Exception.h>
|
||||||
|
#include <torch/csrc/autograd/grad_mode.h>
|
||||||
|
#include <torch/csrc/jit/argument_spec.h>
|
||||||
|
#include <torch/csrc/jit/autodiff.h>
|
||||||
|
#include <torch/csrc/jit/custom_operator.h>
|
||||||
|
#include <torch/csrc/jit/interpreter.h>
|
||||||
|
#include <torch/csrc/jit/ir.h>
|
||||||
|
#include <torch/csrc/jit/profiling_record.h>
|
||||||
|
#include <torch/csrc/jit/resource_guard.h>
|
||||||
|
#include <torch/csrc/jit/symbolic_variable.h>
|
||||||
|
#include <torch/csrc/jit/tracer.h>
|
||||||
|
|
||||||
|
#include <torch/csrc/autograd/edge.h>
|
||||||
|
#include <torch/csrc/autograd/function.h>
|
||||||
|
#include <torch/csrc/jit/script/compiler.h>
|
||||||
|
#include <torch/csrc/jit/script/logging.h>
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
#include <iterator>
|
||||||
|
#include <memory>
|
||||||
|
#include <mutex>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace torch {
|
||||||
|
namespace jit {
|
||||||
|
|
||||||
|
struct ExecutionPlan {
|
||||||
|
ExecutionPlan() = default;
|
||||||
|
ExecutionPlan(std::shared_ptr<Graph> graph)
|
||||||
|
: code(graph), graph(std::move(graph)) {}
|
||||||
|
|
||||||
|
void run(Stack& stack) const;
|
||||||
|
|
||||||
|
operator bool() const {
|
||||||
|
return static_cast<bool>(graph);
|
||||||
|
}
|
||||||
|
|
||||||
|
ExecutionPlanState getDebugState() {
|
||||||
|
ExecutionPlanState state;
|
||||||
|
state.code = &code;
|
||||||
|
state.graph = graph.get();
|
||||||
|
return state;
|
||||||
|
}
|
||||||
|
|
||||||
|
Code code;
|
||||||
|
std::shared_ptr<Graph> graph;
|
||||||
|
};
|
||||||
|
|
||||||
|
// a Graph can be created via tracing, or via a language-based frontend
|
||||||
|
// GraphExecutor runs it. It can run the same graph on many different sizes
|
||||||
|
// and different requires_grad states, and handles specializations for each
|
||||||
|
// situation. GraphExecutor is completely unaware of tracing or module
|
||||||
|
// parameters to keep the tracing concerns separated.
|
||||||
|
struct GraphExecutorImplBase {
|
||||||
|
static std::shared_ptr<Graph> prepareGraph(
|
||||||
|
const std::shared_ptr<Graph>& graph) {
|
||||||
|
auto copy = graph->copy();
|
||||||
|
EraseShapeInformation(copy);
|
||||||
|
return copy;
|
||||||
|
}
|
||||||
|
|
||||||
|
GraphExecutorImplBase(const std::shared_ptr<Graph>& graph, bool optimize)
|
||||||
|
: graph(prepareGraph(graph)),
|
||||||
|
// until we have correct alias analysis any use of mutable operators
|
||||||
|
// disables all optimization
|
||||||
|
optimize(optimize),
|
||||||
|
num_inputs(this->graph->inputs().size()),
|
||||||
|
num_outputs(this->graph->outputs().size()) {}
|
||||||
|
|
||||||
|
// entry point where execution begins
|
||||||
|
virtual void run(Stack& stack) = 0;
|
||||||
|
virtual GraphExecutorState getDebugState() = 0;
|
||||||
|
virtual ~GraphExecutorImplBase() = default;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
friend struct GraphExecutor;
|
||||||
|
|
||||||
|
// The unoptimized starting graph. This field is effectively const, but we
|
||||||
|
// can't make it so because Graph::copy() is not const (and making it const is
|
||||||
|
// not that easy at this point).
|
||||||
|
std::shared_ptr<Graph> graph;
|
||||||
|
|
||||||
|
// If false, we'll run the graph as we get it, without any optimizations.
|
||||||
|
// Useful for debugging.
|
||||||
|
const bool optimize;
|
||||||
|
const size_t num_inputs;
|
||||||
|
const size_t num_outputs;
|
||||||
|
|
||||||
|
// GraphExecutors can be accessed from multiple threads, so this thread needs
|
||||||
|
// to be held every time we access the fallback or plan_cache.
|
||||||
|
std::mutex compile_mutex;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace jit
|
||||||
|
} // namespace torch
|
||||||
|
|
@ -295,6 +295,9 @@ void initJITBindings(PyObject* module) {
|
||||||
auto stack = toStack(args);
|
auto stack = toStack(args);
|
||||||
checkAliasAnnotation(g, std::move(stack), unqualified_op_name);
|
checkAliasAnnotation(g, std::move(stack), unqualified_op_name);
|
||||||
})
|
})
|
||||||
|
.def(
|
||||||
|
"_jit_set_profiling_mode",
|
||||||
|
[](bool profiling_flag) { getProfilingMode() = profiling_flag; })
|
||||||
.def(
|
.def(
|
||||||
"_jit_fuser_get_fused_kernel_code",
|
"_jit_fuser_get_fused_kernel_code",
|
||||||
[](Graph& g, std::vector<at::Tensor> inps) {
|
[](Graph& g, std::vector<at::Tensor> inps) {
|
||||||
|
|
|
||||||
42
torch/csrc/jit/profiling_graph_executor_impl.cpp
Normal file
42
torch/csrc/jit/profiling_graph_executor_impl.cpp
Normal file
|
|
@ -0,0 +1,42 @@
|
||||||
|
#include <torch/csrc/jit/profiling_graph_executor_impl.h>
|
||||||
|
|
||||||
|
namespace torch {
|
||||||
|
namespace jit {
|
||||||
|
|
||||||
|
thread_local bool profiling_mode = false;
|
||||||
|
bool& getProfilingMode() {
|
||||||
|
return profiling_mode;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ProfilingGraphExecutorImpl::run(Stack& stack) {
|
||||||
|
AT_CHECK(
|
||||||
|
stack.size() >= num_inputs,
|
||||||
|
"expected ",
|
||||||
|
num_inputs,
|
||||||
|
" inputs, but got only ",
|
||||||
|
stack.size());
|
||||||
|
|
||||||
|
{
|
||||||
|
std::lock_guard<std::mutex> lock(compile_mutex);
|
||||||
|
if (!pr_) {
|
||||||
|
auto g = graph->copy();
|
||||||
|
runRequiredPasses(g);
|
||||||
|
pr_ = ProfilingRecord::instrumentGraph(g);
|
||||||
|
exec_plan_ = caffe2::make_unique<ExecutionPlan>(pr_->profiled_graph_);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (pr_->profiling_count_ > 0) {
|
||||||
|
exec_plan_->run(stack);
|
||||||
|
} else {
|
||||||
|
AT_ERROR("Not yet implemented");
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
GraphExecutorState ProfilingGraphExecutorImpl::getDebugState() {
|
||||||
|
AT_ERROR("not supported");
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace jit
|
||||||
|
} // namespace torch
|
||||||
20
torch/csrc/jit/profiling_graph_executor_impl.h
Normal file
20
torch/csrc/jit/profiling_graph_executor_impl.h
Normal file
|
|
@ -0,0 +1,20 @@
|
||||||
|
#pragma once
|
||||||
|
#include <torch/csrc/jit/graph_executor_impl.h>
|
||||||
|
|
||||||
|
namespace torch {
|
||||||
|
namespace jit {
|
||||||
|
|
||||||
|
struct ProfilingGraphExecutorImpl : public GraphExecutorImplBase {
|
||||||
|
using GraphExecutorImplBase::GraphExecutorImplBase;
|
||||||
|
|
||||||
|
void run(Stack& stack) override;
|
||||||
|
GraphExecutorState getDebugState() override;
|
||||||
|
~ProfilingGraphExecutorImpl() override = default;
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::unique_ptr<ProfilingRecord> pr_;
|
||||||
|
std::unique_ptr<ExecutionPlan> exec_plan_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace jit
|
||||||
|
} // namespace torch
|
||||||
|
|
@ -66,7 +66,10 @@ std::unique_ptr<ProfilingRecord> ProfilingRecord::instrumentGraph(
|
||||||
pr->instrumentBlock(new_g->block());
|
pr->instrumentBlock(new_g->block());
|
||||||
std::function<void(Stack&)> counter = [raw_pr](Stack&) {
|
std::function<void(Stack&)> counter = [raw_pr](Stack&) {
|
||||||
std::lock_guard<std::mutex> lock(raw_pr->mutex_);
|
std::lock_guard<std::mutex> lock(raw_pr->mutex_);
|
||||||
raw_pr->profiling_count_--;
|
if (raw_pr->profiling_count_ > 0)
|
||||||
|
{
|
||||||
|
raw_pr->profiling_count_--;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
auto pop = pr->createProfileNode(counter, {});
|
auto pop = pr->createProfileNode(counter, {});
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user