[export] Support preserving submodule callling convention in non-strict export (#117796)

Summary: Title

Test Plan: CI

Reviewed By: zhxchen17

Differential Revision: D52889236

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117796
Approved by: https://github.com/angelayi
This commit is contained in:
Tugsbayasgalan (Tugsuu) Manlaibaatar 2024-01-19 17:16:45 +00:00 committed by PyTorch MergeBot
parent 249a226113
commit f316c35a34
3 changed files with 98 additions and 49 deletions

View File

@ -30,7 +30,7 @@ from torch._export.utils import (
from torch.export import Constraint, Dim, export from torch.export import Constraint, Dim, export
from torch.fx.experimental.proxy_tensor import make_fx from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing import FileCheck from torch.testing import FileCheck
from torch.testing._internal.common_utils import run_tests, TestCase from torch.testing._internal.common_utils import skipIfTorchDynamo, run_tests, TestCase
from torch.utils._pytree import ( from torch.utils._pytree import (
LeafSpec, LeafSpec,
tree_flatten, tree_flatten,
@ -199,6 +199,7 @@ class TestUnflatten(TestCase):
id(getattr(unflattened_module.sub_net, "2")), id(getattr(unflattened_module.sub_net, "2")),
) )
@skipIfTorchDynamo("Non strict mode is not meant to run with dynamo")
def test_unflatten_preserve_signature(self): def test_unflatten_preserve_signature(self):
class NestedChild(torch.nn.Module): class NestedChild(torch.nn.Module):
def forward(self, zx, y): def forward(self, zx, y):
@ -234,41 +235,43 @@ class TestUnflatten(TestCase):
orig_eager = MyModule() orig_eager = MyModule()
inps = torch.rand(2, 3), torch.rand(2, 3) inps = torch.rand(2, 3), torch.rand(2, 3)
export_module = export( for strict in [True, False]:
orig_eager, export_module = export(
inps, orig_eager,
{}, inps,
preserve_module_call_signature=("foo.nested",), {},
) preserve_module_call_signature=("foo.nested",),
unflattened = unflatten(export_module) strict=strict
self.compare_outputs(export_module, unflattened, inps) )
unflattened.foo.nested = NestedChild() unflattened = unflatten(export_module)
self.compare_outputs(export_module, unflattened, inps) self.compare_outputs(export_module, unflattened, inps)
unflattened.foo.nested = NestedChild()
self.compare_outputs(export_module, unflattened, inps)
# Test tree spec mismatched input # Test tree spec mismatched input
orig_outs = export_module(*inps) orig_outs = export_module(*inps)
new_inps = *inps, torch.rand(2, 3) new_inps = *inps, torch.rand(2, 3)
with self.assertRaisesRegex( with self.assertRaisesRegex(
TypeError, TypeError,
"There is no flat args adapter sepcified. Are you sure you are calling this with the right arguments?", "There is no flat args adapter sepcified. Are you sure you are calling this with the right arguments?",
): ):
unflattened(new_inps) unflattened(new_inps)
# With flat args adapter # With flat args adapter
class KeepTwoFlatArgsAdapter(FlatArgsAdapter): class KeepTwoFlatArgsAdapter(FlatArgsAdapter):
def adapt( def adapt(
self, self,
target_spec: TreeSpec, target_spec: TreeSpec,
input_spec: TreeSpec, input_spec: TreeSpec,
input_args: List[Any], input_args: List[Any],
) -> List[Any]: ) -> List[Any]:
while len(input_args) > 2: while len(input_args) > 2:
input_args.pop(-1) input_args.pop(-1)
return input_args return input_args
unflattened = unflatten(export_module, KeepTwoFlatArgsAdapter()) unflattened = unflatten(export_module, KeepTwoFlatArgsAdapter())
new_outs = unflattened(*new_inps) new_outs = unflattened(*new_inps)
self.assertTrue(torch.allclose(orig_outs, new_outs)) self.assertTrue(torch.allclose(orig_outs, new_outs))
def test_unflatten_param_list_dict(self): def test_unflatten_param_list_dict(self):
class Mod(torch.nn.Module): class Mod(torch.nn.Module):

View File

@ -61,7 +61,9 @@ def _wrap_submodule(mod, path, module_call_specs):
submodule = getattr(submodule, name) submodule = getattr(submodule, name)
def update_module_call_signatures(path, in_spec, out_spec): def update_module_call_signatures(path, in_spec, out_spec):
assert path not in module_call_specs if path in module_call_specs:
assert module_call_specs[path]["in_spec"] == in_spec
assert module_call_specs[path]["out_spec"] == out_spec
module_call_specs[path] = {"in_spec": in_spec, "out_spec": out_spec} module_call_specs[path] = {"in_spec": in_spec, "out_spec": out_spec}
assert "forward" not in submodule.__dict__ assert "forward" not in submodule.__dict__

View File

@ -517,12 +517,24 @@ def _export(
constraints = constraints or [] constraints = constraints or []
kwargs = kwargs or {} kwargs = kwargs or {}
flat_args, orig_in_spec = pytree.tree_flatten((args, kwargs))
if not strict: if not strict:
assert isinstance(f, torch.nn.Module) assert isinstance(f, torch.nn.Module)
assert len(preserve_module_call_signature) == 0
assert len(kwargs) == 0, "keyword arguments NYI" assert len(kwargs) == 0, "keyword arguments NYI"
out_spec = None out_spec = None
module_call_specs: Dict[str, Dict[str, pytree.TreeSpec]] = {}
def strip_root(x):
if isinstance(x, str) and x.startswith("_export_root"):
stripped = x[len("_export_root") :]
return stripped[1:] if stripped.startswith(".") else stripped
return x
def fixup_key(x):
return "L__self__" + strip_root(x)
def _tuplify_outputs(aot_export): def _tuplify_outputs(aot_export):
def _aot_export_non_strict(mod, args, **kwargs): def _aot_export_non_strict(mod, args, **kwargs):
class Wrapper(torch.nn.Module): class Wrapper(torch.nn.Module):
@ -537,16 +549,16 @@ def _export(
) )
return tuple(flat_outs) return tuple(flat_outs)
gm, sig = aot_export(Wrapper(mod), args, **kwargs) wrapped_mod = Wrapper(mod)
# Patch export_root to the signatures so that wrapper module correctly populates the
def strip_root(x): # in/out spec
if isinstance(x, str) and x.startswith("_export_root"): new_preserved_call_signatures = [
stripped = x[len("_export_root") :] "_export_root." + i for i in preserve_module_call_signature
return stripped[1:] if stripped.startswith(".") else stripped ]
return x with _wrap_submodules(
wrapped_mod, new_preserved_call_signatures, module_call_specs
def fixup_key(x): ):
return "L__self__" + strip_root(x) gm, sig = aot_export(wrapped_mod, args, **kwargs)
sig.parameters = pytree.tree_map(strip_root, sig.parameters) sig.parameters = pytree.tree_map(strip_root, sig.parameters)
sig.buffers = pytree.tree_map(strip_root, sig.buffers) sig.buffers = pytree.tree_map(strip_root, sig.buffers)
@ -585,9 +597,39 @@ def _export(
fake_mode, src_equalities, original_signature, ep_non_strict.gm fake_mode, src_equalities, original_signature, ep_non_strict.gm
) )
assert out_spec is not None assert out_spec is not None
gm = ep_non_strict.gm
module_call_signatures = {
strip_root(fqn): ModuleCallSignature(inputs=[], outputs=[], **specs)
for fqn, specs in module_call_specs.items()
}
if len(preserve_module_call_signature) > 0:
for node in gm.graph.nodes:
if node.target == torch.ops.higher_order._export_tracepoint:
if "path" in node.kwargs:
path = strip_root(node.kwargs["path"])
with gm.graph.inserting_before(node):
new_node = gm.graph.create_node(
"call_function",
torch.ops.higher_order._export_tracepoint,
args=node.args,
kwargs={
"path": path,
"kind": node.kwargs["kind"],
},
)
node.replace_all_uses_with(new_node)
gm.graph.erase_node(node)
res = CollectTracepointsPass(module_call_signatures, ep_non_strict.sig)(gm)
assert res is not None
gm = res.graph_module
return ExportedProgram( return ExportedProgram(
root=ep_non_strict.gm, root=gm,
graph=ep_non_strict.gm.graph, graph=gm.graph,
graph_signature=ep_non_strict.sig, graph_signature=ep_non_strict.sig,
state_dict=_get_params_buffers(f), state_dict=_get_params_buffers(f),
range_constraints=range_constraints, range_constraints=range_constraints,
@ -595,9 +637,12 @@ def _export(
ModuleCallEntry( ModuleCallEntry(
"", "",
ModuleCallSignature( ModuleCallSignature(
[], [], pytree.tree_flatten((args, {}))[1], out_spec inputs=[], outputs=[], in_spec=orig_in_spec, out_spec=out_spec
), ),
) )
]
+ [
ModuleCallEntry(fqn, sig) for fqn, sig in module_call_signatures.items()
], ],
example_inputs=(args, kwargs), example_inputs=(args, kwargs),
constants=ep_non_strict.constants, constants=ep_non_strict.constants,
@ -768,7 +813,6 @@ def _export(
), ),
len(export_graph_signature.input_specs), len(export_graph_signature.input_specs),
) )
flat_args, orig_in_spec = pytree.tree_flatten((args, kwargs))
range_constraints = _process_constraints( range_constraints = _process_constraints(
gm, gm,
num_lifted, num_lifted,