[fr][fix] Split MatchState and dynamic info for fr analysis downstream (#147439)

The original MatchState type was declared as a python Enum. Although we did make it callable but we consume it right away. There are downstream cases when we need it to be a python class which is not supported in Python enum. So we did a small refactoring so that we keep both the enum state and dynamic info (culprit) for the fr analysis script.

Differential Revision: [D69830994](https://our.internmc.facebook.com/intern/diff/D69830994)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147439
Approved by: https://github.com/fegin
This commit is contained in:
fduwjj 2025-02-18 21:33:50 -08:00 committed by PyTorch MergeBot
parent 41ae15faa3
commit fb55bac3de
4 changed files with 86 additions and 47 deletions

View File

@ -8,7 +8,7 @@ import sys
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent.parent
sys.path.insert(0, str(REPO_ROOT))
from tools.flight_recorder.components.types import COLLECTIVES, MatchState
from tools.flight_recorder.components.types import COLLECTIVES, MatchInfo, MatchState
from tools.flight_recorder.components.utils import match_one_event
@ -50,14 +50,14 @@ class FlightRecorderEventTest(TestCase):
)
membership = {"0": {0, 1}}
self.assertEqual(
match_one_event(e1, e1, membership, "0"), MatchState.FULLY_MATCHED
match_one_event(e1, e1, membership, "0").state, MatchState.FULLY_MATCHED
)
e2 = create_one_event(
"all_gather", ("0", "default"), [[4, 4]], [[4, 4]], "scheduled", 1
)
self.assertEqual(
match_one_event(e1, e2, membership, "0"),
match_one_event(e1, e2, membership, "0").state,
MatchState.COLLECTIVE_TYPE_MISMATCH,
)
@ -67,34 +67,39 @@ class FlightRecorderEventTest(TestCase):
e4 = create_one_event(
"all_to_all", ("0", "default"), [[4, 4]], [[4, 4]], "scheduled", 1
)
self.assertEqual(match_one_event(e3, e4, membership, "0"), MatchState.UNDECIDED)
self.assertEqual(
match_one_event(e3, e4, membership, "0").state, MatchState.UNDECIDED
)
e5 = create_one_event(
"all_reduce", ("0", "default"), [[5, 4]], [[4, 4]], "scheduled", 1, 1
)
self.assertEqual(
match_one_event(e1, e5, membership, "0"), MatchState.SIZE_OR_SYNTAX_MISMATCH
match_one_event(e1, e5, membership, "0").state,
MatchState.SIZE_OR_SYNTAX_MISMATCH,
)
e6 = create_one_event(
"all_reduce", ("0", "default"), [[4, 4]], [[5, 4]], "scheduled", 1, 2
)
self.assertEqual(
match_one_event(e1, e6, membership, "0"), MatchState.SIZE_OR_SYNTAX_MISMATCH
match_one_event(e1, e6, membership, "0").state,
MatchState.SIZE_OR_SYNTAX_MISMATCH,
)
e7 = create_one_event(
"all_reduce", ("0", "default"), [[4, 4]], [[5, 4]], "scheduled", 2
)
self.assertEqual(
match_one_event(e7, e7, membership, "0"), MatchState.SIZE_OR_SYNTAX_MISMATCH
match_one_event(e7, e7, membership, "0").state,
MatchState.SIZE_OR_SYNTAX_MISMATCH,
)
e9 = create_one_event(
"all_reduce", ("0", "default"), [[4, 4]], [[4, 4]], "completed", 1
)
self.assertEqual(
match_one_event(e1, e9, membership, "0"),
match_one_event(e1, e9, membership, "0").state,
MatchState.COLLECTIVE_STATE_MISMATCH,
)
@ -108,7 +113,7 @@ class FlightRecorderEventTest(TestCase):
output_dtypes="float16",
)
self.assertEqual(
match_one_event(e10, e9, membership, "0"),
match_one_event(e10, e9, membership, "0").state,
MatchState.COLLECTIVE_DTYPE_MISMATCH,
)
@ -128,9 +133,19 @@ class FlightRecorderEventTest(TestCase):
collective, ("0", "default"), input_sizes, output_sizes, "scheduled", 1
)
membership = {"0": {0, 1}}
result = match_one_event(event, event, membership, "0")
result = match_one_event(event, event, membership, "0").state
self.assertEqual(result, expectedState)
class FlightMatchInfoTest(TestCase):
def test_match_info(self):
m1 = MatchInfo(MatchState.FULLY_MATCHED, "rank 0")
m2 = MatchInfo(MatchState.FULLY_MATCHED, "rank 1")
self.assertEqual(m1.state, MatchState.FULLY_MATCHED)
self.assertEqual(m1.state, m2.state)
self.assertEqual(str(m1), "Error type: FULLY_MATCHED, rank 0")
self.assertEqual(str(m2), "Error type: FULLY_MATCHED, rank 1")
if __name__ == "__main__":
run_tests()

