Profiling GraphExecutor

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/19994

Differential Revision: D15307752

Pulled By: Krovatkin

fbshipit-source-id: 7b35191042199ef16823487e15fe639968cbdc89
This commit is contained in:
Nikolay Korovaiko 2019-05-10 23:02:41 -07:00 committed by Facebook Github Bot
parent f4d9bfaa4d
commit 9499c7b7ee
10 changed files with 227 additions and 61 deletions

View File

@ -380,6 +380,7 @@ if (NOT INTERN_BUILD_MOBILE)
${TORCH_SRC_DIR}/csrc/jit/subgraph_matcher.cpp ${TORCH_SRC_DIR}/csrc/jit/subgraph_matcher.cpp
${TORCH_SRC_DIR}/csrc/jit/symbolic_script.cpp ${TORCH_SRC_DIR}/csrc/jit/symbolic_script.cpp
${TORCH_SRC_DIR}/csrc/jit/profiling_record.cpp ${TORCH_SRC_DIR}/csrc/jit/profiling_record.cpp
${TORCH_SRC_DIR}/csrc/jit/profiling_graph_executor_impl.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/alias_analysis.cpp ${TORCH_SRC_DIR}/csrc/jit/passes/alias_analysis.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/batch_mm.cpp ${TORCH_SRC_DIR}/csrc/jit/passes/batch_mm.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/canonicalize.cpp ${TORCH_SRC_DIR}/csrc/jit/passes/canonicalize.cpp

View File

@ -258,6 +258,11 @@ def enable_cpu_fuser_if(cond):
return wrapper return wrapper
return noop_fuser return noop_fuser
@contextmanager
def enable_profiling_mode():
torch._C._jit_set_profiling_mode(True)
yield
torch._C._jit_set_profiling_mode(False)
# note: not re-entrant, use unnested only # note: not re-entrant, use unnested only
@contextmanager @contextmanager
@ -5001,6 +5006,31 @@ a")
# NOTE: cannot optimize yet because broadcasts are not inserted before the fuser runs # NOTE: cannot optimize yet because broadcasts are not inserted before the fuser runs
self.checkScript(script, [alpha, beta, x, y], optimize=False, outputs=outputs) self.checkScript(script, [alpha, beta, x, y], optimize=False, outputs=outputs)
def test_profiling_graph_executor(self):
@torch.jit.script
def basic(x, y):
a = x + y
b = x * y
c = x + 1
d = a - c
e = b - c
return d + e
a = torch.rand(2, 3)
b = torch.rand(2, 3)
with enable_profiling_mode():
basic(a, b)
basic(a, b)
basic(a, b)
# this tests that a profiling count is being decrement by
# a profile instruction.
# this is the easiest way to test that a graph was instrumented
# from python
with self.assertRaisesRegex(RuntimeError, "Not yet implemented"):
basic(a, b)
def test_resize_input_ops(self): def test_resize_input_ops(self):
# resize_ and resize_as resize the input tensor. because our shape analysis # resize_ and resize_as resize the input tensor. because our shape analysis
# is flow invariant, we set any Tensor that can alias a resized Tensor # is flow invariant, we set any Tensor that can alias a resized Tensor

View File

