From 586c2e8d624ca9bc1756470d6581d9993de2e085 Mon Sep 17 00:00:00 2001 From: BowenBao Date: Thu, 4 Feb 2021 12:35:27 -0800 Subject: [PATCH] [ONNX] Fix graph sequence output from loop node (#51305) (#51521) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/51521 * Add loop & if node to the list of nodes that could produce sequence type output. * Switch from `[]` to `at()` to avoid segfault of out of range access. Test Plan: Imported from OSS Reviewed By: pbelevich Differential Revision: D26203112 Pulled By: SplitInfinity fbshipit-source-id: e990eeed933124b195be0be159271e33fb485063 --- scripts/onnx/test.sh | 1 - test/onnx/test_pytorch_onnx_onnxruntime.py | 13 ++---- .../jit/passes/onnx/shape_type_inference.cpp | 42 ++++++++++++------- 3 files changed, 29 insertions(+), 27 deletions(-) diff --git a/scripts/onnx/test.sh b/scripts/onnx/test.sh index 20863157b01..a5c797e2166 100755 --- a/scripts/onnx/test.sh +++ b/scripts/onnx/test.sh @@ -83,7 +83,6 @@ if [[ "$BUILD_ENVIRONMENT" == *ort_test2* ]]; then "$top_dir/test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset$i" done pytest "${args[@]}" \ - "$top_dir/test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset12_onnx_shape_inference" \ "$top_dir/test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset12_IRv4_old_jit_API" fi diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index 8a11df49de2..68bff713819 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -18,6 +18,9 @@ from test_pytorch_common import BATCH_SIZE from test_pytorch_common import RNN_BATCH_SIZE, RNN_SEQUENCE_LENGTH, RNN_INPUT_SIZE, RNN_HIDDEN_SIZE from typing import List, Tuple, Optional import model_defs.word_language_model as word_language_model + +import onnx + import torchvision from torchvision import ops from torchvision.models.detection.image_list import ImageList @@ -26,7 +29,6 @@ from torchvision.models.detection.rpn import AnchorGenerator, RPNHead, RegionPro from torchvision.models.detection.roi_heads import RoIHeads from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, TwoMLPHead from collections import OrderedDict -import onnx def to_numpy(tensor): if tensor.requires_grad: @@ -3876,7 +3878,6 @@ class TestONNXRuntime(unittest.TestCase): inputs = torch.zeros(1, 2, 3, dtype=torch.long) self.run_test(model, inputs) - @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(11) def test_loop_with_list(self): class ListLoopModel(torch.jit.ScriptModule): @@ -6063,7 +6064,6 @@ class TestONNXRuntime(unittest.TestCase): convert_to_onnx(model, input=(box_regression, proposal), example_outputs=outputs, use_new_jit_passes=True) - @skipIfUnsupportedOpsetVersion([13]) def test_initializer_sequence(self): class MyModule(torch.nn.Module): def __init__(self, input_size, hidden_size, num_classes): @@ -6681,12 +6681,5 @@ TestONNXRuntime_opset12_IRv4_old_jit_API = type(str("TestONNXRuntime_opset12_IRv keep_initializers_as_inputs=False, use_new_jit_passes=False)) - -# opset 12 tests, with _onnx_shape_inference=True. -TestONNXRuntime_opset12_onnx_shape_inference = type(str("TestONNXRuntime_opset12_onnx_shape_inference"), - (unittest.TestCase,), - dict(TestONNXRuntime.__dict__, opset_version=12, - onnx_shape_inference=True)) - if __name__ == '__main__': unittest.main() diff --git a/torch/csrc/jit/passes/onnx/shape_type_inference.cpp b/torch/csrc/jit/passes/onnx/shape_type_inference.cpp index c642da9c247..531006b4ef4 100644 --- a/torch/csrc/jit/passes/onnx/shape_type_inference.cpp +++ b/torch/csrc/jit/passes/onnx/shape_type_inference.cpp @@ -586,7 +586,8 @@ bool HasSequenceTypeOutput(Node* node) { node->kind() == ::c10::onnx::SequenceInsert || node->kind() == ::c10::onnx::SequenceEmpty || node->kind() == ::c10::onnx::SequenceErase || - node->kind() == ::c10::onnx::SequenceConstruct) + node->kind() == ::c10::onnx::SequenceConstruct || + node->kind() == ::c10::onnx::Loop || node->kind() == ::c10::onnx::If) return true; return false; } @@ -618,7 +619,7 @@ void ONNXAssignOutputShape( if (PyList_Check(elem)) { size_t list_len = PyList_GET_SIZE(elem); - if (HasSequenceTypeOutput(graph->outputs()[outputs_index]->node())) { + if (HasSequenceTypeOutput(graph->outputs().at(outputs_index)->node())) { if (list_len > 0) { auto& var = reinterpret_cast(PyList_GET_ITEM(elem, 0))->cdata; @@ -630,15 +631,18 @@ void ONNXAssignOutputShape( var.scalar_type() == new_var.scalar_type(), "Unsupported sequence type in model outputs. ONNX supports sequences of elements of the same data type."); } - auto elem_type = graph->outputs()[outputs_index] + auto elem_type = graph->outputs() + .at(outputs_index) ->type() ->castRaw() ->getElementType() ->cast(); elem_type = elem_type->withScalarType(var.scalar_type()); - graph->outputs()[outputs_index]->setType(MergeInferredType( - graph->outputs()[outputs_index]->type(), - ListType::create(elem_type))); + graph->outputs() + .at(outputs_index) + ->setType(MergeInferredType( + graph->outputs().at(outputs_index)->type(), + ListType::create(elem_type))); outputs_index++; TORCH_INTERNAL_ASSERT( outputs_index <= graph->outputs().size(), @@ -652,9 +656,11 @@ void ONNXAssignOutputShape( PyObject* list_elem = PyList_GET_ITEM(elem, j); TORCH_INTERNAL_ASSERT(THPVariable_Check(list_elem)); auto& var = reinterpret_cast(list_elem)->cdata; - graph->outputs()[outputs_index + j]->setType(MergeInferredType( - graph->outputs()[outputs_index + j]->type(), - TensorType::create(var))); + graph->outputs() + .at(outputs_index + j) + ->setType(MergeInferredType( + graph->outputs().at(outputs_index + j)->type(), + TensorType::create(var))); } outputs_index += list_len; TORCH_INTERNAL_ASSERT( @@ -669,9 +675,11 @@ void ONNXAssignOutputShape( PyObject* tuple_elem = PyTuple_GET_ITEM(elem, j); TORCH_INTERNAL_ASSERT(THPVariable_Check(tuple_elem)); auto& var = reinterpret_cast(tuple_elem)->cdata; - graph->outputs()[outputs_index + j]->setType(MergeInferredType( - graph->outputs()[outputs_index + j]->type(), - TensorType::create(var))); + graph->outputs() + .at(outputs_index + j) + ->setType(MergeInferredType( + graph->outputs().at(outputs_index + j)->type(), + TensorType::create(var))); } outputs_index += tuple_len; TORCH_INTERNAL_ASSERT( @@ -681,7 +689,7 @@ void ONNXAssignOutputShape( } else if (THPVariable_Check(elem)) { at::Tensor var = reinterpret_cast(elem)->cdata; ONNXUpdateTypeFromTensor( - graph->outputs()[outputs_index], var, onnx_shape_inference); + graph->outputs().at(outputs_index), var, onnx_shape_inference); outputs_index++; TORCH_INTERNAL_ASSERT( outputs_index <= graph->outputs().size(), @@ -700,9 +708,11 @@ void ONNXAssignOutputShape( auto& var = reinterpret_cast(PyTuple_GET_ITEM(tuple_elem, 1)) ->cdata; - graph->outputs()[outputs_index + j]->setType(MergeInferredType( - graph->outputs()[outputs_index + j]->type(), - TensorType::create(var))); + graph->outputs() + .at(outputs_index + j) + ->setType(MergeInferredType( + graph->outputs().at(outputs_index + j)->type(), + TensorType::create(var))); } outputs_index += unrolled_dict.size(); TORCH_INTERNAL_ASSERT(