mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Add original forward names to schema so that prettify pass works (#136887)
When we run_decomp, we retrace if it is training IR. As a result, we do need to reliably store the oroiginal forward names when we run decomp. Differential Revision: [D63064453](https://our.internmc.facebook.com/intern/diff/D63064453/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/136887 Approved by: https://github.com/angelayi
This commit is contained in:
parent
46525abb71
commit
bb31e3f57e
|
|
@ -176,6 +176,28 @@ class TestSerialize(TestCase):
|
||||||
loaded_out = loaded_model.module()(*loaded_args, **loaded_kwargs)
|
loaded_out = loaded_model.module()(*loaded_args, **loaded_kwargs)
|
||||||
self.assertEqual(orig_out, loaded_out)
|
self.assertEqual(orig_out, loaded_out)
|
||||||
|
|
||||||
|
def test_metadata_run_decomp_serder(self):
|
||||||
|
class M(torch.nn.Module):
|
||||||
|
def forward(self, x):
|
||||||
|
return x.sin()
|
||||||
|
|
||||||
|
exp_program = torch.export.export_for_training(M(), (torch.randn(4, 4),))
|
||||||
|
|
||||||
|
output_buffer = io.BytesIO()
|
||||||
|
# Tests that example forward arg names are preserved when saving and loading module.
|
||||||
|
torch.export.save(exp_program, output_buffer)
|
||||||
|
loaded_model = torch.export.load(output_buffer)
|
||||||
|
|
||||||
|
ep = loaded_model.run_decompositions({})
|
||||||
|
# We should preserve the original module name
|
||||||
|
self.assertExpectedInline(
|
||||||
|
str(ep.graph_module.code).strip(),
|
||||||
|
"""\
|
||||||
|
def forward(self, x):
|
||||||
|
sin = torch.ops.aten.sin.default(x); x = None
|
||||||
|
return (sin,)""",
|
||||||
|
)
|
||||||
|
|
||||||
def test_metadata_parsing_with_layer_split(self):
|
def test_metadata_parsing_with_layer_split(self):
|
||||||
# Tests that modules with more complicated layer patterns can be serialized
|
# Tests that modules with more complicated layer patterns can be serialized
|
||||||
# and deserialized correctly.
|
# and deserialized correctly.
|
||||||
|
|
|
||||||
|
|
@ -345,6 +345,10 @@ class ModuleCallSignature:
|
||||||
in_spec: str
|
in_spec: str
|
||||||
out_spec: str
|
out_spec: str
|
||||||
|
|
||||||
|
# This field is used to prettify the graph placeholders
|
||||||
|
# after we ser/der and retrace
|
||||||
|
forward_arg_names: Optional[List[str]] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModuleCallEntry:
|
class ModuleCallEntry:
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
# @generated by update_schema.py
|
# @generated by update_schema.py
|
||||||
# checksum<<8c4e7dd0f6fdb1fcb9225325cd7637139c1a69c5941847a86af9d4f1fbc0610c>>
|
# checksum<<69912a674f9c3123e488399d2bc8fdcf1226005721e4ec3dd12da0e176c16e50>>
|
||||||
Argument:
|
Argument:
|
||||||
kind: union
|
kind: union
|
||||||
fields:
|
fields:
|
||||||
|
|
@ -264,6 +264,9 @@ ModuleCallSignature:
|
||||||
type: str
|
type: str
|
||||||
out_spec:
|
out_spec:
|
||||||
type: str
|
type: str
|
||||||
|
forward_arg_names:
|
||||||
|
type: Optional[List[str]]
|
||||||
|
default: None
|
||||||
NamedArgument:
|
NamedArgument:
|
||||||
kind: struct
|
kind: struct
|
||||||
fields:
|
fields:
|
||||||
|
|
|
||||||
|
|
@ -1109,6 +1109,7 @@ class GraphModuleSerializer(metaclass=Final):
|
||||||
],
|
],
|
||||||
in_spec=treespec_dumps(module_call_signature.in_spec, TREESPEC_VERSION),
|
in_spec=treespec_dumps(module_call_signature.in_spec, TREESPEC_VERSION),
|
||||||
out_spec=treespec_dumps(module_call_signature.out_spec, TREESPEC_VERSION),
|
out_spec=treespec_dumps(module_call_signature.out_spec, TREESPEC_VERSION),
|
||||||
|
forward_arg_names=names if (names := module_call_signature.forward_arg_names) else None
|
||||||
)
|
)
|
||||||
|
|
||||||
def serialize_module_call_graph(
|
def serialize_module_call_graph(
|
||||||
|
|
@ -2257,6 +2258,7 @@ class GraphModuleDeserializer(metaclass=Final):
|
||||||
],
|
],
|
||||||
in_spec=treespec_loads(module_call_signature.in_spec),
|
in_spec=treespec_loads(module_call_signature.in_spec),
|
||||||
out_spec=treespec_loads(module_call_signature.out_spec),
|
out_spec=treespec_loads(module_call_signature.out_spec),
|
||||||
|
forward_arg_names=names if (names := module_call_signature.forward_arg_names) else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def deserialize_module_call_graph(
|
def deserialize_module_call_graph(
|
||||||
|
|
|
||||||
|
|
@ -507,6 +507,7 @@ def _make_module_call_graph(
|
||||||
in_spec: TreeSpec,
|
in_spec: TreeSpec,
|
||||||
out_spec: TreeSpec,
|
out_spec: TreeSpec,
|
||||||
module_call_signatures: Dict[str, ModuleCallSignature],
|
module_call_signatures: Dict[str, ModuleCallSignature],
|
||||||
|
forward_arg_names: Optional[List[str]] = None,
|
||||||
) -> List[ModuleCallEntry]:
|
) -> List[ModuleCallEntry]:
|
||||||
ret = [
|
ret = [
|
||||||
ModuleCallEntry(fqn=fqn, signature=module_call_signatures.get(fqn))
|
ModuleCallEntry(fqn=fqn, signature=module_call_signatures.get(fqn))
|
||||||
|
|
@ -514,7 +515,11 @@ def _make_module_call_graph(
|
||||||
]
|
]
|
||||||
assert ret[0].fqn == ""
|
assert ret[0].fqn == ""
|
||||||
ret[0].signature = ModuleCallSignature(
|
ret[0].signature = ModuleCallSignature(
|
||||||
inputs=[], outputs=[], in_spec=in_spec, out_spec=out_spec
|
inputs=[],
|
||||||
|
outputs=[],
|
||||||
|
in_spec=in_spec,
|
||||||
|
out_spec=out_spec,
|
||||||
|
forward_arg_names=forward_arg_names,
|
||||||
)
|
)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
@ -1063,6 +1068,7 @@ def _get_module_call_graph(
|
||||||
original_in_spec: TreeSpec,
|
original_in_spec: TreeSpec,
|
||||||
preserve_module_call_signature: Tuple[str, ...],
|
preserve_module_call_signature: Tuple[str, ...],
|
||||||
strict_mode_export: bool,
|
strict_mode_export: bool,
|
||||||
|
forward_arg_names: Optional[List[str]] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
In-place modify the graph module in export_artifact, remove _export_tracepoint nodes and
|
In-place modify the graph module in export_artifact, remove _export_tracepoint nodes and
|
||||||
|
|
@ -1080,7 +1086,11 @@ def _get_module_call_graph(
|
||||||
for fqn, specs in module_call_specs.items():
|
for fqn, specs in module_call_specs.items():
|
||||||
mod_fqn = _strip_root(fqn) if not strict_mode_export else fqn
|
mod_fqn = _strip_root(fqn) if not strict_mode_export else fqn
|
||||||
module_call_signatures[mod_fqn] = ModuleCallSignature(
|
module_call_signatures[mod_fqn] = ModuleCallSignature(
|
||||||
inputs=[], outputs=[], **specs
|
inputs=[],
|
||||||
|
outputs=[],
|
||||||
|
in_spec=specs["in_spec"],
|
||||||
|
out_spec=specs["out_spec"],
|
||||||
|
forward_arg_names=None, # we only propage forward_arg_names for the top level module
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(preserve_module_call_signature) > 0:
|
if len(preserve_module_call_signature) > 0:
|
||||||
|
|
@ -1096,6 +1106,7 @@ def _get_module_call_graph(
|
||||||
original_in_spec,
|
original_in_spec,
|
||||||
out_spec,
|
out_spec,
|
||||||
module_call_signatures,
|
module_call_signatures,
|
||||||
|
forward_arg_names,
|
||||||
)
|
)
|
||||||
return gm, module_call_graph
|
return gm, module_call_graph
|
||||||
|
|
||||||
|
|
@ -1807,12 +1818,13 @@ def _export_for_training(
|
||||||
)
|
)
|
||||||
# The returned the gm is in-place modified
|
# The returned the gm is in-place modified
|
||||||
gm, module_call_graph = _get_module_call_graph(
|
gm, module_call_graph = _get_module_call_graph(
|
||||||
export_artifact, orig_in_spec, preserve_module_call_signature, strict
|
export_artifact,
|
||||||
|
orig_in_spec,
|
||||||
|
preserve_module_call_signature,
|
||||||
|
strict,
|
||||||
|
forward_arg_names,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add forward args metadata.
|
|
||||||
gm.meta["forward_arg_names"] = forward_arg_names
|
|
||||||
|
|
||||||
_verify_nn_module_stack(gm)
|
_verify_nn_module_stack(gm)
|
||||||
_verify_stack_trace(gm)
|
_verify_stack_trace(gm)
|
||||||
_verify_placeholder_names(gm, export_graph_signature)
|
_verify_placeholder_names(gm, export_graph_signature)
|
||||||
|
|
@ -1953,12 +1965,13 @@ def _export(
|
||||||
dynamic_shapes,
|
dynamic_shapes,
|
||||||
)
|
)
|
||||||
gm, module_call_graph = _get_module_call_graph(
|
gm, module_call_graph = _get_module_call_graph(
|
||||||
export_artifact, original_in_spec, preserve_module_call_signature, strict
|
export_artifact,
|
||||||
|
original_in_spec,
|
||||||
|
preserve_module_call_signature,
|
||||||
|
strict,
|
||||||
|
forward_arg_names,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add forward args metadata.
|
|
||||||
gm.meta["forward_arg_names"] = forward_arg_names
|
|
||||||
|
|
||||||
_verify_nn_module_stack(gm)
|
_verify_nn_module_stack(gm)
|
||||||
_verify_stack_trace(gm)
|
_verify_stack_trace(gm)
|
||||||
if not _is_torch_jit_trace:
|
if not _is_torch_jit_trace:
|
||||||
|
|
|
||||||
|
|
@ -335,8 +335,9 @@ def _unlift_exported_program_lifted_states(ep: ExportedProgram) -> torch.nn.Modu
|
||||||
ep = _remove_effect_tokens(ep)
|
ep = _remove_effect_tokens(ep)
|
||||||
new_gm = torch.fx.GraphModule(ep.graph_module, copy.deepcopy(ep.graph))
|
new_gm = torch.fx.GraphModule(ep.graph_module, copy.deepcopy(ep.graph))
|
||||||
_register_attrs_to_new_gm(new_gm, ep.graph_signature, ep.state_dict, ep.constants)
|
_register_attrs_to_new_gm(new_gm, ep.graph_signature, ep.state_dict, ep.constants)
|
||||||
forward_arg_names = ep.graph_module.meta.get("forward_arg_names")
|
forward_arg_names = (
|
||||||
|
sig.forward_arg_names if (sig := ep.module_call_graph[0].signature) else None
|
||||||
|
)
|
||||||
lifted_inputs: List[Optional[str]] = [
|
lifted_inputs: List[Optional[str]] = [
|
||||||
(
|
(
|
||||||
in_spec.target
|
in_spec.target
|
||||||
|
|
|
||||||
|
|
@ -92,6 +92,7 @@ class ModuleCallSignature:
|
||||||
outputs: List[ArgumentSpec]
|
outputs: List[ArgumentSpec]
|
||||||
in_spec: pytree.TreeSpec
|
in_spec: pytree.TreeSpec
|
||||||
out_spec: pytree.TreeSpec
|
out_spec: pytree.TreeSpec
|
||||||
|
forward_arg_names: Optional[List[str]] = None
|
||||||
|
|
||||||
def replace_all_uses_with(self, original_node, new_node):
|
def replace_all_uses_with(self, original_node, new_node):
|
||||||
for i in self.inputs:
|
for i in self.inputs:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user