mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[pytorch] Bring back RemoveInplaceOps() (#62200)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62200
This commit brings back the `RemoveInplaceOps` pass removed in D29523283 (dec5aa2260) that apparently had a bunch of internal users.
Test Plan: danthe3rd
Reviewed By: danthe3rd
Differential Revision: D29833316
fbshipit-source-id: 6cf13d463ab0a5e50ba3eb3243f79a9c51623809
This commit is contained in:
parent
b91a917616
commit
05b802d4e0
|
|
@ -295,6 +295,7 @@ def _jit_pass_onnx_set_dynamic_input_shape(graph: Graph, dynamic_axes: Dict[str,
|
|||
def _jit_pass_onnx_graph_shape_type_inference(graph: Graph, paramsDict: Dict[str, IValue], opset_version: _int) -> None: ...
|
||||
def _jit_pass_onnx_assign_output_shape(graph: Graph, tensors: List[Tensor], desc: IODescriptor, onnx_shape_inference: _bool = False) -> None: ...
|
||||
def _jit_pass_onnx_remove_inplace_ops_for_onnx(graph: Graph, module: Module) -> None: ...
|
||||
def _jit_pass_remove_inplace_ops(graph: Graph) -> None: ...
|
||||
def _jit_pass_canonicalize_graph_fuser_ops(graph: Graph) -> None: ...
|
||||
def _jit_pass_peephole(graph: Graph, addmm_fusion_enabled: _bool) -> None: ...
|
||||
def _jit_pass_fuse_addmm(graph: Graph) -> None: ...
|
||||
|
|
|
|||
|
|
@ -1,8 +1,81 @@
|
|||
#include <torch/csrc/jit/jit_log.h>
|
||||
#include <torch/csrc/jit/passes/remove_inplace_ops.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace {
|
||||
static const std::unordered_map<NodeKind, NodeKind> inPlaceToOutOfPlace = {
|
||||
{aten::add_, aten::add},
|
||||
{aten::sub_, aten::sub},
|
||||
{aten::div_, aten::div},
|
||||
{aten::mul_, aten::mul},
|
||||
{aten::masked_fill_, aten::masked_fill},
|
||||
{aten::zero_, aten::zeros_like},
|
||||
{aten::fill_, aten::full_like}};
|
||||
|
||||
// This is a horrible no good awful hack to "fill in" the TensorOptions
|
||||
// arguments of zeros_like and full_like so that the defaults are filled
|
||||
// in. Ugh. Would be better to just run the frontend to get the correct
|
||||
// arity here.
|
||||
static const std::unordered_map<NodeKind, int> expectedInputCount = {
|
||||
{aten::zero_, 6},
|
||||
{aten::fill_, 7}};
|
||||
|
||||
bool isInplaceOp(const Node* node) {
|
||||
return inPlaceToOutOfPlace.count(node->kind()) != 0;
|
||||
}
|
||||
|
||||
// Remove all in-place ops and replace them with out-of-place equivalents.
|
||||
// e.g.
|
||||
// %foo = aten::add_(%foo, %n)
|
||||
// becomes
|
||||
// %foo.2 = aten::add(%foo, %n)
|
||||
//
|
||||
// NOTE: this is NOT SAFE, since it assumes that the LHS is not aliased by
|
||||
// another value. This is only to avoid breaking ONNX export; when alias
|
||||
// analysis is done we can emit a warning if someone tries to export.
|
||||
void RemoveInplaceOps(Block* block) {
|
||||
auto graph = block->owningGraph();
|
||||
auto it = block->nodes().begin();
|
||||
while (it != block->nodes().end()) {
|
||||
auto node = *it;
|
||||
++it;
|
||||
for (auto block : node->blocks()) {
|
||||
RemoveInplaceOps(block);
|
||||
}
|
||||
|
||||
if (isInplaceOp(node)) {
|
||||
// create a replacement out of place op
|
||||
auto newNode = graph->create(inPlaceToOutOfPlace.at(node->kind()));
|
||||
newNode->insertBefore(node);
|
||||
newNode->setScope(node->scope());
|
||||
// copy inputs
|
||||
for (auto input : node->inputs()) {
|
||||
newNode->addInput(input);
|
||||
}
|
||||
|
||||
int additionalInputCount = 0;
|
||||
if (expectedInputCount.find(node->kind()) != expectedInputCount.end()) {
|
||||
additionalInputCount = expectedInputCount.at(node->kind()) -
|
||||
static_cast<int>(newNode->inputs().size());
|
||||
}
|
||||
|
||||
for (int i = 0; i < additionalInputCount; ++i) {
|
||||
auto noneNode = graph->createNone();
|
||||
noneNode->insertBefore(newNode);
|
||||
newNode->addInput(noneNode->output());
|
||||
}
|
||||
|
||||
// Create a new output node and replace all uses of self with it
|
||||
newNode->output()->copyMetadata(node->output());
|
||||
node->replaceAllUsesWith(newNode);
|
||||
node->inputs()[0]->replaceAllUsesAfterNodeWith(
|
||||
newNode, newNode->output());
|
||||
node->destroy();
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// Handles special case of binary inplace ops, where the first input node
|
||||
// has a lower type precedence than the second input node. When the
|
||||
// inplace node is converted to a regular op, this information is lost and
|
||||
|
|
@ -12,19 +85,20 @@ namespace jit {
|
|||
// are the same.
|
||||
// An example scenario would be:
|
||||
// Before:
|
||||
// graph(%0 : Half),
|
||||
// %1 : Float):
|
||||
// graph(%0 : Float),
|
||||
// %1 : Half):
|
||||
// # Should result in a Half, but after translation to out-of-place,
|
||||
// # would become a Float b/c Half+Float -> Float.
|
||||
// Float : = aten::add_(%0, %1)
|
||||
// %4 : Float = onnx::Cast[to=1](%1)
|
||||
// %5 : Float = onnx::Add(%4, %0)
|
||||
// ...
|
||||
// After:
|
||||
// graph(%0 : Half),
|
||||
// %1 : Float):
|
||||
// %2 : Half = aten::type_as(%1, %0)
|
||||
// # Half + Half will result in correct dtype.
|
||||
// Half : = aten::add_(%0, %2)
|
||||
// graph(%0 : Float),
|
||||
// %1 : Half):
|
||||
// %4 : Half = onnx::Cast[to=10](%0)
|
||||
// %5 : Half = onnx::Add(%1, %4)
|
||||
// ...
|
||||
|
||||
void ImplicitCastForBinaryInplaceOps(Block* b) {
|
||||
for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end; ++it) {
|
||||
for (auto* child_block : it->blocks()) {
|
||||
|
|
@ -54,5 +128,10 @@ void ImplicitCastForBinaryInplaceOps(Block* b) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
void RemoveInplaceOps(const std::shared_ptr<Graph>& graph) {
|
||||
ImplicitCastForBinaryInplaceOps(graph->block());
|
||||
RemoveInplaceOps(graph->block());
|
||||
}
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@
|
|||
namespace torch {
|
||||
namespace jit {
|
||||
// see .cpp for docs
|
||||
TORCH_API void RemoveInplaceOps(const std::shared_ptr<Graph>& graph);
|
||||
|
||||
TORCH_API void ImplicitCastForBinaryInplaceOps(Block* block);
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -408,6 +408,10 @@ void initJITBindings(PyObject* module) {
|
|||
py::arg("value_name_pairs") =
|
||||
std::vector<std::pair<std::string, std::string>>())
|
||||
.def("_jit_pass_constant_pooling", ConstantPooling)
|
||||
// RemoveInplaceOps is used by CoreML so it must be removed with care.
|
||||
.def(
|
||||
"_jit_pass_remove_inplace_ops",
|
||||
[](const std::shared_ptr<Graph>& g) { return RemoveInplaceOps(g); })
|
||||
.def(
|
||||
"_jit_pass_create_functional_graphs",
|
||||
[](std::shared_ptr<Graph>& g) { return CreateFunctionalGraphs(g); })
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user