From 05b802d4e00651af0aec05766ed24e0d1cb1512f Mon Sep 17 00:00:00 2001 From: Meghan Lele Date: Wed, 28 Jul 2021 11:42:44 -0700 Subject: [PATCH] [pytorch] Bring back RemoveInplaceOps() (#62200) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/62200 This commit brings back the `RemoveInplaceOps` pass removed in D29523283 (https://github.com/pytorch/pytorch/commit/dec5aa2260cef540b622bd9a9504b6f11cb1f607) that apparently had a bunch of internal users. Test Plan: danthe3rd Reviewed By: danthe3rd Differential Revision: D29833316 fbshipit-source-id: 6cf13d463ab0a5e50ba3eb3243f79a9c51623809 --- torch/_C/__init__.pyi.in | 1 + torch/csrc/jit/passes/remove_inplace_ops.cpp | 97 ++++++++++++++++++-- torch/csrc/jit/passes/remove_inplace_ops.h | 2 + torch/csrc/jit/python/init.cpp | 4 + 4 files changed, 95 insertions(+), 9 deletions(-) diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index baf37a8b5ca..8b6cf781228 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -295,6 +295,7 @@ def _jit_pass_onnx_set_dynamic_input_shape(graph: Graph, dynamic_axes: Dict[str, def _jit_pass_onnx_graph_shape_type_inference(graph: Graph, paramsDict: Dict[str, IValue], opset_version: _int) -> None: ... def _jit_pass_onnx_assign_output_shape(graph: Graph, tensors: List[Tensor], desc: IODescriptor, onnx_shape_inference: _bool = False) -> None: ... def _jit_pass_onnx_remove_inplace_ops_for_onnx(graph: Graph, module: Module) -> None: ... +def _jit_pass_remove_inplace_ops(graph: Graph) -> None: ... def _jit_pass_canonicalize_graph_fuser_ops(graph: Graph) -> None: ... def _jit_pass_peephole(graph: Graph, addmm_fusion_enabled: _bool) -> None: ... def _jit_pass_fuse_addmm(graph: Graph) -> None: ... diff --git a/torch/csrc/jit/passes/remove_inplace_ops.cpp b/torch/csrc/jit/passes/remove_inplace_ops.cpp index 10c9ab42c85..9260f0e5a65 100644 --- a/torch/csrc/jit/passes/remove_inplace_ops.cpp +++ b/torch/csrc/jit/passes/remove_inplace_ops.cpp @@ -1,8 +1,81 @@ -#include #include namespace torch { namespace jit { +namespace { +static const std::unordered_map inPlaceToOutOfPlace = { + {aten::add_, aten::add}, + {aten::sub_, aten::sub}, + {aten::div_, aten::div}, + {aten::mul_, aten::mul}, + {aten::masked_fill_, aten::masked_fill}, + {aten::zero_, aten::zeros_like}, + {aten::fill_, aten::full_like}}; + +// This is a horrible no good awful hack to "fill in" the TensorOptions +// arguments of zeros_like and full_like so that the defaults are filled +// in. Ugh. Would be better to just run the frontend to get the correct +// arity here. +static const std::unordered_map expectedInputCount = { + {aten::zero_, 6}, + {aten::fill_, 7}}; + +bool isInplaceOp(const Node* node) { + return inPlaceToOutOfPlace.count(node->kind()) != 0; +} + +// Remove all in-place ops and replace them with out-of-place equivalents. +// e.g. +// %foo = aten::add_(%foo, %n) +// becomes +// %foo.2 = aten::add(%foo, %n) +// +// NOTE: this is NOT SAFE, since it assumes that the LHS is not aliased by +// another value. This is only to avoid breaking ONNX export; when alias +// analysis is done we can emit a warning if someone tries to export. +void RemoveInplaceOps(Block* block) { + auto graph = block->owningGraph(); + auto it = block->nodes().begin(); + while (it != block->nodes().end()) { + auto node = *it; + ++it; + for (auto block : node->blocks()) { + RemoveInplaceOps(block); + } + + if (isInplaceOp(node)) { + // create a replacement out of place op + auto newNode = graph->create(inPlaceToOutOfPlace.at(node->kind())); + newNode->insertBefore(node); + newNode->setScope(node->scope()); + // copy inputs + for (auto input : node->inputs()) { + newNode->addInput(input); + } + + int additionalInputCount = 0; + if (expectedInputCount.find(node->kind()) != expectedInputCount.end()) { + additionalInputCount = expectedInputCount.at(node->kind()) - + static_cast(newNode->inputs().size()); + } + + for (int i = 0; i < additionalInputCount; ++i) { + auto noneNode = graph->createNone(); + noneNode->insertBefore(newNode); + newNode->addInput(noneNode->output()); + } + + // Create a new output node and replace all uses of self with it + newNode->output()->copyMetadata(node->output()); + node->replaceAllUsesWith(newNode); + node->inputs()[0]->replaceAllUsesAfterNodeWith( + newNode, newNode->output()); + node->destroy(); + } + } +} +} // namespace + // Handles special case of binary inplace ops, where the first input node // has a lower type precedence than the second input node. When the // inplace node is converted to a regular op, this information is lost and @@ -12,19 +85,20 @@ namespace jit { // are the same. // An example scenario would be: // Before: -// graph(%0 : Half), -// %1 : Float): +// graph(%0 : Float), +// %1 : Half): // # Should result in a Half, but after translation to out-of-place, // # would become a Float b/c Half+Float -> Float. -// Float : = aten::add_(%0, %1) +// %4 : Float = onnx::Cast[to=1](%1) +// %5 : Float = onnx::Add(%4, %0) // ... // After: -// graph(%0 : Half), -// %1 : Float): -// %2 : Half = aten::type_as(%1, %0) -// # Half + Half will result in correct dtype. -// Half : = aten::add_(%0, %2) +// graph(%0 : Float), +// %1 : Half): +// %4 : Half = onnx::Cast[to=10](%0) +// %5 : Half = onnx::Add(%1, %4) // ... + void ImplicitCastForBinaryInplaceOps(Block* b) { for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end; ++it) { for (auto* child_block : it->blocks()) { @@ -54,5 +128,10 @@ void ImplicitCastForBinaryInplaceOps(Block* b) { } } } + +void RemoveInplaceOps(const std::shared_ptr& graph) { + ImplicitCastForBinaryInplaceOps(graph->block()); + RemoveInplaceOps(graph->block()); +} } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/passes/remove_inplace_ops.h b/torch/csrc/jit/passes/remove_inplace_ops.h index 50cf9ab33c6..e597da64860 100644 --- a/torch/csrc/jit/passes/remove_inplace_ops.h +++ b/torch/csrc/jit/passes/remove_inplace_ops.h @@ -7,6 +7,8 @@ namespace torch { namespace jit { // see .cpp for docs +TORCH_API void RemoveInplaceOps(const std::shared_ptr& graph); + TORCH_API void ImplicitCastForBinaryInplaceOps(Block* block); } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp index bb528e1ea71..2f7a5761a98 100644 --- a/torch/csrc/jit/python/init.cpp +++ b/torch/csrc/jit/python/init.cpp @@ -408,6 +408,10 @@ void initJITBindings(PyObject* module) { py::arg("value_name_pairs") = std::vector>()) .def("_jit_pass_constant_pooling", ConstantPooling) + // RemoveInplaceOps is used by CoreML so it must be removed with care. + .def( + "_jit_pass_remove_inplace_ops", + [](const std::shared_ptr& g) { return RemoveInplaceOps(g); }) .def( "_jit_pass_create_functional_graphs", [](std::shared_ptr& g) { return CreateFunctionalGraphs(g); })