mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ONNX] Add dim_param support in export with onnx shape inference (#44920)
Summary: * Support propagating `dim_param` in ONNX by encoding as `ShapeSymbol` in `SymbolicShape` of outputs. If export is called with `dynamic_axes` provided, shape inference will start with these axes set as dynamic. * Add new test file `test_pytorch_onnx_shape_inference.py`, reusing all test cases from `test_pytorch_onnx_onnxruntime.py`, but focus on validating shape for all nodes in graph. Currently this is not enabled in the CI, since there are still quite some existing issues and corner cases to fix. The test is default to run only at opset 12. * Bug fixes, such as div, _len, and peephole.cpp passes for PackPadded, and LogSoftmaxCrossEntropy. * This PR depends on existing PR such as 44332. Pull Request resolved: https://github.com/pytorch/pytorch/pull/44920 Reviewed By: eellison Differential Revision: D23958398 Pulled By: bzinodev fbshipit-source-id: 00479d9bd19c867d526769a15ba97ec16d56e51d
This commit is contained in:
parent
ffcb0989e7
commit
3da4cea658
|
|
@ -70,4 +70,6 @@ if [[ "$BUILD_ENVIRONMENT" == *ort_test2* ]]; then
|
|||
pytest "${args[@]}" \
|
||||
"$top_dir/test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset$i"
|
||||
done
|
||||
pytest "${args[@]}" \
|
||||
"$top_dir/test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset12_onnx_shape_inference"
|
||||
fi
|
||||
|
|
|
|||
|
|
@ -640,6 +640,20 @@ class TestONNXRuntime(unittest.TestCase):
|
|||
self.run_test(TraceModel(), (x1, x2, x3), atol=10e-5)
|
||||
self.run_test(ScriptModel(), (x1, x2, x3), atol=10e-5)
|
||||
|
||||
def test_conv_shape_inference(self):
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
self.conv2 = torch.nn.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1))
|
||||
|
||||
def forward(self, input):
|
||||
return self.conv2(input) + 2
|
||||
|
||||
x = torch.randn(20, 16, 50, 100)
|
||||
self.run_test(Model(), x, atol=10e-5,
|
||||
input_names=['x'],
|
||||
dynamic_axes={'x': [0]})
|
||||
|
||||
def test_conv_transpose(self):
|
||||
class TraceModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
|
|
|||
78
test/onnx/test_pytorch_onnx_shape_inference.py
Normal file
78
test/onnx/test_pytorch_onnx_shape_inference.py
Normal file
|
|
@ -0,0 +1,78 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import unittest
|
||||
import torch
|
||||
|
||||
import copy
|
||||
|
||||
import test_pytorch_onnx_onnxruntime
|
||||
from test_pytorch_onnx_onnxruntime import TestONNXRuntime
|
||||
from torch.onnx import utils, OperatorExportTypes, TrainingMode
|
||||
from torch.onnx.utils import _validate_dynamic_axes
|
||||
from torch.onnx.symbolic_helper import (_set_opset_version, _set_operator_export_type,
|
||||
_set_onnx_shape_inference, _set_training_mode,
|
||||
_is_tensor_list, _is_tensor, _is_none)
|
||||
|
||||
|
||||
def verify_inferred_shape(graph):
|
||||
# Check every node in graph has type properly assigned.
|
||||
for n in graph.nodes():
|
||||
for out in n.outputs():
|
||||
if not _is_tensor_list(out) and not _is_tensor(out) and not _is_none(out):
|
||||
raise RuntimeError("Output of node is neither type Tensor nor type list of Tensor: ", out)
|
||||
if _is_tensor(out) and out.type().scalarType() is None:
|
||||
raise RuntimeError("Output of node does not have type assigned", out)
|
||||
if _is_tensor(out) and out.type().dim() is None:
|
||||
raise RuntimeError("Output of node does not have shape assigned", out)
|
||||
|
||||
|
||||
def run_model_test(self, model, batch_size=2, state_dict=None,
|
||||
input=None, use_gpu=True, rtol=0.001, atol=1e-7,
|
||||
example_outputs=None, do_constant_folding=True,
|
||||
dynamic_axes=None, test_with_inputs=None,
|
||||
input_names=None, output_names=None,
|
||||
fixed_batch_size=False):
|
||||
model.eval()
|
||||
|
||||
if input is None:
|
||||
input = torch.randn(batch_size, 3, 224, 224, requires_grad=True)
|
||||
|
||||
with torch.no_grad():
|
||||
if isinstance(input, torch.Tensor):
|
||||
input = (input,)
|
||||
# In-place operators will update input tensor data as well.
|
||||
# Thus inputs are replicated before every forward call.
|
||||
input_copy = copy.deepcopy(input)
|
||||
output = model(*input_copy)
|
||||
if isinstance(output, torch.Tensor):
|
||||
output = (output,)
|
||||
|
||||
_set_opset_version(self.opset_version)
|
||||
_set_operator_export_type(OperatorExportTypes.ONNX)
|
||||
_set_onnx_shape_inference(True)
|
||||
_set_training_mode(False)
|
||||
if dynamic_axes is None:
|
||||
dynamic_axes = {}
|
||||
_validate_dynamic_axes(dynamic_axes, model, input_names, output_names)
|
||||
|
||||
input_copy = copy.deepcopy(input)
|
||||
graph, _, _ = utils._model_to_graph(model, input_copy,
|
||||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
operator_export_type=OperatorExportTypes.ONNX,
|
||||
example_outputs=output,
|
||||
do_constant_folding=do_constant_folding,
|
||||
training=TrainingMode.EVAL,
|
||||
use_new_jit_passes=self.use_new_jit_passes,
|
||||
dynamic_axes=dynamic_axes)
|
||||
verify_inferred_shape(graph)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
TestONNXRuntime.opset_version = 12
|
||||
test_pytorch_onnx_onnxruntime.run_model_test = run_model_test
|
||||
|
||||
unittest.main()
|
||||
|
|
@ -310,7 +310,13 @@ void pushPackingPastRnn(Block* b) {
|
|||
std::vector<int64_t> new_sizes;
|
||||
new_sizes.push_back(*oldType->sizes()[0]);
|
||||
new_sizes.push_back(*oldType->sizes()[1]);
|
||||
if (next->kind() == onnx::Reshape) {
|
||||
// bidirection
|
||||
new_sizes.push_back(rnn->i(attr::hidden_size) * 2);
|
||||
} else {
|
||||
// unidirection
|
||||
new_sizes.push_back(rnn->i(attr::hidden_size));
|
||||
}
|
||||
TensorTypePtr newType = TensorType::createContiguous(
|
||||
*oldType->scalarType(), *oldType->device(), new_sizes);
|
||||
next->outputs().at(0)->setType(newType);
|
||||
|
|
@ -747,6 +753,7 @@ static void fuseLogSoftmaxNllLoss(Block* b) {
|
|||
prim::ListConstruct);
|
||||
// make output of reshape the output of nllloss
|
||||
nllloss_output->replaceAllUsesWith(origNllLossNode);
|
||||
origNllLossNode->output(0)->copyMetadata(nllloss_output->output(0));
|
||||
}
|
||||
} else {
|
||||
continue;
|
||||
|
|
|
|||
|
|
@ -40,8 +40,8 @@ TypePtr MergeInferredType(TypePtr existing_type, TypePtr inferred_type) {
|
|||
return new_tensor_type;
|
||||
}
|
||||
auto type = old_tensor_type;
|
||||
if (new_tensor_type->sizes().isComplete()) {
|
||||
type = type->withSizes(new_tensor_type->sizes().concrete_sizes().value());
|
||||
if (new_tensor_type->dim()) {
|
||||
type = type->withSymbolicShapes(new_tensor_type->symbolic_sizes());
|
||||
}
|
||||
if (new_tensor_type->scalarType().has_value()) {
|
||||
type = type->withScalarType(new_tensor_type->scalarType());
|
||||
|
|
@ -69,7 +69,8 @@ namespace onnx_torch = ::torch::onnx;
|
|||
namespace onnx = ::ONNX_NAMESPACE;
|
||||
|
||||
TensorTypePtr TorchTensorTypeFromONNX(
|
||||
const onnx::TypeProto_Tensor& onnx_tensor_type) {
|
||||
const onnx::TypeProto_Tensor& onnx_tensor_type,
|
||||
const SymbolDimMap& symbol_map) {
|
||||
c10::optional<at::ScalarType> scalar_type;
|
||||
if (onnx_tensor_type.has_elem_type()) {
|
||||
scalar_type = ONNXTypeToATenType(onnx_tensor_type.elem_type());
|
||||
|
|
@ -82,33 +83,51 @@ TensorTypePtr TorchTensorTypeFromONNX(
|
|||
c10::VaryingShape<c10::Stride>{},
|
||||
{});
|
||||
if (onnx_tensor_type.has_shape()) {
|
||||
std::vector<int64_t> sizes;
|
||||
std::vector<c10::ShapeSymbol> sizes;
|
||||
auto onnx_shape = onnx_tensor_type.shape();
|
||||
|
||||
for (int i = 0; i < onnx_shape.dim_size(); ++i) {
|
||||
auto& dim = onnx_shape.dim(i);
|
||||
if (dim.has_dim_value()) {
|
||||
sizes.push_back(dim.dim_value());
|
||||
sizes.emplace_back(c10::ShapeSymbol::fromStaticSize(dim.dim_value()));
|
||||
} else {
|
||||
// TODO: handle dim_param?
|
||||
return v_type;
|
||||
GRAPH_UPDATE("Got dim_param:", dim.dim_param());
|
||||
c10::optional<c10::ShapeSymbol> sym = c10::nullopt;
|
||||
for (auto pair : symbol_map) {
|
||||
if (pair.second == dim.dim_param()) {
|
||||
sym = pair.first;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!sym) {
|
||||
sym = c10::ShapeSymbol::newSymbol();
|
||||
}
|
||||
sizes.emplace_back(sym.value());
|
||||
}
|
||||
}
|
||||
v_type = TensorType::create(scalar_type, at::kCPU, sizes.size(), {});
|
||||
v_type = v_type->withSizes(sizes);
|
||||
v_type = v_type->withSymbolicShapes(c10::SymbolicShape(sizes));
|
||||
|
||||
if (v_type->sizes().concrete_sizes().has_value()) {
|
||||
// Populate strides based on sizes info, if sizes are all static.
|
||||
// Creating strides ensures yielding True for isCompleteTensor.
|
||||
v_type = v_type->contiguous();
|
||||
}
|
||||
}
|
||||
|
||||
return v_type;
|
||||
}
|
||||
|
||||
ListTypePtr TorchListTypeFromONNX(
|
||||
const onnx::TypeProto_Sequence& onnx_sequence_type) {
|
||||
const onnx::TypeProto_Sequence& onnx_sequence_type,
|
||||
SymbolDimMap symbol_map) {
|
||||
c10::optional<at::ScalarType> scalar_type;
|
||||
if (onnx_sequence_type.has_elem_type()) {
|
||||
auto onnx_seq_elem_type = onnx_sequence_type.elem_type();
|
||||
if (onnx_seq_elem_type.has_tensor_type()) {
|
||||
auto onnx_tensor_type = onnx_seq_elem_type.tensor_type();
|
||||
auto v_tensor_type = TorchTensorTypeFromONNX(onnx_tensor_type);
|
||||
auto v_tensor_type =
|
||||
TorchTensorTypeFromONNX(onnx_tensor_type, symbol_map);
|
||||
auto v_type = ListType::create(v_tensor_type);
|
||||
return v_type;
|
||||
}
|
||||
|
|
@ -118,21 +137,24 @@ ListTypePtr TorchListTypeFromONNX(
|
|||
|
||||
void UpdateTorchValueByOnnxValueInfo(
|
||||
Value* v,
|
||||
const onnx::ValueInfoProto& p_info) {
|
||||
const onnx::ValueInfoProto& p_info,
|
||||
SymbolDimMap symbol_map) {
|
||||
if (!p_info.has_type()) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto p_type = p_info.type();
|
||||
if (p_type.has_tensor_type()) {
|
||||
auto torch_tensor_type = TorchTensorTypeFromONNX(p_type.tensor_type());
|
||||
auto torch_tensor_type =
|
||||
TorchTensorTypeFromONNX(p_type.tensor_type(), symbol_map);
|
||||
if (torch_tensor_type) {
|
||||
v->setType(torch_tensor_type);
|
||||
v->setType(MergeInferredType(v->type(), torch_tensor_type));
|
||||
}
|
||||
} else if (p_type.has_sequence_type()) {
|
||||
auto torch_list_type = TorchListTypeFromONNX(p_type.sequence_type());
|
||||
auto torch_list_type =
|
||||
TorchListTypeFromONNX(p_type.sequence_type(), symbol_map);
|
||||
if (torch_list_type) {
|
||||
v->setType(torch_list_type);
|
||||
v->setType(MergeInferredType(v->type(), torch_list_type));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -148,10 +170,18 @@ bool IsSupportedNode(const Node* n) {
|
|||
// Skip when block size is zero. This is when the node is first created,
|
||||
// doesn't have subblocks attached yet. Run shape inference for these nodes
|
||||
// when the subgraph has already completed shape inferencing.
|
||||
if ((node_kind == ::c10::onnx::Loop || node_kind == ::c10::onnx::If) &&
|
||||
n->blocks().size() == 0) {
|
||||
if (node_kind == ::c10::onnx::Loop || node_kind == ::c10::onnx::If) {
|
||||
if (n->blocks().size() == 0) {
|
||||
return false;
|
||||
}
|
||||
for (auto b : n->blocks()) {
|
||||
for (auto b_n : b->nodes()) {
|
||||
if (!IsSupportedNode(b_n)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
|
@ -234,9 +264,10 @@ bool IsGraphValidForInference(std::shared_ptr<Graph> graph) {
|
|||
void ConvertGraphToONNXProto(
|
||||
std::shared_ptr<Graph> graph,
|
||||
std::shared_ptr<onnx::ModelProto>& model_proto,
|
||||
SymbolDimMap& symbol_map,
|
||||
int opset_version) {
|
||||
RawDataExportMap export_map;
|
||||
std::tie(model_proto, export_map) = export_onnx(
|
||||
std::tie(model_proto, export_map, symbol_map) = export_onnx(
|
||||
graph,
|
||||
{},
|
||||
opset_version,
|
||||
|
|
@ -282,7 +313,8 @@ void SpecialPostProcess(Node* n) {
|
|||
void UpdateOutputTypeByONNXProto(
|
||||
Node* n,
|
||||
Node* clone_node,
|
||||
const onnx::ModelProto& model_proto) {
|
||||
const onnx::ModelProto& model_proto,
|
||||
SymbolDimMap symbol_map) {
|
||||
auto graph_proto = model_proto.graph();
|
||||
// inferred shapes are stored in value_info.
|
||||
for (size_t i = 0; i < graph_proto.value_info_size(); ++i) {
|
||||
|
|
@ -290,12 +322,10 @@ void UpdateOutputTypeByONNXProto(
|
|||
// get data from value_info and updated original graph.
|
||||
for (size_t j = 0; j < clone_node->outputs().size(); ++j) {
|
||||
if (clone_node->output(j)->debugName() == v_info.name()) {
|
||||
UpdateTorchValueByOnnxValueInfo(n->output(j), v_info);
|
||||
UpdateTorchValueByOnnxValueInfo(n->output(j), v_info, symbol_map);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
SpecialPostProcess(n);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
|
@ -320,26 +350,94 @@ void ONNXShapeTypeInference(Node* n, int opset_version) {
|
|||
GRAPH_DEBUG(
|
||||
"Cloned torch graph to run shape inference: ", n_graph->toString());
|
||||
|
||||
if (!IsGraphValidForInference(n_graph)) {
|
||||
GRAPH_UPDATE("Skipping ONNX shape inference for this node.");
|
||||
return;
|
||||
}
|
||||
|
||||
if (IsGraphValidForInference(n_graph)) {
|
||||
// TODO: Some ops have conversion happen at Peephole pass.
|
||||
// The conversion here is incomplete for these ops.
|
||||
// e.g: ListConstruct, ListUnpack, etc.
|
||||
std::shared_ptr<onnx::ModelProto> model_proto;
|
||||
ConvertGraphToONNXProto(n_graph, model_proto, opset_version);
|
||||
GRAPH_DEBUG("ONNX graph to run shape inference: ", prettyPrint(*model_proto));
|
||||
SymbolDimMap symbol_map;
|
||||
ConvertGraphToONNXProto(n_graph, model_proto, symbol_map, opset_version);
|
||||
GRAPH_DEBUG(
|
||||
"ONNX graph to run shape inference: ", prettyPrint(*model_proto));
|
||||
|
||||
// infer shape
|
||||
onnx::shape_inference::InferShapes(*model_proto);
|
||||
GRAPH_DEBUG("ONNX graph after shape inference: ", prettyPrint(*model_proto));
|
||||
GRAPH_DEBUG(
|
||||
"ONNX graph after shape inference: ", prettyPrint(*model_proto));
|
||||
|
||||
UpdateOutputTypeByONNXProto(n, clone_node, *model_proto);
|
||||
UpdateOutputTypeByONNXProto(n, clone_node, *model_proto, symbol_map);
|
||||
}
|
||||
|
||||
SpecialPostProcess(n);
|
||||
GRAPH_DEBUG(
|
||||
"Torch graph after shape inference:", n->owningGraph()->toString());
|
||||
}
|
||||
|
||||
void ONNXSetDynamicInputShape(
|
||||
std::shared_ptr<Graph>& graph,
|
||||
const std::unordered_map<
|
||||
std::string,
|
||||
std::unordered_map<int64_t, std::string>>& dynamic_axes,
|
||||
const std::vector<std::string>& input_names) {
|
||||
GRAPH_UPDATE("ONNX set dynamic input shape.");
|
||||
GRAPH_UPDATE("dynamic axes tensor names:", [&]() {
|
||||
std::vector<std::string> res(dynamic_axes.size());
|
||||
std::transform(
|
||||
dynamic_axes.begin(), dynamic_axes.end(), res.begin(), [](auto pair) {
|
||||
return pair.first;
|
||||
});
|
||||
return res;
|
||||
}());
|
||||
|
||||
std::map<std::string, ::c10::ShapeSymbol> name_to_sym;
|
||||
|
||||
for (int i = 0; i < input_names.size(); ++i) {
|
||||
auto input_name = input_names[i];
|
||||
if (dynamic_axes.find(input_name) != dynamic_axes.end()) {
|
||||
auto axes_names = dynamic_axes.find(input_name)->second;
|
||||
TORCH_INTERNAL_ASSERT(i < graph->inputs().size());
|
||||
auto input_tensor_type = graph->inputs()[i]->type()->cast<TensorType>();
|
||||
if (!input_tensor_type) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto shape = input_tensor_type->symbolic_sizes().sizes().value();
|
||||
|
||||
for (auto pair : axes_names) {
|
||||
auto axis = pair.first;
|
||||
auto name = pair.second;
|
||||
if (name_to_sym.find(name) == name_to_sym.end()) {
|
||||
name_to_sym[name] = ::c10::ShapeSymbol::newSymbol();
|
||||
}
|
||||
shape[axis] = name_to_sym[name];
|
||||
}
|
||||
|
||||
graph->inputs()[i]->setType(
|
||||
input_tensor_type->withSymbolicShapes(::c10::SymbolicShape(shape)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ONNXAssignOutputShape(
|
||||
std::shared_ptr<Graph>& graph,
|
||||
at::ArrayRef<at::Tensor> outputs,
|
||||
bool onnx_shape_inference) {
|
||||
TORCH_INTERNAL_ASSERT(graph->outputs().size() == outputs.size());
|
||||
for (size_t i = 0; i < outputs.size(); ++i) {
|
||||
if (onnx_shape_inference) {
|
||||
graph->outputs()[i]->setType(MergeInferredType(
|
||||
TensorType::create(outputs[i]), graph->outputs()[i]->type()));
|
||||
} else {
|
||||
graph->outputs()[i]->inferTypeFrom(outputs[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ONNXShapeTypeInference(std::shared_ptr<Graph>& graph, int opset_version) {
|
||||
for (auto n : graph->nodes()) {
|
||||
ONNXShapeTypeInference(n, opset_version);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -8,11 +8,39 @@ namespace jit {
|
|||
TORCH_API TypePtr
|
||||
MergeInferredType(TypePtr existing_type, TypePtr inferred_type);
|
||||
|
||||
// Update graph input types with dynamic axes info.
|
||||
// Axes that are marked as dynamic will be assigned as dynamic ShapeSymbol.
|
||||
// Note it is possible for multiple axes to share the same ShapeSymbol,
|
||||
// if they are defined as such in dynamic_axes.
|
||||
TORCH_API void ONNXSetDynamicInputShape(
|
||||
std::shared_ptr<Graph>& graph,
|
||||
const std::unordered_map<
|
||||
std::string,
|
||||
std::unordered_map<int64_t, std::string>>& dynamic_axes,
|
||||
const std::vector<std::string>& input_names);
|
||||
|
||||
// Update graph output with types of output Tensors.
|
||||
// If onnx_shape_inference is true, types of output Tensors will be compared and
|
||||
// merged with inferred types. It is possible that inferred types contain
|
||||
// dynamic axes, hence it takes precedence over types of output Tensors.
|
||||
TORCH_API void ONNXAssignOutputShape(
|
||||
std::shared_ptr<Graph>& graph,
|
||||
at::ArrayRef<at::Tensor> outputs,
|
||||
bool onnx_shape_inference);
|
||||
|
||||
// Utilize ONNX Shape Inference for node.
|
||||
// The node must have ONNX namespace, and is valid ONNX node accroding to spec.
|
||||
// On successful ONNX shape inference runs, the function updates output types of
|
||||
// n with inferred shape and type. Otherwise n is unchanged.
|
||||
TORCH_API void ONNXShapeTypeInference(Node* n, int opset_version);
|
||||
|
||||
// Utilize ONNX Shape Inference for graph.
|
||||
// Internally calls ONNXShapeTypeInference for each node, to achieve more
|
||||
// coverage that skips only individual nodes if illegal, instead of skipping for
|
||||
// the entire graph.
|
||||
TORCH_API void ONNXShapeTypeInference(
|
||||
std::shared_ptr<Graph>& g,
|
||||
int opset_version);
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -139,6 +139,13 @@ void initJITBindings(PyObject* module) {
|
|||
.def("_jit_pass_onnx_remove_print", RemovePrintOps)
|
||||
.def("_jit_pass_onnx_preprocess_caffe2", PreprocessCaffe2Ops)
|
||||
.def("_jit_pass_onnx", ToONNX)
|
||||
.def(
|
||||
"_jit_pass_onnx_assign_output_shape",
|
||||
[](std::shared_ptr<Graph>& graph,
|
||||
const std::vector<at::Tensor>& tensors,
|
||||
bool onnx_shape_inference = false) {
|
||||
ONNXAssignOutputShape(graph, tensors, onnx_shape_inference);
|
||||
})
|
||||
.def("_jit_pass_lower_all_tuples", LowerAllTuples)
|
||||
.def("_jit_pass_onnx_function_substitution", ONNXFunctionCallSubstitution)
|
||||
.def(
|
||||
|
|
@ -188,7 +195,17 @@ void initJITBindings(PyObject* module) {
|
|||
.def(
|
||||
"_jit_pass_onnx_prepare_inplace_ops_for_onnx",
|
||||
PrepareInplaceOpsForONNX)
|
||||
.def("_jit_pass_onnx_node_shape_type_inference", ONNXShapeTypeInference)
|
||||
.def(
|
||||
"_jit_pass_onnx_node_shape_type_inference",
|
||||
[](Node* n, int opset_version) {
|
||||
ONNXShapeTypeInference(n, opset_version);
|
||||
})
|
||||
.def(
|
||||
"_jit_pass_onnx_graph_shape_type_inference",
|
||||
[](std::shared_ptr<Graph>& graph, int opset_version) {
|
||||
ONNXShapeTypeInference(graph, opset_version);
|
||||
})
|
||||
.def("_jit_pass_onnx_set_dynamic_input_shape", ONNXSetDynamicInputShape)
|
||||
.def("_jit_pass_fuse", FuseGraph)
|
||||
.def(
|
||||
"_jit_pass_dce",
|
||||
|
|
|
|||
|
|
@ -238,7 +238,8 @@ void initPythonIRBindings(PyObject* module_) {
|
|||
std::string graph;
|
||||
std::shared_ptr<::ONNX_NAMESPACE::ModelProto> model_proto;
|
||||
RawDataExportMap export_map;
|
||||
std::tie(model_proto, export_map) = export_onnx(
|
||||
SymbolDimMap symbol_map;
|
||||
std::tie(model_proto, export_map, symbol_map) = export_onnx(
|
||||
g,
|
||||
initializers,
|
||||
onnx_opset_version,
|
||||
|
|
@ -251,6 +252,7 @@ void initPythonIRBindings(PyObject* module_) {
|
|||
add_node_names,
|
||||
use_external_data_format,
|
||||
onnx_file_path);
|
||||
graph = serialize_model_proto_to_string(model_proto);
|
||||
std::unordered_map<std::string, py::bytes>
|
||||
python_serialized_export_map;
|
||||
for (auto& kv : export_map) {
|
||||
|
|
|
|||
|
|
@ -147,11 +147,11 @@ struct PythonResolver : public Resolver {
|
|||
ClassTypePtr classType_;
|
||||
};
|
||||
|
||||
std::shared_ptr<PythonResolver> pythonResolver(ResolutionCallback rcb) {
|
||||
std::shared_ptr<PythonResolver> pythonResolver(const ResolutionCallback& rcb) {
|
||||
return std::make_shared<PythonResolver>(rcb);
|
||||
}
|
||||
std::shared_ptr<PythonResolver> pythonResolver(
|
||||
ResolutionCallback rcb,
|
||||
const ResolutionCallback& rcb,
|
||||
std::string classname,
|
||||
ClassTypePtr classType) {
|
||||
return std::make_shared<PythonResolver>(
|
||||
|
|
@ -491,21 +491,6 @@ static std::shared_ptr<Graph> _propagate_and_assign_input_shapes(
|
|||
return retval;
|
||||
}
|
||||
|
||||
static std::shared_ptr<Graph> _assign_output_shapes(
|
||||
Graph& graph,
|
||||
std::vector<at::Tensor> outputs) {
|
||||
auto retval = graph.copy();
|
||||
AT_ASSERT(retval->outputs().size() == outputs.size());
|
||||
for (size_t i = 0; i < outputs.size(); ++i) {
|
||||
auto scalar_type = outputs[i].scalar_type();
|
||||
auto sizes = outputs[i].sizes();
|
||||
auto type =
|
||||
torch::jit::TensorType::createContiguous(scalar_type, at::kCPU, sizes);
|
||||
retval->outputs()[i]->setType(type);
|
||||
}
|
||||
return retval;
|
||||
}
|
||||
|
||||
void addFunctionToModule(Module& module, const StrongFunctionPtr& func) {
|
||||
// Make a graph with a fake self argument
|
||||
auto graph = func.function_->graph()->copy();
|
||||
|
|
@ -641,7 +626,7 @@ struct slot_dict_impl {
|
|||
template <typename T>
|
||||
py::list debugMakeList(const T& list) {
|
||||
py::list result;
|
||||
for (auto elem : list) {
|
||||
for (const auto& elem : list) {
|
||||
result.append(py::cast(elem));
|
||||
}
|
||||
return result;
|
||||
|
|
@ -681,7 +666,7 @@ static py::dict _jit_debug_module_iterators(Module& module) {
|
|||
return result;
|
||||
}
|
||||
|
||||
static constexpr const char* magic_method_names[] = {
|
||||
static constexpr std::array<const char*, 47> magic_method_names = {
|
||||
"__lt__", "__le__", "__eq__", "__ne__",
|
||||
"__ge__", "__gt__", "__not__", "__abs__",
|
||||
"__add__", "__and__", "__floordiv__", "__index__",
|
||||
|
|
@ -806,7 +791,8 @@ void initJitScriptBindings(PyObject* module) {
|
|||
err << "which does not have a __getstate__ method defined!";
|
||||
throw std::runtime_error(err.str());
|
||||
},
|
||||
[](std::tuple<py::object, std::string> state_tup) -> Object {
|
||||
[](const std::tuple<py::object, std::string>& state_tup)
|
||||
-> Object {
|
||||
py::object state;
|
||||
std::string qualname;
|
||||
std::tie(state, qualname) = state_tup;
|
||||
|
|
@ -970,7 +956,7 @@ void initJitScriptBindings(PyObject* module) {
|
|||
[](Module& m,
|
||||
std::shared_ptr<ConcreteModuleType> concreteType,
|
||||
const std::string& script,
|
||||
ResolutionCallback rcb) {
|
||||
const ResolutionCallback& rcb) {
|
||||
const auto self = ModuleSelf(std::move(concreteType));
|
||||
m._ivalue()->compilation_unit()->define(
|
||||
*m.type()->name(), script, pythonResolver(rcb), &self);
|
||||
|
|
@ -980,7 +966,7 @@ void initJitScriptBindings(PyObject* module) {
|
|||
"_register_attribute",
|
||||
[](Module& m,
|
||||
const std::string& name,
|
||||
TypePtr type,
|
||||
const TypePtr& type,
|
||||
py::handle value) {
|
||||
m.register_attribute(name, type, toIValue(value, type));
|
||||
})
|
||||
|
|
@ -988,9 +974,9 @@ void initJitScriptBindings(PyObject* module) {
|
|||
"_create_method_from_trace",
|
||||
[](Module& self,
|
||||
const std::string& name,
|
||||
py::function func,
|
||||
py::tuple input_tuple,
|
||||
py::function var_lookup_fn,
|
||||
const py::function& func,
|
||||
const py::tuple& input_tuple,
|
||||
const py::function& var_lookup_fn,
|
||||
bool strict,
|
||||
bool force_outplace) {
|
||||
// prereq: Module's buffers and parameters are unique
|
||||
|
|
@ -1106,7 +1092,7 @@ void initJitScriptBindings(PyObject* module) {
|
|||
"define",
|
||||
[](CompilationUnit& cu,
|
||||
const std::string& src,
|
||||
ResolutionCallback rcb) {
|
||||
const ResolutionCallback& rcb) {
|
||||
cu.define(c10::nullopt, src, pythonResolver(rcb), nullptr);
|
||||
})
|
||||
.def(
|
||||
|
|
@ -1279,10 +1265,10 @@ void initJitScriptBindings(PyObject* module) {
|
|||
});
|
||||
m.def(
|
||||
"_create_function_from_trace",
|
||||
[](std::string qualname,
|
||||
py::function func,
|
||||
py::tuple input_tuple,
|
||||
py::function var_lookup_fn,
|
||||
[](const std::string& qualname,
|
||||
const py::function& func,
|
||||
const py::tuple& input_tuple,
|
||||
const py::function& var_lookup_fn,
|
||||
bool strict,
|
||||
bool force_outplace) {
|
||||
auto typed_inputs = toTraceableStack(input_tuple);
|
||||
|
|
@ -1303,7 +1289,7 @@ void initJitScriptBindings(PyObject* module) {
|
|||
[](const std::string& qualifiedName,
|
||||
const ClassDef& classDef,
|
||||
const ClassMethodDefaults& defaults,
|
||||
ResolutionCallback rcb) {
|
||||
const ResolutionCallback& rcb) {
|
||||
C10_LOG_API_USAGE_ONCE("torch.script.class");
|
||||
if (classDef.superclass().present()) {
|
||||
throw ErrorReport(classDef.range())
|
||||
|
|
@ -1465,7 +1451,6 @@ void initJitScriptBindings(PyObject* module) {
|
|||
m.def("_propagate_shapes", _propagate_shapes);
|
||||
m.def(
|
||||
"_propagate_and_assign_input_shapes", _propagate_and_assign_input_shapes);
|
||||
m.def("_assign_output_shapes", _assign_output_shapes);
|
||||
m.def(
|
||||
"_last_executed_optimized_graph",
|
||||
[]() { return lastExecutedOptimizedGraph(); },
|
||||
|
|
@ -1668,18 +1653,23 @@ void initJitScriptBindings(PyObject* module) {
|
|||
|
||||
m.def(
|
||||
"_resolve_type",
|
||||
[](const std::string& name, SourceRange range, ResolutionCallback rcb) {
|
||||
[](const std::string& name,
|
||||
const SourceRange& range,
|
||||
const ResolutionCallback& rcb) {
|
||||
return pythonResolver(rcb)->resolveType(name, range);
|
||||
});
|
||||
m.def(
|
||||
"_resolve_type_from_object",
|
||||
[](const py::object& obj, SourceRange range, ResolutionCallback rcb) {
|
||||
[](const py::object& obj,
|
||||
const SourceRange& range,
|
||||
const ResolutionCallback& rcb) {
|
||||
return pythonResolver(rcb)->resolveTypeFromObject(obj, range);
|
||||
});
|
||||
|
||||
m.def(
|
||||
"_run_emit_module_hook", [](const Module& m) { didFinishEmitModule(m); });
|
||||
|
||||
// NOLINTNEXTLINE(bugprone-unused-raii)
|
||||
py::class_<logging::LoggerBase, std::shared_ptr<logging::LoggerBase>>(
|
||||
m, "LoggerBase");
|
||||
py::enum_<logging::LockingLogger::AggregationType>(m, "AggregationType")
|
||||
|
|
|
|||
|
|
@ -178,6 +178,10 @@ class EncoderBase {
|
|||
return model_proto_;
|
||||
}
|
||||
|
||||
SymbolDimMap get_symbol_dim_param_map() {
|
||||
return symbol_dim_map_;
|
||||
}
|
||||
|
||||
protected:
|
||||
// Using std::map instead of std::unordered_map for initializers
|
||||
// in EncodeGraph constructor so that the order in which initializers
|
||||
|
|
@ -243,6 +247,7 @@ class EncoderBase {
|
|||
const bool use_external_data_format = false,
|
||||
const std::string& onnx_file_path = std::string());
|
||||
|
||||
SymbolDimMap symbol_dim_map_;
|
||||
onnx::ModelProto model_proto_;
|
||||
size_t num_blocks_;
|
||||
size_t num_op_nodes_;
|
||||
|
|
@ -316,33 +321,38 @@ void EncoderBase::EncodeValueInfo(
|
|||
std::unordered_map<int64_t, std::string>>& dynamic_axes) {
|
||||
std::string name = n->debugName();
|
||||
v->set_name(name);
|
||||
auto tensorTypeToONNXType = [&dynamic_axes, &name](
|
||||
auto tensorTypeToONNXType = [&dynamic_axes, &name, this](
|
||||
TensorTypePtr t,
|
||||
onnx::TypeProto_Tensor* tensor_type) {
|
||||
if (t->sizes().isComplete()) {
|
||||
// onnx::TypeProto* onnx_type = v->mutable_type();
|
||||
// onnx::TypeProto_Tensor* tensor_type = onnx_type->mutable_tensor_type();
|
||||
if (t->dim()) {
|
||||
onnx::TensorShapeProto* shape = tensor_type->mutable_shape();
|
||||
std::vector<std::int64_t> sizes = t->sizes().concrete_sizes().value();
|
||||
auto sizes = t->symbolic_sizes().sizes().value();
|
||||
for (size_t i = 0; i < sizes.size(); i++) {
|
||||
shape->add_dim();
|
||||
if ((dynamic_axes.find(name) != dynamic_axes.end()) &&
|
||||
(dynamic_axes.at(name).find(i) != dynamic_axes.at(name).end())) {
|
||||
shape->mutable_dim(i)->set_dim_param(dynamic_axes.at(name).at(i));
|
||||
if (!sizes[i].is_static()) {
|
||||
symbol_dim_map_[sizes[i]] = dynamic_axes.at(name).at(i);
|
||||
}
|
||||
} else if (sizes[i].is_static()) {
|
||||
shape->mutable_dim(i)->set_dim_value(sizes[i].static_size());
|
||||
} else {
|
||||
shape->mutable_dim(i)->set_dim_value(sizes[i]);
|
||||
if (symbol_dim_map_.find(sizes[i]) == symbol_dim_map_.end()) {
|
||||
symbol_dim_map_[sizes[i]] = name + "_" + std::to_string(i);
|
||||
}
|
||||
shape->mutable_dim(i)->set_dim_param(symbol_dim_map_[sizes[i]]);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (t->scalarType()) {
|
||||
// onnx::TypeProto* onnx_type = v->mutable_type();
|
||||
// onnx::TypeProto_Tensor* tensor_type = onnx_type->mutable_tensor_type();
|
||||
tensor_type->set_elem_type(ATenTypeToOnnxType(t->scalarType().value()));
|
||||
}
|
||||
};
|
||||
|
||||
if (TensorTypePtr node_type = n->type()->cast<TensorType>()) {
|
||||
if (node_type->sizes().isComplete() || node_type->scalarType()) {
|
||||
if (node_type->dim() || node_type->scalarType()) {
|
||||
// Encode type if either shape or dtype exists.
|
||||
onnx::TypeProto* onnx_type = v->mutable_type();
|
||||
onnx::TypeProto_Tensor* tensor_type = onnx_type->mutable_tensor_type();
|
||||
tensorTypeToONNXType(node_type, tensor_type);
|
||||
|
|
@ -854,7 +864,10 @@ std::string pretty_print_onnx(
|
|||
// conform to the ONNX op specification. Thus, the output will not
|
||||
// be interpretable by a ONNX-compatible framework. However, PyTorch or
|
||||
// libtorch will be able to import the IR and play it back.
|
||||
std::tuple<std::shared_ptr<::ONNX_NAMESPACE::ModelProto>, RawDataExportMap>
|
||||
std::tuple<
|
||||
std::shared_ptr<::ONNX_NAMESPACE::ModelProto>,
|
||||
RawDataExportMap,
|
||||
SymbolDimMap>
|
||||
export_onnx(
|
||||
const std::shared_ptr<Graph>& graph,
|
||||
const std::map<std::string, at::Tensor>& initializers,
|
||||
|
|
@ -888,10 +901,12 @@ export_onnx(
|
|||
proto_size <= INT_MAX,
|
||||
"Exporting model exceed maximum protobuf size of 2GB. "
|
||||
"Please call torch.onnx.export with use_external_data_format=True.");
|
||||
GRAPH_UPDATE("onnx proto:", prettyPrint(graph_encoder.get_model_proto()));
|
||||
std::shared_ptr<onnx::ModelProto> model_proto =
|
||||
std::make_shared<onnx::ModelProto>(graph_encoder.get_model_proto());
|
||||
return std::make_tuple(model_proto, graph_encoder.get_raw_data_export_map());
|
||||
GRAPH_DEBUG("onnx proto:", prettyPrint(graph_encoder.get_model_proto()));
|
||||
return std::make_tuple(
|
||||
std::make_shared<::ONNX_NAMESPACE::ModelProto>(
|
||||
graph_encoder.get_model_proto()),
|
||||
graph_encoder.get_raw_data_export_map(),
|
||||
graph_encoder.get_symbol_dim_param_map());
|
||||
}
|
||||
|
||||
std::string serialize_model_proto_to_string(
|
||||
|
|
|
|||
|
|
@ -25,8 +25,12 @@ namespace jit {
|
|||
// file contents being the raw tensor data.
|
||||
using RawDataExportMap = std::unordered_map<std::string, at::Tensor>;
|
||||
|
||||
TORCH_API std::
|
||||
tuple<std::shared_ptr<::ONNX_NAMESPACE::ModelProto>, RawDataExportMap>
|
||||
using SymbolDimMap = std::map<c10::ShapeSymbol, std::string>;
|
||||
|
||||
TORCH_API std::tuple<
|
||||
std::shared_ptr<::ONNX_NAMESPACE::ModelProto>,
|
||||
RawDataExportMap,
|
||||
SymbolDimMap>
|
||||
export_onnx(
|
||||
const std::shared_ptr<Graph>& graph,
|
||||
const std::map<std::string, at::Tensor>& initializers,
|
||||
|
|
|
|||
|
|
@ -171,6 +171,8 @@ def _is_none(x):
|
|||
def _is_value(x):
|
||||
return isinstance(x, torch._C.Value)
|
||||
|
||||
def _is_tensor(x):
|
||||
return x.type().isSubtypeOf(torch._C.TensorType.get())
|
||||
|
||||
def _is_tensor_list(x):
|
||||
return isinstance(x.type(), torch._C.ListType) and isinstance(x.type().getElementType(), torch._C.TensorType)
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import torch.onnx.symbolic_helper as sym_help
|
|||
import warnings
|
||||
import numpy
|
||||
|
||||
from torch.onnx.symbolic_helper import parse_args, _unimplemented
|
||||
from torch.onnx.symbolic_helper import parse_args, _unimplemented, _is_tensor_list
|
||||
from torch.onnx.symbolic_opset9 import expand, unused
|
||||
from torch.nn.modules.utils import _single, _pair, _triple
|
||||
from torch.onnx.utils import _add_block, _add_input_to_block, _add_output_to_block
|
||||
|
|
@ -272,7 +272,7 @@ def masked_scatter(g, self, mask, source):
|
|||
|
||||
|
||||
def _len(g, self):
|
||||
if self.type().isSubtypeOf(torch._C.ListType.ofTensors()) or self.node().kind() == "onnx::SplitToSequence":
|
||||
if _is_tensor_list(self) or self.node().kind() == "onnx::SplitToSequence":
|
||||
return g.op("SequenceLength", self)
|
||||
return g.op("Size", self)
|
||||
|
||||
|
|
|
|||
|
|
@ -121,6 +121,7 @@ def floor_divide(g, self, other):
|
|||
# - self is not fp and other is not fp, the output's type is self's output type
|
||||
# - the output type defaults to Float
|
||||
scalar_type = self.type().scalarType()
|
||||
|
||||
if scalar_type is not None:
|
||||
if not sym_help._is_fp(self) and \
|
||||
other.type().scalarType() is not None and \
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ import warnings
|
|||
from torch._six import string_classes
|
||||
from torch.jit import _unique_state_dict
|
||||
from torch.onnx import ONNX_ARCHIVE_MODEL_PROTO_NAME, ExportTypes, OperatorExportTypes, TrainingMode
|
||||
from torch._C import ListType, OptionalType, _propagate_and_assign_input_shapes, _assign_output_shapes, _check_onnx_proto
|
||||
from torch._C import ListType, OptionalType, _propagate_and_assign_input_shapes, _check_onnx_proto
|
||||
|
||||
|
||||
# the flag to tell the user whether it's in the middle of ONNX export or not
|
||||
|
|
@ -121,7 +121,7 @@ def _split_tensor_list_constants(g, block):
|
|||
|
||||
|
||||
def _optimize_graph(graph, operator_export_type, _disable_torch_constant_prop=False, fixed_batch_size=False,
|
||||
params_dict=None, use_new_jit_passes=False):
|
||||
params_dict=None, use_new_jit_passes=False, dynamic_axes=None, input_names=None):
|
||||
# Inline everything
|
||||
torch._C._jit_pass_inline(graph)
|
||||
|
||||
|
|
@ -195,6 +195,11 @@ def _optimize_graph(graph, operator_export_type, _disable_torch_constant_prop=Fa
|
|||
# onnx only supports tensors, so we turn all out number types into tensors
|
||||
torch._C._jit_pass_erase_number_types(graph)
|
||||
|
||||
from torch.onnx.symbolic_helper import _onnx_shape_inference
|
||||
if _onnx_shape_inference:
|
||||
input_names = [] if input_names is None else input_names
|
||||
dynamic_axes = {} if dynamic_axes is None else dynamic_axes
|
||||
torch._C._jit_pass_onnx_set_dynamic_input_shape(graph, dynamic_axes, input_names)
|
||||
graph = torch._C._jit_pass_onnx(graph, operator_export_type)
|
||||
torch._C._jit_pass_lint(graph)
|
||||
|
||||
|
|
@ -214,6 +219,9 @@ def _optimize_graph(graph, operator_export_type, _disable_torch_constant_prop=Fa
|
|||
torch._C._jit_pass_lint(graph)
|
||||
graph = torch._C._jit_pass_canonicalize(graph)
|
||||
torch._C._jit_pass_lint(graph)
|
||||
from torch.onnx.symbolic_helper import _onnx_shape_inference, _export_onnx_opset_version
|
||||
if _onnx_shape_inference:
|
||||
torch._C._jit_pass_onnx_graph_shape_type_inference(graph, _export_onnx_opset_version)
|
||||
return graph
|
||||
|
||||
|
||||
|
|
@ -388,7 +396,8 @@ def _model_to_graph(model, args, verbose=False,
|
|||
example_outputs=None,
|
||||
_retain_param_name=False, do_constant_folding=True,
|
||||
_disable_torch_constant_prop=False, fixed_batch_size=False,
|
||||
training=None, use_new_jit_passes=False):
|
||||
training=None, use_new_jit_passes=False,
|
||||
dynamic_axes=None):
|
||||
from torch.onnx.symbolic_helper import _export_onnx_opset_version
|
||||
# Special case for common case of passing a single Tensor
|
||||
if isinstance(args, torch.Tensor):
|
||||
|
|
@ -408,19 +417,20 @@ def _model_to_graph(model, args, verbose=False,
|
|||
graph = _optimize_graph(graph, operator_export_type,
|
||||
_disable_torch_constant_prop=_disable_torch_constant_prop,
|
||||
fixed_batch_size=fixed_batch_size, params_dict=params_dict,
|
||||
use_new_jit_passes=use_new_jit_passes)
|
||||
use_new_jit_passes=use_new_jit_passes,
|
||||
dynamic_axes=dynamic_axes, input_names=input_names)
|
||||
from torch.onnx.symbolic_helper import _onnx_shape_inference
|
||||
if isinstance(model, torch.jit.ScriptModule) or isinstance(model, torch.jit.ScriptFunction):
|
||||
assert example_outputs is not None, "example_outputs must be provided when exporting a ScriptModule or " \
|
||||
"ScriptFunction."
|
||||
out_vars, _ = torch.jit._flatten(tuple(example_outputs))
|
||||
graph = _assign_output_shapes(graph, out_vars)
|
||||
torch._C._jit_pass_onnx_assign_output_shape(graph, out_vars, _onnx_shape_inference)
|
||||
|
||||
# NB: ONNX requires complete information about output types, which might be
|
||||
# erased by some optimizations, so we need to set it explicitly again.
|
||||
if torch_out is not None:
|
||||
output_tensors, _ = torch._C._jit_flatten(torch_out)
|
||||
for output, tensor in zip(graph.outputs(), output_tensors):
|
||||
output.inferTypeFrom(tensor)
|
||||
torch._C._jit_pass_onnx_assign_output_shape(graph, output_tensors, _onnx_shape_inference)
|
||||
|
||||
_set_input_and_output_names(graph, input_names, output_names)
|
||||
|
||||
|
|
@ -614,6 +624,10 @@ def _export(model, args, f, export_params=True, verbose=False, training=None,
|
|||
val_use_external_data_format, model_file_location = _decide_external_data_format(use_external_data_format,
|
||||
operator_export_type,
|
||||
f)
|
||||
if dynamic_axes is None:
|
||||
dynamic_axes = {}
|
||||
_validate_dynamic_axes(dynamic_axes, model, input_names, output_names)
|
||||
|
||||
graph, params_dict, torch_out = \
|
||||
_model_to_graph(model, args, verbose, input_names,
|
||||
output_names, operator_export_type,
|
||||
|
|
@ -621,17 +635,14 @@ def _export(model, args, f, export_params=True, verbose=False, training=None,
|
|||
val_do_constant_folding,
|
||||
fixed_batch_size=fixed_batch_size,
|
||||
training=training,
|
||||
use_new_jit_passes=use_new_jit_passes)
|
||||
use_new_jit_passes=use_new_jit_passes,
|
||||
dynamic_axes=dynamic_axes)
|
||||
|
||||
# TODO: Don't allocate a in-memory string for the protobuf
|
||||
defer_weight_export = export_type is not ExportTypes.PROTOBUF_FILE
|
||||
if dynamic_axes is None:
|
||||
dynamic_axes = {}
|
||||
if custom_opsets is None:
|
||||
custom_opsets = {}
|
||||
|
||||
_validate_dynamic_axes(dynamic_axes, model, input_names, output_names)
|
||||
|
||||
if export_params:
|
||||
proto, export_map = graph._export_onnx(
|
||||
params_dict, opset_version, dynamic_axes, defer_weight_export,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user