[ONNX] Handle sequence output for models (#50599)

Summary:
Duplicate of https://github.com/pytorch/pytorch/issues/46542

Pull Request resolved: https://github.com/pytorch/pytorch/pull/50599

Reviewed By: SplitInfinity

Differential Revision: D25928897

Pulled By: bzinodev

fbshipit-source-id: a898cef7b2d15a287aedd9798ce1423cebf378d4
This commit is contained in:
neginraoof 2021-01-21 15:29:19 -08:00 committed by Facebook GitHub Bot
parent c082e2184d
commit 137f2a385a
20 changed files with 243 additions and 72 deletions

View File

@ -56,6 +56,7 @@ pytest "${args[@]}" \
--ignore "$top_dir/test/onnx/test_custom_ops.py" \
--ignore "$top_dir/test/onnx/test_models_onnxruntime.py" \
--ignore "$top_dir/test/onnx/test_utility_funs.py" \
--ignore "$top_dir/test/onnx/test_pytorch_onnx_caffe2.py" \
--ignore "$top_dir/test/onnx/test_pytorch_onnx_shape_inference.py" \
"${test_paths[@]}"
@ -68,7 +69,8 @@ if [[ "$BUILD_ENVIRONMENT" == *ort_test1* ]]; then
"$top_dir/test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime" \
"$top_dir/test/onnx/test_custom_ops.py" \
"$top_dir/test/onnx/test_models_onnxruntime.py" \
"$top_dir/test/onnx/test_utility_funs.py"
"$top_dir/test/onnx/test_utility_funs.py" \
"$top_dir/test/onnx/test_pytorch_onnx_caffe2.py"
fi
if [[ "$BUILD_ENVIRONMENT" == *ort_test2* ]]; then
# Update the loop for new opsets

View File

@ -146,7 +146,7 @@ graph {
elem_type: 1
shape {
dim {
dim_param: "13_0"
dim_param: "Range13_dim_0"
}
}
}

View File

@ -47,10 +47,10 @@ graph {
elem_type: 1
shape {
dim {
dim_param: "2_0"
dim_param: "ConstantOfShape2_dim_0"
}
dim {
dim_param: "2_1"
dim_param: "ConstantOfShape2_dim_1"
}
}
}

View File

@ -137,10 +137,10 @@ graph {
elem_type: 1
shape {
dim {
dim_param: "10_0"
dim_param: "ConstantOfShape10_dim_0"
}
dim {
dim_param: "10_1"
dim_param: "ConstantOfShape10_dim_1"
}
}
}

View File

@ -47,10 +47,10 @@ graph {
elem_type: 1
shape {
dim {
dim_param: "2_0"
dim_param: "ConstantOfShape2_dim_0"
}
dim {
dim_param: "2_1"
dim_param: "ConstantOfShape2_dim_1"
}
}
}

View File

@ -47,10 +47,10 @@ graph {
elem_type: 1
shape {
dim {
dim_param: "2_0"
dim_param: "ConstantOfShape2_dim_0"
}
dim {
dim_param: "2_1"
dim_param: "ConstantOfShape2_dim_1"
}
}
}

View File

@ -67,7 +67,7 @@ graph {
elem_type: 1
shape {
dim {
dim_param: "4_0"
dim_param: "TopK4_dim_0"
}
}
}
@ -80,7 +80,7 @@ graph {
elem_type: 7
shape {
dim {
dim_param: "5_0"
dim_param: "TopK5_dim_0"
}
}
}

View File

@ -77,7 +77,7 @@ graph {
elem_type: 1
shape {
dim {
dim_param: "4_0"
dim_param: "TopK4_dim_0"
}
}
}
@ -90,7 +90,7 @@ graph {
elem_type: 7
shape {
dim {
dim_param: "5_0"
dim_param: "TopK5_dim_0"
}
}
}

View File

@ -51,7 +51,7 @@ graph {
elem_type: 1
shape {
dim {
dim_param: "1_0"
dim_param: "Unique1_dim_0"
}
dim {
dim_value: 3
@ -73,7 +73,7 @@ graph {
elem_type: 7
shape {
dim {
dim_param: "4_0"
dim_param: "Unique4_dim_0"
}
}
}

View File

@ -50,16 +50,16 @@ graph {
elem_type: 1
shape {
dim {
dim_param: "4_0"
dim_param: "Upsample4_dim_0"
}
dim {
dim_param: "4_1"
dim_param: "Upsample4_dim_1"
}
dim {
dim_param: "4_2"
dim_param: "Upsample4_dim_2"
}
dim {
dim_param: "4_3"
dim_param: "Upsample4_dim_3"
}
}
}

View File

@ -50,16 +50,16 @@ graph {
elem_type: 1
shape {
dim {
dim_param: "4_0"
dim_param: "Upsample4_dim_0"
}
dim {
dim_param: "4_1"
dim_param: "Upsample4_dim_1"
}
dim {
dim_param: "4_2"
dim_param: "Upsample4_dim_2"
}
dim {
dim_param: "4_3"
dim_param: "Upsample4_dim_3"
}
}
}

View File

@ -47,10 +47,10 @@ graph {
elem_type: 1
shape {
dim {
dim_param: "2_0"
dim_param: "ConstantOfShape2_dim_0"
}
dim {
dim_param: "2_1"
dim_param: "ConstantOfShape2_dim_1"
}
}
}

View File

@ -58,6 +58,13 @@ def convert_to_onnx(model, input=None, opset_version=9, example_outputs=None,
ort_sess = onnxruntime.InferenceSession(f.getvalue())
return ort_sess
def inline_flatten_list(inputs, res_list):
for i in inputs:
res_list.append(i) if not isinstance(i, (list, tuple)) else inline_flatten_list(i, res_list)
return res_list
def run_ort(ort_sess, input):
input_copy = copy.deepcopy(input)
input, _ = torch.jit._flatten(input_copy)
@ -66,7 +73,8 @@ def run_ort(ort_sess, input):
ort_inputs = dict((ort_sess.get_inputs()[i].name, input) for i, input in enumerate(inputs))
ort_outs = ort_sess.run(None, ort_inputs)
return ort_outs
return inline_flatten_list(ort_outs, [])
def ort_compare_with_pytorch(ort_outs, output, rtol, atol):
output, _ = torch.jit._flatten(output)
@ -115,7 +123,7 @@ def run_model_test(self, model, batch_size=2, state_dict=None,
output_names=output_names, fixed_batch_size=fixed_batch_size, training=None,
onnx_shape_inference=self.onnx_shape_inference,
use_new_jit_passes=self.use_new_jit_passes)
# compute onnxruntime output prediction
ort_outs = run_ort(ort_sess, input)
ort_compare_with_pytorch(ort_outs, output, rtol, atol)
@ -3591,27 +3599,24 @@ class TestONNXRuntime(unittest.TestCase):
def test_split(self):
class SplitModel(torch.nn.Module):
def forward(self, input):
out1, out2, out3 = input.split([2, 1, 2])
return out1, out2, out3
return input.split([2, 1, 2]), input.split([3, 2])[0]
x = torch.randn(5, 4, 3)
self.run_test(SplitModel(), x)
class SplitModel2(torch.nn.Module):
def forward(self, input):
out1, out2, out3 = input.split([2, 1, 1], -2)
return out1, out2, out3
return input.split([2, 1, 1], -2), input.split([2, 2], -2)[-1]
x = torch.randn(5, 4, 3)
self.run_test(SplitModel2(), x)
class SplitModel3(torch.nn.Module):
def forward(self, input):
out1, out2, out3 = input.split([2, 1, 2])
return out3, out1
return input.split([2, 1, 2])
x = torch.randn(5, 4, 3)
self.run_test(torch.jit.script(SplitModel3()), x)
self.run_test(SplitModel3(), x)
@skipIfUnsupportedOpsetVersion([13])
@skipIfUnsupportedMinOpsetVersion(11)
@ -3769,7 +3774,7 @@ class TestONNXRuntime(unittest.TestCase):
res2 += 1
res3 = res3 + [arr[i].sum(0, False)]
res4 += [arr[-1 - i].sum(0, False)]
return torch.stack(res), torch.stack(res1), res2, torch.stack(res3), torch.stack(res4)
return res, res1, res2, torch.stack(res3), torch.stack(res4)
model = ListLoopModel()
inputs = torch.randn(16)
@ -5508,7 +5513,6 @@ class TestONNXRuntime(unittest.TestCase):
self.assertRaises(RuntimeError, check_proto)
@disableScriptTest() # dtype mismatch
def test_split_tensor_scalar(self):
class SplitModel(torch.nn.Module):
def forward(self, x):
@ -5861,6 +5865,37 @@ class TestONNXRuntime(unittest.TestCase):
[np.testing.assert_allclose(ort_out1, ort_out2, atol=1e-7, rtol=0.001) for ort_out1, ort_out2 in
zip(ort_outs1, ort_outs2)]
def test_script_custom_class_error(self):
class BoxCoder(object):
def __init__(self, bbox_xform_clip):
# type: (float) -> None
self.bbox_xform_clip = bbox_xform_clip
def decode(self, rel_codes, boxes):
# type: (Tensor, List[Tensor]) -> Tensor
boxes = torch.cat(boxes, dim=0)
pred_ctr_x = torch.clamp(rel_codes[:, 0::4], max=self.bbox_xform_clip) * boxes[:, 2]
return pred_ctr_x
class MyModule(torch.nn.Module):
__annotations__ = {
'box_coder': BoxCoder,
}
def __init__(self):
super(MyModule, self).__init__()
self.box_coder = BoxCoder(1.4)
def forward(self, box_regression: torch.Tensor, proposals: List[torch.Tensor]):
return self.box_coder.decode(box_regression, proposals)
model = torch.jit.script(MyModule())
box_regression = torch.randn([4, 4])
proposal = [torch.randn(2, 4), torch.randn(2, 4)]
with self.assertRaises(RuntimeError) as cm:
self.run_test(model, (box_regression, proposal))
@skipIfUnsupportedOpsetVersion([13])
def test_initializer_sequence(self):
class MyModule(torch.nn.Module):

View File

@ -257,7 +257,7 @@ def _replace_overloaded_method_decl(overload_decl: Decl, implementation_def: Def
def _jit_pass_lower_all_tuples(graph: Graph) -> None: ...
def _jit_pass_onnx_set_dynamic_input_shape(graph: Graph, dynamic_axes: Dict[str, Dict[_int, str]], input_names: List[str]) -> None: ...
def _jit_pass_onnx_graph_shape_type_inference(graph: Graph, opset_version: _int) -> None: ...
def _jit_pass_onnx_assign_output_shape(graph: Graph, tensors: List[Tensor], onnx_shape_inference: _bool = False) -> None: ...
def _jit_pass_onnx_assign_output_shape(graph: Graph, tensors: List[Tensor], desc: IODescriptor, onnx_shape_inference: _bool = False) -> None: ...
def _jit_pass_onnx_remove_inplace_ops_for_onnx(graph: Graph) -> None: ...
def _jit_pass_remove_inplace_ops(graph: Graph) -> None: ...
def _jit_pass_canonicalize_graph_fuser_ops(graph: Graph) -> None: ...

View File

@ -34,9 +34,10 @@ std::deque<std::string> findSubModuleAttr(
if (node->kind() == prim::GetAttr) {
moduleNames.push_front(node->s(attr::name));
node = node->inputs()[0]->node();
} else {
return moduleNames;
}
}
// Assign the inner module to attrModule.
for (auto& moduleName : moduleNames) {
attrModule = attrModule.attr(moduleName).toModule();
@ -127,24 +128,16 @@ std::vector<IValue> getParamAttributes(
paramConst = addParamAsArgument(function_, fullName, attr);
} else if (
attr.isObject() && !attr.toObjectRef().type()->is_module()) {
// Only below registered torch classes are supported.
auto type = attr.type();
TORCH_CHECK(
(type ==
getCustomClass(
"__torch__.torch.classes.quantized.Conv2dPackedParamsBase")) ||
(type ==
getCustomClass(
"__torch__.torch.classes.quantized.Conv3dPackedParamsBase")) ||
(type ==
getCustomClass(
"__torch__.torch.classes.quantized.LinearPackedParamsBase")),
"Unknown type ",
type->repr_str(),
" encountered in handling model params. This type is not supported in ONNX export.");
attrValues.emplace_back(
script::Object(attr.toObject()).run_method("__getstate__"));
paramConst = addParamAsArgument(function_, fullName, attr);
try {
attrValues.emplace_back(
script::Object(attr.toObject()).run_method("__getstate__"));
paramConst = addParamAsArgument(function_, fullName, attr);
} catch (const std::exception&) {
auto type = attr.type();
throw ErrorReport(n->sourceRange())
<< "Unknown type " << type->repr_str()
<< " encountered in handling model params. This class type does not extend __getstate__ method.";
}
} else if (attr.isNone() || name == "training") {
auto attrVal = tryInsertConstant(*graph, attr);
paramConst = *attrVal;

View File

@ -4,6 +4,7 @@
#include <torch/csrc/jit/passes/onnx/fold_if_node.h>
#include <torch/csrc/jit/passes/onnx/helper.h>
#include <torch/csrc/jit/passes/onnx/scalar_type_analysis.h>
#include <torch/csrc/jit/python/python_arg_flatten.h>
#include <torch/csrc/jit/serialization/export.h>
#include <torch/csrc/jit/serialization/onnx.h>
@ -552,19 +553,141 @@ void ONNXSetDynamicInputShape(
}
}
bool HasSequenceTypeOutput(Node* node) {
if (node->kind() == ::c10::onnx::SplitToSequence ||
node->kind() == ::c10::onnx::SequenceInsert ||
node->kind() == ::c10::onnx::SequenceEmpty ||
node->kind() == ::c10::onnx::SequenceErase ||
node->kind() == ::c10::onnx::SequenceConstruct)
return true;
return false;
}
void ONNXUpdateTypeFromTensor(
Value* graph_output,
const at::Tensor& output,
bool onnx_shape_inference) {
if (onnx_shape_inference) {
graph_output->setType(
MergeInferredType(TensorType::create(output), graph_output->type()));
} else {
graph_output->inferTypeFrom(output);
}
}
void ONNXAssignOutputShape(
std::shared_ptr<Graph>& graph,
at::ArrayRef<at::Tensor> outputs,
const python::IODescriptor& desc,
bool onnx_shape_inference) {
TORCH_INTERNAL_ASSERT(graph->outputs().size() == outputs.size());
for (size_t i = 0; i < outputs.size(); ++i) {
if (onnx_shape_inference) {
graph->outputs()[i]->setType(MergeInferredType(
TensorType::create(outputs[i]), graph->outputs()[i]->type()));
} else {
graph->outputs()[i]->inferTypeFrom(outputs[i]);
size_t outputs_index = 0;
PyObject* py_obj = unflatten(outputs, desc);
TORCH_INTERNAL_ASSERT(PyTuple_Check(py_obj));
for (size_t i = 0; i < PyTuple_GET_SIZE(py_obj); ++i) {
PyObject* elem = PyTuple_GET_ITEM(py_obj, i);
if (PyList_Check(elem)) {
size_t list_len = PyList_GET_SIZE(elem);
if (HasSequenceTypeOutput(graph->outputs()[outputs_index]->node())) {
if (list_len > 0) {
auto& var =
reinterpret_cast<THPVariable*>(PyList_GET_ITEM(elem, 0))->cdata;
for (size_t j = 1; j < list_len; ++j) {
PyObject* list_elem = PyList_GET_ITEM(elem, j);
TORCH_INTERNAL_ASSERT(THPVariable_Check(list_elem));
auto& new_var = reinterpret_cast<THPVariable*>(list_elem)->cdata;
TORCH_CHECK(
var.scalar_type() == new_var.scalar_type(),
"Unsupported sequence type in model outputs. ONNX supports sequences of elements of the same data type.");
}
auto elem_type = graph->outputs()[outputs_index]
->type()
->cast<ListType>()
->getElementType()
->cast<TensorType>();
elem_type = elem_type->withScalarType(var.scalar_type());
graph->outputs()[outputs_index]->setType(MergeInferredType(
graph->outputs()[outputs_index]->type(),
ListType::create(elem_type)));
outputs_index++;
TORCH_INTERNAL_ASSERT(
outputs_index <= graph->outputs().size(),
"Incorrect number of elements provided as example outputs.");
}
} else { // When torch output is a list type, but ONNX node is not a
// sequence type. Like prim::ListConstruct
size_t list_len = PyList_GET_SIZE(elem);
if (list_len > 0) {
for (size_t j = 0; j < list_len; ++j) {
PyObject* list_elem = PyList_GET_ITEM(elem, j);
TORCH_INTERNAL_ASSERT(THPVariable_Check(list_elem));
auto& var = reinterpret_cast<THPVariable*>(list_elem)->cdata;
graph->outputs()[outputs_index + j]->setType(MergeInferredType(
graph->outputs()[outputs_index + j]->type(),
TensorType::create(var)));
}
outputs_index += list_len;
TORCH_INTERNAL_ASSERT(
outputs_index <= graph->outputs().size(),
"Incorrect number of elements provided as example outputs.");
}
}
} else if (PyTuple_Check(elem)) {
size_t tuple_len = PyTuple_GET_SIZE(elem);
if (tuple_len > 0) {
for (size_t j = 0; j < tuple_len; ++j) {
PyObject* tuple_elem = PyTuple_GET_ITEM(elem, j);
TORCH_INTERNAL_ASSERT(THPVariable_Check(tuple_elem));
auto& var = reinterpret_cast<THPVariable*>(tuple_elem)->cdata;
graph->outputs()[outputs_index + j]->setType(MergeInferredType(
graph->outputs()[outputs_index + j]->type(),
TensorType::create(var)));
}
outputs_index += tuple_len;
TORCH_INTERNAL_ASSERT(
outputs_index <= graph->outputs().size(),
"Incorrect number of elements provided as example outputs.");
}
} else if (THPVariable_Check(elem)) {
at::Tensor var = reinterpret_cast<THPVariable*>(elem)->cdata;
ONNXUpdateTypeFromTensor(
graph->outputs()[outputs_index], var, onnx_shape_inference);
outputs_index++;
TORCH_INTERNAL_ASSERT(
outputs_index <= graph->outputs().size(),
"Incorrect number of elements provided as example outputs.");
} else { // Dict
// Support for dict data type is limited to fixed size dictionaries in
// ONNX.
// Dictionary values are unrolled and keys are not preserved.
TORCH_INTERNAL_ASSERT(PyDict_Check(elem));
auto unrolled_dict = py::reinterpret_borrow<py::list>(PyDict_Items(elem));
TORCH_INTERNAL_ASSERT(PyList_Check(unrolled_dict.ptr()));
for (size_t j = 0; j < unrolled_dict.size(); ++j) {
PyObject* tuple_elem = PyList_GET_ITEM(unrolled_dict.ptr(), j);
TORCH_INTERNAL_ASSERT(PyTuple_Check(tuple_elem));
TORCH_INTERNAL_ASSERT(PyTuple_GET_SIZE(tuple_elem) == 2);
auto& var =
reinterpret_cast<THPVariable*>(PyTuple_GET_ITEM(tuple_elem, 1))
->cdata;
graph->outputs()[outputs_index + j]->setType(MergeInferredType(
graph->outputs()[outputs_index + j]->type(),
TensorType::create(var)));
}
outputs_index += unrolled_dict.size();
TORCH_INTERNAL_ASSERT(
outputs_index <= graph->outputs().size(),
"Incorrect number of elements provided as example outputs.");
}
}
TORCH_INTERNAL_ASSERT(
outputs_index == graph->outputs().size(),
"Incorrect number of elements provided as example outputs.");
Py_DECREF(py_obj);
}
void ONNXShapeTypeInference(std::shared_ptr<Graph>& graph, int opset_version) {

View File

@ -1,6 +1,7 @@
#pragma once
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/python/python_arg_flatten.h>
namespace torch {
namespace jit {
@ -26,6 +27,7 @@ TORCH_API void ONNXSetDynamicInputShape(
TORCH_API void ONNXAssignOutputShape(
std::shared_ptr<Graph>& graph,
at::ArrayRef<at::Tensor> outputs,
const python::IODescriptor& desc,
bool onnx_shape_inference);
// Utilize ONNX Shape Inference for node.

View File

@ -150,8 +150,9 @@ void initJITBindings(PyObject* module) {
"_jit_pass_onnx_assign_output_shape",
[](std::shared_ptr<Graph>& graph,
const std::vector<at::Tensor>& tensors,
const python::IODescriptor& desc,
bool onnx_shape_inference = false) {
ONNXAssignOutputShape(graph, tensors, onnx_shape_inference);
ONNXAssignOutputShape(graph, tensors, desc, onnx_shape_inference);
})
.def("_jit_pass_lower_all_tuples", LowerAllTuples)
.def("_jit_pass_onnx_function_substitution", ONNXFunctionCallSubstitution)

View File

@ -331,7 +331,8 @@ void EncoderBase::EncodeValueInfo(
std::unordered_map<int64_t, std::string>>& dynamic_axes) {
std::string name = n->debugName();
v->set_name(name);
auto tensorTypeToONNXType = [&dynamic_axes, &name, this](
auto tensorTypeToONNXType = [&dynamic_axes, &name, n, this](
const TensorTypePtr& t,
onnx::TypeProto_Tensor* tensor_type) {
if (t->dim()) {
@ -349,7 +350,13 @@ void EncoderBase::EncodeValueInfo(
shape->mutable_dim(i)->set_dim_value(sizes[i].static_size());
} else {
if (symbol_dim_map_.find(sizes[i]) == symbol_dim_map_.end()) {
symbol_dim_map_[sizes[i]] = name + "_" + std::to_string(i);
if (n->node()->kind() == prim::Param) {
symbol_dim_map_[sizes[i]] = name + "_dim_" + std::to_string(i);
} else {
std::string op_type = n->node()->kind().toUnqualString();
symbol_dim_map_[sizes[i]] =
op_type + name + "_dim_" + std::to_string(i);
}
}
shape->mutable_dim(i)->set_dim_param(symbol_dim_map_[sizes[i]]);
}

View File

@ -437,7 +437,7 @@ def _model_to_graph(model, args, verbose=False,
args = (args, )
if isinstance(example_outputs, torch.Tensor):
example_outputs = [example_outputs]
example_outputs = (example_outputs,)
graph, params, torch_out = _create_jit_graph(model, args,
_retain_param_name,
@ -456,14 +456,22 @@ def _model_to_graph(model, args, verbose=False,
if isinstance(model, torch.jit.ScriptModule) or isinstance(model, torch.jit.ScriptFunction):
assert example_outputs is not None, "example_outputs must be provided when exporting a ScriptModule or " \
"ScriptFunction."
out_vars, _ = torch.jit._flatten(tuple(example_outputs))
torch._C._jit_pass_onnx_assign_output_shape(graph, out_vars, _onnx_shape_inference)
if isinstance(example_outputs, list):
example_outputs = [example_outputs]
out_vars, desc = torch.jit._flatten(tuple(example_outputs))
torch._C._jit_pass_onnx_assign_output_shape(graph, out_vars, desc, _onnx_shape_inference)
# NB: ONNX requires complete information about output types, which might be
# erased by some optimizations, so we need to set it explicitly again.
if torch_out is not None:
output_tensors, _ = torch._C._jit_flatten(torch_out)
torch._C._jit_pass_onnx_assign_output_shape(graph, output_tensors, _onnx_shape_inference)
if not (isinstance(torch_out, list) or isinstance(torch_out, tuple)):
output_wrapped = [torch_out] # type: ignore
else:
output_wrapped = torch_out # type: ignore
output_tensors, out_desc = torch._C._jit_flatten(tuple(output_wrapped))
torch._C._jit_pass_onnx_assign_output_shape(graph, output_tensors, out_desc, _onnx_shape_inference)
_set_input_and_output_names(graph, input_names, output_names)