Revert "Remove MemPoolContext (#154042)"

This reverts commit 3b38989b5f.

Reverted https://github.com/pytorch/pytorch/pull/154042 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/154042#issuecomment-2921401100))
This commit is contained in:
PyTorch MergeBot 2025-05-30 06:53:35 +00:00
parent 0fdd568b78
commit d173ba5a75
14 changed files with 243 additions and 139 deletions

View File

@ -833,9 +833,8 @@ class EventPool {
// CUDA graphs helper
struct PrivatePool {
PrivatePool(MempoolId_t id, CUDAAllocator* allocator = nullptr)
PrivatePool(MempoolId_t id)
: id(std::move(id)),
allocator_(allocator),
large_blocks(/*small=*/false, this),
small_blocks(/*small=*/true, this) {}
PrivatePool(const PrivatePool&) = delete;
@ -856,14 +855,8 @@ struct PrivatePool {
// distinguish private blocks by adding a "pool id" check above the stream
// check in BlockComparator. BlockComparator is performance- critical though,
// I'd rather not add more logic to it.
CUDAAllocator* allocator_;
BlockPool large_blocks;
BlockPool small_blocks;
public:
CUDAAllocator* allocator() {
return allocator_;
}
};
MempoolId_t BlockPool::owner_MempoolId() const {
@ -912,8 +905,9 @@ struct MempoolIdHash {
};
cudaError_t allocPrimitive(void** ptr, size_t size, AllocParams& p) {
if (p.pool->owner_PrivatePool && p.pool->owner_PrivatePool->allocator()) {
*ptr = p.pool->owner_PrivatePool->allocator()->raw_alloc(size);
auto active_pool = MemPoolContext::getActiveMemPool();
if (active_pool && active_pool->allocator() && p.pool->owner_PrivatePool) {
*ptr = active_pool->allocator()->raw_alloc(size);
return *ptr ? cudaSuccess : cudaErrorMemoryAllocation;
} else {
return C10_CUDA_ERROR_HANDLED(cudaMalloc(ptr, size));
@ -1283,14 +1277,14 @@ class DeviceCachingAllocator {
alloc_block(params, false, context, lock))
// Free all non-split cached blocks and retry alloc.
|| (C10_LIKELY(captures_underway.empty()) &&
release_cached_blocks(context, {0, 0}) &&
release_cached_blocks(context) &&
alloc_block(params, true, context, lock));
}
// we are about to oom, try to use existing mempools as a last resort
if (!block_found && params.err == cudaErrorMemoryAllocation) {
// if already trying to use a mempool, then just oom
bool active_pool = params.pool->owner_PrivatePool;
auto active_pool = MemPoolContext::getActiveMemPool();
if (!active_pool) {
for (MempoolId_t mempool_id : use_on_oom_pools) {
auto tid = std::this_thread::get_id();
@ -1677,10 +1671,10 @@ class DeviceCachingAllocator {
}
/** returns cached blocks to the system allocator **/
void emptyCache(MempoolId_t mempool_id) {
void emptyCache() {
auto context = maybeGatherContext(RecordContext::ALL);
std::lock_guard<std::recursive_mutex> lock(mutex);
release_cached_blocks(context, mempool_id);
release_cached_blocks(context);
}
/** Retrieves size of largest unused block held by the memory cache **/
@ -1998,10 +1992,16 @@ class DeviceCachingAllocator {
/** Dump a complete snapshot of the memory held by the allocator. Potentially
* VERY expensive. **/
std::vector<SegmentInfo> snapshot(MempoolId_t mempool_id) {
std::vector<SegmentInfo> snapshot() {
std::lock_guard<std::recursive_mutex> lock(mutex);
std::vector<Block*> all_blocks;
MempoolId_t mempool_id = {0, 0};
auto active_mempool = MemPoolContext::getActiveMemPool();
if (active_mempool) {
mempool_id = active_mempool->id();
}
if (mempool_id.first != 0 || mempool_id.second != 0) {
// If there is an active mempool, we find the corresponding PrivatePool
@ -2011,7 +2011,7 @@ class DeviceCachingAllocator {
all_blocks = get_private_pool_head_blocks(pool->second.get());
}
} else {
// When snapshot is called with non-default mempool_id, we return
// When snapshot is called outside a MemPoolContext, we return
// all the blocks in the CUDACachingAllocator (as returned by
// get_all_blocks).
all_blocks = get_all_blocks();
@ -2130,11 +2130,11 @@ class DeviceCachingAllocator {
}
}
void createOrIncrefPool(MempoolId_t mempool_id, CUDAAllocator* allocator) {
void ensureExistsAndIncrefPool(MempoolId_t mempool_id) {
// Create a PrivatePool object if it does not exist yet
// and increment its use_count
std::lock_guard<std::recursive_mutex> lock(mutex);
create_or_incref_pool(mempool_id, allocator);
ensure_exists_and_incref_pool(mempool_id);
}
void setUseOnOOM(MempoolId_t mempool_id) {
@ -2150,7 +2150,7 @@ class DeviceCachingAllocator {
MempoolId_t mempool_id,
std::function<bool(cudaStream_t)> filter) {
std::lock_guard<std::recursive_mutex> lock(mutex);
create_or_incref_pool(mempool_id);
ensure_exists_and_incref_pool(mempool_id);
for (auto it2 = captures_underway.begin(); it2 != captures_underway.end();
++it2) {
TORCH_CHECK(
@ -2272,24 +2272,21 @@ class DeviceCachingAllocator {
return blocks;
}
void create_or_incref_pool(
MempoolId_t mempool_id,
CUDAAllocator* allocator = nullptr) {
void ensure_exists_and_incref_pool(MempoolId_t mempool_id) {
auto it = graph_pools.find(mempool_id);
if (it == graph_pools.end()) {
// mempool_id does not reference an existing pool.
// Make a new pool for CUDAGraph capture or torch.cuda.use_mem_pool
// usage. use_count is initially 1, which means the pool is
// being used since somebody called createOrIncrefPool.
// being used since somebody called ensureExistsAndIncrefPool.
graph_pools.emplace(
mempool_id, std::make_unique<PrivatePool>(mempool_id, allocator));
mempool_id, std::make_unique<PrivatePool>(mempool_id));
} else {
// mempool_id references an existing pool, which the current CUDAGraph
// capture or torch.cuda.use_mem_pool will
// share. Check this pool is live (at least one other capture already
// references it). Increment it to establish the usage.
TORCH_INTERNAL_ASSERT(it->second->use_count > 0);
TORCH_INTERNAL_ASSERT(allocator == nullptr);
it->second->use_count++;
}
}
@ -2779,8 +2776,7 @@ class DeviceCachingAllocator {
bool in_fbcode = false;
#endif
bool active_pool =
p.pool->owner_PrivatePool && p.pool->owner_PrivatePool->allocator();
auto active_pool = MemPoolContext::getActiveMemPool();
if (set_fraction &&
total_allocated_memory + size > allowed_memory_maximum) {
p.err = cudaErrorMemoryAllocation;
@ -2805,6 +2801,12 @@ class DeviceCachingAllocator {
}
return bool(p.block);
} else {
if (active_pool && active_pool->allocator() &&
p.pool->owner_PrivatePool) {
// Ensure that active_pool and p.pool are the same
auto pp = get_private_pool(active_pool->id());
TORCH_INTERNAL_ASSERT(pp == p.pool->owner_PrivatePool);
}
if (CUDAAllocatorConfig::release_lock_on_cudamalloc()) {
// At scope exit, acquire the lock again. This provides safety against
// any potential exceptions in the cudaMallocMaybeCapturing function.
@ -2924,9 +2926,13 @@ class DeviceCachingAllocator {
return true;
}
bool release_cached_blocks(
const std::shared_ptr<GatheredContext>& context,
MempoolId_t mempool_id) {
bool release_cached_blocks(const std::shared_ptr<GatheredContext>& context) {
MempoolId_t mempool_id = {0, 0};
auto active_mempool = MemPoolContext::getActiveMemPool();
if (active_mempool) {
mempool_id = active_mempool->id();
}
if (mempool_id.first == 0 && mempool_id.second == 0) {
// If there is no active mempool, we work on releasing *all* blocks.
@ -2999,10 +3005,15 @@ class DeviceCachingAllocator {
context ? context : block->context_when_segment_allocated);
auto* pool = block->pool;
if (pool->owner_PrivatePool && pool->owner_PrivatePool->allocator()) {
auto active_pool = MemPoolContext::getActiveMemPool();
if (active_pool && active_pool->allocator() && pool->owner_PrivatePool) {
// Ensure that active_pool and pool are the same
auto pp = get_private_pool(active_pool->id());
TORCH_INTERNAL_ASSERT(pp == pool->owner_PrivatePool);
// If there is an active mempool with a given allocator,
// we use the given allocator's delete function.
pool->owner_PrivatePool->allocator()->raw_delete((void*)block->ptr);
active_pool->allocator()->raw_delete((void*)block->ptr);
} else {
C10_CUDA_CHECK(cudaFree((void*)block->ptr));
}
@ -3578,9 +3589,9 @@ class NativeCachingAllocator : public CUDAAllocator {
}
}
void emptyCache(MempoolId_t mempool_id) override {
void emptyCache() override {
for (auto& da : device_allocator)
da->emptyCache(mempool_id);
da->emptyCache();
}
void enable(bool value) override {
@ -3628,7 +3639,7 @@ class NativeCachingAllocator : public CUDAAllocator {
device_allocator[block->device]->recordStream(block, stream);
}
SnapshotInfo snapshot(MempoolId_t mempool_id) override {
SnapshotInfo snapshot() override {
// Set-up converter to convert timestamps from tsc to microseconds.
auto tsc_to_ns = clock_converter.makeConverter();
auto tsc_to_us = [=](approx_time_t t_approx) {
@ -3646,7 +3657,7 @@ class NativeCachingAllocator : public CUDAAllocator {
// Get the device_traces' TraceEntry lists.
for (auto& da : device_allocator) {
result.device_traces.emplace_back(da->trace(tsc_to_us));
auto snap = da->snapshot(mempool_id);
auto snap = da->snapshot();
result.segments.insert(result.segments.end(), snap.begin(), snap.end());
}
@ -3774,13 +3785,11 @@ class NativeCachingAllocator : public CUDAAllocator {
device_allocator[device]->resetPeakStats();
}
void createOrIncrefPool(
void ensureExistsAndIncrefPool(
c10::DeviceIndex device,
MempoolId_t mempool_id,
CUDAAllocator* allocator) override {
MempoolId_t mempool_id) override {
assertValidDevice(device);
device_allocator[device]->createOrIncrefPool(
std::move(mempool_id), allocator);
device_allocator[device]->ensureExistsAndIncrefPool(std::move(mempool_id));
}
void setUseOnOOM(c10::DeviceIndex device, MempoolId_t mempool_id) override {
@ -4125,7 +4134,7 @@ MemPool::MemPool(
id_ = {uuid_++, 0};
}
device_ = c10::cuda::current_device();
CUDACachingAllocator::createOrIncrefPool(device_, id_, allocator);
CUDACachingAllocator::ensureExistsAndIncrefPool(device_, id_);
if (use_on_oom) {
CUDACachingAllocator::setUseOnOOM(device_, id_);
}
@ -4134,7 +4143,8 @@ MemPool::MemPool(
MemPool::~MemPool() {
TORCH_INTERNAL_ASSERT(use_count() == 1);
CUDACachingAllocator::releasePool(device_, id_);
c10::cuda::CUDACachingAllocator::emptyCache(id_);
auto ctx = MemPoolContext(this);
c10::cuda::CUDACachingAllocator::emptyCache();
}
MempoolId_t MemPool::id() {
@ -4160,4 +4170,23 @@ MempoolId_t MemPool::graph_pool_handle(bool is_user_created) {
return {uuid_++, 0};
}
// Note that active_mempool_ is a global variable here
// and not inside MemPoolContext class, because in windows we
// can't use __declspec(dllexport) and __declspec(thread)
// together: https://stackoverflow.com/a/50967977
static thread_local MemPool* active_mempool_ = nullptr;
MemPoolContext::MemPoolContext(MemPool* mempool)
: prev_mempool_(active_mempool_) {
active_mempool_ = mempool;
}
MemPoolContext::~MemPoolContext() {
active_mempool_ = prev_mempool_;
}
MemPool* MemPoolContext::getActiveMemPool() {
return active_mempool_;
}
} // namespace c10::cuda

View File

@ -211,7 +211,7 @@ class CUDAAllocator : public Allocator {
virtual bool initialized() = 0;
virtual double getMemoryFraction(c10::DeviceIndex device) = 0;
virtual void setMemoryFraction(double fraction, c10::DeviceIndex device) = 0;
virtual void emptyCache(MempoolId_t mempool_id = {0, 0}) = 0;
virtual void emptyCache() = 0;
virtual void enable(bool value) = 0;
virtual bool isEnabled() const = 0;
virtual void cacheInfo(c10::DeviceIndex device, size_t* largestBlock) = 0;
@ -221,7 +221,7 @@ class CUDAAllocator : public Allocator {
c10::DeviceIndex device) = 0;
virtual void resetAccumulatedStats(c10::DeviceIndex device) = 0;
virtual void resetPeakStats(c10::DeviceIndex device) = 0;
virtual SnapshotInfo snapshot(MempoolId_t mempool_id = {0, 0}) = 0;
virtual SnapshotInfo snapshot() = 0;
virtual void beginAllocateToPool(
c10::DeviceIndex device,
MempoolId_t mempool_id,
@ -239,14 +239,13 @@ class CUDAAllocator : public Allocator {
" does not yet support getPoolUseCount. "
"If you need it, please file an issue describing your use case.");
}
virtual void createOrIncrefPool(
virtual void ensureExistsAndIncrefPool(
c10::DeviceIndex /*device*/,
MempoolId_t /*mempool_id*/,
CUDAAllocator* allocator = nullptr) {
MempoolId_t /*mempool_id*/) {
TORCH_CHECK(
false,
name(),
" does not yet support createOrIncrefPool. "
" does not yet support ensureExistsAndIncrefPool. "
"If you need it, please file an issue describing your use case.");
}
virtual void setUseOnOOM(c10::DeviceIndex device, MempoolId_t mempool_id) {
@ -365,7 +364,7 @@ inline void setMemoryFraction(double fraction, c10::DeviceIndex device) {
return get()->setMemoryFraction(fraction, device);
}
inline void emptyCache(MempoolId_t mempool_id = {0, 0}) {
inline void emptyCache() {
return get()->emptyCache();
}
@ -402,8 +401,8 @@ inline void resetPeakStats(c10::DeviceIndex device) {
return get()->resetPeakStats(device);
}
inline SnapshotInfo snapshot(MempoolId_t mempool_id = {0, 0}) {
return get()->snapshot(mempool_id);
inline SnapshotInfo snapshot() {
return get()->snapshot();
}
inline std::shared_ptr<AllocatorState> getCheckpointState(
@ -476,11 +475,10 @@ inline void attachAllocatorTraceTracker(AllocatorTraceTracker tracker) {
inline void releasePool(c10::DeviceIndex device, MempoolId_t mempool_id) {
return get()->releasePool(device, mempool_id);
}
inline void createOrIncrefPool(
inline void ensureExistsAndIncrefPool(
c10::DeviceIndex device,
MempoolId_t mempool_id,
CUDAAllocator* allocator_ptr = nullptr) {
get()->createOrIncrefPool(device, mempool_id, allocator_ptr);
MempoolId_t mempool_id) {
get()->ensureExistsAndIncrefPool(device, mempool_id);
}
inline void setUseOnOOM(c10::DeviceIndex device, MempoolId_t mempool_id) {
get()->setUseOnOOM(device, mempool_id);
@ -557,4 +555,26 @@ struct C10_CUDA_API MemPool {
c10::DeviceIndex device_;
};
// MemPoolContext holds the currently active pool and stashes the previous
// pool. On deletion it makes the previous pool active.
struct C10_CUDA_API MemPoolContext {
MemPoolContext(MemPool* mempool);
~MemPoolContext();
// getActiveMemPool() can be used to get the currently active pool.
// For instance: in CUDACachingAllocator, we can route allocations
// to a user provided allocator, by doing:
//
// auto active_pool = MemPoolContext::getActiveMemPool();
// if (active_pool && active_pool->allocator()) {
// ptr = active_pool->allocator()->raw_alloc(size);
// }
//
static MemPool* getActiveMemPool();
private:
MemPool* prev_mempool_;
};
} // namespace c10::cuda

View File

@ -496,7 +496,7 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator {
// introduces performance nondeterminism.
}
void emptyCache(/*unused*/ MempoolId_t mempool_id) override {
void emptyCache() override {
std::lock_guard<std::mutex> lk(general_mutex);
for (int dev = 0; dev < device_count; dev++) {
@ -778,7 +778,7 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator {
cudaMemPoolSetAttribute(mempool, cudaMemPoolAttrUsedMemHigh, &zero));
}
SnapshotInfo snapshot(MempoolId_t mempool_id) override {
SnapshotInfo snapshot() override {
TORCH_CHECK(
false,
"Calling snapshot with backend:cudaMallocAsync is not meaningful. "

View File

@ -2282,6 +2282,7 @@ coverage_ignore_classes = [
"UnsynchronizedAccessError",
# torch.cuda.memory
"MemPool",
"MemPoolContext",
# torch.distributed.elastic.multiprocessing.errors
"ChildFailedError",
"ProcessFailure",

View File

@ -128,6 +128,7 @@ Memory management
CUDAPluggableAllocator
change_current_allocator
MemPool
MemPoolContext
.. currentmodule:: torch.cuda.memory

View File

@ -5049,57 +5049,41 @@ class TestMemPool(TestCase):
# increments the id
self.assertTrue(abs(pool2[1] - pool1[1]) > 0)
def get_dummy_allocator(self, check_vars):
dummy_allocator_source_vars = """
def test_mempool_with_allocator(self):
pool = torch.cuda.MemPool()
# MemPool doesn't have an allocator by default
self.assertEqual(pool.allocator, None)
from torch.utils.cpp_extension import load_inline
dummy_allocator_source = """
#include <torch/extension.h>
#include <ATen/cuda/Exceptions.h>
#include <cuda_runtime_api.h>
extern "C" {
C10_EXPORT int called_dummy_alloc = 0;
C10_EXPORT int called_dummy_free = 0;
C10_EXPORT int called_dummy_alloc = 0;
C10_EXPORT int called_dummy_free = 0;
// Note that windows needs __declspec(dllexport): https://stackoverflow.com/a/24575865
C10_EXPORT void* dummy_alloc(size_t size, int device, void* stream) {
called_dummy_alloc = 123;
void* ptr;
C10_CUDA_CHECK(cudaMallocManaged(&ptr, size));
return ptr;
}
C10_EXPORT void dummy_free(void* ptr, size_t size, int device, void* stream) {
called_dummy_free = 321;
C10_CUDA_CHECK(cudaFree(ptr));
}
}
"""
dummy_allocator_source_no_vars = """
#include <torch/extension.h>
#include <ATen/cuda/Exceptions.h>
#include <cuda_runtime_api.h>
extern "C" {
// Note that windows needs __declspec(dllexport): https://stackoverflow.com/a/24575865
C10_EXPORT void* dummy_alloc(size_t size, int device, void* stream) {
called_dummy_alloc = 123;
void* ptr;
C10_CUDA_CHECK(cudaMallocManaged(&ptr, size));
return ptr;
}
C10_EXPORT void dummy_free(void* ptr, size_t size, int device, void* stream) {
called_dummy_free = 321;
C10_CUDA_CHECK(cudaFree(ptr));
}
}
"""
from torch.utils.cpp_extension import load_inline
dummy_allocator_libname = "dummy_allocator"
dummy_allocator = load_inline(
name=dummy_allocator_libname,
cpp_sources=dummy_allocator_source_vars
if check_vars
else dummy_allocator_source_no_vars,
cpp_sources=dummy_allocator_source,
is_python_module=False,
keep_intermediates=False,
verbose=True,
@ -5110,15 +5094,6 @@ class TestMemPool(TestCase):
"dummy_alloc",
"dummy_free",
)
return allocator, dummy_allocator
def test_mempool_with_allocator(self):
pool = torch.cuda.MemPool()
# MemPool doesn't have an allocator by default
self.assertEqual(pool.allocator, None)
allocator, dummy_allocator = self.get_dummy_allocator(check_vars=True)
pool = torch.cuda.MemPool(allocator.allocator())
# pool should point to the same allocator as the one passed into it
@ -5153,8 +5128,6 @@ class TestMemPool(TestCase):
# out tensor
self.assertEqual(called_dummy_alloc.value, 123)
out_non_pool = torch.empty(nelem_1mb, device="cuda")
with torch.cuda.use_mem_pool(pool):
# pool should have 1 segment since we made a small allocation (1 MB)
# above and so the CUDACachingAllocator packed it into a 2 MB buffer
@ -5172,9 +5145,6 @@ class TestMemPool(TestCase):
# to make a new 2 MB buffer to accomodate out_2
self.assertEqual(len(pool.snapshot()), 2)
all_segments = torch.cuda.memory._snapshot()["segments"]
self.assertEqual(len(all_segments), 3)
del out_0, out_1, out_2
# pool's destructor calls emptyCache()
@ -5186,7 +5156,40 @@ class TestMemPool(TestCase):
@serialTest()
def test_mempool_limited_memory_with_allocator(self):
allocator, _ = self.get_dummy_allocator(check_vars=False)
from torch.utils.cpp_extension import load_inline
dummy_allocator_source = """
#include <torch/extension.h>
#include <ATen/cuda/Exceptions.h>
#include <cuda_runtime_api.h>
extern "C" {
// Note that windows needs __declspec(dllexport): https://stackoverflow.com/a/24575865
C10_EXPORT void* dummy_alloc(size_t size, int device, void* stream) {
void* ptr;
C10_CUDA_CHECK(cudaMallocManaged(&ptr, size));
return ptr;
}
C10_EXPORT void dummy_free(void* ptr, size_t size, int device, void* stream) {
C10_CUDA_CHECK(cudaFree(ptr));
}
}
"""
dummy_allocator_libname = "dummy_allocator"
dummy_allocator = load_inline(
name=dummy_allocator_libname,
cpp_sources=dummy_allocator_source,
is_python_module=False,
keep_intermediates=False,
verbose=True,
with_cuda=True,
)
allocator = torch.cuda.memory.CUDAPluggableAllocator(
dummy_allocator,
"dummy_alloc",
"dummy_free",
)
pool_do_not_use = torch.cuda.MemPool(allocator.allocator())
pool_use = torch.cuda.MemPool(allocator.allocator(), use_on_oom=True)
@ -5255,13 +5258,38 @@ class TestMemPool(TestCase):
self._teardown_mempool_limited_memory_test()
def test_mempool_context(self):
active_pool = torch.cuda.MemPoolContext.active_pool()
# there is no active pool if none was made active
self.assertEqual(active_pool, None)
pool = torch.cuda.MemPool()
ctx = torch.cuda.MemPoolContext(pool)
active_pool = torch.cuda.MemPoolContext.active_pool()
# pool was made active
self.assertEqual(active_pool, pool)
del ctx
active_pool = torch.cuda.MemPoolContext.active_pool()
# ctx was deleted, so active pool is the previous one
self.assertEqual(active_pool, None)
def test_mempool_multithread(self):
pool_ids = []
active_pool_ids = []
def create_mempool_and_make_active():
pool = torch.cuda.MemPool()
pool_ids.extend([pool.id])
ctx = torch.cuda.MemPoolContext(pool)
active_pool = torch.cuda.MemPoolContext.active_pool()
active_pool_ids.extend([active_pool.id])
del ctx
num_threads = 4
threads = [
threading.Thread(target=create_mempool_and_make_active)
@ -5276,12 +5304,14 @@ class TestMemPool(TestCase):
# mempool id creation is atomic
self.assertEqual(len(set(pool_ids)), 4)
# each thread should have different active mempool, since
# the pointer to the mempool is thread local
self.assertEqual(len(set(active_pool_ids)), 4)
@skipIfRocm(msg="expandable_segments mode is not supported on ROCm")
@unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Load_inline doesn't work in fbcode")
def test_mempool_expandable(self):
torch.cuda.memory._set_allocator_settings("expandable_segments:True")
allocator, _ = self.get_dummy_allocator(check_vars=False)
pool = torch.cuda.MemPool(allocator.allocator())
pool = torch.cuda.MemPool()
# torch.cuda.MemPool doesn't work with expandable segments
with self.assertRaises(RuntimeError):

View File

@ -2027,7 +2027,7 @@ def _cuda_resetPeakMemoryStats(device: _int) -> None: ...
def _cuda_hostMemoryStats() -> dict[str, Any]: ...
def _cuda_resetAccumulatedHostMemoryStats() -> None: ...
def _cuda_resetPeakHostMemoryStats() -> None: ...
def _cuda_memorySnapshot(mempool_id: tuple[_int, _int] | None) -> dict[str, Any]: ...
def _cuda_memorySnapshot() -> dict[str, Any]: ...
def _cuda_record_memory_history_legacy(
enabled: _bool,
record_context: _bool,
@ -2304,6 +2304,11 @@ class _MemPool:
def allocator(self) -> _cuda_CUDAAllocator | None: ...
def use_count(self) -> _int: ...
class _MemPoolContext:
def __init__(self, pool: _MemPool) -> None: ...
@staticmethod
def active_pool() -> _MemPool | None: ...
def _cuda_isCurrentStreamCapturing() -> _bool: ...
def _graph_pool_handle() -> tuple[_int, _int]: ...

View File

@ -184,8 +184,7 @@ void CUDAPluggableAllocator::setMemoryFraction(
}
}
void CUDAPluggableAllocator::emptyCache(
/*unused*/ c10::cuda::MempoolId_t mempool_id) {
void CUDAPluggableAllocator::emptyCache() {
if (reset_fn_) {
return reset_fn_();
}
@ -238,8 +237,8 @@ void CUDAPluggableAllocator::resetPeakStats(c10::DeviceIndex device) {
"If you need it, please file an issue describing your use case.");
}
c10::cuda::CUDACachingAllocator::SnapshotInfo CUDAPluggableAllocator::snapshot(
c10::cuda::MempoolId_t mempool_id) {
c10::cuda::CUDACachingAllocator::SnapshotInfo CUDAPluggableAllocator::
snapshot() {
TORCH_CHECK(
false,
"CUDAPluggableAllocator does not yet support snapshot. "

View File

@ -114,7 +114,7 @@ struct TORCH_CUDA_CPP_API CUDAPluggableAllocator
bool initialized() override;
double getMemoryFraction(c10::DeviceIndex device) override;
void setMemoryFraction(double fraction, c10::DeviceIndex device) override;
void emptyCache(c10::cuda::MempoolId_t mempool_id = {0, 0}) override;
void emptyCache() override;
void enable(bool) override {}
bool isEnabled() const override {
return true;
@ -128,8 +128,7 @@ struct TORCH_CUDA_CPP_API CUDAPluggableAllocator
c10::DeviceIndex device) override;
void resetAccumulatedStats(c10::DeviceIndex device) override;
void resetPeakStats(c10::DeviceIndex device) override;
c10::cuda::CUDACachingAllocator::SnapshotInfo snapshot(
c10::cuda::MempoolId_t mempool) override;
c10::cuda::CUDACachingAllocator::SnapshotInfo snapshot() override;
void beginAllocateToPool(
c10::DeviceIndex device,
c10::cuda::MempoolId_t mempool_id,

View File

@ -24,4 +24,8 @@ void THCPMemPool_init(PyObject* module) {
.def_property_readonly("id", &::c10::cuda::MemPool::id)
.def_property_readonly("allocator", &::c10::cuda::MemPool::allocator)
.def("use_count", &::c10::cuda::MemPool::use_count);
shared_ptr_class_<::c10::cuda::MemPoolContext>(torch_C_m, "_MemPoolContext")
.def(py::init<c10::cuda::MemPool*>())
.def_static(
"active_pool", &::c10::cuda::MemPoolContext::getActiveMemPool);
}

View File

@ -721,24 +721,8 @@ CapturedTraceback* getFromContext(
"attempting to gather stack context from the wrong StackContext type.");
}
PyObject* THCPModule_memorySnapshot(PyObject* _unused, PyObject* arg) {
PyObject* THCPModule_memorySnapshot(PyObject* _unused, PyObject* noargs) {
HANDLE_TH_ERRORS
c10::cuda::MempoolId_t mempool_id = {0, 0};
if (arg && arg != Py_None) {
TORCH_CHECK(PyTuple_Check(arg), "mempool_id must be a tuple");
Py_ssize_t size = PyTuple_Size(arg);
TORCH_CHECK(size == 2, "mempool_id must be a tuple of 2 integers");
auto id1 = THPObjectPtr(PyTuple_GetItem(arg, 0));
auto id2 = THPObjectPtr(PyTuple_GetItem(arg, 1));
TORCH_CHECK(
THPUtils_checkLong(id1) && THPUtils_checkLong(id2),
"mempool_id elements must be integers");
mempool_id = c10::cuda::MempoolId_t(
static_cast<int64_t>(THPUtils_unpackLong(id1)),
static_cast<int64_t>(THPUtils_unpackLong(id2)));
}
using c10::cuda::CUDACachingAllocator::BlockInfo;
using c10::cuda::CUDACachingAllocator::SegmentInfo;
@ -818,7 +802,7 @@ PyObject* THCPModule_memorySnapshot(PyObject* _unused, PyObject* arg) {
return segmentDict;
};
auto snapshot = c10::cuda::CUDACachingAllocator::snapshot(mempool_id);
auto snapshot = c10::cuda::CUDACachingAllocator::snapshot();
py::list segments;
@ -2027,7 +2011,7 @@ static struct PyMethodDef _THCPModule_methods[] = {
THCPModule_resetPeakMemoryStats,
METH_O,
nullptr},
{"_cuda_memorySnapshot", THCPModule_memorySnapshot, METH_O, nullptr},
{"_cuda_memorySnapshot", THCPModule_memorySnapshot, METH_NOARGS, nullptr},
{"_cuda_attach_out_of_memory_observer",
THCPModule_attachOutOfMemoryObserver,
METH_O,

View File

@ -1186,7 +1186,8 @@ void ProcessGroupNCCL::registerMemPool(c10::cuda::MemPool* pool) {
// We must ensure we're listening for allocator trace events in order to
// register future segments allocated in this pool (this call is idempotent).
attachAllocatorHooks();
auto snapshot = c10::cuda::CUDACachingAllocator::snapshot(pool->id());
auto ctx = c10::cuda::MemPoolContext(pool);
auto snapshot = c10::cuda::CUDACachingAllocator::snapshot();
for (const auto& segmentInfo : snapshot.segments) {
TORCH_INTERNAL_ASSERT(
segmentInfo.device == pool->device(),
@ -1220,7 +1221,8 @@ void ProcessGroupNCCL::deregisterMemPool(c10::cuda::MemPool* pool) {
auto iter = ncclCommMemPoolMap.find(ncclComm);
iter->second.erase(pool->id());
}
auto snapshot = c10::cuda::CUDACachingAllocator::snapshot(pool->id());
auto ctx = c10::cuda::MemPoolContext(pool);
auto snapshot = c10::cuda::CUDACachingAllocator::snapshot();
for (const auto& segmentInfo : snapshot.segments) {
TORCH_INTERNAL_ASSERT(
segmentInfo.device == pool->device(),
@ -5570,6 +5572,7 @@ at::Tensor ProcessGroupNCCL::allocateTensor(
}
// Allocate tensor under this MemPool's context
auto ctx = c10::cuda::MemPoolContext(memPool_.get());
auto tid = std::this_thread::get_id();
c10::cuda::CUDACachingAllocator::beginAllocateToPool(
memPool_->device(), memPool_->id(), [=](cudaStream_t) {

View File

@ -1872,6 +1872,7 @@ __all__ = [
"memory_summary",
"memory_usage",
"MemPool",
"MemPoolContext",
"use_mem_pool",
"temperature",
"power_draw",

View File

@ -60,6 +60,7 @@ __all__ = [
"CUDAPluggableAllocator",
"change_current_allocator",
"MemPool",
"MemPoolContext",
"use_mem_pool",
]
@ -72,6 +73,7 @@ if not hasattr(torch._C, "_cuda_CUDAAllocator"):
if not hasattr(torch._C, "_MemPool"):
# Define dummy base classes
torch._C.__dict__["_MemPool"] = _dummy_type("_MemPool")
torch._C.__dict__["_MemPoolContext"] = _dummy_type("_MemPoolContext")
torch._C.__dict__["_cuda_beginAllocateToPool"] = _dummy_type(
"_cuda_beginAllocateToPool"
)
@ -90,6 +92,7 @@ from torch._C import ( # noqa: F401
_cuda_endAllocateToPool,
_cuda_releasePool,
_MemPool,
_MemPoolContext,
)
@ -614,7 +617,7 @@ def max_memory_cached(device: "Device" = None) -> int:
return max_memory_reserved(device=device)
def memory_snapshot(mempool_id=None):
def memory_snapshot():
r"""Return a snapshot of the CUDA memory allocator state across all devices.
Interpreting the output of this function requires familiarity with the
@ -624,7 +627,7 @@ def memory_snapshot(mempool_id=None):
See :ref:`cuda-memory-management` for more details about GPU memory
management.
"""
return torch._C._cuda_memorySnapshot(mempool_id)["segments"]
return torch._C._cuda_memorySnapshot()["segments"]
def memory_summary(device: "Device" = None, abbreviated: bool = False) -> str:
@ -995,7 +998,7 @@ def _snapshot(device: "Device" = None):
Returns:
The Snapshot dictionary object
"""
return _C._cuda_memorySnapshot(None)
return _C._cuda_memorySnapshot()
def _dump_snapshot(filename="dump_snapshot.pickle"):
@ -1107,6 +1110,25 @@ def _get_current_allocator() -> _CUDAAllocator:
return _CUDAAllocator(torch._C._cuda_getAllocator())
class MemPoolContext(_MemPoolContext):
r"""MemPoolContext holds the currently active pool and stashes the previous
pool. On deletion it makes the previous pool active.
Args:
pool(torch.cuda.MemPool): a MemPool object to be made active so that
allocations route to this pool.
"""
def __init__(self, pool: _MemPool):
super().__init__(pool)
@staticmethod
def active_pool() -> Optional[_MemPool]:
r"""Returns the active MemPool"""
return _MemPoolContext.active_pool()
class MemPool(_MemPool):
r"""MemPool represents a pool of memory in a caching allocator. Currently,
it's just the ID of the pool object maintained in the CUDACachingAllocator.
@ -1155,7 +1177,11 @@ class MemPool(_MemPool):
See :ref:`cuda-memory-management` for more details about GPU memory
management.
"""
snapshot = torch.cuda.memory_snapshot(self.id)
try:
ctx = MemPoolContext(self)
snapshot = torch.cuda.memory_snapshot()
finally:
del ctx
return snapshot
@ -1176,6 +1202,7 @@ def use_mem_pool(pool: MemPool, device: "Device" = None):
(e.g. by calling backward) the allocations in that thread will not
route to the given pool.
"""
ctx = MemPoolContext(pool)
device_index = (
torch.cuda.current_device() if device is None else _get_device_index(device)
)
@ -1185,3 +1212,4 @@ def use_mem_pool(pool: MemPool, device: "Device" = None):
finally:
_cuda_endAllocateToPool(device_index, pool.id)
_cuda_releasePool(device_index, pool.id)
del ctx