Add original forward names to schema so that prettify pass works (#136887)

When we run_decomp, we retrace if it is training IR. As a result, we do need to reliably store the oroiginal forward names when we run decomp.

Differential Revision: [D63064453](https://our.internmc.facebook.com/intern/diff/D63064453/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136887
Approved by: https://github.com/angelayi
This commit is contained in:
Tugsbayasgalan Manlaibaatar 2024-10-07 15:45:03 -07:00 committed by PyTorch MergeBot
parent 46525abb71
commit bb31e3f57e
7 changed files with 59 additions and 13 deletions

View File

@ -176,6 +176,28 @@ class TestSerialize(TestCase):
loaded_out = loaded_model.module()(*loaded_args, **loaded_kwargs)
self.assertEqual(orig_out, loaded_out)
def test_metadata_run_decomp_serder(self):
class M(torch.nn.Module):
def forward(self, x):
return x.sin()
exp_program = torch.export.export_for_training(M(), (torch.randn(4, 4),))
output_buffer = io.BytesIO()
# Tests that example forward arg names are preserved when saving and loading module.
torch.export.save(exp_program, output_buffer)
loaded_model = torch.export.load(output_buffer)
ep = loaded_model.run_decompositions({})
# We should preserve the original module name
self.assertExpectedInline(
str(ep.graph_module.code).strip(),
"""\
def forward(self, x):
sin = torch.ops.aten.sin.default(x); x = None
return (sin,)""",
)
def test_metadata_parsing_with_layer_split(self):
# Tests that modules with more complicated layer patterns can be serialized
# and deserialized correctly.

View File

@ -345,6 +345,10 @@ class ModuleCallSignature:
in_spec: str
out_spec: str
# This field is used to prettify the graph placeholders
# after we ser/der and retrace
forward_arg_names: Optional[List[str]] = None
@dataclass
class ModuleCallEntry:

View File

@ -1,5 +1,5 @@
# @generated by update_schema.py
# checksum<<8c4e7dd0f6fdb1fcb9225325cd7637139c1a69c5941847a86af9d4f1fbc0610c>>
# checksum<<69912a674f9c3123e488399d2bc8fdcf1226005721e4ec3dd12da0e176c16e50>>
Argument:
kind: union
fields:
@ -264,6 +264,9 @@ ModuleCallSignature:
type: str
out_spec:
type: str
forward_arg_names:
type: Optional[List[str]]
default: None
NamedArgument:
kind: struct
fields:

View File

@ -1109,6 +1109,7 @@ class GraphModuleSerializer(metaclass=Final):
],
in_spec=treespec_dumps(module_call_signature.in_spec, TREESPEC_VERSION),
out_spec=treespec_dumps(module_call_signature.out_spec, TREESPEC_VERSION),
forward_arg_names=names if (names := module_call_signature.forward_arg_names) else None
)
def serialize_module_call_graph(
@ -2257,6 +2258,7 @@ class GraphModuleDeserializer(metaclass=Final):
],
in_spec=treespec_loads(module_call_signature.in_spec),
out_spec=treespec_loads(module_call_signature.out_spec),
forward_arg_names=names if (names := module_call_signature.forward_arg_names) else None,
)
def deserialize_module_call_graph(

View File

@ -507,6 +507,7 @@ def _make_module_call_graph(
in_spec: TreeSpec,
out_spec: TreeSpec,
module_call_signatures: Dict[str, ModuleCallSignature],
forward_arg_names: Optional[List[str]] = None,
) -> List[ModuleCallEntry]:
ret = [
ModuleCallEntry(fqn=fqn, signature=module_call_signatures.get(fqn))
@ -514,7 +515,11 @@ def _make_module_call_graph(
]
assert ret[0].fqn == ""
ret[0].signature = ModuleCallSignature(
inputs=[], outputs=[], in_spec=in_spec, out_spec=out_spec
inputs=[],
outputs=[],
in_spec=in_spec,
out_spec=out_spec,
forward_arg_names=forward_arg_names,
)
return ret
@ -1063,6 +1068,7 @@ def _get_module_call_graph(
original_in_spec: TreeSpec,
preserve_module_call_signature: Tuple[str, ...],
strict_mode_export: bool,
forward_arg_names: Optional[List[str]] = None,
):
"""
In-place modify the graph module in export_artifact, remove _export_tracepoint nodes and
@ -1080,7 +1086,11 @@ def _get_module_call_graph(
for fqn, specs in module_call_specs.items():
mod_fqn = _strip_root(fqn) if not strict_mode_export else fqn
module_call_signatures[mod_fqn] = ModuleCallSignature(
inputs=[], outputs=[], **specs
inputs=[],
outputs=[],
in_spec=specs["in_spec"],
out_spec=specs["out_spec"],
forward_arg_names=None, # we only propage forward_arg_names for the top level module
)
if len(preserve_module_call_signature) > 0:
@ -1096,6 +1106,7 @@ def _get_module_call_graph(
original_in_spec,
out_spec,
module_call_signatures,
forward_arg_names,
)
return gm, module_call_graph
@ -1807,12 +1818,13 @@ def _export_for_training(
)
# The returned the gm is in-place modified
gm, module_call_graph = _get_module_call_graph(
export_artifact, orig_in_spec, preserve_module_call_signature, strict
export_artifact,
orig_in_spec,
preserve_module_call_signature,
strict,
forward_arg_names,
)
# Add forward args metadata.
gm.meta["forward_arg_names"] = forward_arg_names
_verify_nn_module_stack(gm)
_verify_stack_trace(gm)
_verify_placeholder_names(gm, export_graph_signature)
@ -1953,12 +1965,13 @@ def _export(
dynamic_shapes,
)
gm, module_call_graph = _get_module_call_graph(
export_artifact, original_in_spec, preserve_module_call_signature, strict
export_artifact,
original_in_spec,
preserve_module_call_signature,
strict,
forward_arg_names,
)
# Add forward args metadata.
gm.meta["forward_arg_names"] = forward_arg_names
_verify_nn_module_stack(gm)
_verify_stack_trace(gm)
if not _is_torch_jit_trace:

View File

@ -335,8 +335,9 @@ def _unlift_exported_program_lifted_states(ep: ExportedProgram) -> torch.nn.Modu
ep = _remove_effect_tokens(ep)
new_gm = torch.fx.GraphModule(ep.graph_module, copy.deepcopy(ep.graph))
_register_attrs_to_new_gm(new_gm, ep.graph_signature, ep.state_dict, ep.constants)
forward_arg_names = ep.graph_module.meta.get("forward_arg_names")
forward_arg_names = (
sig.forward_arg_names if (sig := ep.module_call_graph[0].signature) else None
)
lifted_inputs: List[Optional[str]] = [
(
in_spec.target

View File

@ -92,6 +92,7 @@ class ModuleCallSignature:
outputs: List[ArgumentSpec]
in_spec: pytree.TreeSpec
out_spec: pytree.TreeSpec
forward_arg_names: Optional[List[str]] = None
def replace_all_uses_with(self, original_node, new_node):
for i in self.inputs: