mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[export] Support preserving submodule callling convention in non-strict export (#117796)
Summary: Title Test Plan: CI Reviewed By: zhxchen17 Differential Revision: D52889236 Pull Request resolved: https://github.com/pytorch/pytorch/pull/117796 Approved by: https://github.com/angelayi
This commit is contained in:
parent
249a226113
commit
f316c35a34
|
|
@ -30,7 +30,7 @@ from torch._export.utils import (
|
||||||
from torch.export import Constraint, Dim, export
|
from torch.export import Constraint, Dim, export
|
||||||
from torch.fx.experimental.proxy_tensor import make_fx
|
from torch.fx.experimental.proxy_tensor import make_fx
|
||||||
from torch.testing import FileCheck
|
from torch.testing import FileCheck
|
||||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
from torch.testing._internal.common_utils import skipIfTorchDynamo, run_tests, TestCase
|
||||||
from torch.utils._pytree import (
|
from torch.utils._pytree import (
|
||||||
LeafSpec,
|
LeafSpec,
|
||||||
tree_flatten,
|
tree_flatten,
|
||||||
|
|
@ -199,6 +199,7 @@ class TestUnflatten(TestCase):
|
||||||
id(getattr(unflattened_module.sub_net, "2")),
|
id(getattr(unflattened_module.sub_net, "2")),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@skipIfTorchDynamo("Non strict mode is not meant to run with dynamo")
|
||||||
def test_unflatten_preserve_signature(self):
|
def test_unflatten_preserve_signature(self):
|
||||||
class NestedChild(torch.nn.Module):
|
class NestedChild(torch.nn.Module):
|
||||||
def forward(self, zx, y):
|
def forward(self, zx, y):
|
||||||
|
|
@ -234,11 +235,13 @@ class TestUnflatten(TestCase):
|
||||||
|
|
||||||
orig_eager = MyModule()
|
orig_eager = MyModule()
|
||||||
inps = torch.rand(2, 3), torch.rand(2, 3)
|
inps = torch.rand(2, 3), torch.rand(2, 3)
|
||||||
|
for strict in [True, False]:
|
||||||
export_module = export(
|
export_module = export(
|
||||||
orig_eager,
|
orig_eager,
|
||||||
inps,
|
inps,
|
||||||
{},
|
{},
|
||||||
preserve_module_call_signature=("foo.nested",),
|
preserve_module_call_signature=("foo.nested",),
|
||||||
|
strict=strict
|
||||||
)
|
)
|
||||||
unflattened = unflatten(export_module)
|
unflattened = unflatten(export_module)
|
||||||
self.compare_outputs(export_module, unflattened, inps)
|
self.compare_outputs(export_module, unflattened, inps)
|
||||||
|
|
|
||||||
|
|
@ -61,7 +61,9 @@ def _wrap_submodule(mod, path, module_call_specs):
|
||||||
submodule = getattr(submodule, name)
|
submodule = getattr(submodule, name)
|
||||||
|
|
||||||
def update_module_call_signatures(path, in_spec, out_spec):
|
def update_module_call_signatures(path, in_spec, out_spec):
|
||||||
assert path not in module_call_specs
|
if path in module_call_specs:
|
||||||
|
assert module_call_specs[path]["in_spec"] == in_spec
|
||||||
|
assert module_call_specs[path]["out_spec"] == out_spec
|
||||||
module_call_specs[path] = {"in_spec": in_spec, "out_spec": out_spec}
|
module_call_specs[path] = {"in_spec": in_spec, "out_spec": out_spec}
|
||||||
|
|
||||||
assert "forward" not in submodule.__dict__
|
assert "forward" not in submodule.__dict__
|
||||||
|
|
|
||||||
|
|
@ -517,12 +517,24 @@ def _export(
|
||||||
constraints = constraints or []
|
constraints = constraints or []
|
||||||
kwargs = kwargs or {}
|
kwargs = kwargs or {}
|
||||||
|
|
||||||
|
flat_args, orig_in_spec = pytree.tree_flatten((args, kwargs))
|
||||||
|
|
||||||
if not strict:
|
if not strict:
|
||||||
assert isinstance(f, torch.nn.Module)
|
assert isinstance(f, torch.nn.Module)
|
||||||
assert len(preserve_module_call_signature) == 0
|
|
||||||
assert len(kwargs) == 0, "keyword arguments NYI"
|
assert len(kwargs) == 0, "keyword arguments NYI"
|
||||||
out_spec = None
|
out_spec = None
|
||||||
|
|
||||||
|
module_call_specs: Dict[str, Dict[str, pytree.TreeSpec]] = {}
|
||||||
|
|
||||||
|
def strip_root(x):
|
||||||
|
if isinstance(x, str) and x.startswith("_export_root"):
|
||||||
|
stripped = x[len("_export_root") :]
|
||||||
|
return stripped[1:] if stripped.startswith(".") else stripped
|
||||||
|
return x
|
||||||
|
|
||||||
|
def fixup_key(x):
|
||||||
|
return "L__self__" + strip_root(x)
|
||||||
|
|
||||||
def _tuplify_outputs(aot_export):
|
def _tuplify_outputs(aot_export):
|
||||||
def _aot_export_non_strict(mod, args, **kwargs):
|
def _aot_export_non_strict(mod, args, **kwargs):
|
||||||
class Wrapper(torch.nn.Module):
|
class Wrapper(torch.nn.Module):
|
||||||
|
|
@ -537,16 +549,16 @@ def _export(
|
||||||
)
|
)
|
||||||
return tuple(flat_outs)
|
return tuple(flat_outs)
|
||||||
|
|
||||||
gm, sig = aot_export(Wrapper(mod), args, **kwargs)
|
wrapped_mod = Wrapper(mod)
|
||||||
|
# Patch export_root to the signatures so that wrapper module correctly populates the
|
||||||
def strip_root(x):
|
# in/out spec
|
||||||
if isinstance(x, str) and x.startswith("_export_root"):
|
new_preserved_call_signatures = [
|
||||||
stripped = x[len("_export_root") :]
|
"_export_root." + i for i in preserve_module_call_signature
|
||||||
return stripped[1:] if stripped.startswith(".") else stripped
|
]
|
||||||
return x
|
with _wrap_submodules(
|
||||||
|
wrapped_mod, new_preserved_call_signatures, module_call_specs
|
||||||
def fixup_key(x):
|
):
|
||||||
return "L__self__" + strip_root(x)
|
gm, sig = aot_export(wrapped_mod, args, **kwargs)
|
||||||
|
|
||||||
sig.parameters = pytree.tree_map(strip_root, sig.parameters)
|
sig.parameters = pytree.tree_map(strip_root, sig.parameters)
|
||||||
sig.buffers = pytree.tree_map(strip_root, sig.buffers)
|
sig.buffers = pytree.tree_map(strip_root, sig.buffers)
|
||||||
|
|
@ -585,9 +597,39 @@ def _export(
|
||||||
fake_mode, src_equalities, original_signature, ep_non_strict.gm
|
fake_mode, src_equalities, original_signature, ep_non_strict.gm
|
||||||
)
|
)
|
||||||
assert out_spec is not None
|
assert out_spec is not None
|
||||||
|
|
||||||
|
gm = ep_non_strict.gm
|
||||||
|
|
||||||
|
module_call_signatures = {
|
||||||
|
strip_root(fqn): ModuleCallSignature(inputs=[], outputs=[], **specs)
|
||||||
|
for fqn, specs in module_call_specs.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(preserve_module_call_signature) > 0:
|
||||||
|
for node in gm.graph.nodes:
|
||||||
|
if node.target == torch.ops.higher_order._export_tracepoint:
|
||||||
|
if "path" in node.kwargs:
|
||||||
|
path = strip_root(node.kwargs["path"])
|
||||||
|
with gm.graph.inserting_before(node):
|
||||||
|
new_node = gm.graph.create_node(
|
||||||
|
"call_function",
|
||||||
|
torch.ops.higher_order._export_tracepoint,
|
||||||
|
args=node.args,
|
||||||
|
kwargs={
|
||||||
|
"path": path,
|
||||||
|
"kind": node.kwargs["kind"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
node.replace_all_uses_with(new_node)
|
||||||
|
gm.graph.erase_node(node)
|
||||||
|
|
||||||
|
res = CollectTracepointsPass(module_call_signatures, ep_non_strict.sig)(gm)
|
||||||
|
assert res is not None
|
||||||
|
gm = res.graph_module
|
||||||
|
|
||||||
return ExportedProgram(
|
return ExportedProgram(
|
||||||
root=ep_non_strict.gm,
|
root=gm,
|
||||||
graph=ep_non_strict.gm.graph,
|
graph=gm.graph,
|
||||||
graph_signature=ep_non_strict.sig,
|
graph_signature=ep_non_strict.sig,
|
||||||
state_dict=_get_params_buffers(f),
|
state_dict=_get_params_buffers(f),
|
||||||
range_constraints=range_constraints,
|
range_constraints=range_constraints,
|
||||||
|
|
@ -595,9 +637,12 @@ def _export(
|
||||||
ModuleCallEntry(
|
ModuleCallEntry(
|
||||||
"",
|
"",
|
||||||
ModuleCallSignature(
|
ModuleCallSignature(
|
||||||
[], [], pytree.tree_flatten((args, {}))[1], out_spec
|
inputs=[], outputs=[], in_spec=orig_in_spec, out_spec=out_spec
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
]
|
||||||
|
+ [
|
||||||
|
ModuleCallEntry(fqn, sig) for fqn, sig in module_call_signatures.items()
|
||||||
],
|
],
|
||||||
example_inputs=(args, kwargs),
|
example_inputs=(args, kwargs),
|
||||||
constants=ep_non_strict.constants,
|
constants=ep_non_strict.constants,
|
||||||
|
|
@ -768,7 +813,6 @@ def _export(
|
||||||
),
|
),
|
||||||
len(export_graph_signature.input_specs),
|
len(export_graph_signature.input_specs),
|
||||||
)
|
)
|
||||||
flat_args, orig_in_spec = pytree.tree_flatten((args, kwargs))
|
|
||||||
range_constraints = _process_constraints(
|
range_constraints = _process_constraints(
|
||||||
gm,
|
gm,
|
||||||
num_lifted,
|
num_lifted,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user