update Node.is_impure check if subgraph contains impure ops (#166609)

Summary:
## Context
when `const_fold.split_const_subgraphs` sees a `call_module` node that is a GraphModule, by the existing implementation it can mark this node as const-foldable when it shouldn't.

For example, a parent graph contains a `call_module` to a subgraph that has no inputs but contain impure ops inside.
```
parent graph():
    %sub : [num_users=1] = call_module[target=sub](args = (), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%sub, slice(None, None, None)), kwargs = {})
    return (getitem,)

submodule graph():
    %randn : [num_users=1] = call_function[target=torch.ops.aten.randn.default](args = ([5, 10],), kwargs = {device: cpu, pin_memory: False})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%randn, 1), kwargs = {})
    return (add,)
```
when `submodule` graph is fed to const_fold.split_const_subgraph, it would come out unmodified since randn is impure.

But if the `submodule` is called by a `parent` graph, when `parent` is fed to const_fold.split_const_subgraph, it would come out folded.
```
parent after fold graph():
    %_fx_const_folded_attrs : [num_users=1] = get_attr[target=_FX_CONST_FOLDED_ATTRS]
    return (_fx_const_folded_attrs,)
```

This is because `node.is_impure()` check inside `const_fold.split_const_subgraph` fail through, leading the call_module node to be marked as pure.

## Fix

We can update `fx.node.Node.is_impure` function to check for ops inside a call_module node with an additional `subgraph_has_impure_ops` check:
- if a call_module node calls a GraphModule,
- check any call_function nodes are impure ops
- recursively check any call_module nodes that call GraphModule

If the call_module subgraph has impure ops, return True to `is_impure`

Test Plan: added tests to test_fx_const_fold.py

Differential Revision: D85798483

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166609
Approved by: https://github.com/blaine-rister
This commit is contained in:
Jazlyn Li 2025-10-31 16:58:15 +00:00 committed by PyTorch MergeBot
parent aa9c96af04
commit 121235956b
2 changed files with 66 additions and 1 deletions

View File

@ -707,6 +707,50 @@ class TestConstFold(TestCase):
fold_result = mod_folded(in_x, in_y) fold_result = mod_folded(in_x, in_y)
self.assertTrue(torch.equal(fold_result, base_result)) self.assertTrue(torch.equal(fold_result, base_result))
def test_fold_pure_subgraph(self):
class SubModule(torch.nn.Module):
def forward(self):
return torch.full((5, 10), 2.0) + 1
# Create a parent graph with this module as a subgraph and output
ep = torch.export.export(SubModule(), ())
parent_graph = torch.fx.Graph()
call_mod = parent_graph.call_module("sub", args=())
get_item = parent_graph.call_function(
operator.getitem, args=(call_mod, slice(None))
)
parent_graph.output((get_item,))
parent = torch.fx.GraphModule({"sub": ep.module()}, parent_graph)
mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(
parent, device_for_folded_attrs="cpu"
)
self._verify_const_fold_mod(mod_folded)
def test_do_not_fold_impure_subgraph(self):
"""
Skip folding any subgraph containing impure ops.
"""
class SubModule(torch.nn.Module):
def forward(self):
return torch.randn(5, 10) + 1
# Create a parent graph with this module as a subgraph and output
ep = torch.export.export(SubModule(), ())
parent_graph = torch.fx.Graph()
call_mod = parent_graph.call_module("sub", args=())
get_item = parent_graph.call_function(
operator.getitem, args=(call_mod, slice(None))
)
parent_graph.output((get_item,))
parent = torch.fx.GraphModule({"sub": ep.module()}, parent_graph)
mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(
parent, device_for_folded_attrs="cpu"
)
self.assertIsNone(mod_folded.const_subgraph_module)
if __name__ == "__main__": if __name__ == "__main__":
raise_on_run_directly("test/test_fx.py") raise_on_run_directly("test/test_fx.py")

View File

@ -754,6 +754,24 @@ class Node(_NodeBase):
return self.target in _side_effectful_functions return self.target in _side_effectful_functions
def subgraph_has_impure_ops(module: torch.fx.GraphModule) -> bool:
"""
Return True if a GraphModule type subgraph contains any impure op, else False.
"""
assert isinstance(module, torch.fx.GraphModule), (
"caller should only pass GraphModule to subgraph_has_impure_ops check"
)
for node in module.graph.nodes:
if node.op == "call_function" and node.is_impure(impure_random):
return True
if (
node.op == "call_module"
and (submodule := module.get_submodule(node.target))
and isinstance(submodule, torch.fx.GraphModule)
):
return subgraph_has_impure_ops(submodule)
return False
# Check if an impure module. # Check if an impure module.
if self.op == "call_module": if self.op == "call_module":
assert self.graph.owning_module is not None, ( assert self.graph.owning_module is not None, (
@ -763,6 +781,9 @@ class Node(_NodeBase):
assert target_mod is not None, ( assert target_mod is not None, (
f"Did not find expected submodule target {self.target}" f"Did not find expected submodule target {self.target}"
) )
if isinstance(target_mod, torch.fx.GraphModule):
return subgraph_has_impure_ops(target_mod)
else:
return getattr(target_mod, "_is_impure", False) return getattr(target_mod, "_is_impure", False)
return False return False