mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[fx] Recursive DCE on subgraphs (#152772)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152772 Approved by: https://github.com/bdhirsh, https://github.com/zou3519
This commit is contained in:
parent
35c727e7ff
commit
b1d34acac5
|
|
@ -296,7 +296,6 @@ class GraphModule(torch.nn.Module):
|
|||
mm: "f32[3, 3]" = torch.ops.aten.mm.default(arg0_1, arg1_1)
|
||||
clone: "f32[3, 3]" = torch.ops.aten.clone.default(mm)
|
||||
sin: "f32[3, 3]" = torch.ops.aten.sin.default(mm); mm = None
|
||||
cos: "f32[3, 3]" = torch.ops.aten.cos.default(sin); cos = None
|
||||
sin_1: "f32[3, 3]" = torch.ops.aten.sin.default(sin); sin = None
|
||||
neg: "f32[3, 3]" = torch.ops.aten.neg.default(sin_1); sin_1 = None
|
||||
mul: "f32[3, 3]" = torch.ops.aten.mul.Tensor(arg2_1, neg); arg2_1 = neg = None
|
||||
|
|
|
|||
|
|
@ -4565,13 +4565,11 @@ def forward(self, arg0_1, arg1_1):
|
|||
"""\
|
||||
def forward(self, arg0_1, arg1_1):
|
||||
cos = torch.ops.aten.cos.default(arg0_1); arg0_1 = None
|
||||
select = torch.ops.aten.select.int(cos, 0, 0); select = None
|
||||
body_graph_0 = self.body_graph_0
|
||||
map_impl = torch.ops.higher_order.map_impl(body_graph_0, [cos], [arg1_1]); body_graph_0 = None
|
||||
getitem = map_impl[0]; map_impl = None
|
||||
sum_1 = torch.ops.aten.sum.default(getitem); getitem = None
|
||||
add = torch.ops.aten.add.Tensor(cos, sum_1); sum_1 = None
|
||||
select_1 = torch.ops.aten.select.int(cos, 0, 0); select_1 = None
|
||||
body_graph_1 = self.body_graph_1
|
||||
map_impl_1 = torch.ops.higher_order.map_impl(body_graph_1, [cos], [arg1_1]); body_graph_1 = cos = arg1_1 = None
|
||||
getitem_1 = map_impl_1[0]; map_impl_1 = None
|
||||
|
|
@ -5108,8 +5106,6 @@ def forward(self, arg0_1):
|
|||
gm.true_graph_0.code.strip(),
|
||||
"""\
|
||||
def forward(self, arg0_1):
|
||||
add = torch.ops.aten.add.Tensor(arg0_1, 4)
|
||||
add_1 = torch.ops.aten.add.Tensor(add, 5); add = add_1 = None
|
||||
cos = torch.ops.aten.cos.default(arg0_1); arg0_1 = None
|
||||
return (cos,)""",
|
||||
)
|
||||
|
|
@ -5118,8 +5114,6 @@ def forward(self, arg0_1):
|
|||
gm.false_graph_0.code.strip(),
|
||||
"""\
|
||||
def forward(self, arg0_1):
|
||||
add = torch.ops.aten.add.Tensor(arg0_1, 5)
|
||||
add_1 = torch.ops.aten.add.Tensor(add, 6); add = add_1 = None
|
||||
sin = torch.ops.aten.sin.default(arg0_1); arg0_1 = None
|
||||
return (sin,)""",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -423,6 +423,40 @@ class GraphModule(torch.nn.Module):
|
|||
""",
|
||||
)
|
||||
|
||||
def test_dce(self):
|
||||
@mark_compile_region
|
||||
def gn(x):
|
||||
x = torch.sin(x)
|
||||
# should be dce'd
|
||||
y = torch.cos(x) # noqa: F841
|
||||
return x
|
||||
|
||||
def fn(x):
|
||||
return gn(x)
|
||||
|
||||
backend = AotEagerAndRecordGraphs()
|
||||
torch.compile(fn, backend=backend, fullgraph=True)(
|
||||
torch.randn(4, requires_grad=False)
|
||||
)
|
||||
|
||||
if not TEST_WITH_CROSSREF:
|
||||
self.assertExpectedInline(
|
||||
normalize_gm(backend.fw_graphs[0].print_readable(print_output=False)),
|
||||
"""\
|
||||
class <lambda>(torch.nn.Module):
|
||||
def forward(self, arg0_1: "f32[4]"):
|
||||
repeated_subgraph0 = self.repeated_subgraph0
|
||||
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', arg0_1); repeated_subgraph0 = arg0_1 = None
|
||||
getitem: "f32[4]" = invoke_subgraph[0]; invoke_subgraph = None
|
||||
return (getitem,)
|
||||
|
||||
class repeated_subgraph0(torch.nn.Module):
|
||||
def forward(self, arg0_1: "f32[4]"):
|
||||
sin: "f32[4]" = torch.ops.aten.sin.default(arg0_1); arg0_1 = None
|
||||
return (sin,)
|
||||
""",
|
||||
)
|
||||
|
||||
def test_nonlocal_update(self):
|
||||
counter = 2
|
||||
|
||||
|
|
|
|||
|
|
@ -3781,7 +3781,6 @@ class GraphModule(torch.nn.Module):
|
|||
|
||||
class joint_graph0(torch.nn.Module):
|
||||
def forward(self, arg0_1: "f64[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i32[]", arg4_1: "i32[]", arg5_1: "f64[]"):
|
||||
mul: "f64[]" = torch.ops.aten.mul.Tensor(arg0_1, arg0_1); mul = None
|
||||
mul_1: "f64[]" = torch.ops.aten.mul.Tensor(arg5_1, arg0_1)
|
||||
mul_2: "f64[]" = torch.ops.aten.mul.Tensor(arg5_1, arg0_1); arg5_1 = arg0_1 = None
|
||||
add: "f64[]" = torch.ops.aten.add.Tensor(mul_2, mul_1); mul_2 = mul_1 = None
|
||||
|
|
|
|||
|
|
@ -1779,6 +1779,8 @@ class Graph:
|
|||
of functional operations or you supply your own custom
|
||||
function for detecting side-effectful nodes.
|
||||
"""
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
# Lint the graph first to make sure its topologically sorted, otherwise
|
||||
# DCE below will not behave as expected.
|
||||
self.lint()
|
||||
|
|
@ -1801,6 +1803,20 @@ class Graph:
|
|||
self.erase_node(node)
|
||||
changed = True
|
||||
|
||||
# Call DCE on the subgraphs
|
||||
if self.owning_module is not None:
|
||||
subgraph_names = OrderedSet(
|
||||
x.target for x in self.find_nodes(op="get_attr")
|
||||
)
|
||||
for child_name, child_module in self.owning_module.named_children():
|
||||
# Sometimes an owning_module can have unused children. Skip them
|
||||
# by checking them from get_attr node targets.
|
||||
if child_name in subgraph_names and isinstance(
|
||||
child_module, torch.fx.GraphModule
|
||||
):
|
||||
changed |= child_module.graph.eliminate_dead_code()
|
||||
child_module.recompile()
|
||||
|
||||
return changed
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user