mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Fix dead code elimination in onnx export (#22476)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/22476 Dead code elimination assumes a valid jit graph because it checks if operators have side effects. The onnx export path destroys the jit graph right before calling dead code elimination, but it actually doesn't care about side effects. We can just call dead code elimination and disable side effect lookup and things should work. Reviewed By: houseroad Differential Revision: D16100172 fbshipit-source-id: 8c790055e0d76c4227394cafa93b07d1310f2cea
This commit is contained in:
parent
76e14c1e51
commit
17cc79865d
|
|
@ -106,7 +106,9 @@ void validateGraph(
|
|||
const std::shared_ptr<Graph>& graph,
|
||||
onnx_torch::OperatorExportTypes operator_export_type) {
|
||||
validateBlock(graph->block(), operator_export_type);
|
||||
EliminateDeadCode(graph->block());
|
||||
// this is run on an onnx graph which doesn't have side effects.
|
||||
// ignore side effects in dead code elimination.
|
||||
EliminateDeadCode(graph->block(), true, DCESideEffectPolicy::ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS);
|
||||
}
|
||||
|
||||
class EncoderBase {
|
||||
|
|
|
|||
|
|
@ -126,6 +126,11 @@ void initJITBindings(PyObject* module) {
|
|||
[](std::shared_ptr<Graph>& g) {
|
||||
return EliminateDeadCode(g->block()); // overload resolution
|
||||
})
|
||||
.def(
|
||||
"_jit_pass_dce_allow_deleting_nodes_with_side_effects",
|
||||
[](std::shared_ptr<Graph>& g) {
|
||||
return EliminateDeadCode(g->block(), true, DCESideEffectPolicy::ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS); // overload resolution
|
||||
})
|
||||
.def(
|
||||
"_jit_pass_cse",
|
||||
[](std::shared_ptr<Graph>& g) {
|
||||
|
|
|
|||
|
|
@ -15,9 +15,10 @@ using namespace ::c10::prim;
|
|||
|
||||
class DeadCodeEliminator {
|
||||
public:
|
||||
explicit DeadCodeEliminator(std::shared_ptr<Graph> graph)
|
||||
: aliasDb_(torch::make_unique<AliasDb>(std::move(graph))) {}
|
||||
DeadCodeEliminator() = default;
|
||||
explicit DeadCodeEliminator(std::shared_ptr<Graph> graph, DCESideEffectPolicy sideEffectPolicy)
|
||||
: sideEffectPolicy_(sideEffectPolicy), aliasDb_(torch::make_unique<AliasDb>(std::move(graph))) {}
|
||||
DeadCodeEliminator(DCESideEffectPolicy sideEffectPolicy)
|
||||
: sideEffectPolicy_(sideEffectPolicy) {}
|
||||
|
||||
// The algorithm is an inverse mark-and-sweep. Starting from the return node,
|
||||
// we mark "live" nodes that are necessary for the output. Nodes that have
|
||||
|
|
@ -127,7 +128,7 @@ class DeadCodeEliminator {
|
|||
void mark(Block* block) {
|
||||
// Mark all nodes with side effects.
|
||||
for (auto node : block->nodes()) {
|
||||
if (hasSideEffects(node)) {
|
||||
if (sideEffectPolicy_ == DCESideEffectPolicy::DONT_DELETE_NODES_WITH_SIDE_EFFECTS && hasSideEffects(node)) {
|
||||
mark(node);
|
||||
}
|
||||
}
|
||||
|
|
@ -284,6 +285,7 @@ class DeadCodeEliminator {
|
|||
}
|
||||
}
|
||||
|
||||
DCESideEffectPolicy sideEffectPolicy_;
|
||||
std::unique_ptr<AliasDb> aliasDb_ = nullptr;
|
||||
std::unordered_map<Node*, bool> memo_;
|
||||
std::unordered_set<Node*> marked_;
|
||||
|
|
@ -292,18 +294,19 @@ class DeadCodeEliminator {
|
|||
[](const std::unordered_set<const Value*>&) {};
|
||||
};
|
||||
|
||||
void EliminateDeadCode(const std::shared_ptr<Graph>& graph) {
|
||||
DeadCodeEliminator(graph).run(graph->block(), /*recurse=*/true);
|
||||
void EliminateDeadCode(const std::shared_ptr<Graph>& graph, DCESideEffectPolicy sideEffectPolicy) {
|
||||
DeadCodeEliminator(graph, sideEffectPolicy).run(graph->block(), /*recurse=*/true);
|
||||
}
|
||||
|
||||
void EliminateDeadCode(Block* block, bool recurse) {
|
||||
DeadCodeEliminator().run(block, recurse);
|
||||
void EliminateDeadCode(Block* block, bool recurse, DCESideEffectPolicy sideEffectPolicy) {
|
||||
DeadCodeEliminator(sideEffectPolicy).run(block, recurse);
|
||||
}
|
||||
|
||||
void EliminateDeadCode(
|
||||
Block* block,
|
||||
std::function<void(const std::unordered_set<const Value*>&)> cb) {
|
||||
DeadCodeEliminator eliminator;
|
||||
std::function<void(const std::unordered_set<const Value*>&)> cb,
|
||||
DCESideEffectPolicy sideEffectPolicy) {
|
||||
DeadCodeEliminator eliminator(sideEffectPolicy);
|
||||
eliminator.setDeleteCallback(std::move(cb));
|
||||
eliminator.run(block, /*recurse=*/true);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,12 +11,23 @@ namespace jit {
|
|||
// eliminate mutable ops.
|
||||
//
|
||||
// So, prefer to use the graph version if you can.
|
||||
TORCH_API void EliminateDeadCode(const std::shared_ptr<Graph>& graph);
|
||||
TORCH_API void EliminateDeadCode(Block* block, bool recurse = true);
|
||||
enum class DCESideEffectPolicy : uint8_t {
|
||||
// default behavior: dead code elimination will check if a node has side effects
|
||||
// and not delete it if it does.
|
||||
DONT_DELETE_NODES_WITH_SIDE_EFFECTS,
|
||||
// with this flag, dead code elimination will not check if a node has side
|
||||
// effects and treat nodes with side effects like any other node,
|
||||
// i.e. delete them if their outputs aren't used anywhere.
|
||||
ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS
|
||||
};
|
||||
|
||||
TORCH_API void EliminateDeadCode(const std::shared_ptr<Graph>& graph, DCESideEffectPolicy sideEffectPolicy = DCESideEffectPolicy::DONT_DELETE_NODES_WITH_SIDE_EFFECTS);
|
||||
TORCH_API void EliminateDeadCode(Block* block, bool recurse = true, DCESideEffectPolicy sideEffectPolicy = DCESideEffectPolicy::DONT_DELETE_NODES_WITH_SIDE_EFFECTS);
|
||||
|
||||
// Invoke the user-provided callback on all live values before deleting anything
|
||||
TORCH_API void EliminateDeadCode(
|
||||
Block* block,
|
||||
std::function<void(const std::unordered_set<const Value*>&)> cb);
|
||||
std::function<void(const std::unordered_set<const Value*>&)> cb,
|
||||
DCESideEffectPolicy sideEffectPolicy = DCESideEffectPolicy::DONT_DELETE_NODES_WITH_SIDE_EFFECTS);
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -136,7 +136,7 @@ void preprocessCaffe2Ops(Block* block) {
|
|||
}
|
||||
}
|
||||
}
|
||||
EliminateDeadCode(block);
|
||||
EliminateDeadCode(block, true, DCESideEffectPolicy::ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS);
|
||||
}
|
||||
|
||||
void PreprocessCaffe2Ops(std::shared_ptr<Graph>& graph) {
|
||||
|
|
@ -344,8 +344,7 @@ void BlockToONNX(
|
|||
ctx.block->registerOutput(env.at(output));
|
||||
env.at(output)->setType(output->type());
|
||||
}
|
||||
|
||||
EliminateDeadCode(ctx.block);
|
||||
EliminateDeadCode(ctx.block, true, DCESideEffectPolicy::ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS);
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
|
|
|
|||
|
|
@ -223,6 +223,7 @@ def _optimize_graph(graph, operator_export_type, _disable_torch_constant_prop=Fa
|
|||
torch._C._jit_pass_lint(graph)
|
||||
|
||||
torch._C._jit_pass_onnx_remove_print(graph)
|
||||
|
||||
torch._C._jit_pass_onnx_preprocess_caffe2(graph)
|
||||
|
||||
# onnx only supports tensors, so we turn all out number types into tensors
|
||||
|
|
@ -232,7 +233,13 @@ def _optimize_graph(graph, operator_export_type, _disable_torch_constant_prop=Fa
|
|||
torch._C._jit_pass_lint(graph)
|
||||
torch._C._jit_pass_onnx_peephole(graph)
|
||||
torch._C._jit_pass_lint(graph)
|
||||
torch._C._jit_pass_dce(graph)
|
||||
|
||||
# graph is not a valid jit graph anymore because types have been replaced
|
||||
# (e.g. int with Tensor), so it now contains operators that don't actually
|
||||
# exist. We can't run normal dead code elimination because it'd fail trying
|
||||
# to look up if an operator has side effects, but we can run a dead code
|
||||
# elimination variant that doesn't need to look up if an op has side effects.
|
||||
torch._C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)
|
||||
torch._C._jit_pass_lint(graph)
|
||||
torch._C._jit_pass_fixup_onnx_loops(graph)
|
||||
torch._C._jit_pass_lint(graph)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user