@ -68,6 +68,7 @@ libtorch_sources = [
"torch/csrc/jit/register_c10_ops.cpp", "torch/csrc/jit/register_c10_ops.cpp",
"torch/csrc/jit/subgraph_matcher.cpp", "torch/csrc/jit/subgraph_matcher.cpp",
"torch/csrc/jit/symbolic_script.cpp", "torch/csrc/jit/symbolic_script.cpp",
"torch/csrc/jit/profiling_graph_executor_impl.cpp",
"torch/csrc/jit/profiling_record.cpp", "torch/csrc/jit/profiling_record.cpp",
"torch/csrc/jit/operator.cpp", "torch/csrc/jit/operator.cpp",
"torch/csrc/jit/passes/alias_analysis.cpp", "torch/csrc/jit/passes/alias_analysis.cpp",

View File

@ -6,6 +6,7 @@
#include <torch/csrc/jit/argument_spec.h> #include <torch/csrc/jit/argument_spec.h>
#include <torch/csrc/jit/autodiff.h> #include <torch/csrc/jit/autodiff.h>
#include <torch/csrc/jit/custom_operator.h> #include <torch/csrc/jit/custom_operator.h>
#include <torch/csrc/jit/graph_executor_impl.h>
#include <torch/csrc/jit/interpreter.h> #include <torch/csrc/jit/interpreter.h>
#include <torch/csrc/jit/ir.h> #include <torch/csrc/jit/ir.h>
#include <torch/csrc/jit/pass_manager.h> #include <torch/csrc/jit/pass_manager.h>
@ -27,6 +28,8 @@
#include <torch/csrc/jit/passes/requires_grad_analysis.h> #include <torch/csrc/jit/passes/requires_grad_analysis.h>
#include <torch/csrc/jit/passes/shape_analysis.h> #include <torch/csrc/jit/passes/shape_analysis.h>
#include <torch/csrc/jit/passes/specialize_autogradzero.h> #include <torch/csrc/jit/passes/specialize_autogradzero.h>
#include <torch/csrc/jit/profiling_graph_executor_impl.h>
#include <torch/csrc/jit/profiling_record.h>
#include <torch/csrc/jit/resource_guard.h> #include <torch/csrc/jit/resource_guard.h>
#include <torch/csrc/jit/tracer.h> #include <torch/csrc/jit/tracer.h>
@ -58,6 +61,11 @@ std::shared_ptr<Graph> lastExecutedOptimizedGraph() {
return last_executed_optimized_graph.lock(); return last_executed_optimized_graph.lock();
} }
void ExecutionPlan::run(Stack& stack) const {
InterpreterState(code).run(stack);
last_executed_optimized_graph = graph;
}
namespace { namespace {
using tensor_list = std::vector<at::Tensor>; using tensor_list = std::vector<at::Tensor>;
@ -70,31 +78,6 @@ using autograd::variable_list;
const size_t autodiffSubgraphNodeThreshold = 2; const size_t autodiffSubgraphNodeThreshold = 2;
const size_t autodiffSubgraphInlineThreshold = 5; const size_t autodiffSubgraphInlineThreshold = 5;
struct ExecutionPlan {
ExecutionPlan() = default;
ExecutionPlan(std::shared_ptr<Graph> graph)
: code(graph), graph(std::move(graph)) {}
void run(Stack& stack) const {
InterpreterState(code).run(stack);
last_executed_optimized_graph = graph;
}
operator bool() const {
return static_cast<bool>(graph);
}
ExecutionPlanState getDebugState() {
ExecutionPlanState state;
state.code = &code;
state.graph = graph.get();
return state;
}
Code code;
std::shared_ptr<Graph> graph;
};
struct CaptureList { struct CaptureList {
CaptureList(size_t capture_size) { CaptureList(size_t capture_size) {
capture_types_.reserve(capture_size); capture_types_.reserve(capture_size);
@ -489,27 +472,15 @@ GraphExecutor* getGradExecutor(Operation& op) {
// and different requires_grad states, and handles specializations for each // and different requires_grad states, and handles specializations for each
// situation. GraphExecutor is completely unaware of tracing or module // situation. GraphExecutor is completely unaware of tracing or module
// parameters to keep the tracing concerns separated. // parameters to keep the tracing concerns separated.
struct GraphExecutorImpl { struct GraphExecutorImpl : public GraphExecutorImplBase {
static std::shared_ptr<Graph> prepareGraph(std::shared_ptr<Graph>& graph) { GraphExecutorImpl(const std::shared_ptr<Graph>& graph, bool optimize)
auto copy = graph->copy(); : GraphExecutorImplBase(graph, optimize), arg_spec_creator_(*graph) {
EraseShapeInformation(copy);
return copy;
}
GraphExecutorImpl(std::shared_ptr<Graph> graph, bool optimize)
: graph(prepareGraph(graph)),
// until we have correct alias analysis any use of mutable operators
// disables all optimization
optimize(optimize),
num_inputs(this->graph->inputs().size()),
arg_spec_creator_(*graph),
num_outputs(this->graph->outputs().size()) {
logging::getLogger()->addStatValue( logging::getLogger()->addStatValue(
logging::runtime_counters::GRAPH_EXECUTORS_CONSTRUCTED, 1.0); logging::runtime_counters::GRAPH_EXECUTORS_CONSTRUCTED, 1.0);
} }
// entry point where execution begins // entry point where execution begins
void run(Stack& stack) { void run(Stack& stack) override {
AT_CHECK( AT_CHECK(
stack.size() >= num_inputs, stack.size() >= num_inputs,
"expected ", "expected ",
@ -529,7 +500,7 @@ struct GraphExecutorImpl {
return execution_plan.run(stack); return execution_plan.run(stack);
} }
GraphExecutorState getDebugState() { GraphExecutorState getDebugState() override {
GraphExecutorState state; GraphExecutorState state;
state.graph = graph.get(); state.graph = graph.get();
if (fallback) { if (fallback) {
@ -541,7 +512,7 @@ struct GraphExecutorImpl {
return state; return state;
} }
private: protected:
friend struct GraphExecutor; friend struct GraphExecutor;
const ExecutionPlan& getOrCompileFallback() { const ExecutionPlan& getOrCompileFallback() {
@ -720,18 +691,9 @@ struct GraphExecutorImpl {
} }
} }
// The unoptimized starting graph. This field is effectively const, but we ~GraphExecutorImpl() override = default;
// can't make it so because Graph::copy() is not const (and making it const is
// not that easy at this point).
std::shared_ptr<Graph> graph;
// If false, we'll run the graph as we get it, without any optimizations.
// Useful for debugging.
const bool optimize;
const size_t num_inputs;
ArgumentSpecCreator arg_spec_creator_; ArgumentSpecCreator arg_spec_creator_;
const size_t num_outputs;
// Populated only when optimize is false (and in that case plan_cache will be // Populated only when optimize is false (and in that case plan_cache will be
// unused). The compiled version of graph. // unused). The compiled version of graph.
ExecutionPlan fallback; ExecutionPlan fallback;
@ -739,14 +701,15 @@ struct GraphExecutorImpl {
// Mapping from argument configurations to optimized versions of the graph // Mapping from argument configurations to optimized versions of the graph
// that are specialized to the spec. // that are specialized to the spec.
std::unordered_map<ArgumentSpec, ExecutionPlan> plan_cache; std::unordered_map<ArgumentSpec, ExecutionPlan> plan_cache;
// GraphExecutors can be accessed from multiple threads, so this thread needs
// to be held every time we access the fallback or plan_cache.
std::mutex compile_mutex;
}; };
GraphExecutor::GraphExecutor(std::shared_ptr<Graph> graph, bool optimize) GraphExecutor::GraphExecutor(std::shared_ptr<Graph> graph, bool optimize)
: pImpl(new GraphExecutorImpl(std::move(graph), optimize)) {} : pImpl(
getProfilingMode()
? dynamic_cast<GraphExecutorImplBase*>(
new ProfilingGraphExecutorImpl(graph, optimize))
: dynamic_cast<GraphExecutorImplBase*>(
new GraphExecutorImpl(graph, optimize))) {}
void GraphExecutor::run(Stack& inputs) { void GraphExecutor::run(Stack& inputs) {
return pImpl->run(inputs); return pImpl->run(inputs);

View File

@ -26,7 +26,7 @@ struct GraphExecutorState {
std::unordered_map<ArgumentSpec, ExecutionPlanState> execution_plans; std::unordered_map<ArgumentSpec, ExecutionPlanState> execution_plans;
}; };
struct GraphExecutorImpl; struct GraphExecutorImplBase;
struct TORCH_API GraphExecutor { struct TORCH_API GraphExecutor {
GraphExecutor() = default; GraphExecutor() = default;
GraphExecutor(std::shared_ptr<Graph> graph, bool optimize = true); GraphExecutor(std::shared_ptr<Graph> graph, bool optimize = true);
@ -38,7 +38,7 @@ struct TORCH_API GraphExecutor {
GraphExecutorState getDebugState(); GraphExecutorState getDebugState();
private: private:
std::shared_ptr<GraphExecutorImpl> pImpl; std::shared_ptr<GraphExecutorImplBase> pImpl;
}; };
// These passes need to run before it is valid to pass to the interpreter // These passes need to run before it is valid to pass to the interpreter
@ -48,6 +48,8 @@ TORCH_API void runRequiredPasses(const std::shared_ptr<Graph>& g);
TORCH_API void debugSetAutodiffSubgraphInlining(bool state); TORCH_API void debugSetAutodiffSubgraphInlining(bool state);
TORCH_API std::shared_ptr<Graph> lastExecutedOptimizedGraph(); TORCH_API std::shared_ptr<Graph> lastExecutedOptimizedGraph();
TORCH_API bool& getProfilingMode();
namespace detail { namespace detail {
GraphExecutor* getGradExecutor(Operation& op); GraphExecutor* getGradExecutor(Operation& op);

View File

@ -0,0 +1,101 @@
#pragma once
#include <torch/csrc/jit/graph_executor.h>
#include <ATen/core/ivalue.h>
#include <c10/util/Exception.h>
#include <torch/csrc/autograd/grad_mode.h>
#include <torch/csrc/jit/argument_spec.h>
#include <torch/csrc/jit/autodiff.h>
#include <torch/csrc/jit/custom_operator.h>
#include <torch/csrc/jit/interpreter.h>
#include <torch/csrc/jit/ir.h>
#include <torch/csrc/jit/profiling_record.h>
#include <torch/csrc/jit/resource_guard.h>
#include <torch/csrc/jit/symbolic_variable.h>
#include <torch/csrc/jit/tracer.h>
#include <torch/csrc/autograd/edge.h>
#include <torch/csrc/autograd/function.h>
#include <torch/csrc/jit/script/compiler.h>
#include <torch/csrc/jit/script/logging.h>
#include <cstdint>
#include <iterator>
#include <memory>
#include <mutex>
#include <unordered_map>
#include <utility>
#include <vector>
namespace torch {
namespace jit {
struct ExecutionPlan {
ExecutionPlan() = default;
ExecutionPlan(std::shared_ptr<Graph> graph)
: code(graph), graph(std::move(graph)) {}
void run(Stack& stack) const;
operator bool() const {
return static_cast<bool>(graph);
}
ExecutionPlanState getDebugState() {
ExecutionPlanState state;
state.code = &code;
state.graph = graph.get();
return state;
}
Code code;
std::shared_ptr<Graph> graph;
};
// a Graph can be created via tracing, or via a language-based frontend
// GraphExecutor runs it. It can run the same graph on many different sizes
// and different requires_grad states, and handles specializations for each
// situation. GraphExecutor is completely unaware of tracing or module
// parameters to keep the tracing concerns separated.
struct GraphExecutorImplBase {
static std::shared_ptr<Graph> prepareGraph(
const std::shared_ptr<Graph>& graph) {
auto copy = graph->copy();
EraseShapeInformation(copy);
return copy;
}
GraphExecutorImplBase(const std::shared_ptr<Graph>& graph, bool optimize)
: graph(prepareGraph(graph)),
// until we have correct alias analysis any use of mutable operators
// disables all optimization
optimize(optimize),
num_inputs(this->graph->inputs().size()),
num_outputs(this->graph->outputs().size()) {}
// entry point where execution begins
virtual void run(Stack& stack) = 0;
virtual GraphExecutorState getDebugState() = 0;
virtual ~GraphExecutorImplBase() = default;
protected:
friend struct GraphExecutor;
// The unoptimized starting graph. This field is effectively const, but we
// can't make it so because Graph::copy() is not const (and making it const is
// not that easy at this point).
std::shared_ptr<Graph> graph;
// If false, we'll run the graph as we get it, without any optimizations.
// Useful for debugging.
const bool optimize;
const size_t num_inputs;
const size_t num_outputs;
// GraphExecutors can be accessed from multiple threads, so this thread needs
// to be held every time we access the fallback or plan_cache.
std::mutex compile_mutex;
};
} // namespace jit
} // namespace torch

View File

@ -295,6 +295,9 @@ void initJITBindings(PyObject* module) {
auto stack = toStack(args); auto stack = toStack(args);
checkAliasAnnotation(g, std::move(stack), unqualified_op_name); checkAliasAnnotation(g, std::move(stack), unqualified_op_name);
}) })
.def(
"_jit_set_profiling_mode",
[](bool profiling_flag) { getProfilingMode() = profiling_flag; })
.def( .def(
"_jit_fuser_get_fused_kernel_code", "_jit_fuser_get_fused_kernel_code",
[](Graph& g, std::vector<at::Tensor> inps) { [](Graph& g, std::vector<at::Tensor> inps) {

View File

@ -0,0 +1,42 @@
#include <torch/csrc/jit/profiling_graph_executor_impl.h>
namespace torch {
namespace jit {
thread_local bool profiling_mode = false;
bool& getProfilingMode() {
return profiling_mode;
}
void ProfilingGraphExecutorImpl::run(Stack& stack) {
AT_CHECK(
stack.size() >= num_inputs,
"expected ",
num_inputs,
" inputs, but got only ",
stack.size());
{
std::lock_guard<std::mutex> lock(compile_mutex);
if (!pr_) {
auto g = graph->copy();
runRequiredPasses(g);
pr_ = ProfilingRecord::instrumentGraph(g);
exec_plan_ = caffe2::make_unique<ExecutionPlan>(pr_->profiled_graph_);
}
}
if (pr_->profiling_count_ > 0) {
exec_plan_->run(stack);
} else {
AT_ERROR("Not yet implemented");
}
return;
}
GraphExecutorState ProfilingGraphExecutorImpl::getDebugState() {
AT_ERROR("not supported");
}
} // namespace jit
} // namespace torch

View File

@ -0,0 +1,20 @@
#pragma once
#include <torch/csrc/jit/graph_executor_impl.h>
namespace torch {
namespace jit {
struct ProfilingGraphExecutorImpl : public GraphExecutorImplBase {
using GraphExecutorImplBase::GraphExecutorImplBase;
void run(Stack& stack) override;
GraphExecutorState getDebugState() override;
~ProfilingGraphExecutorImpl() override = default;
private:
std::unique_ptr<ProfilingRecord> pr_;
std::unique_ptr<ExecutionPlan> exec_plan_;
};
} // namespace jit
} // namespace torch

View File

@ -66,7 +66,10 @@ std::unique_ptr<ProfilingRecord> ProfilingRecord::instrumentGraph(
pr->instrumentBlock(new_g->block()); pr->instrumentBlock(new_g->block());
std::function<void(Stack&)> counter = [raw_pr](Stack&) { std::function<void(Stack&)> counter = [raw_pr](Stack&) {
std::lock_guard<std::mutex> lock(raw_pr->mutex_); std::lock_guard<std::mutex> lock(raw_pr->mutex_);
if (raw_pr->profiling_count_ > 0)
{
raw_pr->profiling_count_--; raw_pr->profiling_count_--;
}
}; };
auto pop = pr->createProfileNode(counter, {}); auto pop = pr->createProfileNode(counter, {});