mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Hierarchical Compile] Take into account mutation deps in cycle detection (#152506)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152506 Approved by: https://github.com/anijain2305 ghstack dependencies: #152389, #152505, #152410
This commit is contained in:
parent
bc8b305eb8
commit
779e647999
|
|
@ -11,6 +11,7 @@ from torch._dynamo.testing import (
|
|||
extract_graph_and_tracker,
|
||||
normalize_gm,
|
||||
)
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
|
||||
def extract_graph(fn, *args, **kwargs):
|
||||
|
|
@ -38,6 +39,19 @@ class GraphDededuplicationTests(TestCase):
|
|||
def run_and_return_graphs(self, fn, *args, **kwargs):
|
||||
return extract_graph(fn, *args, **kwargs)
|
||||
|
||||
def run_and_get_simple_graph(self):
|
||||
def fn(x, y):
|
||||
x0 = x + 1
|
||||
y0 = y + 2
|
||||
z = x0.sum() + y0.sum()
|
||||
return z
|
||||
|
||||
x = torch.rand(10, 10, requires_grad=False)
|
||||
y = torch.rand(10, 20, requires_grad=False)
|
||||
|
||||
_, _, fw_graphs = self.run_and_return_graphs(fn, x, y)
|
||||
return fw_graphs[0]
|
||||
|
||||
def test_single_subgraph(self):
|
||||
def inner_fn(x, y):
|
||||
x0 = x + 1
|
||||
|
|
@ -599,20 +613,12 @@ class <lambda>(torch.nn.Module):
|
|||
)
|
||||
|
||||
def test_cycle_detection_no_cycle(self):
|
||||
def fn(x, y):
|
||||
x0 = x + 1
|
||||
y0 = y + 2
|
||||
z = x0.sum() + y0.sum()
|
||||
return z
|
||||
mod = self.run_and_get_simple_graph()
|
||||
self.assertExpectedInline(
|
||||
_detect_cycles(mod.graph, {}), """no cycle detected"""
|
||||
)
|
||||
|
||||
x = torch.rand(10, 10, requires_grad=False)
|
||||
y = torch.rand(10, 20, requires_grad=False)
|
||||
|
||||
_, _, fw_graphs = self.run_and_return_graphs(fn, x, y)
|
||||
mod = fw_graphs[0]
|
||||
self.assertExpectedInline(_detect_cycles(mod.graph), """no cycle detected""")
|
||||
|
||||
def test_cycle_detection_simple(self):
|
||||
def test_cycle_detection_single_node(self):
|
||||
def fn(x, y):
|
||||
x0 = x + 1
|
||||
y0 = y + 2
|
||||
|
|
@ -629,8 +635,64 @@ class <lambda>(torch.nn.Module):
|
|||
args = add_node.args
|
||||
add_node.args = (args[0], add_2)
|
||||
self.assertExpectedInline(
|
||||
_detect_cycles(mod.graph),
|
||||
"""cycle detected in path: deque([arg0_1, add, sum_1, add_2, add])""",
|
||||
_detect_cycles(mod.graph, {add_2: OrderedSet([add_2])}),
|
||||
"""cycle detected in path: deque([output, add_2, add_2])""",
|
||||
)
|
||||
|
||||
def test_cycle_detection_two_node(self):
|
||||
def fn(x, y):
|
||||
x0 = x + 1
|
||||
y0 = y + 2
|
||||
z = x0.sum() + y0.sum()
|
||||
return z
|
||||
|
||||
x = torch.rand(10, 10, requires_grad=False)
|
||||
y = torch.rand(10, 20, requires_grad=False)
|
||||
|
||||
_, _, fw_graphs = self.run_and_return_graphs(fn, x, y)
|
||||
mod = fw_graphs[0]
|
||||
add_node = next(n for n in mod.graph.nodes if n.name == "add")
|
||||
add_2 = next(n for n in mod.graph.nodes if n.name == "add_2")
|
||||
args = add_node.args
|
||||
add_node.args = (args[0], add_2)
|
||||
self.assertExpectedInline(
|
||||
_detect_cycles(
|
||||
mod.graph,
|
||||
{add_2: OrderedSet([add_node]), add_node: OrderedSet([add_2])},
|
||||
),
|
||||
"""cycle detected in path: deque([output, add_2, add, add_2])""",
|
||||
)
|
||||
|
||||
def test_cycle_detection_arg_and_additional_deps(self):
|
||||
def fn(x, y):
|
||||
x0 = x + 1
|
||||
y0 = y + 2
|
||||
z = x0.sum() + y0.sum()
|
||||
return z
|
||||
|
||||
x = torch.rand(10, 10, requires_grad=False)
|
||||
y = torch.rand(10, 20, requires_grad=False)
|
||||
|
||||
_, _, fw_graphs = self.run_and_return_graphs(fn, x, y)
|
||||
mod = fw_graphs[0]
|
||||
add_node = next(n for n in mod.graph.nodes if n.name == "add")
|
||||
add_2 = next(n for n in mod.graph.nodes if n.name == "add_2")
|
||||
args = add_node.args
|
||||
add_node.args = (args[0], add_2)
|
||||
self.assertExpectedInline(
|
||||
_detect_cycles(mod.graph, {add_2: OrderedSet([add_node])}),
|
||||
"""cycle detected in path: deque([output, add_2, add, add_2])""",
|
||||
)
|
||||
|
||||
def test_cycle_detection_simple(self):
|
||||
mod = self.run_and_get_simple_graph()
|
||||
add_node = next(n for n in mod.graph.nodes if n.name == "add")
|
||||
add_2 = next(n for n in mod.graph.nodes if n.name == "add_2")
|
||||
args = add_node.args
|
||||
add_node.args = (args[0], add_2)
|
||||
self.assertExpectedInline(
|
||||
_detect_cycles(mod.graph, {}),
|
||||
"""cycle detected in path: deque([output, add_2, sum_1, add, add_2])""",
|
||||
)
|
||||
|
||||
def test_cycle_detection_complex(self):
|
||||
|
|
@ -664,8 +726,8 @@ class <lambda>(torch.nn.Module):
|
|||
args = invoke_subgraph_node.args
|
||||
invoke_subgraph_node.args = (add_2, args[1])
|
||||
self.assertExpectedInline(
|
||||
_detect_cycles(mod.graph),
|
||||
"""cycle detected in path: deque([arg0_1, invoke_subgraph_1, getitem_1, sum_2, add_2, invoke_subgraph, getitem, sum_1, add_1, add_2])""",
|
||||
_detect_cycles(mod.graph, {}),
|
||||
"""cycle detected in path: deque([output, add_2, add_1, sum_1, getitem, invoke_subgraph, add_2])""",
|
||||
)
|
||||
|
||||
def test_autocast_ordering(self):
|
||||
|
|
|
|||
|
|
@ -127,8 +127,6 @@ def _replace_region_with_subgraph(
|
|||
)
|
||||
return
|
||||
|
||||
from torch._inductor.pattern_matcher import stable_topological_sort
|
||||
|
||||
invoke_subgraph_node = graph.create_node(
|
||||
"call_function",
|
||||
torch.ops.higher_order.invoke_subgraph,
|
||||
|
|
@ -154,11 +152,8 @@ def _replace_region_with_subgraph(
|
|||
pass
|
||||
|
||||
if config.graph_deduplication_lint:
|
||||
_detect_cycles(graph)
|
||||
stable_topological_sort(graph)
|
||||
graph.lint()
|
||||
|
||||
if config.graph_deduplication_lint:
|
||||
_detect_cycles(graph, node_to_additional_deps)
|
||||
_stable_topological_sort(graph, node_to_additional_deps)
|
||||
graph.lint()
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -30,7 +30,9 @@ def _get_flat_args_unique(
|
|||
return args
|
||||
|
||||
|
||||
def _detect_cycles(graph: Graph) -> str:
|
||||
def _detect_cycles(
|
||||
graph: Graph, node_to_additional_deps: dict[Node, OrderedSet[Node]]
|
||||
) -> str:
|
||||
current_path: deque[Node] = deque()
|
||||
current_path_set: set[Node] = set()
|
||||
pending: deque[tuple[Node, Node]] = deque()
|
||||
|
|
@ -46,25 +48,30 @@ def _detect_cycles(graph: Graph) -> str:
|
|||
def current_path_head() -> Node:
|
||||
return current_path[-1]
|
||||
|
||||
for origin in graph.find_nodes(op="placeholder"):
|
||||
for origin in graph.find_nodes(op="output"):
|
||||
current_path.clear()
|
||||
current_path_set.clear()
|
||||
add_to_current_path(origin)
|
||||
for child in origin.users:
|
||||
for child in _get_flat_args_unique(origin, node_to_additional_deps):
|
||||
pending.append((child, origin))
|
||||
|
||||
while pending:
|
||||
cur_node, parent = pending.pop()
|
||||
|
||||
while current_path_head() != parent:
|
||||
# handle backtracking
|
||||
while current_path and current_path_head() != parent:
|
||||
pop_current_path()
|
||||
|
||||
if not isinstance(cur_node, Node):
|
||||
continue
|
||||
|
||||
if cur_node in current_path_set:
|
||||
current_path.append(cur_node)
|
||||
return f"cycle detected in path: {current_path}"
|
||||
|
||||
add_to_current_path(cur_node)
|
||||
for child in cur_node.users:
|
||||
|
||||
for child in _get_flat_args_unique(cur_node, node_to_additional_deps):
|
||||
pending.append((child, cur_node))
|
||||
|
||||
return "no cycle detected"
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user