View File

@ -16,6 +16,7 @@ from tools.flight_recorder.components.types import (
Database,
EntryState,
Group,
MatchInfo,
MatchState,
Membership,
NCCLCall,
@ -265,21 +266,23 @@ def build_collectives(
and e["process_group"][1] == desc
and e["collective_seq_id"] == entry_state.collective_seq_id
):
match_state = match_one_event(
match_info = match_one_event(
entries[0], e, _memberships, pg_name
)
if (
match_state
match_info.state
in [MatchState.FULLY_MATCHED, MatchState.UNDECIDED]
and mismatch[pg_name] == 0
):
found_ranks.add(o)
found_idx[o] = i
has_undecided_case = match_state == MatchState.UNDECIDED
has_undecided_case = (
match_info.state == MatchState.UNDECIDED
)
else:
candidate_ranks.add(o)
candidate_idx[o] = i
if match_state not in [
if match_info.state not in [
MatchState.FULLY_MATCHED,
MatchState.UNDECIDED,
]:
@ -287,7 +290,7 @@ def build_collectives(
# But it's possible that the current rank is the culprit, then users will
# see lots of normal ranks reported as culprit.
# TODO: we need to figure out a better way to handle the case mentioned above.
errors.add((o, match_state))
errors.add((o, match_info))
break
# case one: not every rank join the collective or in the flight recorder.
@ -331,7 +334,9 @@ def build_collectives(
candidate_idx.update(found_idx)
found_idx.clear()
found_ranks.clear()
errors.add((first_rank, MatchState.SIZE_OR_SYNTAX_MISMATCH))
errors.add(
(first_rank, MatchInfo(MatchState.SIZE_OR_SYNTAX_MISMATCH))
)
else:
found_ranks.update(candidate_ranks)
found_idx.update(candidate_idx)

View File

@ -63,14 +63,24 @@ class MatchState(Enum):
COLLECTIVE_DTYPE_MISMATCH = auto()
UNDECIDED = auto()
def __call__(self, culprit: Optional[str] = None) -> "MatchState":
# Make the enum instance callable to add culprit.
class MatchInfo:
"""
Aside from the match state, we also store some dynamic info for the match such as the culprit rank
or collective state that caused the mismatch.
"""
def __init__(self, state: MatchState, culprit: Optional[str] = None) -> None:
self._state = state
self.culprit = culprit
return self
def __str__(self) -> str:
details = f", {self.culprit}" if getattr(self, "culprit", None) else ""
return f"Error type: {self.name}{details}"
return f"Error type: {self._state.name}{details}"
@property
def state(self) -> MatchState:
return self._state
"""
@ -130,7 +140,7 @@ class Collective(NamedTuple):
output_numel: Optional[int] = None
missing_ranks: Optional[set[int]] = None
mismatch_collectives: Optional[dict[int, "Collective"]] = None
type_of_mismatch: Optional[MatchState] = None
type_of_mismatch: Optional[MatchInfo] = None
class NCCLCall(NamedTuple):
@ -219,7 +229,7 @@ class EntryState:
self.missing_ranks: set[int]
self.input_numel: int
self.output_numel: int
self.errors: set[tuple[int, MatchState]]
self.errors: set[tuple[int, MatchInfo]]
def log(
self,
@ -227,7 +237,7 @@ class EntryState:
logger_msg: str,
frame_formatter: Any,
total_numel: Optional[tuple[int, int]] = None,
errors: Optional[set[tuple[int, MatchState]]] = None,
errors: Optional[set[tuple[int, MatchInfo]]] = None,
missing_ranks: Optional[set[int]] = None,
) -> None:
logger.info(
@ -263,7 +273,7 @@ class EntryState:
def to_collective(
self,
id: int,
errors: Optional[set[tuple[int, MatchState]]] = None,
errors: Optional[set[tuple[int, MatchInfo]]] = None,
idx_map: Optional[dict[int, int]] = None,
all_entries: Optional[dict[int, list[dict[str, Any]]]] = None,
) -> Collective:
@ -446,7 +456,7 @@ class Op:
f"{p2p_info}, " if p2p_info else ""
)
def match(self, other: "Op") -> MatchState:
def match(self, other: "Op") -> MatchInfo:
# TODO: I think this can validly not match,
# e.g. if one PG was used for p2p ops between only some of the peers?
# if self.seq_id != other.seq_id:
@ -455,61 +465,67 @@ class Op:
if self.type == "send":
# TODO: We need more states for p2p ops.
return (
MatchState.FULLY_MATCHED
MatchInfo(MatchState.FULLY_MATCHED)
if (
other.type == "recv"
and self.src == other.src
and self.dst == other.dst
and self.input_sizes == other.output_sizes
)
else MatchState.SIZE_OR_SYNTAX_MISMATCH
else MatchInfo(MatchState.SIZE_OR_SYNTAX_MISMATCH)
)
elif self.type == "recv":
return (
MatchState.FULLY_MATCHED
MatchInfo(MatchState.FULLY_MATCHED)
if (
other.type == "send"
and self.src == other.src
and self.dst == other.dst
and self.output_sizes == other.input_sizes
)
else MatchState.SIZE_OR_SYNTAX_MISMATCH
else MatchInfo(MatchState.SIZE_OR_SYNTAX_MISMATCH)
)
elif self.type in COLLECTIVES:
if self.type != other.type:
return MatchState.COLLECTIVE_TYPE_MISMATCH(
f"Expected collective type: '{self.type}' does not match found collective type: '{other.type}'"
return MatchInfo(
MatchState.COLLECTIVE_TYPE_MISMATCH,
f"Expected collective type: '{self.type}' does not match found collective type: '{other.type}'",
)
if self.state != other.state:
# MatchState()
return MatchState.COLLECTIVE_STATE_MISMATCH(
f"Expected state: '{self.state}' does not match found state: '{other.state}'"
return MatchInfo(
MatchState.COLLECTIVE_STATE_MISMATCH,
f"Expected state: '{self.state}' does not match found state: '{other.state}'",
)
if (
set(self.input_dtypes) != set(self.output_dtypes)
or set(self.input_dtypes) != set(other.input_dtypes)
or set(self.input_dtypes) != set(other.output_dtypes)
):
return MatchState.COLLECTIVE_DTYPE_MISMATCH(
return MatchInfo(
MatchState.COLLECTIVE_DTYPE_MISMATCH,
f"Expected dtypes: '{set(self.input_dtypes)}' does not "
f"match found dtype: '{set(self.output_dtypes)}/"
f"{set(other.input_dtypes)}/{set(other.output_dtypes)}'",
)
if self.type == "all_to_all":
return MatchState.UNDECIDED
return MatchInfo(MatchState.UNDECIDED)
if self.type != "scatter" and self.input_sizes != other.input_sizes:
return MatchState.SIZE_OR_SYNTAX_MISMATCH(
return MatchInfo(
MatchState.SIZE_OR_SYNTAX_MISMATCH,
f"Expected input sizes: '{self.input_sizes}' does not match found input sizes: "
f"'{other.input_sizes}'",
)
if self.type != "gather" and self.output_sizes != other.output_sizes:
return MatchState.SIZE_OR_SYNTAX_MISMATCH(
return MatchInfo(
MatchState.SIZE_OR_SYNTAX_MISMATCH,
f"Expected output sizes: '{self.output_sizes}' does not match found output sizes: "
f"'{other.output_sizes}'"
f"'{other.output_sizes}'",
)
if self.type == "all_reduce" and self.input_sizes != other.output_sizes:
return MatchState.SIZE_OR_SYNTAX_MISMATCH(
f"Expected input sizes: '{self.input_sizes}' does not match found output sizes: '{other.output_sizes}'"
return MatchInfo(
MatchState.SIZE_OR_SYNTAX_MISMATCH,
f"Expected input sizes: '{self.input_sizes}' does not match found output sizes: '{other.output_sizes}'",
)
# TODO: need to consider uneven sharding for all-gather.
# TODO: need to consider all_gather_into_tensor_coalesced (coalesced related)
@ -520,7 +536,8 @@ class Op:
math.prod(other.output_sizes[0])
== math.prod(self.input_sizes[0]) * self.pg_size
):
return MatchState.SIZE_OR_SYNTAX_MISMATCH(
return MatchInfo(
MatchState.SIZE_OR_SYNTAX_MISMATCH,
f"Found input numel '{math.prod(other.input_sizes[0])} * pg size {self.pg_size}' "
f"does not match output numel '{math.prod(other.output_sizes[0])}'",
)
@ -531,14 +548,15 @@ class Op:
math.prod(other.input_sizes[0])
== math.prod(self.output_sizes[0]) * self.pg_size
):
return MatchState.SIZE_OR_SYNTAX_MISMATCH(
return MatchInfo(
MatchState.SIZE_OR_SYNTAX_MISMATCH,
f"Found input numel '{math.prod(other.input_sizes[0])}' does not match output numel "
f"'{math.prod(other.output_sizes[0])} * pg size {self.pg_size}'",
)
elif self.type == "coalesced":
return (
MatchState.FULLY_MATCHED
MatchInfo(MatchState.FULLY_MATCHED)
if (other.type == "coalesced")
else MatchState.SIZE_OR_SYNTAX_MISMATCH
else MatchInfo(MatchState.SIZE_OR_SYNTAX_MISMATCH)
)
return MatchState.FULLY_MATCHED
return MatchInfo(MatchState.FULLY_MATCHED)

View File

@ -11,6 +11,7 @@ from typing import Any
from tools.flight_recorder.components.fr_logger import FlightRecorderLogger
from tools.flight_recorder.components.types import (
Group,
MatchInfo,
MatchState,
Membership,
Op,
@ -46,7 +47,7 @@ def match_one_event(
event_b: dict[Any, Any],
memberships: dict[str, set[Any]],
pg_name: str,
) -> MatchState:
) -> MatchInfo:
op_a = Op(event_a, memberships, pg_name)
op_b = Op(event_b, memberships, pg_name)
return op_a.match(op_b)
@ -152,7 +153,7 @@ def match_coalesced_groups(
dst_global_rank = sorted(memberships[op.pg_name])[op.dst]
peer_ops = all_ops[dst_global_rank]
for i, other in enumerate(peer_ops):
if op.match(other) == MatchState.FULLY_MATCHED:
if op.match(other).state == MatchState.FULLY_MATCHED:
match_idx = i
break
elif op.dst == other.src: