pytorch/torch/csrc/distributed/c10d/cuda/CUDAEventCache.cpp
Tristan Rice ab557421a4 [cca] [c10d] Refactor CUDAEventCache into separate files (#158616)
Summary:
Refactored CUDAEventCache from ProcessGroupNCCL.hpp/.cpp into dedicated header and implementation files for better code organization and maintainability.

Split out CUDAEventCache into:
- New header file: CUDAEventCache.hpp
- New implementation file: CUDAEventCache.cpp
- Updated build_variables.bzl to include the new file

This change improves code maintainability, readability, and follows better code organization practices.
---
> Generated by [Confucius Code Assist (CCA)](https://www.internalfb.com/wiki/Confucius/Analect/Shared_Analects/Confucius_Code_Assist_(CCA)/)
[Session](https://www.internalfb.com/confucius?session_id=61b9029a-636b-11f0-9d9a-f1bcc55be1ce&tab=Chat), [Trace](https://www.internalfb.com/confucius?session_id=61b9029a-636b-11f0-9d9a-f1bcc55be1ce&tab=Trace)

Test Plan:
Verified build with:
```
buck build //caffe2/test/distributed:c10d
```
---
> Generated by [Confucius Code Assist (CCA)](https://www.internalfb.com/wiki/Confucius/Analect/Shared_Analects/Confucius_Code_Assist_(CCA)/)
[Session](https://www.internalfb.com/confucius?session_id=61b9029a-636b-11f0-9d9a-f1bcc55be1ce&tab=Chat), [Trace](https://www.internalfb.com/confucius?session_id=61b9029a-636b-11f0-9d9a-f1bcc55be1ce&tab=Trace)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158616
Approved by: https://github.com/fduwjj
2025-07-19 02:51:28 +00:00

59 lines
2.3 KiB
C++

#include <c10/cuda/CUDAStream.h>
#include <torch/csrc/distributed/c10d/cuda/CUDAEventCache.hpp>
#include <map>
namespace c10d {
CUDAEventCache::CUDAEventCache() = default;
// CUDA event is used to record the start/end of one Work.
// Instead of let the CUDA event gets destroyed, we now reuse it after the Work
// has been erased from workMetaList_.
// This is to avoid the potential deadlock caused by CudaEventDestroy.
std::shared_ptr<at::cuda::CUDAEvent> CUDAEventCache::create(bool timing) {
// Register the deleter as a callback when the WorkNCCL object is destroyed.
// Each deleter keeps a ref count to the cache object, so that even when
// the thread that creates the cache is gone, the cache object won't be
// destroyed until all the events in the cache are destroyed (ref number drops
// to zero).
auto deleter = [cache = shared_from_this(),
timing](at::cuda::CUDAEvent* event) {
std::lock_guard<std::mutex> lock(cache->cacheMutex_);
// We put the event back to the cache deque once the WorkNCCL object is
// destroyed.
cache->eventsArray_[timing ? 1 : 0].push_back(event);
};
at::cuda::CUDAEvent* event = nullptr;
{
std::lock_guard<std::mutex> lock(cacheMutex_);
auto& events = eventsArray_[timing ? 1 : 0];
// If we still have events in the cache, we reuse it. Otherwise, we create a
// new one.
if (!events.empty()) {
event = events.front();
events.pop_front();
} else {
event = new at::cuda::CUDAEvent(
timing ? cudaEventDefault : cudaEventDisableTiming);
}
}
return std::shared_ptr<at::cuda::CUDAEvent>(event, std::move(deleter));
}
std::shared_ptr<CUDAEventCache> CUDAEventCache::get(at::DeviceIndex device) {
// A per-thread singleton of device-to-CUDAEventCache map.
// Map is needed because events cannot be reused across devices.
// Per-thread ownership is needed to support multi-threaded case (instead of
// multi-process case).
static thread_local std::map<at::DeviceIndex, std::shared_ptr<CUDAEventCache>>
cacheDeviceMap;
// Check if device has already been in the map, if not, add a new entry
auto it = cacheDeviceMap.find(device);
if (it == cacheDeviceMap.end()) {
cacheDeviceMap.emplace(device, std::make_shared<CUDAEventCache>());
}
return cacheDeviceMap[device];
}
} // namespace c10d