mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[export] Support preserving submodule callling convention in non-strict export (#117796)
Summary: Title Test Plan: CI Reviewed By: zhxchen17 Differential Revision: D52889236 Pull Request resolved: https://github.com/pytorch/pytorch/pull/117796 Approved by: https://github.com/angelayi
This commit is contained in:
parent
249a226113
commit
f316c35a34
|
|
@ -30,7 +30,7 @@ from torch._export.utils import (
|
|||
from torch.export import Constraint, Dim, export
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
from torch.testing._internal.common_utils import skipIfTorchDynamo, run_tests, TestCase
|
||||
from torch.utils._pytree import (
|
||||
LeafSpec,
|
||||
tree_flatten,
|
||||
|
|
@ -199,6 +199,7 @@ class TestUnflatten(TestCase):
|
|||
id(getattr(unflattened_module.sub_net, "2")),
|
||||
)
|
||||
|
||||
@skipIfTorchDynamo("Non strict mode is not meant to run with dynamo")
|
||||
def test_unflatten_preserve_signature(self):
|
||||
class NestedChild(torch.nn.Module):
|
||||
def forward(self, zx, y):
|
||||
|
|
@ -234,11 +235,13 @@ class TestUnflatten(TestCase):
|
|||
|
||||
orig_eager = MyModule()
|
||||
inps = torch.rand(2, 3), torch.rand(2, 3)
|
||||
for strict in [True, False]:
|
||||
export_module = export(
|
||||
orig_eager,
|
||||
inps,
|
||||
{},
|
||||
preserve_module_call_signature=("foo.nested",),
|
||||
strict=strict
|
||||
)
|
||||
unflattened = unflatten(export_module)
|
||||
self.compare_outputs(export_module, unflattened, inps)
|
||||
|
|
|
|||
|
|
@ -61,7 +61,9 @@ def _wrap_submodule(mod, path, module_call_specs):
|
|||
submodule = getattr(submodule, name)
|
||||
|
||||
def update_module_call_signatures(path, in_spec, out_spec):
|
||||
assert path not in module_call_specs
|
||||
if path in module_call_specs:
|
||||
assert module_call_specs[path]["in_spec"] == in_spec
|
||||
assert module_call_specs[path]["out_spec"] == out_spec
|
||||
module_call_specs[path] = {"in_spec": in_spec, "out_spec": out_spec}
|
||||
|
||||
assert "forward" not in submodule.__dict__
|
||||
|
|
|
|||
|
|
@ -517,12 +517,24 @@ def _export(
|
|||
constraints = constraints or []
|
||||
kwargs = kwargs or {}
|
||||
|
||||
flat_args, orig_in_spec = pytree.tree_flatten((args, kwargs))
|
||||
|
||||
if not strict:
|
||||
assert isinstance(f, torch.nn.Module)
|
||||
assert len(preserve_module_call_signature) == 0
|
||||
assert len(kwargs) == 0, "keyword arguments NYI"
|
||||
out_spec = None
|
||||
|
||||
module_call_specs: Dict[str, Dict[str, pytree.TreeSpec]] = {}
|
||||
|
||||
def strip_root(x):
|
||||
if isinstance(x, str) and x.startswith("_export_root"):
|
||||
stripped = x[len("_export_root") :]
|
||||
return stripped[1:] if stripped.startswith(".") else stripped
|
||||
return x
|
||||
|
||||
def fixup_key(x):
|
||||
return "L__self__" + strip_root(x)
|
||||
|
||||
def _tuplify_outputs(aot_export):
|
||||
def _aot_export_non_strict(mod, args, **kwargs):
|
||||
class Wrapper(torch.nn.Module):
|
||||
|
|
@ -537,16 +549,16 @@ def _export(
|
|||
)
|
||||
return tuple(flat_outs)
|
||||
|
||||
gm, sig = aot_export(Wrapper(mod), args, **kwargs)
|
||||
|
||||
def strip_root(x):
|
||||
if isinstance(x, str) and x.startswith("_export_root"):
|
||||
stripped = x[len("_export_root") :]
|
||||
return stripped[1:] if stripped.startswith(".") else stripped
|
||||
return x
|
||||
|
||||
def fixup_key(x):
|
||||
return "L__self__" + strip_root(x)
|
||||
wrapped_mod = Wrapper(mod)
|
||||
# Patch export_root to the signatures so that wrapper module correctly populates the
|
||||
# in/out spec
|
||||
new_preserved_call_signatures = [
|
||||
"_export_root." + i for i in preserve_module_call_signature
|
||||
]
|
||||
with _wrap_submodules(
|
||||
wrapped_mod, new_preserved_call_signatures, module_call_specs
|
||||
):
|
||||
gm, sig = aot_export(wrapped_mod, args, **kwargs)
|
||||
|
||||
sig.parameters = pytree.tree_map(strip_root, sig.parameters)
|
||||
sig.buffers = pytree.tree_map(strip_root, sig.buffers)
|
||||
|
|
@ -585,9 +597,39 @@ def _export(
|
|||
fake_mode, src_equalities, original_signature, ep_non_strict.gm
|
||||
)
|
||||
assert out_spec is not None
|
||||
|
||||
gm = ep_non_strict.gm
|
||||
|
||||
module_call_signatures = {
|
||||
strip_root(fqn): ModuleCallSignature(inputs=[], outputs=[], **specs)
|
||||
for fqn, specs in module_call_specs.items()
|
||||
}
|
||||
|
||||
if len(preserve_module_call_signature) > 0:
|
||||
for node in gm.graph.nodes:
|
||||
if node.target == torch.ops.higher_order._export_tracepoint:
|
||||
if "path" in node.kwargs:
|
||||
path = strip_root(node.kwargs["path"])
|
||||
with gm.graph.inserting_before(node):
|
||||
new_node = gm.graph.create_node(
|
||||
"call_function",
|
||||
torch.ops.higher_order._export_tracepoint,
|
||||
args=node.args,
|
||||
kwargs={
|
||||
"path": path,
|
||||
"kind": node.kwargs["kind"],
|
||||
},
|
||||
)
|
||||
node.replace_all_uses_with(new_node)
|
||||
gm.graph.erase_node(node)
|
||||
|
||||
res = CollectTracepointsPass(module_call_signatures, ep_non_strict.sig)(gm)
|
||||
assert res is not None
|
||||
gm = res.graph_module
|
||||
|
||||
return ExportedProgram(
|
||||
root=ep_non_strict.gm,
|
||||
graph=ep_non_strict.gm.graph,
|
||||
root=gm,
|
||||
graph=gm.graph,
|
||||
graph_signature=ep_non_strict.sig,
|
||||
state_dict=_get_params_buffers(f),
|
||||
range_constraints=range_constraints,
|
||||
|
|
@ -595,9 +637,12 @@ def _export(
|
|||
ModuleCallEntry(
|
||||
"",
|
||||
ModuleCallSignature(
|
||||
[], [], pytree.tree_flatten((args, {}))[1], out_spec
|
||||
inputs=[], outputs=[], in_spec=orig_in_spec, out_spec=out_spec
|
||||
),
|
||||
)
|
||||
]
|
||||
+ [
|
||||
ModuleCallEntry(fqn, sig) for fqn, sig in module_call_signatures.items()
|
||||
],
|
||||
example_inputs=(args, kwargs),
|
||||
constants=ep_non_strict.constants,
|
||||
|
|
@ -768,7 +813,6 @@ def _export(
|
|||
),
|
||||
len(export_graph_signature.input_specs),
|
||||
)
|
||||
flat_args, orig_in_spec = pytree.tree_flatten((args, kwargs))
|
||||
range_constraints = _process_constraints(
|
||||
gm,
|
||||
num_lifted,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user