[simple_fsdp][inductor_collectives] rewrite reorder_collectives, sink_waits_iterative (#158062)

Differential Revision: [D78159013](https://our.internmc.facebook.com/intern/diff/D78159013)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158062
Approved by: https://github.com/wconstab
This commit is contained in:
IvanKobzarev 2025-07-17 05:48:35 -07:00 committed by PyTorch MergeBot
parent ef256ad17b
commit eeb0783fe6
3 changed files with 351 additions and 203 deletions

View File

@ -19,6 +19,7 @@ from torch._dynamo.utils import same
from torch._inductor.comms import (
_reorder_communication_preserving_peak_memory_internal,
ReorderInfo,
sink_waits_iterative,
)
from torch._inductor.compile_fx import compile_fx as inductor_compile_fx
from torch._inductor.scheduler import BaseSchedulerNode
@ -1621,7 +1622,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
comm from moving due to data dependency.
"""
def func(x, w, ag_0, ag_1, *, tag, ranks, group_size):
def func(x, w, ag_0, ag_1, ag_2, ag_3, *, tag, ranks, group_size):
# do some unrelated matmuls
y = torch.mm(x, w)
@ -1654,14 +1655,52 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
# wait op
rs_0_out = torch.ops.c10d_functional.wait_tensor(rs_0_out)
rs_1_out = torch.ops.c10d_functional.wait_tensor(rs_1_out)
y += torch.mm(2 * x, 2 * w)
return y, ag_0_out, ag_1_out, rs_0_out, rs_1_out
# cast the inputs
ag_2_cast = ag_2.to(torch.bfloat16)
ag_3_cast = ag_3.to(torch.bfloat16)
ag_2_out = torch.ops._c10d_functional.all_gather_into_tensor(
ag_2_cast, group_size, group_name
)
ag_3_out = torch.ops._c10d_functional.all_gather_into_tensor(
ag_3_cast, group_size, group_name
)
# wait op
ag_2_out = torch.ops.c10d_functional.wait_tensor(ag_2_out)
ag_3_out = torch.ops.c10d_functional.wait_tensor(ag_3_out)
#
rs_2_out = torch.ops._c10d_functional.reduce_scatter_tensor(
ag_2_cast, "sum", group_size, group_name
)
rs_3_out = torch.ops._c10d_functional.reduce_scatter_tensor(
ag_3_cast, "sum", group_size, group_name
)
# wait op
rs_2_out = torch.ops.c10d_functional.wait_tensor(rs_2_out)
rs_3_out = torch.ops.c10d_functional.wait_tensor(rs_3_out)
return (
y,
ag_0_out,
ag_1_out,
ag_2_out,
ag_3_out,
rs_0_out,
rs_1_out,
rs_2_out,
rs_3_out,
)
x = torch.ones(4, 384, device="cuda", dtype=torch.float32)
w = torch.ones(384, 512, device="cuda", dtype=torch.float32)
ag_0 = torch.ones(384, 512, device="cuda", dtype=torch.float32)
ag_1 = torch.ones(512, device="cuda", dtype=torch.float32)
inputs = [x, w, ag_0, ag_1]
ag_0 = torch.ones(1024, 512, device="cuda", dtype=torch.float32)
ag_1 = torch.ones(512, 1024, device="cuda", dtype=torch.float32)
ag_2 = torch.ones(1024, 512, device="cuda", dtype=torch.float32)
ag_3 = torch.ones(512, 1024, device="cuda", dtype=torch.float32)
inputs = [x, w, ag_0, ag_1, ag_2, ag_3]
# get stats directly from the internal helper without affecting the real pass's signature
node_stats: Optional[dict[BaseSchedulerNode, ReorderInfo]] = None
@ -1679,11 +1718,15 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
with torch._inductor.config.patch(
{
"bucket_all_gathers_fx": "all",
"bucket_all_gathers_fx_bucket_size_determinator": lambda _: 2,
"bucket_reduce_scatters_fx": "all",
"bucket_reduce_scatters_fx_bucket_size_determinator": lambda _: 2,
"reorder_for_compute_comm_overlap": True,
"reorder_for_compute_comm_overlap_passes": [
sink_waits_iterative,
_reorder_communication_preserving_peak_memory,
],
"allow_buffer_reuse": False,
}
):
compiled = torch.compile(func)
@ -1694,31 +1737,30 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
FileCheck()
.check_count(
"torch.ops._c10d_functional.all_gather_into_tensor_out.default(",
count=1,
count=2,
exactly=True,
)
.check(
"extern_kernels.mm",
)
.check(
"extern_kernels.addmm",
)
.run(code)
)
(
FileCheck()
.check_count(
"torch.ops._c10d_functional.reduce_scatter_tensor.default(",
count=1,
count=2,
exactly=True,
)
.run(code)
)
(
FileCheck()
.check(
"torch.ops._c10d_functional.all_gather_into_tensor_out.default(",
)
.check(
"torch.ops._c10d_functional.reduce_scatter_tensor.default(",
)
.check(
"extern_kernels.mm",
)
.check(
"extern_kernels.addmm",
)
.run(code)
)
out = compiled(*inputs, **self.get_world_trs())
@ -1726,7 +1768,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
assert same(out, correct), f"{out} va {correct}"
assert node_stats is not None
self.assertTrue(isinstance(node_stats, dict))
self.assertEqual(len(node_stats), 2)
self.assertEqual(len(node_stats), 4)
it = iter(node_stats.values())
node_stat0 = next(it)
self.assertTrue(node_stat0.moves > 0)

View File

@ -4,7 +4,6 @@ from __future__ import annotations
import heapq
import importlib
import itertools
import logging
import operator
import sys
@ -149,9 +148,8 @@ def is_gemm_like(node: Optional[Union[IRNode, Operation]]) -> bool:
return True
if (
hasattr(node, "python_kernel_name")
and node.python_kernel_name == "extern_kernels.mm"
):
python_kernel_name := getattr(node, "python_kernel_name", None)
) and "extern_kernels" in python_kernel_name:
return True
return False
@ -189,15 +187,23 @@ def _group_name(snode, with_bufs=False) -> str:
def _reorder_communication_preserving_peak_memory_internal(
snodes: list[BaseSchedulerNode],
) -> tuple[list[BaseSchedulerNode], dict[BaseSchedulerNode, ReorderInfo]]:
from torch._inductor.scheduler import GroupedSchedulerNode, init_group_node
original_snodes_num = len(snodes)
"""
Internal testing helper that also returns debug info.
Returns:
- reordered snodes list
- dict {snode: ReorderInfo}
"""
has_collectives = False
for snode in snodes:
if contains_collective(snode):
has_collectives = True
break
if not has_collectives:
return snodes, {}
from torch._inductor.scheduler import GroupedSchedulerNode
original_snodes_num = len(snodes)
# heuristic to avoid degenerating to quadratic time
graph_inputs: OrderedSet[str] = OrderedSet(V.graph.graph_inputs.keys())
graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names())
@ -208,7 +214,8 @@ def _reorder_communication_preserving_peak_memory_internal(
snodes, name_to_freeable_input_buf, graph_outputs
)
runtimes = {snode: estimate_op_runtime(snode) for snode in snodes}
snode_to_curr_memory = dict(zip(snodes, curr_memory))
_curr_memory = dict(zip(snodes, curr_memory))
_curr_memory[None] = 0 # type: ignore[index]
# debug stats
stats: dict[BaseSchedulerNode, ReorderInfo] = {}
@ -232,153 +239,151 @@ def _reorder_communication_preserving_peak_memory_internal(
_temp_group_visit_leaves(snode, accumulate_time)
return max(0, comm_time - compute_time)
MOVE_LIMIT = len(snodes) * 100
total_moves = 0
# TODO - experiment with whether this limit is useful, setting `len(snodes)` disables it
PER_COLLECTIVE_PREFETCH_LIMIT = len(snodes)
if config.reorder_prefetch_limit is not None:
PER_COLLECTIVE_PREFETCH_LIMIT = config.reorder_prefetch_limit
# Dicts to keep track of "next" and "previous" as double-linked structure during grouping
_prev: dict[BaseSchedulerNode, Optional[BaseSchedulerNode]] = {}
_next: dict[BaseSchedulerNode, Optional[BaseSchedulerNode]] = {}
_prev: dict[Optional[BaseSchedulerNode], Optional[BaseSchedulerNode]] = {}
_next: dict[Optional[BaseSchedulerNode], Optional[BaseSchedulerNode]] = {}
for i, snode in enumerate(snodes):
_prev[snode] = snodes[i - 1] if i > 0 else None
_next[snode] = snodes[i + 1] if i < len(snodes) - 1 else None
_curr_memory = dict(zip(snodes, curr_memory))
_curr_memory[None] = 0 # type: ignore[index]
gsnodes: list[GroupedSchedulerNode] = [
GroupedSchedulerNode(snode.scheduler, [snode], temp_grouping=True)
for snode in snodes
]
for i, gsnode in enumerate(gsnodes):
snode = gsnode.snodes[0] # type: ignore[attr-defined]
if contains_collective(snode):
reorder_info = stats[snode] = ReorderInfo()
reorder_info.initial_exposed = reorder_info.final_exposed = (
exposed_communication_time(snode, snodes[i + 1 :])
)
if total_moves >= MOVE_LIMIT:
reorder_info.limiting_factor = "move limit"
continue
_head = snodes[0]
for j in range(i - 1, -1, -1):
prev_gsnode = gsnodes[j]
if len(prev_gsnode.snodes) == 0:
continue
if j < max(0, i - PER_COLLECTIVE_PREFETCH_LIMIT):
reorder_info.limiting_factor = "prefetch limit"
def _group_nodes(head, tail):
ret = []
n = head
while True:
if n is not None:
ret.append(n)
if n == tail:
break
if contains_collective(prev_gsnode):
n = _next[n]
return ret
def _group_names(head, tail):
ret = ""
for n in _group_nodes(head, tail):
if ret:
ret += "~"
ret += n.get_name()
return ret
curr = _head
while _next[curr] is not None:
if contains_collective(curr):
reorder_info = stats[curr] = ReorderInfo()
reorder_info.initial_exposed = reorder_info.final_exposed = (
exposed_communication_time(curr, _group_nodes(_next[curr], None))
)
candidate = _prev[curr]
group_head = curr
group_tail = curr
group_peak_memory = _curr_memory[curr]
while candidate is not None:
if contains_collective(candidate):
reorder_info.limiting_factor = "collective ordering"
break
dep_names = OrderedSet([s.name for s in snode.unmet_dependencies])
prev_outs = prev_gsnode.get_outputs()
group = GroupedSchedulerNode(
curr.scheduler,
_group_nodes(group_head, group_tail),
temp_grouping=True,
)
data_deps = {s.name: s for s in group.unmet_dependencies}
candidate_outs = candidate.get_outputs()
data_dep = None
for o in prev_outs:
if o.get_name() in dep_names:
data_dep = o.get_name()
for o in candidate_outs:
if d := data_deps.get(o.get_name(), None):
if isinstance(d, WeakDep) and d.is_fake:
continue
data_dep = d
break
if data_dep is not None:
def is_groupable(prev_gsnode):
def is_groupable(candidate):
# preserve ordering
if contains_collective(prev_gsnode):
return False
if contains_collective(candidate):
return False, "contains_collective"
if contains_gemm_like(prev_gsnode):
return False
return True
if contains_gemm_like(candidate):
return False, "contains_gemm_like"
return True, None
if is_groupable(prev_gsnode):
new_snodes = prev_gsnode.snodes + gsnode.snodes
init_group_node(gsnode, gsnode.scheduler, new_snodes)
prev_gsnode.snodes = []
is_grp, grp_reason = is_groupable(candidate)
if is_grp:
group_head = candidate
group_peak_memory = max(
group_peak_memory, _curr_memory[candidate]
)
reorder_info.grouped += 1
reorder_info.grouped_info = gsnode.get_name()
reorder_info.grouped_info = _group_names(group_head, group_tail)
candidate = _prev[candidate]
continue
else:
msg = (
f"data dependency {data_dep}(dep_names:{dep_names})"
f" prev_gsnode.outputs:{[o.get_name() for o in prev_outs]}"
f"data dependency {data_dep}(dep_names:{list(data_deps.keys())})"
f"\n candidate:{candidate.get_name()}(os:{[candidate.get_buffer_names()]})"
f"dep on {_group_names(group_head, group_tail)}"
f"\n non_group_reason:{grp_reason}"
)
reorder_info.limiting_factor = msg
break
if peak_memory - curr_memory[j] < curr_memory[j - 1] - curr_memory[j]:
delta_memory_candidate = (
_curr_memory[candidate] - _curr_memory[_prev[candidate]] # type: ignore[index]
)
if group_peak_memory - delta_memory_candidate > peak_memory:
reorder_info.limiting_factor = "peak memory"
break
if reorder_info.final_exposed > runtimes[snode]:
reorder_info.limiting_factor = "sufficient overlapping"
break
reorder_info.moves += 1
total_moves += 1
# swapping nodes j and j+1 affects curr memory at j only
# j_plus_one_alloc = curr_memory[j + 1] - curr_memory[j]
# j_alloc = curr_memory[j] - curr_memory[j - 1]
# curr_memory[j] = curr_memory[j] - j_alloc + j_plus_one_alloc
def swap_curr_memory_with_previous(
snode_j_plus_one, snode_j, snode_j_minus_one
):
curr_memory_j_plus_one = snode_to_curr_memory[snode_j_plus_one]
curr_memory_j = snode_to_curr_memory[snode_j]
curr_memory_j_minus_one = (
snode_to_curr_memory[snode_j_minus_one]
if snode_j_minus_one is not None
else 0
)
j_plus_one_alloc = curr_memory_j_plus_one - curr_memory_j
j_alloc = curr_memory_j - curr_memory_j_minus_one
snode_to_curr_memory[snode_j] = (
curr_memory_j - j_alloc + j_plus_one_alloc
)
mem_deltas = {}
for n in [candidate, *_group_nodes(group_head, group_tail)]:
mem_deltas[n] = _curr_memory[n] - _curr_memory[_prev[n]] # type: ignore[index]
# swap (candidate, group_head...group_tail)
# Before:
# candidate_prev -0-> candidate -1-> group_head...group_tail -2-> group_tail_next
# After:
# candidate_prev -0-> group_head...group_tail -1-> candidate -2-> group_tail_next
# 0
candidate_prev = _prev[candidate]
if candidate_prev:
_next[candidate_prev] = group_head
_prev[group_head] = candidate_prev
# Recompuing curr_mem for swapping grouped nodes j (group A) and j + 1 (group B)
# swap([A0, A1, A2], [B0, B1]) --> [B0, B1], [A0, A1, A2]
# decomposing to:
# swap(A2, B0) -> A0, A1, B0, A2, B1
# swap(A2, B1) -> A0, A1, B0, B1, A2
# swap(A1, B0) -> A0, B0, A1, B1, A2
# swap(A1, B1) -> A0, B0, B1, A1, A2
# swap(A0, B0) -> B0, A0, B1, A1, A2
# swap(A0, B1) -> B0, B1, A0, A1, A2
for _j in range(len(gsnodes[j].snodes) - 1, -1, -1): # group A
snode_j = gsnodes[j].snodes[_j]
for _i, snode_i in enumerate(gsnode.snodes): # group B
swap_curr_memory_with_previous(
snode_j_plus_one=snode_i,
snode_j=snode_j,
snode_j_minus_one=_prev[snode_j],
)
# 2
group_tail_next = _next[group_tail]
if group_tail_next:
_prev[group_tail_next] = candidate
_next[candidate] = group_tail_next
# Update _next and _prev for swap [snode_j, snode_i] -> [snode_i, snode_j]
first = snode_j
second = snode_i
first_prev = _prev[first]
second_next = _next[second]
if first_prev:
_next[first_prev] = second
_prev[second] = first_prev
# 1
_prev[candidate] = group_tail
_next[group_tail] = candidate
if second_next:
_prev[second_next] = first
_next[first] = second_next
if _head == candidate:
_head = group_head
_next[second] = first
_prev[first] = second
tmp = gsnodes[j]
gsnodes[j] = gsnodes[j + 1]
gsnodes[j + 1] = tmp
reorder_info.final_exposed = exposed_communication_time(
snode,
itertools.chain(
gsnode.snodes[1:], *[n.snodes for n in gsnodes[j + 1 :]]
),
curr, _group_nodes(_next[curr], None)
)
# Recompute curr_memory
_prev_curr_memory = _curr_memory[_prev[group_head]] # type: ignore[index]
for n in _group_nodes(group_head, candidate):
_curr_memory[n] = _prev_curr_memory = (
_prev_curr_memory + mem_deltas[n]
)
candidate = _prev[group_head]
curr = _next[curr] # type: ignore[assignment]
node_stats = stats
improvement = {snode: node_stats[snode].improvement for snode in node_stats}
@ -426,17 +431,13 @@ def _reorder_communication_preserving_peak_memory_internal(
reorder_log_str += str(headers) + "\n"
reorder_log_str += "\n".join(map(str, rows))
grouping_logs: list[str] = []
flatten_gsnodes: list[BaseSchedulerNode] = []
for i, gsnode in enumerate(gsnodes):
if isinstance(gsnode, GroupedSchedulerNode) and gsnode.temp_grouping:
flatten_gsnodes.extend(gsnode.snodes)
else:
flatten_gsnodes.append(gsnode)
grouping_log_str = "\n".join(grouping_logs)
reorder_log_str += "\n"
reorder_log_str += grouping_log_str
new_snodes = _group_nodes(_head, None)
assert len(new_snodes) == original_snodes_num
new_peak_memory, curr_memory = estimate_peak_memory(
new_snodes, name_to_freeable_input_buf, graph_outputs
)
reorder_log_str += f"\n peak_memory_before:{peak_memory}"
reorder_log_str += f"\n peak_memory_after:{new_peak_memory}"
overlap_log.info(reorder_log_str)
trace_structured(
@ -448,8 +449,7 @@ def _reorder_communication_preserving_peak_memory_internal(
payload_fn=lambda: reorder_log_str,
)
assert len(flatten_gsnodes) == original_snodes_num
return flatten_gsnodes, stats
return new_snodes, stats
def _schedule_for_comm(
@ -623,7 +623,9 @@ def decide_global_ordering_of_comms(
# Enforce ordering by making previous comm a `WeakDep` dependency of the next comm
mutating_buf = next(iter(comm_nodes[i].get_buffer_names()))
for buf in comm_nodes[i - 1].get_buffer_names():
comm_nodes[i].add_fake_dep(WeakDep(buf, mutating_buf=mutating_buf))
comm_nodes[i].add_fake_dep(
WeakDep(buf, mutating_buf=mutating_buf, is_fake=True)
)
return nodes
@ -640,66 +642,166 @@ class SinkWaitInfo:
def _sink_waits_iterative_internal(
snodes: list[BaseSchedulerNode],
) -> tuple[list[BaseSchedulerNode], dict[BaseSchedulerNode, SinkWaitInfo]]:
from torch._inductor.scheduler import GroupedSchedulerNode, init_group_node
from torch._inductor.scheduler import GroupedSchedulerNode
original_snodes_num = len(snodes)
if original_snodes_num == 0:
return snodes, {}
graph_inputs: OrderedSet[str] = OrderedSet(V.graph.graph_inputs.keys())
graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names())
name_to_freeable_input_buf: dict[str, FreeableInputBuffer] = get_freeable_input_buf(
snodes, graph_inputs
)
peak_memory, curr_memory = estimate_peak_memory(
snodes, name_to_freeable_input_buf, graph_outputs
)
n = len(snodes)
stats: dict[BaseSchedulerNode, SinkWaitInfo] = {}
gsnodes: list[GroupedSchedulerNode] = [
GroupedSchedulerNode(snode.scheduler, [snode], temp_grouping=True)
for snode in snodes
]
for i in range(n - 1, -1, -1):
gsnode = gsnodes[i]
if contains_wait(gsnode):
info = stats[gsnode.snodes[0]] = SinkWaitInfo()
for j in range(i + 1, n):
wait_gsnode = gsnodes[j - 1]
wait_outs = wait_gsnode.get_outputs()
next_gsnode = gsnodes[j]
dep_names = OrderedSet([s.name for s in next_gsnode.unmet_dependencies])
_prev: dict[Optional[BaseSchedulerNode], Optional[BaseSchedulerNode]] = {}
_next: dict[Optional[BaseSchedulerNode], Optional[BaseSchedulerNode]] = {}
_head = snodes[0]
for i, snode in enumerate(snodes):
_prev[snode] = snodes[i - 1] if i > 0 else None
_next[snode] = snodes[i + 1] if i < len(snodes) - 1 else None
_curr_memory = dict(zip(snodes, curr_memory))
_curr_memory[None] = 0 # type: ignore[index]
def _group_nodes(head, tail):
ret = []
n = head
while True:
if n is not None:
ret.append(n)
if n == tail:
break
n = _next[n]
return ret
def _group_names(head, tail):
ret = ""
for n in _group_nodes(head, tail):
if ret:
ret += "~"
ret += n.get_name()
return ret
curr = snodes[-1]
processed_waits = OrderedSet() # type: ignore[var-annotated]
while _prev[curr] is not None:
if contains_wait(curr) and curr not in processed_waits:
processed_waits.add(curr)
info = stats[curr] = SinkWaitInfo()
candidate = _next[curr]
wait_snode = curr
group_head = curr
group_tail = curr
group_peak_memory = _curr_memory[curr]
while candidate is not None:
group = GroupedSchedulerNode(
wait_snode.scheduler,
_group_nodes(group_head, group_tail),
temp_grouping=True,
)
group_outs = group.get_outputs()
data_deps = {s.name: s for s in candidate.unmet_dependencies}
data_dep = None
for o in wait_outs:
if o.get_name() in dep_names:
data_dep = o.get_name()
for o in group_outs:
if d := data_deps.get(o.get_name(), None):
if isinstance(d, WeakDep) and d.is_fake:
continue
data_dep = d
break
# 1. If we have data_dep - we can not swap => trying to group
# 2. If swap candidate and current node both contain collectives => trying to group
if data_dep is not None or (
both_contain_comms := (
contains_collective(wait_gsnode)
and contains_collective(next_gsnode)
contains_collective(group) and contains_collective(candidate)
)
):
def is_groupable(snode):
return not contains_gemm_like(snode)
# We do not want to group with collectives to not reorder them forward.
if contains_collective(snode):
return (
False,
f"candidate contains collective {snode.get_name()}",
)
if contains_gemm_like(snode):
return (
False,
f"candidate contains gemm_like {snode.get_name()}",
)
return True, None
if is_groupable(next_gsnode):
new_snodes = wait_gsnode.snodes + next_gsnode.snodes
init_group_node(next_gsnode, gsnode.scheduler, new_snodes)
wait_gsnode.snodes = []
is_grp, grp_reason = is_groupable(candidate)
if is_grp:
group_tail = candidate
group_peak_memory = max(
group_peak_memory, _curr_memory[candidate]
)
info.grouped += 1
info.grouped_info = _group_name(next_gsnode)
info.grouped_info = _group_names(group_head, group_tail)
candidate = _next[candidate]
continue
elif (data_dep is None) and both_contain_comms:
info.limiting_factor = (
f"collective ordering {_group_name(wait_gsnode)}"
f" with candidate:{_group_name(next_gsnode)}"
)
else:
info.limiting_factor = (
f"data dependency {data_dep}(dep_names:{dep_names})"
f" candidate:{_group_name(next_gsnode)} dep on {_group_name(wait_gsnode)}"
f" outs:{[o.get_name() for o in wait_outs]}"
f"collective ordering {_group_names(group_head, group_tail)}"
f" with candidate:{candidate.get_name()}"
)
break
info.moves += 1
info.moves_info += f"+{_group_name(next_gsnode)}"
else:
info.limiting_factor = (
f"data dependency {data_dep}(dep_names:{list(data_deps.keys())})"
f"\n candidate:{candidate.get_name()}(os:{[candidate.get_buffer_names()]})"
f"dep on {_group_names(group_head, group_tail)}"
f"\n outs:{[o.get_name() for o in group_outs]}"
f"\n non_group_reason:{grp_reason}"
)
break
candidate_delta_memory = (
_curr_memory[candidate] - _curr_memory[_prev[candidate]] # type: ignore[index]
)
if group_peak_memory + candidate_delta_memory > peak_memory:
info.limiting_factor = "peak_memory"
break
info.moves += 1
info.moves_info += f"+{candidate.get_name()}"
# group_head_prev -0-> candidate -1-> group_head...group_tail -2-> candidate_next
mem_deltas = {}
for n in [candidate, *_group_nodes(group_head, group_tail)]:
mem_deltas[n] = _curr_memory[n] - _curr_memory[_prev[n]] # type: ignore[index]
# 0:
group_head_prev = _prev[group_head]
if group_head_prev:
_next[group_head_prev] = candidate
_prev[candidate] = group_head_prev
# 2:
candidate_next = _next[candidate]
if candidate_next:
_prev[candidate_next] = group_tail
_next[group_tail] = candidate_next
# 1:
_prev[group_head] = candidate
_next[candidate] = group_head
if group_head == _head:
_head = candidate
# Recompute curr_memory
_prev_curr_memory = _curr_memory[_prev[candidate]] # type: ignore[index]
for n in _group_nodes(candidate, group_tail):
_curr_memory[n] = _prev_curr_memory = (
_prev_curr_memory + mem_deltas[n]
)
candidate = _next[group_tail]
curr = _prev[curr] # type: ignore[assignment]
# Swapping snodes j and j - 1
tmp = gsnodes[j - 1]
gsnodes[j - 1] = gsnodes[j]
gsnodes[j] = tmp
headers = [
"Wait node",
"grouped",
@ -732,16 +834,13 @@ def _sink_waits_iterative_internal(
log_str += str(headers) + "\n"
log_str += "\n".join(map(str, rows))
overlap_log.info(log_str)
grouping_logs = []
flatten_snodes = []
for i, gsnode in enumerate(gsnodes):
grouping_logs.append(f"gsnode[{i}]:{_group_name(gsnode, with_bufs=True)}")
if isinstance(gsnode, GroupedSchedulerNode) and gsnode.temp_grouping:
flatten_snodes.extend(gsnode.snodes)
else:
flatten_snodes.append(gsnode)
grouping_log_str = "\n".join(grouping_logs)
log_str += grouping_log_str
new_snodes = _group_nodes(_head, None)
assert len(new_snodes) == original_snodes_num
new_peak_memory, curr_memory = estimate_peak_memory(
new_snodes, name_to_freeable_input_buf, graph_outputs
)
log_str += f"\n peak_memory_before:{peak_memory}"
log_str += f"\n peak_memory_after:{new_peak_memory}"
trace_structured(
"artifact",
metadata_fn=lambda: {
@ -750,8 +849,7 @@ def _sink_waits_iterative_internal(
},
payload_fn=lambda: log_str,
)
assert len(flatten_snodes) == n
return flatten_snodes, stats
return new_snodes, stats
def sink_waits_iterative(
@ -777,7 +875,9 @@ def node_summary(snode):
if len(snodes) == 1:
detail = ""
if isinstance(snode.node, (ir.ExternKernelOut, ir._CollectiveKernel)):
detail = f" ({snode.node.python_kernel_name})"
outs_str = f"outs:{[o.get_name() for o in snode.get_outputs()]}"
ins_str = f"ins:{[d.name for d in snode.unmet_dependencies]}"
detail = f" {snode.get_name()} ({snode.node.python_kernel_name})\n {outs_str}\n ({ins_str})"
layouts = [child.node.get_output_spec() for child in snode.get_nodes()]
out_tensor_info = ",".join(
[
@ -1352,7 +1452,7 @@ def enforce_comm_ordering_for_fsdp(
mutating_buf = next(iter(ag_group_node.get_buffer_names()))
for o in prev_ag_wait.get_outputs():
ag_group_node.add_fake_dep(
WeakDep(o.get_name(), mutating_buf=mutating_buf)
WeakDep(o.get_name(), mutating_buf=mutating_buf, is_fake=True)
)
prev_ag_wait = wait_group_node
@ -1364,7 +1464,7 @@ def enforce_comm_ordering_for_fsdp(
mutating_buf = next(iter(rs_group_node.get_buffer_names()))
for o in prev_rs_wait.get_outputs():
rs_group_node.add_fake_dep(
WeakDep(o.get_name(), mutating_buf=mutating_buf)
WeakDep(o.get_name(), mutating_buf=mutating_buf, is_fake=True)
)
prev_rs_wait = wait_group_node

View File

@ -342,6 +342,12 @@ class WeakDep(Dep):
name: str
# Buffer that is doing the mutation
mutating_buf: str
# WeakDep's are also used to add dependencies to prevent some specific reordering,
# E.g. collectives global ordering.
# But if other pass guarantees proper ordering by its logic,
# This additional "fake" deps will be holding optimizations.
# This flag is used to identify those additional deps.
is_fake: bool = False
@property
def index(self) -> sympy.Expr:
@ -352,7 +358,7 @@ class WeakDep(Dep):
def rename(self, renames: dict[str, str]) -> "WeakDep":
if self.name in renames:
return WeakDep(renames[self.name], self.mutating_buf)
return WeakDep(renames[self.name], self.mutating_buf, self.is_fake)
return self
def numbytes_hint(self) -> int: