mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[simple_fsdp][inductor_collectives] rewrite reorder_collectives, sink_waits_iterative (#158062)
Differential Revision: [D78159013](https://our.internmc.facebook.com/intern/diff/D78159013) Pull Request resolved: https://github.com/pytorch/pytorch/pull/158062 Approved by: https://github.com/wconstab
This commit is contained in:
parent
ef256ad17b
commit
eeb0783fe6
|
|
@ -19,6 +19,7 @@ from torch._dynamo.utils import same
|
|||
from torch._inductor.comms import (
|
||||
_reorder_communication_preserving_peak_memory_internal,
|
||||
ReorderInfo,
|
||||
sink_waits_iterative,
|
||||
)
|
||||
from torch._inductor.compile_fx import compile_fx as inductor_compile_fx
|
||||
from torch._inductor.scheduler import BaseSchedulerNode
|
||||
|
|
@ -1621,7 +1622,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
|||
comm from moving due to data dependency.
|
||||
"""
|
||||
|
||||
def func(x, w, ag_0, ag_1, *, tag, ranks, group_size):
|
||||
def func(x, w, ag_0, ag_1, ag_2, ag_3, *, tag, ranks, group_size):
|
||||
# do some unrelated matmuls
|
||||
y = torch.mm(x, w)
|
||||
|
||||
|
|
@ -1654,14 +1655,52 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
|||
# wait op
|
||||
rs_0_out = torch.ops.c10d_functional.wait_tensor(rs_0_out)
|
||||
rs_1_out = torch.ops.c10d_functional.wait_tensor(rs_1_out)
|
||||
y += torch.mm(2 * x, 2 * w)
|
||||
|
||||
return y, ag_0_out, ag_1_out, rs_0_out, rs_1_out
|
||||
# cast the inputs
|
||||
ag_2_cast = ag_2.to(torch.bfloat16)
|
||||
ag_3_cast = ag_3.to(torch.bfloat16)
|
||||
ag_2_out = torch.ops._c10d_functional.all_gather_into_tensor(
|
||||
ag_2_cast, group_size, group_name
|
||||
)
|
||||
ag_3_out = torch.ops._c10d_functional.all_gather_into_tensor(
|
||||
ag_3_cast, group_size, group_name
|
||||
)
|
||||
|
||||
# wait op
|
||||
ag_2_out = torch.ops.c10d_functional.wait_tensor(ag_2_out)
|
||||
ag_3_out = torch.ops.c10d_functional.wait_tensor(ag_3_out)
|
||||
|
||||
#
|
||||
rs_2_out = torch.ops._c10d_functional.reduce_scatter_tensor(
|
||||
ag_2_cast, "sum", group_size, group_name
|
||||
)
|
||||
rs_3_out = torch.ops._c10d_functional.reduce_scatter_tensor(
|
||||
ag_3_cast, "sum", group_size, group_name
|
||||
)
|
||||
|
||||
# wait op
|
||||
rs_2_out = torch.ops.c10d_functional.wait_tensor(rs_2_out)
|
||||
rs_3_out = torch.ops.c10d_functional.wait_tensor(rs_3_out)
|
||||
return (
|
||||
y,
|
||||
ag_0_out,
|
||||
ag_1_out,
|
||||
ag_2_out,
|
||||
ag_3_out,
|
||||
rs_0_out,
|
||||
rs_1_out,
|
||||
rs_2_out,
|
||||
rs_3_out,
|
||||
)
|
||||
|
||||
x = torch.ones(4, 384, device="cuda", dtype=torch.float32)
|
||||
w = torch.ones(384, 512, device="cuda", dtype=torch.float32)
|
||||
ag_0 = torch.ones(384, 512, device="cuda", dtype=torch.float32)
|
||||
ag_1 = torch.ones(512, device="cuda", dtype=torch.float32)
|
||||
inputs = [x, w, ag_0, ag_1]
|
||||
ag_0 = torch.ones(1024, 512, device="cuda", dtype=torch.float32)
|
||||
ag_1 = torch.ones(512, 1024, device="cuda", dtype=torch.float32)
|
||||
ag_2 = torch.ones(1024, 512, device="cuda", dtype=torch.float32)
|
||||
ag_3 = torch.ones(512, 1024, device="cuda", dtype=torch.float32)
|
||||
inputs = [x, w, ag_0, ag_1, ag_2, ag_3]
|
||||
|
||||
# get stats directly from the internal helper without affecting the real pass's signature
|
||||
node_stats: Optional[dict[BaseSchedulerNode, ReorderInfo]] = None
|
||||
|
|
@ -1679,11 +1718,15 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
|||
with torch._inductor.config.patch(
|
||||
{
|
||||
"bucket_all_gathers_fx": "all",
|
||||
"bucket_all_gathers_fx_bucket_size_determinator": lambda _: 2,
|
||||
"bucket_reduce_scatters_fx": "all",
|
||||
"bucket_reduce_scatters_fx_bucket_size_determinator": lambda _: 2,
|
||||
"reorder_for_compute_comm_overlap": True,
|
||||
"reorder_for_compute_comm_overlap_passes": [
|
||||
sink_waits_iterative,
|
||||
_reorder_communication_preserving_peak_memory,
|
||||
],
|
||||
"allow_buffer_reuse": False,
|
||||
}
|
||||
):
|
||||
compiled = torch.compile(func)
|
||||
|
|
@ -1694,31 +1737,30 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
|||
FileCheck()
|
||||
.check_count(
|
||||
"torch.ops._c10d_functional.all_gather_into_tensor_out.default(",
|
||||
count=1,
|
||||
count=2,
|
||||
exactly=True,
|
||||
)
|
||||
.check(
|
||||
"extern_kernels.mm",
|
||||
)
|
||||
.check(
|
||||
"extern_kernels.addmm",
|
||||
)
|
||||
.run(code)
|
||||
)
|
||||
(
|
||||
FileCheck()
|
||||
.check_count(
|
||||
"torch.ops._c10d_functional.reduce_scatter_tensor.default(",
|
||||
count=1,
|
||||
count=2,
|
||||
exactly=True,
|
||||
)
|
||||
.run(code)
|
||||
)
|
||||
(
|
||||
FileCheck()
|
||||
.check(
|
||||
"torch.ops._c10d_functional.all_gather_into_tensor_out.default(",
|
||||
)
|
||||
.check(
|
||||
"torch.ops._c10d_functional.reduce_scatter_tensor.default(",
|
||||
)
|
||||
.check(
|
||||
"extern_kernels.mm",
|
||||
)
|
||||
.check(
|
||||
"extern_kernels.addmm",
|
||||
)
|
||||
.run(code)
|
||||
)
|
||||
out = compiled(*inputs, **self.get_world_trs())
|
||||
|
|
@ -1726,7 +1768,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
|||
assert same(out, correct), f"{out} va {correct}"
|
||||
assert node_stats is not None
|
||||
self.assertTrue(isinstance(node_stats, dict))
|
||||
self.assertEqual(len(node_stats), 2)
|
||||
self.assertEqual(len(node_stats), 4)
|
||||
it = iter(node_stats.values())
|
||||
node_stat0 = next(it)
|
||||
self.assertTrue(node_stat0.moves > 0)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ from __future__ import annotations
|
|||
|
||||
import heapq
|
||||
import importlib
|
||||
import itertools
|
||||
import logging
|
||||
import operator
|
||||
import sys
|
||||
|
|
@ -149,9 +148,8 @@ def is_gemm_like(node: Optional[Union[IRNode, Operation]]) -> bool:
|
|||
return True
|
||||
|
||||
if (
|
||||
hasattr(node, "python_kernel_name")
|
||||
and node.python_kernel_name == "extern_kernels.mm"
|
||||
):
|
||||
python_kernel_name := getattr(node, "python_kernel_name", None)
|
||||
) and "extern_kernels" in python_kernel_name:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
|
@ -189,15 +187,23 @@ def _group_name(snode, with_bufs=False) -> str:
|
|||
def _reorder_communication_preserving_peak_memory_internal(
|
||||
snodes: list[BaseSchedulerNode],
|
||||
) -> tuple[list[BaseSchedulerNode], dict[BaseSchedulerNode, ReorderInfo]]:
|
||||
from torch._inductor.scheduler import GroupedSchedulerNode, init_group_node
|
||||
|
||||
original_snodes_num = len(snodes)
|
||||
"""
|
||||
Internal testing helper that also returns debug info.
|
||||
Returns:
|
||||
- reordered snodes list
|
||||
- dict {snode: ReorderInfo}
|
||||
"""
|
||||
has_collectives = False
|
||||
for snode in snodes:
|
||||
if contains_collective(snode):
|
||||
has_collectives = True
|
||||
break
|
||||
if not has_collectives:
|
||||
return snodes, {}
|
||||
|
||||
from torch._inductor.scheduler import GroupedSchedulerNode
|
||||
|
||||
original_snodes_num = len(snodes)
|
||||
# heuristic to avoid degenerating to quadratic time
|
||||
graph_inputs: OrderedSet[str] = OrderedSet(V.graph.graph_inputs.keys())
|
||||
graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names())
|
||||
|
|
@ -208,7 +214,8 @@ def _reorder_communication_preserving_peak_memory_internal(
|
|||
snodes, name_to_freeable_input_buf, graph_outputs
|
||||
)
|
||||
runtimes = {snode: estimate_op_runtime(snode) for snode in snodes}
|
||||
snode_to_curr_memory = dict(zip(snodes, curr_memory))
|
||||
_curr_memory = dict(zip(snodes, curr_memory))
|
||||
_curr_memory[None] = 0 # type: ignore[index]
|
||||
|
||||
# debug stats
|
||||
stats: dict[BaseSchedulerNode, ReorderInfo] = {}
|
||||
|
|
@ -232,153 +239,151 @@ def _reorder_communication_preserving_peak_memory_internal(
|
|||
_temp_group_visit_leaves(snode, accumulate_time)
|
||||
return max(0, comm_time - compute_time)
|
||||
|
||||
MOVE_LIMIT = len(snodes) * 100
|
||||
total_moves = 0
|
||||
# TODO - experiment with whether this limit is useful, setting `len(snodes)` disables it
|
||||
PER_COLLECTIVE_PREFETCH_LIMIT = len(snodes)
|
||||
if config.reorder_prefetch_limit is not None:
|
||||
PER_COLLECTIVE_PREFETCH_LIMIT = config.reorder_prefetch_limit
|
||||
|
||||
# Dicts to keep track of "next" and "previous" as double-linked structure during grouping
|
||||
_prev: dict[BaseSchedulerNode, Optional[BaseSchedulerNode]] = {}
|
||||
_next: dict[BaseSchedulerNode, Optional[BaseSchedulerNode]] = {}
|
||||
_prev: dict[Optional[BaseSchedulerNode], Optional[BaseSchedulerNode]] = {}
|
||||
_next: dict[Optional[BaseSchedulerNode], Optional[BaseSchedulerNode]] = {}
|
||||
for i, snode in enumerate(snodes):
|
||||
_prev[snode] = snodes[i - 1] if i > 0 else None
|
||||
_next[snode] = snodes[i + 1] if i < len(snodes) - 1 else None
|
||||
_curr_memory = dict(zip(snodes, curr_memory))
|
||||
_curr_memory[None] = 0 # type: ignore[index]
|
||||
|
||||
gsnodes: list[GroupedSchedulerNode] = [
|
||||
GroupedSchedulerNode(snode.scheduler, [snode], temp_grouping=True)
|
||||
for snode in snodes
|
||||
]
|
||||
for i, gsnode in enumerate(gsnodes):
|
||||
snode = gsnode.snodes[0] # type: ignore[attr-defined]
|
||||
if contains_collective(snode):
|
||||
reorder_info = stats[snode] = ReorderInfo()
|
||||
_head = snodes[0]
|
||||
|
||||
def _group_nodes(head, tail):
|
||||
ret = []
|
||||
n = head
|
||||
while True:
|
||||
if n is not None:
|
||||
ret.append(n)
|
||||
if n == tail:
|
||||
break
|
||||
n = _next[n]
|
||||
return ret
|
||||
|
||||
def _group_names(head, tail):
|
||||
ret = ""
|
||||
for n in _group_nodes(head, tail):
|
||||
if ret:
|
||||
ret += "~"
|
||||
ret += n.get_name()
|
||||
return ret
|
||||
|
||||
curr = _head
|
||||
while _next[curr] is not None:
|
||||
if contains_collective(curr):
|
||||
reorder_info = stats[curr] = ReorderInfo()
|
||||
reorder_info.initial_exposed = reorder_info.final_exposed = (
|
||||
exposed_communication_time(snode, snodes[i + 1 :])
|
||||
exposed_communication_time(curr, _group_nodes(_next[curr], None))
|
||||
)
|
||||
if total_moves >= MOVE_LIMIT:
|
||||
reorder_info.limiting_factor = "move limit"
|
||||
continue
|
||||
|
||||
for j in range(i - 1, -1, -1):
|
||||
prev_gsnode = gsnodes[j]
|
||||
if len(prev_gsnode.snodes) == 0:
|
||||
continue
|
||||
|
||||
if j < max(0, i - PER_COLLECTIVE_PREFETCH_LIMIT):
|
||||
reorder_info.limiting_factor = "prefetch limit"
|
||||
break
|
||||
if contains_collective(prev_gsnode):
|
||||
candidate = _prev[curr]
|
||||
group_head = curr
|
||||
group_tail = curr
|
||||
group_peak_memory = _curr_memory[curr]
|
||||
while candidate is not None:
|
||||
if contains_collective(candidate):
|
||||
reorder_info.limiting_factor = "collective ordering"
|
||||
break
|
||||
|
||||
dep_names = OrderedSet([s.name for s in snode.unmet_dependencies])
|
||||
prev_outs = prev_gsnode.get_outputs()
|
||||
group = GroupedSchedulerNode(
|
||||
curr.scheduler,
|
||||
_group_nodes(group_head, group_tail),
|
||||
temp_grouping=True,
|
||||
)
|
||||
|
||||
data_deps = {s.name: s for s in group.unmet_dependencies}
|
||||
candidate_outs = candidate.get_outputs()
|
||||
data_dep = None
|
||||
for o in prev_outs:
|
||||
if o.get_name() in dep_names:
|
||||
data_dep = o.get_name()
|
||||
for o in candidate_outs:
|
||||
if d := data_deps.get(o.get_name(), None):
|
||||
if isinstance(d, WeakDep) and d.is_fake:
|
||||
continue
|
||||
data_dep = d
|
||||
break
|
||||
|
||||
if data_dep is not None:
|
||||
|
||||
def is_groupable(prev_gsnode):
|
||||
def is_groupable(candidate):
|
||||
# preserve ordering
|
||||
if contains_collective(prev_gsnode):
|
||||
return False
|
||||
if contains_collective(candidate):
|
||||
return False, "contains_collective"
|
||||
|
||||
if contains_gemm_like(prev_gsnode):
|
||||
return False
|
||||
return True
|
||||
if contains_gemm_like(candidate):
|
||||
return False, "contains_gemm_like"
|
||||
return True, None
|
||||
|
||||
if is_groupable(prev_gsnode):
|
||||
new_snodes = prev_gsnode.snodes + gsnode.snodes
|
||||
init_group_node(gsnode, gsnode.scheduler, new_snodes)
|
||||
prev_gsnode.snodes = []
|
||||
is_grp, grp_reason = is_groupable(candidate)
|
||||
if is_grp:
|
||||
group_head = candidate
|
||||
group_peak_memory = max(
|
||||
group_peak_memory, _curr_memory[candidate]
|
||||
)
|
||||
reorder_info.grouped += 1
|
||||
reorder_info.grouped_info = gsnode.get_name()
|
||||
reorder_info.grouped_info = _group_names(group_head, group_tail)
|
||||
candidate = _prev[candidate]
|
||||
continue
|
||||
else:
|
||||
msg = (
|
||||
f"data dependency {data_dep}(dep_names:{dep_names})"
|
||||
f" prev_gsnode.outputs:{[o.get_name() for o in prev_outs]}"
|
||||
f"data dependency {data_dep}(dep_names:{list(data_deps.keys())})"
|
||||
f"\n candidate:{candidate.get_name()}(os:{[candidate.get_buffer_names()]})"
|
||||
f"dep on {_group_names(group_head, group_tail)}"
|
||||
f"\n non_group_reason:{grp_reason}"
|
||||
)
|
||||
reorder_info.limiting_factor = msg
|
||||
break
|
||||
|
||||
if peak_memory - curr_memory[j] < curr_memory[j - 1] - curr_memory[j]:
|
||||
delta_memory_candidate = (
|
||||
_curr_memory[candidate] - _curr_memory[_prev[candidate]] # type: ignore[index]
|
||||
)
|
||||
|
||||
if group_peak_memory - delta_memory_candidate > peak_memory:
|
||||
reorder_info.limiting_factor = "peak memory"
|
||||
break
|
||||
if reorder_info.final_exposed > runtimes[snode]:
|
||||
reorder_info.limiting_factor = "sufficient overlapping"
|
||||
break
|
||||
|
||||
reorder_info.moves += 1
|
||||
total_moves += 1
|
||||
|
||||
# swapping nodes j and j+1 affects curr memory at j only
|
||||
# j_plus_one_alloc = curr_memory[j + 1] - curr_memory[j]
|
||||
# j_alloc = curr_memory[j] - curr_memory[j - 1]
|
||||
# curr_memory[j] = curr_memory[j] - j_alloc + j_plus_one_alloc
|
||||
def swap_curr_memory_with_previous(
|
||||
snode_j_plus_one, snode_j, snode_j_minus_one
|
||||
):
|
||||
curr_memory_j_plus_one = snode_to_curr_memory[snode_j_plus_one]
|
||||
curr_memory_j = snode_to_curr_memory[snode_j]
|
||||
curr_memory_j_minus_one = (
|
||||
snode_to_curr_memory[snode_j_minus_one]
|
||||
if snode_j_minus_one is not None
|
||||
else 0
|
||||
)
|
||||
j_plus_one_alloc = curr_memory_j_plus_one - curr_memory_j
|
||||
j_alloc = curr_memory_j - curr_memory_j_minus_one
|
||||
snode_to_curr_memory[snode_j] = (
|
||||
curr_memory_j - j_alloc + j_plus_one_alloc
|
||||
)
|
||||
mem_deltas = {}
|
||||
for n in [candidate, *_group_nodes(group_head, group_tail)]:
|
||||
mem_deltas[n] = _curr_memory[n] - _curr_memory[_prev[n]] # type: ignore[index]
|
||||
# swap (candidate, group_head...group_tail)
|
||||
# Before:
|
||||
# candidate_prev -0-> candidate -1-> group_head...group_tail -2-> group_tail_next
|
||||
# After:
|
||||
# candidate_prev -0-> group_head...group_tail -1-> candidate -2-> group_tail_next
|
||||
# 0
|
||||
candidate_prev = _prev[candidate]
|
||||
if candidate_prev:
|
||||
_next[candidate_prev] = group_head
|
||||
_prev[group_head] = candidate_prev
|
||||
|
||||
# Recompuing curr_mem for swapping grouped nodes j (group A) and j + 1 (group B)
|
||||
# swap([A0, A1, A2], [B0, B1]) --> [B0, B1], [A0, A1, A2]
|
||||
# decomposing to:
|
||||
# swap(A2, B0) -> A0, A1, B0, A2, B1
|
||||
# swap(A2, B1) -> A0, A1, B0, B1, A2
|
||||
# swap(A1, B0) -> A0, B0, A1, B1, A2
|
||||
# swap(A1, B1) -> A0, B0, B1, A1, A2
|
||||
# swap(A0, B0) -> B0, A0, B1, A1, A2
|
||||
# swap(A0, B1) -> B0, B1, A0, A1, A2
|
||||
for _j in range(len(gsnodes[j].snodes) - 1, -1, -1): # group A
|
||||
snode_j = gsnodes[j].snodes[_j]
|
||||
for _i, snode_i in enumerate(gsnode.snodes): # group B
|
||||
swap_curr_memory_with_previous(
|
||||
snode_j_plus_one=snode_i,
|
||||
snode_j=snode_j,
|
||||
snode_j_minus_one=_prev[snode_j],
|
||||
)
|
||||
# 2
|
||||
group_tail_next = _next[group_tail]
|
||||
if group_tail_next:
|
||||
_prev[group_tail_next] = candidate
|
||||
_next[candidate] = group_tail_next
|
||||
|
||||
# Update _next and _prev for swap [snode_j, snode_i] -> [snode_i, snode_j]
|
||||
first = snode_j
|
||||
second = snode_i
|
||||
first_prev = _prev[first]
|
||||
second_next = _next[second]
|
||||
if first_prev:
|
||||
_next[first_prev] = second
|
||||
_prev[second] = first_prev
|
||||
# 1
|
||||
_prev[candidate] = group_tail
|
||||
_next[group_tail] = candidate
|
||||
|
||||
if second_next:
|
||||
_prev[second_next] = first
|
||||
_next[first] = second_next
|
||||
if _head == candidate:
|
||||
_head = group_head
|
||||
|
||||
_next[second] = first
|
||||
_prev[first] = second
|
||||
|
||||
tmp = gsnodes[j]
|
||||
gsnodes[j] = gsnodes[j + 1]
|
||||
gsnodes[j + 1] = tmp
|
||||
reorder_info.final_exposed = exposed_communication_time(
|
||||
snode,
|
||||
itertools.chain(
|
||||
gsnode.snodes[1:], *[n.snodes for n in gsnodes[j + 1 :]]
|
||||
),
|
||||
curr, _group_nodes(_next[curr], None)
|
||||
)
|
||||
# Recompute curr_memory
|
||||
_prev_curr_memory = _curr_memory[_prev[group_head]] # type: ignore[index]
|
||||
for n in _group_nodes(group_head, candidate):
|
||||
_curr_memory[n] = _prev_curr_memory = (
|
||||
_prev_curr_memory + mem_deltas[n]
|
||||
)
|
||||
candidate = _prev[group_head]
|
||||
curr = _next[curr] # type: ignore[assignment]
|
||||
|
||||
node_stats = stats
|
||||
improvement = {snode: node_stats[snode].improvement for snode in node_stats}
|
||||
|
|
@ -426,17 +431,13 @@ def _reorder_communication_preserving_peak_memory_internal(
|
|||
reorder_log_str += str(headers) + "\n"
|
||||
reorder_log_str += "\n".join(map(str, rows))
|
||||
|
||||
grouping_logs: list[str] = []
|
||||
flatten_gsnodes: list[BaseSchedulerNode] = []
|
||||
for i, gsnode in enumerate(gsnodes):
|
||||
if isinstance(gsnode, GroupedSchedulerNode) and gsnode.temp_grouping:
|
||||
flatten_gsnodes.extend(gsnode.snodes)
|
||||
else:
|
||||
flatten_gsnodes.append(gsnode)
|
||||
|
||||
grouping_log_str = "\n".join(grouping_logs)
|
||||
reorder_log_str += "\n"
|
||||
reorder_log_str += grouping_log_str
|
||||
new_snodes = _group_nodes(_head, None)
|
||||
assert len(new_snodes) == original_snodes_num
|
||||
new_peak_memory, curr_memory = estimate_peak_memory(
|
||||
new_snodes, name_to_freeable_input_buf, graph_outputs
|
||||
)
|
||||
reorder_log_str += f"\n peak_memory_before:{peak_memory}"
|
||||
reorder_log_str += f"\n peak_memory_after:{new_peak_memory}"
|
||||
|
||||
overlap_log.info(reorder_log_str)
|
||||
trace_structured(
|
||||
|
|
@ -448,8 +449,7 @@ def _reorder_communication_preserving_peak_memory_internal(
|
|||
payload_fn=lambda: reorder_log_str,
|
||||
)
|
||||
|
||||
assert len(flatten_gsnodes) == original_snodes_num
|
||||
return flatten_gsnodes, stats
|
||||
return new_snodes, stats
|
||||
|
||||
|
||||
def _schedule_for_comm(
|
||||
|
|
@ -623,7 +623,9 @@ def decide_global_ordering_of_comms(
|
|||
# Enforce ordering by making previous comm a `WeakDep` dependency of the next comm
|
||||
mutating_buf = next(iter(comm_nodes[i].get_buffer_names()))
|
||||
for buf in comm_nodes[i - 1].get_buffer_names():
|
||||
comm_nodes[i].add_fake_dep(WeakDep(buf, mutating_buf=mutating_buf))
|
||||
comm_nodes[i].add_fake_dep(
|
||||
WeakDep(buf, mutating_buf=mutating_buf, is_fake=True)
|
||||
)
|
||||
|
||||
return nodes
|
||||
|
||||
|
|
@ -640,66 +642,166 @@ class SinkWaitInfo:
|
|||
def _sink_waits_iterative_internal(
|
||||
snodes: list[BaseSchedulerNode],
|
||||
) -> tuple[list[BaseSchedulerNode], dict[BaseSchedulerNode, SinkWaitInfo]]:
|
||||
from torch._inductor.scheduler import GroupedSchedulerNode, init_group_node
|
||||
from torch._inductor.scheduler import GroupedSchedulerNode
|
||||
|
||||
original_snodes_num = len(snodes)
|
||||
if original_snodes_num == 0:
|
||||
return snodes, {}
|
||||
graph_inputs: OrderedSet[str] = OrderedSet(V.graph.graph_inputs.keys())
|
||||
graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names())
|
||||
name_to_freeable_input_buf: dict[str, FreeableInputBuffer] = get_freeable_input_buf(
|
||||
snodes, graph_inputs
|
||||
)
|
||||
peak_memory, curr_memory = estimate_peak_memory(
|
||||
snodes, name_to_freeable_input_buf, graph_outputs
|
||||
)
|
||||
|
||||
n = len(snodes)
|
||||
stats: dict[BaseSchedulerNode, SinkWaitInfo] = {}
|
||||
gsnodes: list[GroupedSchedulerNode] = [
|
||||
GroupedSchedulerNode(snode.scheduler, [snode], temp_grouping=True)
|
||||
for snode in snodes
|
||||
]
|
||||
for i in range(n - 1, -1, -1):
|
||||
gsnode = gsnodes[i]
|
||||
if contains_wait(gsnode):
|
||||
info = stats[gsnode.snodes[0]] = SinkWaitInfo()
|
||||
for j in range(i + 1, n):
|
||||
wait_gsnode = gsnodes[j - 1]
|
||||
wait_outs = wait_gsnode.get_outputs()
|
||||
next_gsnode = gsnodes[j]
|
||||
dep_names = OrderedSet([s.name for s in next_gsnode.unmet_dependencies])
|
||||
_prev: dict[Optional[BaseSchedulerNode], Optional[BaseSchedulerNode]] = {}
|
||||
_next: dict[Optional[BaseSchedulerNode], Optional[BaseSchedulerNode]] = {}
|
||||
_head = snodes[0]
|
||||
for i, snode in enumerate(snodes):
|
||||
_prev[snode] = snodes[i - 1] if i > 0 else None
|
||||
_next[snode] = snodes[i + 1] if i < len(snodes) - 1 else None
|
||||
_curr_memory = dict(zip(snodes, curr_memory))
|
||||
_curr_memory[None] = 0 # type: ignore[index]
|
||||
|
||||
def _group_nodes(head, tail):
|
||||
ret = []
|
||||
n = head
|
||||
while True:
|
||||
if n is not None:
|
||||
ret.append(n)
|
||||
if n == tail:
|
||||
break
|
||||
n = _next[n]
|
||||
return ret
|
||||
|
||||
def _group_names(head, tail):
|
||||
ret = ""
|
||||
for n in _group_nodes(head, tail):
|
||||
if ret:
|
||||
ret += "~"
|
||||
ret += n.get_name()
|
||||
return ret
|
||||
|
||||
curr = snodes[-1]
|
||||
|
||||
processed_waits = OrderedSet() # type: ignore[var-annotated]
|
||||
while _prev[curr] is not None:
|
||||
if contains_wait(curr) and curr not in processed_waits:
|
||||
processed_waits.add(curr)
|
||||
info = stats[curr] = SinkWaitInfo()
|
||||
candidate = _next[curr]
|
||||
wait_snode = curr
|
||||
group_head = curr
|
||||
group_tail = curr
|
||||
group_peak_memory = _curr_memory[curr]
|
||||
while candidate is not None:
|
||||
group = GroupedSchedulerNode(
|
||||
wait_snode.scheduler,
|
||||
_group_nodes(group_head, group_tail),
|
||||
temp_grouping=True,
|
||||
)
|
||||
group_outs = group.get_outputs()
|
||||
|
||||
data_deps = {s.name: s for s in candidate.unmet_dependencies}
|
||||
data_dep = None
|
||||
for o in wait_outs:
|
||||
if o.get_name() in dep_names:
|
||||
data_dep = o.get_name()
|
||||
for o in group_outs:
|
||||
if d := data_deps.get(o.get_name(), None):
|
||||
if isinstance(d, WeakDep) and d.is_fake:
|
||||
continue
|
||||
data_dep = d
|
||||
break
|
||||
# 1. If we have data_dep - we can not swap => trying to group
|
||||
# 2. If swap candidate and current node both contain collectives => trying to group
|
||||
if data_dep is not None or (
|
||||
both_contain_comms := (
|
||||
contains_collective(wait_gsnode)
|
||||
and contains_collective(next_gsnode)
|
||||
contains_collective(group) and contains_collective(candidate)
|
||||
)
|
||||
):
|
||||
|
||||
def is_groupable(snode):
|
||||
return not contains_gemm_like(snode)
|
||||
# We do not want to group with collectives to not reorder them forward.
|
||||
if contains_collective(snode):
|
||||
return (
|
||||
False,
|
||||
f"candidate contains collective {snode.get_name()}",
|
||||
)
|
||||
if contains_gemm_like(snode):
|
||||
return (
|
||||
False,
|
||||
f"candidate contains gemm_like {snode.get_name()}",
|
||||
)
|
||||
return True, None
|
||||
|
||||
if is_groupable(next_gsnode):
|
||||
new_snodes = wait_gsnode.snodes + next_gsnode.snodes
|
||||
init_group_node(next_gsnode, gsnode.scheduler, new_snodes)
|
||||
wait_gsnode.snodes = []
|
||||
is_grp, grp_reason = is_groupable(candidate)
|
||||
if is_grp:
|
||||
group_tail = candidate
|
||||
group_peak_memory = max(
|
||||
group_peak_memory, _curr_memory[candidate]
|
||||
)
|
||||
info.grouped += 1
|
||||
info.grouped_info = _group_name(next_gsnode)
|
||||
info.grouped_info = _group_names(group_head, group_tail)
|
||||
candidate = _next[candidate]
|
||||
continue
|
||||
elif (data_dep is None) and both_contain_comms:
|
||||
info.limiting_factor = (
|
||||
f"collective ordering {_group_name(wait_gsnode)}"
|
||||
f" with candidate:{_group_name(next_gsnode)}"
|
||||
)
|
||||
else:
|
||||
info.limiting_factor = (
|
||||
f"data dependency {data_dep}(dep_names:{dep_names})"
|
||||
f" candidate:{_group_name(next_gsnode)} dep on {_group_name(wait_gsnode)}"
|
||||
f" outs:{[o.get_name() for o in wait_outs]}"
|
||||
f"collective ordering {_group_names(group_head, group_tail)}"
|
||||
f" with candidate:{candidate.get_name()}"
|
||||
)
|
||||
break
|
||||
info.moves += 1
|
||||
info.moves_info += f"+{_group_name(next_gsnode)}"
|
||||
else:
|
||||
info.limiting_factor = (
|
||||
f"data dependency {data_dep}(dep_names:{list(data_deps.keys())})"
|
||||
f"\n candidate:{candidate.get_name()}(os:{[candidate.get_buffer_names()]})"
|
||||
f"dep on {_group_names(group_head, group_tail)}"
|
||||
f"\n outs:{[o.get_name() for o in group_outs]}"
|
||||
f"\n non_group_reason:{grp_reason}"
|
||||
)
|
||||
break
|
||||
candidate_delta_memory = (
|
||||
_curr_memory[candidate] - _curr_memory[_prev[candidate]] # type: ignore[index]
|
||||
)
|
||||
if group_peak_memory + candidate_delta_memory > peak_memory:
|
||||
info.limiting_factor = "peak_memory"
|
||||
break
|
||||
|
||||
info.moves += 1
|
||||
info.moves_info += f"+{candidate.get_name()}"
|
||||
|
||||
# group_head_prev -0-> candidate -1-> group_head...group_tail -2-> candidate_next
|
||||
mem_deltas = {}
|
||||
for n in [candidate, *_group_nodes(group_head, group_tail)]:
|
||||
mem_deltas[n] = _curr_memory[n] - _curr_memory[_prev[n]] # type: ignore[index]
|
||||
# 0:
|
||||
group_head_prev = _prev[group_head]
|
||||
if group_head_prev:
|
||||
_next[group_head_prev] = candidate
|
||||
_prev[candidate] = group_head_prev
|
||||
|
||||
# 2:
|
||||
candidate_next = _next[candidate]
|
||||
if candidate_next:
|
||||
_prev[candidate_next] = group_tail
|
||||
_next[group_tail] = candidate_next
|
||||
|
||||
# 1:
|
||||
_prev[group_head] = candidate
|
||||
_next[candidate] = group_head
|
||||
if group_head == _head:
|
||||
_head = candidate
|
||||
|
||||
# Recompute curr_memory
|
||||
_prev_curr_memory = _curr_memory[_prev[candidate]] # type: ignore[index]
|
||||
for n in _group_nodes(candidate, group_tail):
|
||||
_curr_memory[n] = _prev_curr_memory = (
|
||||
_prev_curr_memory + mem_deltas[n]
|
||||
)
|
||||
|
||||
candidate = _next[group_tail]
|
||||
curr = _prev[curr] # type: ignore[assignment]
|
||||
|
||||
# Swapping snodes j and j - 1
|
||||
tmp = gsnodes[j - 1]
|
||||
gsnodes[j - 1] = gsnodes[j]
|
||||
gsnodes[j] = tmp
|
||||
headers = [
|
||||
"Wait node",
|
||||
"grouped",
|
||||
|
|
@ -732,16 +834,13 @@ def _sink_waits_iterative_internal(
|
|||
log_str += str(headers) + "\n"
|
||||
log_str += "\n".join(map(str, rows))
|
||||
overlap_log.info(log_str)
|
||||
grouping_logs = []
|
||||
flatten_snodes = []
|
||||
for i, gsnode in enumerate(gsnodes):
|
||||
grouping_logs.append(f"gsnode[{i}]:{_group_name(gsnode, with_bufs=True)}")
|
||||
if isinstance(gsnode, GroupedSchedulerNode) and gsnode.temp_grouping:
|
||||
flatten_snodes.extend(gsnode.snodes)
|
||||
else:
|
||||
flatten_snodes.append(gsnode)
|
||||
grouping_log_str = "\n".join(grouping_logs)
|
||||
log_str += grouping_log_str
|
||||
new_snodes = _group_nodes(_head, None)
|
||||
assert len(new_snodes) == original_snodes_num
|
||||
new_peak_memory, curr_memory = estimate_peak_memory(
|
||||
new_snodes, name_to_freeable_input_buf, graph_outputs
|
||||
)
|
||||
log_str += f"\n peak_memory_before:{peak_memory}"
|
||||
log_str += f"\n peak_memory_after:{new_peak_memory}"
|
||||
trace_structured(
|
||||
"artifact",
|
||||
metadata_fn=lambda: {
|
||||
|
|
@ -750,8 +849,7 @@ def _sink_waits_iterative_internal(
|
|||
},
|
||||
payload_fn=lambda: log_str,
|
||||
)
|
||||
assert len(flatten_snodes) == n
|
||||
return flatten_snodes, stats
|
||||
return new_snodes, stats
|
||||
|
||||
|
||||
def sink_waits_iterative(
|
||||
|
|
@ -777,7 +875,9 @@ def node_summary(snode):
|
|||
if len(snodes) == 1:
|
||||
detail = ""
|
||||
if isinstance(snode.node, (ir.ExternKernelOut, ir._CollectiveKernel)):
|
||||
detail = f" ({snode.node.python_kernel_name})"
|
||||
outs_str = f"outs:{[o.get_name() for o in snode.get_outputs()]}"
|
||||
ins_str = f"ins:{[d.name for d in snode.unmet_dependencies]}"
|
||||
detail = f" {snode.get_name()} ({snode.node.python_kernel_name})\n {outs_str}\n ({ins_str})"
|
||||
layouts = [child.node.get_output_spec() for child in snode.get_nodes()]
|
||||
out_tensor_info = ",".join(
|
||||
[
|
||||
|
|
@ -1352,7 +1452,7 @@ def enforce_comm_ordering_for_fsdp(
|
|||
mutating_buf = next(iter(ag_group_node.get_buffer_names()))
|
||||
for o in prev_ag_wait.get_outputs():
|
||||
ag_group_node.add_fake_dep(
|
||||
WeakDep(o.get_name(), mutating_buf=mutating_buf)
|
||||
WeakDep(o.get_name(), mutating_buf=mutating_buf, is_fake=True)
|
||||
)
|
||||
prev_ag_wait = wait_group_node
|
||||
|
||||
|
|
@ -1364,7 +1464,7 @@ def enforce_comm_ordering_for_fsdp(
|
|||
mutating_buf = next(iter(rs_group_node.get_buffer_names()))
|
||||
for o in prev_rs_wait.get_outputs():
|
||||
rs_group_node.add_fake_dep(
|
||||
WeakDep(o.get_name(), mutating_buf=mutating_buf)
|
||||
WeakDep(o.get_name(), mutating_buf=mutating_buf, is_fake=True)
|
||||
)
|
||||
prev_rs_wait = wait_group_node
|
||||
|
||||
|
|
|
|||
|
|
@ -342,6 +342,12 @@ class WeakDep(Dep):
|
|||
name: str
|
||||
# Buffer that is doing the mutation
|
||||
mutating_buf: str
|
||||
# WeakDep's are also used to add dependencies to prevent some specific reordering,
|
||||
# E.g. collectives global ordering.
|
||||
# But if other pass guarantees proper ordering by its logic,
|
||||
# This additional "fake" deps will be holding optimizations.
|
||||
# This flag is used to identify those additional deps.
|
||||
is_fake: bool = False
|
||||
|
||||
@property
|
||||
def index(self) -> sympy.Expr:
|
||||
|
|
@ -352,7 +358,7 @@ class WeakDep(Dep):
|
|||
|
||||
def rename(self, renames: dict[str, str]) -> "WeakDep":
|
||||
if self.name in renames:
|
||||
return WeakDep(renames[self.name], self.mutating_buf)
|
||||
return WeakDep(renames[self.name], self.mutating_buf, self.is_fake)
|
||||
return self
|
||||
|
||||
def numbytes_hint(self) -> int:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user