From cf7046697064db44ed573f6fe21ec657ccb28054 Mon Sep 17 00:00:00 2001 From: BowenBao Date: Mon, 31 Jan 2022 23:58:53 +0000 Subject: [PATCH] [ONNX] Improve scope inference in function extraction Cover more cases of scope inferencing where consecutive nodes don't have valid scope information. Usually these nodes are created in some pass where authors forgot to assign meaningful scope to them. * One rule of `InferScope` is to check if the current node's outputs' users share the same scope. Recursively run `InferScope` on the user nodes if they are missing scope as well. Since the graph is SSA, the depth is finite. * Fix one pass that missed scope information for a new node. Pull Request resolved: https://github.com/pytorch/pytorch/pull/71897 --- test/onnx/test_utility_funs.py | 22 +++++++++++++++++++ .../jit/passes/onnx/function_extraction.cpp | 7 ++++++ torch/csrc/jit/passes/onnx/helper.cpp | 1 + torch/onnx/__init__.py | 2 -- torch/onnx/symbolic_helper.py | 1 + torch/onnx/utils.py | 3 ++- 6 files changed, 33 insertions(+), 3 deletions(-) diff --git a/test/onnx/test_utility_funs.py b/test/onnx/test_utility_funs.py index d7bef668118..dca45fc5c31 100644 --- a/test/onnx/test_utility_funs.py +++ b/test/onnx/test_utility_funs.py @@ -752,6 +752,24 @@ class TestUtilityFuns_opset9(_BaseTestCase): self.assertIn("NWithOverloads.1", func_names) self.assertIn("NWithOverloads.2", func_names) + @skipIfUnsupportedMinOpsetVersion(15) + def test_local_function_infer_scopes(self): + class M(torch.nn.Module): + def forward(self, x): + # Concatenation of scalars inserts unscoped tensors in IR graph. + new_tensor_shape = x.size()[:-1] + (1, 1, -1) + tensor = x.view(*new_tensor_shape) + return tensor + + x = torch.randn(4, 5) + f = io.BytesIO() + torch.onnx.export(M(), (x,), f, export_modules_as_functions=True, + opset_version=self.opset_version, do_constant_folding=False) + + onnx_model = onnx.load(io.BytesIO(f.getvalue())) + funcs = onnx_model.functions + self.assertIn("M", [f.name for f in funcs]) + def test_aten_fallthrough(self): # Test aten export of op with no symbolic class Module(torch.nn.Module): @@ -1222,5 +1240,9 @@ class TestUtilityFuns_opset14(TestUtilityFuns_opset9): opset_version = 14 +class TestUtilityFuns_opset15(TestUtilityFuns_opset9): + opset_version = 15 + + if __name__ == "__main__": run_tests() diff --git a/torch/csrc/jit/passes/onnx/function_extraction.cpp b/torch/csrc/jit/passes/onnx/function_extraction.cpp index 5a0f8592f3b..1840d96fd13 100644 --- a/torch/csrc/jit/passes/onnx/function_extraction.cpp +++ b/torch/csrc/jit/passes/onnx/function_extraction.cpp @@ -353,6 +353,13 @@ c10::optional FunctionExtractor::InferScope(Node* n) { } for (auto output : n->outputs()) { for (auto use : output->uses()) { + if (!IsValidScope(use.user->scope())) { + auto inferred_output_scope = InferScope(use.user); + if (inferred_output_scope.has_value() && + IsValidScope(inferred_output_scope.value())) { + use.user->setScope(inferred_output_scope.value()); + } + } output_scopes.emplace_back(use.user->scope()); } } diff --git a/torch/csrc/jit/passes/onnx/helper.cpp b/torch/csrc/jit/passes/onnx/helper.cpp index 83206935048..f76b606c181 100644 --- a/torch/csrc/jit/passes/onnx/helper.cpp +++ b/torch/csrc/jit/passes/onnx/helper.cpp @@ -189,6 +189,7 @@ Node* transformToONNXConcatNode( Node* unsqueezed_node = createONNXUnsqueeze(g, new_node, new_input, 0, opset_version); + unsqueezed_node->copyMetadata(lc_node); unsqueezed.emplace_back(unsqueezed_node->output()); } diff --git a/torch/onnx/__init__.py b/torch/onnx/__init__.py index 31bffdabca1..2a049f37e35 100644 --- a/torch/onnx/__init__.py +++ b/torch/onnx/__init__.py @@ -9,8 +9,6 @@ ONNX_ARCHIVE_MODEL_PROTO_NAME = "__MODEL_PROTO" producer_name = "pytorch" producer_version = _C._onnx.PRODUCER_VERSION -constant_folding_opset_versions = [9, 10, 11, 12, 13, 14] - class ExportTypes: r""""Specifies how the ONNX model is stored.""" diff --git a/torch/onnx/symbolic_helper.py b/torch/onnx/symbolic_helper.py index 3d62dc2df3c..d6746967d76 100644 --- a/torch/onnx/symbolic_helper.py +++ b/torch/onnx/symbolic_helper.py @@ -839,6 +839,7 @@ _default_onnx_opset_version = 9 _onnx_main_opset = 15 _onnx_stable_opsets = [7, 8, 9, 10, 11, 12, 13, 14] _export_onnx_opset_version = _default_onnx_opset_version +_constant_folding_opset_versions = list(range(9, _onnx_main_opset + 1)) def _set_opset_version(opset_version): diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index ad960a4b3c1..50023313ee7 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -532,7 +532,8 @@ def _model_to_graph(model, args, verbose=False, if training is None or training == TrainingMode.EVAL: params_dict = torch._C._jit_pass_onnx_eval_peephole(graph, params_dict) - if do_constant_folding and _export_onnx_opset_version in torch.onnx.constant_folding_opset_versions: + from torch.onnx.symbolic_helper import _constant_folding_opset_versions + if do_constant_folding and _export_onnx_opset_version in _constant_folding_opset_versions: params_dict = torch._C._jit_pass_onnx_constant_fold(graph, params_dict, _export_onnx_opset_version) torch._C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)