mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[fr][fix] Split MatchState and dynamic info for fr analysis downstream (#147439)
The original MatchState type was declared as a python Enum. Although we did make it callable but we consume it right away. There are downstream cases when we need it to be a python class which is not supported in Python enum. So we did a small refactoring so that we keep both the enum state and dynamic info (culprit) for the fr analysis script. Differential Revision: [D69830994](https://our.internmc.facebook.com/intern/diff/D69830994) Pull Request resolved: https://github.com/pytorch/pytorch/pull/147439 Approved by: https://github.com/fegin
This commit is contained in:
parent
41ae15faa3
commit
fb55bac3de
|
|
@ -8,7 +8,7 @@ import sys
|
|||
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent.parent
|
||||
|
||||
sys.path.insert(0, str(REPO_ROOT))
|
||||
from tools.flight_recorder.components.types import COLLECTIVES, MatchState
|
||||
from tools.flight_recorder.components.types import COLLECTIVES, MatchInfo, MatchState
|
||||
from tools.flight_recorder.components.utils import match_one_event
|
||||
|
||||
|
||||
|
|
@ -50,14 +50,14 @@ class FlightRecorderEventTest(TestCase):
|
|||
)
|
||||
membership = {"0": {0, 1}}
|
||||
self.assertEqual(
|
||||
match_one_event(e1, e1, membership, "0"), MatchState.FULLY_MATCHED
|
||||
match_one_event(e1, e1, membership, "0").state, MatchState.FULLY_MATCHED
|
||||
)
|
||||
|
||||
e2 = create_one_event(
|
||||
"all_gather", ("0", "default"), [[4, 4]], [[4, 4]], "scheduled", 1
|
||||
)
|
||||
self.assertEqual(
|
||||
match_one_event(e1, e2, membership, "0"),
|
||||
match_one_event(e1, e2, membership, "0").state,
|
||||
MatchState.COLLECTIVE_TYPE_MISMATCH,
|
||||
)
|
||||
|
||||
|
|
@ -67,34 +67,39 @@ class FlightRecorderEventTest(TestCase):
|
|||
e4 = create_one_event(
|
||||
"all_to_all", ("0", "default"), [[4, 4]], [[4, 4]], "scheduled", 1
|
||||
)
|
||||
self.assertEqual(match_one_event(e3, e4, membership, "0"), MatchState.UNDECIDED)
|
||||
self.assertEqual(
|
||||
match_one_event(e3, e4, membership, "0").state, MatchState.UNDECIDED
|
||||
)
|
||||
|
||||
e5 = create_one_event(
|
||||
"all_reduce", ("0", "default"), [[5, 4]], [[4, 4]], "scheduled", 1, 1
|
||||
)
|
||||
self.assertEqual(
|
||||
match_one_event(e1, e5, membership, "0"), MatchState.SIZE_OR_SYNTAX_MISMATCH
|
||||
match_one_event(e1, e5, membership, "0").state,
|
||||
MatchState.SIZE_OR_SYNTAX_MISMATCH,
|
||||
)
|
||||
|
||||
e6 = create_one_event(
|
||||
"all_reduce", ("0", "default"), [[4, 4]], [[5, 4]], "scheduled", 1, 2
|
||||
)
|
||||
self.assertEqual(
|
||||
match_one_event(e1, e6, membership, "0"), MatchState.SIZE_OR_SYNTAX_MISMATCH
|
||||
match_one_event(e1, e6, membership, "0").state,
|
||||
MatchState.SIZE_OR_SYNTAX_MISMATCH,
|
||||
)
|
||||
|
||||
e7 = create_one_event(
|
||||
"all_reduce", ("0", "default"), [[4, 4]], [[5, 4]], "scheduled", 2
|
||||
)
|
||||
self.assertEqual(
|
||||
match_one_event(e7, e7, membership, "0"), MatchState.SIZE_OR_SYNTAX_MISMATCH
|
||||
match_one_event(e7, e7, membership, "0").state,
|
||||
MatchState.SIZE_OR_SYNTAX_MISMATCH,
|
||||
)
|
||||
|
||||
e9 = create_one_event(
|
||||
"all_reduce", ("0", "default"), [[4, 4]], [[4, 4]], "completed", 1
|
||||
)
|
||||
self.assertEqual(
|
||||
match_one_event(e1, e9, membership, "0"),
|
||||
match_one_event(e1, e9, membership, "0").state,
|
||||
MatchState.COLLECTIVE_STATE_MISMATCH,
|
||||
)
|
||||
|
||||
|
|
@ -108,7 +113,7 @@ class FlightRecorderEventTest(TestCase):
|
|||
output_dtypes="float16",
|
||||
)
|
||||
self.assertEqual(
|
||||
match_one_event(e10, e9, membership, "0"),
|
||||
match_one_event(e10, e9, membership, "0").state,
|
||||
MatchState.COLLECTIVE_DTYPE_MISMATCH,
|
||||
)
|
||||
|
||||
|
|
@ -128,9 +133,19 @@ class FlightRecorderEventTest(TestCase):
|
|||
collective, ("0", "default"), input_sizes, output_sizes, "scheduled", 1
|
||||
)
|
||||
membership = {"0": {0, 1}}
|
||||
result = match_one_event(event, event, membership, "0")
|
||||
result = match_one_event(event, event, membership, "0").state
|
||||
self.assertEqual(result, expectedState)
|
||||
|
||||
|
||||
class FlightMatchInfoTest(TestCase):
|
||||
def test_match_info(self):
|
||||
m1 = MatchInfo(MatchState.FULLY_MATCHED, "rank 0")
|
||||
m2 = MatchInfo(MatchState.FULLY_MATCHED, "rank 1")
|
||||
self.assertEqual(m1.state, MatchState.FULLY_MATCHED)
|
||||
self.assertEqual(m1.state, m2.state)
|
||||
self.assertEqual(str(m1), "Error type: FULLY_MATCHED, rank 0")
|
||||
self.assertEqual(str(m2), "Error type: FULLY_MATCHED, rank 1")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from tools.flight_recorder.components.types import (
|
|||
Database,
|
||||
EntryState,
|
||||
Group,
|
||||
MatchInfo,
|
||||
MatchState,
|
||||
Membership,
|
||||
NCCLCall,
|
||||
|
|
@ -265,21 +266,23 @@ def build_collectives(
|
|||
and e["process_group"][1] == desc
|
||||
and e["collective_seq_id"] == entry_state.collective_seq_id
|
||||
):
|
||||
match_state = match_one_event(
|
||||
match_info = match_one_event(
|
||||
entries[0], e, _memberships, pg_name
|
||||
)
|
||||
if (
|
||||
match_state
|
||||
match_info.state
|
||||
in [MatchState.FULLY_MATCHED, MatchState.UNDECIDED]
|
||||
and mismatch[pg_name] == 0
|
||||
):
|
||||
found_ranks.add(o)
|
||||
found_idx[o] = i
|
||||
has_undecided_case = match_state == MatchState.UNDECIDED
|
||||
has_undecided_case = (
|
||||
match_info.state == MatchState.UNDECIDED
|
||||
)
|
||||
else:
|
||||
candidate_ranks.add(o)
|
||||
candidate_idx[o] = i
|
||||
if match_state not in [
|
||||
if match_info.state not in [
|
||||
MatchState.FULLY_MATCHED,
|
||||
MatchState.UNDECIDED,
|
||||
]:
|
||||
|
|
@ -287,7 +290,7 @@ def build_collectives(
|
|||
# But it's possible that the current rank is the culprit, then users will
|
||||
# see lots of normal ranks reported as culprit.
|
||||
# TODO: we need to figure out a better way to handle the case mentioned above.
|
||||
errors.add((o, match_state))
|
||||
errors.add((o, match_info))
|
||||
break
|
||||
|
||||
# case one: not every rank join the collective or in the flight recorder.
|
||||
|
|
@ -331,7 +334,9 @@ def build_collectives(
|
|||
candidate_idx.update(found_idx)
|
||||
found_idx.clear()
|
||||
found_ranks.clear()
|
||||
errors.add((first_rank, MatchState.SIZE_OR_SYNTAX_MISMATCH))
|
||||
errors.add(
|
||||
(first_rank, MatchInfo(MatchState.SIZE_OR_SYNTAX_MISMATCH))
|
||||
)
|
||||
else:
|
||||
found_ranks.update(candidate_ranks)
|
||||
found_idx.update(candidate_idx)
|
||||
|
|
|
|||
|
|
@ -63,14 +63,24 @@ class MatchState(Enum):
|
|||
COLLECTIVE_DTYPE_MISMATCH = auto()
|
||||
UNDECIDED = auto()
|
||||
|
||||
def __call__(self, culprit: Optional[str] = None) -> "MatchState":
|
||||
# Make the enum instance callable to add culprit.
|
||||
|
||||
class MatchInfo:
|
||||
"""
|
||||
Aside from the match state, we also store some dynamic info for the match such as the culprit rank
|
||||
or collective state that caused the mismatch.
|
||||
"""
|
||||
|
||||
def __init__(self, state: MatchState, culprit: Optional[str] = None) -> None:
|
||||
self._state = state
|
||||
self.culprit = culprit
|
||||
return self
|
||||
|
||||
def __str__(self) -> str:
|
||||
details = f", {self.culprit}" if getattr(self, "culprit", None) else ""
|
||||
return f"Error type: {self.name}{details}"
|
||||
return f"Error type: {self._state.name}{details}"
|
||||
|
||||
@property
|
||||
def state(self) -> MatchState:
|
||||
return self._state
|
||||
|
||||
|
||||
"""
|
||||
|
|
@ -130,7 +140,7 @@ class Collective(NamedTuple):
|
|||
output_numel: Optional[int] = None
|
||||
missing_ranks: Optional[set[int]] = None
|
||||
mismatch_collectives: Optional[dict[int, "Collective"]] = None
|
||||
type_of_mismatch: Optional[MatchState] = None
|
||||
type_of_mismatch: Optional[MatchInfo] = None
|
||||
|
||||
|
||||
class NCCLCall(NamedTuple):
|
||||
|
|
@ -219,7 +229,7 @@ class EntryState:
|
|||
self.missing_ranks: set[int]
|
||||
self.input_numel: int
|
||||
self.output_numel: int
|
||||
self.errors: set[tuple[int, MatchState]]
|
||||
self.errors: set[tuple[int, MatchInfo]]
|
||||
|
||||
def log(
|
||||
self,
|
||||
|
|
@ -227,7 +237,7 @@ class EntryState:
|
|||
logger_msg: str,
|
||||
frame_formatter: Any,
|
||||
total_numel: Optional[tuple[int, int]] = None,
|
||||
errors: Optional[set[tuple[int, MatchState]]] = None,
|
||||
errors: Optional[set[tuple[int, MatchInfo]]] = None,
|
||||
missing_ranks: Optional[set[int]] = None,
|
||||
) -> None:
|
||||
logger.info(
|
||||
|
|
@ -263,7 +273,7 @@ class EntryState:
|
|||
def to_collective(
|
||||
self,
|
||||
id: int,
|
||||
errors: Optional[set[tuple[int, MatchState]]] = None,
|
||||
errors: Optional[set[tuple[int, MatchInfo]]] = None,
|
||||
idx_map: Optional[dict[int, int]] = None,
|
||||
all_entries: Optional[dict[int, list[dict[str, Any]]]] = None,
|
||||
) -> Collective:
|
||||
|
|
@ -446,7 +456,7 @@ class Op:
|
|||
f"{p2p_info}, " if p2p_info else ""
|
||||
)
|
||||
|
||||
def match(self, other: "Op") -> MatchState:
|
||||
def match(self, other: "Op") -> MatchInfo:
|
||||
# TODO: I think this can validly not match,
|
||||
# e.g. if one PG was used for p2p ops between only some of the peers?
|
||||
# if self.seq_id != other.seq_id:
|
||||
|
|
@ -455,61 +465,67 @@ class Op:
|
|||
if self.type == "send":
|
||||
# TODO: We need more states for p2p ops.
|
||||
return (
|
||||
MatchState.FULLY_MATCHED
|
||||
MatchInfo(MatchState.FULLY_MATCHED)
|
||||
if (
|
||||
other.type == "recv"
|
||||
and self.src == other.src
|
||||
and self.dst == other.dst
|
||||
and self.input_sizes == other.output_sizes
|
||||
)
|
||||
else MatchState.SIZE_OR_SYNTAX_MISMATCH
|
||||
else MatchInfo(MatchState.SIZE_OR_SYNTAX_MISMATCH)
|
||||
)
|
||||
elif self.type == "recv":
|
||||
return (
|
||||
MatchState.FULLY_MATCHED
|
||||
MatchInfo(MatchState.FULLY_MATCHED)
|
||||
if (
|
||||
other.type == "send"
|
||||
and self.src == other.src
|
||||
and self.dst == other.dst
|
||||
and self.output_sizes == other.input_sizes
|
||||
)
|
||||
else MatchState.SIZE_OR_SYNTAX_MISMATCH
|
||||
else MatchInfo(MatchState.SIZE_OR_SYNTAX_MISMATCH)
|
||||
)
|
||||
elif self.type in COLLECTIVES:
|
||||
if self.type != other.type:
|
||||
return MatchState.COLLECTIVE_TYPE_MISMATCH(
|
||||
f"Expected collective type: '{self.type}' does not match found collective type: '{other.type}'"
|
||||
return MatchInfo(
|
||||
MatchState.COLLECTIVE_TYPE_MISMATCH,
|
||||
f"Expected collective type: '{self.type}' does not match found collective type: '{other.type}'",
|
||||
)
|
||||
if self.state != other.state:
|
||||
# MatchState()
|
||||
return MatchState.COLLECTIVE_STATE_MISMATCH(
|
||||
f"Expected state: '{self.state}' does not match found state: '{other.state}'"
|
||||
return MatchInfo(
|
||||
MatchState.COLLECTIVE_STATE_MISMATCH,
|
||||
f"Expected state: '{self.state}' does not match found state: '{other.state}'",
|
||||
)
|
||||
if (
|
||||
set(self.input_dtypes) != set(self.output_dtypes)
|
||||
or set(self.input_dtypes) != set(other.input_dtypes)
|
||||
or set(self.input_dtypes) != set(other.output_dtypes)
|
||||
):
|
||||
return MatchState.COLLECTIVE_DTYPE_MISMATCH(
|
||||
return MatchInfo(
|
||||
MatchState.COLLECTIVE_DTYPE_MISMATCH,
|
||||
f"Expected dtypes: '{set(self.input_dtypes)}' does not "
|
||||
f"match found dtype: '{set(self.output_dtypes)}/"
|
||||
f"{set(other.input_dtypes)}/{set(other.output_dtypes)}'",
|
||||
)
|
||||
if self.type == "all_to_all":
|
||||
return MatchState.UNDECIDED
|
||||
return MatchInfo(MatchState.UNDECIDED)
|
||||
if self.type != "scatter" and self.input_sizes != other.input_sizes:
|
||||
return MatchState.SIZE_OR_SYNTAX_MISMATCH(
|
||||
return MatchInfo(
|
||||
MatchState.SIZE_OR_SYNTAX_MISMATCH,
|
||||
f"Expected input sizes: '{self.input_sizes}' does not match found input sizes: "
|
||||
f"'{other.input_sizes}'",
|
||||
)
|
||||
if self.type != "gather" and self.output_sizes != other.output_sizes:
|
||||
return MatchState.SIZE_OR_SYNTAX_MISMATCH(
|
||||
return MatchInfo(
|
||||
MatchState.SIZE_OR_SYNTAX_MISMATCH,
|
||||
f"Expected output sizes: '{self.output_sizes}' does not match found output sizes: "
|
||||
f"'{other.output_sizes}'"
|
||||
f"'{other.output_sizes}'",
|
||||
)
|
||||
if self.type == "all_reduce" and self.input_sizes != other.output_sizes:
|
||||
return MatchState.SIZE_OR_SYNTAX_MISMATCH(
|
||||
f"Expected input sizes: '{self.input_sizes}' does not match found output sizes: '{other.output_sizes}'"
|
||||
return MatchInfo(
|
||||
MatchState.SIZE_OR_SYNTAX_MISMATCH,
|
||||
f"Expected input sizes: '{self.input_sizes}' does not match found output sizes: '{other.output_sizes}'",
|
||||
)
|
||||
# TODO: need to consider uneven sharding for all-gather.
|
||||
# TODO: need to consider all_gather_into_tensor_coalesced (coalesced related)
|
||||
|
|
@ -520,7 +536,8 @@ class Op:
|
|||
math.prod(other.output_sizes[0])
|
||||
== math.prod(self.input_sizes[0]) * self.pg_size
|
||||
):
|
||||
return MatchState.SIZE_OR_SYNTAX_MISMATCH(
|
||||
return MatchInfo(
|
||||
MatchState.SIZE_OR_SYNTAX_MISMATCH,
|
||||
f"Found input numel '{math.prod(other.input_sizes[0])} * pg size {self.pg_size}' "
|
||||
f"does not match output numel '{math.prod(other.output_sizes[0])}'",
|
||||
)
|
||||
|
|
@ -531,14 +548,15 @@ class Op:
|
|||
math.prod(other.input_sizes[0])
|
||||
== math.prod(self.output_sizes[0]) * self.pg_size
|
||||
):
|
||||
return MatchState.SIZE_OR_SYNTAX_MISMATCH(
|
||||
return MatchInfo(
|
||||
MatchState.SIZE_OR_SYNTAX_MISMATCH,
|
||||
f"Found input numel '{math.prod(other.input_sizes[0])}' does not match output numel "
|
||||
f"'{math.prod(other.output_sizes[0])} * pg size {self.pg_size}'",
|
||||
)
|
||||
elif self.type == "coalesced":
|
||||
return (
|
||||
MatchState.FULLY_MATCHED
|
||||
MatchInfo(MatchState.FULLY_MATCHED)
|
||||
if (other.type == "coalesced")
|
||||
else MatchState.SIZE_OR_SYNTAX_MISMATCH
|
||||
else MatchInfo(MatchState.SIZE_OR_SYNTAX_MISMATCH)
|
||||
)
|
||||
return MatchState.FULLY_MATCHED
|
||||
return MatchInfo(MatchState.FULLY_MATCHED)
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from typing import Any
|
|||
from tools.flight_recorder.components.fr_logger import FlightRecorderLogger
|
||||
from tools.flight_recorder.components.types import (
|
||||
Group,
|
||||
MatchInfo,
|
||||
MatchState,
|
||||
Membership,
|
||||
Op,
|
||||
|
|
@ -46,7 +47,7 @@ def match_one_event(
|
|||
event_b: dict[Any, Any],
|
||||
memberships: dict[str, set[Any]],
|
||||
pg_name: str,
|
||||
) -> MatchState:
|
||||
) -> MatchInfo:
|
||||
op_a = Op(event_a, memberships, pg_name)
|
||||
op_b = Op(event_b, memberships, pg_name)
|
||||
return op_a.match(op_b)
|
||||
|
|
@ -152,7 +153,7 @@ def match_coalesced_groups(
|
|||
dst_global_rank = sorted(memberships[op.pg_name])[op.dst]
|
||||
peer_ops = all_ops[dst_global_rank]
|
||||
for i, other in enumerate(peer_ops):
|
||||
if op.match(other) == MatchState.FULLY_MATCHED:
|
||||
if op.match(other).state == MatchState.FULLY_MATCHED:
|
||||
match_idx = i
|
||||
break
|
||||
elif op.dst == other.src:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user