[simple_fsdp][inductor_collectives] rewrite reorder_collectives, sink_waits_iterative (#158062)

Differential Revision: [D78159013](https://our.internmc.facebook.com/intern/diff/D78159013)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158062
Approved by: https://github.com/wconstab
This commit is contained in:
IvanKobzarev 2025-07-17 05:48:35 -07:00 committed by PyTorch MergeBot
parent ef256ad17b
commit eeb0783fe6
3 changed files with 351 additions and 203 deletions

View File

@ -19,6 +19,7 @@ from torch._dynamo.utils import same
from torch._inductor.comms import ( from torch._inductor.comms import (
_reorder_communication_preserving_peak_memory_internal, _reorder_communication_preserving_peak_memory_internal,
ReorderInfo, ReorderInfo,
sink_waits_iterative,
) )
from torch._inductor.compile_fx import compile_fx as inductor_compile_fx from torch._inductor.compile_fx import compile_fx as inductor_compile_fx
from torch._inductor.scheduler import BaseSchedulerNode from torch._inductor.scheduler import BaseSchedulerNode
@ -1621,7 +1622,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
comm from moving due to data dependency. comm from moving due to data dependency.
""" """
def func(x, w, ag_0, ag_1, *, tag, ranks, group_size): def func(x, w, ag_0, ag_1, ag_2, ag_3, *, tag, ranks, group_size):
# do some unrelated matmuls # do some unrelated matmuls
y = torch.mm(x, w) y = torch.mm(x, w)
@ -1654,14 +1655,52 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
# wait op # wait op
rs_0_out = torch.ops.c10d_functional.wait_tensor(rs_0_out) rs_0_out = torch.ops.c10d_functional.wait_tensor(rs_0_out)
rs_1_out = torch.ops.c10d_functional.wait_tensor(rs_1_out) rs_1_out = torch.ops.c10d_functional.wait_tensor(rs_1_out)
y += torch.mm(2 * x, 2 * w)
return y, ag_0_out, ag_1_out, rs_0_out, rs_1_out # cast the inputs
ag_2_cast = ag_2.to(torch.bfloat16)
ag_3_cast = ag_3.to(torch.bfloat16)
ag_2_out = torch.ops._c10d_functional.all_gather_into_tensor(
ag_2_cast, group_size, group_name
)
ag_3_out = torch.ops._c10d_functional.all_gather_into_tensor(
ag_3_cast, group_size, group_name
)
# wait op
ag_2_out = torch.ops.c10d_functional.wait_tensor(ag_2_out)
ag_3_out = torch.ops.c10d_functional.wait_tensor(ag_3_out)
#
rs_2_out = torch.ops._c10d_functional.reduce_scatter_tensor(
ag_2_cast, "sum", group_size, group_name
)
rs_3_out = torch.ops._c10d_functional.reduce_scatter_tensor(
ag_3_cast, "sum", group_size, group_name
)
# wait op
rs_2_out = torch.ops.c10d_functional.wait_tensor(rs_2_out)
rs_3_out = torch.ops.c10d_functional.wait_tensor(rs_3_out)
return (
y,
ag_0_out,
ag_1_out,
ag_2_out,
ag_3_out,
rs_0_out,
rs_1_out,
rs_2_out,
rs_3_out,
)
x = torch.ones(4, 384, device="cuda", dtype=torch.float32) x = torch.ones(4, 384, device="cuda", dtype=torch.float32)
w = torch.ones(384, 512, device="cuda", dtype=torch.float32) w = torch.ones(384, 512, device="cuda", dtype=torch.float32)
ag_0 = torch.ones(384, 512, device="cuda", dtype=torch.float32) ag_0 = torch.ones(1024, 512, device="cuda", dtype=torch.float32)
ag_1 = torch.ones(512, device="cuda", dtype=torch.float32) ag_1 = torch.ones(512, 1024, device="cuda", dtype=torch.float32)
inputs = [x, w, ag_0, ag_1] ag_2 = torch.ones(1024, 512, device="cuda", dtype=torch.float32)
ag_3 = torch.ones(512, 1024, device="cuda", dtype=torch.float32)
inputs = [x, w, ag_0, ag_1, ag_2, ag_3]
# get stats directly from the internal helper without affecting the real pass's signature # get stats directly from the internal helper without affecting the real pass's signature
node_stats: Optional[dict[BaseSchedulerNode, ReorderInfo]] = None node_stats: Optional[dict[BaseSchedulerNode, ReorderInfo]] = None
@ -1679,11 +1718,15 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
with torch._inductor.config.patch( with torch._inductor.config.patch(
{ {
"bucket_all_gathers_fx": "all", "bucket_all_gathers_fx": "all",
"bucket_all_gathers_fx_bucket_size_determinator": lambda _: 2,
"bucket_reduce_scatters_fx": "all", "bucket_reduce_scatters_fx": "all",
"bucket_reduce_scatters_fx_bucket_size_determinator": lambda _: 2,
"reorder_for_compute_comm_overlap": True, "reorder_for_compute_comm_overlap": True,
"reorder_for_compute_comm_overlap_passes": [ "reorder_for_compute_comm_overlap_passes": [
sink_waits_iterative,
_reorder_communication_preserving_peak_memory, _reorder_communication_preserving_peak_memory,
], ],
"allow_buffer_reuse": False,
} }
): ):
compiled = torch.compile(func) compiled = torch.compile(func)
@ -1694,31 +1737,30 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
FileCheck() FileCheck()
.check_count( .check_count(
"torch.ops._c10d_functional.all_gather_into_tensor_out.default(", "torch.ops._c10d_functional.all_gather_into_tensor_out.default(",
count=1, count=2,
exactly=True, exactly=True,
) )
.check(
"extern_kernels.mm",
)
.check(
"extern_kernels.addmm",
)
.run(code) .run(code)
) )
( (
FileCheck() FileCheck()
.check_count( .check_count(
"torch.ops._c10d_functional.reduce_scatter_tensor.default(", "torch.ops._c10d_functional.reduce_scatter_tensor.default(",
count=1, count=2,
exactly=True, exactly=True,
) )
.run(code)
)
(
FileCheck()
.check(
"torch.ops._c10d_functional.all_gather_into_tensor_out.default(",
)
.check(
"torch.ops._c10d_functional.reduce_scatter_tensor.default(",
)
.check( .check(
"extern_kernels.mm", "extern_kernels.mm",
) )
.check(
"extern_kernels.addmm",
)
.run(code) .run(code)
) )
out = compiled(*inputs, **self.get_world_trs()) out = compiled(*inputs, **self.get_world_trs())
@ -1726,7 +1768,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
assert same(out, correct), f"{out} va {correct}" assert same(out, correct), f"{out} va {correct}"
assert node_stats is not None assert node_stats is not None
self.assertTrue(isinstance(node_stats, dict)) self.assertTrue(isinstance(node_stats, dict))
self.assertEqual(len(node_stats), 2) self.assertEqual(len(node_stats), 4)
it = iter(node_stats.values()) it = iter(node_stats.values())
node_stat0 = next(it) node_stat0 = next(it)
self.assertTrue(node_stat0.moves > 0) self.assertTrue(node_stat0.moves > 0)

