From 121235956bab7430fb8d080cee209607f8387ead Mon Sep 17 00:00:00 2001 From: Jazlyn Li Date: Fri, 31 Oct 2025 16:58:15 +0000 Subject: [PATCH] update Node.is_impure check if subgraph contains impure ops (#166609) Summary: ## Context when `const_fold.split_const_subgraphs` sees a `call_module` node that is a GraphModule, by the existing implementation it can mark this node as const-foldable when it shouldn't. For example, a parent graph contains a `call_module` to a subgraph that has no inputs but contain impure ops inside. ``` parent graph(): %sub : [num_users=1] = call_module[target=sub](args = (), kwargs = {}) %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%sub, slice(None, None, None)), kwargs = {}) return (getitem,) submodule graph(): %randn : [num_users=1] = call_function[target=torch.ops.aten.randn.default](args = ([5, 10],), kwargs = {device: cpu, pin_memory: False}) %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%randn, 1), kwargs = {}) return (add,) ``` when `submodule` graph is fed to const_fold.split_const_subgraph, it would come out unmodified since randn is impure. But if the `submodule` is called by a `parent` graph, when `parent` is fed to const_fold.split_const_subgraph, it would come out folded. ``` parent after fold graph(): %_fx_const_folded_attrs : [num_users=1] = get_attr[target=_FX_CONST_FOLDED_ATTRS] return (_fx_const_folded_attrs,) ``` This is because `node.is_impure()` check inside `const_fold.split_const_subgraph` fail through, leading the call_module node to be marked as pure. ## Fix We can update `fx.node.Node.is_impure` function to check for ops inside a call_module node with an additional `subgraph_has_impure_ops` check: - if a call_module node calls a GraphModule, - check any call_function nodes are impure ops - recursively check any call_module nodes that call GraphModule If the call_module subgraph has impure ops, return True to `is_impure` Test Plan: added tests to test_fx_const_fold.py Differential Revision: D85798483 Pull Request resolved: https://github.com/pytorch/pytorch/pull/166609 Approved by: https://github.com/blaine-rister --- test/fx/test_fx_const_fold.py | 44 +++++++++++++++++++++++++++++++++++ torch/fx/node.py | 23 +++++++++++++++++- 2 files changed, 66 insertions(+), 1 deletion(-) diff --git a/test/fx/test_fx_const_fold.py b/test/fx/test_fx_const_fold.py index 8ff9fb43861..ed1a049a81b 100644 --- a/test/fx/test_fx_const_fold.py +++ b/test/fx/test_fx_const_fold.py @@ -707,6 +707,50 @@ class TestConstFold(TestCase): fold_result = mod_folded(in_x, in_y) self.assertTrue(torch.equal(fold_result, base_result)) + def test_fold_pure_subgraph(self): + class SubModule(torch.nn.Module): + def forward(self): + return torch.full((5, 10), 2.0) + 1 + + # Create a parent graph with this module as a subgraph and output + ep = torch.export.export(SubModule(), ()) + parent_graph = torch.fx.Graph() + call_mod = parent_graph.call_module("sub", args=()) + get_item = parent_graph.call_function( + operator.getitem, args=(call_mod, slice(None)) + ) + parent_graph.output((get_item,)) + parent = torch.fx.GraphModule({"sub": ep.module()}, parent_graph) + + mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs( + parent, device_for_folded_attrs="cpu" + ) + self._verify_const_fold_mod(mod_folded) + + def test_do_not_fold_impure_subgraph(self): + """ + Skip folding any subgraph containing impure ops. + """ + + class SubModule(torch.nn.Module): + def forward(self): + return torch.randn(5, 10) + 1 + + # Create a parent graph with this module as a subgraph and output + ep = torch.export.export(SubModule(), ()) + parent_graph = torch.fx.Graph() + call_mod = parent_graph.call_module("sub", args=()) + get_item = parent_graph.call_function( + operator.getitem, args=(call_mod, slice(None)) + ) + parent_graph.output((get_item,)) + parent = torch.fx.GraphModule({"sub": ep.module()}, parent_graph) + + mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs( + parent, device_for_folded_attrs="cpu" + ) + self.assertIsNone(mod_folded.const_subgraph_module) + if __name__ == "__main__": raise_on_run_directly("test/test_fx.py") diff --git a/torch/fx/node.py b/torch/fx/node.py index 48f57d58863..07a049ae956 100644 --- a/torch/fx/node.py +++ b/torch/fx/node.py @@ -754,6 +754,24 @@ class Node(_NodeBase): return self.target in _side_effectful_functions + def subgraph_has_impure_ops(module: torch.fx.GraphModule) -> bool: + """ + Return True if a GraphModule type subgraph contains any impure op, else False. + """ + assert isinstance(module, torch.fx.GraphModule), ( + "caller should only pass GraphModule to subgraph_has_impure_ops check" + ) + for node in module.graph.nodes: + if node.op == "call_function" and node.is_impure(impure_random): + return True + if ( + node.op == "call_module" + and (submodule := module.get_submodule(node.target)) + and isinstance(submodule, torch.fx.GraphModule) + ): + return subgraph_has_impure_ops(submodule) + return False + # Check if an impure module. if self.op == "call_module": assert self.graph.owning_module is not None, ( @@ -763,7 +781,10 @@ class Node(_NodeBase): assert target_mod is not None, ( f"Did not find expected submodule target {self.target}" ) - return getattr(target_mod, "_is_impure", False) + if isinstance(target_mod, torch.fx.GraphModule): + return subgraph_has_impure_ops(target_mod) + else: + return getattr(target_mod, "_is_impure", False) return False