mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Add original forward names to schema so that prettify pass works (#136887)
When we run_decomp, we retrace if it is training IR. As a result, we do need to reliably store the oroiginal forward names when we run decomp. Differential Revision: [D63064453](https://our.internmc.facebook.com/intern/diff/D63064453/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/136887 Approved by: https://github.com/angelayi
This commit is contained in:
parent
46525abb71
commit
bb31e3f57e
|
|
@ -176,6 +176,28 @@ class TestSerialize(TestCase):
|
|||
loaded_out = loaded_model.module()(*loaded_args, **loaded_kwargs)
|
||||
self.assertEqual(orig_out, loaded_out)
|
||||
|
||||
def test_metadata_run_decomp_serder(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return x.sin()
|
||||
|
||||
exp_program = torch.export.export_for_training(M(), (torch.randn(4, 4),))
|
||||
|
||||
output_buffer = io.BytesIO()
|
||||
# Tests that example forward arg names are preserved when saving and loading module.
|
||||
torch.export.save(exp_program, output_buffer)
|
||||
loaded_model = torch.export.load(output_buffer)
|
||||
|
||||
ep = loaded_model.run_decompositions({})
|
||||
# We should preserve the original module name
|
||||
self.assertExpectedInline(
|
||||
str(ep.graph_module.code).strip(),
|
||||
"""\
|
||||
def forward(self, x):
|
||||
sin = torch.ops.aten.sin.default(x); x = None
|
||||
return (sin,)""",
|
||||
)
|
||||
|
||||
def test_metadata_parsing_with_layer_split(self):
|
||||
# Tests that modules with more complicated layer patterns can be serialized
|
||||
# and deserialized correctly.
|
||||
|
|
|
|||
|
|
@ -345,6 +345,10 @@ class ModuleCallSignature:
|
|||
in_spec: str
|
||||
out_spec: str
|
||||
|
||||
# This field is used to prettify the graph placeholders
|
||||
# after we ser/der and retrace
|
||||
forward_arg_names: Optional[List[str]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModuleCallEntry:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
# @generated by update_schema.py
|
||||
# checksum<<8c4e7dd0f6fdb1fcb9225325cd7637139c1a69c5941847a86af9d4f1fbc0610c>>
|
||||
# checksum<<69912a674f9c3123e488399d2bc8fdcf1226005721e4ec3dd12da0e176c16e50>>
|
||||
Argument:
|
||||
kind: union
|
||||
fields:
|
||||
|
|
@ -264,6 +264,9 @@ ModuleCallSignature:
|
|||
type: str
|
||||
out_spec:
|
||||
type: str
|
||||
forward_arg_names:
|
||||
type: Optional[List[str]]
|
||||
default: None
|
||||
NamedArgument:
|
||||
kind: struct
|
||||
fields:
|
||||
|
|
|
|||
|
|
@ -1109,6 +1109,7 @@ class GraphModuleSerializer(metaclass=Final):
|
|||
],
|
||||
in_spec=treespec_dumps(module_call_signature.in_spec, TREESPEC_VERSION),
|
||||
out_spec=treespec_dumps(module_call_signature.out_spec, TREESPEC_VERSION),
|
||||
forward_arg_names=names if (names := module_call_signature.forward_arg_names) else None
|
||||
)
|
||||
|
||||
def serialize_module_call_graph(
|
||||
|
|
@ -2257,6 +2258,7 @@ class GraphModuleDeserializer(metaclass=Final):
|
|||
],
|
||||
in_spec=treespec_loads(module_call_signature.in_spec),
|
||||
out_spec=treespec_loads(module_call_signature.out_spec),
|
||||
forward_arg_names=names if (names := module_call_signature.forward_arg_names) else None,
|
||||
)
|
||||
|
||||
def deserialize_module_call_graph(
|
||||
|
|
|
|||
|
|
@ -507,6 +507,7 @@ def _make_module_call_graph(
|
|||
in_spec: TreeSpec,
|
||||
out_spec: TreeSpec,
|
||||
module_call_signatures: Dict[str, ModuleCallSignature],
|
||||
forward_arg_names: Optional[List[str]] = None,
|
||||
) -> List[ModuleCallEntry]:
|
||||
ret = [
|
||||
ModuleCallEntry(fqn=fqn, signature=module_call_signatures.get(fqn))
|
||||
|
|
@ -514,7 +515,11 @@ def _make_module_call_graph(
|
|||
]
|
||||
assert ret[0].fqn == ""
|
||||
ret[0].signature = ModuleCallSignature(
|
||||
inputs=[], outputs=[], in_spec=in_spec, out_spec=out_spec
|
||||
inputs=[],
|
||||
outputs=[],
|
||||
in_spec=in_spec,
|
||||
out_spec=out_spec,
|
||||
forward_arg_names=forward_arg_names,
|
||||
)
|
||||
return ret
|
||||
|
||||
|
|
@ -1063,6 +1068,7 @@ def _get_module_call_graph(
|
|||
original_in_spec: TreeSpec,
|
||||
preserve_module_call_signature: Tuple[str, ...],
|
||||
strict_mode_export: bool,
|
||||
forward_arg_names: Optional[List[str]] = None,
|
||||
):
|
||||
"""
|
||||
In-place modify the graph module in export_artifact, remove _export_tracepoint nodes and
|
||||
|
|
@ -1080,7 +1086,11 @@ def _get_module_call_graph(
|
|||
for fqn, specs in module_call_specs.items():
|
||||
mod_fqn = _strip_root(fqn) if not strict_mode_export else fqn
|
||||
module_call_signatures[mod_fqn] = ModuleCallSignature(
|
||||
inputs=[], outputs=[], **specs
|
||||
inputs=[],
|
||||
outputs=[],
|
||||
in_spec=specs["in_spec"],
|
||||
out_spec=specs["out_spec"],
|
||||
forward_arg_names=None, # we only propage forward_arg_names for the top level module
|
||||
)
|
||||
|
||||
if len(preserve_module_call_signature) > 0:
|
||||
|
|
@ -1096,6 +1106,7 @@ def _get_module_call_graph(
|
|||
original_in_spec,
|
||||
out_spec,
|
||||
module_call_signatures,
|
||||
forward_arg_names,
|
||||
)
|
||||
return gm, module_call_graph
|
||||
|
||||
|
|
@ -1807,12 +1818,13 @@ def _export_for_training(
|
|||
)
|
||||
# The returned the gm is in-place modified
|
||||
gm, module_call_graph = _get_module_call_graph(
|
||||
export_artifact, orig_in_spec, preserve_module_call_signature, strict
|
||||
export_artifact,
|
||||
orig_in_spec,
|
||||
preserve_module_call_signature,
|
||||
strict,
|
||||
forward_arg_names,
|
||||
)
|
||||
|
||||
# Add forward args metadata.
|
||||
gm.meta["forward_arg_names"] = forward_arg_names
|
||||
|
||||
_verify_nn_module_stack(gm)
|
||||
_verify_stack_trace(gm)
|
||||
_verify_placeholder_names(gm, export_graph_signature)
|
||||
|
|
@ -1953,12 +1965,13 @@ def _export(
|
|||
dynamic_shapes,
|
||||
)
|
||||
gm, module_call_graph = _get_module_call_graph(
|
||||
export_artifact, original_in_spec, preserve_module_call_signature, strict
|
||||
export_artifact,
|
||||
original_in_spec,
|
||||
preserve_module_call_signature,
|
||||
strict,
|
||||
forward_arg_names,
|
||||
)
|
||||
|
||||
# Add forward args metadata.
|
||||
gm.meta["forward_arg_names"] = forward_arg_names
|
||||
|
||||
_verify_nn_module_stack(gm)
|
||||
_verify_stack_trace(gm)
|
||||
if not _is_torch_jit_trace:
|
||||
|
|
|
|||
|
|
@ -335,8 +335,9 @@ def _unlift_exported_program_lifted_states(ep: ExportedProgram) -> torch.nn.Modu
|
|||
ep = _remove_effect_tokens(ep)
|
||||
new_gm = torch.fx.GraphModule(ep.graph_module, copy.deepcopy(ep.graph))
|
||||
_register_attrs_to_new_gm(new_gm, ep.graph_signature, ep.state_dict, ep.constants)
|
||||
forward_arg_names = ep.graph_module.meta.get("forward_arg_names")
|
||||
|
||||
forward_arg_names = (
|
||||
sig.forward_arg_names if (sig := ep.module_call_graph[0].signature) else None
|
||||
)
|
||||
lifted_inputs: List[Optional[str]] = [
|
||||
(
|
||||
in_spec.target
|
||||
|
|
|
|||
|
|
@ -92,6 +92,7 @@ class ModuleCallSignature:
|
|||
outputs: List[ArgumentSpec]
|
||||
in_spec: pytree.TreeSpec
|
||||
out_spec: pytree.TreeSpec
|
||||
forward_arg_names: Optional[List[str]] = None
|
||||
|
||||
def replace_all_uses_with(self, original_node, new_node):
|
||||
for i in self.inputs:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user