mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[simple_fsdp][inductor_collectives] rewrite reorder_collectives, sink_waits_iterative (#158062)
Differential Revision: [D78159013](https://our.internmc.facebook.com/intern/diff/D78159013) Pull Request resolved: https://github.com/pytorch/pytorch/pull/158062 Approved by: https://github.com/wconstab
This commit is contained in:
parent
ef256ad17b
commit
eeb0783fe6
|
|
@ -19,6 +19,7 @@ from torch._dynamo.utils import same
|
||||||
from torch._inductor.comms import (
|
from torch._inductor.comms import (
|
||||||
_reorder_communication_preserving_peak_memory_internal,
|
_reorder_communication_preserving_peak_memory_internal,
|
||||||
ReorderInfo,
|
ReorderInfo,
|
||||||
|
sink_waits_iterative,
|
||||||
)
|
)
|
||||||
from torch._inductor.compile_fx import compile_fx as inductor_compile_fx
|
from torch._inductor.compile_fx import compile_fx as inductor_compile_fx
|
||||||
from torch._inductor.scheduler import BaseSchedulerNode
|
from torch._inductor.scheduler import BaseSchedulerNode
|
||||||
|
|
@ -1621,7 +1622,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
||||||
comm from moving due to data dependency.
|
comm from moving due to data dependency.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def func(x, w, ag_0, ag_1, *, tag, ranks, group_size):
|
def func(x, w, ag_0, ag_1, ag_2, ag_3, *, tag, ranks, group_size):
|
||||||
# do some unrelated matmuls
|
# do some unrelated matmuls
|
||||||
y = torch.mm(x, w)
|
y = torch.mm(x, w)
|
||||||
|
|
||||||
|
|
@ -1654,14 +1655,52 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
||||||
# wait op
|
# wait op
|
||||||
rs_0_out = torch.ops.c10d_functional.wait_tensor(rs_0_out)
|
rs_0_out = torch.ops.c10d_functional.wait_tensor(rs_0_out)
|
||||||
rs_1_out = torch.ops.c10d_functional.wait_tensor(rs_1_out)
|
rs_1_out = torch.ops.c10d_functional.wait_tensor(rs_1_out)
|
||||||
|
y += torch.mm(2 * x, 2 * w)
|
||||||
|
|
||||||
return y, ag_0_out, ag_1_out, rs_0_out, rs_1_out
|
# cast the inputs
|
||||||
|
ag_2_cast = ag_2.to(torch.bfloat16)
|
||||||
|
ag_3_cast = ag_3.to(torch.bfloat16)
|
||||||
|
ag_2_out = torch.ops._c10d_functional.all_gather_into_tensor(
|
||||||
|
ag_2_cast, group_size, group_name
|
||||||
|
)
|
||||||
|
ag_3_out = torch.ops._c10d_functional.all_gather_into_tensor(
|
||||||
|
ag_3_cast, group_size, group_name
|
||||||
|
)
|
||||||
|
|
||||||
|
# wait op
|
||||||
|
ag_2_out = torch.ops.c10d_functional.wait_tensor(ag_2_out)
|
||||||
|
ag_3_out = torch.ops.c10d_functional.wait_tensor(ag_3_out)
|
||||||
|
|
||||||
|
#
|
||||||
|
rs_2_out = torch.ops._c10d_functional.reduce_scatter_tensor(
|
||||||
|
ag_2_cast, "sum", group_size, group_name
|
||||||
|
)
|
||||||
|
rs_3_out = torch.ops._c10d_functional.reduce_scatter_tensor(
|
||||||
|
ag_3_cast, "sum", group_size, group_name
|
||||||
|
)
|
||||||
|
|
||||||
|
# wait op
|
||||||
|
rs_2_out = torch.ops.c10d_functional.wait_tensor(rs_2_out)
|
||||||
|
rs_3_out = torch.ops.c10d_functional.wait_tensor(rs_3_out)
|
||||||
|
return (
|
||||||
|
y,
|
||||||
|
ag_0_out,
|
||||||
|
ag_1_out,
|
||||||
|
ag_2_out,
|
||||||
|
ag_3_out,
|
||||||
|
rs_0_out,
|
||||||
|
rs_1_out,
|
||||||
|
rs_2_out,
|
||||||
|
rs_3_out,
|
||||||
|
)
|
||||||
|
|
||||||
x = torch.ones(4, 384, device="cuda", dtype=torch.float32)
|
x = torch.ones(4, 384, device="cuda", dtype=torch.float32)
|
||||||
w = torch.ones(384, 512, device="cuda", dtype=torch.float32)
|
w = torch.ones(384, 512, device="cuda", dtype=torch.float32)
|
||||||
ag_0 = torch.ones(384, 512, device="cuda", dtype=torch.float32)
|
ag_0 = torch.ones(1024, 512, device="cuda", dtype=torch.float32)
|
||||||
ag_1 = torch.ones(512, device="cuda", dtype=torch.float32)
|
ag_1 = torch.ones(512, 1024, device="cuda", dtype=torch.float32)
|
||||||
inputs = [x, w, ag_0, ag_1]
|
ag_2 = torch.ones(1024, 512, device="cuda", dtype=torch.float32)
|
||||||
|
ag_3 = torch.ones(512, 1024, device="cuda", dtype=torch.float32)
|
||||||
|
inputs = [x, w, ag_0, ag_1, ag_2, ag_3]
|
||||||
|
|
||||||
# get stats directly from the internal helper without affecting the real pass's signature
|
# get stats directly from the internal helper without affecting the real pass's signature
|
||||||
node_stats: Optional[dict[BaseSchedulerNode, ReorderInfo]] = None
|
node_stats: Optional[dict[BaseSchedulerNode, ReorderInfo]] = None
|
||||||
|
|
@ -1679,11 +1718,15 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
||||||
with torch._inductor.config.patch(
|
with torch._inductor.config.patch(
|
||||||
{
|
{
|
||||||
"bucket_all_gathers_fx": "all",
|
"bucket_all_gathers_fx": "all",
|
||||||
|
"bucket_all_gathers_fx_bucket_size_determinator": lambda _: 2,
|
||||||
"bucket_reduce_scatters_fx": "all",
|
"bucket_reduce_scatters_fx": "all",
|
||||||
|
"bucket_reduce_scatters_fx_bucket_size_determinator": lambda _: 2,
|
||||||
"reorder_for_compute_comm_overlap": True,
|
"reorder_for_compute_comm_overlap": True,
|
||||||
"reorder_for_compute_comm_overlap_passes": [
|
"reorder_for_compute_comm_overlap_passes": [
|
||||||
|
sink_waits_iterative,
|
||||||
_reorder_communication_preserving_peak_memory,
|
_reorder_communication_preserving_peak_memory,
|
||||||
],
|
],
|
||||||
|
"allow_buffer_reuse": False,
|
||||||
}
|
}
|
||||||
):
|
):
|
||||||
compiled = torch.compile(func)
|
compiled = torch.compile(func)
|
||||||
|
|
@ -1694,31 +1737,30 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
||||||
FileCheck()
|
FileCheck()
|
||||||
.check_count(
|
.check_count(
|
||||||
"torch.ops._c10d_functional.all_gather_into_tensor_out.default(",
|
"torch.ops._c10d_functional.all_gather_into_tensor_out.default(",
|
||||||
count=1,
|
count=2,
|
||||||
exactly=True,
|
exactly=True,
|
||||||
)
|
)
|
||||||
|
.check(
|
||||||
|
"extern_kernels.mm",
|
||||||
|
)
|
||||||
|
.check(
|
||||||
|
"extern_kernels.addmm",
|
||||||
|
)
|
||||||
.run(code)
|
.run(code)
|
||||||
)
|
)
|
||||||
(
|
(
|
||||||
FileCheck()
|
FileCheck()
|
||||||
.check_count(
|
.check_count(
|
||||||
"torch.ops._c10d_functional.reduce_scatter_tensor.default(",
|
"torch.ops._c10d_functional.reduce_scatter_tensor.default(",
|
||||||
count=1,
|
count=2,
|
||||||
exactly=True,
|
exactly=True,
|
||||||
)
|
)
|
||||||
.run(code)
|
|
||||||
)
|
|
||||||
(
|
|
||||||
FileCheck()
|
|
||||||
.check(
|
|
||||||
"torch.ops._c10d_functional.all_gather_into_tensor_out.default(",
|
|
||||||
)
|
|
||||||
.check(
|
|
||||||
"torch.ops._c10d_functional.reduce_scatter_tensor.default(",
|
|
||||||
)
|
|
||||||
.check(
|
.check(
|
||||||
"extern_kernels.mm",
|
"extern_kernels.mm",
|
||||||
)
|
)
|
||||||
|
.check(
|
||||||
|
"extern_kernels.addmm",
|
||||||
|
)
|
||||||
.run(code)
|
.run(code)
|
||||||
)
|
)
|
||||||
out = compiled(*inputs, **self.get_world_trs())
|
out = compiled(*inputs, **self.get_world_trs())
|
||||||
|
|
@ -1726,7 +1768,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
||||||
assert same(out, correct), f"{out} va {correct}"
|
assert same(out, correct), f"{out} va {correct}"
|
||||||
assert node_stats is not None
|
assert node_stats is not None
|
||||||
self.assertTrue(isinstance(node_stats, dict))
|
self.assertTrue(isinstance(node_stats, dict))
|
||||||
self.assertEqual(len(node_stats), 2)
|
self.assertEqual(len(node_stats), 4)
|
||||||
it = iter(node_stats.values())
|
it = iter(node_stats.values())
|
||||||
node_stat0 = next(it)
|
node_stat0 = next(it)
|
||||||
self.assertTrue(node_stat0.moves > 0)
|
self.assertTrue(node_stat0.moves > 0)
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,6 @@ from __future__ import annotations
|
||||||
|
|
||||||
import heapq
|
import heapq
|
||||||
import importlib
|
import importlib
|
||||||
import itertools
|
|
||||||
import logging
|
import logging
|
||||||
import operator
|
import operator
|
||||||
import sys
|
import sys
|
||||||
|
|
@ -149,9 +148,8 @@ def is_gemm_like(node: Optional[Union[IRNode, Operation]]) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if (
|
if (
|
||||||
hasattr(node, "python_kernel_name")
|
python_kernel_name := getattr(node, "python_kernel_name", None)
|
||||||
and node.python_kernel_name == "extern_kernels.mm"
|
) and "extern_kernels" in python_kernel_name:
|
||||||
):
|
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
@ -189,15 +187,23 @@ def _group_name(snode, with_bufs=False) -> str:
|
||||||
def _reorder_communication_preserving_peak_memory_internal(
|
def _reorder_communication_preserving_peak_memory_internal(
|
||||||
snodes: list[BaseSchedulerNode],
|
snodes: list[BaseSchedulerNode],
|
||||||
) -> tuple[list[BaseSchedulerNode], dict[BaseSchedulerNode, ReorderInfo]]:
|
) -> tuple[list[BaseSchedulerNode], dict[BaseSchedulerNode, ReorderInfo]]:
|
||||||
from torch._inductor.scheduler import GroupedSchedulerNode, init_group_node
|
|
||||||
|
|
||||||
original_snodes_num = len(snodes)
|
|
||||||
"""
|
"""
|
||||||
Internal testing helper that also returns debug info.
|
Internal testing helper that also returns debug info.
|
||||||
Returns:
|
Returns:
|
||||||
- reordered snodes list
|
- reordered snodes list
|
||||||
- dict {snode: ReorderInfo}
|
- dict {snode: ReorderInfo}
|
||||||
"""
|
"""
|
||||||
|
has_collectives = False
|
||||||
|
for snode in snodes:
|
||||||
|
if contains_collective(snode):
|
||||||
|
has_collectives = True
|
||||||
|
break
|
||||||
|
if not has_collectives:
|
||||||
|
return snodes, {}
|
||||||
|
|
||||||
|
from torch._inductor.scheduler import GroupedSchedulerNode
|
||||||
|
|
||||||
|
original_snodes_num = len(snodes)
|
||||||
# heuristic to avoid degenerating to quadratic time
|
# heuristic to avoid degenerating to quadratic time
|
||||||
graph_inputs: OrderedSet[str] = OrderedSet(V.graph.graph_inputs.keys())
|
graph_inputs: OrderedSet[str] = OrderedSet(V.graph.graph_inputs.keys())
|
||||||
graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names())
|
graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names())
|
||||||
|
|
@ -208,7 +214,8 @@ def _reorder_communication_preserving_peak_memory_internal(
|
||||||
snodes, name_to_freeable_input_buf, graph_outputs
|
snodes, name_to_freeable_input_buf, graph_outputs
|
||||||
)
|
)
|
||||||
runtimes = {snode: estimate_op_runtime(snode) for snode in snodes}
|
runtimes = {snode: estimate_op_runtime(snode) for snode in snodes}
|
||||||
snode_to_curr_memory = dict(zip(snodes, curr_memory))
|
_curr_memory = dict(zip(snodes, curr_memory))
|
||||||
|
_curr_memory[None] = 0 # type: ignore[index]
|
||||||
|
|
||||||
# debug stats
|
# debug stats
|
||||||
stats: dict[BaseSchedulerNode, ReorderInfo] = {}
|
stats: dict[BaseSchedulerNode, ReorderInfo] = {}
|
||||||
|
|
@ -232,153 +239,151 @@ def _reorder_communication_preserving_peak_memory_internal(
|
||||||
_temp_group_visit_leaves(snode, accumulate_time)
|
_temp_group_visit_leaves(snode, accumulate_time)
|
||||||
return max(0, comm_time - compute_time)
|
return max(0, comm_time - compute_time)
|
||||||
|
|
||||||
MOVE_LIMIT = len(snodes) * 100
|
|
||||||
total_moves = 0
|
total_moves = 0
|
||||||
# TODO - experiment with whether this limit is useful, setting `len(snodes)` disables it
|
|
||||||
PER_COLLECTIVE_PREFETCH_LIMIT = len(snodes)
|
|
||||||
if config.reorder_prefetch_limit is not None:
|
|
||||||
PER_COLLECTIVE_PREFETCH_LIMIT = config.reorder_prefetch_limit
|
|
||||||
|
|
||||||
# Dicts to keep track of "next" and "previous" as double-linked structure during grouping
|
# Dicts to keep track of "next" and "previous" as double-linked structure during grouping
|
||||||
_prev: dict[BaseSchedulerNode, Optional[BaseSchedulerNode]] = {}
|
_prev: dict[Optional[BaseSchedulerNode], Optional[BaseSchedulerNode]] = {}
|
||||||
_next: dict[BaseSchedulerNode, Optional[BaseSchedulerNode]] = {}
|
_next: dict[Optional[BaseSchedulerNode], Optional[BaseSchedulerNode]] = {}
|
||||||
for i, snode in enumerate(snodes):
|
for i, snode in enumerate(snodes):
|
||||||
_prev[snode] = snodes[i - 1] if i > 0 else None
|
_prev[snode] = snodes[i - 1] if i > 0 else None
|
||||||
_next[snode] = snodes[i + 1] if i < len(snodes) - 1 else None
|
_next[snode] = snodes[i + 1] if i < len(snodes) - 1 else None
|
||||||
|
_curr_memory = dict(zip(snodes, curr_memory))
|
||||||
|
_curr_memory[None] = 0 # type: ignore[index]
|
||||||
|
|
||||||
gsnodes: list[GroupedSchedulerNode] = [
|
_head = snodes[0]
|
||||||
GroupedSchedulerNode(snode.scheduler, [snode], temp_grouping=True)
|
|
||||||
for snode in snodes
|
def _group_nodes(head, tail):
|
||||||
]
|
ret = []
|
||||||
for i, gsnode in enumerate(gsnodes):
|
n = head
|
||||||
snode = gsnode.snodes[0] # type: ignore[attr-defined]
|
while True:
|
||||||
if contains_collective(snode):
|
if n is not None:
|
||||||
reorder_info = stats[snode] = ReorderInfo()
|
ret.append(n)
|
||||||
|
if n == tail:
|
||||||
|
break
|
||||||
|
n = _next[n]
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def _group_names(head, tail):
|
||||||
|
ret = ""
|
||||||
|
for n in _group_nodes(head, tail):
|
||||||
|
if ret:
|
||||||
|
ret += "~"
|
||||||
|
ret += n.get_name()
|
||||||
|
return ret
|
||||||
|
|
||||||
|
curr = _head
|
||||||
|
while _next[curr] is not None:
|
||||||
|
if contains_collective(curr):
|
||||||
|
reorder_info = stats[curr] = ReorderInfo()
|
||||||
reorder_info.initial_exposed = reorder_info.final_exposed = (
|
reorder_info.initial_exposed = reorder_info.final_exposed = (
|
||||||
exposed_communication_time(snode, snodes[i + 1 :])
|
exposed_communication_time(curr, _group_nodes(_next[curr], None))
|
||||||
)
|
)
|
||||||
if total_moves >= MOVE_LIMIT:
|
|
||||||
reorder_info.limiting_factor = "move limit"
|
|
||||||
continue
|
|
||||||
|
|
||||||
for j in range(i - 1, -1, -1):
|
candidate = _prev[curr]
|
||||||
prev_gsnode = gsnodes[j]
|
group_head = curr
|
||||||
if len(prev_gsnode.snodes) == 0:
|
group_tail = curr
|
||||||
continue
|
group_peak_memory = _curr_memory[curr]
|
||||||
|
while candidate is not None:
|
||||||
if j < max(0, i - PER_COLLECTIVE_PREFETCH_LIMIT):
|
if contains_collective(candidate):
|
||||||
reorder_info.limiting_factor = "prefetch limit"
|
|
||||||
break
|
|
||||||
if contains_collective(prev_gsnode):
|
|
||||||
reorder_info.limiting_factor = "collective ordering"
|
reorder_info.limiting_factor = "collective ordering"
|
||||||
break
|
break
|
||||||
|
|
||||||
dep_names = OrderedSet([s.name for s in snode.unmet_dependencies])
|
group = GroupedSchedulerNode(
|
||||||
prev_outs = prev_gsnode.get_outputs()
|
curr.scheduler,
|
||||||
|
_group_nodes(group_head, group_tail),
|
||||||
|
temp_grouping=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
data_deps = {s.name: s for s in group.unmet_dependencies}
|
||||||
|
candidate_outs = candidate.get_outputs()
|
||||||
data_dep = None
|
data_dep = None
|
||||||
for o in prev_outs:
|
for o in candidate_outs:
|
||||||
if o.get_name() in dep_names:
|
if d := data_deps.get(o.get_name(), None):
|
||||||
data_dep = o.get_name()
|
if isinstance(d, WeakDep) and d.is_fake:
|
||||||
|
continue
|
||||||
|
data_dep = d
|
||||||
break
|
break
|
||||||
|
|
||||||
if data_dep is not None:
|
if data_dep is not None:
|
||||||
|
|
||||||
def is_groupable(prev_gsnode):
|
def is_groupable(candidate):
|
||||||
# preserve ordering
|
# preserve ordering
|
||||||
if contains_collective(prev_gsnode):
|
if contains_collective(candidate):
|
||||||
return False
|
return False, "contains_collective"
|
||||||
|
|
||||||
if contains_gemm_like(prev_gsnode):
|
if contains_gemm_like(candidate):
|
||||||
return False
|
return False, "contains_gemm_like"
|
||||||
return True
|
return True, None
|
||||||
|
|
||||||
if is_groupable(prev_gsnode):
|
is_grp, grp_reason = is_groupable(candidate)
|
||||||
new_snodes = prev_gsnode.snodes + gsnode.snodes
|
if is_grp:
|
||||||
init_group_node(gsnode, gsnode.scheduler, new_snodes)
|
group_head = candidate
|
||||||
prev_gsnode.snodes = []
|
group_peak_memory = max(
|
||||||
|
group_peak_memory, _curr_memory[candidate]
|
||||||
|
)
|
||||||
reorder_info.grouped += 1
|
reorder_info.grouped += 1
|
||||||
reorder_info.grouped_info = gsnode.get_name()
|
reorder_info.grouped_info = _group_names(group_head, group_tail)
|
||||||
|
candidate = _prev[candidate]
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
msg = (
|
msg = (
|
||||||
f"data dependency {data_dep}(dep_names:{dep_names})"
|
f"data dependency {data_dep}(dep_names:{list(data_deps.keys())})"
|
||||||
f" prev_gsnode.outputs:{[o.get_name() for o in prev_outs]}"
|
f"\n candidate:{candidate.get_name()}(os:{[candidate.get_buffer_names()]})"
|
||||||
|
f"dep on {_group_names(group_head, group_tail)}"
|
||||||
|
f"\n non_group_reason:{grp_reason}"
|
||||||
)
|
)
|
||||||
reorder_info.limiting_factor = msg
|
reorder_info.limiting_factor = msg
|
||||||
break
|
break
|
||||||
|
|
||||||
if peak_memory - curr_memory[j] < curr_memory[j - 1] - curr_memory[j]:
|
delta_memory_candidate = (
|
||||||
|
_curr_memory[candidate] - _curr_memory[_prev[candidate]] # type: ignore[index]
|
||||||
|
)
|
||||||
|
|
||||||
|
if group_peak_memory - delta_memory_candidate > peak_memory:
|
||||||
reorder_info.limiting_factor = "peak memory"
|
reorder_info.limiting_factor = "peak memory"
|
||||||
break
|
break
|
||||||
if reorder_info.final_exposed > runtimes[snode]:
|
|
||||||
reorder_info.limiting_factor = "sufficient overlapping"
|
|
||||||
break
|
|
||||||
reorder_info.moves += 1
|
reorder_info.moves += 1
|
||||||
total_moves += 1
|
total_moves += 1
|
||||||
|
|
||||||
# swapping nodes j and j+1 affects curr memory at j only
|
mem_deltas = {}
|
||||||
# j_plus_one_alloc = curr_memory[j + 1] - curr_memory[j]
|
for n in [candidate, *_group_nodes(group_head, group_tail)]:
|
||||||
# j_alloc = curr_memory[j] - curr_memory[j - 1]
|
mem_deltas[n] = _curr_memory[n] - _curr_memory[_prev[n]] # type: ignore[index]
|
||||||
# curr_memory[j] = curr_memory[j] - j_alloc + j_plus_one_alloc
|
# swap (candidate, group_head...group_tail)
|
||||||
def swap_curr_memory_with_previous(
|
# Before:
|
||||||
snode_j_plus_one, snode_j, snode_j_minus_one
|
# candidate_prev -0-> candidate -1-> group_head...group_tail -2-> group_tail_next
|
||||||
):
|
# After:
|
||||||
curr_memory_j_plus_one = snode_to_curr_memory[snode_j_plus_one]
|
# candidate_prev -0-> group_head...group_tail -1-> candidate -2-> group_tail_next
|
||||||
curr_memory_j = snode_to_curr_memory[snode_j]
|
# 0
|
||||||
curr_memory_j_minus_one = (
|
candidate_prev = _prev[candidate]
|
||||||
snode_to_curr_memory[snode_j_minus_one]
|
if candidate_prev:
|
||||||
if snode_j_minus_one is not None
|
_next[candidate_prev] = group_head
|
||||||
else 0
|
_prev[group_head] = candidate_prev
|
||||||
)
|
|
||||||
j_plus_one_alloc = curr_memory_j_plus_one - curr_memory_j
|
|
||||||
j_alloc = curr_memory_j - curr_memory_j_minus_one
|
|
||||||
snode_to_curr_memory[snode_j] = (
|
|
||||||
curr_memory_j - j_alloc + j_plus_one_alloc
|
|
||||||
)
|
|
||||||
|
|
||||||
# Recompuing curr_mem for swapping grouped nodes j (group A) and j + 1 (group B)
|
# 2
|
||||||
# swap([A0, A1, A2], [B0, B1]) --> [B0, B1], [A0, A1, A2]
|
group_tail_next = _next[group_tail]
|
||||||
# decomposing to:
|
if group_tail_next:
|
||||||
# swap(A2, B0) -> A0, A1, B0, A2, B1
|
_prev[group_tail_next] = candidate
|
||||||
# swap(A2, B1) -> A0, A1, B0, B1, A2
|
_next[candidate] = group_tail_next
|
||||||
# swap(A1, B0) -> A0, B0, A1, B1, A2
|
|
||||||
# swap(A1, B1) -> A0, B0, B1, A1, A2
|
|
||||||
# swap(A0, B0) -> B0, A0, B1, A1, A2
|
|
||||||
# swap(A0, B1) -> B0, B1, A0, A1, A2
|
|
||||||
for _j in range(len(gsnodes[j].snodes) - 1, -1, -1): # group A
|
|
||||||
snode_j = gsnodes[j].snodes[_j]
|
|
||||||
for _i, snode_i in enumerate(gsnode.snodes): # group B
|
|
||||||
swap_curr_memory_with_previous(
|
|
||||||
snode_j_plus_one=snode_i,
|
|
||||||
snode_j=snode_j,
|
|
||||||
snode_j_minus_one=_prev[snode_j],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update _next and _prev for swap [snode_j, snode_i] -> [snode_i, snode_j]
|
# 1
|
||||||
first = snode_j
|
_prev[candidate] = group_tail
|
||||||
second = snode_i
|
_next[group_tail] = candidate
|
||||||
first_prev = _prev[first]
|
|
||||||
second_next = _next[second]
|
|
||||||
if first_prev:
|
|
||||||
_next[first_prev] = second
|
|
||||||
_prev[second] = first_prev
|
|
||||||
|
|
||||||
if second_next:
|
if _head == candidate:
|
||||||
_prev[second_next] = first
|
_head = group_head
|
||||||
_next[first] = second_next
|
|
||||||
|
|
||||||
_next[second] = first
|
|
||||||
_prev[first] = second
|
|
||||||
|
|
||||||
tmp = gsnodes[j]
|
|
||||||
gsnodes[j] = gsnodes[j + 1]
|
|
||||||
gsnodes[j + 1] = tmp
|
|
||||||
reorder_info.final_exposed = exposed_communication_time(
|
reorder_info.final_exposed = exposed_communication_time(
|
||||||
snode,
|
curr, _group_nodes(_next[curr], None)
|
||||||
itertools.chain(
|
|
||||||
gsnode.snodes[1:], *[n.snodes for n in gsnodes[j + 1 :]]
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
# Recompute curr_memory
|
||||||
|
_prev_curr_memory = _curr_memory[_prev[group_head]] # type: ignore[index]
|
||||||
|
for n in _group_nodes(group_head, candidate):
|
||||||
|
_curr_memory[n] = _prev_curr_memory = (
|
||||||
|
_prev_curr_memory + mem_deltas[n]
|
||||||
|
)
|
||||||
|
candidate = _prev[group_head]
|
||||||
|
curr = _next[curr] # type: ignore[assignment]
|
||||||
|
|
||||||
node_stats = stats
|
node_stats = stats
|
||||||
improvement = {snode: node_stats[snode].improvement for snode in node_stats}
|
improvement = {snode: node_stats[snode].improvement for snode in node_stats}
|
||||||
|
|
@ -426,17 +431,13 @@ def _reorder_communication_preserving_peak_memory_internal(
|
||||||
reorder_log_str += str(headers) + "\n"
|
reorder_log_str += str(headers) + "\n"
|
||||||
reorder_log_str += "\n".join(map(str, rows))
|
reorder_log_str += "\n".join(map(str, rows))
|
||||||
|
|
||||||
grouping_logs: list[str] = []
|
new_snodes = _group_nodes(_head, None)
|
||||||
flatten_gsnodes: list[BaseSchedulerNode] = []
|
assert len(new_snodes) == original_snodes_num
|
||||||
for i, gsnode in enumerate(gsnodes):
|
new_peak_memory, curr_memory = estimate_peak_memory(
|
||||||
if isinstance(gsnode, GroupedSchedulerNode) and gsnode.temp_grouping:
|
new_snodes, name_to_freeable_input_buf, graph_outputs
|
||||||
flatten_gsnodes.extend(gsnode.snodes)
|
)
|
||||||
else:
|
reorder_log_str += f"\n peak_memory_before:{peak_memory}"
|
||||||
flatten_gsnodes.append(gsnode)
|
reorder_log_str += f"\n peak_memory_after:{new_peak_memory}"
|
||||||
|
|
||||||
grouping_log_str = "\n".join(grouping_logs)
|
|
||||||
reorder_log_str += "\n"
|
|
||||||
reorder_log_str += grouping_log_str
|
|
||||||
|
|
||||||
overlap_log.info(reorder_log_str)
|
overlap_log.info(reorder_log_str)
|
||||||
trace_structured(
|
trace_structured(
|
||||||
|
|
@ -448,8 +449,7 @@ def _reorder_communication_preserving_peak_memory_internal(
|
||||||
payload_fn=lambda: reorder_log_str,
|
payload_fn=lambda: reorder_log_str,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert len(flatten_gsnodes) == original_snodes_num
|
return new_snodes, stats
|
||||||
return flatten_gsnodes, stats
|
|
||||||
|
|
||||||
|
|
||||||
def _schedule_for_comm(
|
def _schedule_for_comm(
|
||||||
|
|
@ -623,7 +623,9 @@ def decide_global_ordering_of_comms(
|
||||||
# Enforce ordering by making previous comm a `WeakDep` dependency of the next comm
|
# Enforce ordering by making previous comm a `WeakDep` dependency of the next comm
|
||||||
mutating_buf = next(iter(comm_nodes[i].get_buffer_names()))
|
mutating_buf = next(iter(comm_nodes[i].get_buffer_names()))
|
||||||
for buf in comm_nodes[i - 1].get_buffer_names():
|
for buf in comm_nodes[i - 1].get_buffer_names():
|
||||||
comm_nodes[i].add_fake_dep(WeakDep(buf, mutating_buf=mutating_buf))
|
comm_nodes[i].add_fake_dep(
|
||||||
|
WeakDep(buf, mutating_buf=mutating_buf, is_fake=True)
|
||||||
|
)
|
||||||
|
|
||||||
return nodes
|
return nodes
|
||||||
|
|
||||||
|
|
@ -640,66 +642,166 @@ class SinkWaitInfo:
|
||||||
def _sink_waits_iterative_internal(
|
def _sink_waits_iterative_internal(
|
||||||
snodes: list[BaseSchedulerNode],
|
snodes: list[BaseSchedulerNode],
|
||||||
) -> tuple[list[BaseSchedulerNode], dict[BaseSchedulerNode, SinkWaitInfo]]:
|
) -> tuple[list[BaseSchedulerNode], dict[BaseSchedulerNode, SinkWaitInfo]]:
|
||||||
from torch._inductor.scheduler import GroupedSchedulerNode, init_group_node
|
from torch._inductor.scheduler import GroupedSchedulerNode
|
||||||
|
|
||||||
|
original_snodes_num = len(snodes)
|
||||||
|
if original_snodes_num == 0:
|
||||||
|
return snodes, {}
|
||||||
|
graph_inputs: OrderedSet[str] = OrderedSet(V.graph.graph_inputs.keys())
|
||||||
|
graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names())
|
||||||
|
name_to_freeable_input_buf: dict[str, FreeableInputBuffer] = get_freeable_input_buf(
|
||||||
|
snodes, graph_inputs
|
||||||
|
)
|
||||||
|
peak_memory, curr_memory = estimate_peak_memory(
|
||||||
|
snodes, name_to_freeable_input_buf, graph_outputs
|
||||||
|
)
|
||||||
|
|
||||||
n = len(snodes)
|
|
||||||
stats: dict[BaseSchedulerNode, SinkWaitInfo] = {}
|
stats: dict[BaseSchedulerNode, SinkWaitInfo] = {}
|
||||||
gsnodes: list[GroupedSchedulerNode] = [
|
_prev: dict[Optional[BaseSchedulerNode], Optional[BaseSchedulerNode]] = {}
|
||||||
GroupedSchedulerNode(snode.scheduler, [snode], temp_grouping=True)
|
_next: dict[Optional[BaseSchedulerNode], Optional[BaseSchedulerNode]] = {}
|
||||||
for snode in snodes
|
_head = snodes[0]
|
||||||
]
|
for i, snode in enumerate(snodes):
|
||||||
for i in range(n - 1, -1, -1):
|
_prev[snode] = snodes[i - 1] if i > 0 else None
|
||||||
gsnode = gsnodes[i]
|
_next[snode] = snodes[i + 1] if i < len(snodes) - 1 else None
|
||||||
if contains_wait(gsnode):
|
_curr_memory = dict(zip(snodes, curr_memory))
|
||||||
info = stats[gsnode.snodes[0]] = SinkWaitInfo()
|
_curr_memory[None] = 0 # type: ignore[index]
|
||||||
for j in range(i + 1, n):
|
|
||||||
wait_gsnode = gsnodes[j - 1]
|
def _group_nodes(head, tail):
|
||||||
wait_outs = wait_gsnode.get_outputs()
|
ret = []
|
||||||
next_gsnode = gsnodes[j]
|
n = head
|
||||||
dep_names = OrderedSet([s.name for s in next_gsnode.unmet_dependencies])
|
while True:
|
||||||
|
if n is not None:
|
||||||
|
ret.append(n)
|
||||||
|
if n == tail:
|
||||||
|
break
|
||||||
|
n = _next[n]
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def _group_names(head, tail):
|
||||||
|
ret = ""
|
||||||
|
for n in _group_nodes(head, tail):
|
||||||
|
if ret:
|
||||||
|
ret += "~"
|
||||||
|
ret += n.get_name()
|
||||||
|
return ret
|
||||||
|
|
||||||
|
curr = snodes[-1]
|
||||||
|
|
||||||
|
processed_waits = OrderedSet() # type: ignore[var-annotated]
|
||||||
|
while _prev[curr] is not None:
|
||||||
|
if contains_wait(curr) and curr not in processed_waits:
|
||||||
|
processed_waits.add(curr)
|
||||||
|
info = stats[curr] = SinkWaitInfo()
|
||||||
|
candidate = _next[curr]
|
||||||
|
wait_snode = curr
|
||||||
|
group_head = curr
|
||||||
|
group_tail = curr
|
||||||
|
group_peak_memory = _curr_memory[curr]
|
||||||
|
while candidate is not None:
|
||||||
|
group = GroupedSchedulerNode(
|
||||||
|
wait_snode.scheduler,
|
||||||
|
_group_nodes(group_head, group_tail),
|
||||||
|
temp_grouping=True,
|
||||||
|
)
|
||||||
|
group_outs = group.get_outputs()
|
||||||
|
|
||||||
|
data_deps = {s.name: s for s in candidate.unmet_dependencies}
|
||||||
data_dep = None
|
data_dep = None
|
||||||
for o in wait_outs:
|
for o in group_outs:
|
||||||
if o.get_name() in dep_names:
|
if d := data_deps.get(o.get_name(), None):
|
||||||
data_dep = o.get_name()
|
if isinstance(d, WeakDep) and d.is_fake:
|
||||||
|
continue
|
||||||
|
data_dep = d
|
||||||
break
|
break
|
||||||
# 1. If we have data_dep - we can not swap => trying to group
|
# 1. If we have data_dep - we can not swap => trying to group
|
||||||
# 2. If swap candidate and current node both contain collectives => trying to group
|
# 2. If swap candidate and current node both contain collectives => trying to group
|
||||||
if data_dep is not None or (
|
if data_dep is not None or (
|
||||||
both_contain_comms := (
|
both_contain_comms := (
|
||||||
contains_collective(wait_gsnode)
|
contains_collective(group) and contains_collective(candidate)
|
||||||
and contains_collective(next_gsnode)
|
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
|
|
||||||
def is_groupable(snode):
|
def is_groupable(snode):
|
||||||
return not contains_gemm_like(snode)
|
# We do not want to group with collectives to not reorder them forward.
|
||||||
|
if contains_collective(snode):
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
f"candidate contains collective {snode.get_name()}",
|
||||||
|
)
|
||||||
|
if contains_gemm_like(snode):
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
f"candidate contains gemm_like {snode.get_name()}",
|
||||||
|
)
|
||||||
|
return True, None
|
||||||
|
|
||||||
if is_groupable(next_gsnode):
|
is_grp, grp_reason = is_groupable(candidate)
|
||||||
new_snodes = wait_gsnode.snodes + next_gsnode.snodes
|
if is_grp:
|
||||||
init_group_node(next_gsnode, gsnode.scheduler, new_snodes)
|
group_tail = candidate
|
||||||
wait_gsnode.snodes = []
|
group_peak_memory = max(
|
||||||
|
group_peak_memory, _curr_memory[candidate]
|
||||||
|
)
|
||||||
info.grouped += 1
|
info.grouped += 1
|
||||||
info.grouped_info = _group_name(next_gsnode)
|
info.grouped_info = _group_names(group_head, group_tail)
|
||||||
|
candidate = _next[candidate]
|
||||||
continue
|
continue
|
||||||
elif (data_dep is None) and both_contain_comms:
|
elif (data_dep is None) and both_contain_comms:
|
||||||
info.limiting_factor = (
|
info.limiting_factor = (
|
||||||
f"collective ordering {_group_name(wait_gsnode)}"
|
f"collective ordering {_group_names(group_head, group_tail)}"
|
||||||
f" with candidate:{_group_name(next_gsnode)}"
|
f" with candidate:{candidate.get_name()}"
|
||||||
)
|
|
||||||
else:
|
|
||||||
info.limiting_factor = (
|
|
||||||
f"data dependency {data_dep}(dep_names:{dep_names})"
|
|
||||||
f" candidate:{_group_name(next_gsnode)} dep on {_group_name(wait_gsnode)}"
|
|
||||||
f" outs:{[o.get_name() for o in wait_outs]}"
|
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
info.moves += 1
|
else:
|
||||||
info.moves_info += f"+{_group_name(next_gsnode)}"
|
info.limiting_factor = (
|
||||||
|
f"data dependency {data_dep}(dep_names:{list(data_deps.keys())})"
|
||||||
|
f"\n candidate:{candidate.get_name()}(os:{[candidate.get_buffer_names()]})"
|
||||||
|
f"dep on {_group_names(group_head, group_tail)}"
|
||||||
|
f"\n outs:{[o.get_name() for o in group_outs]}"
|
||||||
|
f"\n non_group_reason:{grp_reason}"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
candidate_delta_memory = (
|
||||||
|
_curr_memory[candidate] - _curr_memory[_prev[candidate]] # type: ignore[index]
|
||||||
|
)
|
||||||
|
if group_peak_memory + candidate_delta_memory > peak_memory:
|
||||||
|
info.limiting_factor = "peak_memory"
|
||||||
|
break
|
||||||
|
|
||||||
|
info.moves += 1
|
||||||
|
info.moves_info += f"+{candidate.get_name()}"
|
||||||
|
|
||||||
|
# group_head_prev -0-> candidate -1-> group_head...group_tail -2-> candidate_next
|
||||||
|
mem_deltas = {}
|
||||||
|
for n in [candidate, *_group_nodes(group_head, group_tail)]:
|
||||||
|
mem_deltas[n] = _curr_memory[n] - _curr_memory[_prev[n]] # type: ignore[index]
|
||||||
|
# 0:
|
||||||
|
group_head_prev = _prev[group_head]
|
||||||
|
if group_head_prev:
|
||||||
|
_next[group_head_prev] = candidate
|
||||||
|
_prev[candidate] = group_head_prev
|
||||||
|
|
||||||
|
# 2:
|
||||||
|
candidate_next = _next[candidate]
|
||||||
|
if candidate_next:
|
||||||
|
_prev[candidate_next] = group_tail
|
||||||
|
_next[group_tail] = candidate_next
|
||||||
|
|
||||||
|
# 1:
|
||||||
|
_prev[group_head] = candidate
|
||||||
|
_next[candidate] = group_head
|
||||||
|
if group_head == _head:
|
||||||
|
_head = candidate
|
||||||
|
|
||||||
|
# Recompute curr_memory
|
||||||
|
_prev_curr_memory = _curr_memory[_prev[candidate]] # type: ignore[index]
|
||||||
|
for n in _group_nodes(candidate, group_tail):
|
||||||
|
_curr_memory[n] = _prev_curr_memory = (
|
||||||
|
_prev_curr_memory + mem_deltas[n]
|
||||||
|
)
|
||||||
|
|
||||||
|
candidate = _next[group_tail]
|
||||||
|
curr = _prev[curr] # type: ignore[assignment]
|
||||||
|
|
||||||
# Swapping snodes j and j - 1
|
|
||||||
tmp = gsnodes[j - 1]
|
|
||||||
gsnodes[j - 1] = gsnodes[j]
|
|
||||||
gsnodes[j] = tmp
|
|
||||||
headers = [
|
headers = [
|
||||||
"Wait node",
|
"Wait node",
|
||||||
"grouped",
|
"grouped",
|
||||||
|
|
@ -732,16 +834,13 @@ def _sink_waits_iterative_internal(
|
||||||
log_str += str(headers) + "\n"
|
log_str += str(headers) + "\n"
|
||||||
log_str += "\n".join(map(str, rows))
|
log_str += "\n".join(map(str, rows))
|
||||||
overlap_log.info(log_str)
|
overlap_log.info(log_str)
|
||||||
grouping_logs = []
|
new_snodes = _group_nodes(_head, None)
|
||||||
flatten_snodes = []
|
assert len(new_snodes) == original_snodes_num
|
||||||
for i, gsnode in enumerate(gsnodes):
|
new_peak_memory, curr_memory = estimate_peak_memory(
|
||||||
grouping_logs.append(f"gsnode[{i}]:{_group_name(gsnode, with_bufs=True)}")
|
new_snodes, name_to_freeable_input_buf, graph_outputs
|
||||||
if isinstance(gsnode, GroupedSchedulerNode) and gsnode.temp_grouping:
|
)
|
||||||
flatten_snodes.extend(gsnode.snodes)
|
log_str += f"\n peak_memory_before:{peak_memory}"
|
||||||
else:
|
log_str += f"\n peak_memory_after:{new_peak_memory}"
|
||||||
flatten_snodes.append(gsnode)
|
|
||||||
grouping_log_str = "\n".join(grouping_logs)
|
|
||||||
log_str += grouping_log_str
|
|
||||||
trace_structured(
|
trace_structured(
|
||||||
"artifact",
|
"artifact",
|
||||||
metadata_fn=lambda: {
|
metadata_fn=lambda: {
|
||||||
|
|
@ -750,8 +849,7 @@ def _sink_waits_iterative_internal(
|
||||||
},
|
},
|
||||||
payload_fn=lambda: log_str,
|
payload_fn=lambda: log_str,
|
||||||
)
|
)
|
||||||
assert len(flatten_snodes) == n
|
return new_snodes, stats
|
||||||
return flatten_snodes, stats
|
|
||||||
|
|
||||||
|
|
||||||
def sink_waits_iterative(
|
def sink_waits_iterative(
|
||||||
|
|
@ -777,7 +875,9 @@ def node_summary(snode):
|
||||||
if len(snodes) == 1:
|
if len(snodes) == 1:
|
||||||
detail = ""
|
detail = ""
|
||||||
if isinstance(snode.node, (ir.ExternKernelOut, ir._CollectiveKernel)):
|
if isinstance(snode.node, (ir.ExternKernelOut, ir._CollectiveKernel)):
|
||||||
detail = f" ({snode.node.python_kernel_name})"
|
outs_str = f"outs:{[o.get_name() for o in snode.get_outputs()]}"
|
||||||
|
ins_str = f"ins:{[d.name for d in snode.unmet_dependencies]}"
|
||||||
|
detail = f" {snode.get_name()} ({snode.node.python_kernel_name})\n {outs_str}\n ({ins_str})"
|
||||||
layouts = [child.node.get_output_spec() for child in snode.get_nodes()]
|
layouts = [child.node.get_output_spec() for child in snode.get_nodes()]
|
||||||
out_tensor_info = ",".join(
|
out_tensor_info = ",".join(
|
||||||
[
|
[
|
||||||
|
|
@ -1352,7 +1452,7 @@ def enforce_comm_ordering_for_fsdp(
|
||||||
mutating_buf = next(iter(ag_group_node.get_buffer_names()))
|
mutating_buf = next(iter(ag_group_node.get_buffer_names()))
|
||||||
for o in prev_ag_wait.get_outputs():
|
for o in prev_ag_wait.get_outputs():
|
||||||
ag_group_node.add_fake_dep(
|
ag_group_node.add_fake_dep(
|
||||||
WeakDep(o.get_name(), mutating_buf=mutating_buf)
|
WeakDep(o.get_name(), mutating_buf=mutating_buf, is_fake=True)
|
||||||
)
|
)
|
||||||
prev_ag_wait = wait_group_node
|
prev_ag_wait = wait_group_node
|
||||||
|
|
||||||
|
|
@ -1364,7 +1464,7 @@ def enforce_comm_ordering_for_fsdp(
|
||||||
mutating_buf = next(iter(rs_group_node.get_buffer_names()))
|
mutating_buf = next(iter(rs_group_node.get_buffer_names()))
|
||||||
for o in prev_rs_wait.get_outputs():
|
for o in prev_rs_wait.get_outputs():
|
||||||
rs_group_node.add_fake_dep(
|
rs_group_node.add_fake_dep(
|
||||||
WeakDep(o.get_name(), mutating_buf=mutating_buf)
|
WeakDep(o.get_name(), mutating_buf=mutating_buf, is_fake=True)
|
||||||
)
|
)
|
||||||
prev_rs_wait = wait_group_node
|
prev_rs_wait = wait_group_node
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -342,6 +342,12 @@ class WeakDep(Dep):
|
||||||
name: str
|
name: str
|
||||||
# Buffer that is doing the mutation
|
# Buffer that is doing the mutation
|
||||||
mutating_buf: str
|
mutating_buf: str
|
||||||
|
# WeakDep's are also used to add dependencies to prevent some specific reordering,
|
||||||
|
# E.g. collectives global ordering.
|
||||||
|
# But if other pass guarantees proper ordering by its logic,
|
||||||
|
# This additional "fake" deps will be holding optimizations.
|
||||||
|
# This flag is used to identify those additional deps.
|
||||||
|
is_fake: bool = False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def index(self) -> sympy.Expr:
|
def index(self) -> sympy.Expr:
|
||||||
|
|
@ -352,7 +358,7 @@ class WeakDep(Dep):
|
||||||
|
|
||||||
def rename(self, renames: dict[str, str]) -> "WeakDep":
|
def rename(self, renames: dict[str, str]) -> "WeakDep":
|
||||||
if self.name in renames:
|
if self.name in renames:
|
||||||
return WeakDep(renames[self.name], self.mutating_buf)
|
return WeakDep(renames[self.name], self.mutating_buf, self.is_fake)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def numbytes_hint(self) -> int:
|
def numbytes_hint(self) -> int:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user