[export] Unflatten None (#153000)

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153000
Approved by: https://github.com/pianpwk
This commit is contained in:
angelayi 2025-05-08 16:40:11 +00:00 committed by PyTorch MergeBot
parent 7b806a8cb1
commit 3cd69350ed
2 changed files with 28 additions and 1 deletions

View File

@ -954,6 +954,27 @@ class TestUnflatten(TestCase):
unflattened.foo = torch.compile(unflattened.foo, fullgraph=True)
self.compare_outputs(orig_eager, unflattened, inputs)
def test_unflatten_none(self):
class M2(torch.nn.Module):
def forward(self, x, y):
return x + x, None
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.m2 = M2()
def forward(self, x, y):
x = x + x
return self.m2(x, y)
ep = export(
M(), (torch.rand(2, 3), None), preserve_module_call_signature=("m2",)
)
unflattened = unflatten(ep)
inp = (torch.randn(2, 3), None)
self.assertTrue(torch.allclose(M()(*inp)[0], unflattened(*inp)[0]))
if __name__ == "__main__":
run_tests()

View File

@ -1171,7 +1171,13 @@ class _ModuleFrame:
for output in signature.outputs:
if isinstance(
output,
(TensorArgument, SymIntArgument, SymBoolArgument, SymFloatArgument),
(
TensorArgument,
SymIntArgument,
SymBoolArgument,
SymFloatArgument,
ConstantArgument,
),
):
if output.name in self.seen_nodes:
orig_outputs.append(self.seen_nodes[output.name])