[fx] Recursive DCE on subgraphs (#152772)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152772
Approved by: https://github.com/bdhirsh, https://github.com/zou3519
This commit is contained in:
Animesh Jain 2025-05-05 14:02:08 -07:00 committed by PyTorch MergeBot
parent 35c727e7ff
commit b1d34acac5
5 changed files with 50 additions and 8 deletions

View File

@ -296,7 +296,6 @@ class GraphModule(torch.nn.Module):
mm: "f32[3, 3]" = torch.ops.aten.mm.default(arg0_1, arg1_1)
clone: "f32[3, 3]" = torch.ops.aten.clone.default(mm)
sin: "f32[3, 3]" = torch.ops.aten.sin.default(mm); mm = None
cos: "f32[3, 3]" = torch.ops.aten.cos.default(sin); cos = None
sin_1: "f32[3, 3]" = torch.ops.aten.sin.default(sin); sin = None
neg: "f32[3, 3]" = torch.ops.aten.neg.default(sin_1); sin_1 = None
mul: "f32[3, 3]" = torch.ops.aten.mul.Tensor(arg2_1, neg); arg2_1 = neg = None

View File

@ -4565,13 +4565,11 @@ def forward(self, arg0_1, arg1_1):
"""\
def forward(self, arg0_1, arg1_1):
cos = torch.ops.aten.cos.default(arg0_1); arg0_1 = None
select = torch.ops.aten.select.int(cos, 0, 0); select = None
body_graph_0 = self.body_graph_0
map_impl = torch.ops.higher_order.map_impl(body_graph_0, [cos], [arg1_1]); body_graph_0 = None
getitem = map_impl[0]; map_impl = None
sum_1 = torch.ops.aten.sum.default(getitem); getitem = None
add = torch.ops.aten.add.Tensor(cos, sum_1); sum_1 = None
select_1 = torch.ops.aten.select.int(cos, 0, 0); select_1 = None
body_graph_1 = self.body_graph_1
map_impl_1 = torch.ops.higher_order.map_impl(body_graph_1, [cos], [arg1_1]); body_graph_1 = cos = arg1_1 = None
getitem_1 = map_impl_1[0]; map_impl_1 = None
@ -5108,8 +5106,6 @@ def forward(self, arg0_1):
gm.true_graph_0.code.strip(),
"""\
def forward(self, arg0_1):
add = torch.ops.aten.add.Tensor(arg0_1, 4)
add_1 = torch.ops.aten.add.Tensor(add, 5); add = add_1 = None
cos = torch.ops.aten.cos.default(arg0_1); arg0_1 = None
return (cos,)""",
)
@ -5118,8 +5114,6 @@ def forward(self, arg0_1):
gm.false_graph_0.code.strip(),
"""\
def forward(self, arg0_1):
add = torch.ops.aten.add.Tensor(arg0_1, 5)
add_1 = torch.ops.aten.add.Tensor(add, 6); add = add_1 = None
sin = torch.ops.aten.sin.default(arg0_1); arg0_1 = None
return (sin,)""",
)

View File

@ -423,6 +423,40 @@ class GraphModule(torch.nn.Module):
""",
)
def test_dce(self):
@mark_compile_region
def gn(x):
x = torch.sin(x)
# should be dce'd
y = torch.cos(x) # noqa: F841
return x
def fn(x):
return gn(x)
backend = AotEagerAndRecordGraphs()
torch.compile(fn, backend=backend, fullgraph=True)(
torch.randn(4, requires_grad=False)
)
if not TEST_WITH_CROSSREF:
self.assertExpectedInline(
normalize_gm(backend.fw_graphs[0].print_readable(print_output=False)),
"""\
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f32[4]"):
repeated_subgraph0 = self.repeated_subgraph0
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', arg0_1); repeated_subgraph0 = arg0_1 = None
getitem: "f32[4]" = invoke_subgraph[0]; invoke_subgraph = None
return (getitem,)
class repeated_subgraph0(torch.nn.Module):
def forward(self, arg0_1: "f32[4]"):
sin: "f32[4]" = torch.ops.aten.sin.default(arg0_1); arg0_1 = None
return (sin,)
""",
)
def test_nonlocal_update(self):
counter = 2

View File

@ -3781,7 +3781,6 @@ class GraphModule(torch.nn.Module):
class joint_graph0(torch.nn.Module):
def forward(self, arg0_1: "f64[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i32[]", arg4_1: "i32[]", arg5_1: "f64[]"):
mul: "f64[]" = torch.ops.aten.mul.Tensor(arg0_1, arg0_1); mul = None
mul_1: "f64[]" = torch.ops.aten.mul.Tensor(arg5_1, arg0_1)
mul_2: "f64[]" = torch.ops.aten.mul.Tensor(arg5_1, arg0_1); arg5_1 = arg0_1 = None
add: "f64[]" = torch.ops.aten.add.Tensor(mul_2, mul_1); mul_2 = mul_1 = None

View File

@ -1779,6 +1779,8 @@ class Graph:
of functional operations or you supply your own custom
function for detecting side-effectful nodes.
"""
from torch.utils._ordered_set import OrderedSet
# Lint the graph first to make sure its topologically sorted, otherwise
# DCE below will not behave as expected.
self.lint()
@ -1801,6 +1803,20 @@ class Graph:
self.erase_node(node)
changed = True
# Call DCE on the subgraphs
if self.owning_module is not None:
subgraph_names = OrderedSet(
x.target for x in self.find_nodes(op="get_attr")
)
for child_name, child_module in self.owning_module.named_children():
# Sometimes an owning_module can have unused children. Skip them
# by checking them from get_attr node targets.
if child_name in subgraph_names and isinstance(
child_module, torch.fx.GraphModule
):
changed |= child_module.graph.eliminate_dead_code()
child_module.recompile()
return changed
@compatibility(is_backward_compatible=False)