mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[fr] Enable reset the FR recording for fault tolerance (#164988)
We also want to have a python side API for users to reset FR recording for FR entries. We don't need to reset the PGNCCL's member counter since we are creating new PGNCCL anyway. FR is a global ring buffer, so we need to reset it. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164988 Approved by: https://github.com/tushar00jain ghstack dependencies: #164752
This commit is contained in:
parent
81dbeb06f4
commit
8ca986ee60
|
|
@ -4602,6 +4602,34 @@ class NCCLTraceTest(NCCLTraceTestBase):
|
|||
)
|
||||
dist.destroy_process_group()
|
||||
|
||||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
@parametrize("timing_enabled", [True, False])
|
||||
def test_fr_record_reset(self, timing_enabled):
|
||||
if self.rank == self.MAIN_PROCESS_RANK:
|
||||
return
|
||||
pg = self._create_process_group_nccl()
|
||||
if timing_enabled:
|
||||
pg._enable_collectives_timing()
|
||||
device = self.local_device
|
||||
self.set_thread_name("fr_test_thread")
|
||||
a = torch.full((3, 4), float(self.rank), device=device)
|
||||
for _ in range(5):
|
||||
f = pg.allreduce(a)
|
||||
f.wait()
|
||||
torch.cuda.synchronize(device=device)
|
||||
# gah ok so now the duration_ms is populated best-effort since it can only happen outside "dump()" api
|
||||
time.sleep(1)
|
||||
torch._C._distributed_c10d._reset_fr_recording_nccl()
|
||||
for _ in range(4):
|
||||
f = pg.allreduce(a)
|
||||
f.wait()
|
||||
torch.cuda.synchronize(device=device)
|
||||
time.sleep(1)
|
||||
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
|
||||
self.assertEqual(len(t["entries"]), 4)
|
||||
dist.destroy_process_group()
|
||||
|
||||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
def test_dump_pipe(self):
|
||||
|
|
|
|||
|
|
@ -231,6 +231,8 @@ struct FlightRecorder {
|
|||
std::optional<size_t> id,
|
||||
bool compute_duration = true);
|
||||
|
||||
TORCH_API void reset_all();
|
||||
|
||||
const c10::List<c10::IValue> getCollectiveTrace(
|
||||
bool includeStacktraces,
|
||||
bool onlyActive);
|
||||
|
|
|
|||
|
|
@ -249,6 +249,14 @@ void FlightRecorder<EventType>::retire_id(
|
|||
}
|
||||
}
|
||||
|
||||
template <typename EventType>
|
||||
void FlightRecorder<EventType>::reset_all() {
|
||||
std::lock_guard<std::mutex> guard(mutex_);
|
||||
next_ = 0;
|
||||
id_ = 0;
|
||||
entries_.clear();
|
||||
}
|
||||
|
||||
template <typename EventType>
|
||||
const c10::List<c10::IValue> FlightRecorder<EventType>::getCollectiveTrace(
|
||||
bool includeStacktraces,
|
||||
|
|
|
|||
|
|
@ -393,6 +393,10 @@ static std::
|
|||
#endif // (defined(IS_NCCLX) || defined(USE_ROCM)) && defined(NCCL_COMM_DUMP)
|
||||
}
|
||||
|
||||
void reset_nccl_trace() {
|
||||
FlightRecorderCUDA::get()->reset_all();
|
||||
}
|
||||
|
||||
std::string dump_nccl_trace(
|
||||
bool includeCollectives,
|
||||
bool includeStackTraces,
|
||||
|
|
|
|||
|
|
@ -1463,6 +1463,9 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
|||
std::unique_ptr<c10::cuda::MemPool> memPool_ = nullptr;
|
||||
};
|
||||
|
||||
// Reset the flighrecorder recordings for the current rank.
|
||||
TORCH_API void reset_nccl_trace();
|
||||
|
||||
// Dumps the NCCL comm traces and additional information about the Process
|
||||
// Group.
|
||||
TORCH_API std::string dump_nccl_trace(
|
||||
|
|
|
|||
|
|
@ -4091,6 +4091,10 @@ such as `dist.all_reduce(tensor, async_op=True)`.
|
|||
Stringified pickle work traces.
|
||||
Default settings return everything - i.e. contains NCCL comm dumps and collective traces.
|
||||
)");
|
||||
module.def(
|
||||
"_reset_fr_recording_nccl",
|
||||
[]() { ::c10d::reset_nccl_trace(); },
|
||||
"API to reset Flight recorder recording when it comes fault tolerance.");
|
||||
#endif
|
||||
|
||||
module.def(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user