[ONNX] Add dim_param support in export with onnx shape inference (#44920)

Summary:
* Support propagating `dim_param` in ONNX by encoding as `ShapeSymbol` in `SymbolicShape` of outputs. If export is called with `dynamic_axes` provided, shape inference will start with these axes set as dynamic.
* Add new test file `test_pytorch_onnx_shape_inference.py`, reusing all test cases from `test_pytorch_onnx_onnxruntime.py`, but focus on validating shape for all nodes in graph. Currently this is not enabled in the CI, since there are still quite some existing issues and corner cases to fix. The test is default to run only at opset 12.
* Bug fixes, such as div, _len, and peephole.cpp passes for PackPadded, and LogSoftmaxCrossEntropy.
* This PR depends on existing PR such as 44332.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/44920

Reviewed By: eellison

Differential Revision: D23958398

Pulled By: bzinodev

fbshipit-source-id: 00479d9bd19c867d526769a15ba97ec16d56e51d
This commit is contained in:
BowenBao 2020-09-30 21:54:15 -07:00 committed by Facebook GitHub Bot
parent ffcb0989e7
commit 3da4cea658
15 changed files with 393 additions and 124 deletions

View File

@ -70,4 +70,6 @@ if [[ "$BUILD_ENVIRONMENT" == *ort_test2* ]]; then
pytest "${args[@]}" \
"$top_dir/test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset$i"
done
pytest "${args[@]}" \
"$top_dir/test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset12_onnx_shape_inference"
fi

View File

@ -640,6 +640,20 @@ class TestONNXRuntime(unittest.TestCase):
self.run_test(TraceModel(), (x1, x2, x3), atol=10e-5)
self.run_test(ScriptModel(), (x1, x2, x3), atol=10e-5)
def test_conv_shape_inference(self):
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv2 = torch.nn.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1))
def forward(self, input):
return self.conv2(input) + 2
x = torch.randn(20, 16, 50, 100)
self.run_test(Model(), x, atol=10e-5,
input_names=['x'],
dynamic_axes={'x': [0]})
def test_conv_transpose(self):
class TraceModel(torch.nn.Module):
def __init__(self):

View File

@ -0,0 +1,78 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import unittest
import torch
import copy
import test_pytorch_onnx_onnxruntime
from test_pytorch_onnx_onnxruntime import TestONNXRuntime
from torch.onnx import utils, OperatorExportTypes, TrainingMode
from torch.onnx.utils import _validate_dynamic_axes
from torch.onnx.symbolic_helper import (_set_opset_version, _set_operator_export_type,
_set_onnx_shape_inference, _set_training_mode,
_is_tensor_list, _is_tensor, _is_none)
def verify_inferred_shape(graph):
# Check every node in graph has type properly assigned.
for n in graph.nodes():
for out in n.outputs():
if not _is_tensor_list(out) and not _is_tensor(out) and not _is_none(out):
raise RuntimeError("Output of node is neither type Tensor nor type list of Tensor: ", out)
if _is_tensor(out) and out.type().scalarType() is None:
raise RuntimeError("Output of node does not have type assigned", out)
if _is_tensor(out) and out.type().dim() is None:
raise RuntimeError("Output of node does not have shape assigned", out)
def run_model_test(self, model, batch_size=2, state_dict=None,
input=None, use_gpu=True, rtol=0.001, atol=1e-7,
example_outputs=None, do_constant_folding=True,
dynamic_axes=None, test_with_inputs=None,
input_names=None, output_names=None,
fixed_batch_size=False):
model.eval()
if input is None:
input = torch.randn(batch_size, 3, 224, 224, requires_grad=True)
with torch.no_grad():
if isinstance(input, torch.Tensor):
input = (input,)
# In-place operators will update input tensor data as well.
# Thus inputs are replicated before every forward call.
input_copy = copy.deepcopy(input)
output = model(*input_copy)
if isinstance(output, torch.Tensor):
output = (output,)
_set_opset_version(self.opset_version)
_set_operator_export_type(OperatorExportTypes.ONNX)
_set_onnx_shape_inference(True)
_set_training_mode(False)
if dynamic_axes is None:
dynamic_axes = {}
_validate_dynamic_axes(dynamic_axes, model, input_names, output_names)
input_copy = copy.deepcopy(input)
graph, _, _ = utils._model_to_graph(model, input_copy,
input_names=input_names,
output_names=output_names,
operator_export_type=OperatorExportTypes.ONNX,
example_outputs=output,
do_constant_folding=do_constant_folding,
training=TrainingMode.EVAL,
use_new_jit_passes=self.use_new_jit_passes,
dynamic_axes=dynamic_axes)
verify_inferred_shape(graph)
if __name__ == '__main__':
TestONNXRuntime.opset_version = 12
test_pytorch_onnx_onnxruntime.run_model_test = run_model_test
unittest.main()

View File