View File

@ -4,7 +4,6 @@ from __future__ import annotations
import heapq import heapq
import importlib import importlib
import itertools
import logging import logging
import operator import operator
import sys import sys
@ -149,9 +148,8 @@ def is_gemm_like(node: Optional[Union[IRNode, Operation]]) -> bool:
return True return True
if ( if (
hasattr(node, "python_kernel_name") python_kernel_name := getattr(node, "python_kernel_name", None)
and node.python_kernel_name == "extern_kernels.mm" ) and "extern_kernels" in python_kernel_name:
):
return True return True
return False return False
@ -189,15 +187,23 @@ def _group_name(snode, with_bufs=False) -> str:
def _reorder_communication_preserving_peak_memory_internal( def _reorder_communication_preserving_peak_memory_internal(
snodes: list[BaseSchedulerNode], snodes: list[BaseSchedulerNode],
) -> tuple[list[BaseSchedulerNode], dict[BaseSchedulerNode, ReorderInfo]]: ) -> tuple[list[BaseSchedulerNode], dict[BaseSchedulerNode, ReorderInfo]]:
from torch._inductor.scheduler import GroupedSchedulerNode, init_group_node
original_snodes_num = len(snodes)
""" """
Internal testing helper that also returns debug info. Internal testing helper that also returns debug info.
Returns: Returns:
- reordered snodes list - reordered snodes list
- dict {snode: ReorderInfo} - dict {snode: ReorderInfo}
""" """
has_collectives = False
for snode in snodes:
if contains_collective(snode):
has_collectives = True
break
if not has_collectives:
return snodes, {}
from torch._inductor.scheduler import GroupedSchedulerNode
original_snodes_num = len(snodes)
# heuristic to avoid degenerating to quadratic time # heuristic to avoid degenerating to quadratic time
graph_inputs: OrderedSet[str] = OrderedSet(V.graph.graph_inputs.keys()) graph_inputs: OrderedSet[str] = OrderedSet(V.graph.graph_inputs.keys())
graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names()) graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names())
@ -208,7 +214,8 @@ def _reorder_communication_preserving_peak_memory_internal(
snodes, name_to_freeable_input_buf, graph_outputs snodes, name_to_freeable_input_buf, graph_outputs
) )
runtimes = {snode: estimate_op_runtime(snode) for snode in snodes} runtimes = {snode: estimate_op_runtime(snode) for snode in snodes}
snode_to_curr_memory = dict(zip(snodes, curr_memory)) _curr_memory = dict(zip(snodes, curr_memory))
_curr_memory[None] = 0 # type: ignore[index]
# debug stats # debug stats
stats: dict[BaseSchedulerNode, ReorderInfo] = {} stats: dict[BaseSchedulerNode, ReorderInfo] = {}
@ -232,153 +239,151 @@ def _reorder_communication_preserving_peak_memory_internal(
_temp_group_visit_leaves(snode, accumulate_time) _temp_group_visit_leaves(snode, accumulate_time)
return max(0, comm_time - compute_time) return max(0, comm_time - compute_time)
MOVE_LIMIT = len(snodes) * 100
total_moves = 0 total_moves = 0
# TODO - experiment with whether this limit is useful, setting `len(snodes)` disables it
PER_COLLECTIVE_PREFETCH_LIMIT = len(snodes)
if config.reorder_prefetch_limit is not None:
PER_COLLECTIVE_PREFETCH_LIMIT = config.reorder_prefetch_limit
# Dicts to keep track of "next" and "previous" as double-linked structure during grouping # Dicts to keep track of "next" and "previous" as double-linked structure during grouping
_prev: dict[BaseSchedulerNode, Optional[BaseSchedulerNode]] = {} _prev: dict[Optional[BaseSchedulerNode], Optional[BaseSchedulerNode]] = {}
_next: dict[BaseSchedulerNode, Optional[BaseSchedulerNode]] = {} _next: dict[Optional[BaseSchedulerNode], Optional[BaseSchedulerNode]] = {}
for i, snode in enumerate(snodes): for i, snode in enumerate(snodes):
_prev[snode] = snodes[i - 1] if i > 0 else None _prev[snode] = snodes[i - 1] if i > 0 else None
_next[snode] = snodes[i + 1] if i < len(snodes) - 1 else None _next[snode] = snodes[i + 1] if i < len(snodes) - 1 else None
_curr_memory = dict(zip(snodes, curr_memory))
_curr_memory[None] = 0 # type: ignore[index]
gsnodes: list[GroupedSchedulerNode] = [ _head = snodes[0]
GroupedSchedulerNode(snode.scheduler, [snode], temp_grouping=True)
for snode in snodes def _group_nodes(head, tail):
] ret = []
for i, gsnode in enumerate(gsnodes): n = head
snode = gsnode.snodes[0] # type: ignore[attr-defined] while True:
if contains_collective(snode): if n is not None:
reorder_info = stats[snode] = ReorderInfo() ret.append(n)
if n == tail:
break
n = _next[n]
return ret
def _group_names(head, tail):
ret = ""
for n in _group_nodes(head, tail):
if ret:
ret += "~"
ret += n.get_name()
return ret
curr = _head
while _next[curr] is not None:
if contains_collective(curr):
reorder_info = stats[curr] = ReorderInfo()
reorder_info.initial_exposed = reorder_info.final_exposed = ( reorder_info.initial_exposed = reorder_info.final_exposed = (
exposed_communication_time(snode, snodes[i + 1 :]) exposed_communication_time(curr, _group_nodes(_next[curr], None))
) )
if total_moves >= MOVE_LIMIT:
reorder_info.limiting_factor = "move limit"
continue
for j in range(i - 1, -1, -1): candidate = _prev[curr]
prev_gsnode = gsnodes[j] group_head = curr
if len(prev_gsnode.snodes) == 0: group_tail = curr
continue group_peak_memory = _curr_memory[curr]
while candidate is not None:
if j < max(0, i - PER_COLLECTIVE_PREFETCH_LIMIT): if contains_collective(candidate):
reorder_info.limiting_factor = "prefetch limit"
break
if contains_collective(prev_gsnode):
reorder_info.limiting_factor = "collective ordering" reorder_info.limiting_factor = "collective ordering"
break break
dep_names = OrderedSet([s.name for s in snode.unmet_dependencies]) group = GroupedSchedulerNode(
prev_outs = prev_gsnode.get_outputs() curr.scheduler,
_group_nodes(group_head, group_tail),
temp_grouping=True,
)
data_deps = {s.name: s for s in group.unmet_dependencies}
candidate_outs = candidate.get_outputs()
data_dep = None data_dep = None
for o in prev_outs: for o in candidate_outs:
if o.get_name() in dep_names: if d := data_deps.get(o.get_name(), None):
data_dep = o.get_name() if isinstance(d, WeakDep) and d.is_fake:
continue
data_dep = d
break break
if data_dep is not None: if data_dep is not None:
def is_groupable(prev_gsnode): def is_groupable(candidate):
# preserve ordering # preserve ordering
if contains_collective(prev_gsnode): if contains_collective(candidate):
return False return False, "contains_collective"
if contains_gemm_like(prev_gsnode): if contains_gemm_like(candidate):
return False return False, "contains_gemm_like"
return True return True, None
if is_groupable(prev_gsnode): is_grp, grp_reason = is_groupable(candidate)
new_snodes = prev_gsnode.snodes + gsnode.snodes if is_grp:
init_group_node(gsnode, gsnode.scheduler, new_snodes) group_head = candidate
prev_gsnode.snodes = [] group_peak_memory = max(
group_peak_memory, _curr_memory[candidate]
)
reorder_info.grouped += 1 reorder_info.grouped += 1
reorder_info.grouped_info = gsnode.get_name() reorder_info.grouped_info = _group_names(group_head, group_tail)
candidate = _prev[candidate]
continue continue
else: else:
msg = ( msg = (
f"data dependency {data_dep}(dep_names:{dep_names})" f"data dependency {data_dep}(dep_names:{list(data_deps.keys())})"
f" prev_gsnode.outputs:{[o.get_name() for o in prev_outs]}" f"\n candidate:{candidate.get_name()}(os:{[candidate.get_buffer_names()]})"
f"dep on {_group_names(group_head, group_tail)}"
f"\n non_group_reason:{grp_reason}"
) )
reorder_info.limiting_factor = msg reorder_info.limiting_factor = msg
break break
if peak_memory - curr_memory[j] < curr_memory[j - 1] - curr_memory[j]: delta_memory_candidate = (
_curr_memory[candidate] - _curr_memory[_prev[candidate]] # type: ignore[index]
)
if group_peak_memory - delta_memory_candidate > peak_memory:
reorder_info.limiting_factor = "peak memory" reorder_info.limiting_factor = "peak memory"
break break
if reorder_info.final_exposed > runtimes[snode]:
reorder_info.limiting_factor = "sufficient overlapping"
break
reorder_info.moves += 1 reorder_info.moves += 1
total_moves += 1 total_moves += 1
# swapping nodes j and j+1 affects curr memory at j only mem_deltas = {}
# j_plus_one_alloc = curr_memory[j + 1] - curr_memory[j] for n in [candidate, *_group_nodes(group_head, group_tail)]:
# j_alloc = curr_memory[j] - curr_memory[j - 1] mem_deltas[n] = _curr_memory[n] - _curr_memory[_prev[n]] # type: ignore[index]
# curr_memory[j] = curr_memory[j] - j_alloc + j_plus_one_alloc # swap (candidate, group_head...group_tail)
def swap_curr_memory_with_previous( # Before:
snode_j_plus_one, snode_j, snode_j_minus_one # candidate_prev -0-> candidate -1-> group_head...group_tail -2-> group_tail_next
): # After:
curr_memory_j_plus_one = snode_to_curr_memory[snode_j_plus_one] # candidate_prev -0-> group_head...group_tail -1-> candidate -2-> group_tail_next
curr_memory_j = snode_to_curr_memory[snode_j] # 0
curr_memory_j_minus_one = ( candidate_prev = _prev[candidate]
snode_to_curr_memory[snode_j_minus_one] if candidate_prev:
if snode_j_minus_one is not None _next[candidate_prev] = group_head
else 0 _prev[group_head] = candidate_prev
)
j_plus_one_alloc = curr_memory_j_plus_one - curr_memory_j
j_alloc = curr_memory_j - curr_memory_j_minus_one
snode_to_curr_memory[snode_j] = (
curr_memory_j - j_alloc + j_plus_one_alloc
)
# Recompuing curr_mem for swapping grouped nodes j (group A) and j + 1 (group B) # 2
# swap([A0, A1, A2], [B0, B1]) --> [B0, B1], [A0, A1, A2] group_tail_next = _next[group_tail]
# decomposing to: if group_tail_next:
# swap(A2, B0) -> A0, A1, B0, A2, B1 _prev[group_tail_next] = candidate
# swap(A2, B1) -> A0, A1, B0, B1, A2 _next[candidate] = group_tail_next
# swap(A1, B0) -> A0, B0, A1, B1, A2
# swap(A1, B1) -> A0, B0, B1, A1, A2
# swap(A0, B0) -> B0, A0, B1, A1, A2
# swap(A0, B1) -> B0, B1, A0, A1, A2
for _j in range(len(gsnodes[j].snodes) - 1, -1, -1): # group A
snode_j = gsnodes[j].snodes[_j]
for _i, snode_i in enumerate(gsnode.snodes): # group B
swap_curr_memory_with_previous(
snode_j_plus_one=snode_i,
snode_j=snode_j,
snode_j_minus_one=_prev[snode_j],
)
# Update _next and _prev for swap [snode_j, snode_i] -> [snode_i, snode_j] # 1
first = snode_j _prev[candidate] = group_tail
second = snode_i _next[group_tail] = candidate
first_prev = _prev[first]
second_next = _next[second]
if first_prev:
_next[first_prev] = second
_prev[second] = first_prev
if second_next: if _head == candidate:
_prev[second_next] = first _head = group_head
_next[first] = second_next
_next[second] = first
_prev[first] = second
tmp = gsnodes[j]
gsnodes[j] = gsnodes[j + 1]
gsnodes[j + 1] = tmp
reorder_info.final_exposed = exposed_communication_time( reorder_info.final_exposed = exposed_communication_time(
snode, curr, _group_nodes(_next[curr], None)
itertools.chain(
gsnode.snodes[1:], *[n.snodes for n in gsnodes[j + 1 :]]
),
) )
# Recompute curr_memory
_prev_curr_memory = _curr_memory[_prev[group_head]] # type: ignore[index]
for n in _group_nodes(group_head, candidate):
_curr_memory[n] = _prev_curr_memory = (
_prev_curr_memory + mem_deltas[n]
)
candidate = _prev[group_head]
curr = _next[curr] # type: ignore[assignment]
node_stats = stats node_stats = stats
improvement = {snode: node_stats[snode].improvement for snode in node_stats} improvement = {snode: node_stats[snode].improvement for snode in node_stats}
@ -426,17 +431,13 @@ def _reorder_communication_preserving_peak_memory_internal(
reorder_log_str += str(headers) + "\n" reorder_log_str += str(headers) + "\n"
reorder_log_str += "\n".join(map(str, rows)) reorder_log_str += "\n".join(map(str, rows))
grouping_logs: list[str] = [] new_snodes = _group_nodes(_head, None)
flatten_gsnodes: list[BaseSchedulerNode] = [] assert len(new_snodes) == original_snodes_num
for i, gsnode in enumerate(gsnodes): new_peak_memory, curr_memory = estimate_peak_memory(
if isinstance(gsnode, GroupedSchedulerNode) and gsnode.temp_grouping: new_snodes, name_to_freeable_input_buf, graph_outputs
flatten_gsnodes.extend(gsnode.snodes) )
else: reorder_log_str += f"\n peak_memory_before:{peak_memory}"
flatten_gsnodes.append(gsnode) reorder_log_str += f"\n peak_memory_after:{new_peak_memory}"
grouping_log_str = "\n".join(grouping_logs)
reorder_log_str += "\n"
reorder_log_str += grouping_log_str
overlap_log.info(reorder_log_str) overlap_log.info(reorder_log_str)
trace_structured( trace_structured(
@ -448,8 +449,7 @@ def _reorder_communication_preserving_peak_memory_internal(
payload_fn=lambda: reorder_log_str, payload_fn=lambda: reorder_log_str,
) )
assert len(flatten_gsnodes) == original_snodes_num return new_snodes, stats
return flatten_gsnodes, stats
def _schedule_for_comm( def _schedule_for_comm(
@ -623,7 +623,9 @@ def decide_global_ordering_of_comms(
# Enforce ordering by making previous comm a `WeakDep` dependency of the next comm # Enforce ordering by making previous comm a `WeakDep` dependency of the next comm
mutating_buf = next(iter(comm_nodes[i].get_buffer_names())) mutating_buf = next(iter(comm_nodes[i].get_buffer_names()))
for buf in comm_nodes[i - 1].get_buffer_names(): for buf in comm_nodes[i - 1].get_buffer_names():
comm_nodes[i].add_fake_dep(WeakDep(buf, mutating_buf=mutating_buf)) comm_nodes[i].add_fake_dep(
WeakDep(buf, mutating_buf=mutating_buf, is_fake=True)
)
return nodes return nodes
@ -640,66 +642,166 @@ class SinkWaitInfo:
def _sink_waits_iterative_internal( def _sink_waits_iterative_internal(
snodes: list[BaseSchedulerNode], snodes: list[BaseSchedulerNode],
) -> tuple[list[BaseSchedulerNode], dict[BaseSchedulerNode, SinkWaitInfo]]: ) -> tuple[list[BaseSchedulerNode], dict[BaseSchedulerNode, SinkWaitInfo]]:
from torch._inductor.scheduler import GroupedSchedulerNode, init_group_node from torch._inductor.scheduler import GroupedSchedulerNode
original_snodes_num = len(snodes)
if original_snodes_num == 0:
return snodes, {}
graph_inputs: OrderedSet[str] = OrderedSet(V.graph.graph_inputs.keys())
graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names())
name_to_freeable_input_buf: dict[str, FreeableInputBuffer] = get_freeable_input_buf(
snodes, graph_inputs
)
peak_memory, curr_memory = estimate_peak_memory(
snodes, name_to_freeable_input_buf, graph_outputs
)
n = len(snodes)
stats: dict[BaseSchedulerNode, SinkWaitInfo] = {} stats: dict[BaseSchedulerNode, SinkWaitInfo] = {}
gsnodes: list[GroupedSchedulerNode] = [ _prev: dict[Optional[BaseSchedulerNode], Optional[BaseSchedulerNode]] = {}
GroupedSchedulerNode(snode.scheduler, [snode], temp_grouping=True) _next: dict[Optional[BaseSchedulerNode], Optional[BaseSchedulerNode]] = {}
for snode in snodes _head = snodes[0]
] for i, snode in enumerate(snodes):
for i in range(n - 1, -1, -1): _prev[snode] = snodes[i - 1] if i > 0 else None
gsnode = gsnodes[i] _next[snode] = snodes[i + 1] if i < len(snodes) - 1 else None
if contains_wait(gsnode): _curr_memory = dict(zip(snodes, curr_memory))
info = stats[gsnode.snodes[0]] = SinkWaitInfo() _curr_memory[None] = 0 # type: ignore[index]
for j in range(i + 1, n):
wait_gsnode = gsnodes[j - 1] def _group_nodes(head, tail):
wait_outs = wait_gsnode.get_outputs() ret = []
next_gsnode = gsnodes[j] n = head
dep_names = OrderedSet([s.name for s in next_gsnode.unmet_dependencies]) while True:
if n is not None:
ret.append(n)
if n == tail:
break
n = _next[n]
return ret
def _group_names(head, tail):
ret = ""
for n in _group_nodes(head, tail):
if ret:
ret += "~"
ret += n.get_name()
return ret
curr = snodes[-1]
processed_waits = OrderedSet() # type: ignore[var-annotated]
while _prev[curr] is not None:
if contains_wait(curr) and curr not in processed_waits:
processed_waits.add(curr)
info = stats[curr] = SinkWaitInfo()
candidate = _next[curr]
wait_snode = curr
group_head = curr
group_tail = curr
group_peak_memory = _curr_memory[curr]
while candidate is not None:
group = GroupedSchedulerNode(
wait_snode.scheduler,
_group_nodes(group_head, group_tail),
temp_grouping=True,
)
group_outs = group.get_outputs()
data_deps = {s.name: s for s in candidate.unmet_dependencies}
data_dep = None data_dep = None
for o in wait_outs: for o in group_outs:
if o.get_name() in dep_names: if d := data_deps.get(o.get_name(), None):
data_dep = o.get_name() if isinstance(d, WeakDep) and d.is_fake:
continue
data_dep = d
break break
# 1. If we have data_dep - we can not swap => trying to group # 1. If we have data_dep - we can not swap => trying to group
# 2. If swap candidate and current node both contain collectives => trying to group # 2. If swap candidate and current node both contain collectives => trying to group
if data_dep is not None or ( if data_dep is not None or (
both_contain_comms := ( both_contain_comms := (
contains_collective(wait_gsnode) contains_collective(group) and contains_collective(candidate)
and contains_collective(next_gsnode)
) )
): ):
def is_groupable(snode): def is_groupable(snode):
return not contains_gemm_like(snode) # We do not want to group with collectives to not reorder them forward.
if contains_collective(snode):
return (
False,
f"candidate contains collective {snode.get_name()}",
)
if contains_gemm_like(snode):
return (
False,
f"candidate contains gemm_like {snode.get_name()}",
)
return True, None
if is_groupable(next_gsnode): is_grp, grp_reason = is_groupable(candidate)
new_snodes = wait_gsnode.snodes + next_gsnode.snodes if is_grp:
init_group_node(next_gsnode, gsnode.scheduler, new_snodes) group_tail = candidate
wait_gsnode.snodes = [] group_peak_memory = max(
group_peak_memory, _curr_memory[candidate]
)
info.grouped += 1 info.grouped += 1
info.grouped_info = _group_name(next_gsnode) info.grouped_info = _group_names(group_head, group_tail)
candidate = _next[candidate]
continue continue
elif (data_dep is None) and both_contain_comms: elif (data_dep is None) and both_contain_comms:
info.limiting_factor = ( info.limiting_factor = (
f"collective ordering {_group_name(wait_gsnode)}" f"collective ordering {_group_names(group_head, group_tail)}"
f" with candidate:{_group_name(next_gsnode)}" f" with candidate:{candidate.get_name()}"
)
else:
info.limiting_factor = (
f"data dependency {data_dep}(dep_names:{dep_names})"
f" candidate:{_group_name(next_gsnode)} dep on {_group_name(wait_gsnode)}"
f" outs:{[o.get_name() for o in wait_outs]}"
) )
break break
info.moves += 1 else:
info.moves_info += f"+{_group_name(next_gsnode)}" info.limiting_factor = (
f"data dependency {data_dep}(dep_names:{list(data_deps.keys())})"
f"\n candidate:{candidate.get_name()}(os:{[candidate.get_buffer_names()]})"
f"dep on {_group_names(group_head, group_tail)}"
f"\n outs:{[o.get_name() for o in group_outs]}"
f"\n non_group_reason:{grp_reason}"
)
break
candidate_delta_memory = (
_curr_memory[candidate] - _curr_memory[_prev[candidate]] # type: ignore[index]
)
if group_peak_memory + candidate_delta_memory > peak_memory:
info.limiting_factor = "peak_memory"
break
info.moves += 1
info.moves_info += f"+{candidate.get_name()}"
# group_head_prev -0-> candidate -1-> group_head...group_tail -2-> candidate_next
mem_deltas = {}
for n in [candidate, *_group_nodes(group_head, group_tail)]:
mem_deltas[n] = _curr_memory[n] - _curr_memory[_prev[n]] # type: ignore[index]
# 0:
group_head_prev = _prev[group_head]
if group_head_prev:
_next[group_head_prev] = candidate
_prev[candidate] = group_head_prev
# 2:
candidate_next = _next[candidate]
if candidate_next:
_prev[candidate_next] = group_tail
_next[group_tail] = candidate_next
# 1:
_prev[group_head] = candidate
_next[candidate] = group_head
if group_head == _head:
_head = candidate
# Recompute curr_memory
_prev_curr_memory = _curr_memory[_prev[candidate]] # type: ignore[index]
for n in _group_nodes(candidate, group_tail):
_curr_memory[n] = _prev_curr_memory = (
_prev_curr_memory + mem_deltas[n]
)
candidate = _next[group_tail]
curr = _prev[curr] # type: ignore[assignment]
# Swapping snodes j and j - 1
tmp = gsnodes[j - 1]
gsnodes[j - 1] = gsnodes[j]
gsnodes[j] = tmp
headers = [ headers = [
"Wait node", "Wait node",
"grouped", "grouped",
@ -732,16 +834,13 @@ def _sink_waits_iterative_internal(
log_str += str(headers) + "\n" log_str += str(headers) + "\n"
log_str += "\n".join(map(str, rows)) log_str += "\n".join(map(str, rows))
overlap_log.info(log_str) overlap_log.info(log_str)
grouping_logs = [] new_snodes = _group_nodes(_head, None)
flatten_snodes = [] assert len(new_snodes) == original_snodes_num
for i, gsnode in enumerate(gsnodes): new_peak_memory, curr_memory = estimate_peak_memory(
grouping_logs.append(f"gsnode[{i}]:{_group_name(gsnode, with_bufs=True)}") new_snodes, name_to_freeable_input_buf, graph_outputs
if isinstance(gsnode, GroupedSchedulerNode) and gsnode.temp_grouping: )
flatten_snodes.extend(gsnode.snodes) log_str += f"\n peak_memory_before:{peak_memory}"
else: log_str += f"\n peak_memory_after:{new_peak_memory}"
flatten_snodes.append(gsnode)
grouping_log_str = "\n".join(grouping_logs)
log_str += grouping_log_str
trace_structured( trace_structured(
"artifact", "artifact",
metadata_fn=lambda: { metadata_fn=lambda: {
@ -750,8 +849,7 @@ def _sink_waits_iterative_internal(
}, },
payload_fn=lambda: log_str, payload_fn=lambda: log_str,
) )
assert len(flatten_snodes) == n return new_snodes, stats
return flatten_snodes, stats
def sink_waits_iterative( def sink_waits_iterative(
@ -777,7 +875,9 @@ def node_summary(snode):
if len(snodes) == 1: if len(snodes) == 1:
detail = "" detail = ""
if isinstance(snode.node, (ir.ExternKernelOut, ir._CollectiveKernel)): if isinstance(snode.node, (ir.ExternKernelOut, ir._CollectiveKernel)):
detail = f" ({snode.node.python_kernel_name})" outs_str = f"outs:{[o.get_name() for o in snode.get_outputs()]}"
ins_str = f"ins:{[d.name for d in snode.unmet_dependencies]}"
detail = f" {snode.get_name()} ({snode.node.python_kernel_name})\n {outs_str}\n ({ins_str})"
layouts = [child.node.get_output_spec() for child in snode.get_nodes()] layouts = [child.node.get_output_spec() for child in snode.get_nodes()]
out_tensor_info = ",".join( out_tensor_info = ",".join(
[ [
@ -1352,7 +1452,7 @@ def enforce_comm_ordering_for_fsdp(
mutating_buf = next(iter(ag_group_node.get_buffer_names())) mutating_buf = next(iter(ag_group_node.get_buffer_names()))
for o in prev_ag_wait.get_outputs(): for o in prev_ag_wait.get_outputs():
ag_group_node.add_fake_dep( ag_group_node.add_fake_dep(
WeakDep(o.get_name(), mutating_buf=mutating_buf) WeakDep(o.get_name(), mutating_buf=mutating_buf, is_fake=True)
) )
prev_ag_wait = wait_group_node prev_ag_wait = wait_group_node
@ -1364,7 +1464,7 @@ def enforce_comm_ordering_for_fsdp(
mutating_buf = next(iter(rs_group_node.get_buffer_names())) mutating_buf = next(iter(rs_group_node.get_buffer_names()))
for o in prev_rs_wait.get_outputs(): for o in prev_rs_wait.get_outputs():
rs_group_node.add_fake_dep( rs_group_node.add_fake_dep(
WeakDep(o.get_name(), mutating_buf=mutating_buf) WeakDep(o.get_name(), mutating_buf=mutating_buf, is_fake=True)
) )
prev_rs_wait = wait_group_node prev_rs_wait = wait_group_node

View File

@ -342,6 +342,12 @@ class WeakDep(Dep):
name: str name: str
# Buffer that is doing the mutation # Buffer that is doing the mutation
mutating_buf: str mutating_buf: str
# WeakDep's are also used to add dependencies to prevent some specific reordering,
# E.g. collectives global ordering.
# But if other pass guarantees proper ordering by its logic,
# This additional "fake" deps will be holding optimizations.
# This flag is used to identify those additional deps.
is_fake: bool = False
@property @property
def index(self) -> sympy.Expr: def index(self) -> sympy.Expr:
@ -352,7 +358,7 @@ class WeakDep(Dep):
def rename(self, renames: dict[str, str]) -> "WeakDep": def rename(self, renames: dict[str, str]) -> "WeakDep":
if self.name in renames: if self.name in renames:
return WeakDep(renames[self.name], self.mutating_buf) return WeakDep(renames[self.name], self.mutating_buf, self.is_fake)
return self return self
def numbytes_hint(self) -> int: def numbytes_hint(self) -> int: