mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ONNX] Utilize ONNX shape inference for ONNX exporter (#40628)
Summary: It is often that the conversion from torch operator to onnx operator requires input rank/dtype/shape to be known. Previously, the conversion depends on tracer to provide these info, leaving a gap in conversion of scripted modules. We are extending the export with support from onnx shape inference. If enabled, onnx shape inference will be called whenever an onnx node is created. This is the first PR introducing the initial look of the feature. More and more cases will be supported following this PR. * Added pass to run onnx shape inference on a given node. The node has to have namespace `onnx`. * Moved helper functions from `export.cpp` to a common place for re-use. * This feature is currently experimental, and can be turned on through flag `onnx_shape_inference` in internal api `torch.onnx._export`. * Currently skipping ONNX Sequence ops, If/Loop and ConstantOfShape due to limitations. Support will be added in the future. Pull Request resolved: https://github.com/pytorch/pytorch/pull/40628 Reviewed By: mrshenli Differential Revision: D22709746 Pulled By: bzinodev fbshipit-source-id: b52aeeae00667e66e0b0c1144022f7af9a8b2948
This commit is contained in:
parent
3aeb70db0b
commit
08126c9153
3
.github/workflows/lint.yml
vendored
3
.github/workflows/lint.yml
vendored
|
|
@ -152,6 +152,9 @@ jobs:
|
|||
--verbose \
|
||||
--paths torch/csrc/ \
|
||||
--diff "$MERGE_BASE" \
|
||||
-g"-torch/csrc/jit/passes/onnx/helper.cpp" \
|
||||
-g"-torch/csrc/jit/passes/onnx/shape_type_inference.cpp"\
|
||||
-g"-torch/csrc/jit/serialization/onnx.cpp" \
|
||||
-g"-torch/csrc/jit/serialization/export.cpp" \
|
||||
-g"-torch/csrc/jit/serialization/import.cpp" \
|
||||
-g"-torch/csrc/jit/serialization/import_legacy.cpp" \
|
||||
|
|
|
|||
|
|
@ -283,6 +283,7 @@ namespace c10 {
|
|||
_(onnx, SequenceConstruct) \
|
||||
_(onnx, SequenceEmpty) \
|
||||
_(onnx, SequenceInsert) \
|
||||
_(onnx, SequenceErase) \
|
||||
_(onnx, ConcatFromSequence) \
|
||||
_(onnx, Identity) \
|
||||
_(onnx, SoftmaxCrossEntropyLoss) \
|
||||
|
|
|
|||
|
|
@ -459,6 +459,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
|
|||
if(NOT INTERN_BUILD_MOBILE)
|
||||
list(APPEND TORCH_SRCS
|
||||
${TORCH_SRC_DIR}/csrc/api/src/jit.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/serialization/onnx.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/serialization/export.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/serialization/export_module.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/serialization/import_legacy.cpp
|
||||
|
|
|
|||
|
|
@ -86,5 +86,14 @@ def skipIfUnsupportedOpsetVersion(unsupported_opset_versions):
|
|||
return wrapper
|
||||
return skip_dec
|
||||
|
||||
def skipIfONNXShapeInference(onnx_shape_inference):
|
||||
def skip_dec(func):
|
||||
def wrapper(self):
|
||||
if self.onnx_shape_inference is onnx_shape_inference:
|
||||
raise unittest.SkipTest("Skip verify test for unsupported opset_version")
|
||||
return func(self)
|
||||
return wrapper
|
||||
return skip_dec
|
||||
|
||||
def flatten(x):
|
||||
return tuple(function._iter_filter(lambda o: isinstance(o, torch.Tensor))(x))
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from model_defs.lstm_flattening_result import LstmFlatteningResult
|
|||
from model_defs.rnn_model_with_packed_sequence import RnnModelWithPackedSequence
|
||||
from test_pytorch_common import (skipIfUnsupportedMinOpsetVersion, enableScriptTest,
|
||||
skipIfUnsupportedOpsetVersion, skipIfNoLapack,
|
||||
skipIfUnsupportedMaxOpsetVersion)
|
||||
skipIfUnsupportedMaxOpsetVersion, skipIfONNXShapeInference)
|
||||
from test_pytorch_common import BATCH_SIZE
|
||||
from test_pytorch_common import RNN_BATCH_SIZE, RNN_SEQUENCE_LENGTH, RNN_INPUT_SIZE, RNN_HIDDEN_SIZE
|
||||
import model_defs.word_language_model as word_language_model
|
||||
|
|
@ -79,7 +79,8 @@ def run_model_test(self, model, batch_size=2, state_dict=None,
|
|||
keep_initializers_as_inputs=self.keep_initializers_as_inputs,
|
||||
dynamic_axes=dynamic_axes,
|
||||
input_names=input_names, output_names=output_names,
|
||||
fixed_batch_size=fixed_batch_size)
|
||||
fixed_batch_size=fixed_batch_size,
|
||||
onnx_shape_inference=self.onnx_shape_inference)
|
||||
|
||||
# compute onnxruntime output prediction
|
||||
ort_sess = onnxruntime.InferenceSession(f.getvalue())
|
||||
|
|
@ -103,6 +104,7 @@ class TestONNXRuntime(unittest.TestCase):
|
|||
from torch.onnx.symbolic_helper import _export_onnx_opset_version
|
||||
opset_version = _export_onnx_opset_version
|
||||
keep_initializers_as_inputs = True # For IR version 3 type export.
|
||||
onnx_shape_inference = False
|
||||
|
||||
def setUp(self):
|
||||
torch.manual_seed(0)
|
||||
|
|
@ -496,7 +498,7 @@ class TestONNXRuntime(unittest.TestCase):
|
|||
class ScalarInputModel(torch.jit.ScriptModule):
|
||||
@torch.jit.script_method
|
||||
def forward(self, input):
|
||||
return torch.tensor(input.shape[1])
|
||||
return torch.tensor(input.shape[1])
|
||||
|
||||
x = torch.randn(3, 4)
|
||||
self.run_test(ScalarInputModel(), x)
|
||||
|
|
@ -504,7 +506,7 @@ class TestONNXRuntime(unittest.TestCase):
|
|||
class TensorInputModel(torch.jit.ScriptModule):
|
||||
@torch.jit.script_method
|
||||
def forward(self, input):
|
||||
return torch.tensor([input.shape[0], input.shape[1]])
|
||||
return torch.tensor([input.shape[0], input.shape[1]])
|
||||
|
||||
x = torch.randn(3, 4)
|
||||
self.run_test(TensorInputModel(), x)
|
||||
|
|
@ -520,7 +522,7 @@ class TestONNXRuntime(unittest.TestCase):
|
|||
class InputWithDtypeModel(torch.jit.ScriptModule):
|
||||
@torch.jit.script_method
|
||||
def forward(self, input):
|
||||
return torch.tensor(input.shape[1], dtype=torch.long)
|
||||
return torch.tensor(input.shape[1], dtype=torch.long)
|
||||
|
||||
x = torch.randn(3, 4)
|
||||
self.run_test(InputWithDtypeModel(), x)
|
||||
|
|
@ -528,7 +530,7 @@ class TestONNXRuntime(unittest.TestCase):
|
|||
class MixedInputModel(torch.jit.ScriptModule):
|
||||
@torch.jit.script_method
|
||||
def forward(self, input):
|
||||
return torch.tensor([input.shape[0], int(input)])
|
||||
return torch.tensor([input.shape[0], int(input)])
|
||||
|
||||
x = torch.randn(1)
|
||||
self.run_test(MixedInputModel(), x)
|
||||
|
|
@ -686,6 +688,23 @@ class TestONNXRuntime(unittest.TestCase):
|
|||
self.run_test(TraceModel(), (x1, x2, x3), atol=10e-5)
|
||||
self.run_test(ScriptModel(), (x1, x2, x3), atol=10e-5)
|
||||
|
||||
# Conversion of Transpose depends on input shape to be known.
|
||||
# The following test only works when onnx shape inference is enabled.
|
||||
@skipIfONNXShapeInference(False)
|
||||
def test_transpose_infer_shape(self):
|
||||
class TransposeModule(torch.jit.ScriptModule):
|
||||
def __init__(self):
|
||||
super(TransposeModule, self).__init__()
|
||||
self.conv = torch.nn.Conv2d(3, 1, 3, stride=2)
|
||||
|
||||
@torch.jit.script_method
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
return x.transpose(0, 1)
|
||||
|
||||
x = torch.randn(32, 3, 64, 64)
|
||||
self.run_test(TransposeModule(), x)
|
||||
|
||||
def squeeze_model_tests(self, d, x1, x2):
|
||||
class Squeeze(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
|
|
@ -842,6 +861,23 @@ class TestONNXRuntime(unittest.TestCase):
|
|||
x = torch.randn(2, 3, 4)
|
||||
self.run_test(ArithmeticModule(), x)
|
||||
|
||||
# In scripting the first transpose node do not carry shape and dtype info.
|
||||
# The following test only works when onnx shape inference is enabled.
|
||||
@skipIfONNXShapeInference(False)
|
||||
def test_arithmetic_infer_dtype(self):
|
||||
class ArithmeticModule(torch.jit.ScriptModule):
|
||||
@torch.jit.script_method
|
||||
def forward(self, x):
|
||||
x = x.t()
|
||||
x = x + 2
|
||||
x = x - 4
|
||||
x = x * 6
|
||||
x = x / 8
|
||||
return x
|
||||
|
||||
x = torch.randn(2, 3)
|
||||
self.run_test(ArithmeticModule(), x)
|
||||
|
||||
def test_floor_div(self):
|
||||
class FloorDivModule(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
|
|
@ -3015,6 +3051,21 @@ class TestONNXRuntime(unittest.TestCase):
|
|||
x = torch.randn(4, 2, 3, requires_grad=True)
|
||||
self.run_test(UnfoldModel(), x)
|
||||
|
||||
@skipIfONNXShapeInference(False)
|
||||
def test_unfold_infer_shape(self):
|
||||
class UnfoldModule(torch.jit.ScriptModule):
|
||||
def __init__(self):
|
||||
super(UnfoldModule, self).__init__()
|
||||
self.conv = torch.nn.Conv1d(3, 1, 3, stride=2)
|
||||
|
||||
@torch.jit.script_method
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
return x.unfold(dimension=2, size=2, step=2)
|
||||
|
||||
x = torch.randn(32, 3, 64)
|
||||
self.run_test(UnfoldModule(), x)
|
||||
|
||||
def test_remainder(self):
|
||||
class RemainderModel(torch.nn.Module):
|
||||
def forward(self, input, other):
|
||||
|
|
@ -4187,5 +4238,11 @@ TestONNXRuntime_opset12_IRv4 = type(str("TestONNXRuntime_opset12_IRv4"),
|
|||
keep_initializers_as_inputs=False))
|
||||
|
||||
|
||||
# opset 12 tests, with _onnx_shape_inference=True.
|
||||
TestONNXRuntime_opset12_onnx_shape_inference = type(str("TestONNXRuntime_opset12_onnx_shape_inference"),
|
||||
(unittest.TestCase,),
|
||||
dict(TestONNXRuntime.__dict__, opset_version=12,
|
||||
onnx_shape_inference=True))
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -322,6 +322,7 @@ libtorch_extra_sources = libtorch_core_jit_sources + [
|
|||
"torch/csrc/jit/mobile/observer.cpp",
|
||||
"torch/csrc/jit/mobile/optim/sgd.cpp",
|
||||
"torch/csrc/jit/mobile/sequential.cpp",
|
||||
"torch/csrc/jit/serialization/onnx.cpp",
|
||||
"torch/csrc/jit/serialization/export.cpp",
|
||||
"torch/csrc/jit/serialization/export_module.cpp",
|
||||
"torch/csrc/jit/serialization/import_legacy.cpp",
|
||||
|
|
@ -501,6 +502,7 @@ libtorch_python_core_sources = [
|
|||
"torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp",
|
||||
"torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp",
|
||||
"torch/csrc/jit/passes/onnx/prepare_inplace_ops_for_onnx.cpp",
|
||||
"torch/csrc/jit/passes/onnx/shape_type_inference.cpp",
|
||||
"torch/csrc/jit/python/python_arg_flatten.cpp",
|
||||
"torch/csrc/jit/python/python_custom_class.cpp",
|
||||
"torch/csrc/jit/python/python_interpreter.cpp",
|
||||
|
|
|
|||
|
|
@ -10,6 +10,9 @@ then
|
|||
python tools/clang_tidy.py \
|
||||
--paths torch/csrc \
|
||||
--diff HEAD \
|
||||
-g"-torch/csrc/jit/passes/onnx/helper.cpp" \
|
||||
-g"-torch/csrc/jit/passes/onnx/shape_type_inference.cpp" \
|
||||
-g"-torch/csrc/jit/serialization/onnx.cpp" \
|
||||
-g"-torch/csrc/jit/serialization/export.cpp" \
|
||||
-g"-torch/csrc/jit/serialization/import.cpp" \
|
||||
-j
|
||||
|
|
|
|||
|
|
@ -103,7 +103,10 @@ c10::optional<at::Tensor> runTorchSlice_opset9(
|
|||
c10::optional<at::Tensor> runTorchSlice_opset10(
|
||||
const Node* node,
|
||||
std::vector<at::Tensor>& inputTensorValues) {
|
||||
if (inputTensorValues.size() < 3 || inputTensorValues.size() > 5) {
|
||||
const int maxSliceInputCount = 5;
|
||||
const int minSliceInputCount = 3;
|
||||
if (inputTensorValues.size() < minSliceInputCount ||
|
||||
inputTensorValues.size() > maxSliceInputCount) {
|
||||
std::cerr
|
||||
<< "Warning: Constant folding - Invalid number of inputs found for opset 10 or 11 onnx::Slice op. "
|
||||
<< "Constant folding not applied." << std::endl;
|
||||
|
|
@ -249,11 +252,9 @@ c10::optional<at::Tensor> runTorchBackendForOnnx(
|
|||
return c10::optional<at::Tensor>(updated_val);
|
||||
} else if (node->kind() == onnx::Cast) {
|
||||
assert(inputTensorValues.size() == 1);
|
||||
if (node->hasAttributeS("to") &&
|
||||
onnxTypeToScalarTypeMap.find(node->i(attr::to)) !=
|
||||
onnxTypeToScalarTypeMap.end()) {
|
||||
updated_val =
|
||||
inputTensorValues[0].to(onnxTypeToScalarTypeMap[node->i(attr::to)]);
|
||||
if (node->hasAttributeS("to") && ONNXTypeToATenType(node->i(attr::to))) {
|
||||
updated_val = inputTensorValues[0].to(
|
||||
ONNXTypeToATenType(node->i(attr::to)).value());
|
||||
return c10::optional<at::Tensor>(updated_val);
|
||||
}
|
||||
return c10::nullopt;
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
#include <torch/csrc/jit/passes/onnx/helper.h>
|
||||
#include <torch/csrc/jit/jit_log.h>
|
||||
#include <onnx/onnx_pb.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
|
@ -59,5 +59,40 @@ Node* addNodeToBlock(Block* block, Value* input, Symbol kind) {
|
|||
}
|
||||
return new_node;
|
||||
}
|
||||
|
||||
c10::optional<at::ScalarType> ONNXTypeToATenType(int32_t onnx_type) {
|
||||
switch (onnx_type) {
|
||||
case ::ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED:
|
||||
return at::ScalarType::Undefined;
|
||||
case ::ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
|
||||
return at::kFloat;
|
||||
case ::ONNX_NAMESPACE::TensorProto_DataType_UINT8:
|
||||
return at::kByte;
|
||||
case ::ONNX_NAMESPACE::TensorProto_DataType_INT8:
|
||||
return at::kChar;
|
||||
case ::ONNX_NAMESPACE::TensorProto_DataType_INT16:
|
||||
return at::kShort;
|
||||
case ::ONNX_NAMESPACE::TensorProto_DataType_INT32:
|
||||
return at::kInt;
|
||||
case ::ONNX_NAMESPACE::TensorProto_DataType_INT64:
|
||||
return at::kLong;
|
||||
case ::ONNX_NAMESPACE::TensorProto_DataType_BOOL:
|
||||
return at::kBool;
|
||||
case ::ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
|
||||
return at::kHalf;
|
||||
case ::ONNX_NAMESPACE::TensorProto_DataType_DOUBLE:
|
||||
return at::kDouble;
|
||||
case ::ONNX_NAMESPACE::TensorProto_DataType_COMPLEX64:
|
||||
return at::kComplexFloat;
|
||||
case ::ONNX_NAMESPACE::TensorProto_DataType_COMPLEX128:
|
||||
return at::kComplexDouble;
|
||||
case ::ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16:
|
||||
return at::kBFloat16;
|
||||
default:
|
||||
TORCH_CHECK("unexpected tensor scalar type");
|
||||
}
|
||||
return c10::optional<at::ScalarType>{};
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -8,18 +8,19 @@
|
|||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
namespace onnx {
|
||||
static const int OPSET_VERSION_1 = 1;
|
||||
static const int OPSET_VERSION_9 = 9;
|
||||
static const int OPSET_VERSION_10 = 10;
|
||||
static const int OPSET_VERSION_11 = 11;
|
||||
static const int OPSET_VERSION_12 = 12;
|
||||
} // namespace onnx
|
||||
|
||||
using ValueToParamPairMap = std::map<Value*, std::pair<std::string, IValue>>;
|
||||
|
||||
using ParamMap = std::map<std::string, IValue>;
|
||||
|
||||
void buildParamsMapFromValueToParamsMap(
|
||||
const ValueToParamPairMap& valsToParamsMap,
|
||||
ParamMap& paramsDict);
|
||||
ValueToParamPairMap buildValueToParamsMap(Block* b, const ParamMap& paramsDict);
|
||||
void eraseUnusedValuesFromMap(ValueToParamPairMap& valsToParamsMap);
|
||||
void eraseUnusedBlockInputs(Block* b);
|
||||
|
|
@ -28,5 +29,6 @@ void buildParamsMapFromValueToParamsMap(
|
|||
ParamMap& paramsDict);
|
||||
Node* addNodeToBlock(Block* block, Value* input, Symbol kind);
|
||||
|
||||
TORCH_API c10::optional<at::ScalarType> ONNXTypeToATenType(int32_t onnx_type);
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -576,7 +576,7 @@ static void eraseListConstruct(Block* block, int opset_version) {
|
|||
i, std::vector<Value*>({concat_node->output()}));
|
||||
|
||||
} else {
|
||||
if (opset_version < onnx::OPSET_VERSION_11) {
|
||||
if (opset_version < OPSET_VERSION_11) {
|
||||
// Tensor lists are used mostly for inputs to cat/stack. They are
|
||||
// already handled in those symbolics, and should become dead
|
||||
// afterwards.
|
||||
|
|
|
|||
166
torch/csrc/jit/passes/onnx/shape_type_inference.cpp
Normal file
166
torch/csrc/jit/passes/onnx/shape_type_inference.cpp
Normal file
|
|
@ -0,0 +1,166 @@
|
|||
#include <torch/csrc/jit/passes/onnx/shape_type_inference.h>
|
||||
#include <torch/csrc/jit/jit_log.h>
|
||||
#include <torch/csrc/jit/passes/onnx/helper.h>
|
||||
#include <torch/csrc/jit/serialization/export.h>
|
||||
#include <torch/csrc/jit/serialization/onnx.h>
|
||||
|
||||
#include <onnx/shape_inference/implementation.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
namespace {
|
||||
namespace onnx_torch = ::torch::onnx;
|
||||
namespace onnx = ::ONNX_NAMESPACE;
|
||||
|
||||
void UpdateTorchValueByOnnxValueInfo(
|
||||
Value* v,
|
||||
const onnx::ValueInfoProto& p_info) {
|
||||
if (!p_info.has_type()) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto p_type = p_info.type();
|
||||
if (!p_type.has_tensor_type()) {
|
||||
// TODO: Support sequence type.
|
||||
return;
|
||||
}
|
||||
|
||||
auto p_tensor_type = p_type.tensor_type();
|
||||
|
||||
c10::optional<at::ScalarType> scalar_type;
|
||||
if (p_tensor_type.has_elem_type()) {
|
||||
scalar_type = ONNXTypeToATenType(p_tensor_type.elem_type());
|
||||
}
|
||||
|
||||
auto v_type = TensorType::create(
|
||||
scalar_type,
|
||||
at::kCPU,
|
||||
c10::SymbolicShape(),
|
||||
c10::VaryingShape<c10::Stride>{},
|
||||
{});
|
||||
if (p_tensor_type.has_shape()) {
|
||||
std::vector<int64_t> sizes;
|
||||
auto p_shape = p_tensor_type.shape();
|
||||
|
||||
for (int i = 0; i < p_shape.dim_size(); ++i) {
|
||||
auto& dim = p_shape.dim(i);
|
||||
if (dim.has_dim_value()) {
|
||||
sizes.push_back(dim.dim_value());
|
||||
} else {
|
||||
// TODO: handle dim_param?
|
||||
return;
|
||||
}
|
||||
}
|
||||
v_type = TensorType::create(scalar_type, at::kCPU, sizes.size(), {});
|
||||
v_type = v_type->withSizes(sizes);
|
||||
}
|
||||
|
||||
v->setType(v_type);
|
||||
}
|
||||
|
||||
bool IsSupportedNode(const Node* n) {
|
||||
auto node_kind = n->kind();
|
||||
|
||||
if (!node_kind.is_onnx()) {
|
||||
// node kind is not ONNX, skipped.
|
||||
return false;
|
||||
}
|
||||
|
||||
if (node_kind == ::c10::onnx::SequenceAt ||
|
||||
node_kind == ::c10::onnx::SplitToSequence ||
|
||||
node_kind == ::c10::onnx::SequenceConstruct ||
|
||||
node_kind == ::c10::onnx::SequenceEmpty ||
|
||||
node_kind == ::c10::onnx::SequenceInsert ||
|
||||
node_kind == ::c10::onnx::ConcatFromSequence ||
|
||||
node_kind == ::c10::onnx::SequenceErase) {
|
||||
// TODO: ONNX unable to do shape inference for these ops.
|
||||
return false;
|
||||
}
|
||||
|
||||
if (node_kind == ::c10::onnx::ConstantOfShape) {
|
||||
// && n->input()->node()->kind() == ::c10::prim::ListConstruct
|
||||
// TODO: ONNX shape inference segfault.
|
||||
return false;
|
||||
}
|
||||
|
||||
if (node_kind == ::c10::onnx::Loop || node_kind == ::c10::onnx::If) {
|
||||
// TODO: Support Loop & If shape inference by propagating input shape to
|
||||
// block input.
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void ONNXShapeTypeInference(Node* n, int opset_version) {
|
||||
if (!IsSupportedNode(n)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Create a Graph containing only the single node n.
|
||||
// This graph is later converted to ONNX to run shape inference.
|
||||
auto n_graph = std::make_shared<Graph>();
|
||||
// Clone the node n for the new graph.
|
||||
auto clone_node = n_graph->createClone(n, [&n_graph](Value* v) {
|
||||
auto v_n = v->node();
|
||||
if (v_n->kind() == ::c10::onnx::Constant) {
|
||||
// Clone the input if it is constant.
|
||||
auto constant_n = n_graph->insertNode(
|
||||
n_graph->createClone(v_n, [](Value* v) { return v; }));
|
||||
return constant_n->output();
|
||||
} else {
|
||||
// If the input is not constant, we cannot depend on its value
|
||||
// in shape inference. Set it to graph input in the new graph,
|
||||
// and copy over metadata, such as datatype and shape.
|
||||
auto input = n_graph->addInput();
|
||||
input->copyMetadata(v);
|
||||
return input;
|
||||
}
|
||||
});
|
||||
n_graph->insertNode(clone_node);
|
||||
// Register all node outputs as graph outputs.
|
||||
for (auto output : clone_node->outputs()) {
|
||||
n_graph->registerOutput(output);
|
||||
}
|
||||
|
||||
// TODO: Some ops have conversion happen at Peephole pass.
|
||||
// The conversion here is incomplete for these ops.
|
||||
// e.g: ListConstruct, ListUnpack, etc.
|
||||
std::string model_str;
|
||||
RawDataExportMap export_map;
|
||||
std::tie(model_str, export_map) = export_onnx(
|
||||
n_graph,
|
||||
{},
|
||||
opset_version,
|
||||
{},
|
||||
false,
|
||||
onnx_torch::OperatorExportTypes::ONNX,
|
||||
true,
|
||||
true,
|
||||
{},
|
||||
true,
|
||||
false,
|
||||
std::string());
|
||||
onnx::ModelProto model_proto;
|
||||
model_proto.ParseFromString(model_str);
|
||||
|
||||
// infer shape
|
||||
onnx::shape_inference::InferShapes(model_proto);
|
||||
auto graph_proto = model_proto.graph();
|
||||
// inferred shapes are stored in value_info.
|
||||
for (size_t i = 0; i < graph_proto.value_info_size(); ++i) {
|
||||
auto v_info = graph_proto.value_info(i);
|
||||
// get data from value_info and updated original graph.
|
||||
for (size_t j = 0; j < clone_node->outputs().size(); ++j) {
|
||||
if (clone_node->output(j)->debugName() == v_info.name()) {
|
||||
UpdateTorchValueByOnnxValueInfo(n->output(j), v_info);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
15
torch/csrc/jit/passes/onnx/shape_type_inference.h
Normal file
15
torch/csrc/jit/passes/onnx/shape_type_inference.h
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
#pragma once
|
||||
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
// Utilize ONNX Shape Inference for node.
|
||||
// The node must have ONNX namespace, and is valid ONNX node accroding to spec.
|
||||
// On successful ONNX shape inference runs, the function updates output types of
|
||||
// n with inferred shape and type. Otherwise n is unchanged.
|
||||
TORCH_API void ONNXShapeTypeInference(Node* n, int opset_version);
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
@ -41,6 +41,7 @@
|
|||
#include <torch/csrc/jit/passes/onnx/prepare_inplace_ops_for_onnx.h>
|
||||
#include <torch/csrc/jit/passes/onnx/preprocess_for_onnx.h>
|
||||
#include <torch/csrc/jit/passes/onnx/scalar_type_analysis.h>
|
||||
#include <torch/csrc/jit/passes/onnx/shape_type_inference.h>
|
||||
#include <torch/csrc/jit/passes/onnx/unpack_quantized_weights.h>
|
||||
#include <torch/csrc/jit/passes/peephole.h>
|
||||
#include <torch/csrc/jit/passes/quantization/dedup_module_uses.h>
|
||||
|
|
@ -185,6 +186,7 @@ void initJITBindings(PyObject* module) {
|
|||
.def(
|
||||
"_jit_pass_onnx_prepare_inplace_ops_for_onnx",
|
||||
PrepareInplaceOpsForONNX)
|
||||
.def("_jit_pass_onnx_node_shape_type_inference", ONNXShapeTypeInference)
|
||||
.def("_jit_pass_fuse", FuseGraph)
|
||||
.def(
|
||||
"_jit_pass_dce",
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
#include <torch/csrc/jit/serialization/import_export_constants.h>
|
||||
#include <torch/csrc/jit/serialization/import_export_functions.h>
|
||||
#include <torch/csrc/jit/serialization/import_export_helpers.h>
|
||||
#include <torch/csrc/jit/serialization/onnx.h>
|
||||
#include <torch/csrc/onnx/onnx.h>
|
||||
|
||||
#include <ATen/core/functional.h>
|
||||
|
|
@ -22,7 +23,6 @@
|
|||
#include <memory>
|
||||
#include <regex>
|
||||
#include <set>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
namespace torch {
|
||||
|
|
@ -794,181 +794,6 @@ void GraphEncoder::EncodeTensor(
|
|||
}
|
||||
}
|
||||
|
||||
// Pretty printing for ONNX
|
||||
constexpr char indent_char = ' ';
|
||||
constexpr size_t indent_multiplier = 2;
|
||||
|
||||
std::string idt(size_t indent) {
|
||||
return std::string(indent * indent_multiplier, indent_char);
|
||||
}
|
||||
|
||||
std::string nlidt(size_t indent) {
|
||||
return std::string("\n") + idt(indent);
|
||||
}
|
||||
|
||||
void dump(const onnx::TensorProto& tensor, std::ostream& stream) {
|
||||
stream << "TensorProto shape: [";
|
||||
for (int i = 0; i < tensor.dims_size(); ++i) {
|
||||
stream << tensor.dims(i) << (i == tensor.dims_size() - 1 ? "" : " ");
|
||||
}
|
||||
stream << "]";
|
||||
}
|
||||
|
||||
void dump(const onnx::TensorShapeProto& shape, std::ostream& stream) {
|
||||
for (int i = 0; i < shape.dim_size(); ++i) {
|
||||
auto& dim = shape.dim(i);
|
||||
if (dim.has_dim_value()) {
|
||||
stream << dim.dim_value();
|
||||
} else {
|
||||
stream << "?";
|
||||
}
|
||||
stream << (i == shape.dim_size() - 1 ? "" : " ");
|
||||
}
|
||||
}
|
||||
|
||||
void dump(const onnx::TypeProto_Tensor& tensor_type, std::ostream& stream) {
|
||||
stream << "Tensor dims: ";
|
||||
dump(tensor_type.shape(), stream);
|
||||
}
|
||||
|
||||
void dump(const onnx::TypeProto& type, std::ostream& stream) {
|
||||
dump(type.tensor_type(), stream);
|
||||
}
|
||||
|
||||
void dump(const onnx::ValueInfoProto& value_info, std::ostream& stream) {
|
||||
stream << "{name: \"" << value_info.name() << "\", type:";
|
||||
dump(value_info.type(), stream);
|
||||
stream << "}";
|
||||
}
|
||||
|
||||
void dump(const onnx::GraphProto& graph, std::ostream& stream, size_t indent);
|
||||
|
||||
void dump(
|
||||
const onnx::AttributeProto& attr,
|
||||
std::ostream& stream,
|
||||
size_t indent) {
|
||||
stream << "{ name: '" << attr.name() << "', type: ";
|
||||
if (attr.has_f()) {
|
||||
stream << "float, value: " << attr.f();
|
||||
} else if (attr.has_i()) {
|
||||
stream << "int, value: " << attr.i();
|
||||
} else if (attr.has_s()) {
|
||||
stream << "string, value: '" << attr.s() << "'";
|
||||
} else if (attr.has_g()) {
|
||||
stream << "graph, value:\n";
|
||||
dump(attr.g(), stream, indent + 1);
|
||||
stream << nlidt(indent);
|
||||
} else if (attr.has_t()) {
|
||||
stream << "tensor, value:";
|
||||
dump(attr.t(), stream);
|
||||
} else if (attr.floats_size()) {
|
||||
stream << "floats, values: [";
|
||||
for (int i = 0; i < attr.floats_size(); ++i)
|
||||
stream << attr.floats(i) << (i == attr.floats_size() - 1 ? "" : " ");
|
||||
stream << "]";
|
||||
} else if (attr.ints_size()) {
|
||||
stream << "ints, values: [";
|
||||
for (int i = 0; i < attr.ints_size(); ++i)
|
||||
stream << attr.ints(i) << (i == attr.ints_size() - 1 ? "" : " ");
|
||||
stream << "]";
|
||||
} else if (attr.strings_size()) {
|
||||
stream << "strings, values: [";
|
||||
for (int i = 0; i < attr.strings_size(); ++i)
|
||||
stream << "'" << attr.strings(i) << "'"
|
||||
<< (i == attr.strings_size() - 1 ? "" : " ");
|
||||
stream << "]";
|
||||
} else if (attr.tensors_size()) {
|
||||
stream << "tensors, values: [";
|
||||
for (auto& t : attr.tensors()) {
|
||||
dump(t, stream);
|
||||
}
|
||||
stream << "]";
|
||||
} else if (attr.graphs_size()) {
|
||||
stream << "graphs, values: [";
|
||||
for (auto& g : attr.graphs()) {
|
||||
dump(g, stream, indent + 1);
|
||||
}
|
||||
stream << "]";
|
||||
} else {
|
||||
stream << "UNKNOWN";
|
||||
}
|
||||
stream << "}";
|
||||
}
|
||||
|
||||
void dump(const onnx::NodeProto& node, std::ostream& stream, size_t indent) {
|
||||
stream << "Node {type: \"" << node.op_type() << "\", inputs: [";
|
||||
for (int i = 0; i < node.input_size(); ++i) {
|
||||
stream << node.input(i) << (i == node.input_size() - 1 ? "" : ",");
|
||||
}
|
||||
stream << "], outputs: [";
|
||||
for (int i = 0; i < node.output_size(); ++i) {
|
||||
stream << node.output(i) << (i == node.output_size() - 1 ? "" : ",");
|
||||
}
|
||||
stream << "], attributes: [";
|
||||
for (int i = 0; i < node.attribute_size(); ++i) {
|
||||
dump(node.attribute(i), stream, indent + 1);
|
||||
stream << (i == node.attribute_size() - 1 ? "" : ",");
|
||||
}
|
||||
stream << "]}";
|
||||
}
|
||||
|
||||
void dump(const onnx::GraphProto& graph, std::ostream& stream, size_t indent) {
|
||||
stream << idt(indent) << "GraphProto {" << nlidt(indent + 1) << "name: \""
|
||||
<< graph.name() << "\"" << nlidt(indent + 1) << "inputs: [";
|
||||
for (int i = 0; i < graph.input_size(); ++i) {
|
||||
dump(graph.input(i), stream);
|
||||
stream << (i == graph.input_size() - 1 ? "" : ",");
|
||||
}
|
||||
stream << "]" << nlidt(indent + 1) << "outputs: [";
|
||||
for (int i = 0; i < graph.output_size(); ++i) {
|
||||
dump(graph.output(i), stream);
|
||||
stream << (i == graph.output_size() - 1 ? "" : ",");
|
||||
}
|
||||
stream << "]" << nlidt(indent + 1) << "initializers: [";
|
||||
for (int i = 0; i < graph.initializer_size(); ++i) {
|
||||
dump(graph.initializer(i), stream);
|
||||
stream << (i == graph.initializer_size() - 1 ? "" : ",");
|
||||
}
|
||||
stream << "]" << nlidt(indent + 1) << "nodes: [" << nlidt(indent + 2);
|
||||
for (int i = 0; i < graph.node_size(); ++i) {
|
||||
dump(graph.node(i), stream, indent + 2);
|
||||
if (i != graph.node_size() - 1)
|
||||
stream << "," << nlidt(indent + 2);
|
||||
}
|
||||
stream << nlidt(indent + 1) << "]\n" << idt(indent) << "}\n";
|
||||
}
|
||||
|
||||
void dump(
|
||||
const onnx::OperatorSetIdProto& operator_set_id,
|
||||
std::ostream& stream) {
|
||||
stream << "OperatorSetIdProto { domain: " << operator_set_id.domain() << "}";
|
||||
}
|
||||
|
||||
void dump(const onnx::ModelProto& model, std::ostream& stream, size_t indent) {
|
||||
stream << idt(indent) << "ModelProto {" << nlidt(indent + 1)
|
||||
<< "producer_name: \"" << model.producer_name() << "\""
|
||||
<< nlidt(indent + 1) << "domain: \"" << model.domain() << "\""
|
||||
<< nlidt(indent + 1) << "doc_string: \"" << model.doc_string() << "\"";
|
||||
if (model.has_graph()) {
|
||||
stream << nlidt(indent + 1) << "graph:\n";
|
||||
dump(model.graph(), stream, indent + 2);
|
||||
}
|
||||
if (model.opset_import_size()) {
|
||||
stream << idt(indent + 1) << "opset_import: [";
|
||||
for (auto& opset_imp : model.opset_import()) {
|
||||
dump(opset_imp, stream);
|
||||
}
|
||||
stream << "],\n";
|
||||
}
|
||||
stream << idt(indent) << "}\n";
|
||||
}
|
||||
|
||||
std::string prettyPrint(const onnx::ModelProto& model) {
|
||||
std::ostringstream ss;
|
||||
dump(model, ss, 0);
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::string pretty_print_onnx(
|
||||
|
|
|
|||
212
torch/csrc/jit/serialization/onnx.cpp
Normal file
212
torch/csrc/jit/serialization/onnx.cpp
Normal file
|
|
@ -0,0 +1,212 @@
|
|||
#include <torch/csrc/jit/serialization/onnx.h>
|
||||
#include <torch/csrc/onnx/onnx.h>
|
||||
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
namespace {
|
||||
namespace onnx_torch = ::torch::onnx;
|
||||
namespace onnx = ::ONNX_NAMESPACE;
|
||||
|
||||
// Pretty printing for ONNX
|
||||
constexpr char indent_char = ' ';
|
||||
constexpr size_t indent_multiplier = 2;
|
||||
|
||||
std::string idt(size_t indent) {
|
||||
return std::string(indent * indent_multiplier, indent_char);
|
||||
}
|
||||
|
||||
std::string nlidt(size_t indent) {
|
||||
return std::string("\n") + idt(indent);
|
||||
}
|
||||
|
||||
void dump(const onnx::TensorProto& tensor, std::ostream& stream) {
|
||||
stream << "TensorProto shape: [";
|
||||
for (int i = 0; i < tensor.dims_size(); ++i) {
|
||||
stream << tensor.dims(i) << (i == tensor.dims_size() - 1 ? "" : " ");
|
||||
}
|
||||
stream << "]";
|
||||
}
|
||||
|
||||
void dump(const onnx::TensorShapeProto& shape, std::ostream& stream) {
|
||||
for (int i = 0; i < shape.dim_size(); ++i) {
|
||||
auto& dim = shape.dim(i);
|
||||
if (dim.has_dim_value()) {
|
||||
stream << dim.dim_value();
|
||||
} else {
|
||||
stream << "?";
|
||||
}
|
||||
stream << (i == shape.dim_size() - 1 ? "" : " ");
|
||||
}
|
||||
}
|
||||
|
||||
void dump(const onnx::TypeProto_Tensor& tensor_type, std::ostream& stream) {
|
||||
stream << "Tensor dtype: ";
|
||||
if (tensor_type.has_elem_type()) {
|
||||
stream << tensor_type.elem_type();
|
||||
} else {
|
||||
stream << "None.";
|
||||
}
|
||||
stream << ", ";
|
||||
stream << "Tensor dims: ";
|
||||
if (tensor_type.has_shape()) {
|
||||
dump(tensor_type.shape(), stream);
|
||||
} else {
|
||||
stream << "None.";
|
||||
}
|
||||
}
|
||||
|
||||
void dump(const onnx::TypeProto& type, std::ostream& stream) {
|
||||
dump(type.tensor_type(), stream);
|
||||
}
|
||||
|
||||
void dump(const onnx::ValueInfoProto& value_info, std::ostream& stream) {
|
||||
stream << "{name: \"" << value_info.name() << "\", type:";
|
||||
dump(value_info.type(), stream);
|
||||
stream << "}";
|
||||
}
|
||||
|
||||
void dump(const onnx::GraphProto& graph, std::ostream& stream, size_t indent);
|
||||
|
||||
void dump(
|
||||
const onnx::AttributeProto& attr,
|
||||
std::ostream& stream,
|
||||
size_t indent) {
|
||||
stream << "{ name: '" << attr.name() << "', type: ";
|
||||
if (attr.has_f()) {
|
||||
stream << "float, value: " << attr.f();
|
||||
} else if (attr.has_i()) {
|
||||
stream << "int, value: " << attr.i();
|
||||
} else if (attr.has_s()) {
|
||||
stream << "string, value: '" << attr.s() << "'";
|
||||
} else if (attr.has_g()) {
|
||||
stream << "graph, value:\n";
|
||||
dump(attr.g(), stream, indent + 1);
|
||||
stream << nlidt(indent);
|
||||
} else if (attr.has_t()) {
|
||||
stream << "tensor, value:";
|
||||
dump(attr.t(), stream);
|
||||
} else if (attr.floats_size()) {
|
||||
stream << "floats, values: [";
|
||||
for (int i = 0; i < attr.floats_size(); ++i) {
|
||||
stream << attr.floats(i) << (i == attr.floats_size() - 1 ? "" : " ");
|
||||
}
|
||||
stream << "]";
|
||||
} else if (attr.ints_size()) {
|
||||
stream << "ints, values: [";
|
||||
for (int i = 0; i < attr.ints_size(); ++i) {
|
||||
stream << attr.ints(i) << (i == attr.ints_size() - 1 ? "" : " ");
|
||||
}
|
||||
stream << "]";
|
||||
} else if (attr.strings_size()) {
|
||||
stream << "strings, values: [";
|
||||
for (int i = 0; i < attr.strings_size(); ++i) {
|
||||
stream << "'" << attr.strings(i) << "'"
|
||||
<< (i == attr.strings_size() - 1 ? "" : " ");
|
||||
}
|
||||
stream << "]";
|
||||
} else if (attr.tensors_size()) {
|
||||
stream << "tensors, values: [";
|
||||
for (auto& t : attr.tensors()) {
|
||||
dump(t, stream);
|
||||
}
|
||||
stream << "]";
|
||||
} else if (attr.graphs_size()) {
|
||||
stream << "graphs, values: [";
|
||||
for (auto& g : attr.graphs()) {
|
||||
dump(g, stream, indent + 1);
|
||||
}
|
||||
stream << "]";
|
||||
} else {
|
||||
stream << "UNKNOWN";
|
||||
}
|
||||
stream << "}";
|
||||
}
|
||||
|
||||
void dump(const onnx::NodeProto& node, std::ostream& stream, size_t indent) {
|
||||
stream << "Node {type: \"" << node.op_type() << "\", inputs: [";
|
||||
for (int i = 0; i < node.input_size(); ++i) {
|
||||
stream << node.input(i) << (i == node.input_size() - 1 ? "" : ",");
|
||||
}
|
||||
stream << "], outputs: [";
|
||||
for (int i = 0; i < node.output_size(); ++i) {
|
||||
stream << node.output(i) << (i == node.output_size() - 1 ? "" : ",");
|
||||
}
|
||||
stream << "], attributes: [";
|
||||
for (int i = 0; i < node.attribute_size(); ++i) {
|
||||
dump(node.attribute(i), stream, indent + 1);
|
||||
stream << (i == node.attribute_size() - 1 ? "" : ",");
|
||||
}
|
||||
stream << "]}";
|
||||
}
|
||||
|
||||
void dump(const onnx::GraphProto& graph, std::ostream& stream, size_t indent) {
|
||||
stream << idt(indent) << "GraphProto {" << nlidt(indent + 1) << "name: \""
|
||||
<< graph.name() << "\"" << nlidt(indent + 1) << "inputs: [";
|
||||
for (int i = 0; i < graph.input_size(); ++i) {
|
||||
dump(graph.input(i), stream);
|
||||
stream << (i == graph.input_size() - 1 ? "" : ",");
|
||||
}
|
||||
stream << "]" << nlidt(indent + 1) << "outputs: [";
|
||||
for (int i = 0; i < graph.output_size(); ++i) {
|
||||
dump(graph.output(i), stream);
|
||||
stream << (i == graph.output_size() - 1 ? "" : ",");
|
||||
}
|
||||
stream << "]" << nlidt(indent + 1) << "value_infos: [";
|
||||
for (int i = 0; i < graph.value_info_size(); ++i) {
|
||||
dump(graph.value_info(i), stream);
|
||||
stream << (i == graph.value_info_size() - 1 ? "" : ",");
|
||||
}
|
||||
stream << "]" << nlidt(indent + 1) << "initializers: [";
|
||||
for (int i = 0; i < graph.initializer_size(); ++i) {
|
||||
dump(graph.initializer(i), stream);
|
||||
stream << (i == graph.initializer_size() - 1 ? "" : ",");
|
||||
}
|
||||
stream << "]" << nlidt(indent + 1) << "nodes: [" << nlidt(indent + 2);
|
||||
for (int i = 0; i < graph.node_size(); ++i) {
|
||||
dump(graph.node(i), stream, indent + 2);
|
||||
if (i != graph.node_size() - 1) {
|
||||
stream << "," << nlidt(indent + 2);
|
||||
}
|
||||
}
|
||||
stream << nlidt(indent + 1) << "]\n" << idt(indent) << "}\n";
|
||||
}
|
||||
|
||||
void dump(
|
||||
const onnx::OperatorSetIdProto& operator_set_id,
|
||||
std::ostream& stream) {
|
||||
stream << "OperatorSetIdProto { domain: " << operator_set_id.domain()
|
||||
<< ", version: " << operator_set_id.version() << "}";
|
||||
}
|
||||
|
||||
void dump(const onnx::ModelProto& model, std::ostream& stream, size_t indent) {
|
||||
stream << idt(indent) << "ModelProto {" << nlidt(indent + 1)
|
||||
<< "producer_name: \"" << model.producer_name() << "\""
|
||||
<< nlidt(indent + 1) << "domain: \"" << model.domain() << "\""
|
||||
<< nlidt(indent + 1) << "doc_string: \"" << model.doc_string() << "\"";
|
||||
if (model.has_graph()) {
|
||||
stream << nlidt(indent + 1) << "graph:\n";
|
||||
dump(model.graph(), stream, indent + 2);
|
||||
}
|
||||
if (model.opset_import_size()) {
|
||||
stream << idt(indent + 1) << "opset_import: [";
|
||||
for (auto& opset_imp : model.opset_import()) {
|
||||
dump(opset_imp, stream);
|
||||
}
|
||||
stream << "],\n";
|
||||
}
|
||||
stream << idt(indent) << "}\n";
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::string prettyPrint(const ::ONNX_NAMESPACE::ModelProto& model) {
|
||||
std::ostringstream ss;
|
||||
dump(model, ss, 0);
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
12
torch/csrc/jit/serialization/onnx.h
Normal file
12
torch/csrc/jit/serialization/onnx.h
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
#pragma once
|
||||
|
||||
#include <onnx/onnx_pb.h>
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
TORCH_API std::string prettyPrint(const ::ONNX_NAMESPACE::ModelProto& model);
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
@ -517,6 +517,12 @@ def _set_training_mode(training_mode):
|
|||
global _training_mode
|
||||
_training_mode = training_mode
|
||||
|
||||
_onnx_shape_inference = False
|
||||
def _set_onnx_shape_inference(onnx_shape_inference):
|
||||
global _onnx_shape_inference
|
||||
_onnx_shape_inference = onnx_shape_inference
|
||||
|
||||
|
||||
# Metaprogram symbolics for each ATen native specialized cast operator.
|
||||
# For e.g. we specify a function named `_cast_uint8_t` that instantiates an
|
||||
# ONNX cast node with `to` attribute 'UINT8'
|
||||
|
|
|
|||
|
|
@ -521,7 +521,8 @@ def _export(model, args, f, export_params=True, verbose=False, training=None,
|
|||
opset_version=None, _retain_param_name=False, do_constant_folding=True,
|
||||
strip_doc_string=True, dynamic_axes=None, keep_initializers_as_inputs=None,
|
||||
fixed_batch_size=False, custom_opsets=None, add_node_names=True,
|
||||
enable_onnx_checker=True, use_external_data_format=False):
|
||||
enable_onnx_checker=True, use_external_data_format=False,
|
||||
onnx_shape_inference=False):
|
||||
if isinstance(model, torch.nn.DataParallel):
|
||||
raise ValueError('torch.nn.DataParallel is not supported by ONNX '
|
||||
'exporter, please use \'attribute\' module to '
|
||||
|
|
@ -531,6 +532,9 @@ def _export(model, args, f, export_params=True, verbose=False, training=None,
|
|||
assert __IN_ONNX_EXPORT is False
|
||||
__IN_ONNX_EXPORT = True
|
||||
try:
|
||||
from torch.onnx.symbolic_helper import _set_onnx_shape_inference
|
||||
_set_onnx_shape_inference(onnx_shape_inference)
|
||||
|
||||
from torch.onnx.symbolic_helper import _default_onnx_opset_version, _set_opset_version
|
||||
from torch.onnx.symbolic_helper import _set_operator_export_type
|
||||
if opset_version is None:
|
||||
|
|
@ -757,6 +761,12 @@ def _graph_op(g, opname, *raw_args, **kwargs):
|
|||
|
||||
args = list(const_if_tensor(arg) for arg in raw_args)
|
||||
n = g.insertNode(_newNode(g, opname, outputs, *args, **kwargs))
|
||||
|
||||
from torch.onnx.symbolic_helper import _onnx_shape_inference
|
||||
if _onnx_shape_inference:
|
||||
from torch.onnx.symbolic_helper import _export_onnx_opset_version as opset_version
|
||||
torch._C._jit_pass_onnx_node_shape_type_inference(n, opset_version)
|
||||
|
||||
if outputs == 1:
|
||||
return n.output()
|
||||
return tuple(o for o in n.outputs())
|
||||
|
|
@ -852,6 +862,10 @@ def _run_symbolic_function(g, n, inputs, env, operator_export_type=OperatorExpor
|
|||
new_block = new_node.addBlock()
|
||||
torch._C._jit_pass_onnx_block(b, new_block, operator_export_type, env)
|
||||
new_op_outputs = torch._C._jit_pass_fixup_onnx_controlflow_node(new_node, opset_version)
|
||||
# Process Loop and If after subblock is converted.
|
||||
from torch.onnx.symbolic_helper import _onnx_shape_inference
|
||||
if _onnx_shape_inference:
|
||||
torch._C._jit_pass_onnx_node_shape_type_inference(new_node, opset_version)
|
||||
return new_op_outputs
|
||||
else:
|
||||
symbolic_name = 'prim_' + op_name
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user