@ -310,7 +310,13 @@ void pushPackingPastRnn(Block* b) {
std::vector<int64_t> new_sizes;
new_sizes.push_back(*oldType->sizes()[0]);
new_sizes.push_back(*oldType->sizes()[1]);
new_sizes.push_back(rnn->i(attr::hidden_size));
if (next->kind() == onnx::Reshape) {
// bidirection
new_sizes.push_back(rnn->i(attr::hidden_size) * 2);
} else {
// unidirection
new_sizes.push_back(rnn->i(attr::hidden_size));
}
TensorTypePtr newType = TensorType::createContiguous(
*oldType->scalarType(), *oldType->device(), new_sizes);
next->outputs().at(0)->setType(newType);
@ -747,6 +753,7 @@ static void fuseLogSoftmaxNllLoss(Block* b) {
prim::ListConstruct);
// make output of reshape the output of nllloss
nllloss_output->replaceAllUsesWith(origNllLossNode);
origNllLossNode->output(0)->copyMetadata(nllloss_output->output(0));
}
} else {
continue;

View File

@ -40,8 +40,8 @@ TypePtr MergeInferredType(TypePtr existing_type, TypePtr inferred_type) {
return new_tensor_type;
}
auto type = old_tensor_type;
if (new_tensor_type->sizes().isComplete()) {
type = type->withSizes(new_tensor_type->sizes().concrete_sizes().value());
if (new_tensor_type->dim()) {
type = type->withSymbolicShapes(new_tensor_type->symbolic_sizes());
}
if (new_tensor_type->scalarType().has_value()) {
type = type->withScalarType(new_tensor_type->scalarType());
@ -69,7 +69,8 @@ namespace onnx_torch = ::torch::onnx;
namespace onnx = ::ONNX_NAMESPACE;
TensorTypePtr TorchTensorTypeFromONNX(
const onnx::TypeProto_Tensor& onnx_tensor_type) {
const onnx::TypeProto_Tensor& onnx_tensor_type,
const SymbolDimMap& symbol_map) {
c10::optional<at::ScalarType> scalar_type;
if (onnx_tensor_type.has_elem_type()) {
scalar_type = ONNXTypeToATenType(onnx_tensor_type.elem_type());
@ -82,33 +83,51 @@ TensorTypePtr TorchTensorTypeFromONNX(
c10::VaryingShape<c10::Stride>{},
{});
if (onnx_tensor_type.has_shape()) {
std::vector<int64_t> sizes;
std::vector<c10::ShapeSymbol> sizes;
auto onnx_shape = onnx_tensor_type.shape();
for (int i = 0; i < onnx_shape.dim_size(); ++i) {
auto& dim = onnx_shape.dim(i);
if (dim.has_dim_value()) {
sizes.push_back(dim.dim_value());
sizes.emplace_back(c10::ShapeSymbol::fromStaticSize(dim.dim_value()));
} else {
// TODO: handle dim_param?
return v_type;
GRAPH_UPDATE("Got dim_param:", dim.dim_param());
c10::optional<c10::ShapeSymbol> sym = c10::nullopt;
for (auto pair : symbol_map) {
if (pair.second == dim.dim_param()) {
sym = pair.first;
break;
}
}
if (!sym) {
sym = c10::ShapeSymbol::newSymbol();
}
sizes.emplace_back(sym.value());
}
}
v_type = TensorType::create(scalar_type, at::kCPU, sizes.size(), {});
v_type = v_type->withSizes(sizes);
v_type = v_type->withSymbolicShapes(c10::SymbolicShape(sizes));
if (v_type->sizes().concrete_sizes().has_value()) {
// Populate strides based on sizes info, if sizes are all static.
// Creating strides ensures yielding True for isCompleteTensor.
v_type = v_type->contiguous();
}
}
return v_type;
}
ListTypePtr TorchListTypeFromONNX(
const onnx::TypeProto_Sequence& onnx_sequence_type) {
const onnx::TypeProto_Sequence& onnx_sequence_type,
SymbolDimMap symbol_map) {
c10::optional<at::ScalarType> scalar_type;
if (onnx_sequence_type.has_elem_type()) {
auto onnx_seq_elem_type = onnx_sequence_type.elem_type();
if (onnx_seq_elem_type.has_tensor_type()) {
auto onnx_tensor_type = onnx_seq_elem_type.tensor_type();
auto v_tensor_type = TorchTensorTypeFromONNX(onnx_tensor_type);
auto v_tensor_type =
TorchTensorTypeFromONNX(onnx_tensor_type, symbol_map);
auto v_type = ListType::create(v_tensor_type);
return v_type;
}
@ -118,21 +137,24 @@ ListTypePtr TorchListTypeFromONNX(
void UpdateTorchValueByOnnxValueInfo(
Value* v,
const onnx::ValueInfoProto& p_info) {
const onnx::ValueInfoProto& p_info,
SymbolDimMap symbol_map) {
if (!p_info.has_type()) {
return;
}
auto p_type = p_info.type();
if (p_type.has_tensor_type()) {
auto torch_tensor_type = TorchTensorTypeFromONNX(p_type.tensor_type());
auto torch_tensor_type =
TorchTensorTypeFromONNX(p_type.tensor_type(), symbol_map);
if (torch_tensor_type) {
v->setType(torch_tensor_type);
v->setType(MergeInferredType(v->type(), torch_tensor_type));
}
} else if (p_type.has_sequence_type()) {
auto torch_list_type = TorchListTypeFromONNX(p_type.sequence_type());
auto torch_list_type =
TorchListTypeFromONNX(p_type.sequence_type(), symbol_map);
if (torch_list_type) {
v->setType(torch_list_type);
v->setType(MergeInferredType(v->type(), torch_list_type));
}
}
}
@ -148,9 +170,17 @@ bool IsSupportedNode(const Node* n) {
// Skip when block size is zero. This is when the node is first created,
// doesn't have subblocks attached yet. Run shape inference for these nodes
// when the subgraph has already completed shape inferencing.
if ((node_kind == ::c10::onnx::Loop || node_kind == ::c10::onnx::If) &&
n->blocks().size() == 0) {
return false;
if (node_kind == ::c10::onnx::Loop || node_kind == ::c10::onnx::If) {
if (n->blocks().size() == 0) {
return false;
}
for (auto b : n->blocks()) {
for (auto b_n : b->nodes()) {
if (!IsSupportedNode(b_n)) {
return false;
}
}
}
}
return true;
@ -234,9 +264,10 @@ bool IsGraphValidForInference(std::shared_ptr<Graph> graph) {
void ConvertGraphToONNXProto(
std::shared_ptr<Graph> graph,
std::shared_ptr<onnx::ModelProto>& model_proto,
SymbolDimMap& symbol_map,
int opset_version) {
RawDataExportMap export_map;
std::tie(model_proto, export_map) = export_onnx(
std::tie(model_proto, export_map, symbol_map) = export_onnx(
graph,
{},
opset_version,
@ -282,7 +313,8 @@ void SpecialPostProcess(Node* n) {
void UpdateOutputTypeByONNXProto(
Node* n,
Node* clone_node,
const onnx::ModelProto& model_proto) {
const onnx::ModelProto& model_proto,
SymbolDimMap symbol_map) {
auto graph_proto = model_proto.graph();
// inferred shapes are stored in value_info.
for (size_t i = 0; i < graph_proto.value_info_size(); ++i) {
@ -290,12 +322,10 @@ void UpdateOutputTypeByONNXProto(
// get data from value_info and updated original graph.
for (size_t j = 0; j < clone_node->outputs().size(); ++j) {
if (clone_node->output(j)->debugName() == v_info.name()) {
UpdateTorchValueByOnnxValueInfo(n->output(j), v_info);
UpdateTorchValueByOnnxValueInfo(n->output(j), v_info, symbol_map);
}
}
}
SpecialPostProcess(n);
}
} // namespace
@ -320,26 +350,94 @@ void ONNXShapeTypeInference(Node* n, int opset_version) {
GRAPH_DEBUG(
"Cloned torch graph to run shape inference: ", n_graph->toString());
if (!IsGraphValidForInference(n_graph)) {
GRAPH_UPDATE("Skipping ONNX shape inference for this node.");
return;
if (IsGraphValidForInference(n_graph)) {
// TODO: Some ops have conversion happen at Peephole pass.
// The conversion here is incomplete for these ops.
// e.g: ListConstruct, ListUnpack, etc.
std::shared_ptr<onnx::ModelProto> model_proto;
SymbolDimMap symbol_map;
ConvertGraphToONNXProto(n_graph, model_proto, symbol_map, opset_version);
GRAPH_DEBUG(
"ONNX graph to run shape inference: ", prettyPrint(*model_proto));
// infer shape
onnx::shape_inference::InferShapes(*model_proto);
GRAPH_DEBUG(
"ONNX graph after shape inference: ", prettyPrint(*model_proto));
UpdateOutputTypeByONNXProto(n, clone_node, *model_proto, symbol_map);
}
// TODO: Some ops have conversion happen at Peephole pass.
// The conversion here is incomplete for these ops.
// e.g: ListConstruct, ListUnpack, etc.
std::shared_ptr<onnx::ModelProto> model_proto;
ConvertGraphToONNXProto(n_graph, model_proto, opset_version);
GRAPH_DEBUG("ONNX graph to run shape inference: ", prettyPrint(*model_proto));
// infer shape
onnx::shape_inference::InferShapes(*model_proto);
GRAPH_DEBUG("ONNX graph after shape inference: ", prettyPrint(*model_proto));
UpdateOutputTypeByONNXProto(n, clone_node, *model_proto);
SpecialPostProcess(n);
GRAPH_DEBUG(
"Torch graph after shape inference:", n->owningGraph()->toString());
}
void ONNXSetDynamicInputShape(
std::shared_ptr<Graph>& graph,
const std::unordered_map<
std::string,
std::unordered_map<int64_t, std::string>>& dynamic_axes,
const std::vector<std::string>& input_names) {
GRAPH_UPDATE("ONNX set dynamic input shape.");
GRAPH_UPDATE("dynamic axes tensor names:", [&]() {
std::vector<std::string> res(dynamic_axes.size());
std::transform(
dynamic_axes.begin(), dynamic_axes.end(), res.begin(), [](auto pair) {
return pair.first;
});
return res;
}());
std::map<std::string, ::c10::ShapeSymbol> name_to_sym;
for (int i = 0; i < input_names.size(); ++i) {
auto input_name = input_names[i];
if (dynamic_axes.find(input_name) != dynamic_axes.end()) {
auto axes_names = dynamic_axes.find(input_name)->second;
TORCH_INTERNAL_ASSERT(i < graph->inputs().size());
auto input_tensor_type = graph->inputs()[i]->type()->cast<TensorType>();
if (!input_tensor_type) {
continue;
}
auto shape = input_tensor_type->symbolic_sizes().sizes().value();
for (auto pair : axes_names) {
auto axis = pair.first;
auto name = pair.second;
if (name_to_sym.find(name) == name_to_sym.end()) {
name_to_sym[name] = ::c10::ShapeSymbol::newSymbol();
}
shape[axis] = name_to_sym[name];
}
graph->inputs()[i]->setType(
input_tensor_type->withSymbolicShapes(::c10::SymbolicShape(shape)));
}
}
}
void ONNXAssignOutputShape(
std::shared_ptr<Graph>& graph,
at::ArrayRef<at::Tensor> outputs,
bool onnx_shape_inference) {
TORCH_INTERNAL_ASSERT(graph->outputs().size() == outputs.size());
for (size_t i = 0; i < outputs.size(); ++i) {
if (onnx_shape_inference) {
graph->outputs()[i]->setType(MergeInferredType(
TensorType::create(outputs[i]), graph->outputs()[i]->type()));
} else {
graph->outputs()[i]->inferTypeFrom(outputs[i]);
}
}
}
void ONNXShapeTypeInference(std::shared_ptr<Graph>& graph, int opset_version) {
for (auto n : graph->nodes()) {
ONNXShapeTypeInference(n, opset_version);
}
}
} // namespace jit
} // namespace torch

View File

@ -8,11 +8,39 @@ namespace jit {
TORCH_API TypePtr
MergeInferredType(TypePtr existing_type, TypePtr inferred_type);
// Update graph input types with dynamic axes info.
// Axes that are marked as dynamic will be assigned as dynamic ShapeSymbol.
// Note it is possible for multiple axes to share the same ShapeSymbol,
// if they are defined as such in dynamic_axes.
TORCH_API void ONNXSetDynamicInputShape(
std::shared_ptr<Graph>& graph,
const std::unordered_map<
std::string,
std::unordered_map<int64_t, std::string>>& dynamic_axes,
const std::vector<std::string>& input_names);
// Update graph output with types of output Tensors.
// If onnx_shape_inference is true, types of output Tensors will be compared and
// merged with inferred types. It is possible that inferred types contain
// dynamic axes, hence it takes precedence over types of output Tensors.
TORCH_API void ONNXAssignOutputShape(
std::shared_ptr<Graph>& graph,
at::ArrayRef<at::Tensor> outputs,
bool onnx_shape_inference);
// Utilize ONNX Shape Inference for node.
// The node must have ONNX namespace, and is valid ONNX node accroding to spec.
// On successful ONNX shape inference runs, the function updates output types of
// n with inferred shape and type. Otherwise n is unchanged.
TORCH_API void ONNXShapeTypeInference(Node* n, int opset_version);
// Utilize ONNX Shape Inference for graph.
// Internally calls ONNXShapeTypeInference for each node, to achieve more
// coverage that skips only individual nodes if illegal, instead of skipping for
// the entire graph.
TORCH_API void ONNXShapeTypeInference(
std::shared_ptr<Graph>& g,
int opset_version);
} // namespace jit
} // namespace torch

View File

@ -139,6 +139,13 @@ void initJITBindings(PyObject* module) {
.def("_jit_pass_onnx_remove_print", RemovePrintOps)
.def("_jit_pass_onnx_preprocess_caffe2", PreprocessCaffe2Ops)
.def("_jit_pass_onnx", ToONNX)
.def(
"_jit_pass_onnx_assign_output_shape",
[](std::shared_ptr<Graph>& graph,
const std::vector<at::Tensor>& tensors,
bool onnx_shape_inference = false) {
ONNXAssignOutputShape(graph, tensors, onnx_shape_inference);
})
.def("_jit_pass_lower_all_tuples", LowerAllTuples)
.def("_jit_pass_onnx_function_substitution", ONNXFunctionCallSubstitution)
.def(
@ -188,7 +195,17 @@ void initJITBindings(PyObject* module) {
.def(
"_jit_pass_onnx_prepare_inplace_ops_for_onnx",
PrepareInplaceOpsForONNX)
.def("_jit_pass_onnx_node_shape_type_inference", ONNXShapeTypeInference)
.def(
"_jit_pass_onnx_node_shape_type_inference",
[](Node* n, int opset_version) {
ONNXShapeTypeInference(n, opset_version);
})
.def(
"_jit_pass_onnx_graph_shape_type_inference",
[](std::shared_ptr<Graph>& graph, int opset_version) {
ONNXShapeTypeInference(graph, opset_version);
})
.def("_jit_pass_onnx_set_dynamic_input_shape", ONNXSetDynamicInputShape)
.def("_jit_pass_fuse", FuseGraph)
.def(
"_jit_pass_dce",

View File

@ -238,7 +238,8 @@ void initPythonIRBindings(PyObject* module_) {
std::string graph;
std::shared_ptr<::ONNX_NAMESPACE::ModelProto> model_proto;
RawDataExportMap export_map;
std::tie(model_proto, export_map) = export_onnx(
SymbolDimMap symbol_map;
std::tie(model_proto, export_map, symbol_map) = export_onnx(
g,
initializers,
onnx_opset_version,
@ -251,6 +252,7 @@ void initPythonIRBindings(PyObject* module_) {
add_node_names,
use_external_data_format,
onnx_file_path);
graph = serialize_model_proto_to_string(model_proto);
std::unordered_map<std::string, py::bytes>
python_serialized_export_map;
for (auto& kv : export_map) {

View File

@ -147,11 +147,11 @@ struct PythonResolver : public Resolver {
ClassTypePtr classType_;
};
std::shared_ptr<PythonResolver> pythonResolver(ResolutionCallback rcb) {
std::shared_ptr<PythonResolver> pythonResolver(const ResolutionCallback& rcb) {
return std::make_shared<PythonResolver>(rcb);
}
std::shared_ptr<PythonResolver> pythonResolver(
ResolutionCallback rcb,
const ResolutionCallback& rcb,
std::string classname,
ClassTypePtr classType) {
return std::make_shared<PythonResolver>(
@ -491,21 +491,6 @@ static std::shared_ptr<Graph> _propagate_and_assign_input_shapes(
return retval;
}
static std::shared_ptr<Graph> _assign_output_shapes(
Graph& graph,
std::vector<at::Tensor> outputs) {
auto retval = graph.copy();
AT_ASSERT(retval->outputs().size() == outputs.size());
for (size_t i = 0; i < outputs.size(); ++i) {
auto scalar_type = outputs[i].scalar_type();
auto sizes = outputs[i].sizes();
auto type =
torch::jit::TensorType::createContiguous(scalar_type, at::kCPU, sizes);
retval->outputs()[i]->setType(type);
}
return retval;
}
void addFunctionToModule(Module& module, const StrongFunctionPtr& func) {
// Make a graph with a fake self argument
auto graph = func.function_->graph()->copy();
@ -641,7 +626,7 @@ struct slot_dict_impl {
template <typename T>
py::list debugMakeList(const T& list) {
py::list result;
for (auto elem : list) {
for (const auto& elem : list) {
result.append(py::cast(elem));
}
return result;
@ -681,7 +666,7 @@ static py::dict _jit_debug_module_iterators(Module& module) {
return result;
}
static constexpr const char* magic_method_names[] = {
static constexpr std::array<const char*, 47> magic_method_names = {
"__lt__", "__le__", "__eq__", "__ne__",
"__ge__", "__gt__", "__not__", "__abs__",
"__add__", "__and__", "__floordiv__", "__index__",
@ -806,7 +791,8 @@ void initJitScriptBindings(PyObject* module) {
err << "which does not have a __getstate__ method defined!";
throw std::runtime_error(err.str());
},
[](std::tuple<py::object, std::string> state_tup) -> Object {
[](const std::tuple<py::object, std::string>& state_tup)
-> Object {
py::object state;
std::string qualname;
std::tie(state, qualname) = state_tup;
@ -970,7 +956,7 @@ void initJitScriptBindings(PyObject* module) {
[](Module& m,
std::shared_ptr<ConcreteModuleType> concreteType,
const std::string& script,
ResolutionCallback rcb) {
const ResolutionCallback& rcb) {
const auto self = ModuleSelf(std::move(concreteType));
m._ivalue()->compilation_unit()->define(
*m.type()->name(), script, pythonResolver(rcb), &self);
@ -980,7 +966,7 @@ void initJitScriptBindings(PyObject* module) {
"_register_attribute",
[](Module& m,
const std::string& name,
TypePtr type,
const TypePtr& type,
py::handle value) {
m.register_attribute(name, type, toIValue(value, type));
})
@ -988,9 +974,9 @@ void initJitScriptBindings(PyObject* module) {
"_create_method_from_trace",
[](Module& self,
const std::string& name,
py::function func,
py::tuple input_tuple,
py::function var_lookup_fn,
const py::function& func,
const py::tuple& input_tuple,
const py::function& var_lookup_fn,
bool strict,
bool force_outplace) {
// prereq: Module's buffers and parameters are unique
@ -1106,7 +1092,7 @@ void initJitScriptBindings(PyObject* module) {
"define",
[](CompilationUnit& cu,
const std::string& src,
ResolutionCallback rcb) {
const ResolutionCallback& rcb) {
cu.define(c10::nullopt, src, pythonResolver(rcb), nullptr);
})
.def(
@ -1279,10 +1265,10 @@ void initJitScriptBindings(PyObject* module) {
});
m.def(
"_create_function_from_trace",
[](std::string qualname,
py::function func,
py::tuple input_tuple,
py::function var_lookup_fn,
[](const std::string& qualname,
const py::function& func,
const py::tuple& input_tuple,
const py::function& var_lookup_fn,
bool strict,
bool force_outplace) {
auto typed_inputs = toTraceableStack(input_tuple);
@ -1303,7 +1289,7 @@ void initJitScriptBindings(PyObject* module) {
[](const std::string& qualifiedName,
const ClassDef& classDef,
const ClassMethodDefaults& defaults,
ResolutionCallback rcb) {
const ResolutionCallback& rcb) {
C10_LOG_API_USAGE_ONCE("torch.script.class");
if (classDef.superclass().present()) {
throw ErrorReport(classDef.range())
@ -1465,7 +1451,6 @@ void initJitScriptBindings(PyObject* module) {
m.def("_propagate_shapes", _propagate_shapes);
m.def(
"_propagate_and_assign_input_shapes", _propagate_and_assign_input_shapes);
m.def("_assign_output_shapes", _assign_output_shapes);
m.def(
"_last_executed_optimized_graph",
[]() { return lastExecutedOptimizedGraph(); },
@ -1668,18 +1653,23 @@ void initJitScriptBindings(PyObject* module) {
m.def(
"_resolve_type",
[](const std::string& name, SourceRange range, ResolutionCallback rcb) {
[](const std::string& name,
const SourceRange& range,
const ResolutionCallback& rcb) {
return pythonResolver(rcb)->resolveType(name, range);
});
m.def(
"_resolve_type_from_object",
[](const py::object& obj, SourceRange range, ResolutionCallback rcb) {
[](const py::object& obj,
const SourceRange& range,
const ResolutionCallback& rcb) {
return pythonResolver(rcb)->resolveTypeFromObject(obj, range);
});
m.def(
"_run_emit_module_hook", [](const Module& m) { didFinishEmitModule(m); });
// NOLINTNEXTLINE(bugprone-unused-raii)
py::class_<logging::LoggerBase, std::shared_ptr<logging::LoggerBase>>(
m, "LoggerBase");
py::enum_<logging::LockingLogger::AggregationType>(m, "AggregationType")

View File

@ -178,6 +178,10 @@ class EncoderBase {
return model_proto_;
}
SymbolDimMap get_symbol_dim_param_map() {
return symbol_dim_map_;
}
protected:
// Using std::map instead of std::unordered_map for initializers
// in EncodeGraph constructor so that the order in which initializers
@ -243,6 +247,7 @@ class EncoderBase {
const bool use_external_data_format = false,
const std::string& onnx_file_path = std::string());
SymbolDimMap symbol_dim_map_;
onnx::ModelProto model_proto_;
size_t num_blocks_;
size_t num_op_nodes_;
@ -316,33 +321,38 @@ void EncoderBase::EncodeValueInfo(
std::unordered_map<int64_t, std::string>>& dynamic_axes) {
std::string name = n->debugName();
v->set_name(name);
auto tensorTypeToONNXType = [&dynamic_axes, &name](
auto tensorTypeToONNXType = [&dynamic_axes, &name, this](
TensorTypePtr t,
onnx::TypeProto_Tensor* tensor_type) {
if (t->sizes().isComplete()) {
// onnx::TypeProto* onnx_type = v->mutable_type();
// onnx::TypeProto_Tensor* tensor_type = onnx_type->mutable_tensor_type();
if (t->dim()) {
onnx::TensorShapeProto* shape = tensor_type->mutable_shape();
std::vector<std::int64_t> sizes = t->sizes().concrete_sizes().value();
auto sizes = t->symbolic_sizes().sizes().value();
for (size_t i = 0; i < sizes.size(); i++) {
shape->add_dim();
if ((dynamic_axes.find(name) != dynamic_axes.end()) &&
(dynamic_axes.at(name).find(i) != dynamic_axes.at(name).end())) {
shape->mutable_dim(i)->set_dim_param(dynamic_axes.at(name).at(i));
if (!sizes[i].is_static()) {
symbol_dim_map_[sizes[i]] = dynamic_axes.at(name).at(i);
}
} else if (sizes[i].is_static()) {
shape->mutable_dim(i)->set_dim_value(sizes[i].static_size());
} else {
shape->mutable_dim(i)->set_dim_value(sizes[i]);
if (symbol_dim_map_.find(sizes[i]) == symbol_dim_map_.end()) {
symbol_dim_map_[sizes[i]] = name + "_" + std::to_string(i);
}
shape->mutable_dim(i)->set_dim_param(symbol_dim_map_[sizes[i]]);
}
}
}
if (t->scalarType()) {
// onnx::TypeProto* onnx_type = v->mutable_type();
// onnx::TypeProto_Tensor* tensor_type = onnx_type->mutable_tensor_type();
tensor_type->set_elem_type(ATenTypeToOnnxType(t->scalarType().value()));
}
};
if (TensorTypePtr node_type = n->type()->cast<TensorType>()) {
if (node_type->sizes().isComplete() || node_type->scalarType()) {
if (node_type->dim() || node_type->scalarType()) {
// Encode type if either shape or dtype exists.
onnx::TypeProto* onnx_type = v->mutable_type();
onnx::TypeProto_Tensor* tensor_type = onnx_type->mutable_tensor_type();
tensorTypeToONNXType(node_type, tensor_type);
@ -854,7 +864,10 @@ std::string pretty_print_onnx(
// conform to the ONNX op specification. Thus, the output will not
// be interpretable by a ONNX-compatible framework. However, PyTorch or
// libtorch will be able to import the IR and play it back.
std::tuple<std::shared_ptr<::ONNX_NAMESPACE::ModelProto>, RawDataExportMap>
std::tuple<
std::shared_ptr<::ONNX_NAMESPACE::ModelProto>,
RawDataExportMap,
SymbolDimMap>
export_onnx(
const std::shared_ptr<Graph>& graph,
const std::map<std::string, at::Tensor>& initializers,
@ -888,10 +901,12 @@ export_onnx(
proto_size <= INT_MAX,
"Exporting model exceed maximum protobuf size of 2GB. "
"Please call torch.onnx.export with use_external_data_format=True.");
GRAPH_UPDATE("onnx proto:", prettyPrint(graph_encoder.get_model_proto()));
std::shared_ptr<onnx::ModelProto> model_proto =
std::make_shared<onnx::ModelProto>(graph_encoder.get_model_proto());
return std::make_tuple(model_proto, graph_encoder.get_raw_data_export_map());
GRAPH_DEBUG("onnx proto:", prettyPrint(graph_encoder.get_model_proto()));
return std::make_tuple(
std::make_shared<::ONNX_NAMESPACE::ModelProto>(
graph_encoder.get_model_proto()),
graph_encoder.get_raw_data_export_map(),
graph_encoder.get_symbol_dim_param_map());
}
std::string serialize_model_proto_to_string(

View File

@ -25,24 +25,28 @@ namespace jit {
// file contents being the raw tensor data.
using RawDataExportMap = std::unordered_map<std::string, at::Tensor>;
TORCH_API std::
tuple<std::shared_ptr<::ONNX_NAMESPACE::ModelProto>, RawDataExportMap>
export_onnx(
const std::shared_ptr<Graph>& graph,
const std::map<std::string, at::Tensor>& initializers,
int64_t onnx_opset_version,
const std::unordered_map<
std::string,
std::unordered_map<int64_t, std::string>>& dynamic_axes,
bool defer_weight_export = false,
::torch::onnx::OperatorExportTypes operator_export_type =
::torch::onnx::OperatorExportTypes::ONNX,
bool strip_doc_string = true,
bool keep_initializers_as_inputs = true,
const std::map<std::string, int>& custom_opsets = {},
bool add_node_names = true,
bool use_external_data_format = false,
const std::string& onnx_file_path = std::string());
using SymbolDimMap = std::map<c10::ShapeSymbol, std::string>;
TORCH_API std::tuple<
std::shared_ptr<::ONNX_NAMESPACE::ModelProto>,
RawDataExportMap,
SymbolDimMap>
export_onnx(
const std::shared_ptr<Graph>& graph,
const std::map<std::string, at::Tensor>& initializers,
int64_t onnx_opset_version,
const std::unordered_map<
std::string,
std::unordered_map<int64_t, std::string>>& dynamic_axes,
bool defer_weight_export = false,
::torch::onnx::OperatorExportTypes operator_export_type =
::torch::onnx::OperatorExportTypes::ONNX,
bool strip_doc_string = true,
bool keep_initializers_as_inputs = true,
const std::map<std::string, int>& custom_opsets = {},
bool add_node_names = true,
bool use_external_data_format = false,
const std::string& onnx_file_path = std::string());
TORCH_API std::string serialize_model_proto_to_string(
const std::shared_ptr<::ONNX_NAMESPACE::ModelProto>& model_proto);

View File

@ -171,6 +171,8 @@ def _is_none(x):
def _is_value(x):
return isinstance(x, torch._C.Value)
def _is_tensor(x):
return x.type().isSubtypeOf(torch._C.TensorType.get())
def _is_tensor_list(x):
return isinstance(x.type(), torch._C.ListType) and isinstance(x.type().getElementType(), torch._C.TensorType)

View File

@ -6,7 +6,7 @@ import torch.onnx.symbolic_helper as sym_help
import warnings
import numpy
from torch.onnx.symbolic_helper import parse_args, _unimplemented
from torch.onnx.symbolic_helper import parse_args, _unimplemented, _is_tensor_list
from torch.onnx.symbolic_opset9 import expand, unused
from torch.nn.modules.utils import _single, _pair, _triple
from torch.onnx.utils import _add_block, _add_input_to_block, _add_output_to_block
@ -272,7 +272,7 @@ def masked_scatter(g, self, mask, source):
def _len(g, self):
if self.type().isSubtypeOf(torch._C.ListType.ofTensors()) or self.node().kind() == "onnx::SplitToSequence":
if _is_tensor_list(self) or self.node().kind() == "onnx::SplitToSequence":
return g.op("SequenceLength", self)
return g.op("Size", self)

View File

@ -121,6 +121,7 @@ def floor_divide(g, self, other):
# - self is not fp and other is not fp, the output's type is self's output type
# - the output type defaults to Float
scalar_type = self.type().scalarType()
if scalar_type is not None:
if not sym_help._is_fp(self) and \
other.type().scalarType() is not None and \

View File

@ -17,7 +17,7 @@ import warnings
from torch._six import string_classes
from torch.jit import _unique_state_dict
from torch.onnx import ONNX_ARCHIVE_MODEL_PROTO_NAME, ExportTypes, OperatorExportTypes, TrainingMode
from torch._C import ListType, OptionalType, _propagate_and_assign_input_shapes, _assign_output_shapes, _check_onnx_proto
from torch._C import ListType, OptionalType, _propagate_and_assign_input_shapes, _check_onnx_proto
# the flag to tell the user whether it's in the middle of ONNX export or not
@ -121,7 +121,7 @@ def _split_tensor_list_constants(g, block):
def _optimize_graph(graph, operator_export_type, _disable_torch_constant_prop=False, fixed_batch_size=False,
params_dict=None, use_new_jit_passes=False):
params_dict=None, use_new_jit_passes=False, dynamic_axes=None, input_names=None):
# Inline everything
torch._C._jit_pass_inline(graph)
@ -195,6 +195,11 @@ def _optimize_graph(graph, operator_export_type, _disable_torch_constant_prop=Fa
# onnx only supports tensors, so we turn all out number types into tensors
torch._C._jit_pass_erase_number_types(graph)
from torch.onnx.symbolic_helper import _onnx_shape_inference
if _onnx_shape_inference:
input_names = [] if input_names is None else input_names
dynamic_axes = {} if dynamic_axes is None else dynamic_axes
torch._C._jit_pass_onnx_set_dynamic_input_shape(graph, dynamic_axes, input_names)
graph = torch._C._jit_pass_onnx(graph, operator_export_type)
torch._C._jit_pass_lint(graph)
@ -214,6 +219,9 @@ def _optimize_graph(graph, operator_export_type, _disable_torch_constant_prop=Fa
torch._C._jit_pass_lint(graph)
graph = torch._C._jit_pass_canonicalize(graph)
torch._C._jit_pass_lint(graph)
from torch.onnx.symbolic_helper import _onnx_shape_inference, _export_onnx_opset_version
if _onnx_shape_inference:
torch._C._jit_pass_onnx_graph_shape_type_inference(graph, _export_onnx_opset_version)
return graph
@ -388,7 +396,8 @@ def _model_to_graph(model, args, verbose=False,
example_outputs=None,
_retain_param_name=False, do_constant_folding=True,
_disable_torch_constant_prop=False, fixed_batch_size=False,
training=None, use_new_jit_passes=False):
training=None, use_new_jit_passes=False,
dynamic_axes=None):
from torch.onnx.symbolic_helper import _export_onnx_opset_version
# Special case for common case of passing a single Tensor
if isinstance(args, torch.Tensor):
@ -408,19 +417,20 @@ def _model_to_graph(model, args, verbose=False,
graph = _optimize_graph(graph, operator_export_type,
_disable_torch_constant_prop=_disable_torch_constant_prop,
fixed_batch_size=fixed_batch_size, params_dict=params_dict,
use_new_jit_passes=use_new_jit_passes)
use_new_jit_passes=use_new_jit_passes,
dynamic_axes=dynamic_axes, input_names=input_names)
from torch.onnx.symbolic_helper import _onnx_shape_inference
if isinstance(model, torch.jit.ScriptModule) or isinstance(model, torch.jit.ScriptFunction):
assert example_outputs is not None, "example_outputs must be provided when exporting a ScriptModule or " \
"ScriptFunction."
out_vars, _ = torch.jit._flatten(tuple(example_outputs))
graph = _assign_output_shapes(graph, out_vars)
torch._C._jit_pass_onnx_assign_output_shape(graph, out_vars, _onnx_shape_inference)
# NB: ONNX requires complete information about output types, which might be
# erased by some optimizations, so we need to set it explicitly again.
if torch_out is not None:
output_tensors, _ = torch._C._jit_flatten(torch_out)
for output, tensor in zip(graph.outputs(), output_tensors):
output.inferTypeFrom(tensor)
torch._C._jit_pass_onnx_assign_output_shape(graph, output_tensors, _onnx_shape_inference)
_set_input_and_output_names(graph, input_names, output_names)
@ -513,12 +523,12 @@ def _find_missing_ops_onnx_export(model, args, f, verbose=False, training=Traini
input_names=None, output_names=None, opset_version=None, dynamic_axes=None):
r"""
This diagnostic tool runs your model with operator_export_type set to
OperatorExportTypes.ONNX_FALLTHROUGH once in order to get a list of
OperatorExportTypes.ONNX_FALLTHROUGH once in order to get a list of
all the ops that are not supported/implemented by the current exporter
operator_export_type is set to OperatorExportTypes.ONNX_FALLTHROUGH by default
OperatorExportTypes.ONNX_FALLTHROUGH: If an op is not supported
in ONNX, fall through and export the operator as is, as a custom
in ONNX, fall through and export the operator as is, as a custom
ONNX op. Using this mode, the op can be exported and implemented by
the user for their runtime backend.
Example graph::
@ -537,7 +547,7 @@ def _find_missing_ops_onnx_export(model, args, f, verbose=False, training=Traini
%5 : Float(2:12, 3:4, 4:1, requires_grad=0, device=cpu) = aten::cumsum(%0, %6, %4) # main.py:6:0
return (%5)
In the above example, aten::cumsum in not implemented in opset 9, hence exporter falls
In the above example, aten::cumsum in not implemented in opset 9, hence exporter falls
through and provides a list of unsupported ops, the result being:
Unsupported ops : [aten:cumsum]
"""
@ -614,6 +624,10 @@ def _export(model, args, f, export_params=True, verbose=False, training=None,
val_use_external_data_format, model_file_location = _decide_external_data_format(use_external_data_format,
operator_export_type,
f)
if dynamic_axes is None:
dynamic_axes = {}
_validate_dynamic_axes(dynamic_axes, model, input_names, output_names)
graph, params_dict, torch_out = \
_model_to_graph(model, args, verbose, input_names,
output_names, operator_export_type,
@ -621,17 +635,14 @@ def _export(model, args, f, export_params=True, verbose=False, training=None,
val_do_constant_folding,
fixed_batch_size=fixed_batch_size,
training=training,
use_new_jit_passes=use_new_jit_passes)
use_new_jit_passes=use_new_jit_passes,
dynamic_axes=dynamic_axes)
# TODO: Don't allocate a in-memory string for the protobuf
defer_weight_export = export_type is not ExportTypes.PROTOBUF_FILE
if dynamic_axes is None:
dynamic_axes = {}
if custom_opsets is None:
custom_opsets = {}
_validate_dynamic_axes(dynamic_axes, model, input_names, output_names)
if export_params:
proto, export_map = graph._export_onnx(
params_dict, opset_version, dynamic_axes, defer_weight_export,