[c10d][fr] Enable FR analysis script for rest of all coalesce op (#151247)

We revisited how coalesced collective is working in https://github.com/pytorch/pytorch/pull/151243 and we now want to enable the script to work for slow path. The change is indeed bc-breaking but this is needed to make it work and the API is an internal use API. It is not user facing. For slow path the individual has input-sizes and output sizes recorded but no state. The final one has the state ready. We check the correctness of each individual collective one by one but we don't check the state match for these collectives, we can only check the state match for the last one which is the work item with coalesced label.

Added more unit test for slow path.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151247
Approved by: https://github.com/d4l3k, https://github.com/XilunWu
This commit is contained in:
fduwjj 2025-04-15 10:32:52 -07:00 committed by PyTorch MergeBot
parent f98150fc8e
commit ae648f047c
4 changed files with 362 additions and 19 deletions

View File

@ -230,7 +230,7 @@ class FlightRecorderE2ETest(TestCase):
def testBuildDB(self): def testBuildDB(self):
config = JobConfig() config = JobConfig()
args = config.parse_args([]) args = config.parse_args([])
version = "2.6" # Same as the version in FlightRecorder.hpp version = "2.7" # Same as the version in FlightRecorder.hpp
LOADED_FR_DETAIL_TEMPLATE["dump_file_rank_0"]["version"] = version LOADED_FR_DETAIL_TEMPLATE["dump_file_rank_0"]["version"] = version
LOADED_FR_DETAIL_TEMPLATE["dump_file_rank_1"]["version"] = version LOADED_FR_DETAIL_TEMPLATE["dump_file_rank_1"]["version"] = version
# Test case 1: matched all_reduce case. # Test case 1: matched all_reduce case.
@ -259,6 +259,58 @@ class FlightRecorderE2ETest(TestCase):
self.assertEqual(db.collectives[0].record_id, 0) self.assertEqual(db.collectives[0].record_id, 0)
self.assertEqual(db.collectives[0].collective_name, "nccl:allreduce_coalesced") self.assertEqual(db.collectives[0].collective_name, "nccl:allreduce_coalesced")
self.assertEqual(db.collectives[0].pass_check, True) self.assertEqual(db.collectives[0].pass_check, True)
# Test case 3: matched slow path, two broadcast coalesce case.
details3 = copy.deepcopy(LOADED_FR_DETAIL_TEMPLATE)
# sequence ID should not increase for coalesced collectives
details3["dump_file_rank_0"]["entries"].append(
create_one_entry(0, "broadcast", [[4, 4]], [[4, 4]])
)
details3["dump_file_rank_0"]["entries"].append(
create_one_entry(1, "broadcast", [[4, 4]], [[4, 4]])
)
details3["dump_file_rank_0"]["entries"].append(
create_one_entry(2, "coalesced", [[]], [[]])
)
details3["dump_file_rank_1"]["entries"].append(
create_one_entry(0, "broadcast", [[4, 4]], [[4, 4]])
)
details3["dump_file_rank_1"]["entries"].append(
create_one_entry(1, "broadcast", [[4, 4]], [[4, 4]])
)
details3["dump_file_rank_1"]["entries"].append(
create_one_entry(2, "coalesced", [[]], [[]])
)
db = build_db(details3, args, version)
self.assertEqual(len(db.collectives), 1)
self.assertEqual(db.collectives[0].record_id, 2)
self.assertEqual(db.collectives[0].collective_name, "nccl:coalesced")
self.assertEqual(db.collectives[0].pass_check, True)
# Test case 4: mis-matched uneven all-gather case.
details4 = copy.deepcopy(LOADED_FR_DETAIL_TEMPLATE)
# sequence ID should not increase for coalesced collectives
details4["dump_file_rank_0"]["entries"].append(
create_one_entry(0, "_broadcast_oop", [[4, 4]], [[4, 4]])
)
details4["dump_file_rank_0"]["entries"].append(
create_one_entry(1, "_broadcast_oop", [[5, 5]], [[5, 5]])
)
details4["dump_file_rank_0"]["entries"].append(
create_one_entry(2, "coalesced", [[]], [[]])
)
details4["dump_file_rank_1"]["entries"].append(
create_one_entry(0, "_broadcast_oop", [[4, 4]], [[4, 4]])
)
details4["dump_file_rank_1"]["entries"].append(
create_one_entry(1, "_broadcast_oop", [[4, 4]], [[4, 4]])
)
details4["dump_file_rank_1"]["entries"].append(
create_one_entry(2, "coalesced", [[]], [[]])
)
db = build_db(details4, args, version)
self.assertEqual(len(db.collectives), 1)
self.assertEqual(db.collectives[0].record_id, 1)
self.assertEqual(db.collectives[0].collective_name, "nccl:_broadcast_oop")
self.assertEqual(db.collectives[0].pass_check, False)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -6,6 +6,7 @@
import argparse import argparse
import ast import ast
import copy
import os import os
import sys import sys
from typing import Any # type: ignore[attr-defined] from typing import Any # type: ignore[attr-defined]
@ -28,10 +29,12 @@ from tools.flight_recorder.components.utils import (
check_no_missing_dump_files, check_no_missing_dump_files,
check_version, check_version,
error_analysis, error_analysis,
find_coalesced_group, find_coalesced_group as find_coalesced_group_p2p_only,
find_coalesced_group_with_non_p2p,
get_version_detail, get_version_detail,
just_print_entries, just_print_entries,
match_coalesced_groups, match_coalesced_groups as match_coalesced_groups_p2p_only,
match_coalesced_groups_with_non_p2p,
) )
@ -209,12 +212,23 @@ def build_collectives(
errors=set(), errors=set(),
) )
if find_coalesced_group(pg_name, entries, _pg_guids, first_rank): major_v, minor_v = get_version_detail(version)
expected_ranks.add(first_rank) find_coalesced_group = (
find_coalesced_group_p2p_only
if major_v <= 2 and minor_v < 7
else find_coalesced_group_with_non_p2p
)
maybe_coalesced_group = find_coalesced_group(
pg_name, entries, _pg_guids, first_rank
)
if len(maybe_coalesced_group) > 1:
num_coalesced_entries = len(maybe_coalesced_group)
# We need a copy of the original expected ranks to avoid modifying it.
candidate_ranks = copy.deepcopy(expected_ranks)
done_ranks = set() done_ranks = set()
all_coalesced_entries = {} all_coalesced_entries = {}
while expected_ranks: while candidate_ranks:
curr = expected_ranks.pop() curr = candidate_ranks.pop()
done_ranks.add(curr) done_ranks.add(curr)
grp = ( grp = (
find_coalesced_group(pg_name, all_entries[curr], _pg_guids, curr) # type: ignore[index] find_coalesced_group(pg_name, all_entries[curr], _pg_guids, curr) # type: ignore[index]
@ -226,31 +240,54 @@ def build_collectives(
op = Op(entry, _memberships, pg_name) op = Op(entry, _memberships, pg_name)
peer = None peer = None
if op.type == "send": if op.type == "send":
assert op._src_g == curr, (op._src_g, curr) assert op._src_g == curr, (
f"Send src error: {curr} expected but {op._src_g} is set"
)
peer = op._dst_g peer = op._dst_g
elif op.type == "recv": elif op.type == "recv":
assert op._dst_g == curr, (op._dst_g, curr) assert op._dst_g == curr, (
f"Recv dst error: {curr} expected but {op._dst_g} is set"
)
peer = op._src_g peer = op._src_g
if peer and peer not in done_ranks: if peer and peer not in done_ranks:
expected_ranks.add(peer) candidate_ranks.add(peer)
match = match_coalesced_groups( if major_v <= 2 and minor_v < 7:
all_coalesced_entries, match = match_coalesced_groups_p2p_only(
group_size=_groups[pg_name].size, all_coalesced_entries,
groups=_groups, group_size=_groups[pg_name].size,
memberships=_memberships, groups=_groups,
_pg_guids=_pg_guids, memberships=_memberships,
) _pg_guids=_pg_guids,
)
else:
match = match_coalesced_groups_with_non_p2p(
copy.deepcopy(
all_coalesced_entries
), # We want to keep a copy for cleanup.
pg_info=(pg_name, desc),
memberships=_memberships,
_pg_guids=_pg_guids,
mismatch=mismatch,
dumps_ranks=dumps_ranks,
version=version,
collectives=collectives,
match_record=match_record,
)
if match and mismatch[pg_name] == 0: if match and mismatch[pg_name] == 0:
collectives.append(entry_state.to_collective(len(collectives))) # We treat coalesced collectives as a single collective.
# TODO: we need to surface a merged collective info like input/output sizes to users.
collectives.append(
match_record.entry_state.to_collective(len(collectives))
)
else: else:
mismatch[pg_name] += 1 mismatch[pg_name] += 1
for r in all_coalesced_entries: for r in all_coalesced_entries:
idx_map = {r: i for i, _ in reversed(all_coalesced_entries[r])} # noqa: B035 idx_map = {r: i for i, _ in reversed(all_coalesced_entries[r])} # noqa: B035
nccl_calls.extend( nccl_calls.extend(
reversed( reversed(
entry_state.to_nccl_call( match_record.entry_state.to_nccl_call(
all_entries, all_entries,
idx_map, idx_map,
len(nccl_calls), len(nccl_calls),
@ -258,6 +295,10 @@ def build_collectives(
) )
) )
) )
# This extra cleanup is needed because we need to pop all collectives within a coalesced collective.
for i, k in idx_map.items():
for _ in range(1, num_coalesced_entries):
all_entries[i].pop(k)
else: else:
# Iterate through all the ranks and check if there is a mis-match for the current entry. # Iterate through all the ranks and check if there is a mis-match for the current entry.
check_current_entry_match( check_current_entry_match(

View File

@ -187,7 +187,9 @@ https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/overview.html.
""" """
COLLECTIVES = { COLLECTIVES = {
"broadcast", "broadcast",
"_broadcast_oop",
"reduce", "reduce",
"_reduce_oop",
"all_gather", "all_gather",
"all_reduce", "all_reduce",
"_all_gather_base", "_all_gather_base",
@ -604,3 +606,13 @@ class MatchStateRecord:
self.found_idx = found_idx self.found_idx = found_idx
self.errors = errors self.errors = errors
self.has_undecided_case = False self.has_undecided_case = False
def reset_for_coalesced(
self, entry_state: EntryState, candidate_ranks: set[int]
) -> None:
self.entry_state = entry_state
self.candidate_ranks = candidate_ranks
self.candidate_idx = {}
self.found_ranks = set()
self.found_idx = {}
self.errors = set()

View File

@ -10,6 +10,8 @@ from typing import Any
from tools.flight_recorder.components.fr_logger import FlightRecorderLogger from tools.flight_recorder.components.fr_logger import FlightRecorderLogger
from tools.flight_recorder.components.types import ( from tools.flight_recorder.components.types import (
Collective,
EntryState,
Group, Group,
MatchInfo, MatchInfo,
MatchState, MatchState,
@ -176,6 +178,208 @@ def match_coalesced_groups(
return True return True
# We enabled the creating FR entry for non-P2P slow path collective ops in v2.7.
def match_coalesced_groups_with_non_p2p(
all_rank_events: dict[Any, Any],
pg_info: tuple[str, str],
memberships: dict[str, set[Any]],
_pg_guids: dict[tuple[str, int], str],
mismatch: dict[str, int],
dumps_ranks: set[int],
version: str,
collectives: list[Collective],
match_record: MatchStateRecord,
) -> bool:
"""
all_rank_events: {
rank: [
(idx, event_dict)
]
}
Note: it is possible for event dicts in a coalesced group to be asymmetric.
e.g. the following events lists form a valid coalescing group
events0 [send:1]
events1 [recv:0, send:2]
events2 [recv:1]
Rule 1: all ops should find a match
Rule 2: relative ordering of sends and recvs in one event list can be arbitrary
e.g.
events1 [recv:0, send:2] > okay
events1 [send:2, recv:0] > also okay
Rule 3: sends to the same dest or recvs from the src should be in a consistent order
e.g.
rank0 [send:1 (100B), send:1 (1000B)]
rank1 [recv:0 (1000B), recv:0 (100B)] > not okay
"""
all_ops = {
rank: [
Op(e, memberships, _pg_guids[(e["process_group"][0], rank)])
for _, e in all_rank_events[rank]
]
for rank in all_rank_events
}
is_p2p = any(op.type in P2P for op in all_ops[0])
pg_name = pg_info[0]
def visualize_ops(
match: bool,
_pg_guids: dict[tuple[str, int], str],
) -> None:
all_ops = {
rank: [
Op(e, memberships, _pg_guids[(e["process_group"][0], rank)])
for _, e in all_rank_events[rank]
]
for rank in all_rank_events
}
i = 0
row = []
progress = True
table = []
while progress:
progress = False
for r in all_ops:
if len(all_ops[r]) > i:
rank, event = all_rank_events[r][i]
row.append(
Op(
event,
memberships,
_pg_guids[(event["process_group"][0], rank)],
)
)
progress = True
else:
row.append(None) # type: ignore[arg-type]
table.append(row)
row = []
i += 1
title = "Match" if match else "MISMATCH"
logger.info("%s \n", title)
logger.info("%s", tabulate(table)) # type: ignore[operator]
# TODO Need to verify no seq_id deltas for P2P ops.
for rank, op_list in all_ops.items():
if not op_list:
logger.error("Rank %s has an empty op list.", rank)
if op_list[-1].type == "coalesced" and is_p2p:
op_list.pop(-1)
while all_ops:
first_rank = next(iter(all_ops))
my_ops = all_ops[first_rank]
if len(all_ops[first_rank]) == 0:
all_ops.pop(first_rank)
continue
# lets match the first collective! we need to know which ranks are involved, and ensure that this same
# collective is also the first one on those ranks within that group
op = my_ops[0]
match_idx = -1
if is_p2p:
dst_global_rank = sorted(memberships[op.pg_name])[op.dst]
peer_ops = all_ops[dst_global_rank]
for i, other in enumerate(peer_ops):
if op.match(other).state == MatchState.FULLY_MATCHED:
match_idx = i
break
elif op.dst == other.src:
# Rule 3
break
else:
# Rule 1
continue
if match_idx >= 0:
my_ops.pop(0)
peer_ops.pop(match_idx)
else:
visualize_ops(False, _pg_guids)
return False
else:
all_coalesced_entries = {
rank: [e for _, e in all_rank_events[rank]] for rank in all_rank_events
}
current_entry = all_coalesced_entries[first_rank][0]
my_ops.pop(0)
match_record.reset_for_coalesced(
EntryState(current_entry, match_record.expected_ranks),
{first_rank},
)
# Iterate through all the ranks and check if there is a mis-match for the current entry.
check_current_entry_match(
all_coalesced_entries,
_pg_guids,
pg_info,
current_entry,
memberships,
mismatch,
match_record,
)
# Use heuristics to decide what type of errors and error messages we should print.
error_analysis(
all_coalesced_entries,
match_record,
dumps_ranks,
first_rank,
current_entry,
mismatch,
get_version_detail(version),
pg_info[0],
)
# TODO: For now, we only check the correctness of individual collective within a coalesced one in
# this script. We need to merge (e.g, input/output sizes) together
# for downstream consumer.
# at this point there are 3 possibilities
# 1. we found a match on all the ranks that are members of the group
# -> we create a Collective and remove the individual entries from their original lists
if (
match_record.found_ranks == match_record.expected_ranks
and mismatch[pg_name] == 0
):
# Just pop out this collective.
idx_map = {
r: match_record.found_idx[r] if r != first_rank else 0
for r in match_record.found_ranks
}
for i, k in idx_map.items():
all_rank_events[i].pop(k)
for r in match_record.found_ranks:
if r != first_rank:
all_ops[r].pop(0)
# 2. we found a partial match but some ranks are missing
# 3. we found no match
# -> since its not a complete collective, no entry goes into collectives but we still record a nccl call
else:
logger.debug("Non-matching collective inside coalesced group")
idx_map = {
r: match_record.candidate_idx[r] if r != first_rank else 0
for r in match_record.candidate_ranks
}
collectives.append(
match_record.entry_state.to_collective(
len(collectives),
errors=match_record.errors,
idx_map=idx_map,
all_entries=all_coalesced_entries,
)
)
return False
if is_p2p:
visualize_ops(True, _pg_guids)
return True
def check_size_alltoall(alltoall_cases: list[dict[str, Any]]) -> tuple[bool, int, int]: def check_size_alltoall(alltoall_cases: list[dict[str, Any]]) -> tuple[bool, int, int]:
input_numel = 0 input_numel = 0
output_numel = 0 output_numel = 0
@ -369,6 +573,40 @@ def find_coalesced_group(
return [] return []
# We enabled the creating FR entry for non-P2P slow path collective ops in v2.7.
def find_coalesced_group_with_non_p2p(
pg_name: str,
entries: list[dict[str, Any]],
_pg_guids: dict[tuple[str, int], str],
rank: int,
) -> list[tuple[int, dict[str, Any]]]:
"""Given a list of entries, if the collective_seq_id of the first entry matches that of subsequent ones,
build an return a list of entries terminating in a 'coalesced' op entry all sharing a collective_seq_id
"""
found = []
collective_seq_id = None
for i, e in enumerate(entries):
if _pg_guids[(e["process_group"][0], rank)] != pg_name:
continue
elif collective_seq_id is None:
collective_seq_id = (
e["p2p_seq_id"] if e["is_p2p"] else e["collective_seq_id"]
)
found.append((i, e))
elif not e["is_p2p"] and e["collective_seq_id"] == collective_seq_id:
found.append((i, e))
elif e["is_p2p"] and e["p2p_seq_id"] == collective_seq_id:
found.append((i, e))
else:
break
if len(found) > 1:
if found[-1][1]["profiling_name"] != "nccl:coalesced":
logger.error("Rank %s does not have a coalesced end.", rank)
return found
return []
def just_print_entries( def just_print_entries(
all_entries: dict[int, list[dict[str, Any]]], all_entries: dict[int, list[dict[str, Any]]],
_groups: dict[str, Group], _groups: dict[str, Group],