mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[export] Unflatten None (#153000)
Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/153000 Approved by: https://github.com/pianpwk
This commit is contained in:
parent
7b806a8cb1
commit
3cd69350ed
|
|
@ -954,6 +954,27 @@ class TestUnflatten(TestCase):
|
|||
unflattened.foo = torch.compile(unflattened.foo, fullgraph=True)
|
||||
self.compare_outputs(orig_eager, unflattened, inputs)
|
||||
|
||||
def test_unflatten_none(self):
|
||||
class M2(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
return x + x, None
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.m2 = M2()
|
||||
|
||||
def forward(self, x, y):
|
||||
x = x + x
|
||||
return self.m2(x, y)
|
||||
|
||||
ep = export(
|
||||
M(), (torch.rand(2, 3), None), preserve_module_call_signature=("m2",)
|
||||
)
|
||||
unflattened = unflatten(ep)
|
||||
inp = (torch.randn(2, 3), None)
|
||||
self.assertTrue(torch.allclose(M()(*inp)[0], unflattened(*inp)[0]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -1171,7 +1171,13 @@ class _ModuleFrame:
|
|||
for output in signature.outputs:
|
||||
if isinstance(
|
||||
output,
|
||||
(TensorArgument, SymIntArgument, SymBoolArgument, SymFloatArgument),
|
||||
(
|
||||
TensorArgument,
|
||||
SymIntArgument,
|
||||
SymBoolArgument,
|
||||
SymFloatArgument,
|
||||
ConstantArgument,
|
||||
),
|
||||
):
|
||||
if output.name in self.seen_nodes:
|
||||
orig_outputs.append(self.seen_nodes[output.name])
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user