[FR] Enable best effort parital analysis and verbose mode for trace printing (#139853)

Based on user feedback, we want to enable two things for FR analysis script:
1. Print out more information when verbose is specified.
2. Perform best effort based analysis when not all ranks have FR trace dumped.

Differential Revision: [D65516081](https://our.internmc.facebook.com/intern/diff/D65516081/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139853
Approved by: https://github.com/c-p-i-o
This commit is contained in:
fduwjj 2024-11-07 15:38:31 -08:00 committed by PyTorch MergeBot
parent cb15c15157
commit ceb44b22dc
6 changed files with 73 additions and 17 deletions

View File

@ -37,6 +37,7 @@ def create_one_event(
"output_dtypes": output_dtypes,
"collective_seq_id": str(collective_seq_id),
"p2p_seq_id": str(p2p_seq_id),
"time_created_ns": 0,
}

View File

@ -6,6 +6,7 @@
import argparse
import ast
import os
import sys
from typing import Any, Dict, List, Set, Tuple # type: ignore[attr-defined]
@ -186,6 +187,9 @@ def build_collectives(
# instead, just record the remaining ops as NCCLCalls
mismatch = {_groups[g].id: 0 for g in _groups}
MISMATCH_TAIL = 10
# For best effort partial analysis.
dumps_ranks = {int(key) for key in all_entries.keys()}
"""
- it doesn't matter what order I put collectives/ncclops into their table. we can later on re-sort it by start time
- there could be multiple options for the "first" collective to pair up (rank 0,1 might do a bcast while rank 2,3 do a bcast)
@ -238,7 +242,7 @@ def build_collectives(
else []
)
all_coalesced_entries[curr] = grp
for index, entry in grp:
for _, entry in grp:
op = Op(entry, _memberships, pg_name)
peer = None
if op.type == "send":
@ -314,7 +318,9 @@ def build_collectives(
break
# case one: not every rank join the collective or in the flight recorder.
if (candidate_ranks | found_ranks) != expected_ranks:
if (candidate_ranks | found_ranks) != expected_ranks and expected_ranks - (
candidate_ranks | found_ranks
) <= dumps_ranks:
mismatch[pg_name] += 1
logger.info(
"Not all ranks joining collective %s at entry %s",
@ -334,7 +340,7 @@ def build_collectives(
candidate_idx.update(found_idx)
found_idx.clear()
found_ranks.clear()
elif len(candidate_ranks) == 1:
elif len(candidate_ranks) == 1 and dumps_ranks == expected_ranks:
# case two: alltoall or alltoall_base case.
if has_undecided_case:
alltoall_cases = [entries[0]] + [
@ -398,6 +404,19 @@ def build_collectives(
candidate_idx.update(found_idx)
found_idx.clear()
found_ranks.clear()
# partial analysis case when we cannot decide what's wrong with this collective entry.
else:
candidate_ranks.update(found_ranks)
candidate_idx.update(found_idx)
found_idx.clear()
found_ranks.clear()
mismatch[pg_name] += 1
logger.info(
"We cannot decide what's wrong with this collective entry "
"because we missed FR dumps from ranks (%s) so we don't have enough "
"information. If you want to debug further use -j to dump all raw trace",
str(expected_ranks - dumps_ranks),
)
# at this point there are 3 possibilities
# 1. we found a match on all the ranks that are members of the group
@ -450,6 +469,8 @@ def build_collectives(
def build_db(
details: Dict[str, Dict[str, Any]], args: argparse.Namespace, version: str
) -> Database:
if args.verbose:
os.environ["FR_TRACE_VERBOSE_OUTPUT"] = "1"
# temporary state used for building database
entries = {}
pg_config = {}
@ -470,12 +491,13 @@ def build_db(
)
logger.debug("built groups, memberships")
if not args.allow_incomplete_ranks:
check_no_missing_dump_files(entries, memberships)
if args.just_print_entries:
just_print_entries(entries, _groups, _memberships, _pg_guids, args)
sys.exit(0)
check_no_missing_dump_files(entries, memberships)
tracebacks, collectives, nccl_calls = build_collectives(
entries, _groups, _memberships, _pg_guids, version
)

View File

@ -35,6 +35,15 @@ class JobConfig:
type=int,
help="List of ranks we want to show traces for.",
)
self.parser.add_argument(
"--allow-incomplete-ranks",
action="store_true",
help=(
"FR trace require all ranks to have dumps for analysis. "
"This flag allows best-effort partial analysis of results "
"and printing of collected data."
),
)
self.parser.add_argument(
"--pg-filters",
default=None,

View File

@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import gc
import os
import pickle
@ -11,7 +12,7 @@ import re
import time
import typing
from collections import defaultdict
from typing import Any, Dict, List, Optional, Set, Tuple, Union
from typing import Any, Dict, List, Set, Tuple, Union
from tools.flight_recorder.components.fr_logger import FlightRecorderLogger
@ -66,16 +67,15 @@ def _determine_prefix(files: List[str]) -> str:
)
def read_dir(
prefix: Optional[str], folder: str
) -> Tuple[Dict[str, Dict[str, Any]], str]:
def read_dir(args: argparse.Namespace) -> Tuple[Dict[str, Dict[str, Any]], str]:
gc.disable()
prefix = args.prefix
details = {}
t0 = time.time()
version = ""
filecount = 0
assert os.path.isdir(folder), f"folder {folder} does not exist"
for root, _, files in os.walk(folder):
assert os.path.isdir(args.folder), f"folder {args.folder} does not exist"
for root, _, files in os.walk(args.folder):
if prefix is None:
prefix = _determine_prefix(files)
for f in files:
@ -86,6 +86,6 @@ def read_dir(
if not version:
version = str(details[f]["version"])
tb = time.time()
assert len(details) > 0, f"no files loaded from {folder} with prefix {prefix}"
assert len(details) > 0, f"no files loaded from {args.folder} with prefix {prefix}"
logger.debug("loaded %s files in %ss", filecount, tb - t0)
return details, version

View File

@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.
import math
import os
from enum import auto, Enum
from typing import ( # type: ignore[attr-defined]
_eval_type,
@ -199,7 +200,7 @@ class Op:
type = parts[0]
meta = parts[1] if len(parts) == 2 else None
self.state = event["state"]
self.pg_name, _ = event["process_group"]
self.pg_name, self.pg_desc = event["process_group"]
assert type in COLLECTIVES | P2P | {
"coalesced"
}, f"{type} is not a supported operation"
@ -212,7 +213,6 @@ class Op:
self._dst, self._src = int(d), int(s)
else:
self._src, self._dst = -1, -1
_, pg_desc = event["process_group"]
self._init_global_src_dst(memberships[pg_name])
self.pg_size = len(memberships[pg_name])
if type in P2P | COLLECTIVES:
@ -224,6 +224,8 @@ class Op:
self.p2p_seq_id = event["p2p_seq_id"]
self.input_dtypes = event["input_dtypes"]
self.output_dtypes = event["output_dtypes"]
self.time_created_ns = event["time_created_ns"]
self.is_verbose = os.getenv("FR_TRACE_VERBOSE_OUTPUT", "0") == "1"
def _init_global_src_dst(self, pg_ranks: Set[Any]) -> None:
pg_ranks = sorted(pg_ranks)
@ -241,9 +243,31 @@ class Op:
return self._dst
def __repr__(self) -> str:
p2p_info = ""
if self.type in P2P:
return f"{self.type}(s={self._src_g} d={self._dst_g}, sz={self.input_sizes}, state={self.state})"
return f"{self.type}(input_sizes={self.input_sizes}, state={self.state})"
p2p_info = f"s={self._src_g} d={self._dst_g}"
if self.is_verbose:
verbose_info = (
f"timestamp_created={self.time_created_ns}",
p2p_info,
f"input_sizes={self.input_sizes}",
f"output_sizes={self.output_sizes}",
f"input_dtypes={self.input_dtypes}",
f"output_dtypes={self.output_dtypes}",
"collective_seq_id | p2p_seq_id="
f"{self.p2p_seq_id if self.type in P2P else self.collective_seq_id}",
f"pg_name={self.pg_name}",
f"pg_description={self.pg_desc}",
f"pg_size={self.pg_size}",
f"state={self.state}",
)
return f"{self.type}(%s)" % ", ".join(s for s in verbose_info if s)
return (
f"{self.type}(%sinput_sizes={self.input_sizes}, state={self.state})"
% f"{p2p_info}, "
if p2p_info
else ""
)
def match(self, other: "Op") -> MatchState:
# TODO: I think this can validly not match,

View File

@ -41,7 +41,7 @@ def main(args: Optional[Sequence[str]] = None) -> None:
config = JobConfig()
args = config.parse_args(args)
assert args.trace_dir, "Trace directory trace_dir is required"
details, version = read_dir(args.prefix, args.trace_dir)
details, version = read_dir(args)
db = build_db(details, args, version)
if args.output:
with open(args.output, "wb") as f: