mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ONNX] Redesign onnx pass to enable shape type dependent pattern conversion - cont (#51795) (#53304)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/53304 With the introduction of ONNX shape inference, shape and type are inferred on the fly as operators get converted from ATen to ONNX when running symbolic function. This resolves the shape/type requirement for the symbolic functions. The pre-onnx passes however, can not be supported by shape inference, since at that stage the operators in the graph are still ATen operators. This PR is to update the design of ONNX pass, to enable a mechanism of capturing subgraphs of ATen operators of certain patterns, and convert them later, when shape/type information of upstream operators are available. The new design will require pre-onnx passes that need shape/type to be written in two parts, encapsulation and conversion. The encapsulation part will find the nodes of patterns, like how pre-onnx passes were written previously. But instead of converting the nodes, it will encapsulate them into a sub-block of a new placeholder node. This part is called before onnx pass, so it runs before calling symbolic functions. The conversion part will be called inside the onnx pass. In onnx pass, run_symbolic_func will be called for each node in topological order. When it reaches the placeholder node, the conversion part will be invoked. It will convert the nodes inside the sub-block based on pattern. By that time, it will have shape/type of upstream operators available. After the conversion is complete, the placeholder node will be removed, and nodes inside its sub-block converted. Run_symbolic_func will be called for these nodes, and they will be converted from ATen operator to ONNX operator. This PR includes several other fixes, listed below. * ~~replace helper.cpp with onnx_utils.cpp for holding utility functions.~~ * fix EraseNumberTypes on Bool type, the code was outdated that back then Bool type doesn't exist. * ~~enable onnx shape inference in export with parameter/initializer data.~~ * other code clean ups. * fix insertion of identity nodes for loop opset 13 sequence output. ~~PR depends on #51603~~ Test Plan: Imported from OSS Reviewed By: SplitInfinity Differential Revision: D26922417 Pulled By: malfet fbshipit-source-id: 14ed06158d539e2451c2e5e63ba1b32fb0f75095
This commit is contained in:
parent
5648fe6093
commit
3f9c803fe8
|
|
@ -1922,11 +1922,12 @@ class TestONNXRuntime(unittest.TestCase):
|
|||
self.run_test(CopyModel(), (x, update))
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(11)
|
||||
# TODO: Limited scripting support with ellipsis indexing.
|
||||
# Due to dependency on input tensor rank being known.
|
||||
def test_copy_ellipsis_tracing(self):
|
||||
def test_copy_ellipsis_script(self):
|
||||
class CopyModel(torch.nn.Module):
|
||||
def forward(self, x, update):
|
||||
# Insert reshape node to ensure no shape/type info for
|
||||
# x in scripting, without onnx shape inference.
|
||||
x = x.reshape(4, 3, 5, 6)
|
||||
x[2, ..., 1:3] = update
|
||||
return x
|
||||
|
||||
|
|
@ -4387,7 +4388,6 @@ class TestONNXRuntime(unittest.TestCase):
|
|||
self.run_test(MaskedSelectModel(), x)
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(11)
|
||||
@disableScriptTest() # dtype not available
|
||||
def test_index_put_to_masked_fill(self):
|
||||
class MaskedFillModel(torch.nn.Module):
|
||||
def forward(self, input_mask, some_const):
|
||||
|
|
@ -4401,7 +4401,6 @@ class TestONNXRuntime(unittest.TestCase):
|
|||
self.run_test(MaskedFillModel(), (mask, constant))
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(11)
|
||||
@disableScriptTest() # dtype not available
|
||||
def test_index_put_to_masked_scatter(self):
|
||||
class MaskedScatterModel(torch.nn.Module):
|
||||
def forward(self, input_mask, some_const):
|
||||
|
|
|
|||
|
|
@ -565,6 +565,9 @@ libtorch_python_core_sources = [
|
|||
"torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.cpp",
|
||||
"torch/csrc/jit/passes/onnx/shape_type_inference.cpp",
|
||||
"torch/csrc/jit/python/pybind_utils.cpp",
|
||||
"torch/csrc/jit/passes/onnx/pattern_conversion/common.cpp",
|
||||
"torch/csrc/jit/passes/onnx/pattern_conversion/pattern_encapsulation.cpp",
|
||||
"torch/csrc/jit/passes/onnx/pattern_conversion/pattern_conversion.cpp",
|
||||
"torch/csrc/jit/python/python_arg_flatten.cpp",
|
||||
"torch/csrc/jit/python/python_custom_class.cpp",
|
||||
"torch/csrc/jit/python/python_interpreter.cpp",
|
||||
|
|
|
|||
|
|
@ -309,6 +309,7 @@ def _jit_pass_onnx_cast_all_constant_to_floating(graph: Graph) -> None: ...
|
|||
def _jit_pass_filter_non_tensor_arguments(params: Dict[str, IValue]) -> Dict[str, Tensor]: ...
|
||||
def _jit_decay_packed_param_input_types(graph: Graph) -> None: ...
|
||||
def _jit_pass_onnx_node_shape_type_inference(n: Node, paramsDict: Dict[str, IValue], opset_version: _int) -> None: ...
|
||||
def _jit_onnx_convert_pattern_from_subblock(block: Block, n: Node, env: Dict[Value, Value]) -> List[Value]: ...
|
||||
def _jit_pass_onnx_block(
|
||||
old_block: Block,
|
||||
new_block: Block,
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
static void EraseNumberTypesOnBlock(Block* block) {
|
||||
void EraseNumberTypesOnBlock(Block* block) {
|
||||
for (auto it = block->nodes().begin(), end = block->nodes().end(); it != end;
|
||||
++it) {
|
||||
for (auto inp : it->inputs()) {
|
||||
|
|
@ -25,7 +25,7 @@ static void EraseNumberTypesOnBlock(Block* block) {
|
|||
it->output()->type()->isSubtypeOf(BoolType::get())) {
|
||||
at::Scalar s;
|
||||
if (it->output()->type()->isSubtypeOf(BoolType::get())) {
|
||||
s = static_cast<int64_t>(*constant_as<bool>(it->output()));
|
||||
s = *constant_as<bool>(it->output());
|
||||
} else {
|
||||
s = *constant_as<at::Scalar>(it->output());
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ namespace jit {
|
|||
//
|
||||
// The pass assumes that DCE will be called sometime after.
|
||||
TORCH_API void EraseNumberTypes(const std::shared_ptr<Graph>& graph);
|
||||
TORCH_API void EraseNumberTypesOnBlock(Block* block);
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -177,17 +177,6 @@ void BlockToONNX(
|
|||
std::unordered_map<Value*, Value*> env) {
|
||||
torch::autograd::SymbolicContext ctx{};
|
||||
ctx.block = new_block;
|
||||
py::object onnx = py::module::import("torch.onnx");
|
||||
py::object onnx_symbolic = py::module::import("torch.onnx.symbolic_helper");
|
||||
py::object onnx_registry = py::module::import("torch.onnx.symbolic_registry");
|
||||
|
||||
// Returns a node that n maps to in the new graph
|
||||
auto envFn = [&env](Value* n) -> Value* {
|
||||
auto it = env.find(n);
|
||||
TORCH_CHECK(it != env.end(), "Dangling node reference");
|
||||
TORCH_CHECK(it->second, "Unused node was subsequently used");
|
||||
return it->second;
|
||||
};
|
||||
|
||||
GRAPH_DEBUG(
|
||||
"BlockToONNX: graph of old block: ",
|
||||
|
|
@ -199,6 +188,40 @@ void BlockToONNX(
|
|||
env[input] = n;
|
||||
}
|
||||
|
||||
// Finally, visit all nodes in the graph
|
||||
for (auto node : old_block->nodes()) {
|
||||
NodeToONNX(node, ctx.block, operator_export_type, env);
|
||||
}
|
||||
for (auto output : old_block->outputs()) {
|
||||
ctx.block->registerOutput(env.at(output));
|
||||
}
|
||||
|
||||
// Run dce to clean-up unused functional and inplace ops.
|
||||
EliminateDeadCode(
|
||||
ctx.block,
|
||||
true,
|
||||
DCESideEffectPolicy::ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS);
|
||||
}
|
||||
|
||||
void NodeToONNX(
|
||||
Node* old_node,
|
||||
Block* new_block,
|
||||
::torch::onnx::OperatorExportTypes operator_export_type,
|
||||
std::unordered_map<Value*, Value*>& env) {
|
||||
py::object onnx = py::module::import("torch.onnx");
|
||||
py::object onnx_symbolic = py::module::import("torch.onnx.symbolic_helper");
|
||||
py::object onnx_registry = py::module::import("torch.onnx.symbolic_registry");
|
||||
|
||||
// Setup all the lambda helper functions.
|
||||
|
||||
// Returns a node that n maps to in the new graph
|
||||
auto envFn = [&env](Value* n) -> Value* {
|
||||
auto it = env.find(n);
|
||||
TORCH_CHECK(it != env.end(), "Dangling node reference");
|
||||
TORCH_CHECK(it->second, "Unused node was subsequently used");
|
||||
return it->second;
|
||||
};
|
||||
|
||||
// Put the new outputs in our environment map, and copy the type from the
|
||||
// input graph if they were not set by the symbolic. This is called only
|
||||
// with results of symbolic call (not for nodes that are just cloned).
|
||||
|
|
@ -250,11 +273,11 @@ void BlockToONNX(
|
|||
|
||||
// Clone the node and add it to the new graph
|
||||
auto cloneNode = [&](Node* node) {
|
||||
auto n_ = ctx.block->appendNode(
|
||||
ctx.block->owningGraph()->createClone(node, envFn));
|
||||
auto n_ = new_block->appendNode(
|
||||
new_block->owningGraph()->createClone(node, envFn));
|
||||
for (size_t i = 0; i < node->outputs().size(); i++) {
|
||||
// n_->outputs()[i]->setType(node->outputs()[i]->type());
|
||||
env[node->outputs()[i]] = n_->outputs()[i];
|
||||
env[node->output(i)] = n_->output(i);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -295,15 +318,20 @@ void BlockToONNX(
|
|||
py_inputs[input_nr++] = py::cast(envFn(input));
|
||||
}
|
||||
|
||||
WithInsertPoint insert_point_guard(ctx.block);
|
||||
WithCurrentScope scope_guard(*ctx.block->owningGraph(), n->scope());
|
||||
WithInsertPoint insert_point_guard(new_block);
|
||||
WithCurrentScope scope_guard(*new_block->owningGraph(), n->scope());
|
||||
py::object raw_output = onnx.attr("_run_symbolic_function")(
|
||||
ctx.block->owningGraph(), n, py_inputs, env, operator_export_type);
|
||||
new_block->owningGraph(),
|
||||
new_block,
|
||||
n,
|
||||
py_inputs,
|
||||
env,
|
||||
operator_export_type);
|
||||
|
||||
// TODO: Assert it's an ATen identifier???
|
||||
// (Sometimes it's not...)
|
||||
processSymbolicOutput(n->kind().toUnqualString(), n, raw_output);
|
||||
GRAPH_DUMP("after process output:", ctx.block->owningGraph());
|
||||
GRAPH_DUMP("after process output:", new_block->owningGraph());
|
||||
};
|
||||
|
||||
auto callPySymbolicMethod = [&](ConcretePythonOp* op) {
|
||||
|
|
@ -322,7 +350,7 @@ void BlockToONNX(
|
|||
// by regular args, with Variables replaced by corresponding nodes.
|
||||
Py_ssize_t input_nr = 0;
|
||||
py::tuple py_symbolic_args(1 + op->cconv.size());
|
||||
py_symbolic_args[input_nr++] = py::cast(ctx.block->owningGraph());
|
||||
py_symbolic_args[input_nr++] = py::cast(new_block->owningGraph());
|
||||
auto inputs = op->inputs();
|
||||
auto node_it = inputs.begin();
|
||||
auto scalar_it = op->scalar_args.begin();
|
||||
|
|
@ -343,8 +371,8 @@ void BlockToONNX(
|
|||
py_symbolic_args[input_nr++] = obj;
|
||||
}
|
||||
|
||||
WithInsertPoint insert_point_guard(ctx.block);
|
||||
WithCurrentScope scope_guard(*ctx.block->owningGraph(), op->scope());
|
||||
WithInsertPoint insert_point_guard(new_block);
|
||||
WithCurrentScope scope_guard(*new_block->owningGraph(), op->scope());
|
||||
// Call the symbolic function
|
||||
// Use a little trampoline function so we can give good error messages
|
||||
// upon argument mismatch
|
||||
|
|
@ -357,24 +385,15 @@ void BlockToONNX(
|
|||
processSymbolicOutput(op->name(), op, raw_output);
|
||||
};
|
||||
|
||||
// Finally, visit all nodes in the graph
|
||||
for (auto node : old_block->nodes()) {
|
||||
if (node->kind().is_caffe2()) {
|
||||
// Pass on Caffe2 operator, since we already preprocess it
|
||||
cloneNode(node);
|
||||
} else if (node->kind() == prim::PythonOp) {
|
||||
callPySymbolicMethod(static_cast<ConcretePythonOp*>(node));
|
||||
} else {
|
||||
callPySymbolicFunction(node);
|
||||
}
|
||||
auto k = old_node->kind();
|
||||
if (k.is_caffe2()) {
|
||||
// Pass on Caffe2 operator, since we already preprocess it
|
||||
cloneNode(old_node);
|
||||
} else if (k == prim::PythonOp) {
|
||||
callPySymbolicMethod(static_cast<ConcretePythonOp*>(old_node));
|
||||
} else {
|
||||
callPySymbolicFunction(old_node);
|
||||
}
|
||||
for (auto output : old_block->outputs()) {
|
||||
ctx.block->registerOutput(env.at(output));
|
||||
}
|
||||
EliminateDeadCode(
|
||||
ctx.block,
|
||||
true,
|
||||
DCESideEffectPolicy::ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS);
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
|
|
|
|||
|
|
@ -14,6 +14,11 @@ TORCH_API void BlockToONNX(
|
|||
Block* new_block,
|
||||
::torch::onnx::OperatorExportTypes operator_export_type,
|
||||
std::unordered_map<Value*, Value*> env);
|
||||
TORCH_API void NodeToONNX(
|
||||
Node* old_node,
|
||||
Block* new_block,
|
||||
::torch::onnx::OperatorExportTypes operator_export_type,
|
||||
std::unordered_map<Value*, Value*>& env);
|
||||
TORCH_API void RemovePrintOps(std::shared_ptr<Graph>& graph);
|
||||
TORCH_API void PreprocessCaffe2Ops(std::shared_ptr<Graph>& graph);
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
#include <torch/csrc/jit/passes/onnx/cast_all_constant_to_floating.h>
|
||||
#include <torch/csrc/jit/passes/onnx/helper.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
|
@ -36,12 +37,12 @@ void CastAllConstantToFloating(Block* block) {
|
|||
case at::ScalarType::Int:
|
||||
case at::ScalarType::Short:
|
||||
case at::ScalarType::Bool:
|
||||
to_type = 6; // ::ONNX_NAMESPACE::TensorProto_DataType_INT32;
|
||||
to_type = ATenTypeToOnnxType(val.scalar_type());
|
||||
val = val.to(at::ScalarType::Float);
|
||||
break;
|
||||
|
||||
case at::ScalarType::Long:
|
||||
to_type = 7; // ::ONNX_NAMESPACE::TensorProto_DataType_INT64;
|
||||
to_type = ATenTypeToOnnxType(val.scalar_type());
|
||||
val = val.to(at::ScalarType::Double);
|
||||
break;
|
||||
|
||||
|
|
|
|||
|
|
@ -183,14 +183,23 @@ std::vector<Value*> ConvertSequenceDependencies(Node* node, int opset_version) {
|
|||
return new_outputs;
|
||||
}
|
||||
|
||||
void ConvertSequenceDependencies(Block* block, int opset_version) {
|
||||
for (auto* node : block->nodes()) {
|
||||
for (Block* block : node->blocks()) {
|
||||
ConvertSequenceDependencies(block, opset_version);
|
||||
// Resolving limitation from ONNX that the block output can not be
|
||||
// a value from outside the block. Inserting an Identity node inside
|
||||
// the block, linking with the value outside as workaround.
|
||||
void FixupONNXSubblockOutputs(Node* n) {
|
||||
for (Block* block : n->blocks()) {
|
||||
for (Value* output : block->outputs()) {
|
||||
if (output->node()->owningBlock() != block) {
|
||||
Node* id_node = block->owningGraph()->create(onnx::Identity);
|
||||
id_node->insertBefore(block->return_node());
|
||||
id_node->addInput(output);
|
||||
id_node->output()->copyMetadata(output);
|
||||
block->return_node()->replaceInputWith(output, id_node->output());
|
||||
}
|
||||
}
|
||||
ConvertSequenceDependencies(node, opset_version);
|
||||
}
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
void FixupONNXLoopNodeInputs(Node* node) {
|
||||
|
|
@ -201,21 +210,21 @@ void FixupONNXLoopNodeInputs(Node* node) {
|
|||
auto* graph = node->owningGraph();
|
||||
|
||||
// add cast to condition input outside the loop.
|
||||
Value* cond_val = node->inputs()[1];
|
||||
Value* cond_val = node->input(1);
|
||||
if (IsCondCastRequired(cond_val))
|
||||
InsertCastForCond(cond_val, graph, node);
|
||||
|
||||
// Setup Loop input cond and i.
|
||||
TORCH_INTERNAL_ASSERT(node->blocks().size() == 1);
|
||||
auto* sub_block = node->blocks()[0];
|
||||
auto* sub_block = node->blocks().at(0);
|
||||
Value* cond = sub_block->insertInput(1, "cond");
|
||||
cond->setType(BoolType::create());
|
||||
|
||||
Value* i = sub_block->inputs()[0];
|
||||
Value* i = sub_block->inputs().at(0);
|
||||
i->setType(TensorType::fromNumberType(IntType::get()));
|
||||
|
||||
// add cast to condition input inside the loop.
|
||||
Value* next_cond_val = sub_block->outputs()[0];
|
||||
Value* next_cond_val = sub_block->outputs().at(0);
|
||||
if (IsCondCastRequired(next_cond_val))
|
||||
InsertCastForCond(next_cond_val, graph, sub_block->return_node());
|
||||
}
|
||||
|
|
@ -223,6 +232,9 @@ void FixupONNXLoopNodeInputs(Node* node) {
|
|||
std::vector<Value*> FixupONNXLoopNode(Node* node, int opset_version) {
|
||||
auto output_size = node->outputs().size();
|
||||
FixupONNXLoopNodeInputs(node);
|
||||
FixupONNXSubblockOutputs(node);
|
||||
// NOTE: the output order is deliberately changed to match expected order
|
||||
// since onnx loop requires scan outputs to be the last outputs.
|
||||
auto new_outputs = ConvertSequenceDependencies(node, opset_version);
|
||||
TORCH_INTERNAL_ASSERT(output_size == new_outputs.size());
|
||||
return new_outputs;
|
||||
|
|
@ -340,17 +352,7 @@ std::vector<Value*> FixupONNXIfNode(Node* node, int opset_version) {
|
|||
GRAPH_DUMP("Graph before fixing controlflow: ", node->owningGraph());
|
||||
auto* if_node = node;
|
||||
auto* graph = if_node->owningGraph();
|
||||
for (Block* block : node->blocks()) {
|
||||
for (Value* output : block->outputs()) {
|
||||
if (output->node()->owningBlock() != block) {
|
||||
Node* id_node = graph->create(onnx::Identity);
|
||||
id_node->insertBefore(block->return_node());
|
||||
id_node->addInput(output);
|
||||
id_node->output()->copyMetadata(output);
|
||||
block->return_node()->replaceInputWith(output, id_node->output());
|
||||
}
|
||||
}
|
||||
}
|
||||
FixupONNXSubblockOutputs(if_node);
|
||||
ONNXFixupUninitializedOutput(if_node);
|
||||
GRAPH_DUMP("Graph after fixing controlflow: ", node->owningGraph());
|
||||
return if_node->outputs().vec();
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ namespace jit {
|
|||
namespace onnx {
|
||||
using namespace ::c10::onnx;
|
||||
|
||||
}
|
||||
} // namespace onnx
|
||||
|
||||
ValueToParamPairMap buildValueToParamsMap(
|
||||
Block* b,
|
||||
|
|
@ -97,6 +97,44 @@ Value* addInputToBlock(Block* block) {
|
|||
return block->addInput();
|
||||
}
|
||||
|
||||
namespace {
|
||||
::ONNX_NAMESPACE::TensorProto_DataType ATenTypeToOnnxType_aux(
|
||||
at::ScalarType at_type) {
|
||||
switch (at_type) {
|
||||
case at::kDouble:
|
||||
return ::ONNX_NAMESPACE::TensorProto_DataType_DOUBLE;
|
||||
case at::kFloat:
|
||||
return ::ONNX_NAMESPACE::TensorProto_DataType_FLOAT;
|
||||
case at::kHalf:
|
||||
return ::ONNX_NAMESPACE::TensorProto_DataType_FLOAT16;
|
||||
case at::kByte:
|
||||
return ::ONNX_NAMESPACE::TensorProto_DataType_UINT8;
|
||||
case at::kChar:
|
||||
return ::ONNX_NAMESPACE::TensorProto_DataType_INT8;
|
||||
case at::kShort:
|
||||
return ::ONNX_NAMESPACE::TensorProto_DataType_INT16;
|
||||
case at::kInt:
|
||||
return ::ONNX_NAMESPACE::TensorProto_DataType_INT32;
|
||||
case at::kLong:
|
||||
return ::ONNX_NAMESPACE::TensorProto_DataType_INT64;
|
||||
case at::kBool:
|
||||
return ::ONNX_NAMESPACE::TensorProto_DataType_BOOL;
|
||||
case at::kQInt8:
|
||||
return ::ONNX_NAMESPACE::TensorProto_DataType_INT8;
|
||||
case at::kQUInt8:
|
||||
return ::ONNX_NAMESPACE::TensorProto_DataType_UINT8;
|
||||
case at::kQInt32:
|
||||
return ::ONNX_NAMESPACE::TensorProto_DataType_INT32;
|
||||
default:
|
||||
AT_ERROR("unexpected tensor scalar type");
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
int ATenTypeToOnnxType(at::ScalarType at_type) {
|
||||
return static_cast<int>(ATenTypeToOnnxType_aux(at_type));
|
||||
}
|
||||
|
||||
Node* createONNXUnsqueeze(
|
||||
Graph* graph,
|
||||
Node* n_to_insert_before,
|
||||
|
|
|
|||
|
|
@ -3,11 +3,11 @@
|
|||
#include <torch/csrc/jit/api/module.h>
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
// Utility functions for PyTorch to ONNX conversion.
|
||||
|
||||
static const int OPSET_VERSION_1 = 1;
|
||||
static const int OPSET_VERSION_9 = 9;
|
||||
static const int OPSET_VERSION_10 = 10;
|
||||
|
|
@ -19,22 +19,30 @@ using ValueToParamPairMap = std::map<Value*, std::pair<std::string, IValue>>;
|
|||
|
||||
using ParamMap = std::map<std::string, IValue>;
|
||||
|
||||
void buildParamsMapFromValueToParamsMap(
|
||||
TORCH_API void buildParamsMapFromValueToParamsMap(
|
||||
const ValueToParamPairMap& valsToParamsMap,
|
||||
ParamMap& paramsDict);
|
||||
ValueToParamPairMap buildValueToParamsMap(Block* b, const ParamMap& paramsDict);
|
||||
void eraseUnusedValuesFromMap(ValueToParamPairMap& valsToParamsMap);
|
||||
void eraseUnusedBlockInputs(Block* b);
|
||||
void buildParamsMapFromValueToParamsMap(
|
||||
TORCH_API ValueToParamPairMap
|
||||
buildValueToParamsMap(Block* b, const ParamMap& paramsDict);
|
||||
TORCH_API void eraseUnusedValuesFromMap(ValueToParamPairMap& valsToParamsMap);
|
||||
TORCH_API void eraseUnusedBlockInputs(Block* b);
|
||||
TORCH_API void buildParamsMapFromValueToParamsMap(
|
||||
const ValueToParamPairMap& valsToParamsMap,
|
||||
ParamMap& paramsDict);
|
||||
|
||||
Node* addNodeToBlock(Block* block, Symbol kind, ArrayRef<Value*> inputs);
|
||||
TORCH_API Node* addNodeToBlock(
|
||||
Block* block,
|
||||
Symbol kind,
|
||||
ArrayRef<Value*> inputs);
|
||||
|
||||
Value* addInputToBlock(Block* block);
|
||||
TORCH_API Value* addInputToBlock(Block* block);
|
||||
|
||||
TORCH_API c10::optional<at::ScalarType> ONNXTypeToATenType(int32_t onnx_type);
|
||||
|
||||
// Use int return type as no sable way exists to forward declare protobuf enum
|
||||
TORCH_API int ATenTypeToOnnxType(
|
||||
at::ScalarType at_type);
|
||||
|
||||
Node* createONNXUnsqueeze(
|
||||
Graph* graph,
|
||||
Node* n_to_insert_before,
|
||||
|
|
|
|||
45
torch/csrc/jit/passes/onnx/pattern_conversion/common.cpp
Normal file
45
torch/csrc/jit/passes/onnx/pattern_conversion/common.cpp
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
#include <torch/csrc/jit/passes/onnx/pattern_conversion/common.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
bool IndexingPatternFinder::IsSameSource(const Node* n, const Node* m) {
|
||||
const auto& source_n = n->sourceRange().source();
|
||||
const auto& source_m = m->sourceRange().source();
|
||||
return (
|
||||
(source_n->text() == source_m->text()) &&
|
||||
(source_n->starting_line_no() == source_m->starting_line_no()));
|
||||
}
|
||||
|
||||
// Trace back all the slice & select nodes associated with the index_put node.
|
||||
// E.g. The IR for x[1:3, 0] = update
|
||||
// ...
|
||||
// %8 : Float(2, 4) = aten::slice(%0, %4, %5, %6, %7)
|
||||
// ...
|
||||
// %11 : Float(2) = aten::select(%8, %9, %10)
|
||||
// ...
|
||||
// %13 : Tensor?[] = prim::ListConstruct()
|
||||
// ...
|
||||
// %16 : Float(2) = aten::index_put(%11, %13, %14, %15)
|
||||
//
|
||||
// We collect %11 and %8, to construct the index tensors.
|
||||
// The vector slice_and_select_node contains all the associated slice and
|
||||
// select node, in the reversed order.
|
||||
std::vector<Node*> IndexingPatternFinder::FetchSliceAndSelect(
|
||||
const Node* node) {
|
||||
std::vector<Node*> slice_and_select_node;
|
||||
auto src_node = node->input(0)->node();
|
||||
while (src_node) {
|
||||
if ((src_node->kind() == aten::slice || src_node->kind() == aten::select) &&
|
||||
IsSameSource(src_node, node)) {
|
||||
slice_and_select_node.emplace_back(src_node);
|
||||
src_node = src_node->input(0)->node();
|
||||
} else {
|
||||
src_node = nullptr;
|
||||
}
|
||||
}
|
||||
return slice_and_select_node;
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
19
torch/csrc/jit/passes/onnx/pattern_conversion/common.h
Normal file
19
torch/csrc/jit/passes/onnx/pattern_conversion/common.h
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
#pragma once
|
||||
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
|
||||
// Functions used by both encapsulation and conversion.
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
struct IndexingPatternFinder {
|
||||
public:
|
||||
static std::vector<Node*> FetchSliceAndSelect(const Node* node);
|
||||
|
||||
private:
|
||||
static bool IsSameSource(const Node* n, const Node* m);
|
||||
};
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
@ -0,0 +1,373 @@
|
|||
#include <torch/csrc/jit/passes/dead_code_elimination.h>
|
||||
#include <torch/csrc/jit/passes/erase_number_types.h>
|
||||
#include <torch/csrc/jit/passes/onnx.h>
|
||||
#include <torch/csrc/jit/passes/onnx/pattern_conversion/common.h>
|
||||
#include <torch/csrc/jit/passes/onnx/pattern_conversion/pattern_conversion.h>
|
||||
|
||||
// EDITING THIS FILE? READ THIS FIRST!
|
||||
// see Note [Edit Pattern Conversion] in pattern_conversion.h
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
// Converting inplace index_put to ONNX
|
||||
namespace {
|
||||
|
||||
Value* CreateSizeOfDim(Value* input, int64_t dim, Node* insertBefore) {
|
||||
auto graph = input->owningGraph();
|
||||
WithInsertPoint guard(insertBefore);
|
||||
auto size = graph->insert(aten::size, {input, dim});
|
||||
return size;
|
||||
}
|
||||
|
||||
Value* ConvertSelectToIndex(Value* index, Node* insertBefore) {
|
||||
// Create index tensor based on index input of aten::select node.
|
||||
auto graph = insertBefore->owningGraph();
|
||||
WithInsertPoint guard(insertBefore);
|
||||
return graph->insert(aten::unsqueeze, {index, 0});
|
||||
}
|
||||
|
||||
Value* ConvertSliceToIndex(Node* slice, Value* size, Node* insertBefore) {
|
||||
// Create index tensor based on aten::slice node.
|
||||
auto graph = slice->owningGraph();
|
||||
WithInsertPoint guard(insertBefore);
|
||||
TORCH_INTERNAL_ASSERT((slice->inputs()).size() == 5);
|
||||
auto start = slice->inputs()[2];
|
||||
auto end = slice->inputs()[3];
|
||||
auto step = slice->inputs()[4];
|
||||
auto index =
|
||||
graph->insert(aten::arange, {size}, {NamedValue("dtype", c10::kLong)});
|
||||
auto sliced_index_n = graph->create(
|
||||
aten::slice,
|
||||
{index,
|
||||
graph->insertConstant(
|
||||
scalar_to_tensor(at::Scalar(0)), c10::nullopt, slice->scope()),
|
||||
start,
|
||||
end,
|
||||
step});
|
||||
|
||||
auto sliced_index = sliced_index_n->insertBefore(insertBefore)->output();
|
||||
return sliced_index;
|
||||
}
|
||||
|
||||
struct ConvertedIndex {
|
||||
ConvertedIndex(Value* index, c10::Symbol orig_node_kind)
|
||||
: index(index), orig_node_kind(orig_node_kind) {}
|
||||
|
||||
Value* index = nullptr;
|
||||
c10::Symbol orig_node_kind;
|
||||
};
|
||||
|
||||
std::unordered_map<int64_t, ConvertedIndex> MergeSliceAndSelectToIndices(
|
||||
Graph* graph,
|
||||
Node* index_put_node,
|
||||
const std::vector<Node*>& slice_and_select_nodes,
|
||||
Value* orig_data,
|
||||
const std::unordered_map<Value*, Value*>& env) {
|
||||
std::unordered_map<int64_t, ConvertedIndex> dim_index_map;
|
||||
|
||||
// Loop over fetched slice and select nodes and convert them to index tensors.
|
||||
// keep track of which dimension the current slice/select node is applying to.
|
||||
int64_t cur_dim = 0;
|
||||
int64_t dim_offset = 0;
|
||||
const auto orig_tensor_indices = index_put_node->input(1)->node()->inputs();
|
||||
for (auto it = slice_and_select_nodes.rbegin();
|
||||
it != slice_and_select_nodes.rend();
|
||||
++it) {
|
||||
auto node = *it;
|
||||
// select does not keep dims,
|
||||
// this creates offset for latter slice and select nodes.
|
||||
// NOTE: Cannot rely on get(attr::dim), because op no longer match schema.
|
||||
int64_t dim = node->inputs().at(1)->node()->t(attr::value).item().toLong();
|
||||
|
||||
if (dim < 0) {
|
||||
auto input_type = env.at(orig_data)->type()->expect<TensorType>();
|
||||
if (input_type->dim().has_value()) {
|
||||
auto rank = static_cast<int64_t>(input_type->dim().value());
|
||||
// Rank of original tensor to index on.
|
||||
// Minus the offset created by select operators.
|
||||
dim = dim + rank - dim_offset;
|
||||
} else {
|
||||
std::cerr
|
||||
<< "Error: ONNX Remove Inplace Ops - Cannot export ellipsis indexing for input "
|
||||
<< "of unknown rank.";
|
||||
}
|
||||
}
|
||||
dim = dim + dim_offset;
|
||||
while (cur_dim < dim) {
|
||||
// Handle skipped dims, these are created from ..., or tensor indices
|
||||
// E.g.: x[torch.tensor([1, 0]), ..., 0] = update, where x has rank 3.
|
||||
// Both torch.tensor([1, 0]) and ... are skipped, we only observe
|
||||
// aten::select node with dim == 2. Tensor indices will be handled later.
|
||||
// Ellipsis(...) are treated as a complete slice over the axes, thus we
|
||||
// create index tensors here accordingly.
|
||||
if (cur_dim - dim_offset >= (int64_t)orig_tensor_indices.size() ||
|
||||
index_put_node->input(1)
|
||||
->node()
|
||||
->input(cur_dim - dim_offset)
|
||||
->node()
|
||||
->mustBeNone()) {
|
||||
auto size = CreateSizeOfDim(orig_data, cur_dim, index_put_node);
|
||||
WithInsertPoint guard(index_put_node);
|
||||
auto index_tensor = graph->insert(
|
||||
aten::arange, {size}, {NamedValue("dtype", c10::kLong)});
|
||||
dim_index_map.emplace(
|
||||
std::piecewise_construct,
|
||||
std::forward_as_tuple(cur_dim),
|
||||
std::forward_as_tuple(index_tensor, aten::slice));
|
||||
} else if (cur_dim - dim_offset < (int64_t)orig_tensor_indices.size()) {
|
||||
dim_index_map.emplace(
|
||||
std::piecewise_construct,
|
||||
std::forward_as_tuple(cur_dim),
|
||||
std::forward_as_tuple(
|
||||
orig_tensor_indices[cur_dim - dim_offset], aten::index));
|
||||
}
|
||||
cur_dim++;
|
||||
}
|
||||
|
||||
AT_ASSERT(cur_dim == dim);
|
||||
if (node->kind() == aten::slice) {
|
||||
auto size = CreateSizeOfDim(orig_data, dim, index_put_node);
|
||||
auto index_tensor = ConvertSliceToIndex(node, size, index_put_node);
|
||||
dim_index_map.emplace(
|
||||
std::piecewise_construct,
|
||||
std::forward_as_tuple(dim),
|
||||
std::forward_as_tuple(index_tensor, aten::slice));
|
||||
} else if (node->kind() == aten::select) {
|
||||
auto index_tensor = ConvertSelectToIndex(node->input(2), index_put_node);
|
||||
dim_index_map.emplace(
|
||||
std::piecewise_construct,
|
||||
std::forward_as_tuple(dim),
|
||||
std::forward_as_tuple(index_tensor, aten::select));
|
||||
dim_offset++;
|
||||
} else {
|
||||
AT_ERROR(
|
||||
"Unexpected node kind ",
|
||||
node->kind().toDisplayString(),
|
||||
" Expected aten::slice or aten::select.");
|
||||
}
|
||||
|
||||
cur_dim++;
|
||||
}
|
||||
|
||||
while (cur_dim - dim_offset < (int64_t)orig_tensor_indices.size()) {
|
||||
dim_index_map.emplace(
|
||||
std::piecewise_construct,
|
||||
std::forward_as_tuple(cur_dim),
|
||||
std::forward_as_tuple(
|
||||
orig_tensor_indices[cur_dim - dim_offset], aten::index));
|
||||
cur_dim++;
|
||||
}
|
||||
|
||||
// Each dimension should have its associated index tensor.
|
||||
AT_ASSERT((int64_t)dim_index_map.size() == cur_dim);
|
||||
return dim_index_map;
|
||||
}
|
||||
|
||||
// Convert slice/select operators to tensor indices.
|
||||
// Reshape the tensor indices according to their axis.
|
||||
// E.g. x[1:3, 0, ind1, ind2] = y
|
||||
// slice index shape: [2, 1, 1 ]
|
||||
// select index shape: [ 1, 1 ]
|
||||
// ind1 shape: [ _ ]
|
||||
// ind2 shape: [ _ ]
|
||||
// where _ is the original size of ind1 and ind2.
|
||||
// ind1 and ind2 are both 1-d tensors since currently we only supports 1-d
|
||||
// tensor indices.
|
||||
std::vector<Value*> ReshapeToAdvancedIndexingFormat(
|
||||
Graph* graph,
|
||||
Node* index_put_node,
|
||||
std::unordered_map<int64_t, ConvertedIndex>& dim_index_map) {
|
||||
std::vector<Value*> indices;
|
||||
|
||||
size_t min_index_dim = dim_index_map.size();
|
||||
size_t max_index_dim = 0;
|
||||
size_t tensor_ind_count = 0;
|
||||
for (size_t i = 0; i < dim_index_map.size(); ++i) {
|
||||
auto index_i = dim_index_map.find(i);
|
||||
AT_ASSERT(index_i != dim_index_map.end());
|
||||
if (index_i->second.orig_node_kind == aten::index) {
|
||||
if (i < min_index_dim)
|
||||
min_index_dim = i;
|
||||
if (i > max_index_dim)
|
||||
max_index_dim = i;
|
||||
tensor_ind_count++;
|
||||
}
|
||||
}
|
||||
|
||||
if (((max_index_dim - min_index_dim + 1) != tensor_ind_count) &&
|
||||
tensor_ind_count != 0) {
|
||||
AT_ERROR(
|
||||
"Only consecutive 1-d tensor indices are supported in exporting aten::index_put to ONNX.");
|
||||
}
|
||||
|
||||
size_t tensor_ind_offset = tensor_ind_count == 0 ? 0 : tensor_ind_count - 1;
|
||||
WithInsertPoint guard(index_put_node);
|
||||
for (size_t i = 0; i < dim_index_map.size(); ++i) {
|
||||
size_t ind_size = 0;
|
||||
auto index_i = dim_index_map.find(i);
|
||||
AT_ASSERT(index_i != dim_index_map.end());
|
||||
Value* index = index_i->second.index;
|
||||
switch (index_i->second.orig_node_kind) {
|
||||
case aten::select:
|
||||
case aten::slice: {
|
||||
if (i < min_index_dim) {
|
||||
ind_size = dim_index_map.size() - tensor_ind_offset - i;
|
||||
} else {
|
||||
ind_size = dim_index_map.size() - i;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case aten::index: {
|
||||
ind_size = dim_index_map.size() - tensor_ind_offset - min_index_dim;
|
||||
break;
|
||||
}
|
||||
default:
|
||||
AT_ERROR("Unexpected node kind ", index_i->second.orig_node_kind);
|
||||
}
|
||||
|
||||
if (ind_size != 1) {
|
||||
std::vector<int64_t> view_shape(ind_size, 1);
|
||||
view_shape[0] = -1;
|
||||
auto unsqueezed_index = graph->insert(aten::view, {index, view_shape});
|
||||
indices.emplace_back(unsqueezed_index);
|
||||
} else {
|
||||
indices.emplace_back(index);
|
||||
}
|
||||
}
|
||||
|
||||
return indices;
|
||||
}
|
||||
|
||||
// Trace back all the slice & select nodes associated with the index_put node,
|
||||
// and convert them to associated indices.
|
||||
// E.g. The IR for x[1:3, 0] = update
|
||||
// ...
|
||||
// %8 : Float(2, 4) = aten::slice(%0, %4, %5, %6, %7)
|
||||
// ...
|
||||
// %11 : Float(2) = aten::select(%8, %9, %10)
|
||||
// ...
|
||||
// %13 : Tensor?[] = prim::ListConstruct()
|
||||
// ...
|
||||
// %16 : Float(2) = aten::index_put(%11, %13, %14, %15)
|
||||
// The aten::index_put node alone does not contain any indices (%13 : Tensor?[]
|
||||
// = prim::ListConstruct()).
|
||||
// ...
|
||||
// # Below constructs index from slice node.
|
||||
// %23 : Long() = aten::size(%0, %4)
|
||||
// %28 : Tensor = aten::arange(%23, %24, %25, %26, %27)
|
||||
// %33 : Tensor = aten::slice(%28, %4, %5, %6, %7)
|
||||
// %39 : int[] = prim::Constant[value=[-1, 1]]()
|
||||
// %40 : Tensor = aten::view(%33, %39)
|
||||
// ...
|
||||
// # Below constructs index from select node.
|
||||
// %36 : int = prim::Constant[value=0]()
|
||||
// %37 : Tensor = aten::unsqueeze(%10, %36)
|
||||
// %42 : int[] = prim::Constant[value=[-1]]()
|
||||
// %43 : Tensor = aten::view(%37, %42)
|
||||
// ...
|
||||
// # Adding the above two indices to index_put
|
||||
// %44 : Tensor?[] = prim::ListConstruct(%40, %43)
|
||||
// %45 : Float(2, 5) = aten::index_put(%0, %44, %14, %15)
|
||||
std::vector<Value*> ConvertIndexPutToONNX(
|
||||
Block* new_block,
|
||||
Node* old_node,
|
||||
std::unordered_map<Value*, Value*>& env) {
|
||||
if (old_node->kind() != Symbol::fromQualString("onnx::Placeholder") ||
|
||||
(old_node->s(attr::name) != "index_put" &&
|
||||
old_node->s(attr::name) != "index_put_")) {
|
||||
return {};
|
||||
}
|
||||
|
||||
TORCH_INTERNAL_ASSERT(old_node->blocks().size() == 1);
|
||||
auto old_graph = old_node->owningGraph();
|
||||
auto subblock = old_node->blocks()[0];
|
||||
auto index_put_node = subblock->nodes().back()->prev();
|
||||
|
||||
// Find slice and select operators that are associated with this index
|
||||
// operator. E.g. x[1:3, 0] = y will generate one slice operator(1:3) and one
|
||||
// select operator(0).
|
||||
std::vector<Node*> slice_and_select_nodes =
|
||||
IndexingPatternFinder::FetchSliceAndSelect(index_put_node);
|
||||
Node* last_node = slice_and_select_nodes.size() > 0
|
||||
? slice_and_select_nodes.back()
|
||||
: index_put_node;
|
||||
// Update inner block input originates from outside.
|
||||
last_node->replaceInput(0, old_node->input(0));
|
||||
Value* orig_data = last_node->input(0);
|
||||
|
||||
// Convert slice and select operators to indices.
|
||||
std::unordered_map<int64_t, ConvertedIndex> dim_index_map =
|
||||
MergeSliceAndSelectToIndices(
|
||||
old_graph, index_put_node, slice_and_select_nodes, orig_data, env);
|
||||
|
||||
// Reshape indices to advanced indexing format.
|
||||
std::vector<Value*> indices =
|
||||
ReshapeToAdvancedIndexingFormat(old_graph, index_put_node, dim_index_map);
|
||||
|
||||
// Create new index_put node with converted indices.
|
||||
const auto list_indices =
|
||||
old_graph->createList(OptionalType::ofTensor(), indices)
|
||||
->insertBefore(index_put_node)
|
||||
->output();
|
||||
auto new_index_put_node = old_graph->create(
|
||||
aten::index_put,
|
||||
{orig_data,
|
||||
list_indices,
|
||||
index_put_node->input(2),
|
||||
index_put_node->input(3)});
|
||||
new_index_put_node->insertBefore(index_put_node);
|
||||
auto new_index_put = new_index_put_node->output();
|
||||
new_index_put->copyMetadata(index_put_node->output());
|
||||
index_put_node->output()->replaceAllUsesWith(new_index_put);
|
||||
|
||||
// Convert aten type to onnx type.
|
||||
EraseNumberTypesOnBlock(subblock);
|
||||
EliminateDeadCode(
|
||||
subblock,
|
||||
true,
|
||||
DCESideEffectPolicy::ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS);
|
||||
|
||||
// Convert all the new aten nodes that were just created to onnx.
|
||||
// New onnx nodes are appended at the end of new_block.
|
||||
for (auto at_n : subblock->nodes()) {
|
||||
if (at_n == subblock->param_node() || at_n == subblock->return_node()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
NodeToONNX(at_n, new_block, torch::onnx::OperatorExportTypes::ONNX, env);
|
||||
}
|
||||
|
||||
// Find onnx outputs corresponding to the aten outputs of index_put.
|
||||
std::vector<Value*> outs;
|
||||
for (auto o : subblock->return_node()->inputs()) {
|
||||
outs.emplace_back(env[o]);
|
||||
}
|
||||
return outs;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::vector<Value*> ConvertPatternFromSubblock(
|
||||
Block* new_block,
|
||||
Node* old_node,
|
||||
std::unordered_map<Value*, Value*>& env) {
|
||||
std::vector<Value*> res;
|
||||
|
||||
if (old_node->kind() != Symbol::fromQualString("onnx::Placeholder")) {
|
||||
return res;
|
||||
}
|
||||
|
||||
// The pattern conversion code should not alter nodes outside the Placeholder
|
||||
// subblock.
|
||||
auto op_name = old_node->s(attr::name);
|
||||
if (op_name == "index_put" || op_name == "index_put_") {
|
||||
res = ConvertIndexPutToONNX(new_block, old_node, env);
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
@ -0,0 +1,44 @@
|
|||
#pragma once
|
||||
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
// Introduction
|
||||
//
|
||||
// The conversion part is called inside the onnx pass.
|
||||
// In onnx pass, _run_symbolic_function will be called for each node in
|
||||
// topological order. When it reaches the placeholder node, this function will
|
||||
// be invoked. It will convert the nodes inside the sub-block based on pattern.
|
||||
// By that time, it will have shape/type of upstream operators available. After
|
||||
// the conversion is complete, the placeholder node will be removed, and nodes
|
||||
// inside its sub-block converted. NodeToONNX will be called for these
|
||||
// nodes, and they will be converted from ATen operator to ONNX operator.
|
||||
//
|
||||
// Note: Edit Pattern Conversion
|
||||
//
|
||||
// Each pattern is differentiated by the name attribute of placeholder node.
|
||||
// The placeholder node is part of torch IR graph, After this function, the aten
|
||||
// nodes under placeholder node subblock will be converted to ONNX and appended
|
||||
// to the new_block, which is under the new ONNX graph. For the pattern
|
||||
// conversion code, it can be divided into three parts.
|
||||
// 1. Nodes in this pattern should be captured inside the subblock of
|
||||
// Placeholder node after pattern encapsulation[see
|
||||
// pattern_encapsulation.h]. These nodes will be converted based on
|
||||
// pattern. This part of conversion is from aten to aten. It happens on
|
||||
// the torch IR graph inside placeholder node subblock.
|
||||
// 2. The second part of conversion is to convert the aten nodes produced
|
||||
// into ONNX. This is done by calling NodeToONNX for each node. The new
|
||||
// ONNX nodes are appended to the new_block, which is under the new ONNX
|
||||
// graph.
|
||||
// 3. The last part of conversion is to find and return, in the same order,
|
||||
// the ONNX outputs corresponding to the original output for the
|
||||
// placeholder node.
|
||||
TORCH_API std::vector<Value*> ConvertPatternFromSubblock(
|
||||
Block* new_block,
|
||||
Node* old_node,
|
||||
std::unordered_map<Value*, Value*>& env);
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
@ -0,0 +1,91 @@
|
|||
#include <torch/csrc/jit/passes/dead_code_elimination.h>
|
||||
#include <torch/csrc/jit/passes/onnx.h>
|
||||
#include <torch/csrc/jit/passes/onnx/pattern_conversion/common.h>
|
||||
#include <torch/csrc/jit/passes/onnx/pattern_conversion/pattern_encapsulation.h>
|
||||
#include <torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.h>
|
||||
|
||||
// EDITING THIS FILE? READ THIS FIRST!
|
||||
// see Note [Edit Pattern Encapsulation] in pattern_encapsulation.h
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
namespace {
|
||||
|
||||
// Trace back all the slice & select nodes associated with the index_put node,
|
||||
// and copy them under the placeholder subblock.
|
||||
// E.g. The IR for x[1:3, 0] = update
|
||||
// ...
|
||||
// %8 : Float(2, 4) = aten::slice(%0, %4, %5, %6, %7)
|
||||
// ...
|
||||
// %11 : Float(2) = aten::select(%8, %9, %10)
|
||||
// ...
|
||||
// %13 : Tensor?[] = prim::ListConstruct()
|
||||
// ...
|
||||
// %16 : Float(2) = aten::index_put(%11, %13, %14, %15)
|
||||
// The aten::index_put node alone does not contain any indices (%13 : Tensor?[]
|
||||
// = prim::ListConstruct()).
|
||||
Node* EncapsulateInplaceIndexPutForONNX(Node* index_put_node) {
|
||||
auto graph = index_put_node->owningGraph();
|
||||
|
||||
// Find slice and select operators that are associated with this index
|
||||
// operator. E.g. x[1:3, 0] = y will generate one slice operator(1:3) and one
|
||||
// select operator(0).
|
||||
std::vector<Node*> slice_and_select_nodes =
|
||||
IndexingPatternFinder::FetchSliceAndSelect(index_put_node);
|
||||
Node* last_node = slice_and_select_nodes.size() > 0
|
||||
? slice_and_select_nodes.back()
|
||||
: index_put_node;
|
||||
Value* orig_data = last_node->input(0);
|
||||
|
||||
// Copy related nodes into subblock of a new special placeholder node.
|
||||
Node* placeholder_node =
|
||||
graph->create(Symbol::fromQualString("onnx::Placeholder"));
|
||||
placeholder_node->s_(attr::name, index_put_node->kind().toUnqualString());
|
||||
placeholder_node->addInput(orig_data);
|
||||
|
||||
// Construct subblock
|
||||
auto subblock = placeholder_node->addBlock();
|
||||
std::unordered_map<Value*, Value*> env;
|
||||
|
||||
// slice_and_select_nodes are in reversed order.
|
||||
for (auto it = slice_and_select_nodes.rbegin();
|
||||
it != slice_and_select_nodes.rend();
|
||||
++it) {
|
||||
auto n = *it;
|
||||
auto cloned_n = subblock->appendNode(graph->createClone(
|
||||
n, [&](Value* v) { return env.find(v) != env.end() ? env[v] : v; }));
|
||||
for (size_t i = 0; i < cloned_n->outputs().size(); ++i) {
|
||||
env[n->outputs().at(i)] = cloned_n->outputs().at(i);
|
||||
}
|
||||
}
|
||||
|
||||
Node* new_index_put_node =
|
||||
subblock->appendNode(graph->createClone(index_put_node, [&](Value* v) {
|
||||
return env.find(v) != env.end() ? env[v] : v;
|
||||
}));
|
||||
for (auto o : new_index_put_node->outputs()) {
|
||||
subblock->registerOutput(o);
|
||||
}
|
||||
|
||||
placeholder_node->insertBefore(index_put_node);
|
||||
placeholder_node->copyMetadata(index_put_node);
|
||||
index_put_node->replaceAllUsesWith(placeholder_node);
|
||||
|
||||
return placeholder_node;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
c10::optional<Node*> EncapsulatePatternIntoSubblock(Node* n) {
|
||||
switch (n->kind()) {
|
||||
case aten::index_put_:
|
||||
case aten::index_put: {
|
||||
return EncapsulateInplaceIndexPutForONNX(n);
|
||||
}
|
||||
}
|
||||
return c10::nullopt;
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
@ -0,0 +1,34 @@
|
|||
#pragma once
|
||||
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
// Introduction
|
||||
//
|
||||
// The encapsulation part will find the nodes of patterns, like how other
|
||||
// pre-onnx passes are written. But instead of converting the nodes, it will
|
||||
// encapsulate them into a sub-block of a new placeholder node. This part is
|
||||
// called before onnx pass, so it runs before calling symbolic functions.
|
||||
//
|
||||
// Note: Why separate the function into two parts
|
||||
//
|
||||
// The purpose is to support conversions that depend on shape and type
|
||||
// information. Shape and type information is only available after
|
||||
// _jit_pass_onnx, which converts aten nodes to onnx nodes. So there is a
|
||||
// interdependent issue. _jit_pass_onnx depends on preprocess passes to convert
|
||||
// aten nodes into convertable condition, and preprocess passes depend on
|
||||
// _jit_pass_onnx to convert upstream nodes and apply onnx shape inference.
|
||||
// Separating the pass into two parts breaks the interdependency.
|
||||
//
|
||||
// Note: Edit Pattern Encapsulation
|
||||
//
|
||||
// Encapsulation step identifies the pattern, and copies the nodes into
|
||||
// the subblock of a new placeholder node. The outputs of the new placeholder
|
||||
// node are used in place of the original nodes instead. The category of the
|
||||
// pattern is stored as attr::name.
|
||||
TORCH_API c10::optional<Node*> EncapsulatePatternIntoSubblock(Node* n);
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
@ -7,6 +7,7 @@
|
|||
#include <torch/csrc/jit/jit_log.h>
|
||||
#include <torch/csrc/jit/passes/dead_code_elimination.h>
|
||||
#include <torch/csrc/jit/passes/onnx/helper.h>
|
||||
#include <torch/csrc/jit/passes/onnx/pattern_conversion/pattern_encapsulation.h>
|
||||
|
||||
#include <limits>
|
||||
|
||||
|
|
@ -17,273 +18,22 @@ namespace {
|
|||
|
||||
const std::set<c10::Symbol> inplace_ops = {
|
||||
aten::append,
|
||||
aten::index_put,
|
||||
aten::index_put_,
|
||||
aten::pop,
|
||||
aten::insert,
|
||||
aten::Delete};
|
||||
|
||||
Value* CreateSizeOfDim(Value* input, int64_t dim, Node* insertBefore) {
|
||||
auto graph = input->owningGraph();
|
||||
WithInsertPoint guard(insertBefore);
|
||||
auto size = graph->insert(aten::size, {input, dim});
|
||||
return size;
|
||||
}
|
||||
|
||||
Value* ConvertSelectToIndex(Value* index, Node* insertBefore) {
|
||||
// Create index tensor based on index input of aten::select node.
|
||||
auto graph = insertBefore->owningGraph();
|
||||
WithInsertPoint guard(insertBefore);
|
||||
auto idx_tensor = graph->createNumToTensor(index);
|
||||
graph->insertNode(idx_tensor);
|
||||
return graph->insert(aten::unsqueeze, {idx_tensor->output(), 0});
|
||||
}
|
||||
|
||||
Value* ConvertSliceToIndex(Node* slice, Value* size, Node* insertBefore) {
|
||||
// Create index tensor based on aten::slice node.
|
||||
const int64_t int_max = std::numeric_limits<int>::max();
|
||||
auto graph = slice->owningGraph();
|
||||
WithInsertPoint guard(insertBefore);
|
||||
TORCH_INTERNAL_ASSERT((slice->inputs()).size() == 5);
|
||||
auto start = slice->inputs()[2];
|
||||
auto end = slice->inputs()[3];
|
||||
auto step = slice->inputs()[4];
|
||||
auto index =
|
||||
graph->insert(aten::arange, {size}, {NamedValue("dtype", c10::kLong)});
|
||||
auto sliced_index =
|
||||
graph->insert(aten::slice, {index, {0}, start, end, step});
|
||||
return sliced_index;
|
||||
}
|
||||
|
||||
Value* CreateCompleteIndexTensor(Value* size, Node* insertBefore) {
|
||||
// Create index tensor of size.
|
||||
// The result is torch.tensor([0, 1, 2, ..., size - 1])
|
||||
auto graph = size->owningGraph();
|
||||
WithInsertPoint guard(insertBefore);
|
||||
auto index =
|
||||
graph->insert(aten::arange, {size}, {NamedValue("dtype", c10::kLong)});
|
||||
return index;
|
||||
}
|
||||
|
||||
bool IsSameSource(const Node* n, const Node* m) {
|
||||
const auto& source_n = n->sourceRange().source();
|
||||
const auto& source_m = m->sourceRange().source();
|
||||
return (
|
||||
(source_n->text() == source_m->text()) &&
|
||||
(source_n->starting_line_no() == source_m->starting_line_no()));
|
||||
}
|
||||
|
||||
// Trace back all the slice & select nodes associated with the index_put node.
|
||||
// E.g. The IR for x[1:3, 0] = update
|
||||
// ...
|
||||
// %8 : Float(2, 4) = aten::slice(%0, %4, %5, %6, %7)
|
||||
// ...
|
||||
// %11 : Float(2) = aten::select(%8, %9, %10)
|
||||
// ...
|
||||
// %13 : Tensor?[] = prim::ListConstruct()
|
||||
// ...
|
||||
// %16 : Float(2) = aten::index_put(%11, %13, %14, %15)
|
||||
//
|
||||
// We collect %11 and %8, to construct the index tensors.
|
||||
// The vector slice_and_select_node contains all the associated slice and
|
||||
// select node, in the reversed order.
|
||||
std::vector<Node*> FetchSliceAndSelect(const Node* index_put_node) {
|
||||
std::vector<Node*> slice_and_select_node;
|
||||
auto src_node = index_put_node->input(0)->node();
|
||||
while (src_node) {
|
||||
if ((src_node->kind() == aten::slice || src_node->kind() == aten::select) &&
|
||||
IsSameSource(src_node, index_put_node)) {
|
||||
slice_and_select_node.emplace_back(src_node);
|
||||
src_node = src_node->input(0)->node();
|
||||
} else {
|
||||
src_node = nullptr;
|
||||
}
|
||||
}
|
||||
return slice_and_select_node;
|
||||
}
|
||||
|
||||
struct ConvertedIndex {
|
||||
ConvertedIndex(Value* index, c10::Symbol orig_node_kind)
|
||||
: index(index), orig_node_kind(orig_node_kind) {}
|
||||
|
||||
Value* index = nullptr;
|
||||
c10::Symbol orig_node_kind;
|
||||
};
|
||||
|
||||
std::unordered_map<int64_t, ConvertedIndex> MergeSliceAndSelectToIndices(
|
||||
Graph* graph,
|
||||
Node* index_put_node,
|
||||
const std::vector<Node*>& slice_and_select_nodes,
|
||||
Value* orig_data) {
|
||||
std::unordered_map<int64_t, ConvertedIndex> dim_index_map;
|
||||
|
||||
// Loop over fetched slice and select nodes and convert them to index tensors.
|
||||
// keep track of which dimension the current slice/select node is applying to.
|
||||
int64_t cur_dim = 0;
|
||||
int64_t dim_offset = 0;
|
||||
const auto orig_tensor_indices = index_put_node->input(1)->node()->inputs();
|
||||
for (auto it = slice_and_select_nodes.rbegin();
|
||||
it != slice_and_select_nodes.rend();
|
||||
++it) {
|
||||
auto node = *it;
|
||||
// select does not keep dims,
|
||||
// this creates offset for latter slice and select nodes.
|
||||
auto dim = node->get(attr::dim)->toInt();
|
||||
if (dim < 0) {
|
||||
auto input_type = orig_data->type()->expect<TensorType>();
|
||||
if (input_type->dim().has_value()) {
|
||||
auto rank = input_type->dim().value();
|
||||
// Rank of original tensor to index on.
|
||||
// Minus the offset created by select operators.
|
||||
dim = dim + rank - dim_offset;
|
||||
} else {
|
||||
std::cerr
|
||||
<< "Error: ONNX Remove Inplace Ops - Cannot export ellipsis indexing for input "
|
||||
<< "of unknown rank.";
|
||||
}
|
||||
}
|
||||
dim = dim + dim_offset;
|
||||
|
||||
while (cur_dim < dim) {
|
||||
// Handle skipped dims, these are created from ..., or tensor indices
|
||||
// E.g.: x[torch.tensor([1, 0]), ..., 0] = update, where x has rank 3.
|
||||
// Both torch.tensor([1, 0]) and ... are skipped, we only observe
|
||||
// aten::select node with dim == 2. Tensor indices will be handled later.
|
||||
// Ellipsis(...) are treated as a complete slice over the axes, thus we
|
||||
// create index tensors here accordingly.
|
||||
if (cur_dim - dim_offset >= orig_tensor_indices.size() ||
|
||||
index_put_node->input(1)
|
||||
->node()
|
||||
->input(cur_dim - dim_offset)
|
||||
->node()
|
||||
->mustBeNone()) {
|
||||
auto size = CreateSizeOfDim(orig_data, cur_dim, index_put_node);
|
||||
WithInsertPoint guard(index_put_node);
|
||||
auto index_tensor = graph->insert(
|
||||
aten::arange, {size}, {NamedValue("dtype", c10::kLong)});
|
||||
dim_index_map.emplace(
|
||||
std::piecewise_construct,
|
||||
std::forward_as_tuple(cur_dim),
|
||||
std::forward_as_tuple(index_tensor, aten::slice));
|
||||
} else if (cur_dim - dim_offset < orig_tensor_indices.size()) {
|
||||
dim_index_map.emplace(
|
||||
std::piecewise_construct,
|
||||
std::forward_as_tuple(cur_dim),
|
||||
std::forward_as_tuple(
|
||||
orig_tensor_indices[cur_dim - dim_offset], aten::index));
|
||||
}
|
||||
cur_dim++;
|
||||
}
|
||||
|
||||
AT_ASSERT(cur_dim == dim);
|
||||
if (node->kind() == aten::slice) {
|
||||
auto size = CreateSizeOfDim(orig_data, dim, index_put_node);
|
||||
auto index_tensor = ConvertSliceToIndex(node, size, index_put_node);
|
||||
dim_index_map.emplace(
|
||||
std::piecewise_construct,
|
||||
std::forward_as_tuple(dim),
|
||||
std::forward_as_tuple(index_tensor, aten::slice));
|
||||
} else if (node->kind() == aten::select) {
|
||||
auto index_tensor = ConvertSelectToIndex(node->input(2), index_put_node);
|
||||
dim_index_map.emplace(
|
||||
std::piecewise_construct,
|
||||
std::forward_as_tuple(dim),
|
||||
std::forward_as_tuple(index_tensor, aten::select));
|
||||
dim_offset++;
|
||||
} else {
|
||||
AT_ERROR(
|
||||
"Unexpected node kind ",
|
||||
node->kind().toDisplayString(),
|
||||
" Expected aten::slice or aten::select.");
|
||||
}
|
||||
|
||||
cur_dim++;
|
||||
bool IsInplaceNode(const Node* n) {
|
||||
if (inplace_ops.find(n->kind()) != inplace_ops.end()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
while (cur_dim - dim_offset < orig_tensor_indices.size()) {
|
||||
dim_index_map.emplace(
|
||||
std::piecewise_construct,
|
||||
std::forward_as_tuple(cur_dim),
|
||||
std::forward_as_tuple(
|
||||
orig_tensor_indices[cur_dim - dim_offset], aten::index));
|
||||
cur_dim++;
|
||||
if (n->kind() == Symbol::fromQualString("onnx::Placeholder") &&
|
||||
n->s(attr::name) == "index_put_") {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Each dimension should have its associated index tensor.
|
||||
AT_ASSERT(dim_index_map.size() == cur_dim);
|
||||
return dim_index_map;
|
||||
}
|
||||
|
||||
// Convert slice/select operators to tensor indices.
|
||||
// Reshape the tensor indices according to their axis.
|
||||
// E.g. x[1:3, 0, ind1, ind2] = y
|
||||
// slice index shape: [2, 1, 1 ]
|
||||
// select index shape: [ 1, 1 ]
|
||||
// ind1 shape: [ _ ]
|
||||
// ind2 shape: [ _ ]
|
||||
// where _ is the original size of ind1 and ind2.
|
||||
// ind1 and ind2 are both 1-d tensors since currently we only supports 1-d
|
||||
// tensor indices.
|
||||
std::vector<Value*> ReshapeToAdvancedIndexingFormat(
|
||||
Graph* graph,
|
||||
Node* index_put_node,
|
||||
std::unordered_map<int64_t, ConvertedIndex>& dim_index_map) {
|
||||
std::vector<Value*> indices;
|
||||
|
||||
size_t min_index_dim = dim_index_map.size();
|
||||
size_t max_index_dim = 0;
|
||||
size_t tensor_ind_count = 0;
|
||||
for (size_t i = 0; i < dim_index_map.size(); ++i) {
|
||||
auto index_i = dim_index_map.find(i);
|
||||
AT_ASSERT(index_i != dim_index_map.end());
|
||||
if (index_i->second.orig_node_kind == aten::index) {
|
||||
if (i < min_index_dim)
|
||||
min_index_dim = i;
|
||||
if (i > max_index_dim)
|
||||
max_index_dim = i;
|
||||
tensor_ind_count++;
|
||||
}
|
||||
}
|
||||
|
||||
if (((max_index_dim - min_index_dim + 1) != tensor_ind_count) &&
|
||||
tensor_ind_count != 0) {
|
||||
AT_ERROR(
|
||||
"Only consecutive 1-d tensor indices are supported in exporting aten::index_put to ONNX.");
|
||||
}
|
||||
|
||||
size_t tensor_ind_offset = tensor_ind_count == 0 ? 0 : tensor_ind_count - 1;
|
||||
WithInsertPoint guard(index_put_node);
|
||||
for (size_t i = 0; i < dim_index_map.size(); ++i) {
|
||||
size_t ind_size = 0;
|
||||
auto index_i = dim_index_map.find(i);
|
||||
AT_ASSERT(index_i != dim_index_map.end());
|
||||
Value* index = index_i->second.index;
|
||||
switch (index_i->second.orig_node_kind) {
|
||||
case aten::select:
|
||||
case aten::slice: {
|
||||
if (i < min_index_dim) {
|
||||
ind_size = dim_index_map.size() - tensor_ind_offset - i;
|
||||
} else {
|
||||
ind_size = dim_index_map.size() - i;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case aten::index: {
|
||||
ind_size = dim_index_map.size() - tensor_ind_offset - min_index_dim;
|
||||
break;
|
||||
}
|
||||
default:
|
||||
AT_ERROR("Unexpected node kind ", index_i->second.orig_node_kind);
|
||||
}
|
||||
|
||||
std::vector<int64_t> view_shape(ind_size, 1);
|
||||
view_shape[0] = -1;
|
||||
auto unsqueezed_index = graph->insert(aten::view, {index, view_shape});
|
||||
indices.emplace_back(unsqueezed_index);
|
||||
}
|
||||
|
||||
return indices;
|
||||
return false;
|
||||
}
|
||||
|
||||
Node* addDummyCloneToBlock(Block* b, Value* orig_data) {
|
||||
|
|
@ -318,7 +68,6 @@ Value* MatchIfBlocksOutputForValue(
|
|||
Value* origOutput) {
|
||||
if (outer_block->owningNode()->kind() != prim::If)
|
||||
return nullptr;
|
||||
|
||||
size_t output_size = outer_block->outputs().size();
|
||||
|
||||
for (size_t i = 0; i < output_size - 1; i++) {
|
||||
|
|
@ -344,83 +93,100 @@ Value* MatchIfBlocksOutputForValue(
|
|||
return outer_block->owningNode()->outputs().at(output_size - 1);
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
// Register inplace op node inputs/outputs through the blocks.
|
||||
// Eg. The IR before updating:
|
||||
//%23 : bool = aten::eq(%22, %13)
|
||||
// = prim::If(%23) # test/onnx/test_pytorch_onnx_onnxruntime.py:6243:12
|
||||
// block0():
|
||||
// %24 : int[] = prim::ListConstruct(%batch_size.1, %6, %spatial_size_0.1,
|
||||
// %spatial_size_1.1) %25 : Tensor = aten::ones(%24, %12, %12, %12, %12) %26
|
||||
// : Tensor = aten::slice(%state.1, %13, %13, %10, %11) %27 : Tensor =
|
||||
// aten::copy_(%26, %25, %9)
|
||||
// %24 : int[] = prim::ListConstruct(%batch_size.1, %6, %spatial_size_0.1, %spatial_size_1.1)
|
||||
// %25 : Tensor = aten::ones(%24, %12, %12, %12, %12)
|
||||
// %26 : Tensor = aten::slice(%state.1, %13, %13, %10, %11)
|
||||
// %27 : Tensor = aten::copy_(%26, %25, %9)
|
||||
// -> ()
|
||||
// block1():
|
||||
// %28 : int[] = prim::ListConstruct(%batch_size.1, %6, %spatial_size_0.1,
|
||||
// %spatial_size_1.1) %29 : Tensor = aten::randn(%28, %12, %12, %12, %12) %30
|
||||
// : Tensor = aten::slice(%state.1, %13, %13, %10, %11) %31 : Tensor =
|
||||
// aten::copy_(%30, %29, %9)
|
||||
// %28 : int[] = prim::ListConstruct(%batch_size.1, %6, %spatial_size_0.1, %spatial_size_1.1)
|
||||
// %29 : Tensor = aten::randn(%28, %12, %12, %12, %12)
|
||||
// %30 : Tensor = aten::slice(%state.1, %13, %13, %10, %11)
|
||||
// %31 : Tensor = aten::copy_(%30, %29, %9)
|
||||
// -> ()
|
||||
// After updating:
|
||||
//%23 : bool = aten::eq(%22, %13)
|
||||
//%51 : Tensor = prim::If(%23)
|
||||
// block0():
|
||||
// %24 : int[] = prim::ListConstruct(%batch_size.1, %6, %spatial_size_0.1,
|
||||
// %spatial_size_1.1) %25 : Tensor = aten::ones(%24, %12, %12, %12, %12) %26
|
||||
// : Tensor = aten::slice(%state.1, %13, %13, %10, %11) %32 : Tensor?[] =
|
||||
// prim::ListConstruct() %33 : Tensor = aten::expand_as(%25, %26) %38 : int =
|
||||
// prim::Constant[value=0]() %39 : int = aten::size(%state.1, %38) %40 : int
|
||||
// = prim::Constant[value=4]() %41 : None = prim::Constant() %42 : None =
|
||||
// prim::Constant() %43 : None = prim::Constant() %44 : Tensor =
|
||||
// aten::arange(%39, %40, %41, %42, %43) %45 : int =
|
||||
// prim::Constant[value=0]() %46 : Tensor = aten::slice(%44, %45, %13, %10,
|
||||
// %11) %47 : int[] = prim::Constant[value=[-1]]() %48 : Tensor =
|
||||
// aten::view(%46, %47) %49 : Tensor?[] = prim::ListConstruct(%48) %50 :
|
||||
// Tensor = aten::index_put(%state.1, %49, %33, %9)
|
||||
// %24 : int[] = prim::ListConstruct(%batch_size.1, %6, %spatial_size_0.1, %spatial_size_1.1)
|
||||
// %25 : Tensor = aten::ones(%24, %12, %12, %12, %12)
|
||||
// %26 : Tensor = aten::slice(%state.1, %13, %13, %10, %11)
|
||||
// %32 : Tensor?[] = prim::ListConstruct()
|
||||
// %33 : Tensor = aten::expand_as(%25, %26)
|
||||
// %38 : int = prim::Constant[value=0]()
|
||||
// %39 : int = aten::size(%state.1, %38)
|
||||
// %40 : int = prim::Constant[value=4]()
|
||||
// %41 : None = prim::Constant()
|
||||
// %42 : None = prim::Constant()
|
||||
// %43 : None = prim::Constant()
|
||||
// %44 : Tensor = aten::arange(%39, %40, %41, %42, %43)
|
||||
// %45 : int = prim::Constant[value=0]()
|
||||
// %46 : Tensor = aten::slice(%44, %45, %13, %10, %11)
|
||||
// %47 : int[] = prim::Constant[value=[-1]]()
|
||||
// %48 : Tensor = aten::view(%46, %47)
|
||||
// %49 : Tensor?[] = prim::ListConstruct(%48)
|
||||
// %50 : Tensor = aten::index_put(%state.1, %49, %33, %9)
|
||||
// -> (%50)
|
||||
// block1():
|
||||
// %28 : int[] = prim::ListConstruct(%batch_size.1, %6, %spatial_size_0.1,
|
||||
// %spatial_size_1.1) %29 : Tensor = aten::randn(%28, %12, %12, %12, %12) %30
|
||||
// : Tensor = aten::slice(%state.1, %13, %13, %10, %11) %35 : Tensor?[] =
|
||||
// prim::ListConstruct() %36 : Tensor = aten::expand_as(%29, %30) %52 : int =
|
||||
// prim::Constant[value=0]() %53 : int = aten::size(%state.1, %52) %54 : int
|
||||
// = prim::Constant[value=4]() %55 : None = prim::Constant() %56 : None =
|
||||
// prim::Constant() %57 : None = prim::Constant() %58 : Tensor =
|
||||
// aten::arange(%53, %54, %55, %56, %57) %59 : int =
|
||||
// prim::Constant[value=0]() %60 : Tensor = aten::slice(%58, %59, %13, %10,
|
||||
// %11) %61 : int[] = prim::Constant[value=[-1]]() %62 : Tensor =
|
||||
// aten::view(%60, %61) %63 : Tensor?[] = prim::ListConstruct(%62) %64 :
|
||||
// Tensor = aten::index_put(%state.1, %63, %36, %9)
|
||||
// %28 : int[] = prim::ListConstruct(%batch_size.1, %6, %spatial_size_0.1, %spatial_size_1.1)
|
||||
// %29 : Tensor = aten::randn(%28, %12, %12, %12, %12)
|
||||
// %30 : Tensor = aten::slice(%state.1, %13, %13, %10, %11)
|
||||
// %35 : Tensor?[] = prim::ListConstruct()
|
||||
// %36 : Tensor = aten::expand_as(%29, %30)
|
||||
// %52 : int = prim::Constant[value=0]()
|
||||
// %53 : int = aten::size(%state.1, %52)
|
||||
// %54 : int = prim::Constant[value=4]()
|
||||
// %55 : None = prim::Constant()
|
||||
// %56 : None = prim::Constant()
|
||||
// %57 : None = prim::Constant()
|
||||
// %58 : Tensor = aten::arange(%53, %54, %55, %56, %57)
|
||||
// %59 : int = prim::Constant[value=0]()
|
||||
// %60 : Tensor = aten::slice(%58, %59, %13, %10, %11)
|
||||
// %61 : int[] = prim::Constant[value=[-1]]()
|
||||
// %62 : Tensor = aten::view(%60, %61)
|
||||
// %63 : Tensor?[] = prim::ListConstruct(%62)
|
||||
// %64 : Tensor = aten::index_put(%state.1, %63, %36, %9)
|
||||
// -> (%64)
|
||||
// clang-format on
|
||||
void RegisterInplaceNodeInIfBlocks(
|
||||
Value* value,
|
||||
Value* new_inplace_node,
|
||||
Block* outer_block,
|
||||
Node* initial_node,
|
||||
Value* orig_data,
|
||||
Value* new_data,
|
||||
const std::string& output_name) {
|
||||
if (initial_node->kind() != prim::If)
|
||||
auto outer_block = new_data->node()->owningBlock();
|
||||
auto initial_block_node = outer_block->owningNode();
|
||||
|
||||
if ((nullptr == initial_block_node) ||
|
||||
(initial_block_node->kind() != prim::If)) {
|
||||
return;
|
||||
|
||||
auto next_node = initial_node;
|
||||
new_inplace_node->setDebugName("_output_" + output_name);
|
||||
outer_block->registerOutput(new_inplace_node);
|
||||
// Block has a new output. Add the output for the prim::If node.
|
||||
if (next_node->outputs().size() < outer_block->outputs().size())
|
||||
next_node->addOutput()->copyMetadata(new_inplace_node);
|
||||
|
||||
auto next_block = next_node->owningBlock();
|
||||
while (nullptr != next_block->owningNode() &&
|
||||
next_block != value->node()->owningBlock()) {
|
||||
next_block->registerOutput(next_node->output(0));
|
||||
next_node = next_block->owningNode();
|
||||
// Block has a new output. Add the output for the prim::If node.
|
||||
if (next_node->outputs().size() < next_block->outputs().size())
|
||||
next_node->addOutput()->setType(new_inplace_node->type());
|
||||
next_block = next_node->owningBlock();
|
||||
}
|
||||
|
||||
value->replaceAllUsesAfterNodeWith(
|
||||
next_node->output(0)->node(),
|
||||
next_node->outputs().at(next_node->outputs().size() - 1));
|
||||
auto next_block_node = initial_block_node;
|
||||
new_data->setDebugName("_output_" + output_name);
|
||||
outer_block->registerOutput(new_data);
|
||||
// Block has a new output. Add the output for the prim::If node.
|
||||
if (next_block_node->outputs().size() < outer_block->outputs().size())
|
||||
next_block_node->addOutput()->copyMetadata(new_data);
|
||||
|
||||
auto next_block = next_block_node->owningBlock();
|
||||
while (nullptr != next_block->owningNode() &&
|
||||
next_block != orig_data->node()->owningBlock()) {
|
||||
next_block->registerOutput(next_block_node->output(0));
|
||||
next_block_node = next_block->owningNode();
|
||||
// Block has a new output. Add the output for the prim::If node.
|
||||
if (next_block_node->outputs().size() < next_block->outputs().size())
|
||||
next_block_node->addOutput()->setType(new_data->type());
|
||||
next_block = next_block_node->owningBlock();
|
||||
}
|
||||
|
||||
orig_data->replaceAllUsesAfterNodeWith(
|
||||
next_block_node->output(0)->node(),
|
||||
next_block_node->outputs().at(next_block_node->outputs().size() - 1));
|
||||
}
|
||||
|
||||
// Register inplace op node inputs/outputs through the blocks.
|
||||
|
|
@ -443,21 +209,25 @@ void RegisterInplaceNodeInIfBlocks(
|
|||
// %60 : Tensor = aten::index_put(%bias.1, %59, %45, %25)
|
||||
// -> (%27, %60)
|
||||
// -> (%27, %61)
|
||||
void RegisterInplaceNodeInLoopBlocks(
|
||||
Value* orig_data,
|
||||
Value* new_inplace_node,
|
||||
Node* block_node,
|
||||
Block* outer_block,
|
||||
Node* next_node) {
|
||||
if (next_node->kind() != prim::Loop)
|
||||
void RegisterInplaceNodeInLoopBlocks(Value* orig_data, Value* new_data) {
|
||||
Node* inplace_node = new_data->node();
|
||||
Block* outer_block = inplace_node->owningBlock();
|
||||
Node* outer_block_node = outer_block->owningNode();
|
||||
|
||||
if (nullptr == outer_block_node) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (outer_block_node->kind() != prim::Loop)
|
||||
return;
|
||||
|
||||
outer_block->registerOutput(new_inplace_node);
|
||||
outer_block->registerOutput(new_data);
|
||||
std::vector<std::pair<Block*, Node*>> node_list = {
|
||||
std::make_pair(outer_block, next_node)};
|
||||
std::make_pair(outer_block, outer_block_node)};
|
||||
|
||||
next_node->addOutput()->setType(new_inplace_node->type());
|
||||
auto next_block = next_node->owningBlock();
|
||||
outer_block_node->addOutput()->setType(new_data->type());
|
||||
auto next_block = outer_block_node->owningBlock();
|
||||
auto next_node = outer_block_node;
|
||||
|
||||
while (nullptr != next_block->owningNode() &&
|
||||
next_block != orig_data->node()->owningBlock()) {
|
||||
|
|
@ -465,7 +235,7 @@ void RegisterInplaceNodeInLoopBlocks(
|
|||
outer_block->registerOutput(
|
||||
next_node->outputs().at(next_node->outputs().size() - 1));
|
||||
next_node = outer_block->owningNode();
|
||||
next_node->addOutput()->setType(new_inplace_node->type());
|
||||
next_node->addOutput()->setType(new_data->type());
|
||||
next_block = next_node->owningBlock();
|
||||
if (next_node->kind() ==
|
||||
prim::Loop) // Do not register input if nested in If block
|
||||
|
|
@ -486,11 +256,11 @@ void RegisterInplaceNodeInLoopBlocks(
|
|||
}
|
||||
|
||||
// Update inplace node inputs inside the inner most block.
|
||||
auto prev_data = block_node->inputs().at(0);
|
||||
while (inplace_ops.find(prev_data->node()->kind()) != inplace_ops.end()) {
|
||||
auto prev_data = inplace_node->inputs().at(0);
|
||||
while (IsInplaceNode(prev_data->node())) {
|
||||
prev_data = prev_data->node()->inputs().at(0);
|
||||
}
|
||||
for (auto node : block_node->owningBlock()->nodes()) {
|
||||
for (auto node : inplace_node->owningBlock()->nodes()) {
|
||||
size_t idx = 0;
|
||||
for (auto inputs_ : node->inputs()) {
|
||||
if (inputs_ == prev_data) {
|
||||
|
|
@ -507,27 +277,26 @@ void RegisterInplaceNodeInLoopBlocks(
|
|||
}
|
||||
|
||||
// Register inplace op node inputs/outputs through the blocks.
|
||||
void RegisterInplaceNodeInBlocks(
|
||||
Value* orig_data,
|
||||
Value* new_inplace_node,
|
||||
Node* block_node,
|
||||
Block* outer_block,
|
||||
Node* next_node) {
|
||||
if (next_node == nullptr)
|
||||
void RegisterInplaceNodeInBlocks(Value* orig_data, Value* new_data) {
|
||||
Node* inplace_node = new_data->node();
|
||||
Block* outer_block = inplace_node->owningBlock();
|
||||
Node* outer_block_node = outer_block->owningNode();
|
||||
|
||||
if (outer_block_node == nullptr)
|
||||
return;
|
||||
|
||||
// Check if the value is already registered in the block
|
||||
bool registered = false;
|
||||
while (inplace_ops.find(orig_data->node()->kind()) != inplace_ops.end()) {
|
||||
while (IsInplaceNode(orig_data->node())) {
|
||||
orig_data = orig_data->node()->inputs().at(0);
|
||||
}
|
||||
for (auto use : orig_data->uses()) {
|
||||
if ((use.user->owningBlock() == outer_block) &&
|
||||
(use.user->isAfter(new_inplace_node->node()))) {
|
||||
(use.user->isAfter(inplace_node))) {
|
||||
size_t idx = 0;
|
||||
for (auto input_ : use.user->inputs()) {
|
||||
if (input_ == orig_data) {
|
||||
use.user->replaceInput(idx, new_inplace_node);
|
||||
use.user->replaceInput(idx, new_data);
|
||||
registered = true;
|
||||
}
|
||||
idx++;
|
||||
|
|
@ -538,105 +307,30 @@ void RegisterInplaceNodeInBlocks(
|
|||
return;
|
||||
|
||||
// Register inplace node outputs through the blocks.
|
||||
RegisterInplaceNodeInLoopBlocks(
|
||||
orig_data, new_inplace_node, block_node, outer_block, next_node);
|
||||
RegisterInplaceNodeInLoopBlocks(orig_data, new_data);
|
||||
|
||||
RegisterInplaceNodeInIfBlocks(
|
||||
orig_data,
|
||||
new_inplace_node,
|
||||
outer_block,
|
||||
next_node,
|
||||
orig_data->debugName());
|
||||
RegisterInplaceNodeInIfBlocks(orig_data, new_data, orig_data->debugName());
|
||||
|
||||
while (nullptr != outer_block->owningNode() &&
|
||||
outer_block != orig_data->node()->owningBlock()) {
|
||||
MatchIfBlocksOutputForValue(orig_data, outer_block, new_inplace_node);
|
||||
MatchIfBlocksOutputForValue(orig_data, outer_block, new_data);
|
||||
outer_block = outer_block->owningNode()->owningBlock();
|
||||
}
|
||||
}
|
||||
|
||||
// Trace back all the slice & select nodes associated with the index_put node,
|
||||
// and convert them to associated indices.
|
||||
// E.g. The IR for x[1:3, 0] = update
|
||||
// ...
|
||||
// %8 : Float(2, 4) = aten::slice(%0, %4, %5, %6, %7)
|
||||
// ...
|
||||
// %11 : Float(2) = aten::select(%8, %9, %10)
|
||||
// ...
|
||||
// %13 : Tensor?[] = prim::ListConstruct()
|
||||
// ...
|
||||
// %16 : Float(2) = aten::index_put(%11, %13, %14, %15)
|
||||
// The aten::index_put node alone does not contain any indices (%13 : Tensor?[]
|
||||
// = prim::ListConstruct()).
|
||||
// ...
|
||||
// # Below constructs index from slice node.
|
||||
// %23 : Long() = aten::size(%0, %4)
|
||||
// %28 : Tensor = aten::arange(%23, %24, %25, %26, %27)
|
||||
// %33 : Tensor = aten::slice(%28, %4, %5, %6, %7)
|
||||
// %39 : int[] = prim::Constant[value=[-1, 1]]()
|
||||
// %40 : Tensor = aten::view(%33, %39)
|
||||
// ...
|
||||
// # Below constructs index from select node.
|
||||
// %36 : int = prim::Constant[value=0]()
|
||||
// %37 : Tensor = aten::unsqueeze(%10, %36)
|
||||
// %42 : int[] = prim::Constant[value=[-1]]()
|
||||
// %43 : Tensor = aten::view(%37, %42)
|
||||
// ...
|
||||
// # Adding the above two indices to index_put
|
||||
// %44 : Tensor?[] = prim::ListConstruct(%40, %43)
|
||||
// %45 : Float(2, 5) = aten::index_put(%0, %44, %14, %15)
|
||||
void SquashSliceAndSelect(Node* index_put_node) {
|
||||
auto graph = index_put_node->owningGraph();
|
||||
|
||||
// Find slice and select operators that are associated with this index
|
||||
// operator. E.g. x[1:3, 0] = y will generate one slice operator(1:3) and one
|
||||
// select operator(0).
|
||||
std::vector<Node*> slice_and_select_nodes =
|
||||
FetchSliceAndSelect(index_put_node);
|
||||
|
||||
Node* last_node = slice_and_select_nodes.size() > 0
|
||||
? slice_and_select_nodes.back()
|
||||
: index_put_node;
|
||||
Value* orig_data = last_node->input(0);
|
||||
|
||||
// Convert fetched slice/select operators into tensor indices.
|
||||
std::unordered_map<int64_t, ConvertedIndex> dim_index_map =
|
||||
MergeSliceAndSelectToIndices(
|
||||
graph, index_put_node, slice_and_select_nodes, orig_data);
|
||||
std::vector<Value*> indices =
|
||||
ReshapeToAdvancedIndexingFormat(graph, index_put_node, dim_index_map);
|
||||
|
||||
// Create new aten::index_put operator.
|
||||
WithInsertPoint guard(index_put_node);
|
||||
const auto list_indices =
|
||||
graph->insertNode(graph->createList(OptionalType::ofTensor(), indices))
|
||||
->output();
|
||||
|
||||
auto new_index_put = graph->insert(
|
||||
aten::index_put,
|
||||
{orig_data,
|
||||
list_indices,
|
||||
index_put_node->input(2),
|
||||
index_put_node->input(3)});
|
||||
new_index_put->copyMetadata(index_put_node->output());
|
||||
index_put_node->output()->replaceAllUsesWith(new_index_put);
|
||||
|
||||
auto block_node = new_index_put->node();
|
||||
auto outer_block = block_node->owningBlock();
|
||||
auto next_node = outer_block->owningNode();
|
||||
if (nullptr == next_node) {
|
||||
orig_data->replaceAllUsesAfterNodeWith(
|
||||
new_index_put->node(), new_index_put);
|
||||
return;
|
||||
}
|
||||
|
||||
RegisterInplaceNodeInBlocks(
|
||||
orig_data, new_index_put, block_node, outer_block, next_node);
|
||||
}
|
||||
|
||||
void PrepareIndexPutForONNX(Node* node) {
|
||||
if (node->kind() == aten::index_put || node->kind() == aten::index_put_) {
|
||||
SquashSliceAndSelect(node);
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
node->kind() == aten::index_put || node->kind() == aten::index_put_);
|
||||
auto placeholder_node = EncapsulatePatternIntoSubblock(node).value();
|
||||
if (node->kind() == aten::index_put_) {
|
||||
auto orig_data = placeholder_node->input();
|
||||
auto new_data = placeholder_node->output();
|
||||
|
||||
if (nullptr == placeholder_node->owningBlock()->owningNode()) {
|
||||
orig_data->replaceAllUsesAfterNodeWith(placeholder_node, new_data);
|
||||
return;
|
||||
}
|
||||
RegisterInplaceNodeInBlocks(orig_data, new_data);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -662,7 +356,7 @@ void PrepareCopyForONNX(Node* node) {
|
|||
expanded_value->copyMetadata(node->input(1));
|
||||
|
||||
auto index_put = graph->insert(
|
||||
aten::index_put,
|
||||
aten::index_put_,
|
||||
{node->input(0), dummy_list, expanded_value, node->input(2)});
|
||||
index_put->node()->setSourceRange(node->sourceRange());
|
||||
index_put->copyMetadata(node->output());
|
||||
|
|
@ -695,12 +389,7 @@ static void PrepareListPopForONNX(Node* n) {
|
|||
n->inputs().at(0)->replaceAllUsesAfterNodeWith(n, n->output());
|
||||
return;
|
||||
}
|
||||
RegisterInplaceNodeInBlocks(
|
||||
n->inputs().at(0),
|
||||
n->output(),
|
||||
n,
|
||||
n->owningBlock(),
|
||||
n->owningBlock()->owningNode());
|
||||
RegisterInplaceNodeInBlocks(n->inputs().at(0), n->output());
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -713,12 +402,7 @@ static void PrepareListDeleteForONNX(Node* n) {
|
|||
n->inputs().at(0)->replaceAllUsesAfterNodeWith(n, n->output());
|
||||
return;
|
||||
}
|
||||
RegisterInplaceNodeInBlocks(
|
||||
n->inputs().at(0),
|
||||
n->output(),
|
||||
n,
|
||||
n->owningBlock(),
|
||||
n->owningBlock()->owningNode());
|
||||
RegisterInplaceNodeInBlocks(n->inputs().at(0), n->output());
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -733,12 +417,7 @@ static void PrepareListAppendAndInsertForONNX(Node* n) {
|
|||
n->inputs().at(0)->replaceAllUsesAfterNodeWith(n, n->output());
|
||||
return;
|
||||
}
|
||||
RegisterInplaceNodeInBlocks(
|
||||
n->inputs().at(0),
|
||||
n->output(),
|
||||
n,
|
||||
n->owningBlock(),
|
||||
n->owningBlock()->owningNode());
|
||||
RegisterInplaceNodeInBlocks(n->inputs().at(0), n->output());
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -903,10 +582,8 @@ Value* registerSetAttrInBlocks(
|
|||
Value* origValue,
|
||||
const std::string& output_name) {
|
||||
auto cloneNode = insertCloneBeforeNode(graph, newValue, block->return_node());
|
||||
auto next_node = block->owningNode();
|
||||
|
||||
RegisterInplaceNodeInIfBlocks(
|
||||
origValue, cloneNode->output(), block, next_node, output_name);
|
||||
RegisterInplaceNodeInIfBlocks(origValue, cloneNode->output(), output_name);
|
||||
|
||||
return MatchIfBlocksOutputForValue(origValue, block, cloneNode->output());
|
||||
}
|
||||
|
|
|
|||
|
|
@ -44,6 +44,8 @@
|
|||
#include <torch/csrc/jit/passes/onnx/fold_if_node.h>
|
||||
#include <torch/csrc/jit/passes/onnx/function_substitution.h>
|
||||
#include <torch/csrc/jit/passes/onnx/list_model_parameters.h>
|
||||
#include <torch/csrc/jit/passes/onnx/pattern_conversion/pattern_conversion.h>
|
||||
#include <torch/csrc/jit/passes/onnx/pattern_conversion/pattern_encapsulation.h>
|
||||
#include <torch/csrc/jit/passes/onnx/peephole.h>
|
||||
#include <torch/csrc/jit/passes/onnx/prepare_division_for_onnx.h>
|
||||
#include <torch/csrc/jit/passes/onnx/preprocess_for_onnx.h>
|
||||
|
|
@ -493,6 +495,11 @@ void initJITBindings(PyObject* module) {
|
|||
python::unflatten(vars, desc));
|
||||
})
|
||||
.def("_jit_pass_onnx_block", BlockToONNX)
|
||||
.def(
|
||||
"_jit_pass_onnx_encapsulate_pattern_into_subblock",
|
||||
EncapsulatePatternIntoSubblock)
|
||||
.def(
|
||||
"_jit_onnx_convert_pattern_from_subblock", ConvertPatternFromSubblock)
|
||||
.def("_jit_pass_fixup_onnx_controlflow_node", FixupONNXControlflowNode)
|
||||
.def("_jit_pass_canonicalize_graph_fuser_ops", CanonicalizeOps)
|
||||
.def("_jit_pass_decompose_ops", DecomposeOps)
|
||||
|
|
|
|||
|
|
@ -84,95 +84,41 @@ def index_put(g, self, indices_list_value, values, accumulate=False):
|
|||
# when inputs to the index_put node contains boolean inputs
|
||||
#
|
||||
# index_put -> masked_fill
|
||||
# * input index contains single tensor of Bool type (e.g.: %24 <- %23).
|
||||
# * input value contains single element (e.g.: %18).
|
||||
#
|
||||
# before graph(%0 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=1, device=cpu),
|
||||
# %some_const : Float(requires_grad=0, device=cpu)):
|
||||
# %6 : None = prim::Constant()
|
||||
# Torch IR
|
||||
# %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6)
|
||||
# %8 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::ne(%mask, %some_const)
|
||||
# %26 : Long(requires_grad=0, device=cpu) = prim::Constant[value={11}]()
|
||||
# %27 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]()
|
||||
# %11 : Device = prim::Constant[value="cpu"]()
|
||||
# %12 : None = prim::Constant()
|
||||
# %28 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]()
|
||||
# %29 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]()
|
||||
# %15 : None = prim::Constant()
|
||||
# %16 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) =
|
||||
# aten::to(%8, %26, %27, %11, %12, %28, %29, %15)
|
||||
# %18 : Float(requires_grad=0, device=cpu) = prim::Constant[value={1}]()
|
||||
# %30 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]()
|
||||
# %22 : int[] = prim::Constant[value=[-1]]()
|
||||
# %23 : Tensor = aten::view(%16, %22)
|
||||
# %23 : Bool(8, strides=[1], device=cpu) = aten::view(%16, %22)
|
||||
# %24 : Tensor?[] = prim::ListConstruct(%23)
|
||||
# %25 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) =
|
||||
# aten::index_put(%mask, %24, %18, %30)
|
||||
# return (%25)
|
||||
#
|
||||
# after graph(%0 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu),
|
||||
# %some_const : Float(requires_grad=0, device=cpu)):
|
||||
# %3 : Tensor = onnx::Equal(%0, %some_const)
|
||||
# %4 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = onnx::Not(%3)
|
||||
# %12 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = onnx::Cast[to=9](%4)
|
||||
# %19 : Tensor = onnx::Cast[to=9](%12)
|
||||
# %20 : Tensor = onnx::Constant[value={1}]()
|
||||
# %21 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
|
||||
# = onnx::Where(%19, %20, %0)
|
||||
# return (%21)
|
||||
#
|
||||
# index_put -> masked_scatter
|
||||
# * input index contains single tensor of Bool type (e.g.: %32 <- %31).
|
||||
# * input value contains multiple elements (e.g.: %28).
|
||||
#
|
||||
# before graph(%0 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=1, device=cpu),
|
||||
# %some_const : Float(requires_grad=0, device=cpu)):
|
||||
# %6 : None = prim::Constant()
|
||||
# Torch IR
|
||||
# %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6)
|
||||
# %28 : Float(8, strides=[1], requires_grad=0, device=cpu)
|
||||
# = prim::Constant[value= 1 1 1 1 1 1 1 1 [ CPUFloatType{8} ]]()
|
||||
# %15 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
|
||||
# = aten::ne(%mask, %some_const)
|
||||
# %34 : Long(requires_grad=0, device=cpu) = prim::Constant[value={11}]()
|
||||
# %35 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]()
|
||||
# %18 : Device = prim::Constant[value="cpu"]()
|
||||
# %19 : None = prim::Constant()
|
||||
# %36 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]()
|
||||
# %37 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]()
|
||||
# %22 : None = prim::Constant()
|
||||
# %23 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
|
||||
# = aten::to(%15, %34, %35, %18, %19, %36, %37, %22)
|
||||
# %38 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]()
|
||||
# %30 : int[] = prim::Constant[value=[-1]]()
|
||||
# %31 : Tensor = aten::view(%23, %30)
|
||||
# %31 : Bool(8, strides=[1], device=cpu) = aten::view(%23, %30)
|
||||
# %32 : Tensor?[] = prim::ListConstruct(%31)
|
||||
# %33 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
|
||||
# = aten::index_put(%mask, %32, %28, %38)
|
||||
# return (%33)
|
||||
#
|
||||
# after graph(%0 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu),
|
||||
# %some_const : Float(requires_grad=0, device=cpu)):
|
||||
# %3 : Float(8, strides=[1], requires_grad=0, device=cpu)
|
||||
# = onnx::Constant[value= 1 1 1 1 1 1 1 1 [ CPUFloatType{8} ]]()
|
||||
# %4 : Tensor = onnx::Equal(%0, %some_const)
|
||||
# %5 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = onnx::Not(%4)
|
||||
# %13 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = onnx::Cast[to=9](%5)
|
||||
# %19 : Tensor = onnx::Shape(%0)
|
||||
# %20 : Tensor = onnx::Expand(%13, %19)
|
||||
# %21 : Tensor = onnx::NonZero(%20)
|
||||
# %22 : Tensor = onnx::Transpose[perm=[1, 0]](%21)
|
||||
# %23 : Tensor = onnx::Constant[value={-1}]()
|
||||
# %24 : Tensor = onnx::Reshape(%3, %23)
|
||||
# %25 : Tensor = onnx::Shape(%22)
|
||||
# %27 : Tensor = onnx::Constant[value={0}]()
|
||||
# %28 : Tensor = onnx::Gather[axis=0](%25, %27)
|
||||
# %29 : Tensor = onnx::Constant[value={0}]()
|
||||
# %30 : Tensor = onnx::Unsqueeze[axes=[0]](%29)
|
||||
# %31 : Tensor = onnx::Unsqueeze[axes=[0]](%28)
|
||||
# %32 : Tensor = onnx::Constant[value={0}]()
|
||||
# %33 : Tensor = onnx::Unsqueeze[axes=[0]](%32)
|
||||
# %34 : Tensor = onnx::Slice(%24, %30, %31, %33)
|
||||
# %35 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
|
||||
# = onnx::ScatterND(%0, %22, %34)
|
||||
# return (%35)
|
||||
|
||||
bool_inp = list(index.node().inputs())[0]
|
||||
bool_inp = index
|
||||
if bool_inp.type() is not None and bool_inp.type().scalarType() == 'Bool':
|
||||
rank = sym_help._get_tensor_rank(values)
|
||||
if rank is not None and rank == 0:
|
||||
|
|
|
|||
|
|
@ -947,7 +947,7 @@ def _find_symbolic_in_registry(domain, op_name, opset_version, operator_export_t
|
|||
return sym_registry.get_registered_op(op_name, domain, opset_version)
|
||||
|
||||
|
||||
def _run_symbolic_function(g, n, inputs, env, operator_export_type=OperatorExportTypes.ONNX):
|
||||
def _run_symbolic_function(g, block, n, inputs, env, operator_export_type=OperatorExportTypes.ONNX):
|
||||
# NB: Returning None means the node gets cloned as is into
|
||||
# the new graph
|
||||
try:
|
||||
|
|
@ -970,9 +970,12 @@ def _run_symbolic_function(g, n, inputs, env, operator_export_type=OperatorExpor
|
|||
ns_op_name = n.kind()
|
||||
ns, op_name = ns_op_name.split("::")
|
||||
if ns == "onnx":
|
||||
# Clone node to trigger ONNX shape inference
|
||||
attrs = {k + "_" + n.kindOf(k)[0]: n[k] for k in n.attributeNames()}
|
||||
return g.op(op_name, *inputs, **attrs, outputs=n.outputsSize())
|
||||
if op_name == "Placeholder":
|
||||
return torch._C._jit_onnx_convert_pattern_from_subblock(block, n, env)
|
||||
else:
|
||||
# Use the original node directly
|
||||
attrs = {k + "_" + n.kindOf(k)[0]: n[k] for k in n.attributeNames()}
|
||||
return g.op(op_name, *inputs, **attrs, outputs=n.outputsSize())
|
||||
|
||||
elif ns == "aten":
|
||||
is_exportable_aten_op = sym_registry.is_registered_op(op_name, '', opset_version)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user