From bb31e3f57e9480a1d427b5f5c3b51eb3ae4aa16a Mon Sep 17 00:00:00 2001 From: Tugsbayasgalan Manlaibaatar Date: Mon, 7 Oct 2024 15:45:03 -0700 Subject: [PATCH] Add original forward names to schema so that prettify pass works (#136887) When we run_decomp, we retrace if it is training IR. As a result, we do need to reliably store the oroiginal forward names when we run decomp. Differential Revision: [D63064453](https://our.internmc.facebook.com/intern/diff/D63064453/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/136887 Approved by: https://github.com/angelayi --- test/export/test_serialize.py | 22 +++++++++++++++++++++ torch/_export/serde/schema.py | 4 ++++ torch/_export/serde/schema.yaml | 5 ++++- torch/_export/serde/serialize.py | 2 ++ torch/export/_trace.py | 33 ++++++++++++++++++++++---------- torch/export/_unlift.py | 5 +++-- torch/export/exported_program.py | 1 + 7 files changed, 59 insertions(+), 13 deletions(-) diff --git a/test/export/test_serialize.py b/test/export/test_serialize.py index 2a4097f6de4..19e3db9ed29 100644 --- a/test/export/test_serialize.py +++ b/test/export/test_serialize.py @@ -176,6 +176,28 @@ class TestSerialize(TestCase): loaded_out = loaded_model.module()(*loaded_args, **loaded_kwargs) self.assertEqual(orig_out, loaded_out) + def test_metadata_run_decomp_serder(self): + class M(torch.nn.Module): + def forward(self, x): + return x.sin() + + exp_program = torch.export.export_for_training(M(), (torch.randn(4, 4),)) + + output_buffer = io.BytesIO() + # Tests that example forward arg names are preserved when saving and loading module. + torch.export.save(exp_program, output_buffer) + loaded_model = torch.export.load(output_buffer) + + ep = loaded_model.run_decompositions({}) + # We should preserve the original module name + self.assertExpectedInline( + str(ep.graph_module.code).strip(), + """\ +def forward(self, x): + sin = torch.ops.aten.sin.default(x); x = None + return (sin,)""", + ) + def test_metadata_parsing_with_layer_split(self): # Tests that modules with more complicated layer patterns can be serialized # and deserialized correctly. diff --git a/torch/_export/serde/schema.py b/torch/_export/serde/schema.py index 3c901183964..00b41f8a37c 100644 --- a/torch/_export/serde/schema.py +++ b/torch/_export/serde/schema.py @@ -345,6 +345,10 @@ class ModuleCallSignature: in_spec: str out_spec: str + # This field is used to prettify the graph placeholders + # after we ser/der and retrace + forward_arg_names: Optional[List[str]] = None + @dataclass class ModuleCallEntry: diff --git a/torch/_export/serde/schema.yaml b/torch/_export/serde/schema.yaml index 5da198529b4..c34576f9096 100644 --- a/torch/_export/serde/schema.yaml +++ b/torch/_export/serde/schema.yaml @@ -1,5 +1,5 @@ # @generated by update_schema.py -# checksum<<8c4e7dd0f6fdb1fcb9225325cd7637139c1a69c5941847a86af9d4f1fbc0610c>> +# checksum<<69912a674f9c3123e488399d2bc8fdcf1226005721e4ec3dd12da0e176c16e50>> Argument: kind: union fields: @@ -264,6 +264,9 @@ ModuleCallSignature: type: str out_spec: type: str + forward_arg_names: + type: Optional[List[str]] + default: None NamedArgument: kind: struct fields: diff --git a/torch/_export/serde/serialize.py b/torch/_export/serde/serialize.py index d3186f2d3d2..09c50438620 100644 --- a/torch/_export/serde/serialize.py +++ b/torch/_export/serde/serialize.py @@ -1109,6 +1109,7 @@ class GraphModuleSerializer(metaclass=Final): ], in_spec=treespec_dumps(module_call_signature.in_spec, TREESPEC_VERSION), out_spec=treespec_dumps(module_call_signature.out_spec, TREESPEC_VERSION), + forward_arg_names=names if (names := module_call_signature.forward_arg_names) else None ) def serialize_module_call_graph( @@ -2257,6 +2258,7 @@ class GraphModuleDeserializer(metaclass=Final): ], in_spec=treespec_loads(module_call_signature.in_spec), out_spec=treespec_loads(module_call_signature.out_spec), + forward_arg_names=names if (names := module_call_signature.forward_arg_names) else None, ) def deserialize_module_call_graph( diff --git a/torch/export/_trace.py b/torch/export/_trace.py index 77f1297d678..42a32f2fa6e 100644 --- a/torch/export/_trace.py +++ b/torch/export/_trace.py @@ -507,6 +507,7 @@ def _make_module_call_graph( in_spec: TreeSpec, out_spec: TreeSpec, module_call_signatures: Dict[str, ModuleCallSignature], + forward_arg_names: Optional[List[str]] = None, ) -> List[ModuleCallEntry]: ret = [ ModuleCallEntry(fqn=fqn, signature=module_call_signatures.get(fqn)) @@ -514,7 +515,11 @@ def _make_module_call_graph( ] assert ret[0].fqn == "" ret[0].signature = ModuleCallSignature( - inputs=[], outputs=[], in_spec=in_spec, out_spec=out_spec + inputs=[], + outputs=[], + in_spec=in_spec, + out_spec=out_spec, + forward_arg_names=forward_arg_names, ) return ret @@ -1063,6 +1068,7 @@ def _get_module_call_graph( original_in_spec: TreeSpec, preserve_module_call_signature: Tuple[str, ...], strict_mode_export: bool, + forward_arg_names: Optional[List[str]] = None, ): """ In-place modify the graph module in export_artifact, remove _export_tracepoint nodes and @@ -1080,7 +1086,11 @@ def _get_module_call_graph( for fqn, specs in module_call_specs.items(): mod_fqn = _strip_root(fqn) if not strict_mode_export else fqn module_call_signatures[mod_fqn] = ModuleCallSignature( - inputs=[], outputs=[], **specs + inputs=[], + outputs=[], + in_spec=specs["in_spec"], + out_spec=specs["out_spec"], + forward_arg_names=None, # we only propage forward_arg_names for the top level module ) if len(preserve_module_call_signature) > 0: @@ -1096,6 +1106,7 @@ def _get_module_call_graph( original_in_spec, out_spec, module_call_signatures, + forward_arg_names, ) return gm, module_call_graph @@ -1807,12 +1818,13 @@ def _export_for_training( ) # The returned the gm is in-place modified gm, module_call_graph = _get_module_call_graph( - export_artifact, orig_in_spec, preserve_module_call_signature, strict + export_artifact, + orig_in_spec, + preserve_module_call_signature, + strict, + forward_arg_names, ) - # Add forward args metadata. - gm.meta["forward_arg_names"] = forward_arg_names - _verify_nn_module_stack(gm) _verify_stack_trace(gm) _verify_placeholder_names(gm, export_graph_signature) @@ -1953,12 +1965,13 @@ def _export( dynamic_shapes, ) gm, module_call_graph = _get_module_call_graph( - export_artifact, original_in_spec, preserve_module_call_signature, strict + export_artifact, + original_in_spec, + preserve_module_call_signature, + strict, + forward_arg_names, ) - # Add forward args metadata. - gm.meta["forward_arg_names"] = forward_arg_names - _verify_nn_module_stack(gm) _verify_stack_trace(gm) if not _is_torch_jit_trace: diff --git a/torch/export/_unlift.py b/torch/export/_unlift.py index 5a4cd2a559e..4f6a45585ca 100644 --- a/torch/export/_unlift.py +++ b/torch/export/_unlift.py @@ -335,8 +335,9 @@ def _unlift_exported_program_lifted_states(ep: ExportedProgram) -> torch.nn.Modu ep = _remove_effect_tokens(ep) new_gm = torch.fx.GraphModule(ep.graph_module, copy.deepcopy(ep.graph)) _register_attrs_to_new_gm(new_gm, ep.graph_signature, ep.state_dict, ep.constants) - forward_arg_names = ep.graph_module.meta.get("forward_arg_names") - + forward_arg_names = ( + sig.forward_arg_names if (sig := ep.module_call_graph[0].signature) else None + ) lifted_inputs: List[Optional[str]] = [ ( in_spec.target diff --git a/torch/export/exported_program.py b/torch/export/exported_program.py index 3f7fabd3615..e357f8f067e 100644 --- a/torch/export/exported_program.py +++ b/torch/export/exported_program.py @@ -92,6 +92,7 @@ class ModuleCallSignature: outputs: List[ArgumentSpec] in_spec: pytree.TreeSpec out_spec: pytree.TreeSpec + forward_arg_names: Optional[List[str]] = None def replace_all_uses_with(self, original_node, new_node): for i in self.inputs: