mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[pytorch] Bring back RemoveInplaceOps() (#62200)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62200
This commit brings back the `RemoveInplaceOps` pass removed in D29523283 (dec5aa2260) that apparently had a bunch of internal users.
Test Plan: danthe3rd
Reviewed By: danthe3rd
Differential Revision: D29833316
fbshipit-source-id: 6cf13d463ab0a5e50ba3eb3243f79a9c51623809
This commit is contained in:
parent
b91a917616
commit
05b802d4e0
|
|
@ -295,6 +295,7 @@ def _jit_pass_onnx_set_dynamic_input_shape(graph: Graph, dynamic_axes: Dict[str,
|
||||||
def _jit_pass_onnx_graph_shape_type_inference(graph: Graph, paramsDict: Dict[str, IValue], opset_version: _int) -> None: ...
|
def _jit_pass_onnx_graph_shape_type_inference(graph: Graph, paramsDict: Dict[str, IValue], opset_version: _int) -> None: ...
|
||||||
def _jit_pass_onnx_assign_output_shape(graph: Graph, tensors: List[Tensor], desc: IODescriptor, onnx_shape_inference: _bool = False) -> None: ...
|
def _jit_pass_onnx_assign_output_shape(graph: Graph, tensors: List[Tensor], desc: IODescriptor, onnx_shape_inference: _bool = False) -> None: ...
|
||||||
def _jit_pass_onnx_remove_inplace_ops_for_onnx(graph: Graph, module: Module) -> None: ...
|
def _jit_pass_onnx_remove_inplace_ops_for_onnx(graph: Graph, module: Module) -> None: ...
|
||||||
|
def _jit_pass_remove_inplace_ops(graph: Graph) -> None: ...
|
||||||
def _jit_pass_canonicalize_graph_fuser_ops(graph: Graph) -> None: ...
|
def _jit_pass_canonicalize_graph_fuser_ops(graph: Graph) -> None: ...
|
||||||
def _jit_pass_peephole(graph: Graph, addmm_fusion_enabled: _bool) -> None: ...
|
def _jit_pass_peephole(graph: Graph, addmm_fusion_enabled: _bool) -> None: ...
|
||||||
def _jit_pass_fuse_addmm(graph: Graph) -> None: ...
|
def _jit_pass_fuse_addmm(graph: Graph) -> None: ...
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,81 @@
|
||||||
#include <torch/csrc/jit/jit_log.h>
|
|
||||||
#include <torch/csrc/jit/passes/remove_inplace_ops.h>
|
#include <torch/csrc/jit/passes/remove_inplace_ops.h>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
|
namespace {
|
||||||
|
static const std::unordered_map<NodeKind, NodeKind> inPlaceToOutOfPlace = {
|
||||||
|
{aten::add_, aten::add},
|
||||||
|
{aten::sub_, aten::sub},
|
||||||
|
{aten::div_, aten::div},
|
||||||
|
{aten::mul_, aten::mul},
|
||||||
|
{aten::masked_fill_, aten::masked_fill},
|
||||||
|
{aten::zero_, aten::zeros_like},
|
||||||
|
{aten::fill_, aten::full_like}};
|
||||||
|
|
||||||
|
// This is a horrible no good awful hack to "fill in" the TensorOptions
|
||||||
|
// arguments of zeros_like and full_like so that the defaults are filled
|
||||||
|
// in. Ugh. Would be better to just run the frontend to get the correct
|
||||||
|
// arity here.
|
||||||
|
static const std::unordered_map<NodeKind, int> expectedInputCount = {
|
||||||
|
{aten::zero_, 6},
|
||||||
|
{aten::fill_, 7}};
|
||||||
|
|
||||||
|
bool isInplaceOp(const Node* node) {
|
||||||
|
return inPlaceToOutOfPlace.count(node->kind()) != 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove all in-place ops and replace them with out-of-place equivalents.
|
||||||
|
// e.g.
|
||||||
|
// %foo = aten::add_(%foo, %n)
|
||||||
|
// becomes
|
||||||
|
// %foo.2 = aten::add(%foo, %n)
|
||||||
|
//
|
||||||
|
// NOTE: this is NOT SAFE, since it assumes that the LHS is not aliased by
|
||||||
|
// another value. This is only to avoid breaking ONNX export; when alias
|
||||||
|
// analysis is done we can emit a warning if someone tries to export.
|
||||||
|
void RemoveInplaceOps(Block* block) {
|
||||||
|
auto graph = block->owningGraph();
|
||||||
|
auto it = block->nodes().begin();
|
||||||
|
while (it != block->nodes().end()) {
|
||||||
|
auto node = *it;
|
||||||
|
++it;
|
||||||
|
for (auto block : node->blocks()) {
|
||||||
|
RemoveInplaceOps(block);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isInplaceOp(node)) {
|
||||||
|
// create a replacement out of place op
|
||||||
|
auto newNode = graph->create(inPlaceToOutOfPlace.at(node->kind()));
|
||||||
|
newNode->insertBefore(node);
|
||||||
|
newNode->setScope(node->scope());
|
||||||
|
// copy inputs
|
||||||
|
for (auto input : node->inputs()) {
|
||||||
|
newNode->addInput(input);
|
||||||
|
}
|
||||||
|
|
||||||
|
int additionalInputCount = 0;
|
||||||
|
if (expectedInputCount.find(node->kind()) != expectedInputCount.end()) {
|
||||||
|
additionalInputCount = expectedInputCount.at(node->kind()) -
|
||||||
|
static_cast<int>(newNode->inputs().size());
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < additionalInputCount; ++i) {
|
||||||
|
auto noneNode = graph->createNone();
|
||||||
|
noneNode->insertBefore(newNode);
|
||||||
|
newNode->addInput(noneNode->output());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a new output node and replace all uses of self with it
|
||||||
|
newNode->output()->copyMetadata(node->output());
|
||||||
|
node->replaceAllUsesWith(newNode);
|
||||||
|
node->inputs()[0]->replaceAllUsesAfterNodeWith(
|
||||||
|
newNode, newNode->output());
|
||||||
|
node->destroy();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
// Handles special case of binary inplace ops, where the first input node
|
// Handles special case of binary inplace ops, where the first input node
|
||||||
// has a lower type precedence than the second input node. When the
|
// has a lower type precedence than the second input node. When the
|
||||||
// inplace node is converted to a regular op, this information is lost and
|
// inplace node is converted to a regular op, this information is lost and
|
||||||
|
|
@ -12,19 +85,20 @@ namespace jit {
|
||||||
// are the same.
|
// are the same.
|
||||||
// An example scenario would be:
|
// An example scenario would be:
|
||||||
// Before:
|
// Before:
|
||||||
// graph(%0 : Half),
|
// graph(%0 : Float),
|
||||||
// %1 : Float):
|
// %1 : Half):
|
||||||
// # Should result in a Half, but after translation to out-of-place,
|
// # Should result in a Half, but after translation to out-of-place,
|
||||||
// # would become a Float b/c Half+Float -> Float.
|
// # would become a Float b/c Half+Float -> Float.
|
||||||
// Float : = aten::add_(%0, %1)
|
// %4 : Float = onnx::Cast[to=1](%1)
|
||||||
|
// %5 : Float = onnx::Add(%4, %0)
|
||||||
// ...
|
// ...
|
||||||
// After:
|
// After:
|
||||||
// graph(%0 : Half),
|
// graph(%0 : Float),
|
||||||
// %1 : Float):
|
// %1 : Half):
|
||||||
// %2 : Half = aten::type_as(%1, %0)
|
// %4 : Half = onnx::Cast[to=10](%0)
|
||||||
// # Half + Half will result in correct dtype.
|
// %5 : Half = onnx::Add(%1, %4)
|
||||||
// Half : = aten::add_(%0, %2)
|
|
||||||
// ...
|
// ...
|
||||||
|
|
||||||
void ImplicitCastForBinaryInplaceOps(Block* b) {
|
void ImplicitCastForBinaryInplaceOps(Block* b) {
|
||||||
for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end; ++it) {
|
for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end; ++it) {
|
||||||
for (auto* child_block : it->blocks()) {
|
for (auto* child_block : it->blocks()) {
|
||||||
|
|
@ -54,5 +128,10 @@ void ImplicitCastForBinaryInplaceOps(Block* b) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void RemoveInplaceOps(const std::shared_ptr<Graph>& graph) {
|
||||||
|
ImplicitCastForBinaryInplaceOps(graph->block());
|
||||||
|
RemoveInplaceOps(graph->block());
|
||||||
|
}
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,8 @@
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
// see .cpp for docs
|
// see .cpp for docs
|
||||||
|
TORCH_API void RemoveInplaceOps(const std::shared_ptr<Graph>& graph);
|
||||||
|
|
||||||
TORCH_API void ImplicitCastForBinaryInplaceOps(Block* block);
|
TORCH_API void ImplicitCastForBinaryInplaceOps(Block* block);
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
|
|
||||||
|
|
@ -408,6 +408,10 @@ void initJITBindings(PyObject* module) {
|
||||||
py::arg("value_name_pairs") =
|
py::arg("value_name_pairs") =
|
||||||
std::vector<std::pair<std::string, std::string>>())
|
std::vector<std::pair<std::string, std::string>>())
|
||||||
.def("_jit_pass_constant_pooling", ConstantPooling)
|
.def("_jit_pass_constant_pooling", ConstantPooling)
|
||||||
|
// RemoveInplaceOps is used by CoreML so it must be removed with care.
|
||||||
|
.def(
|
||||||
|
"_jit_pass_remove_inplace_ops",
|
||||||
|
[](const std::shared_ptr<Graph>& g) { return RemoveInplaceOps(g); })
|
||||||
.def(
|
.def(
|
||||||
"_jit_pass_create_functional_graphs",
|
"_jit_pass_create_functional_graphs",
|
||||||
[](std::shared_ptr<Graph>& g) { return CreateFunctionalGraphs(g); })
|
[](std::shared_ptr<Graph>& g) { return CreateFunctionalGraphs(g); })
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user