mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
update Node.is_impure check if subgraph contains impure ops (#166609)
Summary:
## Context
when `const_fold.split_const_subgraphs` sees a `call_module` node that is a GraphModule, by the existing implementation it can mark this node as const-foldable when it shouldn't.
For example, a parent graph contains a `call_module` to a subgraph that has no inputs but contain impure ops inside.
```
parent graph():
%sub : [num_users=1] = call_module[target=sub](args = (), kwargs = {})
%getitem : [num_users=1] = call_function[target=operator.getitem](args = (%sub, slice(None, None, None)), kwargs = {})
return (getitem,)
submodule graph():
%randn : [num_users=1] = call_function[target=torch.ops.aten.randn.default](args = ([5, 10],), kwargs = {device: cpu, pin_memory: False})
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%randn, 1), kwargs = {})
return (add,)
```
when `submodule` graph is fed to const_fold.split_const_subgraph, it would come out unmodified since randn is impure.
But if the `submodule` is called by a `parent` graph, when `parent` is fed to const_fold.split_const_subgraph, it would come out folded.
```
parent after fold graph():
%_fx_const_folded_attrs : [num_users=1] = get_attr[target=_FX_CONST_FOLDED_ATTRS]
return (_fx_const_folded_attrs,)
```
This is because `node.is_impure()` check inside `const_fold.split_const_subgraph` fail through, leading the call_module node to be marked as pure.
## Fix
We can update `fx.node.Node.is_impure` function to check for ops inside a call_module node with an additional `subgraph_has_impure_ops` check:
- if a call_module node calls a GraphModule,
- check any call_function nodes are impure ops
- recursively check any call_module nodes that call GraphModule
If the call_module subgraph has impure ops, return True to `is_impure`
Test Plan: added tests to test_fx_const_fold.py
Differential Revision: D85798483
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166609
Approved by: https://github.com/blaine-rister
This commit is contained in:
parent
aa9c96af04
commit
121235956b
|
|
@ -707,6 +707,50 @@ class TestConstFold(TestCase):
|
|||
fold_result = mod_folded(in_x, in_y)
|
||||
self.assertTrue(torch.equal(fold_result, base_result))
|
||||
|
||||
def test_fold_pure_subgraph(self):
|
||||
class SubModule(torch.nn.Module):
|
||||
def forward(self):
|
||||
return torch.full((5, 10), 2.0) + 1
|
||||
|
||||
# Create a parent graph with this module as a subgraph and output
|
||||
ep = torch.export.export(SubModule(), ())
|
||||
parent_graph = torch.fx.Graph()
|
||||
call_mod = parent_graph.call_module("sub", args=())
|
||||
get_item = parent_graph.call_function(
|
||||
operator.getitem, args=(call_mod, slice(None))
|
||||
)
|
||||
parent_graph.output((get_item,))
|
||||
parent = torch.fx.GraphModule({"sub": ep.module()}, parent_graph)
|
||||
|
||||
mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(
|
||||
parent, device_for_folded_attrs="cpu"
|
||||
)
|
||||
self._verify_const_fold_mod(mod_folded)
|
||||
|
||||
def test_do_not_fold_impure_subgraph(self):
|
||||
"""
|
||||
Skip folding any subgraph containing impure ops.
|
||||
"""
|
||||
|
||||
class SubModule(torch.nn.Module):
|
||||
def forward(self):
|
||||
return torch.randn(5, 10) + 1
|
||||
|
||||
# Create a parent graph with this module as a subgraph and output
|
||||
ep = torch.export.export(SubModule(), ())
|
||||
parent_graph = torch.fx.Graph()
|
||||
call_mod = parent_graph.call_module("sub", args=())
|
||||
get_item = parent_graph.call_function(
|
||||
operator.getitem, args=(call_mod, slice(None))
|
||||
)
|
||||
parent_graph.output((get_item,))
|
||||
parent = torch.fx.GraphModule({"sub": ep.module()}, parent_graph)
|
||||
|
||||
mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(
|
||||
parent, device_for_folded_attrs="cpu"
|
||||
)
|
||||
self.assertIsNone(mod_folded.const_subgraph_module)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_fx.py")
|
||||
|
|
|
|||
|
|
@ -754,6 +754,24 @@ class Node(_NodeBase):
|
|||
|
||||
return self.target in _side_effectful_functions
|
||||
|
||||
def subgraph_has_impure_ops(module: torch.fx.GraphModule) -> bool:
|
||||
"""
|
||||
Return True if a GraphModule type subgraph contains any impure op, else False.
|
||||
"""
|
||||
assert isinstance(module, torch.fx.GraphModule), (
|
||||
"caller should only pass GraphModule to subgraph_has_impure_ops check"
|
||||
)
|
||||
for node in module.graph.nodes:
|
||||
if node.op == "call_function" and node.is_impure(impure_random):
|
||||
return True
|
||||
if (
|
||||
node.op == "call_module"
|
||||
and (submodule := module.get_submodule(node.target))
|
||||
and isinstance(submodule, torch.fx.GraphModule)
|
||||
):
|
||||
return subgraph_has_impure_ops(submodule)
|
||||
return False
|
||||
|
||||
# Check if an impure module.
|
||||
if self.op == "call_module":
|
||||
assert self.graph.owning_module is not None, (
|
||||
|
|
@ -763,7 +781,10 @@ class Node(_NodeBase):
|
|||
assert target_mod is not None, (
|
||||
f"Did not find expected submodule target {self.target}"
|
||||
)
|
||||
return getattr(target_mod, "_is_impure", False)
|
||||
if isinstance(target_mod, torch.fx.GraphModule):
|
||||
return subgraph_has_impure_ops(target_mod)
|
||||
else:
|
||||
return getattr(target_mod, "_is_impure", False)
|
||||
|
||||
return False
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user