diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 3cb6aee8b9d..3e9802d855e 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -2284,6 +2284,10 @@ void ProcessGroupNCCL::Watchdog::runLoop() { // Work status logging for desync debug desyncDebugger_.logWorkStart(work); + // allow watchdog to do an event query on a side thread + at::cuda::CUDAGuard device_guard(work.ncclEndEvent_->device_index()); + at::cuda::CUDAStreamCaptureModeGuard g{cudaStreamCaptureModeThreadLocal}; + // a work could be started but not completed, so we should not update // lastStartedSeq and lastStartedOpName if the work state is checked // multiple times after the start @@ -2295,10 +2299,6 @@ void ProcessGroupNCCL::Watchdog::runLoop() { pg_->pgStatus_->lastStartedNumelOut = work.numelOut_; } - // allow watchdog to do an event query on a side thread - at::cuda::CUDAGuard device_guard(work.ncclEndEvent_->device_index()); - at::cuda::CUDAStreamCaptureModeGuard g{cudaStreamCaptureModeThreadLocal}; - // Clean up completed work if (work.isCompleted()) { // In case user didn't call `work.wait()` with async collectives,