From 058782c6ab347a424945f081f938d36548347e38 Mon Sep 17 00:00:00 2001 From: Malay Bag Date: Tue, 14 Oct 2025 20:26:24 +0000 Subject: [PATCH] [torch.export] Rmoving unused constants - add support for corner case (#165205) Summary: In some cases unused constant had only one level of child node, no second level of child node. Those constants should be removed too. The added test case has the scenario where this scenario will happen. Test Plan: ``` buck test mode/opt caffe2/test:test_export -- 'test_unused_constant' ``` https://www.internalfb.com/intern/testinfra/testrun/15481123837456594 Differential Revision: D84398413 Pull Request resolved: https://github.com/pytorch/pytorch/pull/165205 Approved by: https://github.com/angelayi --- test/export/test_export.py | 18 ++++++++++++++++++ torch/_export/passes/lift_constants_pass.py | 5 +++++ 2 files changed, 23 insertions(+) diff --git a/test/export/test_export.py b/test/export/test_export.py index 23dab73d898..197978a19d4 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -1628,6 +1628,24 @@ graph(): ep = export(M(), (torch.ones(3),)) self.assertEqual(len(ep.constants), 0) + class M(torch.nn.Module): + def __init__(self, num_features: int = 1) -> None: + super().__init__() + self.num_features = num_features + + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: + res = [torch.Tensor([])] * self.num_features + for i in range(self.num_features): + res[i] = x * (i + 1) + return res + + inp = torch.ones(3) + ep = export(M(), (inp,)) + self.assertEqual(len(ep.constants), 0) + + unf = unflatten(ep) + self.assertTrue(torch.allclose(M()(inp)[0], unf(inp)[0])) + def test_unbacked_bincount(self): class Foo(torch.nn.Module): def forward(self, xs): diff --git a/torch/_export/passes/lift_constants_pass.py b/torch/_export/passes/lift_constants_pass.py index 20253a91c25..7e57817eb68 100644 --- a/torch/_export/passes/lift_constants_pass.py +++ b/torch/_export/passes/lift_constants_pass.py @@ -142,6 +142,10 @@ def _unused_constant(node: torch.fx.Node) -> Optional[list[torch.fx.Node]]: if len(lift_fresh_node.users) > 1: return None + # Case 1: lift node is not used anywhere + if len(lift_fresh_node.users) == 0: + return [lift_fresh_node, node] + detach_node = next(iter(lift_fresh_node.users.keys())) if not ( detach_node.op == "call_function" @@ -156,6 +160,7 @@ def _unused_constant(node: torch.fx.Node) -> Optional[list[torch.fx.Node]]: if len(detach_node.users) > 0: return None else: + # Case 2: Lift node's child is not used anywhere return [detach_node, lift_fresh_node, node]