mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[fr] Fix one error in analysis script when subPG world size is smaller than global size (#156156)
Summary: We run into an interesting case when we see so many mismatches while lot of mismatch turns out to be a fully match. The reason is that we use the dump ranks (which is from 0 to 79) to compare against the local pg ranks (0 to 7) this leads to false positive of mismatches. We can just check whether dump ranks contain all expected ranks or not, that should be sufficient. Test Plan: Test with the failed case with the script and we now see the correct behavior + new unit test case. Rollback Plan: Differential Revision: D76775373 Pull Request resolved: https://github.com/pytorch/pytorch/pull/156156 Approved by: https://github.com/VieEeEw
This commit is contained in:
parent
bb462a6237
commit
3106a33e41
|
|
@ -187,6 +187,7 @@ LOADED_FR_DETAIL_TEMPLATE: dict[str, dict[str, Any]] = {
|
|||
"entries": [],
|
||||
"pg_config": {
|
||||
"0": {"name": "0", "desc": "default_pg", "ranks": "[0, 1]"},
|
||||
"1": {"name": "1", "desc": "sub_pg", "ranks": "[0]"},
|
||||
},
|
||||
"rank": 0,
|
||||
},
|
||||
|
|
@ -194,6 +195,7 @@ LOADED_FR_DETAIL_TEMPLATE: dict[str, dict[str, Any]] = {
|
|||
"entries": [],
|
||||
"pg_config": {
|
||||
"0": {"name": "0", "desc": "default_pg", "ranks": "[0, 1]"},
|
||||
"1": {"name": "1", "desc": "sub_pg", "ranks": "[1]"},
|
||||
},
|
||||
"rank": 1,
|
||||
},
|
||||
|
|
@ -209,10 +211,11 @@ def create_one_entry(
|
|||
collective_seq_id=0,
|
||||
p2p_seq_id=0,
|
||||
output_dtypes="float32",
|
||||
pg_info=("0", "default"),
|
||||
):
|
||||
event = create_one_event(
|
||||
collective_name,
|
||||
("0", "default"),
|
||||
pg_info,
|
||||
input_sizes,
|
||||
output_sizes,
|
||||
state,
|
||||
|
|
@ -229,7 +232,7 @@ class FlightRecorderE2ETest(TestCase):
|
|||
def testBuildDB(self):
|
||||
config = JobConfig()
|
||||
args = config.parse_args([])
|
||||
version = "2.7" # Same as the version in FlightRecorder.hpp
|
||||
version = "2.8" # Same as the version in FlightRecorder.hpp
|
||||
LOADED_FR_DETAIL_TEMPLATE["dump_file_rank_0"]["version"] = version
|
||||
LOADED_FR_DETAIL_TEMPLATE["dump_file_rank_1"]["version"] = version
|
||||
# Test case 1: matched all_reduce case.
|
||||
|
|
@ -240,11 +243,25 @@ class FlightRecorderE2ETest(TestCase):
|
|||
details1["dump_file_rank_1"]["entries"].append(
|
||||
create_one_entry(0, "all_reduce", [[4, 4]], [[4, 4]])
|
||||
)
|
||||
details1["dump_file_rank_0"]["entries"].append(
|
||||
create_one_entry(
|
||||
1, "all_reduce", [[5, 5]], [[5, 5]], pg_info=("1", "sub_pg")
|
||||
)
|
||||
)
|
||||
details1["dump_file_rank_1"]["entries"].append(
|
||||
create_one_entry(
|
||||
1, "all_reduce", [[5, 5]], [[5, 5]], pg_info=("1", "sub_pg")
|
||||
)
|
||||
)
|
||||
db = build_db(details1, args, version)
|
||||
self.assertEqual(len(db.collectives), 1)
|
||||
self.assertEqual(len(db.collectives), 3)
|
||||
self.assertEqual(db.collectives[0].record_id, 0)
|
||||
self.assertEqual(db.collectives[0].collective_name, "nccl:all_reduce")
|
||||
self.assertEqual(db.collectives[0].pass_check, True)
|
||||
self.assertEqual(db.collectives[1].record_id, 1)
|
||||
self.assertEqual(db.collectives[1].collective_name, "nccl:all_reduce")
|
||||
self.assertEqual(db.collectives[1].pass_check, True)
|
||||
self.assertEqual(db.collectives[2].pass_check, True)
|
||||
# Test case 2: matched allreduce_coalesced case.
|
||||
details2 = copy.deepcopy(LOADED_FR_DETAIL_TEMPLATE)
|
||||
details2["dump_file_rank_0"]["entries"].append(
|
||||
|
|
|
|||
|
|
@ -463,10 +463,10 @@ def error_analysis(
|
|||
match_record.candidate_idx.update(match_record.found_idx)
|
||||
match_record.found_idx.clear()
|
||||
match_record.found_ranks.clear()
|
||||
elif (
|
||||
len(match_record.candidate_ranks) == 1
|
||||
and dumps_ranks == match_record.expected_ranks
|
||||
):
|
||||
# We didn't see any mismatch and all expected ranks are in the dump.
|
||||
elif len(
|
||||
match_record.candidate_ranks
|
||||
) == 1 and match_record.expected_ranks.issubset(dumps_ranks):
|
||||
# case two: alltoall or alltoall_base case.
|
||||
if match_record.has_undecided_case:
|
||||
alltoall_cases = [current_entry] + [
|
||||
|
|
@ -527,6 +527,7 @@ def error_analysis(
|
|||
match_record.candidate_idx.update(match_record.found_idx)
|
||||
match_record.found_idx.clear()
|
||||
match_record.found_ranks.clear()
|
||||
# if any element in expected_ranks not in dumps_ranks.
|
||||
if match_record.expected_ranks - dumps_ranks:
|
||||
mismatch[pg_name] += 1
|
||||
logger.info(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user