[ONNX] Improve scope inference in function extraction

Cover more cases of scope inferencing where consecutive nodes don't have valid scope information. Usually these nodes are created in some pass where authors forgot to assign meaningful scope to them.
* One rule of `InferScope` is to check if the current node's outputs' users share the same scope. Recursively run `InferScope` on the user nodes if they are missing scope as well. Since the graph is SSA, the depth is finite.
* Fix one pass that missed scope information for a new node.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71897
This commit is contained in:
BowenBao 2022-01-31 23:58:53 +00:00 committed by PyTorch MergeBot
parent a83cf17807
commit cf70466970
6 changed files with 33 additions and 3 deletions

View File

@ -752,6 +752,24 @@ class TestUtilityFuns_opset9(_BaseTestCase):
self.assertIn("NWithOverloads.1", func_names)
self.assertIn("NWithOverloads.2", func_names)
@skipIfUnsupportedMinOpsetVersion(15)
def test_local_function_infer_scopes(self):
class M(torch.nn.Module):
def forward(self, x):
# Concatenation of scalars inserts unscoped tensors in IR graph.
new_tensor_shape = x.size()[:-1] + (1, 1, -1)
tensor = x.view(*new_tensor_shape)
return tensor
x = torch.randn(4, 5)
f = io.BytesIO()
torch.onnx.export(M(), (x,), f, export_modules_as_functions=True,
opset_version=self.opset_version, do_constant_folding=False)
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
funcs = onnx_model.functions
self.assertIn("M", [f.name for f in funcs])
def test_aten_fallthrough(self):
# Test aten export of op with no symbolic
class Module(torch.nn.Module):
@ -1222,5 +1240,9 @@ class TestUtilityFuns_opset14(TestUtilityFuns_opset9):
opset_version = 14
class TestUtilityFuns_opset15(TestUtilityFuns_opset9):
opset_version = 15
if __name__ == "__main__":
run_tests()

View File

@ -353,6 +353,13 @@ c10::optional<ScopePtr> FunctionExtractor::InferScope(Node* n) {
}
for (auto output : n->outputs()) {
for (auto use : output->uses()) {
if (!IsValidScope(use.user->scope())) {
auto inferred_output_scope = InferScope(use.user);
if (inferred_output_scope.has_value() &&
IsValidScope(inferred_output_scope.value())) {
use.user->setScope(inferred_output_scope.value());
}
}
output_scopes.emplace_back(use.user->scope());
}
}

View File

@ -189,6 +189,7 @@ Node* transformToONNXConcatNode(
Node* unsqueezed_node =
createONNXUnsqueeze(g, new_node, new_input, 0, opset_version);
unsqueezed_node->copyMetadata(lc_node);
unsqueezed.emplace_back(unsqueezed_node->output());
}

View File

@ -9,8 +9,6 @@ ONNX_ARCHIVE_MODEL_PROTO_NAME = "__MODEL_PROTO"
producer_name = "pytorch"
producer_version = _C._onnx.PRODUCER_VERSION
constant_folding_opset_versions = [9, 10, 11, 12, 13, 14]
class ExportTypes:
r""""Specifies how the ONNX model is stored."""

View File

@ -839,6 +839,7 @@ _default_onnx_opset_version = 9
_onnx_main_opset = 15
_onnx_stable_opsets = [7, 8, 9, 10, 11, 12, 13, 14]
_export_onnx_opset_version = _default_onnx_opset_version
_constant_folding_opset_versions = list(range(9, _onnx_main_opset + 1))
def _set_opset_version(opset_version):

View File

@ -532,7 +532,8 @@ def _model_to_graph(model, args, verbose=False,
if training is None or training == TrainingMode.EVAL:
params_dict = torch._C._jit_pass_onnx_eval_peephole(graph, params_dict)
if do_constant_folding and _export_onnx_opset_version in torch.onnx.constant_folding_opset_versions:
from torch.onnx.symbolic_helper import _constant_folding_opset_versions
if do_constant_folding and _export_onnx_opset_version in _constant_folding_opset_versions:
params_dict = torch._C._jit_pass_onnx_constant_fold(graph, params_dict,
_export_onnx_opset_version)
torch._C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)