diff --git a/test/fx/test_fx_const_fold.py b/test/fx/test_fx_const_fold.py index 8ff9fb43861..ed1a049a81b 100644 --- a/test/fx/test_fx_const_fold.py +++ b/test/fx/test_fx_const_fold.py @@ -707,6 +707,50 @@ class TestConstFold(TestCase): fold_result = mod_folded(in_x, in_y) self.assertTrue(torch.equal(fold_result, base_result)) + def test_fold_pure_subgraph(self): + class SubModule(torch.nn.Module): + def forward(self): + return torch.full((5, 10), 2.0) + 1 + + # Create a parent graph with this module as a subgraph and output + ep = torch.export.export(SubModule(), ()) + parent_graph = torch.fx.Graph() + call_mod = parent_graph.call_module("sub", args=()) + get_item = parent_graph.call_function( + operator.getitem, args=(call_mod, slice(None)) + ) + parent_graph.output((get_item,)) + parent = torch.fx.GraphModule({"sub": ep.module()}, parent_graph) + + mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs( + parent, device_for_folded_attrs="cpu" + ) + self._verify_const_fold_mod(mod_folded) + + def test_do_not_fold_impure_subgraph(self): + """ + Skip folding any subgraph containing impure ops. + """ + + class SubModule(torch.nn.Module): + def forward(self): + return torch.randn(5, 10) + 1 + + # Create a parent graph with this module as a subgraph and output + ep = torch.export.export(SubModule(), ()) + parent_graph = torch.fx.Graph() + call_mod = parent_graph.call_module("sub", args=()) + get_item = parent_graph.call_function( + operator.getitem, args=(call_mod, slice(None)) + ) + parent_graph.output((get_item,)) + parent = torch.fx.GraphModule({"sub": ep.module()}, parent_graph) + + mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs( + parent, device_for_folded_attrs="cpu" + ) + self.assertIsNone(mod_folded.const_subgraph_module) + if __name__ == "__main__": raise_on_run_directly("test/test_fx.py") diff --git a/torch/fx/node.py b/torch/fx/node.py index 48f57d58863..07a049ae956 100644 --- a/torch/fx/node.py +++ b/torch/fx/node.py @@ -754,6 +754,24 @@ class Node(_NodeBase): return self.target in _side_effectful_functions + def subgraph_has_impure_ops(module: torch.fx.GraphModule) -> bool: + """ + Return True if a GraphModule type subgraph contains any impure op, else False. + """ + assert isinstance(module, torch.fx.GraphModule), ( + "caller should only pass GraphModule to subgraph_has_impure_ops check" + ) + for node in module.graph.nodes: + if node.op == "call_function" and node.is_impure(impure_random): + return True + if ( + node.op == "call_module" + and (submodule := module.get_submodule(node.target)) + and isinstance(submodule, torch.fx.GraphModule) + ): + return subgraph_has_impure_ops(submodule) + return False + # Check if an impure module. if self.op == "call_module": assert self.graph.owning_module is not None, ( @@ -763,7 +781,10 @@ class Node(_NodeBase): assert target_mod is not None, ( f"Did not find expected submodule target {self.target}" ) - return getattr(target_mod, "_is_impure", False) + if isinstance(target_mod, torch.fx.GraphModule): + return subgraph_has_impure_ops(target_mod) + else: + return getattr(target_mod, "_is_impure", False) return False