diff --git a/test/distributed/test_inductor_collectives.py b/test/distributed/test_inductor_collectives.py index fad2f819560..1f09d72ea2b 100644 --- a/test/distributed/test_inductor_collectives.py +++ b/test/distributed/test_inductor_collectives.py @@ -19,6 +19,7 @@ from torch._dynamo.utils import same from torch._inductor.comms import ( _reorder_communication_preserving_peak_memory_internal, ReorderInfo, + sink_waits_iterative, ) from torch._inductor.compile_fx import compile_fx as inductor_compile_fx from torch._inductor.scheduler import BaseSchedulerNode @@ -1621,7 +1622,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase): comm from moving due to data dependency. """ - def func(x, w, ag_0, ag_1, *, tag, ranks, group_size): + def func(x, w, ag_0, ag_1, ag_2, ag_3, *, tag, ranks, group_size): # do some unrelated matmuls y = torch.mm(x, w) @@ -1654,14 +1655,52 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase): # wait op rs_0_out = torch.ops.c10d_functional.wait_tensor(rs_0_out) rs_1_out = torch.ops.c10d_functional.wait_tensor(rs_1_out) + y += torch.mm(2 * x, 2 * w) - return y, ag_0_out, ag_1_out, rs_0_out, rs_1_out + # cast the inputs + ag_2_cast = ag_2.to(torch.bfloat16) + ag_3_cast = ag_3.to(torch.bfloat16) + ag_2_out = torch.ops._c10d_functional.all_gather_into_tensor( + ag_2_cast, group_size, group_name + ) + ag_3_out = torch.ops._c10d_functional.all_gather_into_tensor( + ag_3_cast, group_size, group_name + ) + + # wait op + ag_2_out = torch.ops.c10d_functional.wait_tensor(ag_2_out) + ag_3_out = torch.ops.c10d_functional.wait_tensor(ag_3_out) + + # + rs_2_out = torch.ops._c10d_functional.reduce_scatter_tensor( + ag_2_cast, "sum", group_size, group_name + ) + rs_3_out = torch.ops._c10d_functional.reduce_scatter_tensor( + ag_3_cast, "sum", group_size, group_name + ) + + # wait op + rs_2_out = torch.ops.c10d_functional.wait_tensor(rs_2_out) + rs_3_out = torch.ops.c10d_functional.wait_tensor(rs_3_out) + return ( + y, + ag_0_out, + ag_1_out, + ag_2_out, + ag_3_out, + rs_0_out, + rs_1_out, + rs_2_out, + rs_3_out, + ) x = torch.ones(4, 384, device="cuda", dtype=torch.float32) w = torch.ones(384, 512, device="cuda", dtype=torch.float32) - ag_0 = torch.ones(384, 512, device="cuda", dtype=torch.float32) - ag_1 = torch.ones(512, device="cuda", dtype=torch.float32) - inputs = [x, w, ag_0, ag_1] + ag_0 = torch.ones(1024, 512, device="cuda", dtype=torch.float32) + ag_1 = torch.ones(512, 1024, device="cuda", dtype=torch.float32) + ag_2 = torch.ones(1024, 512, device="cuda", dtype=torch.float32) + ag_3 = torch.ones(512, 1024, device="cuda", dtype=torch.float32) + inputs = [x, w, ag_0, ag_1, ag_2, ag_3] # get stats directly from the internal helper without affecting the real pass's signature node_stats: Optional[dict[BaseSchedulerNode, ReorderInfo]] = None @@ -1679,11 +1718,15 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase): with torch._inductor.config.patch( { "bucket_all_gathers_fx": "all", + "bucket_all_gathers_fx_bucket_size_determinator": lambda _: 2, "bucket_reduce_scatters_fx": "all", + "bucket_reduce_scatters_fx_bucket_size_determinator": lambda _: 2, "reorder_for_compute_comm_overlap": True, "reorder_for_compute_comm_overlap_passes": [ + sink_waits_iterative, _reorder_communication_preserving_peak_memory, ], + "allow_buffer_reuse": False, } ): compiled = torch.compile(func) @@ -1694,31 +1737,30 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase): FileCheck() .check_count( "torch.ops._c10d_functional.all_gather_into_tensor_out.default(", - count=1, + count=2, exactly=True, ) + .check( + "extern_kernels.mm", + ) + .check( + "extern_kernels.addmm", + ) .run(code) ) ( FileCheck() .check_count( "torch.ops._c10d_functional.reduce_scatter_tensor.default(", - count=1, + count=2, exactly=True, ) - .run(code) - ) - ( - FileCheck() - .check( - "torch.ops._c10d_functional.all_gather_into_tensor_out.default(", - ) - .check( - "torch.ops._c10d_functional.reduce_scatter_tensor.default(", - ) .check( "extern_kernels.mm", ) + .check( + "extern_kernels.addmm", + ) .run(code) ) out = compiled(*inputs, **self.get_world_trs()) @@ -1726,7 +1768,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase): assert same(out, correct), f"{out} va {correct}" assert node_stats is not None self.assertTrue(isinstance(node_stats, dict)) - self.assertEqual(len(node_stats), 2) + self.assertEqual(len(node_stats), 4) it = iter(node_stats.values()) node_stat0 = next(it) self.assertTrue(node_stat0.moves > 0) diff --git a/torch/_inductor/comms.py b/torch/_inductor/comms.py index caaf43dba59..f93485333d3 100644 --- a/torch/_inductor/comms.py +++ b/torch/_inductor/comms.py @@ -4,7 +4,6 @@ from __future__ import annotations import heapq import importlib -import itertools import logging import operator import sys @@ -149,9 +148,8 @@ def is_gemm_like(node: Optional[Union[IRNode, Operation]]) -> bool: return True if ( - hasattr(node, "python_kernel_name") - and node.python_kernel_name == "extern_kernels.mm" - ): + python_kernel_name := getattr(node, "python_kernel_name", None) + ) and "extern_kernels" in python_kernel_name: return True return False @@ -189,15 +187,23 @@ def _group_name(snode, with_bufs=False) -> str: def _reorder_communication_preserving_peak_memory_internal( snodes: list[BaseSchedulerNode], ) -> tuple[list[BaseSchedulerNode], dict[BaseSchedulerNode, ReorderInfo]]: - from torch._inductor.scheduler import GroupedSchedulerNode, init_group_node - - original_snodes_num = len(snodes) """ Internal testing helper that also returns debug info. Returns: - reordered snodes list - dict {snode: ReorderInfo} """ + has_collectives = False + for snode in snodes: + if contains_collective(snode): + has_collectives = True + break + if not has_collectives: + return snodes, {} + + from torch._inductor.scheduler import GroupedSchedulerNode + + original_snodes_num = len(snodes) # heuristic to avoid degenerating to quadratic time graph_inputs: OrderedSet[str] = OrderedSet(V.graph.graph_inputs.keys()) graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names()) @@ -208,7 +214,8 @@ def _reorder_communication_preserving_peak_memory_internal( snodes, name_to_freeable_input_buf, graph_outputs ) runtimes = {snode: estimate_op_runtime(snode) for snode in snodes} - snode_to_curr_memory = dict(zip(snodes, curr_memory)) + _curr_memory = dict(zip(snodes, curr_memory)) + _curr_memory[None] = 0 # type: ignore[index] # debug stats stats: dict[BaseSchedulerNode, ReorderInfo] = {} @@ -232,153 +239,151 @@ def _reorder_communication_preserving_peak_memory_internal( _temp_group_visit_leaves(snode, accumulate_time) return max(0, comm_time - compute_time) - MOVE_LIMIT = len(snodes) * 100 total_moves = 0 - # TODO - experiment with whether this limit is useful, setting `len(snodes)` disables it - PER_COLLECTIVE_PREFETCH_LIMIT = len(snodes) - if config.reorder_prefetch_limit is not None: - PER_COLLECTIVE_PREFETCH_LIMIT = config.reorder_prefetch_limit # Dicts to keep track of "next" and "previous" as double-linked structure during grouping - _prev: dict[BaseSchedulerNode, Optional[BaseSchedulerNode]] = {} - _next: dict[BaseSchedulerNode, Optional[BaseSchedulerNode]] = {} + _prev: dict[Optional[BaseSchedulerNode], Optional[BaseSchedulerNode]] = {} + _next: dict[Optional[BaseSchedulerNode], Optional[BaseSchedulerNode]] = {} for i, snode in enumerate(snodes): _prev[snode] = snodes[i - 1] if i > 0 else None _next[snode] = snodes[i + 1] if i < len(snodes) - 1 else None + _curr_memory = dict(zip(snodes, curr_memory)) + _curr_memory[None] = 0 # type: ignore[index] - gsnodes: list[GroupedSchedulerNode] = [ - GroupedSchedulerNode(snode.scheduler, [snode], temp_grouping=True) - for snode in snodes - ] - for i, gsnode in enumerate(gsnodes): - snode = gsnode.snodes[0] # type: ignore[attr-defined] - if contains_collective(snode): - reorder_info = stats[snode] = ReorderInfo() + _head = snodes[0] + + def _group_nodes(head, tail): + ret = [] + n = head + while True: + if n is not None: + ret.append(n) + if n == tail: + break + n = _next[n] + return ret + + def _group_names(head, tail): + ret = "" + for n in _group_nodes(head, tail): + if ret: + ret += "~" + ret += n.get_name() + return ret + + curr = _head + while _next[curr] is not None: + if contains_collective(curr): + reorder_info = stats[curr] = ReorderInfo() reorder_info.initial_exposed = reorder_info.final_exposed = ( - exposed_communication_time(snode, snodes[i + 1 :]) + exposed_communication_time(curr, _group_nodes(_next[curr], None)) ) - if total_moves >= MOVE_LIMIT: - reorder_info.limiting_factor = "move limit" - continue - for j in range(i - 1, -1, -1): - prev_gsnode = gsnodes[j] - if len(prev_gsnode.snodes) == 0: - continue - - if j < max(0, i - PER_COLLECTIVE_PREFETCH_LIMIT): - reorder_info.limiting_factor = "prefetch limit" - break - if contains_collective(prev_gsnode): + candidate = _prev[curr] + group_head = curr + group_tail = curr + group_peak_memory = _curr_memory[curr] + while candidate is not None: + if contains_collective(candidate): reorder_info.limiting_factor = "collective ordering" break - dep_names = OrderedSet([s.name for s in snode.unmet_dependencies]) - prev_outs = prev_gsnode.get_outputs() + group = GroupedSchedulerNode( + curr.scheduler, + _group_nodes(group_head, group_tail), + temp_grouping=True, + ) + + data_deps = {s.name: s for s in group.unmet_dependencies} + candidate_outs = candidate.get_outputs() data_dep = None - for o in prev_outs: - if o.get_name() in dep_names: - data_dep = o.get_name() + for o in candidate_outs: + if d := data_deps.get(o.get_name(), None): + if isinstance(d, WeakDep) and d.is_fake: + continue + data_dep = d break if data_dep is not None: - def is_groupable(prev_gsnode): + def is_groupable(candidate): # preserve ordering - if contains_collective(prev_gsnode): - return False + if contains_collective(candidate): + return False, "contains_collective" - if contains_gemm_like(prev_gsnode): - return False - return True + if contains_gemm_like(candidate): + return False, "contains_gemm_like" + return True, None - if is_groupable(prev_gsnode): - new_snodes = prev_gsnode.snodes + gsnode.snodes - init_group_node(gsnode, gsnode.scheduler, new_snodes) - prev_gsnode.snodes = [] + is_grp, grp_reason = is_groupable(candidate) + if is_grp: + group_head = candidate + group_peak_memory = max( + group_peak_memory, _curr_memory[candidate] + ) reorder_info.grouped += 1 - reorder_info.grouped_info = gsnode.get_name() + reorder_info.grouped_info = _group_names(group_head, group_tail) + candidate = _prev[candidate] continue else: msg = ( - f"data dependency {data_dep}(dep_names:{dep_names})" - f" prev_gsnode.outputs:{[o.get_name() for o in prev_outs]}" + f"data dependency {data_dep}(dep_names:{list(data_deps.keys())})" + f"\n candidate:{candidate.get_name()}(os:{[candidate.get_buffer_names()]})" + f"dep on {_group_names(group_head, group_tail)}" + f"\n non_group_reason:{grp_reason}" ) reorder_info.limiting_factor = msg break - if peak_memory - curr_memory[j] < curr_memory[j - 1] - curr_memory[j]: + delta_memory_candidate = ( + _curr_memory[candidate] - _curr_memory[_prev[candidate]] # type: ignore[index] + ) + + if group_peak_memory - delta_memory_candidate > peak_memory: reorder_info.limiting_factor = "peak memory" break - if reorder_info.final_exposed > runtimes[snode]: - reorder_info.limiting_factor = "sufficient overlapping" - break + reorder_info.moves += 1 total_moves += 1 - # swapping nodes j and j+1 affects curr memory at j only - # j_plus_one_alloc = curr_memory[j + 1] - curr_memory[j] - # j_alloc = curr_memory[j] - curr_memory[j - 1] - # curr_memory[j] = curr_memory[j] - j_alloc + j_plus_one_alloc - def swap_curr_memory_with_previous( - snode_j_plus_one, snode_j, snode_j_minus_one - ): - curr_memory_j_plus_one = snode_to_curr_memory[snode_j_plus_one] - curr_memory_j = snode_to_curr_memory[snode_j] - curr_memory_j_minus_one = ( - snode_to_curr_memory[snode_j_minus_one] - if snode_j_minus_one is not None - else 0 - ) - j_plus_one_alloc = curr_memory_j_plus_one - curr_memory_j - j_alloc = curr_memory_j - curr_memory_j_minus_one - snode_to_curr_memory[snode_j] = ( - curr_memory_j - j_alloc + j_plus_one_alloc - ) + mem_deltas = {} + for n in [candidate, *_group_nodes(group_head, group_tail)]: + mem_deltas[n] = _curr_memory[n] - _curr_memory[_prev[n]] # type: ignore[index] + # swap (candidate, group_head...group_tail) + # Before: + # candidate_prev -0-> candidate -1-> group_head...group_tail -2-> group_tail_next + # After: + # candidate_prev -0-> group_head...group_tail -1-> candidate -2-> group_tail_next + # 0 + candidate_prev = _prev[candidate] + if candidate_prev: + _next[candidate_prev] = group_head + _prev[group_head] = candidate_prev - # Recompuing curr_mem for swapping grouped nodes j (group A) and j + 1 (group B) - # swap([A0, A1, A2], [B0, B1]) --> [B0, B1], [A0, A1, A2] - # decomposing to: - # swap(A2, B0) -> A0, A1, B0, A2, B1 - # swap(A2, B1) -> A0, A1, B0, B1, A2 - # swap(A1, B0) -> A0, B0, A1, B1, A2 - # swap(A1, B1) -> A0, B0, B1, A1, A2 - # swap(A0, B0) -> B0, A0, B1, A1, A2 - # swap(A0, B1) -> B0, B1, A0, A1, A2 - for _j in range(len(gsnodes[j].snodes) - 1, -1, -1): # group A - snode_j = gsnodes[j].snodes[_j] - for _i, snode_i in enumerate(gsnode.snodes): # group B - swap_curr_memory_with_previous( - snode_j_plus_one=snode_i, - snode_j=snode_j, - snode_j_minus_one=_prev[snode_j], - ) + # 2 + group_tail_next = _next[group_tail] + if group_tail_next: + _prev[group_tail_next] = candidate + _next[candidate] = group_tail_next - # Update _next and _prev for swap [snode_j, snode_i] -> [snode_i, snode_j] - first = snode_j - second = snode_i - first_prev = _prev[first] - second_next = _next[second] - if first_prev: - _next[first_prev] = second - _prev[second] = first_prev + # 1 + _prev[candidate] = group_tail + _next[group_tail] = candidate - if second_next: - _prev[second_next] = first - _next[first] = second_next + if _head == candidate: + _head = group_head - _next[second] = first - _prev[first] = second - - tmp = gsnodes[j] - gsnodes[j] = gsnodes[j + 1] - gsnodes[j + 1] = tmp reorder_info.final_exposed = exposed_communication_time( - snode, - itertools.chain( - gsnode.snodes[1:], *[n.snodes for n in gsnodes[j + 1 :]] - ), + curr, _group_nodes(_next[curr], None) ) + # Recompute curr_memory + _prev_curr_memory = _curr_memory[_prev[group_head]] # type: ignore[index] + for n in _group_nodes(group_head, candidate): + _curr_memory[n] = _prev_curr_memory = ( + _prev_curr_memory + mem_deltas[n] + ) + candidate = _prev[group_head] + curr = _next[curr] # type: ignore[assignment] node_stats = stats improvement = {snode: node_stats[snode].improvement for snode in node_stats} @@ -426,17 +431,13 @@ def _reorder_communication_preserving_peak_memory_internal( reorder_log_str += str(headers) + "\n" reorder_log_str += "\n".join(map(str, rows)) - grouping_logs: list[str] = [] - flatten_gsnodes: list[BaseSchedulerNode] = [] - for i, gsnode in enumerate(gsnodes): - if isinstance(gsnode, GroupedSchedulerNode) and gsnode.temp_grouping: - flatten_gsnodes.extend(gsnode.snodes) - else: - flatten_gsnodes.append(gsnode) - - grouping_log_str = "\n".join(grouping_logs) - reorder_log_str += "\n" - reorder_log_str += grouping_log_str + new_snodes = _group_nodes(_head, None) + assert len(new_snodes) == original_snodes_num + new_peak_memory, curr_memory = estimate_peak_memory( + new_snodes, name_to_freeable_input_buf, graph_outputs + ) + reorder_log_str += f"\n peak_memory_before:{peak_memory}" + reorder_log_str += f"\n peak_memory_after:{new_peak_memory}" overlap_log.info(reorder_log_str) trace_structured( @@ -448,8 +449,7 @@ def _reorder_communication_preserving_peak_memory_internal( payload_fn=lambda: reorder_log_str, ) - assert len(flatten_gsnodes) == original_snodes_num - return flatten_gsnodes, stats + return new_snodes, stats def _schedule_for_comm( @@ -623,7 +623,9 @@ def decide_global_ordering_of_comms( # Enforce ordering by making previous comm a `WeakDep` dependency of the next comm mutating_buf = next(iter(comm_nodes[i].get_buffer_names())) for buf in comm_nodes[i - 1].get_buffer_names(): - comm_nodes[i].add_fake_dep(WeakDep(buf, mutating_buf=mutating_buf)) + comm_nodes[i].add_fake_dep( + WeakDep(buf, mutating_buf=mutating_buf, is_fake=True) + ) return nodes @@ -640,66 +642,166 @@ class SinkWaitInfo: def _sink_waits_iterative_internal( snodes: list[BaseSchedulerNode], ) -> tuple[list[BaseSchedulerNode], dict[BaseSchedulerNode, SinkWaitInfo]]: - from torch._inductor.scheduler import GroupedSchedulerNode, init_group_node + from torch._inductor.scheduler import GroupedSchedulerNode + + original_snodes_num = len(snodes) + if original_snodes_num == 0: + return snodes, {} + graph_inputs: OrderedSet[str] = OrderedSet(V.graph.graph_inputs.keys()) + graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names()) + name_to_freeable_input_buf: dict[str, FreeableInputBuffer] = get_freeable_input_buf( + snodes, graph_inputs + ) + peak_memory, curr_memory = estimate_peak_memory( + snodes, name_to_freeable_input_buf, graph_outputs + ) - n = len(snodes) stats: dict[BaseSchedulerNode, SinkWaitInfo] = {} - gsnodes: list[GroupedSchedulerNode] = [ - GroupedSchedulerNode(snode.scheduler, [snode], temp_grouping=True) - for snode in snodes - ] - for i in range(n - 1, -1, -1): - gsnode = gsnodes[i] - if contains_wait(gsnode): - info = stats[gsnode.snodes[0]] = SinkWaitInfo() - for j in range(i + 1, n): - wait_gsnode = gsnodes[j - 1] - wait_outs = wait_gsnode.get_outputs() - next_gsnode = gsnodes[j] - dep_names = OrderedSet([s.name for s in next_gsnode.unmet_dependencies]) + _prev: dict[Optional[BaseSchedulerNode], Optional[BaseSchedulerNode]] = {} + _next: dict[Optional[BaseSchedulerNode], Optional[BaseSchedulerNode]] = {} + _head = snodes[0] + for i, snode in enumerate(snodes): + _prev[snode] = snodes[i - 1] if i > 0 else None + _next[snode] = snodes[i + 1] if i < len(snodes) - 1 else None + _curr_memory = dict(zip(snodes, curr_memory)) + _curr_memory[None] = 0 # type: ignore[index] + + def _group_nodes(head, tail): + ret = [] + n = head + while True: + if n is not None: + ret.append(n) + if n == tail: + break + n = _next[n] + return ret + + def _group_names(head, tail): + ret = "" + for n in _group_nodes(head, tail): + if ret: + ret += "~" + ret += n.get_name() + return ret + + curr = snodes[-1] + + processed_waits = OrderedSet() # type: ignore[var-annotated] + while _prev[curr] is not None: + if contains_wait(curr) and curr not in processed_waits: + processed_waits.add(curr) + info = stats[curr] = SinkWaitInfo() + candidate = _next[curr] + wait_snode = curr + group_head = curr + group_tail = curr + group_peak_memory = _curr_memory[curr] + while candidate is not None: + group = GroupedSchedulerNode( + wait_snode.scheduler, + _group_nodes(group_head, group_tail), + temp_grouping=True, + ) + group_outs = group.get_outputs() + + data_deps = {s.name: s for s in candidate.unmet_dependencies} data_dep = None - for o in wait_outs: - if o.get_name() in dep_names: - data_dep = o.get_name() + for o in group_outs: + if d := data_deps.get(o.get_name(), None): + if isinstance(d, WeakDep) and d.is_fake: + continue + data_dep = d break # 1. If we have data_dep - we can not swap => trying to group # 2. If swap candidate and current node both contain collectives => trying to group if data_dep is not None or ( both_contain_comms := ( - contains_collective(wait_gsnode) - and contains_collective(next_gsnode) + contains_collective(group) and contains_collective(candidate) ) ): def is_groupable(snode): - return not contains_gemm_like(snode) + # We do not want to group with collectives to not reorder them forward. + if contains_collective(snode): + return ( + False, + f"candidate contains collective {snode.get_name()}", + ) + if contains_gemm_like(snode): + return ( + False, + f"candidate contains gemm_like {snode.get_name()}", + ) + return True, None - if is_groupable(next_gsnode): - new_snodes = wait_gsnode.snodes + next_gsnode.snodes - init_group_node(next_gsnode, gsnode.scheduler, new_snodes) - wait_gsnode.snodes = [] + is_grp, grp_reason = is_groupable(candidate) + if is_grp: + group_tail = candidate + group_peak_memory = max( + group_peak_memory, _curr_memory[candidate] + ) info.grouped += 1 - info.grouped_info = _group_name(next_gsnode) + info.grouped_info = _group_names(group_head, group_tail) + candidate = _next[candidate] continue elif (data_dep is None) and both_contain_comms: info.limiting_factor = ( - f"collective ordering {_group_name(wait_gsnode)}" - f" with candidate:{_group_name(next_gsnode)}" - ) - else: - info.limiting_factor = ( - f"data dependency {data_dep}(dep_names:{dep_names})" - f" candidate:{_group_name(next_gsnode)} dep on {_group_name(wait_gsnode)}" - f" outs:{[o.get_name() for o in wait_outs]}" + f"collective ordering {_group_names(group_head, group_tail)}" + f" with candidate:{candidate.get_name()}" ) break - info.moves += 1 - info.moves_info += f"+{_group_name(next_gsnode)}" + else: + info.limiting_factor = ( + f"data dependency {data_dep}(dep_names:{list(data_deps.keys())})" + f"\n candidate:{candidate.get_name()}(os:{[candidate.get_buffer_names()]})" + f"dep on {_group_names(group_head, group_tail)}" + f"\n outs:{[o.get_name() for o in group_outs]}" + f"\n non_group_reason:{grp_reason}" + ) + break + candidate_delta_memory = ( + _curr_memory[candidate] - _curr_memory[_prev[candidate]] # type: ignore[index] + ) + if group_peak_memory + candidate_delta_memory > peak_memory: + info.limiting_factor = "peak_memory" + break + + info.moves += 1 + info.moves_info += f"+{candidate.get_name()}" + + # group_head_prev -0-> candidate -1-> group_head...group_tail -2-> candidate_next + mem_deltas = {} + for n in [candidate, *_group_nodes(group_head, group_tail)]: + mem_deltas[n] = _curr_memory[n] - _curr_memory[_prev[n]] # type: ignore[index] + # 0: + group_head_prev = _prev[group_head] + if group_head_prev: + _next[group_head_prev] = candidate + _prev[candidate] = group_head_prev + + # 2: + candidate_next = _next[candidate] + if candidate_next: + _prev[candidate_next] = group_tail + _next[group_tail] = candidate_next + + # 1: + _prev[group_head] = candidate + _next[candidate] = group_head + if group_head == _head: + _head = candidate + + # Recompute curr_memory + _prev_curr_memory = _curr_memory[_prev[candidate]] # type: ignore[index] + for n in _group_nodes(candidate, group_tail): + _curr_memory[n] = _prev_curr_memory = ( + _prev_curr_memory + mem_deltas[n] + ) + + candidate = _next[group_tail] + curr = _prev[curr] # type: ignore[assignment] - # Swapping snodes j and j - 1 - tmp = gsnodes[j - 1] - gsnodes[j - 1] = gsnodes[j] - gsnodes[j] = tmp headers = [ "Wait node", "grouped", @@ -732,16 +834,13 @@ def _sink_waits_iterative_internal( log_str += str(headers) + "\n" log_str += "\n".join(map(str, rows)) overlap_log.info(log_str) - grouping_logs = [] - flatten_snodes = [] - for i, gsnode in enumerate(gsnodes): - grouping_logs.append(f"gsnode[{i}]:{_group_name(gsnode, with_bufs=True)}") - if isinstance(gsnode, GroupedSchedulerNode) and gsnode.temp_grouping: - flatten_snodes.extend(gsnode.snodes) - else: - flatten_snodes.append(gsnode) - grouping_log_str = "\n".join(grouping_logs) - log_str += grouping_log_str + new_snodes = _group_nodes(_head, None) + assert len(new_snodes) == original_snodes_num + new_peak_memory, curr_memory = estimate_peak_memory( + new_snodes, name_to_freeable_input_buf, graph_outputs + ) + log_str += f"\n peak_memory_before:{peak_memory}" + log_str += f"\n peak_memory_after:{new_peak_memory}" trace_structured( "artifact", metadata_fn=lambda: { @@ -750,8 +849,7 @@ def _sink_waits_iterative_internal( }, payload_fn=lambda: log_str, ) - assert len(flatten_snodes) == n - return flatten_snodes, stats + return new_snodes, stats def sink_waits_iterative( @@ -777,7 +875,9 @@ def node_summary(snode): if len(snodes) == 1: detail = "" if isinstance(snode.node, (ir.ExternKernelOut, ir._CollectiveKernel)): - detail = f" ({snode.node.python_kernel_name})" + outs_str = f"outs:{[o.get_name() for o in snode.get_outputs()]}" + ins_str = f"ins:{[d.name for d in snode.unmet_dependencies]}" + detail = f" {snode.get_name()} ({snode.node.python_kernel_name})\n {outs_str}\n ({ins_str})" layouts = [child.node.get_output_spec() for child in snode.get_nodes()] out_tensor_info = ",".join( [ @@ -1352,7 +1452,7 @@ def enforce_comm_ordering_for_fsdp( mutating_buf = next(iter(ag_group_node.get_buffer_names())) for o in prev_ag_wait.get_outputs(): ag_group_node.add_fake_dep( - WeakDep(o.get_name(), mutating_buf=mutating_buf) + WeakDep(o.get_name(), mutating_buf=mutating_buf, is_fake=True) ) prev_ag_wait = wait_group_node @@ -1364,7 +1464,7 @@ def enforce_comm_ordering_for_fsdp( mutating_buf = next(iter(rs_group_node.get_buffer_names())) for o in prev_rs_wait.get_outputs(): rs_group_node.add_fake_dep( - WeakDep(o.get_name(), mutating_buf=mutating_buf) + WeakDep(o.get_name(), mutating_buf=mutating_buf, is_fake=True) ) prev_rs_wait = wait_group_node diff --git a/torch/_inductor/dependencies.py b/torch/_inductor/dependencies.py index 9de52061c64..8a374f5bab3 100644 --- a/torch/_inductor/dependencies.py +++ b/torch/_inductor/dependencies.py @@ -342,6 +342,12 @@ class WeakDep(Dep): name: str # Buffer that is doing the mutation mutating_buf: str + # WeakDep's are also used to add dependencies to prevent some specific reordering, + # E.g. collectives global ordering. + # But if other pass guarantees proper ordering by its logic, + # This additional "fake" deps will be holding optimizations. + # This flag is used to identify those additional deps. + is_fake: bool = False @property def index(self) -> sympy.Expr: @@ -352,7 +358,7 @@ class WeakDep(Dep): def rename(self, renames: dict[str, str]) -> "WeakDep": if self.name in renames: - return WeakDep(renames[self.name], self.mutating_buf) + return WeakDep(renames[self.name], self.mutating_buf, self.is_fake) return self def numbytes_hint(self) -> int: