mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[FR] Fix duplicate output for the case when not all ranks join on collective (#137256)
As title, when testing on an internal case, we found that we have very similar output for the error when certain ranks does not join one collective. This is because we didn't put all ranks into `candidate_ranks` so that they didn't get wiped out from entries and gets checked again. Ideally for the given case, we should report this is an out of order case, because rank 0, 1 calls all-to-all while all the rest ranks call all-gather-base. But when we select entries to compare, we don't have global view of the entries. In the specific case, on rank 0 and 1, it has collective of PG 7 on entry 1130 with seq ID = 1130. However, on other ranks, they have collective of PG 0 on entry 1130 with seq ID = 2. It's hard to use entry idx to do the match because if we later consider p2p, this assumption will collapse, so we now still defer it for users or further down debugging stream to figure it out. To make the message clearer, I also include both seqID and record_id (aka, entry index) in the message. (That does not mean this is not possible to implement in the code, for example, we can let all record_id to minus the maximum p2p seq id before it; but users will easily see the wrong order, so we don't think it's necessary to have that logic now) P1626755348 Differential Revision: [D63815335](https://our.internmc.facebook.com/intern/diff/D63815335/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/137256 Approved by: https://github.com/c-p-i-o
This commit is contained in:
parent
adc48a5b52
commit
c7714b8d8d
|
|
@ -312,11 +312,16 @@ def build_collectives(
|
|||
if (candidate_ranks | found_ranks) != expected_ranks:
|
||||
mismatch[pg_name] += 1
|
||||
print(
|
||||
f"Not all ranks joining collective {record_id} for group {pg_desc} collective {profiling_name} ",
|
||||
f"Not all ranks joining collective {collective_seq_id} at entry {record_id}",
|
||||
f" for group {pg_desc} collective {profiling_name} ",
|
||||
f"Missing ranks are {expected_ranks - (candidate_ranks | found_ranks)} ",
|
||||
f"{input_sizes} {output_sizes} {len(expected_ranks)} {collective_state} ",
|
||||
f"\nCollective stack traces: \n{collective_frames}",
|
||||
)
|
||||
candidate_ranks.update(found_ranks)
|
||||
candidate_idx.update(found_idx)
|
||||
found_idx.clear()
|
||||
found_ranks.clear()
|
||||
elif len(candidate_ranks) == 1:
|
||||
# case two: alltoall or alltoall_base case.
|
||||
if has_undecided_case:
|
||||
|
|
@ -334,8 +339,8 @@ def build_collectives(
|
|||
# When we see errors in all_to_all, it's hard to tell which rank is the source of the error.
|
||||
mismatch[pg_name] += 1
|
||||
print(
|
||||
f"Input/output mismatch in the collective {record_id} ",
|
||||
f"for group {pg_desc} collective {profiling_name} ",
|
||||
f"Input/output mismatch in the collective {collective_seq_id} ",
|
||||
f"at entry {record_id} for group {pg_desc} collective {profiling_name} ",
|
||||
f"input_numel {input_numel} output_numel {output_numel} ",
|
||||
f"{input_sizes} {output_sizes} {len(expected_ranks)} {collective_state} ",
|
||||
f"\nCollective stack traces: \n{collective_frames}",
|
||||
|
|
@ -362,7 +367,8 @@ def build_collectives(
|
|||
f"Culprit rank {error[0]}; {str(error[1])}" for error in errors
|
||||
)
|
||||
print(
|
||||
f"Collective {record_id} errors for group {pg_desc} collective {profiling_name} ",
|
||||
f"Collective {collective_seq_id} at entry {record_id} errors",
|
||||
f" for group {pg_desc} collective {profiling_name} ",
|
||||
f"{input_sizes} {output_sizes} {len(expected_ranks)} {collective_state} ",
|
||||
f"\nFound errors: {error_msg}.\n",
|
||||
f"\nCollective stack traces: \n{collective_frames} ",
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user