mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[inductor collectives] sink waits iterative (#157708)
Differential Revision: [D77861763](https://our.internmc.facebook.com/intern/diff/D77861763) Pull Request resolved: https://github.com/pytorch/pytorch/pull/157708 Approved by: https://github.com/wconstab ghstack dependencies: #157706
This commit is contained in:
parent
2af7c67e48
commit
8134684d44
|
|
@ -19,6 +19,7 @@ from torch._dynamo.utils import same
|
|||
from torch._inductor.comms import (
|
||||
_reorder_communication_preserving_peak_memory_internal,
|
||||
ReorderInfo,
|
||||
sink_waits_iterative,
|
||||
)
|
||||
from torch._inductor.compile_fx import compile_fx as inductor_compile_fx
|
||||
from torch._inductor.scheduler import BaseSchedulerNode
|
||||
|
|
@ -1613,6 +1614,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
|||
"reorder_for_compute_comm_overlap": True,
|
||||
"reorder_for_compute_comm_overlap_passes": [
|
||||
_reorder_communication_preserving_peak_memory,
|
||||
sink_waits_iterative,
|
||||
],
|
||||
}
|
||||
):
|
||||
|
|
|
|||
|
|
@ -175,6 +175,17 @@ def _temp_group_visit_leaves(snode, fn):
|
|||
fn(snode)
|
||||
|
||||
|
||||
def _group_name(snode, with_bufs=False) -> str:
|
||||
ret = ""
|
||||
for n in snode.snodes:
|
||||
if ret:
|
||||
ret += "_"
|
||||
ret += n.get_name()
|
||||
if with_bufs:
|
||||
ret += f"{list(snode.get_buffer_names())}"
|
||||
return ret
|
||||
|
||||
|
||||
def _reorder_communication_preserving_peak_memory_internal(
|
||||
snodes: list[BaseSchedulerNode],
|
||||
) -> tuple[list[BaseSchedulerNode], dict[BaseSchedulerNode, ReorderInfo]]:
|
||||
|
|
@ -617,6 +628,138 @@ def decide_global_ordering_of_comms(
|
|||
return nodes
|
||||
|
||||
|
||||
@dataclass
|
||||
class SinkWaitInfo:
|
||||
grouped: int = 0
|
||||
grouped_info: str = ""
|
||||
moves: int = 0
|
||||
moves_info: str = ""
|
||||
limiting_factor: str = "None"
|
||||
|
||||
|
||||
def _sink_waits_iterative_internal(
|
||||
snodes: list[BaseSchedulerNode],
|
||||
) -> tuple[list[BaseSchedulerNode], dict[BaseSchedulerNode, SinkWaitInfo]]:
|
||||
from torch._inductor.scheduler import GroupedSchedulerNode, init_group_node
|
||||
|
||||
n = len(snodes)
|
||||
stats: dict[BaseSchedulerNode, SinkWaitInfo] = {}
|
||||
gsnodes: list[GroupedSchedulerNode] = [
|
||||
GroupedSchedulerNode(snode.scheduler, [snode], temp_grouping=True)
|
||||
for snode in snodes
|
||||
]
|
||||
for i in range(n - 1, -1, -1):
|
||||
gsnode = gsnodes[i]
|
||||
if contains_wait(gsnode):
|
||||
info = stats[gsnode.snodes[0]] = SinkWaitInfo()
|
||||
for j in range(i + 1, n):
|
||||
wait_gsnode = gsnodes[j - 1]
|
||||
wait_outs = wait_gsnode.get_outputs()
|
||||
next_gsnode = gsnodes[j]
|
||||
dep_names = OrderedSet([s.name for s in next_gsnode.unmet_dependencies])
|
||||
data_dep = None
|
||||
for o in wait_outs:
|
||||
if o.get_name() in dep_names:
|
||||
data_dep = o.get_name()
|
||||
break
|
||||
# 1. If we have data_dep - we can not swap => trying to group
|
||||
# 2. If swap candidate and current node boths contain collectives => trying to group
|
||||
if data_dep is not None or (
|
||||
both_contain_comms := (
|
||||
contains_collective(wait_gsnode)
|
||||
and contains_collective(next_gsnode)
|
||||
)
|
||||
):
|
||||
|
||||
def is_groupable(snode):
|
||||
return not contains_gemm_like(snode)
|
||||
|
||||
if is_groupable(next_gsnode):
|
||||
new_snodes = wait_gsnode.snodes + next_gsnode.snodes
|
||||
init_group_node(next_gsnode, gsnode.scheduler, new_snodes)
|
||||
wait_gsnode.snodes = []
|
||||
info.grouped += 1
|
||||
info.grouped_info = _group_name(next_gsnode)
|
||||
continue
|
||||
elif (data_dep is None) and both_contain_comms:
|
||||
info.limiting_factor = (
|
||||
f"collective ordering {_group_name(wait_gsnode)}"
|
||||
f" with candidate:{_group_name(next_gsnode)}"
|
||||
)
|
||||
else:
|
||||
info.limiting_factor = (
|
||||
f"data dependency {data_dep}(dep_names:{dep_names})"
|
||||
f" candidate:{_group_name(next_gsnode)} dep on {_group_name(wait_gsnode)}"
|
||||
f" outs:{[o.get_name() for o in wait_outs]}"
|
||||
)
|
||||
break
|
||||
info.moves += 1
|
||||
info.moves_info += f"+{_group_name(next_gsnode)}"
|
||||
|
||||
# Swapping snodes j and j - 1
|
||||
tmp = gsnodes[j - 1]
|
||||
gsnodes[j - 1] = gsnodes[j]
|
||||
gsnodes[j] = tmp
|
||||
headers = [
|
||||
"Wait node",
|
||||
"grouped",
|
||||
"grouped_info",
|
||||
"moves",
|
||||
"moves_info",
|
||||
"limiting factor",
|
||||
]
|
||||
rows = [
|
||||
[
|
||||
node_summary(snode),
|
||||
info.grouped,
|
||||
info.grouped_info,
|
||||
info.moves,
|
||||
info.moves_info,
|
||||
info.limiting_factor,
|
||||
]
|
||||
for snode, info in stats.items()
|
||||
]
|
||||
log_str = ""
|
||||
if importlib.util.find_spec("tabulate"):
|
||||
from tabulate import tabulate
|
||||
|
||||
log_str += tabulate(
|
||||
rows,
|
||||
headers=headers,
|
||||
)
|
||||
else:
|
||||
log_str += "Please `pip install tabulate` to nicely render overlap stats.\n"
|
||||
log_str += str(headers) + "\n"
|
||||
log_str += "\n".join(map(str, rows))
|
||||
overlap_log.info(log_str)
|
||||
grouping_logs = []
|
||||
flatten_snodes = []
|
||||
for i, gsnode in enumerate(gsnodes):
|
||||
grouping_logs.append(f"gsnode[{i}]:{_group_name(gsnode, with_bufs=True)}")
|
||||
if isinstance(gsnode, GroupedSchedulerNode) and gsnode.temp_grouping:
|
||||
flatten_snodes.extend(gsnode.snodes)
|
||||
else:
|
||||
flatten_snodes.append(gsnode)
|
||||
grouping_log_str = "\n".join(grouping_logs)
|
||||
log_str += grouping_log_str
|
||||
trace_structured(
|
||||
"artifact",
|
||||
metadata_fn=lambda: {
|
||||
"name": "sink_waits_iterative_info",
|
||||
"encoding": "string",
|
||||
},
|
||||
payload_fn=lambda: log_str,
|
||||
)
|
||||
assert len(flatten_snodes) == n
|
||||
return flatten_snodes, stats
|
||||
|
||||
|
||||
def sink_waits_iterative(
|
||||
snodes: list[BaseSchedulerNode],
|
||||
) -> list[BaseSchedulerNode]:
|
||||
return _sink_waits_iterative_internal(snodes)[0]
|
||||
|
||||
|
||||
def estimate_op_runtime(snode: BaseSchedulerNode) -> float:
|
||||
"""
|
||||
Returns estimated op runtime in nanoseconds (ns)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user