[Hierarchical Compile] Take into account mutation deps in cycle detection (#152506)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152506
Approved by: https://github.com/anijain2305
ghstack dependencies: #152389, #152505, #152410
This commit is contained in:
Michael Lazos 2025-05-09 14:33:22 -07:00 committed by PyTorch MergeBot
parent bc8b305eb8
commit 779e647999
3 changed files with 93 additions and 29 deletions

View File

@ -11,6 +11,7 @@ from torch._dynamo.testing import (
extract_graph_and_tracker,
normalize_gm,
)
from torch.utils._ordered_set import OrderedSet
def extract_graph(fn, *args, **kwargs):
@ -38,6 +39,19 @@ class GraphDededuplicationTests(TestCase):
def run_and_return_graphs(self, fn, *args, **kwargs):
return extract_graph(fn, *args, **kwargs)
def run_and_get_simple_graph(self):
def fn(x, y):
x0 = x + 1
y0 = y + 2
z = x0.sum() + y0.sum()
return z
x = torch.rand(10, 10, requires_grad=False)
y = torch.rand(10, 20, requires_grad=False)
_, _, fw_graphs = self.run_and_return_graphs(fn, x, y)
return fw_graphs[0]
def test_single_subgraph(self):
def inner_fn(x, y):
x0 = x + 1
@ -599,20 +613,12 @@ class <lambda>(torch.nn.Module):
)
def test_cycle_detection_no_cycle(self):
def fn(x, y):
x0 = x + 1
y0 = y + 2
z = x0.sum() + y0.sum()
return z
mod = self.run_and_get_simple_graph()
self.assertExpectedInline(
_detect_cycles(mod.graph, {}), """no cycle detected"""
)
x = torch.rand(10, 10, requires_grad=False)
y = torch.rand(10, 20, requires_grad=False)
_, _, fw_graphs = self.run_and_return_graphs(fn, x, y)
mod = fw_graphs[0]
self.assertExpectedInline(_detect_cycles(mod.graph), """no cycle detected""")
def test_cycle_detection_simple(self):
def test_cycle_detection_single_node(self):
def fn(x, y):
x0 = x + 1
y0 = y + 2
@ -629,8 +635,64 @@ class <lambda>(torch.nn.Module):
args = add_node.args
add_node.args = (args[0], add_2)
self.assertExpectedInline(
_detect_cycles(mod.graph),
"""cycle detected in path: deque([arg0_1, add, sum_1, add_2, add])""",
_detect_cycles(mod.graph, {add_2: OrderedSet([add_2])}),
"""cycle detected in path: deque([output, add_2, add_2])""",
)
def test_cycle_detection_two_node(self):
def fn(x, y):
x0 = x + 1
y0 = y + 2
z = x0.sum() + y0.sum()
return z
x = torch.rand(10, 10, requires_grad=False)
y = torch.rand(10, 20, requires_grad=False)
_, _, fw_graphs = self.run_and_return_graphs(fn, x, y)
mod = fw_graphs[0]
add_node = next(n for n in mod.graph.nodes if n.name == "add")
add_2 = next(n for n in mod.graph.nodes if n.name == "add_2")
args = add_node.args
add_node.args = (args[0], add_2)
self.assertExpectedInline(
_detect_cycles(
mod.graph,
{add_2: OrderedSet([add_node]), add_node: OrderedSet([add_2])},
),
"""cycle detected in path: deque([output, add_2, add, add_2])""",
)
def test_cycle_detection_arg_and_additional_deps(self):
def fn(x, y):
x0 = x + 1
y0 = y + 2
z = x0.sum() + y0.sum()
return z
x = torch.rand(10, 10, requires_grad=False)
y = torch.rand(10, 20, requires_grad=False)
_, _, fw_graphs = self.run_and_return_graphs(fn, x, y)
mod = fw_graphs[0]
add_node = next(n for n in mod.graph.nodes if n.name == "add")
add_2 = next(n for n in mod.graph.nodes if n.name == "add_2")
args = add_node.args
add_node.args = (args[0], add_2)
self.assertExpectedInline(
_detect_cycles(mod.graph, {add_2: OrderedSet([add_node])}),
"""cycle detected in path: deque([output, add_2, add, add_2])""",
)
def test_cycle_detection_simple(self):
mod = self.run_and_get_simple_graph()
add_node = next(n for n in mod.graph.nodes if n.name == "add")
add_2 = next(n for n in mod.graph.nodes if n.name == "add_2")
args = add_node.args
add_node.args = (args[0], add_2)
self.assertExpectedInline(
_detect_cycles(mod.graph, {}),
"""cycle detected in path: deque([output, add_2, sum_1, add, add_2])""",
)
def test_cycle_detection_complex(self):
@ -664,8 +726,8 @@ class <lambda>(torch.nn.Module):
args = invoke_subgraph_node.args
invoke_subgraph_node.args = (add_2, args[1])
self.assertExpectedInline(
_detect_cycles(mod.graph),
"""cycle detected in path: deque([arg0_1, invoke_subgraph_1, getitem_1, sum_2, add_2, invoke_subgraph, getitem, sum_1, add_1, add_2])""",
_detect_cycles(mod.graph, {}),
"""cycle detected in path: deque([output, add_2, add_1, sum_1, getitem, invoke_subgraph, add_2])""",
)
def test_autocast_ordering(self):

View File

@ -127,8 +127,6 @@ def _replace_region_with_subgraph(
)
return
from torch._inductor.pattern_matcher import stable_topological_sort
invoke_subgraph_node = graph.create_node(
"call_function",
torch.ops.higher_order.invoke_subgraph,
@ -154,11 +152,8 @@ def _replace_region_with_subgraph(
pass
if config.graph_deduplication_lint:
_detect_cycles(graph)
stable_topological_sort(graph)
graph.lint()
if config.graph_deduplication_lint:
_detect_cycles(graph, node_to_additional_deps)
_stable_topological_sort(graph, node_to_additional_deps)
graph.lint()

View File

@ -30,7 +30,9 @@ def _get_flat_args_unique(
return args
def _detect_cycles(graph: Graph) -> str:
def _detect_cycles(
graph: Graph, node_to_additional_deps: dict[Node, OrderedSet[Node]]
) -> str:
current_path: deque[Node] = deque()
current_path_set: set[Node] = set()
pending: deque[tuple[Node, Node]] = deque()
@ -46,25 +48,30 @@ def _detect_cycles(graph: Graph) -> str:
def current_path_head() -> Node:
return current_path[-1]
for origin in graph.find_nodes(op="placeholder"):
for origin in graph.find_nodes(op="output"):
current_path.clear()
current_path_set.clear()
add_to_current_path(origin)
for child in origin.users:
for child in _get_flat_args_unique(origin, node_to_additional_deps):
pending.append((child, origin))
while pending:
cur_node, parent = pending.pop()
while current_path_head() != parent:
# handle backtracking
while current_path and current_path_head() != parent:
pop_current_path()
if not isinstance(cur_node, Node):
continue
if cur_node in current_path_set:
current_path.append(cur_node)
return f"cycle detected in path: {current_path}"
add_to_current_path(cur_node)
for child in cur_node.users:
for child in _get_flat_args_unique(cur_node, node_to_additional_deps):
pending.append((child, cur_node))
return "no cycle detected"