[inductor collectives] sink waits iterative (#157708)

Differential Revision: [D77861763](https://our.internmc.facebook.com/intern/diff/D77861763)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157708
Approved by: https://github.com/wconstab
ghstack dependencies: #157706
This commit is contained in:
IvanKobzarev 2025-07-07 11:29:58 -07:00 committed by PyTorch MergeBot
parent 2af7c67e48
commit 8134684d44
2 changed files with 145 additions and 0 deletions

View File

@ -19,6 +19,7 @@ from torch._dynamo.utils import same
from torch._inductor.comms import (
_reorder_communication_preserving_peak_memory_internal,
ReorderInfo,
sink_waits_iterative,
)
from torch._inductor.compile_fx import compile_fx as inductor_compile_fx
from torch._inductor.scheduler import BaseSchedulerNode
@ -1613,6 +1614,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
"reorder_for_compute_comm_overlap": True,
"reorder_for_compute_comm_overlap_passes": [
_reorder_communication_preserving_peak_memory,
sink_waits_iterative,
],
}
):

View File

@ -175,6 +175,17 @@ def _temp_group_visit_leaves(snode, fn):
fn(snode)
def _group_name(snode, with_bufs=False) -> str:
ret = ""
for n in snode.snodes:
if ret:
ret += "_"
ret += n.get_name()
if with_bufs:
ret += f"{list(snode.get_buffer_names())}"
return ret
def _reorder_communication_preserving_peak_memory_internal(
snodes: list[BaseSchedulerNode],
) -> tuple[list[BaseSchedulerNode], dict[BaseSchedulerNode, ReorderInfo]]:
@ -617,6 +628,138 @@ def decide_global_ordering_of_comms(
return nodes
@dataclass
class SinkWaitInfo:
grouped: int = 0
grouped_info: str = ""
moves: int = 0
moves_info: str = ""
limiting_factor: str = "None"
def _sink_waits_iterative_internal(
snodes: list[BaseSchedulerNode],
) -> tuple[list[BaseSchedulerNode], dict[BaseSchedulerNode, SinkWaitInfo]]:
from torch._inductor.scheduler import GroupedSchedulerNode, init_group_node
n = len(snodes)
stats: dict[BaseSchedulerNode, SinkWaitInfo] = {}
gsnodes: list[GroupedSchedulerNode] = [
GroupedSchedulerNode(snode.scheduler, [snode], temp_grouping=True)
for snode in snodes
]
for i in range(n - 1, -1, -1):
gsnode = gsnodes[i]
if contains_wait(gsnode):
info = stats[gsnode.snodes[0]] = SinkWaitInfo()
for j in range(i + 1, n):
wait_gsnode = gsnodes[j - 1]
wait_outs = wait_gsnode.get_outputs()
next_gsnode = gsnodes[j]
dep_names = OrderedSet([s.name for s in next_gsnode.unmet_dependencies])
data_dep = None
for o in wait_outs:
if o.get_name() in dep_names:
data_dep = o.get_name()
break
# 1. If we have data_dep - we can not swap => trying to group
# 2. If swap candidate and current node boths contain collectives => trying to group
if data_dep is not None or (
both_contain_comms := (
contains_collective(wait_gsnode)
and contains_collective(next_gsnode)
)
):
def is_groupable(snode):
return not contains_gemm_like(snode)
if is_groupable(next_gsnode):
new_snodes = wait_gsnode.snodes + next_gsnode.snodes
init_group_node(next_gsnode, gsnode.scheduler, new_snodes)
wait_gsnode.snodes = []
info.grouped += 1
info.grouped_info = _group_name(next_gsnode)
continue
elif (data_dep is None) and both_contain_comms:
info.limiting_factor = (
f"collective ordering {_group_name(wait_gsnode)}"
f" with candidate:{_group_name(next_gsnode)}"
)
else:
info.limiting_factor = (
f"data dependency {data_dep}(dep_names:{dep_names})"
f" candidate:{_group_name(next_gsnode)} dep on {_group_name(wait_gsnode)}"
f" outs:{[o.get_name() for o in wait_outs]}"
)
break
info.moves += 1
info.moves_info += f"+{_group_name(next_gsnode)}"
# Swapping snodes j and j - 1
tmp = gsnodes[j - 1]
gsnodes[j - 1] = gsnodes[j]
gsnodes[j] = tmp
headers = [
"Wait node",
"grouped",
"grouped_info",
"moves",
"moves_info",
"limiting factor",
]
rows = [
[
node_summary(snode),
info.grouped,
info.grouped_info,
info.moves,
info.moves_info,
info.limiting_factor,
]
for snode, info in stats.items()
]
log_str = ""
if importlib.util.find_spec("tabulate"):
from tabulate import tabulate
log_str += tabulate(
rows,
headers=headers,
)
else:
log_str += "Please `pip install tabulate` to nicely render overlap stats.\n"
log_str += str(headers) + "\n"
log_str += "\n".join(map(str, rows))
overlap_log.info(log_str)
grouping_logs = []
flatten_snodes = []
for i, gsnode in enumerate(gsnodes):
grouping_logs.append(f"gsnode[{i}]:{_group_name(gsnode, with_bufs=True)}")
if isinstance(gsnode, GroupedSchedulerNode) and gsnode.temp_grouping:
flatten_snodes.extend(gsnode.snodes)
else:
flatten_snodes.append(gsnode)
grouping_log_str = "\n".join(grouping_logs)
log_str += grouping_log_str
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "sink_waits_iterative_info",
"encoding": "string",
},
payload_fn=lambda: log_str,
)
assert len(flatten_snodes) == n
return flatten_snodes, stats
def sink_waits_iterative(
snodes: list[BaseSchedulerNode],
) -> list[BaseSchedulerNode]:
return _sink_waits_iterative_internal(snodes)[0]
def estimate_op_runtime(snode: BaseSchedulerNode) -> float:
"""
Returns estimated op runtime in nanoseconds (ns)