mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ONNX] Improve scope inference in function extraction
Cover more cases of scope inferencing where consecutive nodes don't have valid scope information. Usually these nodes are created in some pass where authors forgot to assign meaningful scope to them. * One rule of `InferScope` is to check if the current node's outputs' users share the same scope. Recursively run `InferScope` on the user nodes if they are missing scope as well. Since the graph is SSA, the depth is finite. * Fix one pass that missed scope information for a new node. Pull Request resolved: https://github.com/pytorch/pytorch/pull/71897
This commit is contained in:
parent
a83cf17807
commit
cf70466970
|
|
@ -752,6 +752,24 @@ class TestUtilityFuns_opset9(_BaseTestCase):
|
|||
self.assertIn("NWithOverloads.1", func_names)
|
||||
self.assertIn("NWithOverloads.2", func_names)
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(15)
|
||||
def test_local_function_infer_scopes(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
# Concatenation of scalars inserts unscoped tensors in IR graph.
|
||||
new_tensor_shape = x.size()[:-1] + (1, 1, -1)
|
||||
tensor = x.view(*new_tensor_shape)
|
||||
return tensor
|
||||
|
||||
x = torch.randn(4, 5)
|
||||
f = io.BytesIO()
|
||||
torch.onnx.export(M(), (x,), f, export_modules_as_functions=True,
|
||||
opset_version=self.opset_version, do_constant_folding=False)
|
||||
|
||||
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
|
||||
funcs = onnx_model.functions
|
||||
self.assertIn("M", [f.name for f in funcs])
|
||||
|
||||
def test_aten_fallthrough(self):
|
||||
# Test aten export of op with no symbolic
|
||||
class Module(torch.nn.Module):
|
||||
|
|
@ -1222,5 +1240,9 @@ class TestUtilityFuns_opset14(TestUtilityFuns_opset9):
|
|||
opset_version = 14
|
||||
|
||||
|
||||
class TestUtilityFuns_opset15(TestUtilityFuns_opset9):
|
||||
opset_version = 15
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -353,6 +353,13 @@ c10::optional<ScopePtr> FunctionExtractor::InferScope(Node* n) {
|
|||
}
|
||||
for (auto output : n->outputs()) {
|
||||
for (auto use : output->uses()) {
|
||||
if (!IsValidScope(use.user->scope())) {
|
||||
auto inferred_output_scope = InferScope(use.user);
|
||||
if (inferred_output_scope.has_value() &&
|
||||
IsValidScope(inferred_output_scope.value())) {
|
||||
use.user->setScope(inferred_output_scope.value());
|
||||
}
|
||||
}
|
||||
output_scopes.emplace_back(use.user->scope());
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -189,6 +189,7 @@ Node* transformToONNXConcatNode(
|
|||
|
||||
Node* unsqueezed_node =
|
||||
createONNXUnsqueeze(g, new_node, new_input, 0, opset_version);
|
||||
unsqueezed_node->copyMetadata(lc_node);
|
||||
unsqueezed.emplace_back(unsqueezed_node->output());
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -9,8 +9,6 @@ ONNX_ARCHIVE_MODEL_PROTO_NAME = "__MODEL_PROTO"
|
|||
|
||||
producer_name = "pytorch"
|
||||
producer_version = _C._onnx.PRODUCER_VERSION
|
||||
constant_folding_opset_versions = [9, 10, 11, 12, 13, 14]
|
||||
|
||||
|
||||
class ExportTypes:
|
||||
r""""Specifies how the ONNX model is stored."""
|
||||
|
|
|
|||
|
|
@ -839,6 +839,7 @@ _default_onnx_opset_version = 9
|
|||
_onnx_main_opset = 15
|
||||
_onnx_stable_opsets = [7, 8, 9, 10, 11, 12, 13, 14]
|
||||
_export_onnx_opset_version = _default_onnx_opset_version
|
||||
_constant_folding_opset_versions = list(range(9, _onnx_main_opset + 1))
|
||||
|
||||
|
||||
def _set_opset_version(opset_version):
|
||||
|
|
|
|||
|
|
@ -532,7 +532,8 @@ def _model_to_graph(model, args, verbose=False,
|
|||
if training is None or training == TrainingMode.EVAL:
|
||||
params_dict = torch._C._jit_pass_onnx_eval_peephole(graph, params_dict)
|
||||
|
||||
if do_constant_folding and _export_onnx_opset_version in torch.onnx.constant_folding_opset_versions:
|
||||
from torch.onnx.symbolic_helper import _constant_folding_opset_versions
|
||||
if do_constant_folding and _export_onnx_opset_version in _constant_folding_opset_versions:
|
||||
params_dict = torch._C._jit_pass_onnx_constant_fold(graph, params_dict,
|
||||
_export_onnx_opset_version)
|
||||
torch._C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user