[fr] Fix one error in analysis script when subPG world size is smaller than global size (#156156)

Summary: We run into an interesting case when we see so many mismatches while lot of mismatch turns out to be a fully match. The reason is that we use the dump ranks (which is from 0 to 79) to compare against the local pg ranks (0 to 7) this leads to false positive of mismatches. We can just check whether dump ranks contain all expected ranks or not, that should be sufficient.

Test Plan:
Test with the failed case with the script and we now see the correct behavior + new unit test case.

Rollback Plan:

Differential Revision: D76775373

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156156
Approved by: https://github.com/VieEeEw
This commit is contained in:
Junjie Wang (PyTorch) 2025-06-17 21:17:58 +00:00 committed by PyTorch MergeBot
parent bb462a6237
commit 3106a33e41
2 changed files with 25 additions and 7 deletions

View File

@ -187,6 +187,7 @@ LOADED_FR_DETAIL_TEMPLATE: dict[str, dict[str, Any]] = {
"entries": [],
"pg_config": {
"0": {"name": "0", "desc": "default_pg", "ranks": "[0, 1]"},
"1": {"name": "1", "desc": "sub_pg", "ranks": "[0]"},
},
"rank": 0,
},
@ -194,6 +195,7 @@ LOADED_FR_DETAIL_TEMPLATE: dict[str, dict[str, Any]] = {
"entries": [],
"pg_config": {
"0": {"name": "0", "desc": "default_pg", "ranks": "[0, 1]"},
"1": {"name": "1", "desc": "sub_pg", "ranks": "[1]"},
},
"rank": 1,
},
@ -209,10 +211,11 @@ def create_one_entry(
collective_seq_id=0,
p2p_seq_id=0,
output_dtypes="float32",
pg_info=("0", "default"),
):
event = create_one_event(
collective_name,
("0", "default"),
pg_info,
input_sizes,
output_sizes,
state,
@ -229,7 +232,7 @@ class FlightRecorderE2ETest(TestCase):
def testBuildDB(self):
config = JobConfig()
args = config.parse_args([])
version = "2.7" # Same as the version in FlightRecorder.hpp
version = "2.8" # Same as the version in FlightRecorder.hpp
LOADED_FR_DETAIL_TEMPLATE["dump_file_rank_0"]["version"] = version
LOADED_FR_DETAIL_TEMPLATE["dump_file_rank_1"]["version"] = version
# Test case 1: matched all_reduce case.
@ -240,11 +243,25 @@ class FlightRecorderE2ETest(TestCase):
details1["dump_file_rank_1"]["entries"].append(
create_one_entry(0, "all_reduce", [[4, 4]], [[4, 4]])
)
details1["dump_file_rank_0"]["entries"].append(
create_one_entry(
1, "all_reduce", [[5, 5]], [[5, 5]], pg_info=("1", "sub_pg")
)
)
details1["dump_file_rank_1"]["entries"].append(
create_one_entry(
1, "all_reduce", [[5, 5]], [[5, 5]], pg_info=("1", "sub_pg")
)
)
db = build_db(details1, args, version)
self.assertEqual(len(db.collectives), 1)
self.assertEqual(len(db.collectives), 3)
self.assertEqual(db.collectives[0].record_id, 0)
self.assertEqual(db.collectives[0].collective_name, "nccl:all_reduce")
self.assertEqual(db.collectives[0].pass_check, True)
self.assertEqual(db.collectives[1].record_id, 1)
self.assertEqual(db.collectives[1].collective_name, "nccl:all_reduce")
self.assertEqual(db.collectives[1].pass_check, True)
self.assertEqual(db.collectives[2].pass_check, True)
# Test case 2: matched allreduce_coalesced case.
details2 = copy.deepcopy(LOADED_FR_DETAIL_TEMPLATE)
details2["dump_file_rank_0"]["entries"].append(

View File

@ -463,10 +463,10 @@ def error_analysis(
match_record.candidate_idx.update(match_record.found_idx)
match_record.found_idx.clear()
match_record.found_ranks.clear()
elif (
len(match_record.candidate_ranks) == 1
and dumps_ranks == match_record.expected_ranks
):
# We didn't see any mismatch and all expected ranks are in the dump.
elif len(
match_record.candidate_ranks
) == 1 and match_record.expected_ranks.issubset(dumps_ranks):
# case two: alltoall or alltoall_base case.
if match_record.has_undecided_case:
alltoall_cases = [current_entry] + [
@ -527,6 +527,7 @@ def error_analysis(
match_record.candidate_idx.update(match_record.found_idx)
match_record.found_idx.clear()
match_record.found_ranks.clear()
# if any element in expected_ranks not in dumps_ranks.
if match_record.expected_ranks - dumps_ranks:
mismatch[pg_name] += 1
logger.info(