[export] Support preserving submodule callling convention in non-strict export (#117796)

Summary: Title

Test Plan: CI

Reviewed By: zhxchen17

Differential Revision: D52889236

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117796
Approved by: https://github.com/angelayi
This commit is contained in:
Tugsbayasgalan (Tugsuu) Manlaibaatar 2024-01-19 17:16:45 +00:00 committed by PyTorch MergeBot
parent 249a226113
commit f316c35a34
3 changed files with 98 additions and 49 deletions

View File

@ -30,7 +30,7 @@ from torch._export.utils import (
from torch.export import Constraint, Dim, export
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing import FileCheck
from torch.testing._internal.common_utils import run_tests, TestCase
from torch.testing._internal.common_utils import skipIfTorchDynamo, run_tests, TestCase
from torch.utils._pytree import (
LeafSpec,
tree_flatten,
@ -199,6 +199,7 @@ class TestUnflatten(TestCase):
id(getattr(unflattened_module.sub_net, "2")),
)
@skipIfTorchDynamo("Non strict mode is not meant to run with dynamo")
def test_unflatten_preserve_signature(self):
class NestedChild(torch.nn.Module):
def forward(self, zx, y):
@ -234,41 +235,43 @@ class TestUnflatten(TestCase):
orig_eager = MyModule()
inps = torch.rand(2, 3), torch.rand(2, 3)
export_module = export(
orig_eager,
inps,
{},
preserve_module_call_signature=("foo.nested",),
)
unflattened = unflatten(export_module)
self.compare_outputs(export_module, unflattened, inps)
unflattened.foo.nested = NestedChild()
self.compare_outputs(export_module, unflattened, inps)
for strict in [True, False]:
export_module = export(
orig_eager,
inps,
{},
preserve_module_call_signature=("foo.nested",),
strict=strict
)
unflattened = unflatten(export_module)
self.compare_outputs(export_module, unflattened, inps)
unflattened.foo.nested = NestedChild()
self.compare_outputs(export_module, unflattened, inps)
# Test tree spec mismatched input
orig_outs = export_module(*inps)
new_inps = *inps, torch.rand(2, 3)
with self.assertRaisesRegex(
TypeError,
"There is no flat args adapter sepcified. Are you sure you are calling this with the right arguments?",
):
unflattened(new_inps)
# Test tree spec mismatched input
orig_outs = export_module(*inps)
new_inps = *inps, torch.rand(2, 3)
with self.assertRaisesRegex(
TypeError,
"There is no flat args adapter sepcified. Are you sure you are calling this with the right arguments?",
):
unflattened(new_inps)
# With flat args adapter
class KeepTwoFlatArgsAdapter(FlatArgsAdapter):
def adapt(
self,
target_spec: TreeSpec,
input_spec: TreeSpec,
input_args: List[Any],
) -> List[Any]:
while len(input_args) > 2:
input_args.pop(-1)
return input_args
# With flat args adapter
class KeepTwoFlatArgsAdapter(FlatArgsAdapter):
def adapt(
self,
target_spec: TreeSpec,
input_spec: TreeSpec,
input_args: List[Any],
) -> List[Any]:
while len(input_args) > 2:
input_args.pop(-1)
return input_args
unflattened = unflatten(export_module, KeepTwoFlatArgsAdapter())
new_outs = unflattened(*new_inps)
self.assertTrue(torch.allclose(orig_outs, new_outs))
unflattened = unflatten(export_module, KeepTwoFlatArgsAdapter())
new_outs = unflattened(*new_inps)
self.assertTrue(torch.allclose(orig_outs, new_outs))
def test_unflatten_param_list_dict(self):
class Mod(torch.nn.Module):

View File

@ -61,7 +61,9 @@ def _wrap_submodule(mod, path, module_call_specs):
submodule = getattr(submodule, name)
def update_module_call_signatures(path, in_spec, out_spec):
assert path not in module_call_specs
if path in module_call_specs:
assert module_call_specs[path]["in_spec"] == in_spec
assert module_call_specs[path]["out_spec"] == out_spec
module_call_specs[path] = {"in_spec": in_spec, "out_spec": out_spec}
assert "forward" not in submodule.__dict__

View File

@ -517,12 +517,24 @@ def _export(
constraints = constraints or []
kwargs = kwargs or {}
flat_args, orig_in_spec = pytree.tree_flatten((args, kwargs))
if not strict:
assert isinstance(f, torch.nn.Module)
assert len(preserve_module_call_signature) == 0
assert len(kwargs) == 0, "keyword arguments NYI"
out_spec = None
module_call_specs: Dict[str, Dict[str, pytree.TreeSpec]] = {}
def strip_root(x):
if isinstance(x, str) and x.startswith("_export_root"):
stripped = x[len("_export_root") :]
return stripped[1:] if stripped.startswith(".") else stripped
return x
def fixup_key(x):
return "L__self__" + strip_root(x)
def _tuplify_outputs(aot_export):
def _aot_export_non_strict(mod, args, **kwargs):
class Wrapper(torch.nn.Module):
@ -537,16 +549,16 @@ def _export(
)
return tuple(flat_outs)
gm, sig = aot_export(Wrapper(mod), args, **kwargs)
def strip_root(x):
if isinstance(x, str) and x.startswith("_export_root"):
stripped = x[len("_export_root") :]
return stripped[1:] if stripped.startswith(".") else stripped
return x
def fixup_key(x):
return "L__self__" + strip_root(x)
wrapped_mod = Wrapper(mod)
# Patch export_root to the signatures so that wrapper module correctly populates the
# in/out spec
new_preserved_call_signatures = [
"_export_root." + i for i in preserve_module_call_signature
]
with _wrap_submodules(
wrapped_mod, new_preserved_call_signatures, module_call_specs
):
gm, sig = aot_export(wrapped_mod, args, **kwargs)
sig.parameters = pytree.tree_map(strip_root, sig.parameters)
sig.buffers = pytree.tree_map(strip_root, sig.buffers)
@ -585,9 +597,39 @@ def _export(
fake_mode, src_equalities, original_signature, ep_non_strict.gm
)
assert out_spec is not None
gm = ep_non_strict.gm
module_call_signatures = {
strip_root(fqn): ModuleCallSignature(inputs=[], outputs=[], **specs)
for fqn, specs in module_call_specs.items()
}
if len(preserve_module_call_signature) > 0:
for node in gm.graph.nodes:
if node.target == torch.ops.higher_order._export_tracepoint:
if "path" in node.kwargs:
path = strip_root(node.kwargs["path"])
with gm.graph.inserting_before(node):
new_node = gm.graph.create_node(
"call_function",
torch.ops.higher_order._export_tracepoint,
args=node.args,
kwargs={
"path": path,
"kind": node.kwargs["kind"],
},
)
node.replace_all_uses_with(new_node)
gm.graph.erase_node(node)
res = CollectTracepointsPass(module_call_signatures, ep_non_strict.sig)(gm)
assert res is not None
gm = res.graph_module
return ExportedProgram(
root=ep_non_strict.gm,
graph=ep_non_strict.gm.graph,
root=gm,
graph=gm.graph,
graph_signature=ep_non_strict.sig,
state_dict=_get_params_buffers(f),
range_constraints=range_constraints,
@ -595,9 +637,12 @@ def _export(
ModuleCallEntry(
"",
ModuleCallSignature(
[], [], pytree.tree_flatten((args, {}))[1], out_spec
inputs=[], outputs=[], in_spec=orig_in_spec, out_spec=out_spec
),
)
]
+ [
ModuleCallEntry(fqn, sig) for fqn, sig in module_call_signatures.items()
],
example_inputs=(args, kwargs),
constants=ep_non_strict.constants,
@ -768,7 +813,6 @@ def _export(
),
len(export_graph_signature.input_specs),
)
flat_args, orig_in_spec = pytree.tree_flatten((args, kwargs))
range_constraints = _process_constraints(
gm,
num_lifted,