[pytorch] Bring back RemoveInplaceOps() (#62200)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62200

This commit brings back the `RemoveInplaceOps` pass removed in D29523283 (dec5aa2260) that apparently had a bunch of internal users.

Test Plan: danthe3rd

Reviewed By: danthe3rd

Differential Revision: D29833316

fbshipit-source-id: 6cf13d463ab0a5e50ba3eb3243f79a9c51623809
This commit is contained in:
Meghan Lele 2021-07-28 11:42:44 -07:00 committed by Facebook GitHub Bot
parent b91a917616
commit 05b802d4e0
4 changed files with 95 additions and 9 deletions

View File

@ -295,6 +295,7 @@ def _jit_pass_onnx_set_dynamic_input_shape(graph: Graph, dynamic_axes: Dict[str,
def _jit_pass_onnx_graph_shape_type_inference(graph: Graph, paramsDict: Dict[str, IValue], opset_version: _int) -> None: ... def _jit_pass_onnx_graph_shape_type_inference(graph: Graph, paramsDict: Dict[str, IValue], opset_version: _int) -> None: ...
def _jit_pass_onnx_assign_output_shape(graph: Graph, tensors: List[Tensor], desc: IODescriptor, onnx_shape_inference: _bool = False) -> None: ... def _jit_pass_onnx_assign_output_shape(graph: Graph, tensors: List[Tensor], desc: IODescriptor, onnx_shape_inference: _bool = False) -> None: ...
def _jit_pass_onnx_remove_inplace_ops_for_onnx(graph: Graph, module: Module) -> None: ... def _jit_pass_onnx_remove_inplace_ops_for_onnx(graph: Graph, module: Module) -> None: ...
def _jit_pass_remove_inplace_ops(graph: Graph) -> None: ...
def _jit_pass_canonicalize_graph_fuser_ops(graph: Graph) -> None: ... def _jit_pass_canonicalize_graph_fuser_ops(graph: Graph) -> None: ...
def _jit_pass_peephole(graph: Graph, addmm_fusion_enabled: _bool) -> None: ... def _jit_pass_peephole(graph: Graph, addmm_fusion_enabled: _bool) -> None: ...
def _jit_pass_fuse_addmm(graph: Graph) -> None: ... def _jit_pass_fuse_addmm(graph: Graph) -> None: ...

View File

@ -1,8 +1,81 @@
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/remove_inplace_ops.h> #include <torch/csrc/jit/passes/remove_inplace_ops.h>
namespace torch { namespace torch {
namespace jit { namespace jit {
namespace {
static const std::unordered_map<NodeKind, NodeKind> inPlaceToOutOfPlace = {
{aten::add_, aten::add},
{aten::sub_, aten::sub},
{aten::div_, aten::div},
{aten::mul_, aten::mul},
{aten::masked_fill_, aten::masked_fill},
{aten::zero_, aten::zeros_like},
{aten::fill_, aten::full_like}};
// This is a horrible no good awful hack to "fill in" the TensorOptions
// arguments of zeros_like and full_like so that the defaults are filled
// in. Ugh. Would be better to just run the frontend to get the correct
// arity here.
static const std::unordered_map<NodeKind, int> expectedInputCount = {
{aten::zero_, 6},
{aten::fill_, 7}};
bool isInplaceOp(const Node* node) {
return inPlaceToOutOfPlace.count(node->kind()) != 0;
}
// Remove all in-place ops and replace them with out-of-place equivalents.
// e.g.
// %foo = aten::add_(%foo, %n)
// becomes
// %foo.2 = aten::add(%foo, %n)
//
// NOTE: this is NOT SAFE, since it assumes that the LHS is not aliased by
// another value. This is only to avoid breaking ONNX export; when alias
// analysis is done we can emit a warning if someone tries to export.
void RemoveInplaceOps(Block* block) {
auto graph = block->owningGraph();
auto it = block->nodes().begin();
while (it != block->nodes().end()) {
auto node = *it;
++it;
for (auto block : node->blocks()) {
RemoveInplaceOps(block);
}
if (isInplaceOp(node)) {
// create a replacement out of place op
auto newNode = graph->create(inPlaceToOutOfPlace.at(node->kind()));
newNode->insertBefore(node);
newNode->setScope(node->scope());
// copy inputs
for (auto input : node->inputs()) {
newNode->addInput(input);
}
int additionalInputCount = 0;
if (expectedInputCount.find(node->kind()) != expectedInputCount.end()) {
additionalInputCount = expectedInputCount.at(node->kind()) -
static_cast<int>(newNode->inputs().size());
}
for (int i = 0; i < additionalInputCount; ++i) {
auto noneNode = graph->createNone();
noneNode->insertBefore(newNode);
newNode->addInput(noneNode->output());
}
// Create a new output node and replace all uses of self with it
newNode->output()->copyMetadata(node->output());
node->replaceAllUsesWith(newNode);
node->inputs()[0]->replaceAllUsesAfterNodeWith(
newNode, newNode->output());
node->destroy();
}
}
}
} // namespace
// Handles special case of binary inplace ops, where the first input node // Handles special case of binary inplace ops, where the first input node
// has a lower type precedence than the second input node. When the // has a lower type precedence than the second input node. When the
// inplace node is converted to a regular op, this information is lost and // inplace node is converted to a regular op, this information is lost and
@ -12,19 +85,20 @@ namespace jit {
// are the same. // are the same.
// An example scenario would be: // An example scenario would be:
// Before: // Before:
// graph(%0 : Half), // graph(%0 : Float),
// %1 : Float): // %1 : Half):
// # Should result in a Half, but after translation to out-of-place, // # Should result in a Half, but after translation to out-of-place,
// # would become a Float b/c Half+Float -> Float. // # would become a Float b/c Half+Float -> Float.
// Float : = aten::add_(%0, %1) // %4 : Float = onnx::Cast[to=1](%1)
// %5 : Float = onnx::Add(%4, %0)
// ... // ...
// After: // After:
// graph(%0 : Half), // graph(%0 : Float),
// %1 : Float): // %1 : Half):
// %2 : Half = aten::type_as(%1, %0) // %4 : Half = onnx::Cast[to=10](%0)
// # Half + Half will result in correct dtype. // %5 : Half = onnx::Add(%1, %4)
// Half : = aten::add_(%0, %2)
// ... // ...
void ImplicitCastForBinaryInplaceOps(Block* b) { void ImplicitCastForBinaryInplaceOps(Block* b) {
for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end; ++it) { for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end; ++it) {
for (auto* child_block : it->blocks()) { for (auto* child_block : it->blocks()) {
@ -54,5 +128,10 @@ void ImplicitCastForBinaryInplaceOps(Block* b) {
} }
} }
} }
void RemoveInplaceOps(const std::shared_ptr<Graph>& graph) {
ImplicitCastForBinaryInplaceOps(graph->block());
RemoveInplaceOps(graph->block());
}
} // namespace jit } // namespace jit
} // namespace torch } // namespace torch

View File

@ -7,6 +7,8 @@
namespace torch { namespace torch {
namespace jit { namespace jit {
// see .cpp for docs // see .cpp for docs
TORCH_API void RemoveInplaceOps(const std::shared_ptr<Graph>& graph);
TORCH_API void ImplicitCastForBinaryInplaceOps(Block* block); TORCH_API void ImplicitCastForBinaryInplaceOps(Block* block);
} // namespace jit } // namespace jit
} // namespace torch } // namespace torch

View File

@ -408,6 +408,10 @@ void initJITBindings(PyObject* module) {
py::arg("value_name_pairs") = py::arg("value_name_pairs") =
std::vector<std::pair<std::string, std::string>>()) std::vector<std::pair<std::string, std::string>>())
.def("_jit_pass_constant_pooling", ConstantPooling) .def("_jit_pass_constant_pooling", ConstantPooling)
// RemoveInplaceOps is used by CoreML so it must be removed with care.
.def(
"_jit_pass_remove_inplace_ops",
[](const std::shared_ptr<Graph>& g) { return RemoveInplaceOps(g); })
.def( .def(
"_jit_pass_create_functional_graphs", "_jit_pass_create_functional_graphs",
[](std::shared_ptr<Graph>& g) { return CreateFunctionalGraphs(g); }) [](std::shared_ptr<Graph>& g) { return CreateFunctionalGraphs(g); })