diff --git a/aten/src/ATen/test/cuda_allocator_test.cpp b/aten/src/ATen/test/cuda_allocator_test.cpp index 27a352e7d5a..5aa2378c22c 100644 --- a/aten/src/ATen/test/cuda_allocator_test.cpp +++ b/aten/src/ATen/test/cuda_allocator_test.cpp @@ -5,6 +5,51 @@ #include +#include + TEST(AllocatorTestCUDA, test_clone) { test_allocator_clone(c10::cuda::CUDACachingAllocator::get()); } + +static int called_dummy_free_0 = 0; +static int called_dummy_free_1 = 0; + +void* dummy_alloc_0(size_t size, int device, void* stream) {return nullptr;} +void dummy_free_0(void* data, size_t size, int device, void* stream) { + called_dummy_free_0++; +} +void dummy_free_1(void* data, size_t size, int device, void* stream) { + called_dummy_free_1++; +} + +// Tests that data_ptrs have their respective deleters +// when mixing allocators +TEST(AllocatorTestCUDA, test_pluggable_allocator_deleters) { + // Create a tensor with dummy_allocator_0, where dummy_free_0 is the deleter + auto dummy_allocator_0 = torch::cuda::CUDAPluggableAllocator::createCustomAllocator(dummy_alloc_0, dummy_free_0); + c10::cuda::CUDACachingAllocator::allocator.store(dummy_allocator_0.get()); + at::Tensor a = at::empty({0}, at::TensorOptions().device(at::kCUDA)); + + // Create a tensor with dummy_allocator_1, where dummy_free_1 is the deleter + auto dummy_allocator_1 = torch::cuda::CUDAPluggableAllocator::createCustomAllocator(dummy_alloc_0, dummy_free_1); + c10::cuda::CUDACachingAllocator::allocator.store(dummy_allocator_1.get()); + at::Tensor b = at::empty({0}, at::TensorOptions().device(at::kCUDA)); + + // Manually use a's deleter + auto* ctx = a.storage().data_ptr().get_context(); + a.storage().data_ptr().get_deleter()(ctx); + a.storage().mutable_data_ptr().release_context(); + + // a's deleter is dummy_free_0 + // dummy_free_0 should be called above, so called_dummy_free_0 should be 1 + ASSERT_TRUE(called_dummy_free_0 == 1); + + // Manually use b's deleter + ctx = b.storage().data_ptr().get_context(); + b.storage().data_ptr().get_deleter()(ctx); + b.storage().mutable_data_ptr().release_context(); + + // b's deleter is dummy_free_1 + // dummy_free_1 should be called above, so called_dummy_free_1 should be 1 + ASSERT_TRUE(called_dummy_free_1 == 1); +} diff --git a/build_variables.bzl b/build_variables.bzl index c1d0b5dca25..66930eb0c60 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -661,6 +661,7 @@ libtorch_cuda_core_sources = [ "torch/csrc/CudaIPCTypes.cpp", "torch/csrc/cuda/comm.cpp", "torch/csrc/cuda/memory_snapshot.cpp", + "torch/csrc/cuda/CUDAPluggableAllocator.cpp", "torch/csrc/inductor/aoti_runner/model_container_runner_cuda.cpp", "torch/csrc/inductor/aoti_torch/shim_cuda.cpp", "torch/csrc/jit/codegen/fuser/cuda/fused_kernel.cpp", @@ -772,7 +773,6 @@ libtorch_python_cuda_core_sources = [ "torch/csrc/cuda/shared/cudart.cpp", "torch/csrc/cuda/shared/nvtx.cpp", "torch/csrc/cuda/utils.cpp", - "torch/csrc/cuda/CUDAPluggableAllocator.cpp", ] libtorch_python_cuda_sources = libtorch_python_cuda_core_sources + [ diff --git a/torch/csrc/cuda/CUDAPluggableAllocator.cpp b/torch/csrc/cuda/CUDAPluggableAllocator.cpp index cb7b62387b6..5d651fa83f0 100644 --- a/torch/csrc/cuda/CUDAPluggableAllocator.cpp +++ b/torch/csrc/cuda/CUDAPluggableAllocator.cpp @@ -8,6 +8,23 @@ namespace torch::cuda::CUDAPluggableAllocator { +CUDAPluggableAllocatorDeleterContext::CUDAPluggableAllocatorDeleterContext( + std::function free_fn, + void* data, + size_t size, + int device, + cudaStream_t stream) + : free_fn_(free_fn), + data_(data), + size_(size), + device_(device), + stream_(stream) {} + +void CUDAPluggableAllocatorDeleterContext::free() { + free_fn_(data_, size_, device_, stream_); + delete this; +} + int device_count = 0; void custom_raw_deleter(void* ptr); @@ -26,8 +43,8 @@ _AllocationMetadata::_AllocationMetadata( // This avoids having to link against libtorch for C++ based custom allocators // And also use this from python CUDAPluggableAllocator::CUDAPluggableAllocator( - std::function alloc_fn, - std::function free_fn) + std::function alloc_fn, + std::function free_fn) : alloc_fn_(std::move(alloc_fn)), free_fn_(std::move(free_fn)) {} CUDAPluggableAllocator::CUDAPluggableAllocator(CUDAPluggableAllocator& other) @@ -99,8 +116,10 @@ c10::DataPtr CUDAPluggableAllocator::allocate(size_t size) { C10_CUDA_CHECK(c10::cuda::GetDevice(&device)); cudaStream_t stream = c10::cuda::getCurrentCUDAStream(device); void* r = this->malloc(size, device, stream); + auto* ctx = new CUDAPluggableAllocatorDeleterContext( + free_fn_, r, size, device, stream); c10::DataPtr data_ptr = { - r, r, raw_deleter(), c10::Device(c10::DeviceType::CUDA, device)}; + r, ctx, raw_deleter(), c10::Device(c10::DeviceType::CUDA, device)}; return data_ptr; } @@ -348,8 +367,8 @@ getCurrentAllocator() { // TODO: add more functions in the argument std::shared_ptr createCustomAllocator( - std::function alloc_fn, - std::function free_fn) { + std::function alloc_fn, + std::function free_fn) { std::shared_ptr allocator( new CUDAPluggableAllocator(std::move(alloc_fn), std::move(free_fn))); allocator->init(device_count); @@ -366,8 +385,8 @@ void changeCurrentAllocator( current_custom_allocator = allocator; } -void custom_raw_deleter(void* ptr) { - current_custom_allocator->raw_delete(ptr); +void custom_raw_deleter(void* ctx) { + reinterpret_cast(ctx)->free(); } } // namespace torch::cuda::CUDAPluggableAllocator diff --git a/torch/csrc/cuda/CUDAPluggableAllocator.h b/torch/csrc/cuda/CUDAPluggableAllocator.h index 22a61e48e4a..cd5f3196aba 100644 --- a/torch/csrc/cuda/CUDAPluggableAllocator.h +++ b/torch/csrc/cuda/CUDAPluggableAllocator.h @@ -11,19 +11,47 @@ namespace torch::cuda::CUDAPluggableAllocator { +using MallocFuncType = void*(size_t, int, cudaStream_t); +using FreeFuncType = void(void*, size_t, int, cudaStream_t); + +// A CUDAPluggableAllocatorDeleterContext object is used as the `ctx` +// argument for DataPtr. We need context because a user can use +// multiple allocators in the same PyTorch program, and +// the allocators can have different free functions, such as: +// free, cudaFree, cudaFreeAsync, ncclMemFree etc. +struct TORCH_CUDA_CPP_API CUDAPluggableAllocatorDeleterContext { + explicit CUDAPluggableAllocatorDeleterContext( + std::function free_fn, + void* data, + size_t size, + int device, + cudaStream_t stream); + + void free(); + + private: + std::function free_fn_; + void* data_; + size_t size_; + int device_; + cudaStream_t stream_; +}; + #if defined(TORCH_HIP_VERSION) using streamType = c10::hip::HIPStream; #else using streamType = c10::cuda::CUDAStream; #endif -std::shared_ptr +TORCH_CUDA_CPP_API std::shared_ptr< + c10::cuda::CUDACachingAllocator::CUDAAllocator> getCurrentAllocator(); -std::shared_ptr +TORCH_CUDA_CPP_API std::shared_ptr< + c10::cuda::CUDACachingAllocator::CUDAAllocator> createCustomAllocator( - std::function alloc_fn, - std::function free_fn); -void changeCurrentAllocator( + std::function alloc_fn, + std::function free_fn); +TORCH_CUDA_CPP_API void changeCurrentAllocator( const std::shared_ptr& allocator); @@ -38,11 +66,11 @@ struct _AllocationMetadata { cudaStream_t stream; }; -struct CUDAPluggableAllocator +struct TORCH_CUDA_CPP_API CUDAPluggableAllocator : public c10::cuda::CUDACachingAllocator::CUDAAllocator { CUDAPluggableAllocator( - std::function alloc_fn, - std::function free_fn); + std::function alloc_fn, + std::function free_fn); CUDAPluggableAllocator(CUDAPluggableAllocator& other); @@ -131,8 +159,8 @@ struct CUDAPluggableAllocator void copy_data(void* dest, const void* src, std::size_t count) const final; protected: - std::function alloc_fn_; - std::function free_fn_; + std::function alloc_fn_; + std::function free_fn_; std::function init_fn_; std::function reset_fn_; std::function memory_fraction_fn_; diff --git a/torch/csrc/cuda/Module.cpp b/torch/csrc/cuda/Module.cpp index 772c73c0053..f8154570d89 100644 --- a/torch/csrc/cuda/Module.cpp +++ b/torch/csrc/cuda/Module.cpp @@ -1175,16 +1175,14 @@ static void registerCudaPluggableAllocator(PyObject* module) { self.set_release_pool(func); }); m.def("_cuda_customAllocator", [](uint64_t malloc_ptr, uint64_t free_ptr) { - using MallocFuncType = void*(size_t, int, cudaStream_t); - using FreeFuncType = void(void*, size_t, int, cudaStream_t); + using namespace torch::cuda::CUDAPluggableAllocator; std::function malloc_fn = // NOLINTNEXTLINE(performance-no-int-to-ptr) reinterpret_cast(malloc_ptr); std::function free_fn = // NOLINTNEXTLINE(performance-no-int-to-ptr) reinterpret_cast(free_ptr); - return torch::cuda::CUDAPluggableAllocator::createCustomAllocator( - malloc_fn, free_fn); + return createCustomAllocator(malloc_fn, free_fn); }); // NOLINTNEXTLINE(bugprone-unused-raii)