[fr] Enable reset the FR recording for fault tolerance (#164988)

We also want to have a python side API for users to reset FR recording for FR entries. We don't need to reset the PGNCCL's member counter since we are creating new PGNCCL anyway. FR is a global ring buffer, so we need to reset it.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164988
Approved by: https://github.com/tushar00jain
ghstack dependencies: #164752
This commit is contained in:
fduwjj 2025-10-08 13:36:35 -07:00 committed by PyTorch MergeBot
parent 81dbeb06f4
commit 8ca986ee60
6 changed files with 49 additions and 0 deletions

View File

@ -4602,6 +4602,34 @@ class NCCLTraceTest(NCCLTraceTestBase):
)
dist.destroy_process_group()
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@parametrize("timing_enabled", [True, False])
def test_fr_record_reset(self, timing_enabled):
if self.rank == self.MAIN_PROCESS_RANK:
return
pg = self._create_process_group_nccl()
if timing_enabled:
pg._enable_collectives_timing()
device = self.local_device
self.set_thread_name("fr_test_thread")
a = torch.full((3, 4), float(self.rank), device=device)
for _ in range(5):
f = pg.allreduce(a)
f.wait()
torch.cuda.synchronize(device=device)
# gah ok so now the duration_ms is populated best-effort since it can only happen outside "dump()" api
time.sleep(1)
torch._C._distributed_c10d._reset_fr_recording_nccl()
for _ in range(4):
f = pg.allreduce(a)
f.wait()
torch.cuda.synchronize(device=device)
time.sleep(1)
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
self.assertEqual(len(t["entries"]), 4)
dist.destroy_process_group()
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
def test_dump_pipe(self):

View File

@ -231,6 +231,8 @@ struct FlightRecorder {
std::optional<size_t> id,
bool compute_duration = true);
TORCH_API void reset_all();
const c10::List<c10::IValue> getCollectiveTrace(
bool includeStacktraces,
bool onlyActive);

View File

@ -249,6 +249,14 @@ void FlightRecorder<EventType>::retire_id(
}
}
template <typename EventType>
void FlightRecorder<EventType>::reset_all() {
std::lock_guard<std::mutex> guard(mutex_);
next_ = 0;
id_ = 0;
entries_.clear();
}
template <typename EventType>
const c10::List<c10::IValue> FlightRecorder<EventType>::getCollectiveTrace(
bool includeStacktraces,

View File

@ -393,6 +393,10 @@ static std::
#endif // (defined(IS_NCCLX) || defined(USE_ROCM)) && defined(NCCL_COMM_DUMP)
}
void reset_nccl_trace() {
FlightRecorderCUDA::get()->reset_all();
}
std::string dump_nccl_trace(
bool includeCollectives,
bool includeStackTraces,

View File

@ -1463,6 +1463,9 @@ class TORCH_API ProcessGroupNCCL : public Backend {
std::unique_ptr<c10::cuda::MemPool> memPool_ = nullptr;
};
// Reset the flighrecorder recordings for the current rank.
TORCH_API void reset_nccl_trace();
// Dumps the NCCL comm traces and additional information about the Process
// Group.
TORCH_API std::string dump_nccl_trace(

View File

@ -4091,6 +4091,10 @@ such as `dist.all_reduce(tensor, async_op=True)`.
Stringified pickle work traces.
Default settings return everything - i.e. contains NCCL comm dumps and collective traces.
)");
module.def(
"_reset_fr_recording_nccl",
[]() { ::c10d::reset_nccl_trace(); },
"API to reset Flight recorder recording when it comes fault tolerance.");
#endif
module.def(