helper function for replacing nodes in aug graph (#166309)

When we do bucketing, we replace starts and waits with new nodes. This pr adds a helper to transfer the augmented graph additional deps.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166309
Approved by: https://github.com/IvanKobzarev
This commit is contained in:
eellison 2025-10-29 06:14:36 -07:00 committed by PyTorch MergeBot
parent c54e2c5b41
commit c3d205d598
2 changed files with 241 additions and 0 deletions

View File

@ -360,6 +360,191 @@ class TestAugmentedGraphHelper(TestCase):
self.assertEqual(self.tracker.merge_sets[nodes[0]], {nodes[0]})
self.assertEqual(len(self.tracker.merge_sets[nodes[1]]), 1)
# ========== Dependency Transfer Tests ==========
def test_transfer_with_cross_deps(self):
"""Test transfer when erased nodes depend on each other."""
# old_start -> old_wait, both get replaced
# Should become: new_start -> new_wait
graph = fx.Graph()
x = graph.placeholder("x")
old_start = graph.call_function(torch.relu, args=(x,), name="old_start")
old_wait = graph.call_function(torch.abs, args=(x,), name="old_wait")
compute = graph.call_function(torch.neg, args=(old_wait,), name="compute")
graph.output(compute)
tracker = AugmentedGraphHelper(graph)
# Add cross-dependency: old_start -> old_wait
tracker.add_extra_dep(n=old_wait, dep=old_start)
# Add extra dep: compute -> old_wait
tracker.add_extra_dep(n=compute, dep=old_wait)
# Create replacements
new_start = graph.call_function(torch.sigmoid, args=(x,), name="new_start")
new_wait = graph.call_function(torch.tanh, args=(x,), name="new_wait")
# Transfer both at once
tracker.transfer_erased_node_deps({old_start: new_start, old_wait: new_wait})
# new_wait should depend on new_start (cross-dep redirected correctly)
self.assertIn(new_start, tracker.extra_deps[new_wait])
# compute should depend on new_wait
self.assertIn(new_wait, tracker.extra_deps[compute])
# Old nodes should be cleaned up
self.assertEqual(len(tracker.extra_deps[old_start]), 0)
self.assertEqual(len(tracker.extra_deps[old_wait]), 0)
self.assertEqual(len(tracker.extra_uses[old_start]), 0)
self.assertEqual(len(tracker.extra_uses[old_wait]), 0)
def test_transfer_preserves_external_deps(self):
"""Test that external dependencies are preserved correctly."""
# external1 -> old1, old2 -> external2
# Should become: external1 -> new1, new2 -> external2
graph = fx.Graph()
x = graph.placeholder("x")
external1 = graph.call_function(torch.relu, args=(x,), name="external1")
old1 = graph.call_function(torch.abs, args=(x,), name="old1")
old2 = graph.call_function(torch.neg, args=(x,), name="old2")
external2 = graph.call_function(torch.sigmoid, args=(x,), name="external2")
graph.output(external2)
tracker = AugmentedGraphHelper(graph)
# Add deps: old1 -> external1, external2 -> old2
tracker.add_extra_dep(n=old1, dep=external1)
tracker.add_extra_dep(n=external2, dep=old2)
# Create new nodes
new1 = graph.call_function(torch.tanh, args=(x,), name="new1")
new2 = graph.call_function(torch.exp, args=(x,), name="new2")
# Transfer
tracker.transfer_erased_node_deps({old1: new1, old2: new2})
self.assertIn(external1, tracker.extra_deps[new1])
self.assertIn(new2, tracker.extra_deps[external2])
self.assertNotIn(old2, tracker.extra_deps[external2])
def test_transfer_with_merge_sets(self):
"""Test transfer when nodes have merge sets."""
graph = fx.Graph()
x = graph.placeholder("x")
old_a = graph.call_function(torch.relu, args=(x,), name="old_a")
old_b = graph.call_function(torch.abs, args=(x,), name="old_b")
dep = graph.call_function(torch.neg, args=(x,), name="dep")
user = graph.call_function(torch.sigmoid, args=(x,), name="user")
graph.output(user)
tracker = AugmentedGraphHelper(graph)
# Merge old_a and old_b
tracker.merge_to_set(old_a, old_b)
# Add deps: old_a -> dep, user -> old_a
tracker.add_extra_dep(n=old_a, dep=dep)
tracker.add_extra_dep(n=user, dep=old_a)
# Create new node
new = graph.call_function(torch.tanh, args=(x,), name="new")
# Transfer (only need to specify one from merge set)
tracker.transfer_erased_node_deps({old_a: new})
# new should have dep on dep
self.assertIn(dep, tracker.extra_deps[new])
# user should depend on new
self.assertIn(new, tracker.extra_deps[user])
# Both old nodes should be cleaned up
self.assertEqual(len(tracker.extra_deps[old_a]), 0)
self.assertEqual(len(tracker.extra_deps[old_b]), 0)
def test_transfer_multiple_merge_sets_with_chain(self):
"""Test transferring multiple merge sets that depend on each other.
Setup:
node1 (singleton)
node2, node3 (merged)
other_node (singleton)
node4, node5 (merged)
Dependencies:
node2 -> node1
other_node -> node3
node4 -> other_node
Transfer:
(node2, node3) -> new_2_3
(node4, node5) -> new_4_5
Expected:
new_2_3 -> node1
other_node -> new_2_3
new_4_5 -> other_node
"""
graph = fx.Graph()
x = graph.placeholder("x")
# Create nodes
node1 = graph.call_function(torch.relu, args=(x,), name="node1")
node2 = graph.call_function(torch.abs, args=(x,), name="node2")
node3 = graph.call_function(torch.neg, args=(x,), name="node3")
other_node = graph.call_function(torch.sigmoid, args=(x,), name="other_node")
node4 = graph.call_function(torch.tanh, args=(x,), name="node4")
node5 = graph.call_function(torch.exp, args=(x,), name="node5")
graph.output(other_node)
tracker = AugmentedGraphHelper(graph)
# Merge node2 and node3
tracker.merge_to_set(node2, node3)
# Merge node4 and node5
tracker.merge_to_set(node4, node5)
# Add dependencies
tracker.add_extra_dep(n=node2, dep=node1) # node2 -> node1
tracker.add_extra_dep(n=other_node, dep=node3) # other_node -> node3
tracker.add_extra_dep(n=node4, dep=other_node) # node4 -> other_node
# Create replacement nodes
new_2_3 = graph.call_function(torch.sin, args=(x,), name="new_2_3")
new_4_5 = graph.call_function(torch.cos, args=(x,), name="new_4_5")
# Transfer both merge sets atomically
tracker.transfer_erased_node_deps(
{
node2: new_2_3, # This will transfer both node2 and node3
node4: new_4_5, # This will transfer both node4 and node5
}
)
# Verify: new_2_3 should depend on node1
self.assertIn(node1, tracker.extra_deps[new_2_3])
# Verify: other_node should depend on new_2_3 (not node3)
self.assertIn(new_2_3, tracker.extra_deps[other_node])
self.assertNotIn(node3, tracker.extra_deps[other_node])
# Verify: new_4_5 should depend on other_node
self.assertIn(other_node, tracker.extra_deps[new_4_5])
# Verify: old nodes are cleaned up
self.assertEqual(len(tracker.extra_deps[node2]), 0)
self.assertEqual(len(tracker.extra_deps[node3]), 0)
self.assertEqual(len(tracker.extra_deps[node4]), 0)
self.assertEqual(len(tracker.extra_deps[node5]), 0)
# Verify: bidirectional consistency
self.assertIn(new_2_3, tracker.extra_uses[node1])
self.assertIn(other_node, tracker.extra_uses[new_2_3])
self.assertIn(new_4_5, tracker.extra_uses[other_node])
if __name__ == "__main__":
from torch._inductor.test_case import run_tests

View File

@ -26,6 +26,8 @@ class AugmentedGraphHelper:
# Extra dependencies: node depends on dep (dep must come before node)
self.extra_deps: dict[fx.Node, OrderedSet[fx.Node]] = defaultdict(OrderedSet)
# Extra uses: reverse of extra_deps (node is used by user)
self.extra_uses: dict[fx.Node, OrderedSet[fx.Node]] = defaultdict(OrderedSet)
# Note: only reflect original ancestors, not maintained through additional deps
# or merge sets
self.node_ancestors = node_ancestors
@ -33,6 +35,12 @@ class AugmentedGraphHelper:
def add_extra_dep(self, *, n: fx.Node, dep: fx.Node) -> None:
"""Add extra dependency: node depends on dep."""
self.extra_deps[n].add(dep)
self.extra_uses[dep].add(n)
def remove_extra_dep(self, *, n: fx.Node, dep: fx.Node) -> None:
if dep in self.extra_deps[n]:
self.extra_deps[n].discard(dep)
self.extra_uses[dep].discard(n)
def merge_to_set(self, existing_node: fx.Node, new_node: fx.Node) -> None:
"""
@ -123,3 +131,51 @@ class AugmentedGraphHelper:
queue.append(dep)
return False
def transfer_erased_node_deps(self, erased_to_new: dict[fx.Node, fx.Node]) -> None:
"""
Transfer all extra dependencies from erased nodes to their replacements, handling
cross-dependencies between erased nodes correctly.
"""
erased_merge_sets: dict[fx.Node, fx.Node] = {}
for replaced, new in erased_to_new.items():
for equiv in self.merge_sets[replaced]:
erased_merge_sets[equiv] = new
# Transfer dependencies
for old_node, new_node in erased_merge_sets.items():
# Transfer dependencies FROM old_node (what old_node depended on)
for extra_dep in self.extra_deps[old_node]:
# Redirect if dep is also being erased
updated_dep = erased_merge_sets.get(extra_dep, extra_dep)
self.extra_deps[new_node].add(updated_dep)
self.extra_uses[updated_dep].discard(old_node)
self.extra_uses[updated_dep].add(new_node)
# Transfer dependencies TO old_node (what depended on old_node)
for extra_use in self.extra_uses[old_node]:
# Redirect if this user is also being erased
updated_use = erased_merge_sets.get(extra_use, extra_use)
# Update the user's deps to point to new_node
self.extra_deps[updated_use].discard(old_node)
self.extra_deps[updated_use].add(new_node)
self.extra_uses[new_node].add(updated_use)
# Clean up erased nodes
for old_node in erased_merge_sets.keys():
self.extra_deps[old_node].clear()
self.extra_uses[old_node].clear()
del self.merge_sets[old_node]
def get_all_extra_deps(self) -> dict[fx.Node, OrderedSet[fx.Node]]:
"""
Get all extra dependencies in a format suitable for topological sort.
Returns a copy to avoid external modifications.
"""
return {
node: OrderedSet(deps)
for node, deps in self.extra_deps.items()
if deps # Only include nodes with non-empty deps
}