Add knobs in FR dump by watchdog (stacktrace and only active collectives) and trigger FR even on any exceptions (#164591)

This PR includes a couple of changes to extend FlightRecorder dump by PyTorch watchdog

- New knobs to control FR dump as suggested in the public documentation even for watchdog
(TORCH_INCLUDE_STACK_TRACE, TORCH_INCLUDE_ONLY_ACTIVE)
- Trigger the flight recorder dump on exceptions which could be triggered by any CUDA / host side error
  (TORCH_NCCL_EXTRA_DUMP_ON_EXEC)
-> Can be used as a snapshot of the workload progress for post-mortem analysis

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164591
Approved by: https://github.com/fduwjj
This commit is contained in:
Seonmyeong Bak 2025-10-09 05:33:31 +00:00 committed by PyTorch MergeBot
parent ed6156e3ea
commit 263db92563
3 changed files with 59 additions and 10 deletions

View File

@ -265,6 +265,15 @@ struct FlightRecorder {
bool onlyActive);
};
// Whether to include stack trace in the Flight Recorder trace (default true)
static std::vector<std::string> TORCH_INCLUDE_STACK_TRACE = {
"TORCH_INCLUDE_STACK_TRACE"};
// Whether to include only active collectives in the Flight Recorder trace
// (default false)
static std::vector<std::string> TORCH_INCLUDE_ONLY_ACTIVE = {
"TORCH_INCLUDE_ONLY_ACTIVE"};
// Dumps the fr traces and additional information about the Process
// Group.
TORCH_API std::string dump_fr_trace(

View File

@ -1430,17 +1430,41 @@ bool ProcessGroupNCCL::abortComms(
return true;
}
void ProcessGroupNCCL::dumpExtraDebuggingInfo() {
// This extra dump is intended to capture the current snapshot of collectives
// When this process group is terminated for some exception out of NCCL
bool dumpExtraOnExec_ = getCvarBool(TORCH_NCCL_EXTRA_DUMP_ON_EXEC, false);
if (dumpExtraOnExec_) {
bool should_dump_local = false;
bool succeeded = shouldDump_.compare_exchange_strong(
should_dump_local,
true,
std::memory_order_release,
std::memory_order_acquire);
if (succeeded) {
LOG(INFO) << logPrefix() << "Sending extra dumping signal";
broadcastDumpSignal();
// When this routine is called, exception is captured so
// dumping by default_pg is not guaranteed due to early termination of
// process So we call dumping manually here
bool onlyActive = getCvarBool(TORCH_INCLUDE_ONLY_ACTIVE, false);
// Stacktrace is not included at the moment to prevent deadlock due to GIL
dumpDebuggingInfo(false, onlyActive);
}
}
}
// Abort this backend.
void ProcessGroupNCCL::abort() {
// This will log counter for how long the abort actually takes.
STATIC_SCOPED_WAIT_COUNTER(pytorch.ProcessGroupNCCL__abort);
dumpExtraDebuggingInfo();
// Don't join threads here since the purpose of this method is to abort all
// communicators and signal the threads to exit. Joining on the threads could
// potentially block and hence avoid it in this method.
terminateProcessGroup_.store(true);
watchdog_->notify();
// launch abort asynchronously and wait for it to complete or timeout
LOG(INFO) << logPrefix()
<< "Launching ProcessGroupNCCL abort asynchronously.";
@ -1568,7 +1592,9 @@ ProcessGroupNCCL::~ProcessGroupNCCL() {
}
}
bool ProcessGroupNCCL::dumpDebuggingInfo(bool includeStackTrace /*=true*/) {
bool ProcessGroupNCCL::dumpDebuggingInfo(
bool includeStackTrace /*=true*/,
bool onlyActive /*=false*/) {
// This will log counter for how long dumpDebuggingInfo actually takes.
STATIC_SCOPED_WAIT_COUNTER(pytorch.ProcessGroupNCCL__dumpDebuggingInfo);
@ -1579,12 +1605,12 @@ bool ProcessGroupNCCL::dumpDebuggingInfo(bool includeStackTrace /*=true*/) {
LOG(ERROR)
<< logPrefix()
<< "ProcessGroupNCCL preparing to dump debug info. Include stack trace: "
<< includeStackTrace;
<< includeStackTrace << ", only active collectives: " << onlyActive;
if (traceBufferSize_ > 0) {
// We dump nccl trace into local disk by default and users can register
// their customized writer by inheriting `DebugInfoWriter` via
// `registerDebugInfoWriter`.
auto ncclTrace = dump_nccl_trace(true, includeStackTrace, false);
auto ncclTrace = dump_nccl_trace(true, includeStackTrace, onlyActive);
// dump_nccl_trace will hang so we don't grab the global lock until we get
// the trace.
std::lock_guard<std::mutex> lock(writeDebugInfoMutex);
@ -1852,10 +1878,11 @@ void ProcessGroupNCCL::HeartbeatMonitor::runLoop() {
// recorder and dump. After dump, the training should continue.
if (dumpPipe.has_value() && dumpPipe->shouldDump()) {
// best effort dump, not waiting for the dump here
bool onlyActive = getCvarBool(TORCH_INCLUDE_ONLY_ACTIVE, false);
LOG(INFO) << pg_->logPrefix()
<< "Dump signal received through pipe, triggering FR dump.";
futures.emplace_back(std::async(std::launch::async, [this]() {
return this->pg_->dumpDebuggingInfo();
futures.emplace_back(std::async(std::launch::async, [this, onlyActive]() {
return this->pg_->dumpDebuggingInfo(false, onlyActive);
}));
}
}
@ -1873,7 +1900,8 @@ void ProcessGroupNCCL::HeartbeatMonitor::runLoop() {
if (checkDumpSignal && shouldDump_.load()) {
// Store debug info to storage if no other thread does it. (By default to
// local disk)
bool dumpStackTrace = true;
bool dumpStackTrace = getCvarBool(TORCH_INCLUDE_STACK_TRACE, true);
bool onlyActive = getCvarBool(TORCH_INCLUDE_ONLY_ACTIVE, false);
::c10d::C10dLoggingData debugLog;
debugLog.integers["pg_id"] = static_cast<int64_t>(pg_->getUid());
debugLog.integers["rank"] = pg_->getRank();
@ -1882,8 +1910,8 @@ void ProcessGroupNCCL::HeartbeatMonitor::runLoop() {
debugLog.strings["flight_recorder_version"] = c10d::version_val_str;
for (int i = 0; i < 2; i++) {
std::future<bool> asyncDebugDump =
std::async(std::launch::async, [this, dumpStackTrace]() {
return this->pg_->dumpDebuggingInfo(dumpStackTrace);
std::async(std::launch::async, [this, dumpStackTrace, onlyActive]() {
return this->pg_->dumpDebuggingInfo(dumpStackTrace, onlyActive);
});
// wait for the dump until timeout - log data
@ -2045,6 +2073,9 @@ void ProcessGroupNCCL::Watchdog::run() {
VLOG(2) << pg_->logPrefix()
<< "Process group watchdog thread terminated normally";
} catch (std::exception& e) {
// This condition is triggered when any routine in watchdog gets an
// exception
pg_->dumpExtraDebuggingInfo();
if (std::string(e.what()).find("driver shutting down") !=
std::string::npos) {
VLOG(2)

View File

@ -126,6 +126,11 @@ static std::vector<std::string> TORCH_NCCL_COORD_CHECK_MILSEC = {
static std::vector<std::string> TORCH_NCCL_LOG_CPP_STACK_ON_UNCLEAN_SHUTDOWN = {
"TORCH_NCCL_LOG_CPP_STACK_ON_UNCLEAN_SHUTDOWN"};
// Whether to include only active collectives in the Flight Recorder trace
// (default false)
static std::vector<std::string> TORCH_NCCL_EXTRA_DUMP_ON_EXEC = {
"TORCH_NCCL_EXTRA_DUMP_ON_EXEC"};
// Control whether to use CudaEventCache for the collective in watchdog thread.
// We noticed in the past when cuda global lock is held, destroying CudaEvent
// can cause a hang.
@ -1079,7 +1084,11 @@ class TORCH_API ProcessGroupNCCL : public Backend {
// In the timeout case and we will dump debug info such as the NCCL flight
// recorder to storage. Down the road, if we have more complicated or blocking
// operations, we might need to use a side thread to do it.
bool dumpDebuggingInfo(bool includeStackTrace = true);
bool dumpDebuggingInfo(
bool includeStackTrace = true,
bool onlyActive = false);
void dumpExtraDebuggingInfo();
// Abort all communicators on this rank.
bool abortComms(const std::optional<std::string>& abortReason = std::nullopt);