mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[allocator] Move getFreeMutex (#87237)
It isn't used at all the allocators and this change makes that more clear. Pull Request resolved: https://github.com/pytorch/pytorch/pull/87237 Approved by: https://github.com/wconstab
This commit is contained in:
parent
89e6078bc3
commit
f56ce8dbad
|
|
@ -3,7 +3,6 @@
|
|||
#include <c10/util/irange.h>
|
||||
#include <c10/util/hash.h>
|
||||
#include <c10/util/Optional.h>
|
||||
#include <c10/cuda/CUDACachingAllocator.h>
|
||||
#include <ATen/jit_macros.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/cuda/detail/OffsetCalculator.cuh>
|
||||
|
|
@ -927,7 +926,7 @@ void __inline__ initializeCudaContext() {
|
|||
AT_CUDA_DRIVER_CHECK(at::globalContext().getNVRTC().cuCtxGetCurrent(&pctx));
|
||||
if (!pctx) {
|
||||
std::unique_lock<std::mutex> cudaFreeMutexLock(
|
||||
*(c10::cuda::CUDACachingAllocator::getFreeMutex()));
|
||||
*(c10::cuda::getFreeMutex()));
|
||||
cudaFree(nullptr);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1894,9 +1894,6 @@ class NativeCachingAllocator {
|
|||
// allocated blocks by device pointer
|
||||
ska::flat_hash_map<void*, Block*> allocated_blocks;
|
||||
|
||||
// lock around calls to cudaFree (to prevent deadlocks with NCCL)
|
||||
mutable std::mutex cuda_free_mutex;
|
||||
|
||||
void add_allocated_block(Block* block) {
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
allocated_blocks[block->ptr] = block;
|
||||
|
|
@ -1905,10 +1902,6 @@ class NativeCachingAllocator {
|
|||
public:
|
||||
std::vector<std::unique_ptr<DeviceCachingAllocator>> device_allocator;
|
||||
|
||||
std::mutex* getCudaFreeMutex() const {
|
||||
return &cuda_free_mutex;
|
||||
}
|
||||
|
||||
Block* get_allocated_block(void* ptr, bool remove = false) {
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
auto it = allocated_blocks.find(ptr);
|
||||
|
|
@ -2156,10 +2149,6 @@ void recordStream(const DataPtr& ptr, cuda::CUDAStream stream) {
|
|||
caching_allocator.recordStream(ptr, stream);
|
||||
}
|
||||
|
||||
std::mutex* getFreeMutex() {
|
||||
return caching_allocator.getCudaFreeMutex();
|
||||
}
|
||||
|
||||
static inline void assertValidDevice(int device) {
|
||||
const auto device_num = caching_allocator.device_allocator.size();
|
||||
TORCH_CHECK(
|
||||
|
|
|
|||
|
|
@ -211,7 +211,6 @@ using OutOfMemoryObserver = std::function<void(
|
|||
_(C10_CUDA_API void, \
|
||||
notifyCaptureDestroy, \
|
||||
(int device, MempoolId_t mempool_id)) \
|
||||
_(C10_CUDA_API std::mutex*, getFreeMutex, ()) \
|
||||
_(C10_CUDA_API std::shared_ptr<void>, getIpcDevPtr, (std::string handle)) \
|
||||
_(C10_CUDA_API void, \
|
||||
recordHistory, \
|
||||
|
|
@ -324,11 +323,6 @@ inline void notifyCaptureEnded(int device, CaptureId_t graph_id) {
|
|||
inline void notifyCaptureDestroy(int device, MempoolId_t mempool_id) {
|
||||
return Chosen::notifyCaptureDestroy(device, mempool_id);
|
||||
}
|
||||
|
||||
inline std::mutex* getFreeMutex() {
|
||||
return Chosen::getFreeMutex();
|
||||
}
|
||||
|
||||
// Not part of CUDA_ALLOCATOR_BACKEND_INTERFACE
|
||||
inline std::shared_ptr<void> getIpcDevPtr(std::string handle) {
|
||||
return Chosen::getIpcDevPtr(handle);
|
||||
|
|
|
|||
|
|
@ -77,11 +77,6 @@ std::vector<UsageStream> dummy_unifying_free_streams;
|
|||
// Keeping it simple with an ordinary mutex for now.
|
||||
std::mutex general_mutex;
|
||||
|
||||
// Copy-paste CUDACachingAllocator's
|
||||
// lock around calls to cudaFree (to prevent deadlocks with NCCL)
|
||||
// is this safe?
|
||||
std::mutex cuda_free_mutex;
|
||||
|
||||
/**
|
||||
* Note [Avoid freeing uncaptured ptrs during CUDA graph capture]
|
||||
* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
|
@ -602,10 +597,6 @@ void recordStream(const DataPtr& ptr, cuda::CUDAStream stream) {
|
|||
}
|
||||
}
|
||||
|
||||
std::mutex* getFreeMutex() {
|
||||
return &cuda_free_mutex;
|
||||
}
|
||||
|
||||
std::shared_ptr<void> getIpcDevPtr(std::string handle) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
|
|
@ -899,9 +890,6 @@ void notifyCaptureEnded(int device, CaptureId_t graph_id) {
|
|||
void notifyCaptureDestroy(int device, MempoolId_t mempool_id) {
|
||||
NOT_AVAILABLE("notifyCaptureDestroy");
|
||||
}
|
||||
std::mutex* getFreeMutex() {
|
||||
NOT_AVAILABLE("getFreeMutex");
|
||||
}
|
||||
std::shared_ptr<void> getIpcDevPtr(std::string handle) {
|
||||
NOT_AVAILABLE("getIpcDevPtr");
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,5 +16,10 @@ const char* get_cuda_check_suffix() noexcept {
|
|||
"\nFor debugging consider passing CUDA_LAUNCH_BLOCKING=1.";
|
||||
}
|
||||
}
|
||||
std::mutex* getFreeMutex() {
|
||||
static std::mutex cuda_free_mutex;
|
||||
return &cuda_free_mutex;
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace c10
|
||||
|
|
|
|||
|
|
@ -4,8 +4,11 @@
|
|||
|
||||
#include <c10/cuda/CUDAMacros.h>
|
||||
|
||||
#include <mutex>
|
||||
|
||||
namespace c10 {
|
||||
namespace cuda {
|
||||
C10_CUDA_API const char* get_cuda_check_suffix() noexcept;
|
||||
}
|
||||
C10_CUDA_API std::mutex* getFreeMutex();
|
||||
} // namespace cuda
|
||||
} // namespace c10
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@
|
|||
#include <ATen/cuda/Sleep.h>
|
||||
#include <ATen/cuda/detail/CUDAHooks.h>
|
||||
#include <ATen/cuda/jiterator.h>
|
||||
#include <c10/cuda/CUDACachingAllocator.h>
|
||||
#include <c10/cuda/CUDAFunctions.h>
|
||||
#include <ATen/cuda/CUDAGraphsUtils.cuh>
|
||||
#ifdef USE_NCCL
|
||||
|
|
@ -407,7 +406,7 @@ PyObject* THCPModule_cudaSleep(PyObject* _unused, PyObject* cycles) {
|
|||
static PyGILState_STATE cudaMutexGILState;
|
||||
|
||||
PyObject* THCPModule_cudaLockMutex(PyObject* module, PyObject* noargs) {
|
||||
auto mutex = c10::cuda::CUDACachingAllocator::getFreeMutex();
|
||||
auto mutex = c10::cuda::getFreeMutex();
|
||||
// This has to be a busy loop because we **absolutely need to** hold the GIL
|
||||
// or it's a recipe for a deadlock otherwise (if we let other Python threads
|
||||
// run while we have the cudaMutex, but not the GIL, they might try to e.g.
|
||||
|
|
@ -427,7 +426,7 @@ PyObject* THCPModule_cudaLockMutex(PyObject* module, PyObject* noargs) {
|
|||
}
|
||||
|
||||
PyObject* THCPModule_cudaUnlockMutex(PyObject* module, PyObject* noargs) {
|
||||
auto mutex = c10::cuda::CUDACachingAllocator::getFreeMutex();
|
||||
auto mutex = c10::cuda::getFreeMutex();
|
||||
PyGILState_Release(cudaMutexGILState);
|
||||
mutex->unlock();
|
||||
Py_RETURN_NONE;
|
||||
|
|
|
|||
|
|
@ -297,7 +297,7 @@ void check_inputs(
|
|||
} // namespace detail
|
||||
|
||||
AutoNcclGroup::AutoNcclGroup() {
|
||||
(c10::cuda::CUDACachingAllocator::getFreeMutex())->lock();
|
||||
(c10::cuda::getFreeMutex())->lock();
|
||||
#if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2)
|
||||
detail::NCCL_CHECK(ncclGroupStart());
|
||||
#endif
|
||||
|
|
@ -307,7 +307,7 @@ AutoNcclGroup::~AutoNcclGroup() noexcept(false) {
|
|||
#if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2)
|
||||
detail::NCCL_CHECK(ncclGroupEnd());
|
||||
#endif
|
||||
(c10::cuda::CUDACachingAllocator::getFreeMutex())->unlock();
|
||||
(c10::cuda::getFreeMutex())->unlock();
|
||||
}
|
||||
|
||||
bool is_available(TensorList tensors) {
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@
|
|||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDACachingAllocator.h>
|
||||
#include <c10/util/Optional.h>
|
||||
|
||||
#include <cstddef>
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@
|
|||
#include <ATen/cuda/CUDAGeneratorImpl.h>
|
||||
#include <ATen/cuda/nvrtc_stub/ATenNVRTC.h>
|
||||
|
||||
#include <c10/cuda/CUDACachingAllocator.h>
|
||||
#include <c10/util/irange.h>
|
||||
|
||||
#include <torch/csrc/jit/codegen/cuda/contiguity.h>
|
||||
|
|
@ -941,7 +940,7 @@ void initializeCudaContext() {
|
|||
AT_CUDA_DRIVER_CHECK(at::globalContext().getNVRTC().cuCtxGetCurrent(&pctx));
|
||||
if (!pctx) {
|
||||
std::unique_lock<std::mutex> cudaFreeMutexLock(
|
||||
*(c10::cuda::CUDACachingAllocator::getFreeMutex()));
|
||||
*(c10::cuda::getFreeMutex()));
|
||||
cudaFree(nullptr);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user