mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[export] Fix serialization of empty torch artifact (#125542)
A previous PR added support for serializing/deserializing example inputs, but this fails when `example_inputs` is none. Pull Request resolved: https://github.com/pytorch/pytorch/pull/125542 Approved by: https://github.com/pianpwk, https://github.com/BoyuanFeng, https://github.com/ydwu4
This commit is contained in:
parent
b37bef9b13
commit
0de9ce9bb3
|
|
@ -860,6 +860,20 @@ class TestDeserialize(TestCase):
|
|||
inputs = (torch.ones(2, 3),)
|
||||
self.check_graph(m, inputs, strict=False)
|
||||
|
||||
def test_export_no_inputs(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.p = torch.ones(3, 3)
|
||||
|
||||
def forward(self):
|
||||
return self.p * self.p
|
||||
|
||||
ep = torch.export.export(M(), ())
|
||||
ep._example_inputs = None
|
||||
roundtrip_ep = deserialize(serialize(ep))
|
||||
self.assertTrue(torch.allclose(ep.module()(), roundtrip_ep.module()()))
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestDeserialize)
|
||||
|
||||
|
|
|
|||
|
|
@ -284,7 +284,10 @@ def _reconstruct_fake_tensor(
|
|||
return fake_tensor
|
||||
|
||||
|
||||
def serialize_torch_artifact(artifact: Dict[str, Any]) -> bytes:
|
||||
def serialize_torch_artifact(artifact: Optional[Any]) -> bytes:
|
||||
if artifact is None:
|
||||
return b""
|
||||
|
||||
assert (
|
||||
FakeTensor not in copyreg.dispatch_table
|
||||
), "Refusing to stomp on existing FakeTensor reducer"
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user