Clean up profiling mode and profiling executor strategy (#73875)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73875

Previously we had a few settings:
- getExecutor - which toggled between Profiling Executor and Legacy
- getGraphOptimize - if true, overrides PE/Legacy to run with simple executor (no optimizations)
and then...
- getProfilingMode - which would set PE to 0 specializtions.

The last mode is redundant with getGraphOptimize, we should just remove it and use getGraphOptimize in these cases. It would lead to potentially invalid combinations of logic - what does mean if getProfilingMode is true but getExecutor is set to false ? This would lead to a bug in specialize_autograd_zero in this case, see: https://github.com/pytorch/pytorch/blob/master/torch%2Fcsrc%2Fjit%2Fpasses%2Fspecialize_autogradzero.cpp#L93.

The tests here are failing but get fixed with the PR above it, so i'll squash for landing.

Test Plan: Imported from OSS

Reviewed By: cpuhrsch

Differential Revision: D34938130

Pulled By: eellison

fbshipit-source-id: 1a9c0ae7f6d1cfddc2ed3499a5af611053ae5e1b
(cherry picked from commit cf69ce3d155ba7d334022c42fb2cee54bb088c23)
This commit is contained in:
Elias Ellison 2022-03-29 11:32:31 -07:00 committed by PyTorch MergeBot
parent ab57876420
commit 6694fdaccd
22 changed files with 93 additions and 91 deletions

View File

@ -62,7 +62,7 @@ struct BuiltinOpFunction : public Function {
return *this;
}
bool call(Stack& stack, size_t, c10::function_ref<void(const Code&)>) override {
bool call(Stack& stack, c10::optional<size_t>, c10::function_ref<void(const Code&)>) override {
run(stack);
return false;
}

View File

@ -90,7 +90,7 @@ struct TORCH_API Function {
// call() returns false.
// Overload for server interpreter, a bailout size is needed for graph executor.
virtual bool call(Stack&, size_t, c10::function_ref<void(const Code&)>) {
virtual bool call(Stack&, c10::optional<size_t>, c10::function_ref<void(const Code&)>) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(false);
return false;
}

View File

@ -4,18 +4,18 @@ def set_fuser(fuser_name, executor_name):
assert fuser_name in ['te', 'old', 'none', 'default']
if fuser_name == 'te':
torch._C._jit_set_profiling_executor(True)
torch._C._jit_set_profiling_mode(True)
torch._C._get_graph_executor_optimize(True)
torch._C._jit_override_can_fuse_on_cpu(False)
torch._C._jit_override_can_fuse_on_gpu(True)
torch._C._jit_set_texpr_fuser_enabled(True)
elif fuser_name == 'old':
torch._C._jit_set_profiling_executor(False)
torch._C._jit_set_profiling_mode(False)
torch._C._get_graph_executor_optimize(False)
torch._C._jit_override_can_fuse_on_gpu(True)
torch._C._jit_set_texpr_fuser_enabled(False)
elif fuser_name == 'none':
torch._C._jit_set_profiling_executor(False)
torch._C._jit_set_profiling_mode(False)
torch._C._get_graph_executor_optimize(False)
torch._C._jit_override_can_fuse_on_gpu(False)
torch._C._jit_override_can_fuse_on_cpu(False)
torch._C._jit_set_texpr_fuser_enabled(False)
@ -25,12 +25,11 @@ def set_fuser(fuser_name, executor_name):
# --executor overrides settings of --fuser
if executor_name == 'profiling':
torch._C._jit_set_profiling_executor(True)
torch._C._jit_set_profiling_mode(True)
torch._C._get_graph_executor_optimize(True)
elif executor_name == 'simple':
torch._C._jit_set_profiling_executor(True)
torch._C._jit_set_profiling_mode(False)
torch._C._get_graph_executor_optimize(False)
elif executor_name == 'legacy':
torch._C._jit_set_profiling_executor(False)
torch._C._jit_set_profiling_mode(False)
torch._C._get_graph_executor_optimize(True)
elif executor_name == 'default':
pass

View File

@ -137,7 +137,7 @@ Works only with Python3.\n A few examples:
torch._C._jit_set_profiling_executor(True)
torch._C._jit_set_texpr_fuser_enabled(True)
torch._C._jit_override_can_fuse_on_gpu(True)
torch._C._jit_set_profiling_mode(True)
torch._C._get_graph_executor_optimize(True)
elif args.cuda_fuser == "old":
import torch
torch._C._jit_set_profiling_executor(False)
@ -148,7 +148,7 @@ Works only with Python3.\n A few examples:
torch._C._jit_set_profiling_executor(True)
torch._C._jit_set_texpr_fuser_enabled(False)
torch._C._jit_set_nvfuser_enabled(True)
torch._C._jit_set_profiling_mode(True)
torch._C._get_graph_executor_optimize(True)
else :
raise ValueError("Undefined fuser: {}".format(args.cuda_fuser))

View File

@ -289,14 +289,11 @@ class AutodiffRemoveUnusedGradientsTest : public ::testing::Test {
void SetUp() override {
prev_exec = getExecutorMode();
getExecutorMode() = true;
prev_profiling = getProfilingMode();
getProfilingMode() = true;
prev_inline_autodiff = getAutodiffSubgraphInlining();
debugSetAutodiffSubgraphInlining(false);
}
void TearDown() override {
getExecutorMode() = prev_exec;
getProfilingMode() = prev_profiling;
debugSetAutodiffSubgraphInlining(prev_inline_autodiff);
}

View File

@ -18,7 +18,7 @@ if __name__ == '__main__':
class TestProfiler(JitTestCase):
def setUp(self):
self.prev_exec = torch._C._jit_set_profiling_executor(True)
self.prev_profiling = torch._C._jit_set_profiling_mode(True)
self.prev_profiling = torch._C._get_graph_executor_optimize(True)
self.inline_autodiff = torch._C._debug_set_autodiff_subgraph_inlining(False)
self.texpr_fuser_state = torch._C._jit_texpr_fuser_enabled()
self.can_fuse_on_cpu = torch._C._jit_can_fuse_on_cpu()
@ -34,7 +34,7 @@ class TestProfiler(JitTestCase):
def tearDown(self):
torch._C._jit_set_profiling_executor(self.prev_exec)
torch._C._jit_set_profiling_mode(self.prev_profiling)
torch._C._get_graph_executor_optimize(self.prev_profiling)
torch._C._debug_set_autodiff_subgraph_inlining(self.inline_autodiff)
torch._C._jit_set_texpr_fuser_enabled(self.texpr_fuser_state)
torch._C._jit_override_can_fuse_on_cpu(self.can_fuse_on_cpu)

View File

@ -204,11 +204,6 @@ def doAutodiffCheck(testname):
# TODO: enable TE in PE when all tests are fixed
torch._C._jit_set_texpr_fuser_enabled(GRAPH_EXECUTOR == ProfilingMode.PROFILING)
torch._C._jit_set_profiling_executor(GRAPH_EXECUTOR != ProfilingMode.LEGACY)
# even though FULL_PROFILER should be our default
# we haven't tested every single test in this file
# but we enable FULL_PROFILER for a large subset
# of the tests with "with enable_profiling_mode_for_profiling_tests"
torch._C._jit_set_profiling_mode(False)
def LSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None):
hx, cx = hidden
@ -7360,7 +7355,7 @@ a")
g = test_as_tensor_tensor_input.graph_for(torch.ones(3, 4))
FileCheck().check("Tensor = aten::as_tensor").check("Float(*, *, requires_grad=0, device=cpu) = aten::as_tensor").run(g)
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "testing legacy behavior")
def test_tensor_requires_grad(self):
@torch.jit.script
def test(b):

View File

@ -18,7 +18,7 @@ import warnings
# inferred erroneously runs or skips
# some tests
torch._C._jit_set_profiling_executor(True)
torch._C._jit_set_profiling_mode(True)
torch._C._get_graph_executor_optimize(True)
from torch.testing._internal.common_utils import run_tests, ProfilingMode, GRAPH_EXECUTOR, \
enable_profiling_mode_for_profiling_tests, slowTest
@ -2608,7 +2608,7 @@ class TestLoopnestRandomization(TestLoopnestRandomizationParent):
torch._C._jit_override_can_fuse_on_gpu(True)
self.old_profiling_executor = torch._C._jit_set_profiling_executor(True)
self.old_profiling_mode = torch._C._jit_set_profiling_mode(True)
self.old_profiling_mode = torch._C._get_graph_executor_optimize(True)
self.old_fusion_inlining = torch._C._debug_get_fusion_group_inlining()
torch._C._debug_set_fusion_group_inlining(False)
@ -2625,7 +2625,7 @@ class TestLoopnestRandomization(TestLoopnestRandomizationParent):
def tearDown(self):
torch._C._jit_set_profiling_executor(self.old_profiling_executor)
torch._C._jit_set_profiling_mode(self.old_profiling_mode)
torch._C._get_graph_executor_optimize(self.old_profiling_mode)
torch._C._jit_override_can_fuse_on_gpu(self.old_gpu_fuser_state)
torch._C._jit_override_can_fuse_on_cpu(self.old_cpu_fuser_state)

View File

@ -283,7 +283,7 @@ def _get_model_ops_and_info_from_buffer(buffer: BinaryIO): ...
def _get_mobile_model_contained_types(filename: Union[str, Path]): ...
def _get_mobile_model_contained_types_from_buffer(buffer: BinaryIO): ...
def _logging_set_logger(logger: LoggerBase) -> LoggerBase: ...
def _get_graph_executor_optimize() -> _bool: ...
def _get_graph_executor_optimize(optimize: Optional[_bool] = None) -> _bool: ...
def _set_graph_executor_optimize(optimize: _bool): ...
def _export_opnames(module: ScriptModule) -> List[str]: ...
def _create_function_from_trace(

View File

@ -99,7 +99,7 @@ struct TORCH_API GraphFunction : public Function {
using Function::call;
bool call(
Stack& stack,
size_t bailOut,
c10::optional<size_t> bailOut,
c10::function_ref<void(const Code&)> f) override {
f(get_executor().getPlanFor(stack, bailOut).code);
return true;

View File

@ -90,7 +90,7 @@ struct AutogradZeroSpecializer {
if (!isBackwardGraph()) {
return;
}
if (getProfilingMode()) {
if (getExecutorMode()) {
if (auto versioning_if = guardSpecializations()) {
specializeAutogradOps(versioning_if->blocks()[0]);
GRAPH_DUMP("After versioning graph", graph_);

View File

@ -2009,7 +2009,16 @@ void initJitScriptBindings(PyObject* module) {
setGraphExecutorOptimize(optimize);
});
m.def("_get_graph_executor_optimize", &torch::jit::getGraphExecutorOptimize);
m.def(
"_get_graph_executor_optimize",
[](c10::optional<bool> new_setting = c10::nullopt) {
bool old_value = getGraphExecutorOptimize();
if (new_setting) {
setGraphExecutorOptimize(*new_setting);
}
return old_value;
},
py::arg("new_settings") = nullptr);
m.def(
"_enable_mobile_interface_call_export",

View File

@ -42,6 +42,7 @@
#include <torch/csrc/autograd/edge.h>
#include <torch/csrc/autograd/function.h>
#include <torch/csrc/jit/python/update_graph_executor_opt.h>
#include <torch/csrc/jit/runtime/logging.h>
#include <cstdint>
@ -56,17 +57,16 @@ namespace torch {
namespace jit {
EnableProfilingGuard::EnableProfilingGuard() {
auto& profiling_mode = getProfilingMode();
old_profiling_mode = profiling_mode;
profiling_mode = true;
auto& executor_mode = getExecutorMode();
old_executor_mode = executor_mode;
executor_mode = true;
old_get_optimize = getGraphExecutorOptimize();
setGraphExecutorOptimize(true);
}
EnableProfilingGuard::~EnableProfilingGuard() {
getProfilingMode() = old_profiling_mode;
getExecutorMode() = old_executor_mode;
setGraphExecutorOptimize(old_get_optimize);
}
namespace {
@ -408,8 +408,7 @@ struct DifferentiableGraphOp {
detachVariables(stack);
if (IsNewExecutorEnabled()) {
const ExecutionPlan& plan =
f_ptr->getPlanFor(stack, GraphExecutor::getDefaultNumBailOuts());
const ExecutionPlan& plan = f_ptr->getPlanFor(stack);
InterpreterState(plan.code).run(stack);
} else {
InterpreterState(legacy_f).run(stack);
@ -550,8 +549,7 @@ void GraphExecutorImplBase::run(Stack& stack) {
logging::getLogger()->addStatValue(
logging::runtime_counters::GRAPH_EXECUTOR_INVOCATIONS, 1.0);
const ExecutionPlan& plan =
getPlanFor(stack, GraphExecutor::getDefaultNumBailOuts());
const ExecutionPlan& plan = getPlanFor(stack);
InterpreterState(plan.code).run(stack);
last_executed_optimized_graph = plan.graph;
}
@ -576,9 +574,8 @@ c10::intrusive_ptr<Future> GraphExecutorImplBase::runAsync(
ExecutionPlan plan;
InterpreterState state;
};
auto frame = std::make_shared<Frame>(
getPlanFor(stack, GraphExecutor::getDefaultNumBailOuts()),
std::move(taskLauncher));
auto frame =
std::make_shared<Frame>(getPlanFor(stack), std::move(taskLauncher));
auto res = frame->state.runAsync(stack);
last_executed_optimized_graph = frame->plan.graph;
if (!res->completed()) {
@ -603,8 +600,9 @@ struct GraphExecutorImpl : public GraphExecutorImplBase {
logging::runtime_counters::GRAPH_EXECUTORS_CONSTRUCTED, 1.0);
}
const ExecutionPlan& getPlanFor(Stack& stack, size_t remaining_bailout_depth)
override {
const ExecutionPlan& getPlanFor(
Stack& stack,
c10::optional<size_t> remaining_bailout_depth) override {
return getGraphExecutorOptimize() ? getOrCompile(stack)
: getOrCompileFallback();
}
@ -783,13 +781,9 @@ c10::intrusive_ptr<Future> GraphExecutor::runAsync(
return pImpl->runAsync(stack, std::move(taskLauncher));
}
size_t GraphExecutor::getDefaultNumBailOuts() {
return getProfilingMode() ? getBailoutDepth() : 0;
}
const ExecutionPlan& GraphExecutor::getPlanFor(
Stack& inputs,
size_t remaining_bailout_depth) {
c10::optional<size_t> remaining_bailout_depth) {
return pImpl->getPlanFor(inputs, remaining_bailout_depth);
}
@ -887,10 +881,8 @@ void runNondiffOptimization(
// decomposition pass, decompose certain ops that will be used in the
// following passes (like batchmm and jit fusion)
if (!getProfilingMode()) {
DecomposeOps(graph);
GRAPH_DEBUG("After DecomposeOps\n", *graph);
}
DecomposeOps(graph);
GRAPH_DEBUG("After DecomposeOps\n", *graph);
// TupleConstruct / TupleUnpack pairs can still be present at this point
// and must be removed for fusion.
@ -901,7 +893,7 @@ void runNondiffOptimization(
BatchMM(graph);
GRAPH_DEBUG("After BatchMM, before Fusion\n", *graph);
if (getProfilingMode()) {
if (getExecutorMode()) {
if (tensorExprFuserEnabled()) {
auto min_size = getFusionGroupInlining() ? 2 : 1;
auto dyn_shapes = tensorExprDynamicShapeFusionEnabled();

View File

@ -18,12 +18,8 @@ struct Code;
struct ExecutionPlan {
ExecutionPlan() = default;
ExecutionPlan(
std::shared_ptr<Graph> graph,
std::string function_name,
size_t remaining_bailout_depth = 0)
: code(graph, std::move(function_name), remaining_bailout_depth),
graph(std::move(graph)) {}
ExecutionPlan(std::shared_ptr<Graph> graph, std::string function_name)
: code(graph, std::move(function_name)), graph(std::move(graph)) {}
operator bool() const {
return static_cast<bool>(graph);
@ -34,8 +30,8 @@ struct ExecutionPlan {
};
// Notice that those structs don't manage lifetime of their members.
// They is only valid only right after you call getDebugState() and should never
// be used again once another GraphExecutor function is called.
// They are only valid only right after you call getDebugState() and should
// never be used again once another GraphExecutor function is called.
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
struct GraphExecutorState {
@ -50,7 +46,7 @@ struct TORCH_API EnableProfilingGuard {
private:
bool old_executor_mode = false;
bool old_profiling_mode = false;
bool old_get_optimize = false;
};
struct GraphExecutorImplBase;
@ -72,13 +68,13 @@ struct TORCH_API GraphExecutor {
// profiled information whenever a bailout check is failed/triggered, a new
// `GraphExecutor` will be created. This new `GraphExecutor`'s
// remaining_bailout_depth will be reduced by 1.
// If no bailout depth is passed, the depth will be initialized from the
// current global fusion strategy settings.
const ExecutionPlan& getPlanFor(
Stack& inputs,
size_t remaining_bailout_depth);
c10::optional<size_t> remaining_bailout_depth = c10::nullopt);
GraphExecutorState getDebugState();
static size_t getDefaultNumBailOuts();
void debugFlushCompilationCache();
bool isOptimized() const;

View File

@ -79,7 +79,7 @@ struct GraphExecutorImplBase {
virtual const ExecutionPlan& getPlanFor(
Stack& stack,
size_t remaining_bailout_depth) = 0;
c10::optional<size_t> remaining_bailout_depth = c10::nullopt) = 0;
virtual GraphExecutorState getDebugState() = 0;
virtual ~GraphExecutorImplBase() = default;

View File

@ -175,7 +175,7 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target {
void callFunction(
Function& f,
Stack& stack,
size_t bailOut = GraphExecutor::getDefaultNumBailOuts(),
c10::optional<size_t> bailOut = c10::nullopt,
bool next = true) {
bool newFrame = f.call(stack, bailOut, [&](const Code& code) {
enterFrame(code, stack.size() - code.num_inputs());
@ -716,10 +716,7 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target {
auto& forked_fn =
toGraphFunction(*frame.function->function_table_[inst.X]);
InterpreterState forked_interpreter(
forked_fn.get_executor()
.getPlanFor(stack, GraphExecutor::getDefaultNumBailOuts())
.code,
taskLauncher_);
forked_fn.get_executor().getPlanFor(stack).code, taskLauncher_);
InterpreterContinuation continuation(
forked_interpreter,
Stack(stack.end() - inst.N, stack.end()),

View File

@ -609,13 +609,23 @@ ProfilingGraphExecutorImpl::ProfilingGraphExecutorImpl(
fusion_strategy_ = getFusionStrategy();
}
size_t ProfilingGraphExecutorImpl::getInstantiatedBailoutDepth() {
// Initialize bailout_depth from command-line flag.
size_t depth = 0;
for (const auto& pair : fusion_strategy_) {
depth += pair.second;
}
return depth;
}
const ExecutionPlan& ProfilingGraphExecutorImpl::getOptimizedPlanFor(
Stack& stack,
size_t remaining_bailout_depth) {
c10::optional<size_t> remaining_bailout_depth) {
GRAPH_DEBUG("Running ProfilingGraphExecutorImpl ", this);
// TODO: instantiate simple executor when getProfilingMode() is false
// no opt mode
if (!getGraphExecutorOptimize()) {
if (!getGraphExecutorOptimize() || !getProfilingMode()) {
if (!fallback_plan_) {
auto copy = graph->copy();
GRAPH_DEBUG(
@ -635,7 +645,11 @@ const ExecutionPlan& ProfilingGraphExecutorImpl::getOptimizedPlanFor(
// getPlanFor(remaining_bailout_depth) is corrected and persisted by the Code
// object in interpreter.
if (!remaining_bailout_depth_.has_value() || !tensorExprFuserEnabled()) {
remaining_bailout_depth_ = remaining_bailout_depth;
if (remaining_bailout_depth.has_value()) {
remaining_bailout_depth_ = *remaining_bailout_depth;
} else {
remaining_bailout_depth_ = getInstantiatedBailoutDepth();
}
}
// simple executor
@ -683,14 +697,13 @@ const ExecutionPlan& ProfilingGraphExecutorImpl::getOptimizedPlanFor(
replaceFallbackGraphWithFallbackFunction(copy->block());
runFinalOptimizations(copy);
GRAPH_DUMP("Optimized Graph: ", copy);
optimized_plan_ =
ExecutionPlan(copy, function_name_, *remaining_bailout_depth_);
optimized_plan_ = ExecutionPlan(copy, function_name_);
return *optimized_plan_;
}
const ExecutionPlan& ProfilingGraphExecutorImpl::getPlanFor(
Stack& stack,
size_t remaining_bailout_depth) {
c10::optional<size_t> remaining_bailout_depth) {
std::lock_guard<std::mutex> lock(compile_mutex);
// IMPORTANT: This is a hot path of calling a torchscript function. Try not to
@ -698,7 +711,7 @@ const ExecutionPlan& ProfilingGraphExecutorImpl::getPlanFor(
if (optimized_plan_) {
return *optimized_plan_;
}
// if depth is not set, use
return getOptimizedPlanFor(stack, remaining_bailout_depth);
}

View File

@ -15,8 +15,9 @@ struct TORCH_API ProfilingGraphExecutorImpl : public GraphExecutorImplBase {
const std::shared_ptr<Graph>& graph,
std::string function_name);
const ExecutionPlan& getPlanFor(Stack& stack, size_t remaining_bailout_depth)
override;
const ExecutionPlan& getPlanFor(
Stack& stack,
c10::optional<size_t> remaining_bailout_depth) override;
GraphExecutorState getDebugState() override;
~ProfilingGraphExecutorImpl() override = default;
@ -29,6 +30,8 @@ struct TORCH_API ProfilingGraphExecutorImpl : public GraphExecutorImplBase {
// prevent memory leaks
fallback_functions_.clear();
remaining_bailout_depth_.reset();
// TODO - would be nice to have it initialized in subsequent use
fusion_strategy_ = getFusionStrategy();
}
bool isOptimized() const override {
@ -38,13 +41,14 @@ struct TORCH_API ProfilingGraphExecutorImpl : public GraphExecutorImplBase {
private:
const ExecutionPlan& getOptimizedPlanFor(
Stack& stack,
size_t remaining_bailout_depth);
c10::optional<size_t> remaining_bailout_depth);
void runProfilingInsensitiveOptimizations(std::shared_ptr<Graph>& graph);
void runProfilingOptimizations(
std::shared_ptr<Graph>& graph,
size_t remaining_depth);
void replaceFallbackGraphWithFallbackFunction(Block* b);
FusionBehavior getCurrentBehavior(size_t remaining_depth);
size_t getInstantiatedBailoutDepth();
void runNoGradOptimizations(
std::shared_ptr<Graph>& graph,
size_t remaining_bailout_depth);

View File

@ -38,7 +38,7 @@ def fuser(name):
torch._C._jit_set_nvfuser_enabled(False)
elif name == 'fuser1': # NNC
old_profiling_executor = torch._C._jit_set_profiling_executor(True)
old_profiling_mode = torch._C._jit_set_profiling_mode(True)
old_profiling_mode = torch._C._get_graph_executor_optimize(True)
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)
torch._C._jit_set_texpr_fuser_enabled(True)
@ -55,7 +55,7 @@ def fuser(name):
finally:
if name == 'fuser1': # NNC
torch._C._jit_set_profiling_executor(old_profiling_executor)
torch._C._jit_set_profiling_mode(old_profiling_mode)
torch._C._get_graph_executor_optimize(old_profiling_mode)
# recover the previous values
torch._C._jit_override_can_fuse_on_cpu(old_cpu_fuse)
torch._C._jit_override_can_fuse_on_gpu(old_gpu_fuse)

View File

@ -370,7 +370,7 @@ if __name__ == '__main__':
# Turn off profiling executor
if not args.profiling_executor:
torch._C._jit_set_profiling_executor(False)
torch._C._jit_set_profiling_mode(False)
torch._C._get_graph_executor_optimize(False)
# factor sorta control the depth of the model
GRAPH_FACTOR = args.depth_factor

View File

@ -369,9 +369,9 @@ class ProfilingMode(Enum):
def cppProfilingFlagsToProfilingMode():
old_prof_exec_state = torch._C._jit_set_profiling_executor(True)
old_prof_mode_state = torch._C._jit_set_profiling_mode(True)
old_prof_mode_state = torch._C._get_graph_executor_optimize(True)
torch._C._jit_set_profiling_executor(old_prof_exec_state)
torch._C._jit_set_profiling_mode(old_prof_mode_state)
torch._C._get_graph_executor_optimize(old_prof_mode_state)
if old_prof_exec_state:
if old_prof_mode_state:
@ -385,23 +385,23 @@ def cppProfilingFlagsToProfilingMode():
def enable_profiling_mode_for_profiling_tests():
if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
old_prof_exec_state = torch._C._jit_set_profiling_executor(True)
old_prof_mode_state = torch._C._jit_set_profiling_mode(True)
old_prof_mode_state = torch._C._get_graph_executor_optimize(True)
try:
yield
finally:
if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
torch._C._jit_set_profiling_executor(old_prof_exec_state)
torch._C._jit_set_profiling_mode(old_prof_mode_state)
torch._C._get_graph_executor_optimize(old_prof_mode_state)
@contextmanager
def enable_profiling_mode():
old_prof_exec_state = torch._C._jit_set_profiling_executor(True)
old_prof_mode_state = torch._C._jit_set_profiling_mode(True)
old_prof_mode_state = torch._C._get_graph_executor_optimize(True)
try:
yield
finally:
torch._C._jit_set_profiling_executor(old_prof_exec_state)
torch._C._jit_set_profiling_mode(old_prof_mode_state)
torch._C._get_graph_executor_optimize(old_prof_mode_state)
@contextmanager
def num_profiled_runs(num_runs):

View File

@ -767,7 +767,7 @@ def _get_py3_code(code, fn_name):
class TensorExprTestOptions():
def __init__(self):
self.old_profiling_executor = torch._C._jit_set_profiling_executor(True)
self.old_profiling_mode = torch._C._jit_set_profiling_mode(True)
self.old_profiling_mode = torch._C._get_graph_executor_optimize(True)
self.old_cpu_fuser_state = torch._C._jit_can_fuse_on_cpu()
self.old_gpu_fuser_state = torch._C._jit_can_fuse_on_gpu()
@ -782,7 +782,7 @@ class TensorExprTestOptions():
def restore(self):
torch._C._jit_set_profiling_executor(self.old_profiling_executor)
torch._C._jit_set_profiling_mode(self.old_profiling_mode)
torch._C._get_graph_executor_optimize(self.old_profiling_mode)
torch._C._jit_set_texpr_fuser_enabled(self.texpr_fuser_state)
torch._C._jit_override_can_fuse_on_gpu(self.old_gpu_fuser_state)