mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[JIT][SR] Introduce prim::IfThenElse (#72587)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/72587 This pattern frequently appears in a few graphs: ``` %result = prim::If(%condition) block0(): -> (%a) block1(): -> (%b) ``` This is slow, particularly in static runtime. Static runtime creates memory planners/block runners for each sub-block, which eats up a lot of memory and introduces a lot of extra overhead for this relatively simple operation. This diff introduces a new op that replaces nodes like the above with a single op meant to act like a ternary operator: ``` %result = prim::IfThenElse(%condition, %a, %b) ``` Test Plan: New unit tests Reviewed By: eellison Differential Revision: D34091789 fbshipit-source-id: eb6a8c460c39b4c019a1f4ab1f3f1e5b6edc400c
This commit is contained in:
parent
b3a1923331
commit
0f1b335e5b
|
|
@ -96,6 +96,7 @@ namespace c10 {
|
|||
_(prim, With) \
|
||||
_(prim, Enter) \
|
||||
_(prim, Exit) \
|
||||
_(prim, IfThenElse) \
|
||||
_(aten, Bool) \
|
||||
_(aten, Int) \
|
||||
_(aten, FloatImplicit) \
|
||||
|
|
|
|||
|
|
@ -2720,3 +2720,19 @@ TEST(StaticRuntime, ToList) {
|
|||
)JIT";
|
||||
testStaticRuntime(src, {at::randn({2, 2})});
|
||||
}
|
||||
|
||||
TEST(StaticRuntime, IfThenElse) {
|
||||
const auto src = R"IR(
|
||||
graph(%cond: bool, %a: Tensor, %b: Tensor):
|
||||
%none: NoneType = prim::Constant()
|
||||
%c: Tensor = prim::IfThenElse(%cond, %a, %b)
|
||||
%d: Tensor = aten::clone(%c, %none)
|
||||
return (%d)
|
||||
)IR";
|
||||
|
||||
std::vector<IValue> args1{true, at::randn({1}), at::randn({1})};
|
||||
std::vector<IValue> args2{false, at::randn({1}), at::randn({1})};
|
||||
|
||||
testStaticRuntime(src, args1);
|
||||
testStaticRuntime(src, args2);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -39,6 +39,7 @@ endif()
|
|||
|
||||
# Build the cpp gtest binary containing the cpp-only tests.
|
||||
set(JIT_TEST_SRCS
|
||||
${JIT_TEST_ROOT}/test_add_if_then_else.cpp
|
||||
${JIT_TEST_ROOT}/test_alias_analysis.cpp
|
||||
${JIT_TEST_ROOT}/test_argument_spec.cpp
|
||||
${JIT_TEST_ROOT}/test_autodiff.cpp
|
||||
|
|
|
|||
53
test/cpp/jit/test_add_if_then_else.cpp
Normal file
53
test/cpp/jit/test_add_if_then_else.cpp
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
#include <gtest/gtest.h>
|
||||
|
||||
#include <test/cpp/jit/test_utils.h>
|
||||
#include <torch/csrc/jit/ir/irparser.h>
|
||||
#include <torch/csrc/jit/passes/add_if_then_else.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
TEST(AddIfThenElseOpTest, AddIfThenElseOpSimple) {
|
||||
const auto src = R"IR(
|
||||
graph(%cond: bool, %a: Tensor, %b: Tensor):
|
||||
%result: Tensor = prim::If(%cond)
|
||||
block0():
|
||||
-> (%a)
|
||||
block1():
|
||||
-> (%b)
|
||||
return (%result)
|
||||
)IR";
|
||||
|
||||
auto graph = std::make_shared<Graph>();
|
||||
parseIR(src, graph.get());
|
||||
EXPECT_TRUE(AddIfThenElseOp(graph));
|
||||
|
||||
testing::FileCheck()
|
||||
.check_count("= prim::IfThenElse", 1, /*exactly*/ true)
|
||||
->check_count("= prim::If", 0, /*exactly*/ true)
|
||||
->run(*graph);
|
||||
}
|
||||
|
||||
TEST(AddIfThenElseOpTest, NoIfThenElseOpMultipleOutputs) {
|
||||
const auto src = R"IR(
|
||||
graph(%cond: bool, %a: Tensor, %b: Tensor):
|
||||
%result1: Tensor, %result2: Tensor = prim::If(%cond)
|
||||
block0():
|
||||
-> (%a, %b)
|
||||
block1():
|
||||
-> (%b, %a)
|
||||
return (%result1, %result2)
|
||||
)IR";
|
||||
|
||||
auto graph = std::make_shared<Graph>();
|
||||
parseIR(src, graph.get());
|
||||
EXPECT_FALSE(AddIfThenElseOp(graph));
|
||||
|
||||
testing::FileCheck()
|
||||
.check_count("= prim::IfThenElse", 0, /*exactly*/ true)
|
||||
->check_count("= prim::If", 1, /*exactly*/ true)
|
||||
->run(*graph);
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
@ -213,6 +213,7 @@ core_sources_full_mobile_no_backend_interface = [
|
|||
"torch/csrc/jit/operator_upgraders/utils.cpp",
|
||||
"torch/csrc/jit/operator_upgraders/upgraders.cpp",
|
||||
"torch/csrc/jit/operator_upgraders/upgraders_entry.cpp",
|
||||
"torch/csrc/jit/passes/add_if_then_else.cpp",
|
||||
"torch/csrc/jit/passes/annotate_warns.cpp",
|
||||
"torch/csrc/jit/passes/bailout_graph.cpp",
|
||||
"torch/csrc/jit/passes/batch_mm.cpp",
|
||||
|
|
|
|||
55
torch/csrc/jit/passes/add_if_then_else.cpp
Normal file
55
torch/csrc/jit/passes/add_if_then_else.cpp
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
#include <torch/csrc/jit/passes/add_if_then_else.h>
|
||||
#include <torch/csrc/jit/runtime/graph_iterator.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
namespace {
|
||||
|
||||
bool hasNoNodes(Block* block) {
|
||||
auto nodes = block->nodes();
|
||||
return nodes.begin() == nodes.end();
|
||||
}
|
||||
|
||||
bool hasTrivialSubBlocks(Node* node) {
|
||||
const auto blocks = node->blocks();
|
||||
DCHECK_EQ(blocks.size(), 2);
|
||||
|
||||
return hasNoNodes(blocks[0]) && hasNoNodes(blocks[1]);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool AddIfThenElseOp(std::shared_ptr<Graph>& graph) {
|
||||
std::vector<Node*> to_replace;
|
||||
DepthFirstGraphNodeIterator graph_it(graph);
|
||||
for (auto* node = graph_it.next(); node != nullptr; node = graph_it.next()) {
|
||||
if (node->kind() != prim::If) {
|
||||
continue;
|
||||
}
|
||||
if (node->outputs().size() != 1) {
|
||||
continue;
|
||||
}
|
||||
if (hasTrivialSubBlocks(node)) {
|
||||
to_replace.push_back(node);
|
||||
}
|
||||
}
|
||||
|
||||
for (auto* node : to_replace) {
|
||||
auto* if_then_else_node = graph->create(prim::IfThenElse, 1);
|
||||
if_then_else_node->addInput(node->input());
|
||||
auto blocks = node->blocks();
|
||||
if_then_else_node->addInput(blocks[0]->return_node()->input());
|
||||
if_then_else_node->addInput(blocks[1]->return_node()->input());
|
||||
|
||||
if_then_else_node->insertBefore(node);
|
||||
if_then_else_node->output()->copyMetadata(node->output());
|
||||
|
||||
node->output()->replaceAllUsesWith(if_then_else_node->output());
|
||||
node->destroy();
|
||||
}
|
||||
return !to_replace.empty();
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
11
torch/csrc/jit/passes/add_if_then_else.h
Normal file
11
torch/csrc/jit/passes/add_if_then_else.h
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
#pragma once
|
||||
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
TORCH_API bool AddIfThenElseOp(std::shared_ptr<Graph>& graph);
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
#include <c10/util/irange.h>
|
||||
#include <torch/csrc/jit/jit_log.h>
|
||||
#include <torch/csrc/jit/passes/add_if_then_else.h>
|
||||
#include <torch/csrc/jit/passes/bailout_graph.h>
|
||||
#include <torch/csrc/jit/passes/batch_mm.h>
|
||||
#include <torch/csrc/jit/passes/canonicalize_graph_fuser_ops.h>
|
||||
|
|
@ -650,6 +651,7 @@ const ExecutionPlan& ProfilingGraphExecutorImpl::getOptimizedPlanFor(
|
|||
// replaces a fallback graph inserted by
|
||||
// specialize_autogradzero if one exists
|
||||
replaceFallbackGraphWithFallbackFunction(copy->block());
|
||||
runFinalOptimizations(copy);
|
||||
GRAPH_DUMP("Optimized Graph: ", copy);
|
||||
optimized_plan_ =
|
||||
ExecutionPlan(copy, function_name_, *remaining_bailout_depth_);
|
||||
|
|
@ -749,5 +751,10 @@ void ProfilingGraphExecutorImpl::replaceFallbackGraphWithFallbackFunction(
|
|||
}
|
||||
}
|
||||
|
||||
void ProfilingGraphExecutorImpl::runFinalOptimizations(
|
||||
std::shared_ptr<Graph>& graph) {
|
||||
AddIfThenElseOp(graph);
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -39,6 +39,7 @@ struct TORCH_API ProfilingGraphExecutorImpl : public GraphExecutorImplBase {
|
|||
std::shared_ptr<Graph>& graph,
|
||||
size_t remaining_depth);
|
||||
void replaceFallbackGraphWithFallbackFunction(Block* b);
|
||||
void runFinalOptimizations(std::shared_ptr<Graph>& graph);
|
||||
std::unique_ptr<ProfilingRecord> pr_;
|
||||
c10::optional<ExecutionPlan>
|
||||
profiling_plan_; // plan to run in order to profiling the code
|
||||
|
|
|
|||
|
|
@ -700,6 +700,17 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{
|
|||
push(stack, at::stack(inputs, dim));
|
||||
},
|
||||
aliasAnalysisFromSchema()),
|
||||
OperatorGeneratorArgs(
|
||||
TORCH_SELECTIVE_SCHEMA(
|
||||
"prim::IfThenElse(bool cond, Any(a) x, Any(b) y) -> Any(a|b)"),
|
||||
[](Stack& stack) {
|
||||
const auto cond = stack[stack.size() - 3].toBool();
|
||||
stack[stack.size() - 3] =
|
||||
std::move(stack[stack.size() - (cond ? 2 : 1)]);
|
||||
stack.pop_back();
|
||||
stack.pop_back();
|
||||
},
|
||||
aliasAnalysisFromSchema()),
|
||||
OperatorGeneratorArgs(
|
||||
TORCH_SELECTIVE_SCHEMA(
|
||||
"aten::eq.enum(AnyEnumType a, AnyEnumType b) -> bool"),
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@
|
|||
#include <caffe2/core/timer.h>
|
||||
#include <torch/csrc/jit/ir/alias_analysis.h>
|
||||
#include <torch/csrc/jit/jit_log.h>
|
||||
#include <torch/csrc/jit/passes/add_if_then_else.h>
|
||||
#include <torch/csrc/jit/passes/canonicalize.h>
|
||||
#include <torch/csrc/jit/passes/dead_code_elimination.h>
|
||||
#include <torch/csrc/jit/passes/eliminate_no_ops.h>
|
||||
|
|
@ -173,6 +174,7 @@ void OptimizeGraph(
|
|||
UseVariadicGroupedAccessor(graph);
|
||||
EliminateNoOps(
|
||||
graph, /* custom_ops */ {fromQualString("fb::scale_gradient")});
|
||||
AddIfThenElseOp(graph);
|
||||
GRAPH_DUMP("Final graph after optimizations: ", graph);
|
||||
}
|
||||
|
||||
|
|
@ -1846,8 +1848,9 @@ static bool checkNoMemoryOverlap(const at::Tensor& a, const at::Tensor& b) {
|
|||
}
|
||||
|
||||
bool ProcessedNode::verify_no_memory_overlap(bool force_check) const {
|
||||
const static std::array<c10::Symbol, 5> special_case_ops = {
|
||||
const static std::array<c10::Symbol, 6> special_case_ops = {
|
||||
fromQualString("prim::TypeCheck"),
|
||||
fromQualString("prim::IfThenElse"),
|
||||
fromQualString("static_runtime::select_tensor"),
|
||||
fromQualString("static_runtime::VarTupleUnpack"),
|
||||
fromQualString("static_runtime::dict_unpack"),
|
||||
|
|
|
|||
|
|
@ -58,10 +58,11 @@ TORCH_API inline bool doesNotHeapAllocateWhenStoredInIValue(const Type& type) {
|
|||
}
|
||||
|
||||
TORCH_API inline bool borrowsOutputs(c10::Symbol kind) {
|
||||
static const std::array<c10::Symbol, 3> symbols_with_borrowed_outputs = {
|
||||
static const std::array<c10::Symbol, 4> symbols_with_borrowed_outputs = {
|
||||
c10::Symbol::fromQualString("static_runtime::select_tensor"),
|
||||
c10::Symbol::fromQualString("static_runtime::dict_unpack"),
|
||||
c10::Symbol::fromQualString("static_runtime::VarTupleUnpack"),
|
||||
c10::Symbol::fromQualString("prim::IfThenElse"),
|
||||
};
|
||||
return std::find(
|
||||
symbols_with_borrowed_outputs.begin(),
|
||||
|
|
|
|||
|
|
@ -946,5 +946,17 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|||
};
|
||||
});
|
||||
|
||||
// See [Borrowed IValue Outputs]
|
||||
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
||||
prim::IfThenElse,
|
||||
prim_IfThenElse,
|
||||
[](Node*) -> SROperator {
|
||||
return [](ProcessedNode* pnode) {
|
||||
const auto condition = pnode->Input(0).toBool();
|
||||
pnode->Output(0) = condition ? createBorrowedIValue(pnode->Input(1))
|
||||
: createBorrowedIValue(pnode->Input(2));
|
||||
};
|
||||
});
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user