From 2d1e92307d3e67622f4fe8058d62e44fe4fa2f4e Mon Sep 17 00:00:00 2001 From: Xiaochang Wu Date: Mon, 28 Jul 2025 17:36:29 +0000 Subject: [PATCH] Partitioner: Fix to align partition node order with original graph (#157892) Fixes #157891 Pull Request resolved: https://github.com/pytorch/pytorch/pull/157892 Approved by: https://github.com/ezyang --- test/fx/test_partitioner_order.py | 15 ++++++++++++-- torch/fx/passes/infra/partitioner.py | 30 ++++++++++++++++++---------- torch/fx/passes/utils/fuser_utils.py | 4 ++-- 3 files changed, 35 insertions(+), 14 deletions(-) diff --git a/test/fx/test_partitioner_order.py b/test/fx/test_partitioner_order.py index ab50b59fb96..f4c3ef072f9 100644 --- a/test/fx/test_partitioner_order.py +++ b/test/fx/test_partitioner_order.py @@ -24,6 +24,7 @@ class DummyPartitioner(CapabilityBasedPartitioner): ) +# original graph node order is: ['x', 'add', 'add_1', 'output'] class AddModule(torch.nn.Module): def forward(self, x): y = torch.add(x, x) @@ -32,8 +33,18 @@ class AddModule(torch.nn.Module): class TestPartitionerOrder(TestCase): - # partitoner test to check graph node order - def test_partitioner_order(self): + # partitoner test to check graph node order remains the same with the original graph after partitioning + def test_partitioner_graph_node_order(self): + m = AddModule() + traced_m = torch.fx.symbolic_trace(m) + origin_node_order = [n.name for n in traced_m.graph.nodes] + partions = DummyPartitioner(traced_m).propose_partitions() + partion_nodes = [list(partition.nodes) for partition in partions] + partition_node_order = [n.name for n in partion_nodes[0]] + self.assertTrue(partition_node_order == origin_node_order) + + # partitoner test to check graph node order remains the same during multiple runs + def test_partitioner_multiple_runs_order(self): m = AddModule() traced_m = torch.fx.symbolic_trace(m) partitions = DummyPartitioner(traced_m).propose_partitions() diff --git a/torch/fx/passes/infra/partitioner.py b/torch/fx/passes/infra/partitioner.py index 43866109094..ec0745ad61d 100644 --- a/torch/fx/passes/infra/partitioner.py +++ b/torch/fx/passes/infra/partitioner.py @@ -18,16 +18,18 @@ logger.setLevel(logging.WARNING) class Partition: def __init__( - self, id: Optional[int] = None, nodes: Optional[Iterable[Node]] = None + self, + id: Optional[int] = None, + nodes: Optional[Iterable[tuple[Node, Optional[int]]]] = None, ): self.id = id - self.nodes = dict.fromkeys(nodes) if nodes is not None else {} + self.nodes = dict(nodes) if nodes is not None else {} def __repr__(self) -> str: return str(self.nodes) - def add_node(self, node: Node): - self.nodes.update({node: None}) + def add_node(self, node: Node, node_order: Optional[int] = None): + self.nodes.update({node: node_order}) def remove_node(self, node: Node): del self.nodes[node] @@ -172,7 +174,7 @@ class CapabilityBasedPartitioner: return merge_id, True - def merge_single_node(node: Node, id: Optional[int]): + def merge_single_node(node: Node, node_order: Optional[int], id: Optional[int]): def _update_partition_map(node: Node, id: int): # Iterate through all the users of this node and update the partition map to indicate # that there is a path from the partition id of this node to the target partition id. @@ -189,16 +191,16 @@ class CapabilityBasedPartitioner: assignment.pop(node) elif id not in partitions_by_id: assignment[node] = id - partitions_by_id[id] = Partition(id=id, nodes=[node]) + partitions_by_id[id] = Partition(id=id, nodes=[(node, node_order)]) partition_users[id] = set(node.users) _update_partition_map(node, id) else: assignment[node] = id - partitions_by_id[id].add_node(node) + partitions_by_id[id].add_node(node, node_order) logger.debug("Proposing partitions...") - for node in reversed(self.graph_module.graph.nodes): + for node_order, node in enumerate(reversed(self.graph_module.graph.nodes)): # use Dict as an ordered set to ensure deterministic partitioning result, don't care value merge_candidates: dict[int, None] = {} @@ -211,7 +213,7 @@ class CapabilityBasedPartitioner: partition_id = next(new_partition_id) nodes_order[node] = partition_id partitions_order[partition_id] = partition_id - merge_single_node(node, partition_id) + merge_single_node(node, node_order, partition_id) merge_candidates[partition_id] = None # merge all possible partitions @@ -228,6 +230,14 @@ class CapabilityBasedPartitioner: # in the graph, otherwise, this is a no-op self_id, _ = maybe_merge_partition(self_id, other_id) + # sort partition nodes based on descending node order + for partition in partitions_by_id.values(): + partition.nodes = dict( + sorted( + partition.nodes.items(), key=operator.itemgetter(1), reverse=True + ) + ) + # post processing to re-assign "getitem" nodes into upstream partition logger.debug("Reassigning getitem nodes to its producer node's partition...") nodes_reassignment: dict[Node, int] = {} @@ -248,7 +258,7 @@ class CapabilityBasedPartitioner: if assignment.get(user, None) != id: # type: ignore[arg-type] nodes_reassignment[user] = id # type: ignore[assignment] for node, id in nodes_reassignment.items(): - merge_single_node(node, id) + merge_single_node(node, None, id) # filter out single node partitions if not self.allows_single_node_partition: diff --git a/torch/fx/passes/utils/fuser_utils.py b/torch/fx/passes/utils/fuser_utils.py index 1b22490405d..33db9fd03d7 100644 --- a/torch/fx/passes/utils/fuser_utils.py +++ b/torch/fx/passes/utils/fuser_utils.py @@ -96,7 +96,7 @@ def fuse_as_graphmodule( gm: GraphModule, nodes: NodeList, module_name: str, - partition_lookup_table: _Optional[dict[Node, None]] = None, + partition_lookup_table: _Optional[dict[Node, _Optional[int]]] = None, *, always_return_tuple: bool = False, ) -> tuple[GraphModule, tuple[Node, ...], tuple[Node, ...]]: @@ -249,7 +249,7 @@ def erase_nodes(gm: GraphModule, nodes: NodeList) -> None: @compatibility(is_backward_compatible=False) def fuse_by_partitions( gm: GraphModule, - partitions: list[dict[Node, None]], + partitions: list[dict[Node, _Optional[int]]], prefix: str = "fused_", always_return_tuple: bool = False, ) -> GraphModule: