mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
helper function for replacing nodes in aug graph (#166309)
When we do bucketing, we replace starts and waits with new nodes. This pr adds a helper to transfer the augmented graph additional deps. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166309 Approved by: https://github.com/IvanKobzarev
This commit is contained in:
parent
c54e2c5b41
commit
c3d205d598
|
|
@ -360,6 +360,191 @@ class TestAugmentedGraphHelper(TestCase):
|
|||
self.assertEqual(self.tracker.merge_sets[nodes[0]], {nodes[0]})
|
||||
self.assertEqual(len(self.tracker.merge_sets[nodes[1]]), 1)
|
||||
|
||||
# ========== Dependency Transfer Tests ==========
|
||||
|
||||
def test_transfer_with_cross_deps(self):
|
||||
"""Test transfer when erased nodes depend on each other."""
|
||||
# old_start -> old_wait, both get replaced
|
||||
# Should become: new_start -> new_wait
|
||||
graph = fx.Graph()
|
||||
x = graph.placeholder("x")
|
||||
old_start = graph.call_function(torch.relu, args=(x,), name="old_start")
|
||||
old_wait = graph.call_function(torch.abs, args=(x,), name="old_wait")
|
||||
compute = graph.call_function(torch.neg, args=(old_wait,), name="compute")
|
||||
graph.output(compute)
|
||||
|
||||
tracker = AugmentedGraphHelper(graph)
|
||||
|
||||
# Add cross-dependency: old_start -> old_wait
|
||||
tracker.add_extra_dep(n=old_wait, dep=old_start)
|
||||
# Add extra dep: compute -> old_wait
|
||||
tracker.add_extra_dep(n=compute, dep=old_wait)
|
||||
|
||||
# Create replacements
|
||||
new_start = graph.call_function(torch.sigmoid, args=(x,), name="new_start")
|
||||
new_wait = graph.call_function(torch.tanh, args=(x,), name="new_wait")
|
||||
|
||||
# Transfer both at once
|
||||
tracker.transfer_erased_node_deps({old_start: new_start, old_wait: new_wait})
|
||||
|
||||
# new_wait should depend on new_start (cross-dep redirected correctly)
|
||||
self.assertIn(new_start, tracker.extra_deps[new_wait])
|
||||
|
||||
# compute should depend on new_wait
|
||||
self.assertIn(new_wait, tracker.extra_deps[compute])
|
||||
|
||||
# Old nodes should be cleaned up
|
||||
self.assertEqual(len(tracker.extra_deps[old_start]), 0)
|
||||
self.assertEqual(len(tracker.extra_deps[old_wait]), 0)
|
||||
self.assertEqual(len(tracker.extra_uses[old_start]), 0)
|
||||
self.assertEqual(len(tracker.extra_uses[old_wait]), 0)
|
||||
|
||||
def test_transfer_preserves_external_deps(self):
|
||||
"""Test that external dependencies are preserved correctly."""
|
||||
# external1 -> old1, old2 -> external2
|
||||
# Should become: external1 -> new1, new2 -> external2
|
||||
graph = fx.Graph()
|
||||
x = graph.placeholder("x")
|
||||
external1 = graph.call_function(torch.relu, args=(x,), name="external1")
|
||||
old1 = graph.call_function(torch.abs, args=(x,), name="old1")
|
||||
old2 = graph.call_function(torch.neg, args=(x,), name="old2")
|
||||
external2 = graph.call_function(torch.sigmoid, args=(x,), name="external2")
|
||||
graph.output(external2)
|
||||
|
||||
tracker = AugmentedGraphHelper(graph)
|
||||
|
||||
# Add deps: old1 -> external1, external2 -> old2
|
||||
tracker.add_extra_dep(n=old1, dep=external1)
|
||||
tracker.add_extra_dep(n=external2, dep=old2)
|
||||
|
||||
# Create new nodes
|
||||
new1 = graph.call_function(torch.tanh, args=(x,), name="new1")
|
||||
new2 = graph.call_function(torch.exp, args=(x,), name="new2")
|
||||
|
||||
# Transfer
|
||||
tracker.transfer_erased_node_deps({old1: new1, old2: new2})
|
||||
|
||||
self.assertIn(external1, tracker.extra_deps[new1])
|
||||
|
||||
self.assertIn(new2, tracker.extra_deps[external2])
|
||||
self.assertNotIn(old2, tracker.extra_deps[external2])
|
||||
|
||||
def test_transfer_with_merge_sets(self):
|
||||
"""Test transfer when nodes have merge sets."""
|
||||
graph = fx.Graph()
|
||||
x = graph.placeholder("x")
|
||||
old_a = graph.call_function(torch.relu, args=(x,), name="old_a")
|
||||
old_b = graph.call_function(torch.abs, args=(x,), name="old_b")
|
||||
dep = graph.call_function(torch.neg, args=(x,), name="dep")
|
||||
user = graph.call_function(torch.sigmoid, args=(x,), name="user")
|
||||
graph.output(user)
|
||||
|
||||
tracker = AugmentedGraphHelper(graph)
|
||||
|
||||
# Merge old_a and old_b
|
||||
tracker.merge_to_set(old_a, old_b)
|
||||
|
||||
# Add deps: old_a -> dep, user -> old_a
|
||||
tracker.add_extra_dep(n=old_a, dep=dep)
|
||||
tracker.add_extra_dep(n=user, dep=old_a)
|
||||
|
||||
# Create new node
|
||||
new = graph.call_function(torch.tanh, args=(x,), name="new")
|
||||
|
||||
# Transfer (only need to specify one from merge set)
|
||||
tracker.transfer_erased_node_deps({old_a: new})
|
||||
|
||||
# new should have dep on dep
|
||||
self.assertIn(dep, tracker.extra_deps[new])
|
||||
|
||||
# user should depend on new
|
||||
self.assertIn(new, tracker.extra_deps[user])
|
||||
|
||||
# Both old nodes should be cleaned up
|
||||
self.assertEqual(len(tracker.extra_deps[old_a]), 0)
|
||||
self.assertEqual(len(tracker.extra_deps[old_b]), 0)
|
||||
|
||||
def test_transfer_multiple_merge_sets_with_chain(self):
|
||||
"""Test transferring multiple merge sets that depend on each other.
|
||||
|
||||
Setup:
|
||||
node1 (singleton)
|
||||
node2, node3 (merged)
|
||||
other_node (singleton)
|
||||
node4, node5 (merged)
|
||||
|
||||
Dependencies:
|
||||
node2 -> node1
|
||||
other_node -> node3
|
||||
node4 -> other_node
|
||||
|
||||
Transfer:
|
||||
(node2, node3) -> new_2_3
|
||||
(node4, node5) -> new_4_5
|
||||
|
||||
Expected:
|
||||
new_2_3 -> node1
|
||||
other_node -> new_2_3
|
||||
new_4_5 -> other_node
|
||||
"""
|
||||
graph = fx.Graph()
|
||||
x = graph.placeholder("x")
|
||||
|
||||
# Create nodes
|
||||
node1 = graph.call_function(torch.relu, args=(x,), name="node1")
|
||||
node2 = graph.call_function(torch.abs, args=(x,), name="node2")
|
||||
node3 = graph.call_function(torch.neg, args=(x,), name="node3")
|
||||
other_node = graph.call_function(torch.sigmoid, args=(x,), name="other_node")
|
||||
node4 = graph.call_function(torch.tanh, args=(x,), name="node4")
|
||||
node5 = graph.call_function(torch.exp, args=(x,), name="node5")
|
||||
graph.output(other_node)
|
||||
|
||||
tracker = AugmentedGraphHelper(graph)
|
||||
|
||||
# Merge node2 and node3
|
||||
tracker.merge_to_set(node2, node3)
|
||||
|
||||
# Merge node4 and node5
|
||||
tracker.merge_to_set(node4, node5)
|
||||
|
||||
# Add dependencies
|
||||
tracker.add_extra_dep(n=node2, dep=node1) # node2 -> node1
|
||||
tracker.add_extra_dep(n=other_node, dep=node3) # other_node -> node3
|
||||
tracker.add_extra_dep(n=node4, dep=other_node) # node4 -> other_node
|
||||
|
||||
# Create replacement nodes
|
||||
new_2_3 = graph.call_function(torch.sin, args=(x,), name="new_2_3")
|
||||
new_4_5 = graph.call_function(torch.cos, args=(x,), name="new_4_5")
|
||||
|
||||
# Transfer both merge sets atomically
|
||||
tracker.transfer_erased_node_deps(
|
||||
{
|
||||
node2: new_2_3, # This will transfer both node2 and node3
|
||||
node4: new_4_5, # This will transfer both node4 and node5
|
||||
}
|
||||
)
|
||||
|
||||
# Verify: new_2_3 should depend on node1
|
||||
self.assertIn(node1, tracker.extra_deps[new_2_3])
|
||||
|
||||
# Verify: other_node should depend on new_2_3 (not node3)
|
||||
self.assertIn(new_2_3, tracker.extra_deps[other_node])
|
||||
self.assertNotIn(node3, tracker.extra_deps[other_node])
|
||||
|
||||
# Verify: new_4_5 should depend on other_node
|
||||
self.assertIn(other_node, tracker.extra_deps[new_4_5])
|
||||
|
||||
# Verify: old nodes are cleaned up
|
||||
self.assertEqual(len(tracker.extra_deps[node2]), 0)
|
||||
self.assertEqual(len(tracker.extra_deps[node3]), 0)
|
||||
self.assertEqual(len(tracker.extra_deps[node4]), 0)
|
||||
self.assertEqual(len(tracker.extra_deps[node5]), 0)
|
||||
|
||||
# Verify: bidirectional consistency
|
||||
self.assertIn(new_2_3, tracker.extra_uses[node1])
|
||||
self.assertIn(other_node, tracker.extra_uses[new_2_3])
|
||||
self.assertIn(new_4_5, tracker.extra_uses[other_node])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._inductor.test_case import run_tests
|
||||
|
|
|
|||
|
|
@ -26,6 +26,8 @@ class AugmentedGraphHelper:
|
|||
|
||||
# Extra dependencies: node depends on dep (dep must come before node)
|
||||
self.extra_deps: dict[fx.Node, OrderedSet[fx.Node]] = defaultdict(OrderedSet)
|
||||
# Extra uses: reverse of extra_deps (node is used by user)
|
||||
self.extra_uses: dict[fx.Node, OrderedSet[fx.Node]] = defaultdict(OrderedSet)
|
||||
# Note: only reflect original ancestors, not maintained through additional deps
|
||||
# or merge sets
|
||||
self.node_ancestors = node_ancestors
|
||||
|
|
@ -33,6 +35,12 @@ class AugmentedGraphHelper:
|
|||
def add_extra_dep(self, *, n: fx.Node, dep: fx.Node) -> None:
|
||||
"""Add extra dependency: node depends on dep."""
|
||||
self.extra_deps[n].add(dep)
|
||||
self.extra_uses[dep].add(n)
|
||||
|
||||
def remove_extra_dep(self, *, n: fx.Node, dep: fx.Node) -> None:
|
||||
if dep in self.extra_deps[n]:
|
||||
self.extra_deps[n].discard(dep)
|
||||
self.extra_uses[dep].discard(n)
|
||||
|
||||
def merge_to_set(self, existing_node: fx.Node, new_node: fx.Node) -> None:
|
||||
"""
|
||||
|
|
@ -123,3 +131,51 @@ class AugmentedGraphHelper:
|
|||
queue.append(dep)
|
||||
|
||||
return False
|
||||
|
||||
def transfer_erased_node_deps(self, erased_to_new: dict[fx.Node, fx.Node]) -> None:
|
||||
"""
|
||||
Transfer all extra dependencies from erased nodes to their replacements, handling
|
||||
cross-dependencies between erased nodes correctly.
|
||||
"""
|
||||
erased_merge_sets: dict[fx.Node, fx.Node] = {}
|
||||
|
||||
for replaced, new in erased_to_new.items():
|
||||
for equiv in self.merge_sets[replaced]:
|
||||
erased_merge_sets[equiv] = new
|
||||
|
||||
# Transfer dependencies
|
||||
for old_node, new_node in erased_merge_sets.items():
|
||||
# Transfer dependencies FROM old_node (what old_node depended on)
|
||||
for extra_dep in self.extra_deps[old_node]:
|
||||
# Redirect if dep is also being erased
|
||||
updated_dep = erased_merge_sets.get(extra_dep, extra_dep)
|
||||
self.extra_deps[new_node].add(updated_dep)
|
||||
self.extra_uses[updated_dep].discard(old_node)
|
||||
self.extra_uses[updated_dep].add(new_node)
|
||||
|
||||
# Transfer dependencies TO old_node (what depended on old_node)
|
||||
for extra_use in self.extra_uses[old_node]:
|
||||
# Redirect if this user is also being erased
|
||||
updated_use = erased_merge_sets.get(extra_use, extra_use)
|
||||
|
||||
# Update the user's deps to point to new_node
|
||||
self.extra_deps[updated_use].discard(old_node)
|
||||
self.extra_deps[updated_use].add(new_node)
|
||||
self.extra_uses[new_node].add(updated_use)
|
||||
|
||||
# Clean up erased nodes
|
||||
for old_node in erased_merge_sets.keys():
|
||||
self.extra_deps[old_node].clear()
|
||||
self.extra_uses[old_node].clear()
|
||||
del self.merge_sets[old_node]
|
||||
|
||||
def get_all_extra_deps(self) -> dict[fx.Node, OrderedSet[fx.Node]]:
|
||||
"""
|
||||
Get all extra dependencies in a format suitable for topological sort.
|
||||
Returns a copy to avoid external modifications.
|
||||
"""
|
||||
return {
|
||||
node: OrderedSet(deps)
|
||||
for node, deps in self.extra_deps.items()
|
||||
if deps # Only include nodes with non-empty deps
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user