[export] Fix serialization of empty torch artifact (#125542)

A previous PR added support for serializing/deserializing example inputs, but this fails when `example_inputs` is none.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125542
Approved by: https://github.com/pianpwk, https://github.com/BoyuanFeng, https://github.com/ydwu4
This commit is contained in:
angelayi 2024-05-07 15:54:45 +00:00 committed by PyTorch MergeBot
parent b37bef9b13
commit 0de9ce9bb3
2 changed files with 18 additions and 1 deletions

View File

@ -860,6 +860,20 @@ class TestDeserialize(TestCase):
inputs = (torch.ones(2, 3),)
self.check_graph(m, inputs, strict=False)
def test_export_no_inputs(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.p = torch.ones(3, 3)
def forward(self):
return self.p * self.p
ep = torch.export.export(M(), ())
ep._example_inputs = None
roundtrip_ep = deserialize(serialize(ep))
self.assertTrue(torch.allclose(ep.module()(), roundtrip_ep.module()()))
instantiate_parametrized_tests(TestDeserialize)

View File

@ -284,7 +284,10 @@ def _reconstruct_fake_tensor(
return fake_tensor
def serialize_torch_artifact(artifact: Dict[str, Any]) -> bytes:
def serialize_torch_artifact(artifact: Optional[Any]) -> bytes:
if artifact is None:
return b""
assert (
FakeTensor not in copyreg.dispatch_table
), "Refusing to stomp on existing FakeTensor reducer"