mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Profiling GraphExecutor
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/19994 Differential Revision: D15307752 Pulled By: Krovatkin fbshipit-source-id: 7b35191042199ef16823487e15fe639968cbdc89
This commit is contained in:
parent
f4d9bfaa4d
commit
9499c7b7ee
|
|
@ -380,6 +380,7 @@ if (NOT INTERN_BUILD_MOBILE)
|
|||
${TORCH_SRC_DIR}/csrc/jit/subgraph_matcher.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/symbolic_script.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/profiling_record.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/profiling_graph_executor_impl.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/passes/alias_analysis.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/passes/batch_mm.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/passes/canonicalize.cpp
|
||||
|
|
|
|||
|
|
@ -258,6 +258,11 @@ def enable_cpu_fuser_if(cond):
|
|||
return wrapper
|
||||
return noop_fuser
|
||||
|
||||
@contextmanager
|
||||
def enable_profiling_mode():
|
||||
torch._C._jit_set_profiling_mode(True)
|
||||
yield
|
||||
torch._C._jit_set_profiling_mode(False)
|
||||
|
||||
# note: not re-entrant, use unnested only
|
||||
@contextmanager
|
||||
|
|
@ -5001,6 +5006,31 @@ a")
|
|||
# NOTE: cannot optimize yet because broadcasts are not inserted before the fuser runs
|
||||
self.checkScript(script, [alpha, beta, x, y], optimize=False, outputs=outputs)
|
||||
|
||||
def test_profiling_graph_executor(self):
|
||||
@torch.jit.script
|
||||
def basic(x, y):
|
||||
a = x + y
|
||||
b = x * y
|
||||
c = x + 1
|
||||
d = a - c
|
||||
e = b - c
|
||||
return d + e
|
||||
|
||||
a = torch.rand(2, 3)
|
||||
b = torch.rand(2, 3)
|
||||
|
||||
with enable_profiling_mode():
|
||||
basic(a, b)
|
||||
basic(a, b)
|
||||
basic(a, b)
|
||||
|
||||
# this tests that a profiling count is being decrement by
|
||||
# a profile instruction.
|
||||
# this is the easiest way to test that a graph was instrumented
|
||||
# from python
|
||||
with self.assertRaisesRegex(RuntimeError, "Not yet implemented"):
|
||||
basic(a, b)
|
||||
|
||||
def test_resize_input_ops(self):
|
||||
# resize_ and resize_as resize the input tensor. because our shape analysis
|
||||
# is flow invariant, we set any Tensor that can alias a resized Tensor
|
||||
|
|
|
|||
|
|
@ -68,6 +68,7 @@ libtorch_sources = [
|
|||
"torch/csrc/jit/register_c10_ops.cpp",
|
||||
"torch/csrc/jit/subgraph_matcher.cpp",
|
||||
"torch/csrc/jit/symbolic_script.cpp",
|
||||
"torch/csrc/jit/profiling_graph_executor_impl.cpp",
|
||||
"torch/csrc/jit/profiling_record.cpp",
|
||||
"torch/csrc/jit/operator.cpp",
|
||||
"torch/csrc/jit/passes/alias_analysis.cpp",
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
#include <torch/csrc/jit/argument_spec.h>
|
||||
#include <torch/csrc/jit/autodiff.h>
|
||||
#include <torch/csrc/jit/custom_operator.h>
|
||||
#include <torch/csrc/jit/graph_executor_impl.h>
|
||||
#include <torch/csrc/jit/interpreter.h>
|
||||
#include <torch/csrc/jit/ir.h>
|
||||
#include <torch/csrc/jit/pass_manager.h>
|
||||
|
|
@ -27,6 +28,8 @@
|
|||
#include <torch/csrc/jit/passes/requires_grad_analysis.h>
|
||||
#include <torch/csrc/jit/passes/shape_analysis.h>
|
||||
#include <torch/csrc/jit/passes/specialize_autogradzero.h>
|
||||
#include <torch/csrc/jit/profiling_graph_executor_impl.h>
|
||||
#include <torch/csrc/jit/profiling_record.h>
|
||||
#include <torch/csrc/jit/resource_guard.h>
|
||||
#include <torch/csrc/jit/tracer.h>
|
||||
|
||||
|
|
@ -58,6 +61,11 @@ std::shared_ptr<Graph> lastExecutedOptimizedGraph() {
|
|||
return last_executed_optimized_graph.lock();
|
||||
}
|
||||
|
||||
void ExecutionPlan::run(Stack& stack) const {
|
||||
InterpreterState(code).run(stack);
|
||||
last_executed_optimized_graph = graph;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
using tensor_list = std::vector<at::Tensor>;
|
||||
|
|
@ -70,31 +78,6 @@ using autograd::variable_list;
|
|||
const size_t autodiffSubgraphNodeThreshold = 2;
|
||||
const size_t autodiffSubgraphInlineThreshold = 5;
|
||||
|
||||
struct ExecutionPlan {
|
||||
ExecutionPlan() = default;
|
||||
ExecutionPlan(std::shared_ptr<Graph> graph)
|
||||
: code(graph), graph(std::move(graph)) {}
|
||||
|
||||
void run(Stack& stack) const {
|
||||
InterpreterState(code).run(stack);
|
||||
last_executed_optimized_graph = graph;
|
||||
}
|
||||
|
||||
operator bool() const {
|
||||
return static_cast<bool>(graph);
|
||||
}
|
||||
|
||||
ExecutionPlanState getDebugState() {
|
||||
ExecutionPlanState state;
|
||||
state.code = &code;
|
||||
state.graph = graph.get();
|
||||
return state;
|
||||
}
|
||||
|
||||
Code code;
|
||||
std::shared_ptr<Graph> graph;
|
||||
};
|
||||
|
||||
struct CaptureList {
|
||||
CaptureList(size_t capture_size) {
|
||||
capture_types_.reserve(capture_size);
|
||||
|
|
@ -489,27 +472,15 @@ GraphExecutor* getGradExecutor(Operation& op) {
|
|||
// and different requires_grad states, and handles specializations for each
|
||||
// situation. GraphExecutor is completely unaware of tracing or module
|
||||
// parameters to keep the tracing concerns separated.
|
||||
struct GraphExecutorImpl {
|
||||
static std::shared_ptr<Graph> prepareGraph(std::shared_ptr<Graph>& graph) {
|
||||
auto copy = graph->copy();
|
||||
EraseShapeInformation(copy);
|
||||
return copy;
|
||||
}
|
||||
|
||||
GraphExecutorImpl(std::shared_ptr<Graph> graph, bool optimize)
|
||||
: graph(prepareGraph(graph)),
|
||||
// until we have correct alias analysis any use of mutable operators
|
||||
// disables all optimization
|
||||
optimize(optimize),
|
||||
num_inputs(this->graph->inputs().size()),
|
||||
arg_spec_creator_(*graph),
|
||||
num_outputs(this->graph->outputs().size()) {
|
||||
struct GraphExecutorImpl : public GraphExecutorImplBase {
|
||||
GraphExecutorImpl(const std::shared_ptr<Graph>& graph, bool optimize)
|
||||
: GraphExecutorImplBase(graph, optimize), arg_spec_creator_(*graph) {
|
||||
logging::getLogger()->addStatValue(
|
||||
logging::runtime_counters::GRAPH_EXECUTORS_CONSTRUCTED, 1.0);
|
||||
}
|
||||
|
||||
// entry point where execution begins
|
||||
void run(Stack& stack) {
|
||||
void run(Stack& stack) override {
|
||||
AT_CHECK(
|
||||
stack.size() >= num_inputs,
|
||||
"expected ",
|
||||
|
|
@ -529,7 +500,7 @@ struct GraphExecutorImpl {
|
|||
return execution_plan.run(stack);
|
||||
}
|
||||
|
||||
GraphExecutorState getDebugState() {
|
||||
GraphExecutorState getDebugState() override {
|
||||
GraphExecutorState state;
|
||||
state.graph = graph.get();
|
||||
if (fallback) {
|
||||
|
|
@ -541,7 +512,7 @@ struct GraphExecutorImpl {
|
|||
return state;
|
||||
}
|
||||
|
||||
private:
|
||||
protected:
|
||||
friend struct GraphExecutor;
|
||||
|
||||
const ExecutionPlan& getOrCompileFallback() {
|
||||
|
|
@ -720,18 +691,9 @@ struct GraphExecutorImpl {
|
|||
}
|
||||
}
|
||||
|
||||
// The unoptimized starting graph. This field is effectively const, but we
|
||||
// can't make it so because Graph::copy() is not const (and making it const is
|
||||
// not that easy at this point).
|
||||
std::shared_ptr<Graph> graph;
|
||||
~GraphExecutorImpl() override = default;
|
||||
|
||||
// If false, we'll run the graph as we get it, without any optimizations.
|
||||
// Useful for debugging.
|
||||
const bool optimize;
|
||||
const size_t num_inputs;
|
||||
ArgumentSpecCreator arg_spec_creator_;
|
||||
const size_t num_outputs;
|
||||
|
||||
// Populated only when optimize is false (and in that case plan_cache will be
|
||||
// unused). The compiled version of graph.
|
||||
ExecutionPlan fallback;
|
||||
|
|
@ -739,14 +701,15 @@ struct GraphExecutorImpl {
|
|||
// Mapping from argument configurations to optimized versions of the graph
|
||||
// that are specialized to the spec.
|
||||
std::unordered_map<ArgumentSpec, ExecutionPlan> plan_cache;
|
||||
|
||||
// GraphExecutors can be accessed from multiple threads, so this thread needs
|
||||
// to be held every time we access the fallback or plan_cache.
|
||||
std::mutex compile_mutex;
|
||||
};
|
||||
|
||||
GraphExecutor::GraphExecutor(std::shared_ptr<Graph> graph, bool optimize)
|
||||
: pImpl(new GraphExecutorImpl(std::move(graph), optimize)) {}
|
||||
: pImpl(
|
||||
getProfilingMode()
|
||||
? dynamic_cast<GraphExecutorImplBase*>(
|
||||
new ProfilingGraphExecutorImpl(graph, optimize))
|
||||
: dynamic_cast<GraphExecutorImplBase*>(
|
||||
new GraphExecutorImpl(graph, optimize))) {}
|
||||
|
||||
void GraphExecutor::run(Stack& inputs) {
|
||||
return pImpl->run(inputs);
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ struct GraphExecutorState {
|
|||
std::unordered_map<ArgumentSpec, ExecutionPlanState> execution_plans;
|
||||
};
|
||||
|
||||
struct GraphExecutorImpl;
|
||||
struct GraphExecutorImplBase;
|
||||
struct TORCH_API GraphExecutor {
|
||||
GraphExecutor() = default;
|
||||
GraphExecutor(std::shared_ptr<Graph> graph, bool optimize = true);
|
||||
|
|
@ -38,7 +38,7 @@ struct TORCH_API GraphExecutor {
|
|||
GraphExecutorState getDebugState();
|
||||
|
||||
private:
|
||||
std::shared_ptr<GraphExecutorImpl> pImpl;
|
||||
std::shared_ptr<GraphExecutorImplBase> pImpl;
|
||||
};
|
||||
|
||||
// These passes need to run before it is valid to pass to the interpreter
|
||||
|
|
@ -48,6 +48,8 @@ TORCH_API void runRequiredPasses(const std::shared_ptr<Graph>& g);
|
|||
TORCH_API void debugSetAutodiffSubgraphInlining(bool state);
|
||||
TORCH_API std::shared_ptr<Graph> lastExecutedOptimizedGraph();
|
||||
|
||||
TORCH_API bool& getProfilingMode();
|
||||
|
||||
namespace detail {
|
||||
|
||||
GraphExecutor* getGradExecutor(Operation& op);
|
||||
|
|
|
|||
101
torch/csrc/jit/graph_executor_impl.h
Normal file
101
torch/csrc/jit/graph_executor_impl.h
Normal file
|
|
@ -0,0 +1,101 @@
|
|||
#pragma once
|
||||
#include <torch/csrc/jit/graph_executor.h>
|
||||
|
||||
#include <ATen/core/ivalue.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <torch/csrc/autograd/grad_mode.h>
|
||||
#include <torch/csrc/jit/argument_spec.h>
|
||||
#include <torch/csrc/jit/autodiff.h>
|
||||
#include <torch/csrc/jit/custom_operator.h>
|
||||
#include <torch/csrc/jit/interpreter.h>
|
||||
#include <torch/csrc/jit/ir.h>
|
||||
#include <torch/csrc/jit/profiling_record.h>
|
||||
#include <torch/csrc/jit/resource_guard.h>
|
||||
#include <torch/csrc/jit/symbolic_variable.h>
|
||||
#include <torch/csrc/jit/tracer.h>
|
||||
|
||||
#include <torch/csrc/autograd/edge.h>
|
||||
#include <torch/csrc/autograd/function.h>
|
||||
#include <torch/csrc/jit/script/compiler.h>
|
||||
#include <torch/csrc/jit/script/logging.h>
|
||||
|
||||
#include <cstdint>
|
||||
#include <iterator>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
struct ExecutionPlan {
|
||||
ExecutionPlan() = default;
|
||||
ExecutionPlan(std::shared_ptr<Graph> graph)
|
||||
: code(graph), graph(std::move(graph)) {}
|
||||
|
||||
void run(Stack& stack) const;
|
||||
|
||||
operator bool() const {
|
||||
return static_cast<bool>(graph);
|
||||
}
|
||||
|
||||
ExecutionPlanState getDebugState() {
|
||||
ExecutionPlanState state;
|
||||
state.code = &code;
|
||||
state.graph = graph.get();
|
||||
return state;
|
||||
}
|
||||
|
||||
Code code;
|
||||
std::shared_ptr<Graph> graph;
|
||||
};
|
||||
|
||||
// a Graph can be created via tracing, or via a language-based frontend
|
||||
// GraphExecutor runs it. It can run the same graph on many different sizes
|
||||
// and different requires_grad states, and handles specializations for each
|
||||
// situation. GraphExecutor is completely unaware of tracing or module
|
||||
// parameters to keep the tracing concerns separated.
|
||||
struct GraphExecutorImplBase {
|
||||
static std::shared_ptr<Graph> prepareGraph(
|
||||
const std::shared_ptr<Graph>& graph) {
|
||||
auto copy = graph->copy();
|
||||
EraseShapeInformation(copy);
|
||||
return copy;
|
||||
}
|
||||
|
||||
GraphExecutorImplBase(const std::shared_ptr<Graph>& graph, bool optimize)
|
||||
: graph(prepareGraph(graph)),
|
||||
// until we have correct alias analysis any use of mutable operators
|
||||
// disables all optimization
|
||||
optimize(optimize),
|
||||
num_inputs(this->graph->inputs().size()),
|
||||
num_outputs(this->graph->outputs().size()) {}
|
||||
|
||||
// entry point where execution begins
|
||||
virtual void run(Stack& stack) = 0;
|
||||
virtual GraphExecutorState getDebugState() = 0;
|
||||
virtual ~GraphExecutorImplBase() = default;
|
||||
|
||||
protected:
|
||||
friend struct GraphExecutor;
|
||||
|
||||
// The unoptimized starting graph. This field is effectively const, but we
|
||||
// can't make it so because Graph::copy() is not const (and making it const is
|
||||
// not that easy at this point).
|
||||
std::shared_ptr<Graph> graph;
|
||||
|
||||
// If false, we'll run the graph as we get it, without any optimizations.
|
||||
// Useful for debugging.
|
||||
const bool optimize;
|
||||
const size_t num_inputs;
|
||||
const size_t num_outputs;
|
||||
|
||||
// GraphExecutors can be accessed from multiple threads, so this thread needs
|
||||
// to be held every time we access the fallback or plan_cache.
|
||||
std::mutex compile_mutex;
|
||||
};
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
@ -295,6 +295,9 @@ void initJITBindings(PyObject* module) {
|
|||
auto stack = toStack(args);
|
||||
checkAliasAnnotation(g, std::move(stack), unqualified_op_name);
|
||||
})
|
||||
.def(
|
||||
"_jit_set_profiling_mode",
|
||||
[](bool profiling_flag) { getProfilingMode() = profiling_flag; })
|
||||
.def(
|
||||
"_jit_fuser_get_fused_kernel_code",
|
||||
[](Graph& g, std::vector<at::Tensor> inps) {
|
||||
|
|
|
|||
42
torch/csrc/jit/profiling_graph_executor_impl.cpp
Normal file
42
torch/csrc/jit/profiling_graph_executor_impl.cpp
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
#include <torch/csrc/jit/profiling_graph_executor_impl.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
thread_local bool profiling_mode = false;
|
||||
bool& getProfilingMode() {
|
||||
return profiling_mode;
|
||||
}
|
||||
|
||||
void ProfilingGraphExecutorImpl::run(Stack& stack) {
|
||||
AT_CHECK(
|
||||
stack.size() >= num_inputs,
|
||||
"expected ",
|
||||
num_inputs,
|
||||
" inputs, but got only ",
|
||||
stack.size());
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(compile_mutex);
|
||||
if (!pr_) {
|
||||
auto g = graph->copy();
|
||||
runRequiredPasses(g);
|
||||
pr_ = ProfilingRecord::instrumentGraph(g);
|
||||
exec_plan_ = caffe2::make_unique<ExecutionPlan>(pr_->profiled_graph_);
|
||||
}
|
||||
}
|
||||
|
||||
if (pr_->profiling_count_ > 0) {
|
||||
exec_plan_->run(stack);
|
||||
} else {
|
||||
AT_ERROR("Not yet implemented");
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
GraphExecutorState ProfilingGraphExecutorImpl::getDebugState() {
|
||||
AT_ERROR("not supported");
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
20
torch/csrc/jit/profiling_graph_executor_impl.h
Normal file
20
torch/csrc/jit/profiling_graph_executor_impl.h
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
#pragma once
|
||||
#include <torch/csrc/jit/graph_executor_impl.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
struct ProfilingGraphExecutorImpl : public GraphExecutorImplBase {
|
||||
using GraphExecutorImplBase::GraphExecutorImplBase;
|
||||
|
||||
void run(Stack& stack) override;
|
||||
GraphExecutorState getDebugState() override;
|
||||
~ProfilingGraphExecutorImpl() override = default;
|
||||
|
||||
private:
|
||||
std::unique_ptr<ProfilingRecord> pr_;
|
||||
std::unique_ptr<ExecutionPlan> exec_plan_;
|
||||
};
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
@ -66,7 +66,10 @@ std::unique_ptr<ProfilingRecord> ProfilingRecord::instrumentGraph(
|
|||
pr->instrumentBlock(new_g->block());
|
||||
std::function<void(Stack&)> counter = [raw_pr](Stack&) {
|
||||
std::lock_guard<std::mutex> lock(raw_pr->mutex_);
|
||||
raw_pr->profiling_count_--;
|
||||
if (raw_pr->profiling_count_ > 0)
|
||||
{
|
||||
raw_pr->profiling_count_--;
|
||||
}
|
||||
};
|
||||
|
||||
auto pop = pr->createProfileNode(counter, {});
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user