mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[export] Fix deserialization issue (#150515)
An internal model was serialized in 2023, and is now breaking while loading with the following error:
```
File "<eval_with_key>.1675", line 4
def forward(self, arg1163_1, arg1164_1, , arg1166_1, , arg1168_1, arg1169_1, arg1170_1, , arg1172_1, arg1173_1, arg1174_1, arg1175_1, arg1176_1, arg1177_1, arg1178_1, arg1179_1, arg1180_1, arg1181_1, arg1182_1, arg1183_1, arg1184_1, arg1185_1, arg1186_1, arg1187_1, arg1188_1, arg1189_1, arg1190_1, arg1191_1, arg1192_1, arg1193_1, arg1194_1, arg1195_1, arg1196_1, arg1197_1, arg1198_1, arg1199_1, arg1200_1, arg1201_1, arg1202_1, arg1203_1, arg1204_1, arg1205_1, arg1206_1, arg1207_1, arg1208_1, arg1209_1, arg1210_1, arg1211_1, arg1212_1, arg1213_1, arg1214_1, arg1215_1, arg1216_1, , arg1218_1, arg1219_1, arg1220_1, arg1221_1, arg1222_1, arg1223_1, arg1224_1, , arg1226_1, arg1227_1, arg1228_1, , arg1230_1, , , , , , , , , , , , , , , ):
^
SyntaxError: invalid syntax
```
The syntax errors are due to inputs that are `None` when exporting. Prior to changes in https://github.com/pytorch/pytorch/pull/123590 (landed 4/2024), input specs for none inputs look like `InputSpec(userInput=UserInputSpec(arg=Argument(asNone=True)))`, and during deserialization when creating a node, we would just use a dummy name `arg`. After to those changes, the input specs for none inputs look like `InputSpec(constantInput=InputToConstantInputSpec(name='y', value=ConstantValue(asNone=True)))`, and when creating a node we would use the name `y` as the name. However the PR didn't handle the case if it's loading an old package which doesn't have this name, so ended up putting empty names in the placeholder nodes.
This error was uncovered after https://github.com/pytorch/pytorch/pull/149717, where we now use the GraphModule's python codegen to run the UnflattenedModule instead of going through the interpreter path. The placeholder nodes having empty names caused the python codegen to fail.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150515
Approved by: https://github.com/yushangdi
This commit is contained in:
parent
a72b4eb806
commit
5314a6fe82
|
|
@ -16,6 +16,7 @@ from typing import NamedTuple
|
|||
|
||||
import torch
|
||||
import torch._dynamo as torchdynamo
|
||||
import torch._export.serde.schema as schema
|
||||
import torch.export._trace
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._export.db.case import ExportCase, SupportLevel
|
||||
|
|
@ -918,6 +919,32 @@ class TestDeserialize(TestCase):
|
|||
inp = (torch.ones(3, 3), torch.ones(3, 3), torch.tensor(2))
|
||||
self.check_graph(Mod(), inp, use_pre_dispatch=False)
|
||||
|
||||
def test_none_input(self):
|
||||
"""
|
||||
Testing a backwards-compatibility breakage where old models do not have
|
||||
an input spec with the node name.
|
||||
"""
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x, y, z):
|
||||
return x + z
|
||||
|
||||
ep = torch.export.export(M(), (torch.ones(3, 3), None, torch.ones(3, 3)))
|
||||
|
||||
serialized_program = ExportedProgramSerializer(None, 2).serialize(ep)
|
||||
serialized_program.exported_program.graph_module.signature.input_specs[
|
||||
1
|
||||
] = schema.InputSpec.create(
|
||||
user_input=schema.UserInputSpec(arg=schema.Argument.create(as_none=True))
|
||||
)
|
||||
ep = ExportedProgramDeserializer(None).deserialize(
|
||||
serialized_program.exported_program, {}, {}, {}
|
||||
)
|
||||
ep.graph_module.recompile()
|
||||
unflattened = torch.export.unflatten(ep)
|
||||
inp = (torch.rand(3, 3), None, torch.rand(3, 3))
|
||||
self.assertEqual(unflattened(*inp), M()(*inp))
|
||||
|
||||
def test_multi_return(self) -> None:
|
||||
"""
|
||||
Test multiple return from a single node (ex. layer_norm has 2 outputs)
|
||||
|
|
|
|||
|
|
@ -1861,7 +1861,7 @@ class GraphModuleDeserializer(metaclass=Final):
|
|||
"as_none",
|
||||
"as_string",
|
||||
):
|
||||
node_name = self.signature.input_specs[i].arg.name
|
||||
node_name = self.signature.input_specs[i].arg.name or f"arg{i}"
|
||||
placeholder_node = self.graph.placeholder(node_name)
|
||||
placeholder_node.meta["val"] = self.deserialize_input(input_)
|
||||
else:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user