Fix dead code elimination in onnx export (#22476)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/22476

Dead code elimination assumes a valid jit graph because it checks if operators have side effects.
The onnx export path destroys the jit graph right before calling dead code elimination, but it actually doesn't care about side effects.
We can just call dead code elimination and disable side effect lookup and things should work.

Reviewed By: houseroad

Differential Revision: D16100172

fbshipit-source-id: 8c790055e0d76c4227394cafa93b07d1310f2cea
This commit is contained in:
Sebastian Messmer 2019-07-02 21:18:43 -07:00 committed by Facebook Github Bot
parent 76e14c1e51
commit 17cc79865d
6 changed files with 45 additions and 18 deletions

View File

@ -106,7 +106,9 @@ void validateGraph(
const std::shared_ptr<Graph>& graph,
onnx_torch::OperatorExportTypes operator_export_type) {
validateBlock(graph->block(), operator_export_type);
EliminateDeadCode(graph->block());
// this is run on an onnx graph which doesn't have side effects.
// ignore side effects in dead code elimination.
EliminateDeadCode(graph->block(), true, DCESideEffectPolicy::ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS);
}
class EncoderBase {

View File

@ -126,6 +126,11 @@ void initJITBindings(PyObject* module) {
[](std::shared_ptr<Graph>& g) {
return EliminateDeadCode(g->block()); // overload resolution
})
.def(
"_jit_pass_dce_allow_deleting_nodes_with_side_effects",
[](std::shared_ptr<Graph>& g) {
return EliminateDeadCode(g->block(), true, DCESideEffectPolicy::ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS); // overload resolution
})
.def(
"_jit_pass_cse",
[](std::shared_ptr<Graph>& g) {

View File

@ -15,9 +15,10 @@ using namespace ::c10::prim;
class DeadCodeEliminator {
public:
explicit DeadCodeEliminator(std::shared_ptr<Graph> graph)
: aliasDb_(torch::make_unique<AliasDb>(std::move(graph))) {}
DeadCodeEliminator() = default;
explicit DeadCodeEliminator(std::shared_ptr<Graph> graph, DCESideEffectPolicy sideEffectPolicy)
: sideEffectPolicy_(sideEffectPolicy), aliasDb_(torch::make_unique<AliasDb>(std::move(graph))) {}
DeadCodeEliminator(DCESideEffectPolicy sideEffectPolicy)
: sideEffectPolicy_(sideEffectPolicy) {}
// The algorithm is an inverse mark-and-sweep. Starting from the return node,
// we mark "live" nodes that are necessary for the output. Nodes that have
@ -127,7 +128,7 @@ class DeadCodeEliminator {
void mark(Block* block) {
// Mark all nodes with side effects.
for (auto node : block->nodes()) {
if (hasSideEffects(node)) {
if (sideEffectPolicy_ == DCESideEffectPolicy::DONT_DELETE_NODES_WITH_SIDE_EFFECTS && hasSideEffects(node)) {
mark(node);
}
}
@ -284,6 +285,7 @@ class DeadCodeEliminator {
}
}
DCESideEffectPolicy sideEffectPolicy_;
std::unique_ptr<AliasDb> aliasDb_ = nullptr;
std::unordered_map<Node*, bool> memo_;
std::unordered_set<Node*> marked_;
@ -292,18 +294,19 @@ class DeadCodeEliminator {
[](const std::unordered_set<const Value*>&) {};
};
void EliminateDeadCode(const std::shared_ptr<Graph>& graph) {
DeadCodeEliminator(graph).run(graph->block(), /*recurse=*/true);
void EliminateDeadCode(const std::shared_ptr<Graph>& graph, DCESideEffectPolicy sideEffectPolicy) {
DeadCodeEliminator(graph, sideEffectPolicy).run(graph->block(), /*recurse=*/true);
}
void EliminateDeadCode(Block* block, bool recurse) {
DeadCodeEliminator().run(block, recurse);
void EliminateDeadCode(Block* block, bool recurse, DCESideEffectPolicy sideEffectPolicy) {
DeadCodeEliminator(sideEffectPolicy).run(block, recurse);
}
void EliminateDeadCode(
Block* block,
std::function<void(const std::unordered_set<const Value*>&)> cb) {
DeadCodeEliminator eliminator;
std::function<void(const std::unordered_set<const Value*>&)> cb,
DCESideEffectPolicy sideEffectPolicy) {
DeadCodeEliminator eliminator(sideEffectPolicy);
eliminator.setDeleteCallback(std::move(cb));
eliminator.run(block, /*recurse=*/true);
}

View File

@ -11,12 +11,23 @@ namespace jit {
// eliminate mutable ops.
//
// So, prefer to use the graph version if you can.
TORCH_API void EliminateDeadCode(const std::shared_ptr<Graph>& graph);
TORCH_API void EliminateDeadCode(Block* block, bool recurse = true);
enum class DCESideEffectPolicy : uint8_t {
// default behavior: dead code elimination will check if a node has side effects
// and not delete it if it does.
DONT_DELETE_NODES_WITH_SIDE_EFFECTS,
// with this flag, dead code elimination will not check if a node has side
// effects and treat nodes with side effects like any other node,
// i.e. delete them if their outputs aren't used anywhere.
ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS
};
TORCH_API void EliminateDeadCode(const std::shared_ptr<Graph>& graph, DCESideEffectPolicy sideEffectPolicy = DCESideEffectPolicy::DONT_DELETE_NODES_WITH_SIDE_EFFECTS);
TORCH_API void EliminateDeadCode(Block* block, bool recurse = true, DCESideEffectPolicy sideEffectPolicy = DCESideEffectPolicy::DONT_DELETE_NODES_WITH_SIDE_EFFECTS);
// Invoke the user-provided callback on all live values before deleting anything
TORCH_API void EliminateDeadCode(
Block* block,
std::function<void(const std::unordered_set<const Value*>&)> cb);
std::function<void(const std::unordered_set<const Value*>&)> cb,
DCESideEffectPolicy sideEffectPolicy = DCESideEffectPolicy::DONT_DELETE_NODES_WITH_SIDE_EFFECTS);
} // namespace jit
} // namespace torch

View File

@ -136,7 +136,7 @@ void preprocessCaffe2Ops(Block* block) {
}
}
}
EliminateDeadCode(block);
EliminateDeadCode(block, true, DCESideEffectPolicy::ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS);
}
void PreprocessCaffe2Ops(std::shared_ptr<Graph>& graph) {
@ -344,8 +344,7 @@ void BlockToONNX(
ctx.block->registerOutput(env.at(output));
env.at(output)->setType(output->type());
}
EliminateDeadCode(ctx.block);
EliminateDeadCode(ctx.block, true, DCESideEffectPolicy::ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS);
}
} // namespace jit

View File

@ -223,6 +223,7 @@ def _optimize_graph(graph, operator_export_type, _disable_torch_constant_prop=Fa
torch._C._jit_pass_lint(graph)
torch._C._jit_pass_onnx_remove_print(graph)
torch._C._jit_pass_onnx_preprocess_caffe2(graph)
# onnx only supports tensors, so we turn all out number types into tensors
@ -232,7 +233,13 @@ def _optimize_graph(graph, operator_export_type, _disable_torch_constant_prop=Fa
torch._C._jit_pass_lint(graph)
torch._C._jit_pass_onnx_peephole(graph)
torch._C._jit_pass_lint(graph)
torch._C._jit_pass_dce(graph)
# graph is not a valid jit graph anymore because types have been replaced
# (e.g. int with Tensor), so it now contains operators that don't actually
# exist. We can't run normal dead code elimination because it'd fail trying
# to look up if an operator has side effects, but we can run a dead code
# elimination variant that doesn't need to look up if an op has side effects.
torch._C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)
torch._C._jit_pass_lint(graph)
torch._C._jit_pass_fixup_onnx_loops(graph)
torch._C._jit_pass_lint(graph)