diff --git a/test/onnx/test_utility_funs.py b/test/onnx/test_utility_funs.py index d7bef668118..dca45fc5c31 100644 --- a/test/onnx/test_utility_funs.py +++ b/test/onnx/test_utility_funs.py @@ -752,6 +752,24 @@ class TestUtilityFuns_opset9(_BaseTestCase): self.assertIn("NWithOverloads.1", func_names) self.assertIn("NWithOverloads.2", func_names) + @skipIfUnsupportedMinOpsetVersion(15) + def test_local_function_infer_scopes(self): + class M(torch.nn.Module): + def forward(self, x): + # Concatenation of scalars inserts unscoped tensors in IR graph. + new_tensor_shape = x.size()[:-1] + (1, 1, -1) + tensor = x.view(*new_tensor_shape) + return tensor + + x = torch.randn(4, 5) + f = io.BytesIO() + torch.onnx.export(M(), (x,), f, export_modules_as_functions=True, + opset_version=self.opset_version, do_constant_folding=False) + + onnx_model = onnx.load(io.BytesIO(f.getvalue())) + funcs = onnx_model.functions + self.assertIn("M", [f.name for f in funcs]) + def test_aten_fallthrough(self): # Test aten export of op with no symbolic class Module(torch.nn.Module): @@ -1222,5 +1240,9 @@ class TestUtilityFuns_opset14(TestUtilityFuns_opset9): opset_version = 14 +class TestUtilityFuns_opset15(TestUtilityFuns_opset9): + opset_version = 15 + + if __name__ == "__main__": run_tests() diff --git a/torch/csrc/jit/passes/onnx/function_extraction.cpp b/torch/csrc/jit/passes/onnx/function_extraction.cpp index 5a0f8592f3b..1840d96fd13 100644 --- a/torch/csrc/jit/passes/onnx/function_extraction.cpp +++ b/torch/csrc/jit/passes/onnx/function_extraction.cpp @@ -353,6 +353,13 @@ c10::optional FunctionExtractor::InferScope(Node* n) { } for (auto output : n->outputs()) { for (auto use : output->uses()) { + if (!IsValidScope(use.user->scope())) { + auto inferred_output_scope = InferScope(use.user); + if (inferred_output_scope.has_value() && + IsValidScope(inferred_output_scope.value())) { + use.user->setScope(inferred_output_scope.value()); + } + } output_scopes.emplace_back(use.user->scope()); } } diff --git a/torch/csrc/jit/passes/onnx/helper.cpp b/torch/csrc/jit/passes/onnx/helper.cpp index 83206935048..f76b606c181 100644 --- a/torch/csrc/jit/passes/onnx/helper.cpp +++ b/torch/csrc/jit/passes/onnx/helper.cpp @@ -189,6 +189,7 @@ Node* transformToONNXConcatNode( Node* unsqueezed_node = createONNXUnsqueeze(g, new_node, new_input, 0, opset_version); + unsqueezed_node->copyMetadata(lc_node); unsqueezed.emplace_back(unsqueezed_node->output()); } diff --git a/torch/onnx/__init__.py b/torch/onnx/__init__.py index 31bffdabca1..2a049f37e35 100644 --- a/torch/onnx/__init__.py +++ b/torch/onnx/__init__.py @@ -9,8 +9,6 @@ ONNX_ARCHIVE_MODEL_PROTO_NAME = "__MODEL_PROTO" producer_name = "pytorch" producer_version = _C._onnx.PRODUCER_VERSION -constant_folding_opset_versions = [9, 10, 11, 12, 13, 14] - class ExportTypes: r""""Specifies how the ONNX model is stored.""" diff --git a/torch/onnx/symbolic_helper.py b/torch/onnx/symbolic_helper.py index 3d62dc2df3c..d6746967d76 100644 --- a/torch/onnx/symbolic_helper.py +++ b/torch/onnx/symbolic_helper.py @@ -839,6 +839,7 @@ _default_onnx_opset_version = 9 _onnx_main_opset = 15 _onnx_stable_opsets = [7, 8, 9, 10, 11, 12, 13, 14] _export_onnx_opset_version = _default_onnx_opset_version +_constant_folding_opset_versions = list(range(9, _onnx_main_opset + 1)) def _set_opset_version(opset_version): diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index ad960a4b3c1..50023313ee7 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -532,7 +532,8 @@ def _model_to_graph(model, args, verbose=False, if training is None or training == TrainingMode.EVAL: params_dict = torch._C._jit_pass_onnx_eval_peephole(graph, params_dict) - if do_constant_folding and _export_onnx_opset_version in torch.onnx.constant_folding_opset_versions: + from torch.onnx.symbolic_helper import _constant_folding_opset_versions + if do_constant_folding and _export_onnx_opset_version in _constant_folding_opset_versions: params_dict = torch._C._jit_pass_onnx_constant_fold(graph, params_dict, _export_onnx_opset_version) torch._C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)