pytorch/torch/csrc/cuda/CUDAPluggableAllocator.cpp
Shivam Raikundalia a11538aa46 [GPU Snapshot] Add Clear History Flag (#149352)
Summary:
Oftentimes, users complain that a bunch of extra events are prepended to their desired GPU snapshot. This is because they usually attach an OOM logger without knowing and when they go to collect the actual snapshot, it adds all the OOM logger contents. Since OOM and regular snapshot use the same backend, we currently don't have the infra in place to split these snapshots.

As a solution we add a flag to the snapshot frontend to clear out the history when starting the auto-trace record memory history.

A more thorough solution would be to have a user pass in a handle and to have snapshots per handle to seperate the events. However, this would likely be complicated and more work than it is worth as we would have to change the callbacks in the caching allocator and pass these objects between python and cpp.

Test Plan:
See diff below

Differential Revision: D71159720

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149352
Approved by: https://github.com/eqy, https://github.com/aaronenyeshi
2025-03-19 21:44:20 +00:00

407 lines
12 KiB
C++

#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/cuda/CUDAGuard.h>
#include <mutex>
#include <utility>
#include <torch/csrc/cuda/CUDAPluggableAllocator.h>
namespace torch::cuda::CUDAPluggableAllocator {
CUDAPluggableAllocatorDeleterContext::CUDAPluggableAllocatorDeleterContext(
std::function<FreeFuncType> free_fn,
void* data,
size_t size,
int device,
cudaStream_t stream)
: free_fn_(std::move(free_fn)),
data_(data),
size_(size),
device_(device),
stream_(stream) {}
void CUDAPluggableAllocatorDeleterContext::free() {
free_fn_(data_, size_, device_, stream_);
delete this;
}
int device_count = 0;
void custom_raw_deleter(void* ptr);
_AllocationMetadata::_AllocationMetadata() : size(0), device_idx(-1) {}
_AllocationMetadata::_AllocationMetadata(
size_t size,
c10::DeviceIndex device_idx,
cudaStream_t stream)
: size(size), device_idx(device_idx), stream(stream) {}
// This is a fast API to just register allocators
// based on function pointers (ie. external .so libraries)
// This avoids having to link against libtorch for C++ based custom allocators
// And also use this from python
CUDAPluggableAllocator::CUDAPluggableAllocator(
std::function<MallocFuncType> alloc_fn,
std::function<FreeFuncType> free_fn)
: alloc_fn_(std::move(alloc_fn)), free_fn_(std::move(free_fn)) {}
CUDAPluggableAllocator::CUDAPluggableAllocator(CUDAPluggableAllocator& other)
: alloc_fn_(other.alloc_fn_),
free_fn_(other.free_fn_),
init_fn_(other.init_fn_),
reset_fn_(other.reset_fn_),
memory_fraction_fn_(other.memory_fraction_fn_),
base_alloc_fn_(other.base_alloc_fn_),
record_stream_fn_(other.record_stream_fn_),
begin_allocate_to_pool_fn_(other.begin_allocate_to_pool_fn_),
end_allocate_to_pool_fn_(other.end_allocate_to_pool_fn_),
relase_pool_fn_(other.relase_pool_fn_) {}
void CUDAPluggableAllocator::set_init_fn(std::function<void(int)> init_fn) {
init_fn_ = std::move(init_fn);
}
void CUDAPluggableAllocator::set_reset_fn(std::function<void()> reset_fn) {
reset_fn_ = std::move(reset_fn);
}
void CUDAPluggableAllocator::set_memory_fraction_fn(
std::function<void(double, int)> memory_fraction_fn) {
memory_fraction_fn_ = std::move(memory_fraction_fn);
}
void CUDAPluggableAllocator::set_base_alloc_fn(
std::function<void*(void*, size_t*)> base_alloc_fn) {
base_alloc_fn_ = std::move(base_alloc_fn);
}
void CUDAPluggableAllocator::set_record_stream_fn(
std::function<void(void* ptr, cudaStream_t stream)> record_stream_fn) {
record_stream_fn_ = std::move(record_stream_fn);
}
void CUDAPluggableAllocator::set_begin_allocate_to_pool(
std::function<
void(int, c10::cuda::MempoolId_t, std::function<bool(cudaStream_t)>)>
capture_begin_fn) {
begin_allocate_to_pool_fn_ = std::move(capture_begin_fn);
}
void CUDAPluggableAllocator::set_end_allocate_to_pool_fn(
std::function<void(int, c10::cuda::MempoolId_t)> capture_about_to_end_fn) {
end_allocate_to_pool_fn_ = std::move(capture_about_to_end_fn);
}
void CUDAPluggableAllocator::set_release_pool(
std::function<void(int, c10::cuda::MempoolId_t)> capture_destroy_fn) {
relase_pool_fn_ = std::move(capture_destroy_fn);
}
void* CUDAPluggableAllocator::malloc(
size_t size,
c10::DeviceIndex device,
cudaStream_t stream) {
void* r = alloc_fn_(size, device, stream);
{
const std::lock_guard<std::mutex> lock(allocator_mutex_);
allocation_metadata_.emplace(r, _AllocationMetadata(size, device, stream));
}
return r;
}
c10::DataPtr CUDAPluggableAllocator::allocate(size_t size) {
c10::DeviceIndex device = -1;
C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
cudaStream_t stream = c10::cuda::getCurrentCUDAStream(device);
void* r = this->malloc(size, device, stream);
auto* ctx = new CUDAPluggableAllocatorDeleterContext(
free_fn_, r, size, device, stream);
c10::DataPtr data_ptr = {
r, ctx, raw_deleter(), c10::Device(c10::DeviceType::CUDA, device)};
return data_ptr;
}
c10::DeleterFnPtr CUDAPluggableAllocator::raw_deleter() const {
return &custom_raw_deleter;
}
void* CUDAPluggableAllocator::raw_alloc(size_t nbytes) {
c10::DeviceIndex device = -1;
C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
cudaStream_t stream = c10::cuda::getCurrentCUDAStream(device);
return malloc(nbytes, device, stream);
}
void* CUDAPluggableAllocator::raw_alloc_with_stream(
size_t nbytes,
cudaStream_t stream) {
c10::DeviceIndex device = -1;
C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
return malloc(nbytes, device, stream);
}
void CUDAPluggableAllocator::raw_delete(void* ptr) {
cudaStream_t stream{};
c10::DeviceIndex device_idx = -1;
size_t size = 0;
{
const std::lock_guard<std::mutex> lock(allocator_mutex_);
TORCH_CHECK(
allocation_metadata_.count(ptr),
"Trying to free a pointer not allocated here");
_AllocationMetadata& metadata = allocation_metadata_[ptr];
size = metadata.size;
device_idx = metadata.device_idx;
stream = metadata.stream;
allocation_metadata_.erase(ptr);
}
free_fn_(ptr, size, device_idx, stream);
}
void CUDAPluggableAllocator::init(int device_count) {
if (init_fn_) {
init_fn_(device_count);
}
initialized_ = true;
}
bool CUDAPluggableAllocator::initialized() {
return initialized_;
}
double CUDAPluggableAllocator::getMemoryFraction(c10::DeviceIndex device) {
TORCH_CHECK(
false,
"CUDAPluggableAllocator does not yet support getMemoryFraction. "
"If you need it, please file an issue describing your use case.");
}
void CUDAPluggableAllocator::setMemoryFraction(
double fraction,
c10::DeviceIndex device) {
if (memory_fraction_fn_) {
memory_fraction_fn_(fraction, device);
}
}
void CUDAPluggableAllocator::emptyCache() {
if (reset_fn_) {
return reset_fn_();
}
}
void CUDAPluggableAllocator::cacheInfo(
c10::DeviceIndex device,
size_t* largestBlock) {
TORCH_CHECK(
false,
"CUDAPluggableAllocator does not yet support cacheInfo. "
"If you need it, please file an issue describing your use case.");
}
void* CUDAPluggableAllocator::getBaseAllocation(void* ptr, size_t* size) {
if (base_alloc_fn_) {
return base_alloc_fn_(ptr, size);
} else {
return ptr;
}
}
void CUDAPluggableAllocator::recordStream(
const c10::DataPtr& ptr,
streamType stream) {
if (record_stream_fn_) {
record_stream_fn_(ptr.get(), stream);
}
}
c10::CachingDeviceAllocator::DeviceStats CUDAPluggableAllocator::getDeviceStats(
c10::DeviceIndex device) {
TORCH_CHECK(
false,
"CUDAPluggableAllocator does not yet support getDeviceStats. "
"If you need it, please file an issue describing your use case.");
}
void CUDAPluggableAllocator::resetAccumulatedStats(c10::DeviceIndex device) {
TORCH_CHECK(
false,
"CUDAPluggableAllocator does not yet support resetAccumulatedStats. "
"If you need it, please file an issue describing your use case.");
}
void CUDAPluggableAllocator::resetPeakStats(c10::DeviceIndex device) {
TORCH_CHECK(
false,
"CUDAPluggableAllocator does not yet support resetPeakStats. "
"If you need it, please file an issue describing your use case.");
}
c10::cuda::CUDACachingAllocator::SnapshotInfo CUDAPluggableAllocator::
snapshot() {
TORCH_CHECK(
false,
"CUDAPluggableAllocator does not yet support snapshot. "
"If you need it, please file an issue describing your use case.");
}
c10::cuda::CUDACachingAllocator::ShareableHandle CUDAPluggableAllocator::
shareIpcHandle(void* ptr) {
TORCH_CHECK(
false,
"CUDAPluggableAllocator does not yet support shareIPcHandle. "
"If you need it, please file an issue describing your use case.");
}
std::shared_ptr<void> CUDAPluggableAllocator::getIpcDevPtr(std::string handle) {
TORCH_CHECK(
false,
"CUDAPluggableAllocator does not yet support getIpcDevPtr. "
"If you need it, please file an issue describing your use case.");
}
// CUDAGraph interactions
void CUDAPluggableAllocator::beginAllocateToPool(
c10::DeviceIndex device,
c10::cuda::MempoolId_t mempool_id,
std::function<bool(cudaStream_t)> filter) {
if (begin_allocate_to_pool_fn_) {
begin_allocate_to_pool_fn_(device, mempool_id, std::move(filter));
}
}
void CUDAPluggableAllocator::endAllocateToPool(
c10::DeviceIndex device,
c10::cuda::MempoolId_t mempool_id) {
if (end_allocate_to_pool_fn_) {
end_allocate_to_pool_fn_(device, mempool_id);
}
}
void CUDAPluggableAllocator::releasePool(
c10::DeviceIndex device,
c10::cuda::MempoolId_t mempool_id) {
if (relase_pool_fn_) {
relase_pool_fn_(device, mempool_id);
}
}
void CUDAPluggableAllocator::recordHistory(
bool enabled,
c10::cuda::CUDACachingAllocator::CreateContextFn context_recorder,
size_t alloc_trace_max_entries,
c10::cuda::CUDACachingAllocator::RecordContext when,
bool clearHistory) {
TORCH_CHECK(
false,
"CUDAPluggableAllocator does not yet support recordHistory. "
"If you need it, please file an issue describing your use case.");
}
void CUDAPluggableAllocator::attachOutOfMemoryObserver(
c10::cuda::CUDACachingAllocator::OutOfMemoryObserver observer) {
TORCH_CHECK(
false,
"CUDAPluggableAllocator does not yet support attachOutOfMemoryObserver. "
"If you need it, please file an issue describing your use case.");
}
void CUDAPluggableAllocator::attachAllocatorTraceTracker(
c10::cuda::CUDACachingAllocator::AllocatorTraceTracker tracker) {
TORCH_CHECK(
false,
"CUDAPluggableAllocator does not support attachAllocatorTraceTracker. "
"attachAllocatorTraceTracker is only used inside Pytorch.");
}
std::shared_ptr<c10::cuda::CUDACachingAllocator::AllocatorState>
CUDAPluggableAllocator::getCheckpointState(
c10::DeviceIndex device,
at::cuda::MempoolId_t id) {
TORCH_CHECK(
false,
"CUDAPluggableAllocator does not yet support getCheckpointState. "
"If you need it, please file an issue describing your use case.");
}
c10::cuda::CUDACachingAllocator::CheckpointDelta CUDAPluggableAllocator::
setCheckpointPoolState(
c10::DeviceIndex device,
std::shared_ptr<c10::cuda::CUDACachingAllocator::AllocatorState> pps) {
TORCH_CHECK(
false,
"CUDAPluggableAllocator does not yet support setCheckpointPoolState. "
"If you need it, please file an issue describing your use case.");
}
void CUDAPluggableAllocator::enablePeerAccess(
c10::DeviceIndex dev,
c10::DeviceIndex dev_to_access) {
c10::cuda::CUDAGuard device_guard(dev);
cudaError_t err = cudaDeviceEnablePeerAccess(dev_to_access, 0);
if (err == cudaErrorPeerAccessAlreadyEnabled) {
// ignore and clear the error if access was already enabled
(void)cudaGetLastError();
} else {
C10_CUDA_CHECK(err);
}
}
cudaError_t CUDAPluggableAllocator::memcpyAsync(
void* dst,
int dstDevice,
const void* src,
int srcDevice,
size_t count,
cudaStream_t stream,
bool p2p_enabled) {
return cudaMemcpyAsync(dst, src, count, cudaMemcpyDeviceToDevice, stream);
}
std::string CUDAPluggableAllocator::name() {
return "pluggable";
}
void CUDAPluggableAllocator::copy_data(
void* dest,
const void* src,
std::size_t count) const {
C10_CUDA_CHECK(
cudaMemcpy(dest, src, count, cudaMemcpyKind::cudaMemcpyDeviceToDevice));
}
std::shared_ptr<c10::cuda::CUDACachingAllocator::CUDAAllocator>
current_custom_allocator;
std::shared_ptr<c10::cuda::CUDACachingAllocator::CUDAAllocator>
getCurrentAllocator() {
return current_custom_allocator;
}
// TODO: add more functions in the argument
std::shared_ptr<c10::cuda::CUDACachingAllocator::CUDAAllocator>
createCustomAllocator(
std::function<MallocFuncType> alloc_fn,
std::function<FreeFuncType> free_fn) {
std::shared_ptr<CUDAPluggableAllocator> allocator(
new CUDAPluggableAllocator(std::move(alloc_fn), std::move(free_fn)));
allocator->init(device_count);
return allocator;
}
void changeCurrentAllocator(
const std::shared_ptr<c10::cuda::CUDACachingAllocator::CUDAAllocator>&
allocator) {
TORCH_CHECK(
!c10::cuda::CUDACachingAllocator::allocator.load()->initialized(),
"Can't swap an already initialized allocator");
c10::cuda::CUDACachingAllocator::allocator.store(allocator.get());
current_custom_allocator = allocator;
}
void custom_raw_deleter(void* ctx) {
reinterpret_cast<CUDAPluggableAllocatorDeleterContext*>(ctx)->free();
}
} // namespace torch::cuda::CUDAPluggableAllocator