(Re-open) Adds cudaMallocAsync as an alternative backend for the CUDA allocator (#82682)

Rebased version of @mcarilli 's cudaMallocAsync #65365 for continued testing
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82682
Approved by: https://github.com/ngimel
This commit is contained in:
Eddie Yan 2022-10-12 03:44:21 +00:00 committed by PyTorch MergeBot
parent a216f4700c
commit 25725fd624
22 changed files with 1615 additions and 202 deletions

View File

@ -19,10 +19,10 @@ namespace at {
* *
* A CUDA graph containing multiple RNG ops behaves like a * A CUDA graph containing multiple RNG ops behaves like a
* single giant kernel from the perspective of ops external * single giant kernel from the perspective of ops external
* to the graph. During graph capture, logic below records * to the graph. During graph capture, logic in CUDAGeneratorImpl
* the total of all offset increments that occur in the graphed * records the total of all offset increments that occur in the
* region, and records the final total as the offset for the * graphed region, and records the final total as the offset for
* entire graph. * the entire graph.
* *
* When the graph reruns, the logic that reruns it * When the graph reruns, the logic that reruns it
* increments this device's CUDA generator's offset * increments this device's CUDA generator's offset
@ -30,8 +30,8 @@ namespace at {
* *
* Meanwhile, within the graph, at capture time, instead of * Meanwhile, within the graph, at capture time, instead of
* populating PhiloxCudaStates with the uint64_t offset pulled * populating PhiloxCudaStates with the uint64_t offset pulled
* directly from the global state, PhiloxCudaState instead * directly from the global state, PhiloxCudaState uses a pointer
* holds a pointer to one-element stream-local int64_t device tensor * to a one-element stream-local int64_t device tensor
* holding an initial offset value, and a uint64_t holding an * holding an initial offset value, and a uint64_t holding an
* intra-graph offset. (The intra-graph offset starts from zero * intra-graph offset. (The intra-graph offset starts from zero
* when capture begins.) In each consumer kernel, * when capture begins.) In each consumer kernel,

View File

@ -133,16 +133,42 @@ void CUDAGraph::capture_end() {
TORCH_CHECK(stream == capture_stream_, TORCH_CHECK(stream == capture_stream_,
"Capture must end on the same stream it began on."); "Capture must end on the same stream it began on.");
c10::cuda::CUDACachingAllocator::notifyCaptureEnd(capture_dev_, id_); c10::cuda::CUDACachingAllocator::notifyCaptureAboutToEnd(capture_dev_, id_);
AT_CUDA_CHECK(cudaStreamEndCapture(capture_stream_, &graph_)); AT_CUDA_CHECK(cudaStreamEndCapture(capture_stream_, &graph_));
TORCH_CHECK(graph_ != NULL, "Invalid capture."); TORCH_CHECK(graph_ != NULL, "Invalid capture.");
has_graph_ = true; has_graph_ = true;
c10::cuda::CUDACachingAllocator::notifyCaptureEnded(capture_dev_, id_);
// In typical graph usage some tensors (e.g. the tensors used for graph IO) are not freed
// between replays.
// If Pytorch compiles and runs with a CUDA 11.4+ toolkit, there's a chance the allocator backend
// is cudaMallocAsync.
// cudaMallocAsync is generally graph-safe, but if some tensors are not freed between replays,
// the graph's internal bookkeeping requires that we instantiate with
// cudaGraphInstantiateFlagAutoFreeOnLaunch. See
// cudaGraphLaunch
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__GRAPH.html#group__CUDART__GRAPH_1g1accfe1da0c605a577c22d9751a09597
// cudaGraphInstantiateWithFlags
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__GRAPH.html#group__CUDART__GRAPH_1ga2c652a24ba93e52b99a47bec0888233
#if CUDA_VERSION >= 11040
int version;
AT_CUDA_CHECK(cudaDriverGetVersion(&version));
if (version < 11040) {
#endif
// Trailing NULL, NULL, 0 arguments were recommended by Cuda driver people, // Trailing NULL, NULL, 0 arguments were recommended by Cuda driver people,
// who prefer not to report error message through these arguments moving forward // who prefer not to report error message through these arguments moving forward
// (they prefer return value, or errors on api calls internal to the capture) // (they prefer return value, or errors on api calls internal to the capture)
AT_CUDA_CHECK(cudaGraphInstantiate(&graph_exec_, graph_, NULL, NULL, 0)); AT_CUDA_CHECK(cudaGraphInstantiate(&graph_exec_, graph_, NULL, NULL, 0));
#if CUDA_VERSION >= 11040
} else {
AT_CUDA_CHECK(cudaGraphInstantiateWithFlags(&graph_exec_,
graph_,
cudaGraphInstantiateFlagAutoFreeOnLaunch));
}
#endif
has_graph_exec_ = true; has_graph_exec_ = true;
auto* gen = get_generator_or_default<CUDAGeneratorImpl>( auto* gen = get_generator_or_default<CUDAGeneratorImpl>(

View File

@ -1,10 +1,11 @@
#include <ATen/cuda/PeerToPeerAccess.h> #include <ATen/cuda/PeerToPeerAccess.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include <c10/util/Exception.h> #include <c10/util/Exception.h>
#include <c10/util/irange.h> #include <c10/util/irange.h>
#include <vector> #include <vector>
#include <algorithm>
namespace at { namespace at {
namespace cuda { namespace cuda {
@ -38,6 +39,13 @@ bool get_p2p_access(int dev, int dev_to_access) {
dev_to_access, " is not a device"); dev_to_access, " is not a device");
TORCH_INTERNAL_ASSERT(num_devices_ >= 0, "p2p access cache not initialized"); TORCH_INTERNAL_ASSERT(num_devices_ >= 0, "p2p access cache not initialized");
#ifdef USE_ROCM
bool using_cudaMallocAsync = false;
#else
bool using_cudaMallocAsync = (CUDACachingAllocator::allocatorBackend() ==
CUDACachingAllocator::AllocatorBackend::CUDAMALLOCASYNC);
#endif
auto &cache = p2pAccessEnabled_[dev * num_devices_ + dev_to_access]; auto &cache = p2pAccessEnabled_[dev * num_devices_ + dev_to_access];
if (cache != -1) { if (cache != -1) {
@ -49,6 +57,29 @@ bool get_p2p_access(int dev, int dev_to_access) {
int access = 0; int access = 0;
C10_CUDA_CHECK(cudaDeviceCanAccessPeer(&access, dev, dev_to_access)); C10_CUDA_CHECK(cudaDeviceCanAccessPeer(&access, dev, dev_to_access));
if (access) { if (access) {
if (using_cudaMallocAsync) {
#if CUDA_VERSION >= 11040
// Double-checks allocator backend hasn't changed, which would definitely be an error.
#ifndef USE_ROCM
TORCH_INTERNAL_ASSERT(CUDACachingAllocator::allocatorBackend() ==
CUDACachingAllocator::AllocatorBackend::CUDAMALLOCASYNC);
#endif
// cudaMallocAsync pools are unaffected by cudaDeviceEnablePeerAccess.
// We need pool-specific enablement. See
// https://developer.nvidia.com/blog/using-cuda-stream-ordered-memory-allocator-part-2/
cudaMemPool_t mempool;
C10_CUDA_CHECK(cudaDeviceGetDefaultMemPool(&mempool, dev_to_access));
cudaMemAccessDesc desc = {};
desc.location.type = cudaMemLocationTypeDevice;
desc.location.id = dev;
desc.flags = cudaMemAccessFlagsProtReadWrite;
C10_CUDA_CHECK(cudaMemPoolSetAccess(mempool, &desc, 1 /* numDescs */));
#else
TORCH_INTERNAL_ASSERT(false);
#endif
} else {
TORCH_INTERNAL_ASSERT(CUDACachingAllocator::allocatorBackend() ==
CUDACachingAllocator::AllocatorBackend::NATIVE);
cudaError_t err = cudaDeviceEnablePeerAccess(dev_to_access, 0); cudaError_t err = cudaDeviceEnablePeerAccess(dev_to_access, 0);
if (err == cudaErrorPeerAccessAlreadyEnabled) { if (err == cudaErrorPeerAccessAlreadyEnabled) {
// ignore and clear the error if access was already enabled // ignore and clear the error if access was already enabled
@ -56,6 +87,7 @@ bool get_p2p_access(int dev, int dev_to_access) {
} else { } else {
C10_CUDA_CHECK(err); C10_CUDA_CHECK(err);
} }
}
cache = 1; cache = 1;
} else { } else {
cache = 0; cache = 0;

View File

@ -6,7 +6,6 @@
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAEvent.h> #include <ATen/cuda/CUDAEvent.h>
#include <ATen/cuda/PeerToPeerAccess.h> #include <ATen/cuda/PeerToPeerAccess.h>
#include <c10/cuda/CUDAStream.h>
#include <ATen/native/Copy.h> #include <ATen/native/Copy.h>
#include <ATen/native/TensorIterator.h> #include <ATen/native/TensorIterator.h>
#include <ATen/native/cuda/Loops.cuh> #include <ATen/native/cuda/Loops.cuh>
@ -17,6 +16,9 @@
#include <ATen/ops/empty_like.h> #include <ATen/ops/empty_like.h>
#endif #endif
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/cuda/CUDAStream.h>
namespace at { namespace at {
namespace native { namespace native {
@ -46,7 +48,9 @@ void neg_conj_kernel_cuda(TensorIteratorBase &iter) {
using namespace at::cuda; using namespace at::cuda;
// device-to-device copy, does type conversion // device-to-device copy, does type conversion
void copy_device_to_device(TensorIterator& iter, bool non_blocking) { void copy_device_to_device(TensorIterator& iter,
bool non_blocking,
bool p2p_enabled) {
int64_t numel = iter.numel(); int64_t numel = iter.numel();
// We can memcpy the memory if both tensors have the same type AND both // We can memcpy the memory if both tensors have the same type AND both
@ -87,12 +91,30 @@ void copy_device_to_device(TensorIterator& iter, bool non_blocking) {
void *src = iter.data_ptr(1); void *src = iter.data_ptr(1);
size_t size = numel * iter.element_size(0); size_t size = numel * iter.element_size(0);
if (src != dst || src_device != dst_device) { if (src != dst || src_device != dst_device) {
// Perform the copy // Due to bizarre cuda driver intricacies, copies of
// cudaMallocAsynced memory between devices that aren't
// peer-to-peer-capable need "cudaMemcpyPeerAsync".
#ifdef USE_ROCM
bool using_cudaMallocAsync = false;
#else
bool using_cudaMallocAsync = (CUDACachingAllocator::allocatorBackend() ==
CUDACachingAllocator::AllocatorBackend::CUDAMALLOCASYNC);
#endif
bool needs_MemcpyPeer = (src_device != dst_device &&
using_cudaMallocAsync &&
!p2p_enabled);
if (needs_MemcpyPeer) {
AT_CUDA_CHECK(cudaMemcpyPeerAsync(
dst, dst_device.index(),
src, src_device.index(),
size, copy_stream));
} else {
AT_CUDA_CHECK(cudaMemcpyAsync( AT_CUDA_CHECK(cudaMemcpyAsync(
dst, src, size, dst, src, size,
cudaMemcpyDeviceToDevice, cudaMemcpyDeviceToDevice,
copy_stream)); copy_stream));
} }
}
} else { } else {
if (same_neg) { if (same_neg) {
if (!same_conj) { if (!same_conj) {
@ -205,7 +227,7 @@ static void copy_kernel_cuda(TensorIterator& iter, bool non_blocking) {
// Copy on GPU (or between GPUs) // Copy on GPU (or between GPUs)
if (dst_device.is_cuda() && src_device.is_cuda()) { if (dst_device.is_cuda() && src_device.is_cuda()) {
copy_device_to_device(iter, non_blocking); copy_device_to_device(iter, non_blocking, p2p_enabled);
return; return;
} }

View File

@ -205,10 +205,11 @@ size_t getMaxWorkspaceSize(
{ {
size_t max_ws_size = 0; size_t max_ws_size = 0;
size_t max_block_size = 0; size_t max_block_size = 0;
size_t tmp_bytes = 0; // Only used for filling pointer parameters that aren't used later
const auto device = c10::cuda::current_device(); const auto device = c10::cuda::current_device();
c10::cuda::CUDACachingAllocator::cacheInfo(device, &tmp_bytes, &max_block_size); // For the native allocator, retrieves the size of the largest unused block.
// For cudaMallocAsync, see c10/cuda/CUDAMallocAsync.cpp:cacheInfo for details.
c10::cuda::CUDACachingAllocator::cacheInfo(device, &max_block_size);
for (const auto i : c10::irange(n_algo)) { for (const auto i : c10::irange(n_algo)) {
cudnnStatus_t err; cudnnStatus_t err;

View File

@ -307,8 +307,7 @@ size_t get_available_workspace() {
int device; int device;
C10_CUDA_CHECK(cudaGetDevice(&device)); C10_CUDA_CHECK(cudaGetDevice(&device));
size_t max_block_size = 0; size_t max_block_size = 0;
size_t tmp_bytes = 0; // Only used for filling pointer parameters that aren't used later c10::cuda::CUDACachingAllocator::cacheInfo(device, &max_block_size);
c10::cuda::CUDACachingAllocator::cacheInfo(device, &tmp_bytes, &max_block_size);
return max_block_size; return max_block_size;
} }

View File

@ -25,6 +25,8 @@ set(C10_CUDA_SRCS
CUDAFunctions.cpp CUDAFunctions.cpp
CUDAMiscFunctions.cpp CUDAMiscFunctions.cpp
CUDAStream.cpp CUDAStream.cpp
CUDACachingAllocator.cpp
CUDAMallocAsyncAllocator.cpp
impl/CUDAGuardImpl.cpp impl/CUDAGuardImpl.cpp
impl/CUDATest.cpp impl/CUDATest.cpp
) )
@ -36,6 +38,7 @@ set(C10_CUDA_HEADERS
CUDAMathCompat.h CUDAMathCompat.h
CUDAMiscFunctions.h CUDAMiscFunctions.h
CUDAStream.h CUDAStream.h
CUDACachingAllocator.h
impl/CUDAGuardImpl.h impl/CUDAGuardImpl.h
impl/CUDATest.h impl/CUDATest.h
) )

View File

@ -1,4 +1,3 @@
#include <c10/cuda/CUDACachingAllocator.h> #include <c10/cuda/CUDACachingAllocator.h>
#include <c10/core/impl/GPUTrace.h> #include <c10/core/impl/GPUTrace.h>
@ -28,6 +27,7 @@ C10_DEFINE_REGISTRY(FreeCudaMemoryCallbacksRegistry, FreeMemoryCallback);
namespace cuda { namespace cuda {
namespace CUDACachingAllocator { namespace CUDACachingAllocator {
namespace Native {
// //
// Yet another caching allocator for CUDA device allocations. // Yet another caching allocator for CUDA device allocations.
@ -73,7 +73,7 @@ namespace CUDACachingAllocator {
* must be available for the graph to use during replay. DeviceCachingAllocator * must be available for the graph to use during replay. DeviceCachingAllocator
* assigns and frees memory eagerly and dynamically, so if we're not careful * assigns and frees memory eagerly and dynamically, so if we're not careful
* about managing graphs' memory, at replay time those memory addresses could be * about managing graphs' memory, at replay time those memory addresses could be
* use by other tensors. * used by other tensors.
* *
* To guarantee a graph's baked in addresses are safe to reuse in replay, * To guarantee a graph's baked in addresses are safe to reuse in replay,
* DeviceAllocator satisfies allocations from a graph-private memory pool during * DeviceAllocator satisfies allocations from a graph-private memory pool during
@ -88,14 +88,12 @@ namespace CUDACachingAllocator {
* (regardless whether those captures are idle or replaying). * (regardless whether those captures are idle or replaying).
* *
* CUDAGraph's requests for private pools are mediated by * CUDAGraph's requests for private pools are mediated by
* DeviceAllocator::notifyCaptureBegin, notifyCaptureEnd, and * DeviceAllocator::notifyCaptureBegin,
* notifyCaptureAboutToEnd,
* notifyCaptureEnded,
* notifyCaptureDestroy. * notifyCaptureDestroy.
*/ */
namespace {
using stream_set = ska::flat_hash_set<cuda::CUDAStream>;
constexpr size_t kMinBlockSize = constexpr size_t kMinBlockSize =
512; // all sizes are rounded to at least 512 bytes 512; // all sizes are rounded to at least 512 bytes
constexpr size_t kSmallSize = 1048576; // largest "small" allocation is 1 MiB constexpr size_t kSmallSize = 1048576; // largest "small" allocation is 1 MiB
@ -107,6 +105,10 @@ constexpr size_t kMinLargeAlloc =
10485760; // allocations between 1 and 10 MiB may use kLargeBuffer 10485760; // allocations between 1 and 10 MiB may use kLargeBuffer
constexpr size_t kRoundLarge = 2097152; // round up large allocations to 2 MiB constexpr size_t kRoundLarge = 2097152; // round up large allocations to 2 MiB
namespace {
using stream_set = ska::flat_hash_set<cuda::CUDAStream>;
using StatTypes = std::array<bool, static_cast<size_t>(StatType::NUM_TYPES)>; using StatTypes = std::array<bool, static_cast<size_t>(StatType::NUM_TYPES)>;
void update_stat(Stat& stat, int64_t amount) { void update_stat(Stat& stat, int64_t amount) {
@ -237,25 +239,6 @@ static bool BlockComparator(const Block* a, const Block* b) {
return (uintptr_t)a->ptr < (uintptr_t)b->ptr; return (uintptr_t)a->ptr < (uintptr_t)b->ptr;
} }
static std::string format_size(uint64_t size) {
std::ostringstream os;
os.precision(2);
os << std::fixed;
if (size <= 1024) {
os << size << " bytes";
} else if (size <= 1048576) {
os << (size / 1024.0);
os << " KiB";
} else if (size <= 1073741824ULL) {
os << size / 1048576.0;
os << " MiB";
} else {
os << size / 1073741824.0;
os << " GiB";
}
return os.str();
}
struct AllocParams { struct AllocParams {
AllocParams( AllocParams(
int device, int device,
@ -403,10 +386,102 @@ cudaError_t cudaMallocMaybeCapturing(void** p, size_t size) {
#endif #endif
} }
} // namespace } // anonymous namespace
} // namespace Native
// Backend static initialization.
#define DECLARE_BACKEND_INTERFACE(RET, FUNC, ARGS) RET FUNC ARGS;
// Not called directly by clients.
namespace Native {
FORALL_ALLOCATOR_INTERFACE(DECLARE_BACKEND_INTERFACE)
}
// Not called directly by clients.
namespace CudaMallocAsync {
FORALL_ALLOCATOR_INTERFACE(DECLARE_BACKEND_INTERFACE)
}
#undef DECLARE_BACKEND_INTERFACE
#define DEFINE_CHOSEN(RET, FUNC, ARGS) RET(*FUNC) ARGS = 0;
namespace Chosen {
FORALL_ALLOCATOR_INTERFACE(DEFINE_CHOSEN);
} // namespace Chosen
#define INITIALIZE_NATIVE(RET, FUNC, ARGS) Chosen::FUNC = Native::FUNC;
#define INITIALIZE_CUDAMALLOCASYNC(RET, FUNC, ARGS) \
Chosen::FUNC = CudaMallocAsync::FUNC;
struct BackendStaticInitializer {
AllocatorBackend backend;
// Parses env for backend at load time, duplicating some logic from
// CachingAllocatorConfig. CachingAllocatorConfig double-checks it later (at
// runtime). Defers verbose exceptions and error checks, including Cuda
// version checks, to CachingAllocatorConfig's runtime doublecheck. If this
// works, maybe we should move all of CachingAllocatorConfig here?
AllocatorBackend parseEnvForBackend() {
const char* val = getenv("PYTORCH_CUDA_ALLOC_CONF");
if (val == NULL) {
return AllocatorBackend::NATIVE;
} else {
const std::string config(val);
std::regex exp("[\\s,]+");
std::sregex_token_iterator it(config.begin(), config.end(), exp, -1);
std::sregex_token_iterator end;
std::vector<std::string> options(it, end);
for (auto option : options) {
std::regex exp2("[:]+");
std::sregex_token_iterator it2(option.begin(), option.end(), exp2, -1);
std::sregex_token_iterator end2;
std::vector<std::string> kv(it2, end2);
if (kv.size() >= 2) {
if (kv[0] == "backend") {
if (kv[1] == "cudaMallocAsync")
return AllocatorBackend::CUDAMALLOCASYNC;
if (kv[1] == "native")
return AllocatorBackend::NATIVE;
}
}
}
}
return AllocatorBackend::NATIVE;
}
BackendStaticInitializer() {
backend = parseEnvForBackend();
switch (backend) {
case AllocatorBackend::NATIVE:
FORALL_ALLOCATOR_INTERFACE(INITIALIZE_NATIVE)
break;
case AllocatorBackend::CUDAMALLOCASYNC:
FORALL_ALLOCATOR_INTERFACE(INITIALIZE_CUDAMALLOCASYNC)
break;
}
}
};
#undef INITIALIZE_NATIVE
#undef INITIALIZE_CUDAMALLOCAYSNC
BackendStaticInitializer backend_static_initializer{};
// Environment config parser
// Defined here, rather than its own .cpp file,
// because parseArgs needs to know kLargeBuffer.
// Defined outside namespace Native because it's not Native-specific.
class CachingAllocatorConfig { class CachingAllocatorConfig {
public: public:
static AllocatorBackend allocator_backend() {
return instance().m_allocator_backend;
}
static size_t max_split_size() { static size_t max_split_size() {
return instance().m_max_split_size; return instance().m_max_split_size;
} }
@ -441,6 +516,7 @@ class CachingAllocatorConfig {
m_roundup_power2_divisions = 0; m_roundup_power2_divisions = 0;
m_roundup_bypass_threshold = std::numeric_limits<size_t>::max(); m_roundup_bypass_threshold = std::numeric_limits<size_t>::max();
m_garbage_collection_threshold = 0; m_garbage_collection_threshold = 0;
m_allocator_backend = AllocatorBackend::NATIVE;
if (env == nullptr) { if (env == nullptr) {
return; return;
@ -453,6 +529,9 @@ class CachingAllocatorConfig {
std::sregex_token_iterator end; std::sregex_token_iterator end;
std::vector<std::string> options(it, end); std::vector<std::string> options(it, end);
bool used_cudaMallocAsync = false;
bool used_native_specific_option = false;
for (auto option : options) { for (auto option : options) {
std::regex exp2("[:]+"); std::regex exp2("[:]+");
std::sregex_token_iterator it2(option.begin(), option.end(), exp2, -1); std::sregex_token_iterator it2(option.begin(), option.end(), exp2, -1);
@ -460,28 +539,59 @@ class CachingAllocatorConfig {
std::vector<std::string> kv(it2, end2); std::vector<std::string> kv(it2, end2);
if (kv.size() >= 2) { if (kv.size() >= 2) {
/* Maximum split size in MB. Limited to large size blocks */ /* Maximum split size in MB. Limited to large size blocks */
if (kv[0].compare("max_split_size_mb") == 0) { if (kv[0] == "max_split_size_mb") {
size_t val2 = stoi(kv[1]); size_t val2 = stoi(kv[1]);
TORCH_CHECK( TORCH_CHECK(
val2 > kLargeBuffer / (1024 * 1024), val2 > Native::kLargeBuffer / (1024 * 1024),
"CachingAllocator option max_split_size_mb too small, must be > ", "CachingAllocator option max_split_size_mb too small, must be > ",
kLargeBuffer / (1024 * 1024), Native::kLargeBuffer / (1024 * 1024),
""); "");
val2 = std::max(val2, kLargeBuffer / (1024 * 1024)); val2 = std::max(val2, Native::kLargeBuffer / (1024 * 1024));
val2 = std::min( val2 = std::min(
val2, (std::numeric_limits<size_t>::max() / (1024 * 1024))); val2, (std::numeric_limits<size_t>::max() / (1024 * 1024)));
m_max_split_size = val2 * 1024 * 1024; m_max_split_size = val2 * 1024 * 1024;
} else if (kv[0].compare("roundup_power2_divisions") == 0) { used_native_specific_option = true;
} else if (kv[0] == "roundup_power2_divisions") {
size_t val2 = stoi(kv[1]); size_t val2 = stoi(kv[1]);
TORCH_CHECK( TORCH_CHECK(
llvm::isPowerOf2_64(val2), llvm::isPowerOf2_64(val2),
"For roundups, the divisons has to be power of 2 ", "For roundups, the divisons has to be power of 2 ",
""); "");
m_roundup_power2_divisions = val2; m_roundup_power2_divisions = val2;
} else if (kv[0].compare("roundup_bypass_threshold_mb") == 0) { used_native_specific_option = true;
} else if (kv[0] == "roundup_bypass_threshold_mb") {
size_t val2 = stoi(kv[1]); size_t val2 = stoi(kv[1]);
m_roundup_bypass_threshold = val2 * 1024 * 1024; m_roundup_bypass_threshold = val2 * 1024 * 1024;
} else if (kv[0].compare("garbage_collection_threshold") == 0) { used_native_specific_option = true;
} else if (kv[0] == "backend") {
TORCH_CHECK(
((kv[1] == "native") || (kv[1] == "cudaMallocAsync")),
"Unknown allocator backend, "
"options are native and cudaMallocAsync");
used_cudaMallocAsync = (kv[1] == "cudaMallocAsync");
if (used_cudaMallocAsync) {
#if CUDA_VERSION >= 11040
int version;
C10_CUDA_CHECK(cudaDriverGetVersion(&version));
TORCH_CHECK(
version >= 11040,
"backend:cudaMallocAsync requires CUDA runtime "
"11.4 or newer, but cudaDriverGetVersion returned ",
version);
m_allocator_backend = AllocatorBackend::CUDAMALLOCASYNC;
#else
TORCH_CHECK(
false,
"backend:cudaMallocAsync requires PyTorch to be built with "
"CUDA 11.4 or newer, but CUDA_VERSION is ",
CUDA_VERSION);
#endif
}
TORCH_INTERNAL_ASSERT(
m_allocator_backend == backend_static_initializer.backend,
"Allocator backend parsed at runtime != "
"allocator backend parsed at load time");
} else if (kv[0] == "garbage_collection_threshold") {
/* /*
* Perform garbage collection of GPU memory blocks to avoid * Perform garbage collection of GPU memory blocks to avoid
* triggering expensive sync-and-reclaim-all operation. Upon setting * triggering expensive sync-and-reclaim-all operation. Upon setting
@ -500,10 +610,17 @@ class CachingAllocatorConfig {
"garbage_collect_threshold too big, set it 0.0~1.0", "garbage_collect_threshold too big, set it 0.0~1.0",
""); "");
m_garbage_collection_threshold = val2; m_garbage_collection_threshold = val2;
used_native_specific_option = true;
} else { } else {
TORCH_CHECK(false, "Unrecognized CachingAllocator option: ", kv[0]); TORCH_CHECK(false, "Unrecognized CachingAllocator option: ", kv[0]);
} }
} }
if (used_cudaMallocAsync && used_native_specific_option) {
TORCH_WARN(
"backend:cudaMallocAsync ignores max_split_size_mb, roundup_bypass_threshold_mb,"
"roundup_power2_divisions, and garbage_collect_threshold.");
}
} }
} }
@ -516,8 +633,11 @@ class CachingAllocatorConfig {
std::atomic<size_t> m_roundup_power2_divisions; std::atomic<size_t> m_roundup_power2_divisions;
std::atomic<size_t> m_roundup_bypass_threshold; std::atomic<size_t> m_roundup_bypass_threshold;
std::atomic<double> m_garbage_collection_threshold; std::atomic<double> m_garbage_collection_threshold;
AllocatorBackend m_allocator_backend;
}; };
namespace Native {
class DeviceCachingAllocator { class DeviceCachingAllocator {
private: private:
// lock around all operations // lock around all operations
@ -938,8 +1058,8 @@ class DeviceCachingAllocator {
release_cached_blocks(); release_cached_blocks();
} }
/** Retrieves info (total size + largest block) of the memory cache **/ /** Retrieves size of largest unused block held by the memory cache **/
void cacheInfo(size_t* total, size_t* largest) { void cacheInfo(size_t* largest) {
std::lock_guard<std::recursive_mutex> lock(mutex); std::lock_guard<std::recursive_mutex> lock(mutex);
if (*largest == if (*largest ==
0) { // make an initial guess if a zero *largest is passed in 0) { // make an initial guess if a zero *largest is passed in
@ -948,11 +1068,11 @@ class DeviceCachingAllocator {
largest, // Use free memory as an optimistic initial guess of *largest largest, // Use free memory as an optimistic initial guess of *largest
&tmp_bytes)); &tmp_bytes));
} }
cache_info_aux(large_blocks, total, largest); cache_info_aux(large_blocks, largest);
cache_info_aux(small_blocks, total, largest); cache_info_aux(small_blocks, largest);
for (const auto& gp : graph_pools) { for (const auto& gp : graph_pools) {
cache_info_aux(gp.second->large_blocks, total, largest); cache_info_aux(gp.second->large_blocks, largest);
cache_info_aux(gp.second->small_blocks, total, largest); cache_info_aux(gp.second->small_blocks, largest);
} }
} }
@ -1145,7 +1265,7 @@ class DeviceCachingAllocator {
} }
// Called by CUDAGraph::capture_end // Called by CUDAGraph::capture_end
void notifyCaptureEnd(CaptureId_t graph_id) { void notifyCaptureAboutToEnd(CaptureId_t graph_id) {
std::lock_guard<std::recursive_mutex> lock(mutex); std::lock_guard<std::recursive_mutex> lock(mutex);
captures_underway--; captures_underway--;
auto it = capture_to_pool_map.find(graph_id); auto it = capture_to_pool_map.find(graph_id);
@ -1474,14 +1594,13 @@ class DeviceCachingAllocator {
if (p.err != cudaSuccess) { if (p.err != cudaSuccess) {
if (p.err == cudaErrorMemoryAllocation) { if (p.err == cudaErrorMemoryAllocation) {
// If this is the first attempt (!isRetry), we can forgive and clear // If this is the first attempt (!isRetry), we can forgive and clear
// CUDA's
// internal error state.
// If this is the second attempt (isRetry), malloc's TORCH_CHECK_WITH
// will take
// over to throw a helpful exception. The user can choose to catch
// the exception, free some stuff in their script, and attempt their
// allocation again. In this case, we can also forgive and clear
// CUDA's internal error state. // CUDA's internal error state.
//
// If this is the second attempt (isRetry), malloc's TORCH_CHECK_WITH
// will take over to throw a helpful exception. The user can choose
// to catch the exception, free some stuff in their script, and
// attempt the allocation again. In this case, we can also forgive and
// clear CUDA's internal error state.
cudaGetLastError(); cudaGetLastError();
} else { } else {
// If the error's unrelated to memory allocation, we should throw // If the error's unrelated to memory allocation, we should throw
@ -1735,11 +1854,10 @@ class DeviceCachingAllocator {
} }
} }
// Accumulates sizes of all memory blocks for given device in given pool // Iterates over sizes of all memory blocks for given device in given pool
void cache_info_aux(const BlockPool& pool, size_t* total, size_t* largest) { void cache_info_aux(const BlockPool& pool, size_t* largest) {
for (const auto& block : pool.blocks) { for (const auto& block : pool.blocks) {
const auto blocksize = block->size; const auto blocksize = block->size;
*total += blocksize;
if (blocksize > *largest) { if (blocksize > *largest) {
*largest = blocksize; *largest = blocksize;
} }
@ -1769,7 +1887,7 @@ class DeviceCachingAllocator {
} }
}; };
class THCCachingAllocator { class NativeCachingAllocator {
private: private:
std::mutex mutex; std::mutex mutex;
@ -1931,7 +2049,7 @@ class THCCachingAllocator {
} }
}; };
THCCachingAllocator caching_allocator; NativeCachingAllocator caching_allocator;
// Returns whether to force all allocations to bypass the caching allocator and // Returns whether to force all allocations to bypass the caching allocator and
// go straight to cudaMalloc. This setting is useful when debugging GPU memory // go straight to cudaMalloc. This setting is useful when debugging GPU memory
@ -1950,8 +2068,8 @@ static void uncached_delete(void* ptr) {
C10_CUDA_CHECK(cudaFree(ptr)); C10_CUDA_CHECK(cudaFree(ptr));
} }
// NB: I decided not to fold this into THCCachingAllocator, because the latter // NB: I decided not to fold this into NativeCachingAllocator, because the
// has a lot more methods and it wasn't altogether clear that they should // latter has a lot more methods and it wasn't altogether clear that they should
// actually be publicly exposed // actually be publicly exposed
struct CudaCachingAllocator : public Allocator { struct CudaCachingAllocator : public Allocator {
DataPtr allocate(size_t size) const override { DataPtr allocate(size_t size) const override {
@ -2026,9 +2144,8 @@ void emptyCache(void) {
caching_allocator.emptyCache(); caching_allocator.emptyCache();
} }
void cacheInfo(int dev_id, size_t* cachedAndFree, size_t* largestBlock) { void cacheInfo(int dev_id, size_t* largestBlock) {
caching_allocator.device_allocator[dev_id]->cacheInfo( caching_allocator.device_allocator[dev_id]->cacheInfo(largestBlock);
cachedAndFree, largestBlock);
} }
void* getBaseAllocation(void* ptr, size_t* size) { void* getBaseAllocation(void* ptr, size_t* size) {
@ -2081,17 +2198,45 @@ void notifyCaptureBegin(
graph_id, mempool_id); graph_id, mempool_id);
} }
void notifyCaptureEnd(int device, CaptureId_t graph_id) { void notifyCaptureAboutToEnd(int device, CaptureId_t graph_id) {
assertValidDevice(device); assertValidDevice(device);
caching_allocator.device_allocator[device]->notifyCaptureEnd(graph_id); caching_allocator.device_allocator[device]->notifyCaptureAboutToEnd(graph_id);
} }
void notifyCaptureEnded(int device, CaptureId_t graph_id) {} // no-op
void notifyCaptureDestroy(int device, MempoolId_t mempool_id) { void notifyCaptureDestroy(int device, MempoolId_t mempool_id) {
assertValidDevice(device); assertValidDevice(device);
caching_allocator.device_allocator[device]->notifyCaptureDestroy(mempool_id); caching_allocator.device_allocator[device]->notifyCaptureDestroy(mempool_id);
} }
// void* raw_alloc(size_t nbytes) {
if (nbytes == 0) {
return nullptr;
}
int device;
C10_CUDA_CHECK(cudaGetDevice(&device));
void* r = nullptr;
caching_allocator.malloc(
&r, device, nbytes, cuda::getCurrentCUDAStream(device));
return r;
}
void* raw_alloc_with_stream(size_t nbytes, cudaStream_t stream) {
if (nbytes == 0) {
return nullptr;
}
int device;
C10_CUDA_CHECK(cudaGetDevice(&device));
void* r = nullptr;
caching_allocator.malloc(&r, device, nbytes, stream);
return r;
}
void raw_delete(void* ptr) {
caching_allocator.free(ptr);
}
// In CUDA IPC, sender sends a tensor to receiver, getIpcDevPtr // In CUDA IPC, sender sends a tensor to receiver, getIpcDevPtr
// is called by the receiving process to map the CUDA memory from the sending // is called by the receiving process to map the CUDA memory from the sending
// process into its own address space. // process into its own address space.
@ -2146,34 +2291,42 @@ std::shared_ptr<void> getIpcDevPtr(std::string handle) {
return sp; return sp;
} }
void* raw_alloc(size_t nbytes) { } // namespace Native
if (nbytes == 0) {
return nullptr; // General caching allocator utilities
}
int device; // External config interface (declared in CUDACachingAllocator.h)
C10_CUDA_CHECK(cudaGetDevice(&device)); // This is a useless layer of indirection with a minor
void* r = nullptr; // code-cleanliness benefit: it alleviates the need to define
caching_allocator.malloc( // CachingAllocatorConfig itself in CUDACachingAllocator.h.
&r, device, nbytes, cuda::getCurrentCUDAStream(device)); AllocatorBackend allocatorBackend() {
return r; return CachingAllocatorConfig::allocator_backend();
} }
void* raw_alloc_with_stream(size_t nbytes, cudaStream_t stream) { void setAllocatorSettings(const std::string& env) {
if (nbytes == 0) { CachingAllocatorConfig::instance().parseArgs(env.c_str());
return nullptr;
}
int device;
C10_CUDA_CHECK(cudaGetDevice(&device));
void* r = nullptr;
caching_allocator.malloc(&r, device, nbytes, stream);
return r;
} }
void raw_delete(void* ptr) { // Size pretty-printer
caching_allocator.free(ptr); inline std::string format_size(uint64_t size) {
std::ostringstream os;
os.precision(2);
os << std::fixed;
if (size <= 1024) {
os << size << " bytes";
} else if (size <= 1048576) {
os << (size / 1024.0);
os << " KiB";
} else if (size <= 1073741824ULL) {
os << size / 1048576.0;
os << " MiB";
} else {
os << size / 1073741824.0;
os << " GiB";
}
return os.str();
} }
} // namespace CUDACachingAllocator } // namespace CUDACachingAllocator
} // namespace cuda } // namespace cuda
} // namespace c10 } // namespace c10

View File

@ -1,5 +1,5 @@
#ifndef THC_DEVICE_ALLOCATOR_INC #pragma once
#define THC_DEVICE_ALLOCATOR_INC
#include <c10/core/Allocator.h> #include <c10/core/Allocator.h>
#include <c10/cuda/CUDAGraphsC10Utils.h> #include <c10/cuda/CUDAGraphsC10Utils.h>
#include <c10/cuda/CUDAMacros.h> #include <c10/cuda/CUDAMacros.h>
@ -166,52 +166,174 @@ struct SnapshotInfo {
std::vector<std::vector<TraceEntry>> device_traces; std::vector<std::vector<TraceEntry>> device_traces;
}; };
C10_CUDA_API void* raw_alloc(size_t nbytes); // Allocator config options.
C10_CUDA_API void* raw_alloc_with_stream(size_t nbytes, cudaStream_t stream); enum struct AllocatorBackend : uint8_t {
C10_CUDA_API void raw_delete(void* ptr); NATIVE = 0,
CUDAMALLOCASYNC = 1,
};
C10_CUDA_API Allocator* get(); C10_CUDA_API AllocatorBackend allocatorBackend();
C10_CUDA_API void init(int device_count);
C10_CUDA_API void setMemoryFraction(double fraction, int device);
C10_CUDA_API void setAllocatorSettings(const std::string& env); C10_CUDA_API void setAllocatorSettings(const std::string& env);
C10_CUDA_API void emptyCache();
C10_CUDA_API void cacheInfo(
int dev_id,
size_t* cachedAndFree,
size_t* largestBlock);
C10_CUDA_API void* getBaseAllocation(void* ptr, size_t* size);
C10_CUDA_API void recordStream(const DataPtr&, CUDAStream stream);
C10_CUDA_API DeviceStats getDeviceStats(int device);
C10_CUDA_API void resetAccumulatedStats(int device);
C10_CUDA_API void resetPeakStats(int device);
C10_CUDA_API SnapshotInfo snapshot();
// CUDAGraph interactions // Size pretty-printer
C10_CUDA_API void notifyCaptureBegin( std::string format_size(uint64_t size);
int device,
CaptureId_t graph_id,
MempoolId_t mempool_id);
C10_CUDA_API void notifyCaptureEnd(int device, CaptureId_t graph_id);
C10_CUDA_API void notifyCaptureDestroy(int device, MempoolId_t mempool_id);
C10_CUDA_API std::mutex* getFreeMutex();
C10_CUDA_API void recordHistory(
bool enabled,
CreateContextFn context_recorder,
size_t alloc_trace_max_entries,
bool alloc_trace_record_context);
using OutOfMemoryObserver = std::function<void( using OutOfMemoryObserver = std::function<void(
int64_t device, int64_t device,
int64_t allocated, int64_t allocated,
int64_t device_total, int64_t device_total,
int64_t device_free)>; int64_t device_free)>;
C10_CUDA_API void attachOutOfMemoryObserver(OutOfMemoryObserver observer);
C10_CUDA_API std::shared_ptr<void> getIpcDevPtr(std::string handle); #define FORALL_ALLOCATOR_INTERFACE(_) \
_(C10_CUDA_API void*, raw_alloc, (size_t nbytes)) \
_(C10_CUDA_API void*, \
raw_alloc_with_stream, \
(size_t nbytes, cudaStream_t stream)) \
_(C10_CUDA_API void, raw_delete, (void* ptr)) \
_(C10_CUDA_API Allocator*, get, ()) \
_(C10_CUDA_API void, init, (int device_count)) \
_(C10_CUDA_API void, setMemoryFraction, (double fraction, int device)) \
_(C10_CUDA_API void, emptyCache, ()) \
_(C10_CUDA_API void, cacheInfo, (int dev_id, size_t* largestBlock)) \
_(C10_CUDA_API void*, getBaseAllocation, (void* ptr, size_t* size)) \
_(C10_CUDA_API void, recordStream, (const DataPtr&, CUDAStream stream)) \
_(C10_CUDA_API DeviceStats, getDeviceStats, (int device)) \
_(C10_CUDA_API void, resetAccumulatedStats, (int device)) \
_(C10_CUDA_API void, resetPeakStats, (int device)) \
_(C10_CUDA_API SnapshotInfo, snapshot, ()) \
_(C10_CUDA_API void, \
notifyCaptureBegin, \
(int device, CaptureId_t graph_id, MempoolId_t mempool_id)) \
_(C10_CUDA_API void, \
notifyCaptureAboutToEnd, \
(int device, CaptureId_t graph_id)) \
_(C10_CUDA_API void, notifyCaptureEnded, (int device, CaptureId_t graph_id)) \
_(C10_CUDA_API void, \
notifyCaptureDestroy, \
(int device, MempoolId_t mempool_id)) \
_(C10_CUDA_API std::mutex*, getFreeMutex, ()) \
_(C10_CUDA_API std::shared_ptr<void>, getIpcDevPtr, (std::string handle)) \
_(C10_CUDA_API void, \
recordHistory, \
(bool enabled, \
CreateContextFn context_recorder, \
size_t alloc_trace_max_entries, \
bool alloc_trace_record_context)) \
_(C10_CUDA_API void, \
attachOutOfMemoryObserver, \
(OutOfMemoryObserver observer))
// Allocator backend function pointers, statically initialized
// according to PYTORCH_CUDA_ALLOC_CONF.
// See BackendInitializer in CUDACachingAllocator.cpp.
namespace Chosen {
#define DECLARE_CHOSEN(RET, FUNC, ARGS) extern RET(*FUNC) ARGS;
FORALL_ALLOCATOR_INTERFACE(DECLARE_CHOSEN)
#undef DECLARE_CHOSEN
} // namespace Chosen
// Called directly by clients.
inline void* raw_alloc(size_t nbytes) {
return Chosen::raw_alloc(nbytes);
}
inline void* raw_alloc_with_stream(size_t nbytes, cudaStream_t stream) {
return Chosen::raw_alloc_with_stream(nbytes, stream);
}
inline void raw_delete(void* ptr) {
return Chosen::raw_delete(ptr);
}
inline Allocator* get() {
return Chosen::get();
}
inline void init(int device_count) {
return Chosen::init(device_count);
}
inline void setMemoryFraction(double fraction, int device) {
return Chosen::setMemoryFraction(fraction, device);
}
inline void emptyCache() {
return Chosen::emptyCache();
}
inline void cacheInfo(int dev_id, size_t* largestBlock) {
return Chosen::cacheInfo(dev_id, largestBlock);
}
inline void* getBaseAllocation(void* ptr, size_t* size) {
return Chosen::getBaseAllocation(ptr, size);
}
inline void recordStream(const DataPtr& dataPtr, CUDAStream stream) {
return Chosen::recordStream(dataPtr, stream);
}
inline DeviceStats getDeviceStats(int device) {
return Chosen::getDeviceStats(device);
}
inline void resetAccumulatedStats(int device) {
return Chosen::resetAccumulatedStats(device);
}
inline void resetPeakStats(int device) {
return Chosen::resetPeakStats(device);
}
inline SnapshotInfo snapshot() {
return Chosen::snapshot();
}
// CUDAGraph interactions
inline void notifyCaptureBegin(
int device,
CaptureId_t graph_id,
MempoolId_t mempool_id) {
return Chosen::notifyCaptureBegin(device, graph_id, mempool_id);
}
inline void notifyCaptureAboutToEnd(int device, CaptureId_t graph_id) {
return Chosen::notifyCaptureAboutToEnd(device, graph_id);
}
inline void recordHistory(
bool enabled,
CreateContextFn context_recorder,
size_t alloc_trace_max_entries,
bool alloc_trace_record_context) {
return Chosen::recordHistory(
enabled,
context_recorder,
alloc_trace_max_entries,
alloc_trace_record_context);
}
inline void attachOutOfMemoryObserver(OutOfMemoryObserver observer) {
return Chosen::attachOutOfMemoryObserver(observer);
}
inline void notifyCaptureEnded(int device, CaptureId_t graph_id) {
return Chosen::notifyCaptureEnded(device, graph_id);
}
inline void notifyCaptureDestroy(int device, MempoolId_t mempool_id) {
return Chosen::notifyCaptureDestroy(device, mempool_id);
}
inline std::mutex* getFreeMutex() {
return Chosen::getFreeMutex();
}
// Not part of CUDA_ALLOCATOR_BACKEND_INTERFACE
inline std::shared_ptr<void> getIpcDevPtr(std::string handle) {
return Chosen::getIpcDevPtr(handle);
}
} // namespace CUDACachingAllocator } // namespace CUDACachingAllocator
} // namespace cuda } // namespace cuda
} // namespace c10 } // namespace c10
#endif

View File

@ -0,0 +1,924 @@
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/cuda/CUDAException.h>
#include <c10/cuda/CUDAFunctions.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/util/UniqueVoidPtr.h>
#include <c10/util/flat_hash_map.h>
#include <c10/util/irange.h>
#include <unordered_set>
#include <vector>
namespace c10 {
namespace cuda {
namespace CUDACachingAllocator {
namespace CudaMallocAsync {
#if CUDA_VERSION >= 11040
// CUDA device allocator that uses cudaMallocAsync to implement
// the same interface as CUDACachingAllocator.cpp.
// Designed to be safe for CUDA graph capture.
// Interactions with CUDA graph capture are mediated by
// notifyCaptureBegin
// notifyCaptureAboutToEnd
// notifyCaptureEnded
// notifyCaptureDestroy
// Implementation details, not declared in CUDACachingAllocator.h
namespace {
// General helpers
struct UsageStream {
cudaStream_t stream;
int device;
UsageStream() {}
UsageStream(cudaStream_t s, int d) : stream(s), device(d) {}
UsageStream(const UsageStream& us) : stream(us.stream), device(us.device) {}
UsageStream(const UsageStream&& us) : stream(us.stream), device(us.device) {}
UsageStream& operator=(UsageStream other) {
stream = other.stream;
device = other.device;
return *this;
}
};
bool operator==(const UsageStream& lhs, const UsageStream& rhs) {
return (lhs.stream == rhs.stream) && (lhs.device == rhs.device);
}
struct UsageStreamHash {
size_t operator()(const UsageStream& us) const noexcept {
return std::hash<void*>{}(us.stream) + size_t(us.device);
}
};
struct PtrUsage {
// recorded_streams holds side usage streams added by record_stream calls.
// In other words, it does NOT include the original creation stream.
ska::flat_hash_set<UsageStream, UsageStreamHash> recorded_streams;
UsageStream creation_stream;
uint64_t size;
bool captured;
PtrUsage(uint64_t s, bool c) : size(s), captured(c) {}
};
int device_count = 0;
// these don't need to be c10::once_flags as in CUDAGeneratorImpl.cpp
// because they'll only be flipped by functions that have locked the mutex.
std::vector<bool> devs_initialized_flags;
std::vector<UsageStream> dummy_unifying_free_streams;
// Possible micro-optimization:
// Some accesses to ptr_info are read-only.
// We could let those be concurrent with a shared_mutex and
// have concurrent calls take a shared_lock.
// Keeping it simple with an ordinary mutex for now.
std::mutex general_mutex;
// Copy-paste CUDACachingAllocator's
// lock around calls to cudaFree (to prevent deadlocks with NCCL)
// is this safe?
std::mutex cuda_free_mutex;
/**
* Note [Avoid freeing uncaptured ptrs during CUDA graph capture]
* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
* During CUDA graph capture, it's illegal to call cudaFreeAsync
* on a pointer that came from a non-captured cudaMallocAsync.
* Unfortunately, Python being what it is, it's impossible to be
* sure no uncaptured tensor will ever have its destructor called
* in a capturing region.
* We avoid errors by
* 1. remembering if allocated pointers were captured or uncaptured
* 2. during capture, if we detect an attempt to free an uncaptured
* allocation on a capturing stream, don't free it immediately,
* just remember it and defer its cudaFreeAsync call to after
* the end of capture (specifically, to notifyCaptureEnded).
*/
using PtrInfo = ska::flat_hash_map<void*, PtrUsage>;
PtrInfo ptr_info;
std::vector<void*> ungraphed_ptrs_defer_free_until_no_capture;
// These two help setMemoryFraction limit the amount of memory
// used by PyTorch in particular (as opposed to other libraries
// in the same process that might be sharing the same cudaMemPool_t).
std::vector<size_t> pytorch_used_bytes;
std::vector<size_t> pytorch_memory_limits;
// Graph-specific helpers
/**
* Note [Avoid dangling free streams during CUDA graph capture]
* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
* During capture, all stream dependencies must branch out from
* the stream on which capture began and rejoin this initial stream
* before capture ends.
* The user rigs desired forking and joining with event waits.
* But it's hard to be sure when tensor destructors get called relative
* to the final joins.
* For example, suppose a user
* forks work stream B from initial capture stream A
* creates a tensor T in B
* joins by syncing A with B
* ends capture.
* All well and good, right? Maybe not: maybe T went out of scope
* and its destructor got called AFTER the rejoin, leaving the graph with
* "unjoined work": a dangling cudaFreeAsync node in stream B.
* Ensuring that all tensor destructors for all side stream tensors
* are called before side streams rejoin the main stream is
* difficult. The user might have to add a bunch of explicit
* "del"s at the right spots in code that was fine for ordinary
* eager execution.
* Fortunately, we can spare the user this burden:
* during capture, we remember _all_ free streams,
* and manually rejoin them with the capture stream during
* notifyCaptureAboutToEnd.
* This approach is heavy-handed, but hopefully capture only needs to
* happen once, so we don't mind being heavy-handed.
*
* TODO: If, someday, we augment the graph bindings to support recapture
* https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#whole-graph-update
* (eg, as a way to accommodate dynamic params) we should think more
* carefully about the CPU overhead of remembering and rejoining
* all free streams during capture. Maybe it's not a big deal.
*/
std::unordered_set<UsageStream, UsageStreamHash> capture_free_streams;
bool capture_underway = false;
// Implementation functions
// Assumes the caller holds general_mutex
inline void lazy_init_device(int device) {
if (!devs_initialized_flags[device]) {
CUDAGuard g(device);
// See "Retaining memory in the pool" here:
// https://developer.nvidia.com/blog/using-cuda-stream-ordered-memory-allocator-part-1/
cudaMemPool_t mempool;
C10_CUDA_CHECK(cudaDeviceGetDefaultMemPool(&mempool, device));
uint64_t threshold = UINT64_MAX;
C10_CUDA_CHECK(cudaMemPoolSetAttribute(
mempool, cudaMemPoolAttrReleaseThreshold, &threshold));
// I think all these are on by default, but I want to enable them
// explicitly to ensure awareness.
int enable = 1;
C10_CUDA_CHECK(cudaMemPoolSetAttribute(
mempool, cudaMemPoolReuseFollowEventDependencies, &enable));
C10_CUDA_CHECK(cudaMemPoolSetAttribute(
mempool, cudaMemPoolReuseAllowOpportunistic, &enable));
C10_CUDA_CHECK(cudaMemPoolSetAttribute(
mempool, cudaMemPoolReuseAllowInternalDependencies, &enable));
// Grabs a stream from the current device to use as the "unifier" free
// stream for allocations that end up used on multiple streams.
const auto dufs = getStreamFromPool();
dummy_unifying_free_streams[device] =
UsageStream(dufs.stream(), dufs.device_index());
pytorch_used_bytes[device] = 0;
pytorch_memory_limits[device] = UINT64_MAX;
devs_initialized_flags[device] = true;
}
}
inline void sync_raw(cudaStream_t dependency, cudaStream_t dependent) {
// CUDACachingAllocator.cpp uses raw cuda events, as do we.
cudaEvent_t event;
C10_CUDA_CHECK(cudaEventCreateWithFlags(&event, cudaEventDisableTiming));
C10_CUDA_CHECK(cudaEventRecord(event, dependency));
C10_CUDA_CHECK(cudaStreamWaitEvent(dependent, event));
C10_CUDA_CHECK(cudaEventDestroy(event));
}
// Assumes the caller holds general_mutex
inline void free_impl(PtrInfo::iterator& it) {
// Possible micro-optimization: If we did a value-copy here, we could move
// ptr_info.erase(it) up here and drop the lock immediately.
const auto& recorded_streams = it->second.recorded_streams;
const auto& creation_stream = it->second.creation_stream;
// If the usage stream is a null (default) stream,
// cudaFreeAsync infers the device from the ambient context,
// so we need to set the right ambient context.
CUDAGuard g(creation_stream.device);
if (recorded_streams.empty()) {
// ptr was only used on one stream, which must have been
// the original allocation stream.
// Frees ptr in the original allocation stream.
C10_CUDA_CHECK(cudaFreeAsync(it->first, creation_stream.stream));
if (C10_UNLIKELY(capture_underway)) {
// See Note [Avoid dangling free streams during CUDA graph capture]
capture_free_streams.insert(creation_stream);
}
} else {
// ptr was used on many streams. We don't know which was the most recent.
// There could even have been multiple most recent usage streams acting
// on different regions of the memory.
// But cudaFreeAsync only accepts a single most recent usage stream.
// We can still safely free ptr with a trick:
// Use a dummy "unifying stream", sync the unifying stream with all of
// ptr's usage streams, and pass the dummy stream to cudaFreeAsync.
// Retrieves the dummy "unifier" stream from the device
// on which the pointer was originally allocated.
auto dummy_unifying_free_stream =
dummy_unifying_free_streams[creation_stream.device];
TORCH_INTERNAL_ASSERT(
dummy_unifying_free_stream.device == creation_stream.device);
// we're already on creation_stream.device, no need to re-guard
sync_raw(creation_stream.stream, dummy_unifying_free_stream.stream);
// The number of usage streams is typically small (low single digits)
for (const auto& recorded_stream : recorded_streams) {
// Logic here accommodates the chance some of the usage streams were on
// other devices, which is possible if some usage kernels accessed the
// memory via p2p.
// cudaEventRecord requires that the input event and stream are on the
// same device.
CUDAGuard g_usage(recorded_stream.device);
sync_raw(recorded_stream.stream, dummy_unifying_free_stream.stream);
}
// Frees ptr in the dummy "unifier" stream.
C10_CUDA_CHECK(cudaFreeAsync(it->first, dummy_unifying_free_stream.stream));
// At this point, unless dummy_unifying_free_stream happens to alias some
// future user stream, the allocation is only available for "opportunistic"
// reuse, ie, if the CPU sees dummy_unifying_free_stream has reached the
// point that all events recorded on all usage streams have resolved from
// the CPU's perspective. In theory, we could remove the need for the driver
// to do this tracking by e.g. replacing
// cudaStreamWaitEvent(dummy_unifying_free_stream.stream, event);
// with
// cudaStreamWaitEvent(creation_stream.stream, event);
// then cudaFreeAsyncing straight back into creation_stream.stream,
// but this forces a potentially false dependency of creation_stream.stream
// on all the recorded_streams.
if (C10_UNLIKELY(capture_underway)) {
// See Note [Avoid dangling free streams during CUDA graph capture]
capture_free_streams.insert(UsageStream(
dummy_unifying_free_stream.stream,
dummy_unifying_free_stream.device));
}
}
pytorch_used_bytes[creation_stream.device] -= it->second.size;
ptr_info.erase(it);
}
void freeAsync(void* ptr) {
std::lock_guard<std::mutex> lk(general_mutex);
auto err = cudaGetLastError();
C10_CUDA_CHECK(err);
auto it = ptr_info.find(ptr);
TORCH_INTERNAL_ASSERT(it != ptr_info.end(), "ptr not found in ptr_info");
if (C10_UNLIKELY(capture_underway)) {
if (!it->second.captured) {
TORCH_WARN_ONCE(
"freeAsync() was called on an uncaptured allocation during graph capture "
"(address = ",
ptr,
"). This may be benign, for example, a Python tensor in the capture "
"might happen to shadow (use the same name as) an unrelated temporary "
"tensor from somewhere before capture, pushing the earlier tensor "
"out of scope. "
"However, if the tensor we're freeing here IS used by the capture, "
"freeing it is an error, and may cause illegal memory accesses or "
"memory corruption during graph replay.");
// See Note [Avoid freeing uncaptured ptrs during CUDA graph capture]
// Remembers the raw pointer, not the iterator.
// This forces notifyCaptureEnded to do another lookup,
// but avoids the risk the iterator might be invalidated
// between now and then.
ungraphed_ptrs_defer_free_until_no_capture.push_back(ptr);
return;
}
} else if (C10_UNLIKELY(it->second.captured)) {
TORCH_WARN(
"Attempting uncaptured free of a captured allocation with address ",
ptr,
"\nThis is technically allowed, but may indicate you are losing "
"the last user-visible tensor through which the allocation can "
"be accessed, so you'll have no way to view the data after "
"future replays of the owning graph.");
}
free_impl(it);
}
// Symmetric with NativeCachingAllocator::malloc for now,
// although I don't think we absolutely need the symmetry.
void mallocAsync(void** devPtr, int device, size_t size, cudaStream_t stream) {
TORCH_INTERNAL_ASSERT(
0 <= device && device < device_count,
"Invalid device index ",
device,
": did you call init?");
// If stream is a null (default) stream,
// cudaMallocAsync infers the device from the ambient context,
// so we need to set the right ambient context.
CUDAGuard g(device);
std::lock_guard<std::mutex> lk(general_mutex);
lazy_init_device(device);
// Defensively checks for preexisting CUDA error state.
auto err = cudaGetLastError();
C10_CUDA_CHECK(err);
// TODO: Could we avoid calling cudaMallocAsync while holding general_mutex,
// perhaps by letting lazy_init_device use separate once_flags or an internal
// static initializer?
if (pytorch_used_bytes[device] + size > pytorch_memory_limits[device]) {
err = cudaErrorMemoryAllocation;
} else {
err = cudaMallocAsync(devPtr, size, stream);
}
if (err == cudaErrorMemoryAllocation) {
// Clears CUDA's internal error state so the user, if desired, can catch the
// OOM exception, free some stuff on the script side, and retry the
// allocation. This aligns with the behavior of alloc_block in
// CUDACachingAllocator.cpp.
cudaGetLastError();
size_t device_free;
size_t device_total;
C10_CUDA_CHECK(cudaMemGetInfo(&device_free, &device_total));
TORCH_CHECK_WITH(
OutOfMemoryError,
false,
"Allocation on device ",
device,
" would exceed allowed memory. (out of memory)",
"\nCurrently allocated : ",
format_size(pytorch_used_bytes[device]),
"\nRequested : ",
format_size(size),
"\nDevice limit : ",
format_size(device_total),
"\nFree (according to CUDA): ",
format_size(device_free),
"\nPyTorch limit (set by user-supplied memory fraction)"
"\n : ",
format_size(pytorch_memory_limits[device]));
} else {
C10_CUDA_CHECK(err);
}
auto inserted = ptr_info.emplace(*devPtr, PtrUsage(size, capture_underway));
TORCH_INTERNAL_ASSERT(
inserted.second,
"address returned by cudaMallocAsync already exists "
"in ptr_info");
inserted.first->second.creation_stream = {stream, device};
pytorch_used_bytes[device] += size;
}
} // anonymous namespace
// Same pattern as CUDACachingAllocator.cpp.
struct CudaMallocAsyncAllocator : public Allocator {
DataPtr allocate(size_t size) const override {
constexpr size_t one_exa_bytes = 1152921504606846976ULL;
TORCH_CHECK_WITH(
OutOfMemoryError,
size < one_exa_bytes,
"CUDA out of memory. Tried to allocate more than 1EB memory.");
int device;
C10_CUDA_CHECK(cudaGetDevice(&device));
void* r = nullptr;
if (size != 0) {
mallocAsync(&r, device, size, cuda::getCurrentCUDAStream(device));
}
return {r, r, &raw_delete, Device(DeviceType::CUDA, device)};
}
DeleterFnPtr raw_deleter() const override {
return &raw_delete;
}
};
CudaMallocAsyncAllocator device_allocator;
// Interface functions declared in CUDACachingAllocator.h
Allocator* get(void) {
return &device_allocator;
}
// This function should not issue any context-creating calls,
// just set up for later calls to init per-device pools based
// on the current device each later call sees.
void init(int dev_count) {
static bool called = [](int dev_count) {
;
// Are there external guarantees init will be called before
// any of the allocator's other functions?
// std::lock_guard<std::mutex> lk(general_mutex);
device_count = dev_count;
devs_initialized_flags.resize(dev_count, false);
dummy_unifying_free_streams.resize(dev_count);
pytorch_used_bytes.resize(dev_count);
pytorch_memory_limits.resize(dev_count);
return true;
}(dev_count);
(void)called;
}
static inline void assertValidDevice(int device) {
TORCH_CHECK(0 <= device && device < device_count, "Invalid device argument.");
}
void setMemoryFraction(double fraction, int device) {
TORCH_INTERNAL_ASSERT(
0 <= fraction && fraction <= 1,
"invalid fraction:",
fraction,
". Please set within (0, 1).");
std::lock_guard<std::mutex> lk(general_mutex);
assertValidDevice(device);
CUDAGuard g(device);
// Should setMemoryFraction be allowed to trigger a full device context and
// pool-creating lazy_init_device, or should we simply assert this device is
// already initialized, ie
// TORCH_CHECK(devs_initialized_flags[device], ...)?
lazy_init_device(device);
size_t device_free;
size_t device_total;
C10_CUDA_CHECK(cudaMemGetInfo(&device_free, &device_total));
pytorch_memory_limits[device] =
static_cast<uint64_t>(fraction * device_total);
// Alternative: Instead of a manual hard limit, we could use
// cudaMemPoolSetAttribute(mempool, cudaMemPoolAttrReleaseThreshold,
// &threshold); This is a soft hint: The driver allows the pool's reserved
// memory to spike above threshold in regions of high cudaMallocAsync demand,
// but opportunistically trims reserved memory back to threshold when the
// memory in use is < threshold. I don't like this because it introduces
// performance nondeterminism.
}
void emptyCache(void) {
std::lock_guard<std::mutex> lk(general_mutex);
for (int dev = 0; dev < device_count; dev++) {
if (devs_initialized_flags[dev]) {
CUDAGuard g(dev);
cudaMemPool_t mempool;
cudaDeviceGetDefaultMemPool(&mempool, dev);
cudaDeviceSynchronize();
cudaMemPoolTrimTo(mempool, 0);
}
}
}
void cacheInfo(int device, size_t* maxWorkspaceGuess) {
// The only consumer of cacheInfo is getMaxWorkspaceSize in Conv_v7.cpp.
// Afaict, the role of cacheInfo is to give getMaxWorkspaceSize a reasonable
// maximum workspace size to use for an upcoming cudnnFind call.
//
// The native allocator's cacheInfo chooses to return the size of its largest
// unused block (which is the largest allocation the native allocator can
// service immediately and asynchronously without a cudaMalloc.
//
// Here, we use a different heuristic: figure out the max usable workspace
// size with a bit of educated trial and error. It's ok to be perf-inefficient
// because cacheInfo is a prelude to cudnnFind.
//
// The algo cache then stores the best-performing algo with workspace <=
// maxWorkspaceGuess. Later calls with the same param set hit in cache and try
// to allocate the same workspace. If, in one of those future calls, workspace
// allocation fails (ie because less ambient memory is available), the
// bindings rerun cudnnFind, including calling cacheInfo again beforehand to
// estimate a new (smaller) largest-available workspace. Over a few such
// calls, the cache should settle to the algo with a workspace size that's
// small enough to succeed every time (for that param set).
//
// So the strategy here is to return a rough, largeish guess and let the
// bindings retry to trim as needed over time.
//
// The only caveat is, even if a workspace is allocated without OOM errors now
// and in future calls, it's hard to be sure those later error-free
// cudaMallocAsyncs are fast and come straight from the pool (ie,
// cudaMallocAsync didn't need to reserve more memory from the system).
// Hopefully, after repeated workspace requests, the pool's reserved memory
// also stabilizes to a point where they all come straight from the pool.
std::lock_guard<std::mutex> lk(general_mutex);
assertValidDevice(device);
CUDAGuard g(device);
lazy_init_device(device);
size_t free_upper_bound;
size_t device_total;
C10_CUDA_CHECK(cudaMemGetInfo(&free_upper_bound, &device_total));
TORCH_INTERNAL_ASSERT(
free_upper_bound + pytorch_used_bytes[device] <= device_total);
size_t guess = std::min(
free_upper_bound,
pytorch_memory_limits[device] - pytorch_used_bytes[device]);
auto stream = c10::cuda::getCurrentCUDAStream();
void* dummy;
// Defensively checks for preexisting CUDA error state.
auto err = cudaGetLastError();
C10_CUDA_CHECK(err);
while (true) {
// Duplicates some logic from mallocAsync to work with the error state
// directly instead of repeatedly catching an exception thrown by
// mallocAsync.
if (pytorch_used_bytes[device] + guess > pytorch_memory_limits[device]) {
err = cudaErrorMemoryAllocation;
} else {
err = cudaMallocAsync(&dummy, guess, stream);
}
if (err == cudaSuccess) {
cudaFreeAsync(dummy, stream);
*maxWorkspaceGuess = guess;
return;
} else if (err == cudaErrorMemoryAllocation) {
cudaGetLastError(); // clear CUDA error
guess >>= 1; // quick and dirty: try half the size next iteration
} else {
C10_CUDA_CHECK(err);
}
}
}
void* getBaseAllocation(void* ptr, size_t* size) {
std::lock_guard<std::mutex> lk(general_mutex);
auto it = ptr_info.find(ptr);
TORCH_INTERNAL_ASSERT(it != ptr_info.end(), "ptr not found in ptr_info");
if (size) {
*size = it->second.size;
}
return ptr;
}
void recordStream(const DataPtr& ptr, cuda::CUDAStream stream) {
std::lock_guard<std::mutex> lk(general_mutex);
auto ptr_val = ptr.get();
// Empty tensor's storage().data() might be a null ptr. As there is no
// blocks associated with those tensors, it is fine to do nothing here.
if (!ptr_val) {
return;
}
// The pointer should exist in the map already.
auto it = ptr_info.find(ptr_val);
TORCH_INTERNAL_ASSERT(it != ptr_info.end(), "ptr not found in ptr_info");
UsageStream to_record{stream.stream(), stream.device_index()};
if (to_record == it->second.creation_stream) {
TORCH_WARN(
"Called record_stream on tensor whose original creation stream "
"matches the recorded stream. This is unnecessary and has no effect.");
} else {
it->second.recorded_streams.insert(to_record);
}
}
std::mutex* getFreeMutex() {
return &cuda_free_mutex;
}
std::shared_ptr<void> getIpcDevPtr(std::string handle) {
TORCH_CHECK(
false,
"cudaMallocAsync does not yet support getIpcDevPtr. "
"If you need it, please file an issue describing your use case.");
}
void recordHistory(
bool enabled,
CreateContextFn context_recorder,
size_t alloc_trace_max_entries,
bool alloc_trace_record_context) {
TORCH_CHECK(
false,
"cudaMallocAsync does not yet support recordHistory. "
"If you need it, please file an issue describing your use case.");
}
void attachOutOfMemoryObserver(OutOfMemoryObserver observer) {
TORCH_CHECK(
false,
"cudaMallocAsync does not yet support attachOutOfMemoryObserver. "
"If you need it, please file an issue describing your use case.");
}
// Collects stats for device.
// If device hasn't been used yet, returns 0s without creating a context.
DeviceStats getDeviceStats(int device) {
assertValidDevice(device);
// Memory currently reserved by the mempool
uint64_t reserved_mem_current = 0;
// High-water mark of memory reserved by the mempool since last reset
uint64_t reserved_mem_peak = 0;
// Memory currently in use by the mempool
uint64_t used_mem_current = 0;
// High-water mark of memory
uint64_t used_mem_peak = 0;
std::lock_guard<std::mutex> lk(general_mutex);
if (devs_initialized_flags[device]) {
CUDAGuard g(device);
cudaMemPool_t mempool;
C10_CUDA_CHECK(cudaDeviceGetDefaultMemPool(&mempool, device));
C10_CUDA_CHECK(cudaMemPoolGetAttribute(
mempool, cudaMemPoolAttrReservedMemCurrent, &reserved_mem_current));
C10_CUDA_CHECK(cudaMemPoolGetAttribute(
mempool, cudaMemPoolAttrReservedMemHigh, &reserved_mem_peak));
C10_CUDA_CHECK(cudaMemPoolGetAttribute(
mempool, cudaMemPoolAttrUsedMemCurrent, &used_mem_current));
C10_CUDA_CHECK(cudaMemPoolGetAttribute(
mempool, cudaMemPoolAttrUsedMemHigh, &used_mem_peak));
}
// Many stat types are specific to the native allocator. We leave these
// untouched. Their "struct Stat"s will contain zeroed values.
DeviceStats stats;
// In the native allocator:
// allocated_bytes is the total bytes of blocks that have been malloc()ed
// and not yet free()d.
// active_bytes is the total bytes of blocks that have been malloc()ed but not
// yet released back into a free pool. In other words, it includes all
// allocated_bytes, as well as the bytes of "limbo state" blocks had have
// already been free()ed but not yet free_block()ed back into a pool due to
// outstanding stream_uses.
//
// Here, in the cudaMallocAsync allocator:
// We simply ask the driver's opinion about active memory.
// We don't bother distinguishing between allocated_bytes and active_bytes.
stats.allocated_bytes[static_cast<size_t>(StatType::AGGREGATE)].current =
used_mem_current;
stats.allocated_bytes[static_cast<size_t>(StatType::AGGREGATE)].peak =
used_mem_peak;
stats.active_bytes[static_cast<size_t>(StatType::AGGREGATE)].current =
used_mem_current;
stats.active_bytes[static_cast<size_t>(StatType::AGGREGATE)].peak =
used_mem_peak;
stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)].current =
reserved_mem_current;
stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)].peak =
reserved_mem_peak;
return stats;
}
void resetAccumulatedStats(int device) {
assertValidDevice(device);
TORCH_WARN_ONCE(
"For backend:cudaMallocAsync, resetAccumulatedStats has no effect.");
}
void resetPeakStats(int device) {
assertValidDevice(device);
CUDAGuard g(device);
cudaMemPool_t mempool;
C10_CUDA_CHECK(cudaDeviceGetDefaultMemPool(&mempool, device));
// Using zero as the reset value is the method recommended by Cuda driver
// team. Vivek Kini says:
// "Resetting to zero (which is the only valid value when setting
// ReservedMemHigh) resets it to ReservedMemCurrent inside the driver
// (same goes for UsedMemHigh/UsedMemCurrent)"
uint64_t zero = 0;
C10_CUDA_CHECK(
cudaMemPoolSetAttribute(mempool, cudaMemPoolAttrReservedMemHigh, &zero));
C10_CUDA_CHECK(
cudaMemPoolSetAttribute(mempool, cudaMemPoolAttrUsedMemHigh, &zero));
}
SnapshotInfo snapshot() {
TORCH_CHECK(
false,
"Calling snapshot with backend:cudaMallocAsync is not meaningful. "
"(For backend:native, snapshot returns a detailed summary of all "
"blocks tracked by the allocator, but the cudaMallocAsync backend "
"does not track individual blocks.)");
// Alternative: TORCH_WARN
return {};
}
// CUDAGraph interactions
void notifyCaptureBegin(
int device,
CaptureId_t graph_id,
MempoolId_t mempool_id) {
std::lock_guard<std::mutex> lk(general_mutex);
TORCH_INTERNAL_ASSERT(capture_free_streams.empty());
TORCH_CHECK(
!capture_underway, "Only one capture at a time is allowed in a process.")
capture_underway = true;
}
void notifyCaptureAboutToEnd(int device, CaptureId_t graph_id) {
assertValidDevice(device);
std::lock_guard<std::mutex> lk(general_mutex);
TORCH_CHECK(
capture_underway,
"CudaMallocAsync::notifyCaptureAboutToEnd called, "
"but CudaMallocAsync::capture_underway is false.");
auto capture_stream = cuda::getCurrentCUDAStream(device);
// See Note [Avoid dangling free streams during CUDA graph capture]
for (const auto& free_stream : capture_free_streams) {
// cudaEventRecord requires that the input event and stream are on the same
// device.
CUDAGuard g(free_stream.device);
// CUDACachingAllocator.cpp uses raw cuda events, as do we.
cudaEvent_t event;
C10_CUDA_CHECK(cudaEventCreateWithFlags(&event, cudaEventDisableTiming));
C10_CUDA_CHECK(cudaEventRecord(event, free_stream.stream));
C10_CUDA_CHECK(cudaStreamWaitEvent(capture_stream.stream(), event));
C10_CUDA_CHECK(cudaEventDestroy(event));
}
capture_free_streams.clear();
}
void notifyCaptureEnded(int device, CaptureId_t graph_id) {
assertValidDevice(device);
std::lock_guard<std::mutex> lk(general_mutex);
TORCH_CHECK(
capture_underway,
"CudaMallocAsync::notifyCaptureEnded called, "
"but CudaMallocAsync::capture_underway is false.");
capture_underway = false;
// See Note [Avoid freeing uncaptured ptrs during CUDA graph capture]
for (const auto ptr : ungraphed_ptrs_defer_free_until_no_capture) {
auto it = ptr_info.find(ptr);
TORCH_INTERNAL_ASSERT(it != ptr_info.end(), "ptr not found in ptr_info");
free_impl(it);
}
ungraphed_ptrs_defer_free_until_no_capture.clear();
}
void notifyCaptureDestroy(int device, MempoolId_t mempool_id) {
// Q: Do we need to do anything special here, like clear long-lived
// pointers created during the original capture (for example,
// tensors intended as the graph's I/O surface) that might still
// be resident in ptr_info?
// A: I don't think so.
// Those allocations survived capture because the user held
// explicit tensor references to them,
// Those tensors' destructors will call freeAsync() on each pointer
// when the user is done with them.
// The freeAsync()s will probably incur
// TORCH_WARN("Attempting uncaptured free of a captured allocation..."
// but stale ptrs will not permanently leak into ptr_info.
}
void* raw_alloc(size_t nbytes) {
if (nbytes == 0) {
return nullptr;
}
int device;
C10_CUDA_CHECK(cudaGetDevice(&device));
void* r = nullptr;
mallocAsync(&r, device, nbytes, cuda::getCurrentCUDAStream(device));
return r;
}
void* raw_alloc_with_stream(size_t nbytes, cudaStream_t stream) {
if (nbytes == 0) {
return nullptr;
}
int device;
C10_CUDA_CHECK(cudaGetDevice(&device));
void* r = nullptr;
mallocAsync(&r, device, nbytes, stream);
return r;
}
void raw_delete(void* ptr) {
freeAsync(ptr);
}
#else
// CUDA_VERSION is < 11040
#define NOT_AVAILABLE(name) \
TORCH_CHECK( \
false, \
"Called CudaMallocAsync::" name \
" but PyTorch was built with cuda < 11.4.");
void* raw_alloc(size_t nbytes) {
NOT_AVAILABLE("raw_alloc");
}
void* raw_alloc_with_stream(size_t nbytes, cudaStream_t stream) {
NOT_AVAILABLE("raw_alloc_with_stream");
}
void raw_delete(void* ptr){NOT_AVAILABLE("raw_delete")} Allocator* get() {
NOT_AVAILABLE("get");
}
void init(int device_count) {
NOT_AVAILABLE("init");
}
void setMemoryFraction(double fraction, int device) {
NOT_AVAILABLE("setMemoryFraction");
}
void emptyCache() {
NOT_AVAILABLE("emptyCache");
}
void cacheInfo(int device, size_t* maxWorkspaceGuess) {
NOT_AVAILABLE("cacheInfo");
}
void* getBaseAllocation(void* ptr, size_t* size) {
NOT_AVAILABLE("getBaseAllocation");
}
void recordStream(const DataPtr&, CUDAStream stream) {
NOT_AVAILABLE("recordStream");
}
DeviceStats getDeviceStats(int device) {
NOT_AVAILABLE("getDeviceStats");
}
void resetAccumulatedStats(int device) {
NOT_AVAILABLE("resetAccumulatedStats");
}
void resetPeakStats(int device) {
NOT_AVAILABLE("resetPeakStats");
}
SnapshotInfo snapshot() {
NOT_AVAILABLE("snapshot");
}
void notifyCaptureBegin(
int device,
CaptureId_t graph_id,
MempoolId_t mempool_id) {
NOT_AVAILABLE("notifyCaptureBegin");
}
void notifyCaptureAboutToEnd(int device, CaptureId_t graph_id) {
NOT_AVAILABLE("notifyCaptureAboutToEnd");
}
void notifyCaptureEnded(int device, CaptureId_t graph_id) {
NOT_AVAILABLE("notifyCaptureEnded");
}
void notifyCaptureDestroy(int device, MempoolId_t mempool_id) {
NOT_AVAILABLE("notifyCaptureDestroy");
}
std::mutex* getFreeMutex() {
NOT_AVAILABLE("getFreeMutex");
}
std::shared_ptr<void> getIpcDevPtr(std::string handle) {
NOT_AVAILABLE("getIpcDevPtr");
}
void recordHistory(
bool enabled,
CreateContextFn context_recorder,
size_t alloc_trace_max_entries,
bool alloc_trace_record_context) {
NOT_AVAILABLE("recordHistory");
}
void attachOutOfMemoryObserver(OutOfMemoryObserver observer) {
NOT_AVAILABLE("attachOutOfMemoryObserver");
}
#endif
} // namespace CudaMallocAsync
} // namespace CUDACachingAllocator
} // namespace cuda
} // namespace c10

View File

@ -87,6 +87,8 @@ Graphs (beta)
graph graph
make_graphed_callables make_graphed_callables
.. _cuda-memory-management-api:
Memory management Memory management
----------------- -----------------
.. autosummary:: .. autosummary::
@ -111,6 +113,7 @@ Memory management
reset_peak_memory_stats reset_peak_memory_stats
caching_allocator_alloc caching_allocator_alloc
caching_allocator_delete caching_allocator_delete
get_allocator_backend
.. FIXME The following doesn't seem to exist. Is it supposed to? .. FIXME The following doesn't seem to exist. Is it supposed to?
https://github.com/pytorch/pytorch/issues/27785 https://github.com/pytorch/pytorch/issues/27785
.. autofunction:: reset_max_memory_reserved .. autofunction:: reset_max_memory_reserved

View File

@ -40,7 +40,7 @@ only use the subset of Python supported in TorchScript. This section documents
what is supported in TorchScript as if it were a language reference for a stand what is supported in TorchScript as if it were a language reference for a stand
alone language. Any features of Python not mentioned in this reference are not alone language. Any features of Python not mentioned in this reference are not
part of TorchScript. See `Builtin Functions` for a complete reference of available part of TorchScript. See `Builtin Functions` for a complete reference of available
Pytorch tensor methods, modules, and functions. PyTorch tensor methods, modules, and functions.
As a subset of Python, any valid TorchScript function is also a valid Python As a subset of Python, any valid TorchScript function is also a valid Python
function. This makes it possible to `disable TorchScript` and debug the function. This makes it possible to `disable TorchScript` and debug the

View File

@ -1,6 +1,6 @@
.. _jit_unsupported: .. _jit_unsupported:
TorchScript Unsupported Pytorch Constructs TorchScript Unsupported PyTorch Constructs
============================================ ============================================
Torch and Tensor Unsupported Attributes Torch and Tensor Unsupported Attributes

View File

@ -349,27 +349,42 @@ complete snapshot of the memory allocator state via
:meth:`~torch.cuda.memory_snapshot`, which can help you understand the :meth:`~torch.cuda.memory_snapshot`, which can help you understand the
underlying allocation patterns produced by your code. underlying allocation patterns produced by your code.
.. _cuda-memory-envvars:
Environment variables
^^^^^^^^^^^^^^^^^^^^^
Use of a caching allocator can interfere with memory checking tools such as Use of a caching allocator can interfere with memory checking tools such as
``cuda-memcheck``. To debug memory errors using ``cuda-memcheck``, set ``cuda-memcheck``. To debug memory errors using ``cuda-memcheck``, set
``PYTORCH_NO_CUDA_MEMORY_CACHING=1`` in your environment to disable caching. ``PYTORCH_NO_CUDA_MEMORY_CACHING=1`` in your environment to disable caching.
The behavior of caching allocator can be controlled via environment variable The behavior of the caching allocator can be controlled via the environment variable
``PYTORCH_CUDA_ALLOC_CONF``. ``PYTORCH_CUDA_ALLOC_CONF``.
The format is ``PYTORCH_CUDA_ALLOC_CONF=<option>:<value>,<option2>:<value2>...`` The format is ``PYTORCH_CUDA_ALLOC_CONF=<option>:<value>,<option2>:<value2>...``
Available options: Available options:
* ``max_split_size_mb`` prevents the allocator from splitting blocks larger * ``backend`` allows selecting the underlying allocator implementation.
than this size (in MB). This can help prevent fragmentation and may allow Currently, valid options are ``native``, which uses PyTorch's native
some borderline workloads to complete without running out of memory. implementation, and ``cudaMallocAsync``, which uses
Performance cost can range from 'zero' to 'substatial' depending on `CUDA's built-in asynchronous allocator`_.
allocation patterns. Default value is unlimited, i.e. all blocks can be ``cudaMallocAsync`` requires CUDA 11.4 or newer. The default is ``native``.
split. The :meth:`~torch.cuda.memory_stats` and ``backend`` applies to all devices used by the process, and can't be
specified on a per-device basis.
* ``max_split_size_mb`` prevents the native allocator
from splitting blocks larger than this size (in MB). This can reduce
fragmentation and may allow some borderline workloads to complete without
running out of memory. Performance cost can range from 'zero' to 'substantial'
depending on allocation patterns. Default value is unlimited, i.e. all blocks
can be split. The
:meth:`~torch.cuda.memory_stats` and
:meth:`~torch.cuda.memory_summary` methods are useful for tuning. This :meth:`~torch.cuda.memory_summary` methods are useful for tuning. This
option should be used as a last resort for a workload that is aborting option should be used as a last resort for a workload that is aborting
due to 'out of memory' and showing a large amount of inactive split blocks. due to 'out of memory' and showing a large amount of inactive split blocks.
``max_split_size_mb`` is only meaningful with ``backend:native``.
With ``backend:cudaMallocAsync``, ``max_split_size_mb`` is ignored.
* ``roundup_power2_divisions`` helps with rounding the requested allocation * ``roundup_power2_divisions`` helps with rounding the requested allocation
size to nearest power-2 division and making better use of the blocks. In size to nearest power-2 division and making better use of the blocks. In
the current CUDACachingAllocator, the sizes are rounded up in multiple the native CUDACachingAllocator, the sizes are rounded up in multiple
of blocks size of 512, so this works fine for smaller sizes. However, this of blocks size of 512, so this works fine for smaller sizes. However, this
can be inefficient for large near-by allocations as each will go to different can be inefficient for large near-by allocations as each will go to different
size of blocks and re-use of those blocks are minimized. This might create size of blocks and re-use of those blocks are minimized. This might create
@ -379,10 +394,14 @@ Available options:
the size 1200 lies between 1024 and 2048 and if we do 4 divisions between the size 1200 lies between 1024 and 2048 and if we do 4 divisions between
them, the values are 1024, 1280, 1536, and 1792. So, allocation size of 1200 them, the values are 1024, 1280, 1536, and 1792. So, allocation size of 1200
will be rounded to 1280 as the nearest ceiling of power-2 division. will be rounded to 1280 as the nearest ceiling of power-2 division.
``roundup_power2_divisions`` is only meaningful with ``backend:native``.
With ``backend:cudaMallocAsync``, ``roundup_power2_divisions`` is ignored.
* ``roundup_bypass_threshold_mb`` bypass rounding the requested allocation size, * ``roundup_bypass_threshold_mb`` bypass rounding the requested allocation size,
for allocation requests larger than the threshold value (in MB). This can help for allocation requests larger than the threshold value (in MB). This can help
reduce the memory footprint when making large allocations that are expected to reduce the memory footprint when making large allocations that are expected to
be persistent or have a large lifetime. be persistent or have a large lifetime.
``roundup_bypass_threshold_mb`` is only meaningful with ``backend:native``.
With ``backend:cudaMallocAsync``, ``roundup_bypass_threshold_mb`` is ignored.
* ``garbage_collection_threshold`` helps actively reclaiming unused GPU memory to * ``garbage_collection_threshold`` helps actively reclaiming unused GPU memory to
avoid triggering expensive sync-and-reclaim-all operation (release_cached_blocks), avoid triggering expensive sync-and-reclaim-all operation (release_cached_blocks),
which can be unfavorable to latency-critical GPU applications (e.g., servers). which can be unfavorable to latency-critical GPU applications (e.g., servers).
@ -391,6 +410,19 @@ Available options:
80% of the total memory allocated to the GPU application). The algorithm prefers 80% of the total memory allocated to the GPU application). The algorithm prefers
to free old & unused blocks first to avoid freeing blocks that are actively being to free old & unused blocks first to avoid freeing blocks that are actively being
reused. The threshold value should be between greater than 0.0 and less than 1.0. reused. The threshold value should be between greater than 0.0 and less than 1.0.
``garbage_collection_threshold`` is only meaningful with ``backend:native``.
With ``backend:cudaMallocAsync``, ``garbage_collection_threshold`` is ignored.
.. note::
Some stats reported by the
:ref:`CUDA memory management API<cuda-memory-management-api>`
are specific to ``backend:native``, and are not meaningful with
``backend:cudaMallocAsync``.
See each function's docstring for details.
.. _CUDA's built-in asynchronous allocator:
https://developer.nvidia.com/blog/using-cuda-stream-ordered-memory-allocator-part-1/
.. _cufft-plan-cache: .. _cufft-plan-cache:

View File

@ -23,7 +23,7 @@ feature. A section at the end discusses the extensions for forward mode AD.
When to use When to use
^^^^^^^^^^^ ^^^^^^^^^^^
In general, implement a custom function if you want to perform computations in your model In general, implement a custom function if you want to perform computations in your model
that are not differentiable or rely on non-Pytorch libraries (e.g., NumPy), but that are not differentiable or rely on non-PyTorch libraries (e.g., NumPy), but
still wish for your operation to chain with other ops and work with the autograd engine. still wish for your operation to chain with other ops and work with the autograd engine.
In some situations, custom functions can also be used to improve performance and In some situations, custom functions can also be used to improve performance and

View File

@ -103,7 +103,7 @@ There are three types of quantization supported:
3. static quantization aware training (weights quantized, activations quantized, 3. static quantization aware training (weights quantized, activations quantized,
quantization numerics modeled during training) quantization numerics modeled during training)
Please see our `Introduction to Quantization on Pytorch Please see our `Introduction to Quantization on PyTorch
<https://pytorch.org/blog/introduction-to-quantization-on-pytorch/>`_ blog post <https://pytorch.org/blog/introduction-to-quantization-on-pytorch/>`_ blog post
for a more comprehensive overview of the tradeoffs between these quantization for a more comprehensive overview of the tradeoffs between these quantization
types. types.

View File

@ -142,7 +142,7 @@ only:
Sparse hybrid COO tensors Sparse hybrid COO tensors
------------------------- -------------------------
Pytorch implements an extension of sparse tensors with scalar values PyTorch implements an extension of sparse tensors with scalar values
to sparse tensors with (contiguous) tensor values. Such tensors are to sparse tensors with (contiguous) tensor values. Such tensors are
called hybrid tensors. called hybrid tensors.

View File

@ -44,6 +44,7 @@ if not TEST_CUDA:
print('CUDA not available, skipping tests', file=sys.stderr) print('CUDA not available, skipping tests', file=sys.stderr)
TestCase = object # noqa: F811 TestCase = object # noqa: F811
TEST_CUDAMALLOCASYNC = TEST_CUDA and (torch.cuda.get_allocator_backend() == "cudaMallocAsync")
TEST_LARGE_TENSOR = TEST_CUDA TEST_LARGE_TENSOR = TEST_CUDA
TEST_MEDIUM_TENSOR = TEST_CUDA TEST_MEDIUM_TENSOR = TEST_CUDA
TEST_CUDNN = TEST_CUDA TEST_CUDNN = TEST_CUDA
@ -271,6 +272,7 @@ class TestCuda(TestCase):
self.assertEqual(r, 0) self.assertEqual(r, 0)
self.assertFalse(t.is_pinned()) self.assertFalse(t.is_pinned())
@unittest.skipIf(TEST_CUDAMALLOCASYNC, "temporarily disabled")
def test_memory_stats(self): def test_memory_stats(self):
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -322,6 +324,7 @@ class TestCuda(TestCase):
device_capability_no_argument = torch.cuda.get_device_capability() device_capability_no_argument = torch.cuda.get_device_capability()
self.assertEqual(current_device_capability, device_capability_no_argument) self.assertEqual(current_device_capability, device_capability_no_argument)
@unittest.skipIf(TEST_CUDAMALLOCASYNC, "temporarily disabled")
@unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected") @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
def test_memory_stats_multigpu(self): def test_memory_stats_multigpu(self):
# advance a generator with a end flag # advance a generator with a end flag
@ -362,7 +365,9 @@ class TestCuda(TestCase):
def test_out_of_memory(self): def test_out_of_memory(self):
tensor = torch.zeros(1024, device='cuda') tensor = torch.zeros(1024, device='cuda')
with self.assertRaisesRegex(RuntimeError, "Tried to allocate 800000000.00 GiB"): oom_regex = "would exceed allowed memory" if TEST_CUDAMALLOCASYNC else \
"Tried to allocate 800000000.00 GiB"
with self.assertRaisesRegex(RuntimeError, oom_regex):
torch.empty(1024 * 1024 * 1024 * 800000000, dtype=torch.int8, device='cuda') torch.empty(1024 * 1024 * 1024 * 800000000, dtype=torch.int8, device='cuda')
with self.assertRaisesRegex(RuntimeError, "Tried to allocate more than 1EB memory"): with self.assertRaisesRegex(RuntimeError, "Tried to allocate more than 1EB memory"):
@ -372,6 +377,22 @@ class TestCuda(TestCase):
tensor.fill_(1) tensor.fill_(1)
self.assertTrue((tensor == 1).all()) self.assertTrue((tensor == 1).all())
def test_out_of_memory_retry(self):
torch.cuda.empty_cache()
total_memory = torch.cuda.get_device_properties(0).total_memory
oom_regex = "would exceed allowed memory" if TEST_CUDAMALLOCASYNC else \
"Tried to allocate"
size = int(total_memory * 0.5)
a = torch.empty(size , dtype=torch.int8, device='cuda')
with self.assertRaisesRegex(RuntimeError, oom_regex):
b = torch.empty(size, dtype=torch.int8, device='cuda')
del a
b = torch.empty(size, dtype=torch.int8, device='cuda')
del b
# We used a lot of memory here, clean up so we don't affect other tests too much
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
def test_set_per_process_memory_fraction(self): def test_set_per_process_memory_fraction(self):
# test invalid fraction value. # test invalid fraction value.
with self.assertRaisesRegex(TypeError, "Invalid type"): with self.assertRaisesRegex(TypeError, "Invalid type"):
@ -394,7 +415,9 @@ class TestCuda(TestCase):
application = int(total_memory * 0.5) application = int(total_memory * 0.5)
# it will get OOM when try to allocate more than half memory. # it will get OOM when try to allocate more than half memory.
with self.assertRaisesRegex(RuntimeError, "out of memory"): oom_regex = "would exceed allowed memory" if TEST_CUDAMALLOCASYNC else \
"out of memory"
with self.assertRaisesRegex(RuntimeError, oom_regex):
torch.empty(application, dtype=torch.int8, device='cuda') torch.empty(application, dtype=torch.int8, device='cuda')
# ensure out of memory error doesn't disturb subsequent kernel # ensure out of memory error doesn't disturb subsequent kernel
@ -1297,7 +1320,9 @@ class TestCuda(TestCase):
self.assertEqual(result.tolist(), [1, 2, 3, 4]) self.assertEqual(result.tolist(), [1, 2, 3, 4])
# Check that the block will be re-used after the main stream finishes if not TEST_CUDAMALLOCASYNC:
# In the native allocator, we expect "tmp"'s side-stream-tagged block will be reused
# in that side stream after result.copy_(tmp) in the main stream finishes.
torch.cuda.current_stream().synchronize() torch.cuda.current_stream().synchronize()
with torch.cuda.stream(stream): with torch.cuda.stream(stream):
tmp3 = torch.cuda.FloatTensor(t.size()) tmp3 = torch.cuda.FloatTensor(t.size())
@ -3231,7 +3256,9 @@ torch.cuda.synchronize()
TEST_WITH_ROCM or TEST_WITH_ROCM or
int(torch.version.cuda.split(".")[0]) < 11, "CUDA >= 11.0 required for graphs") int(torch.version.cuda.split(".")[0]) < 11, "CUDA >= 11.0 required for graphs")
def test_graph_capture_oom(self): def test_graph_capture_oom(self):
with self.assertRaisesRegex(RuntimeError, "out of memory"): oom_regex = "would exceed allowed memory" if TEST_CUDAMALLOCASYNC else \
"out of memory"
with self.assertRaisesRegex(RuntimeError, oom_regex):
with torch.cuda.graph(torch.cuda.CUDAGraph()): with torch.cuda.graph(torch.cuda.CUDAGraph()):
torch.zeros(2 ** 40, device="cuda") torch.zeros(2 ** 40, device="cuda")
@ -3416,6 +3443,14 @@ torch.cuda.synchronize()
g.capture_end() g.capture_end()
torch.cuda.current_stream().wait_stream(stream) torch.cuda.current_stream().wait_stream(stream)
if not TEST_CUDAMALLOCASYNC:
# Makes sure values haven't been populated yet
# (in other words, makes sure capture didn't actually run ops).
# We can only try this with the native allocator, for which captured
# addresses are already backed by cudaMalloced memory.
# If we try it with cudaMallocAsync, CUDA won't event consider
# the captured addresses allocated until replay(), and if we
# access them before replay() we get IMAs.
try: try:
self.assertNotEqual(control1, t1) self.assertNotEqual(control1, t1)
self.assertNotEqual(control2, t2) self.assertNotEqual(control2, t2)
@ -3442,8 +3477,11 @@ torch.cuda.synchronize()
else: else:
getattr(dummy, op)(*args) getattr(dummy, op)(*args)
# see above comment on TEST_CUDAMALLOCASYNC
if not TEST_CUDAMALLOCASYNC:
t1.copy_(alloc) t1.copy_(alloc)
t2.copy_(alloc) t2.copy_(alloc)
# Runs RNG ops that fill t1 and t2. # Runs RNG ops that fill t1 and t2.
g.replay() g.replay()
@ -3516,6 +3554,8 @@ torch.cuda.synchronize()
self.assertEqual(b.sum().item(), size * 3070) self.assertEqual(b.sum().item(), size * 3070)
self.assertEqual(c.sum().item(), size * 442) self.assertEqual(c.sum().item(), size * 442)
if not TEST_CUDAMALLOCASYNC:
# These stat checks are specific to the native allocator.
if share_mem != "Don't share": if share_mem != "Don't share":
self.assertEqual(reserved_no_sharing - torch.cuda.memory_stats()["reserved_bytes.all.current"], self.assertEqual(reserved_no_sharing - torch.cuda.memory_stats()["reserved_bytes.all.current"],
kSmallBuffer) kSmallBuffer)
@ -3527,11 +3567,14 @@ torch.cuda.synchronize()
torch.cuda.synchronize() torch.cuda.synchronize()
torch.cuda.empty_cache() torch.cuda.empty_cache()
@unittest.skip("Temporarily disabled due to a graphs bug in libcuda.so, " +
"see https://github.com/pytorch/pytorch/pull/57556")
@unittest.skipIf((not TEST_CUDA) or @unittest.skipIf((not TEST_CUDA) or
TEST_WITH_ROCM or TEST_WITH_ROCM or
int(torch.version.cuda.split(".")[0]) < 11, "CUDA >= 11.0 required for graphs") IS_WINDOWS or # appears to still be broken on Windows as of 11.4+
int(torch.version.cuda.split(".")[0]) < 11 or
(int(torch.version.cuda.split(".")[0]) == 11 and
int(torch.version.cuda.split(".")[1]) < 4),
"Graph bindings disallow concurrent replay for CUDA < 11.4, see " +
"https://github.com/pytorch/pytorch/pull/57556")
def test_graph_concurrent_replay(self): def test_graph_concurrent_replay(self):
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -3582,12 +3625,16 @@ torch.cuda.synchronize()
torch.cuda.current_stream().wait_stream(s0) torch.cuda.current_stream().wait_stream(s0)
torch.cuda.current_stream().wait_stream(s1) torch.cuda.current_stream().wait_stream(s1)
if share_mem != "Don't share": if (not TEST_CUDAMALLOCASYNC) and (share_mem != "Don't share"):
# Confirms concurrent replays using the same mempool corrupted each other. # If we used the native allocator and shared mempools,
# we expect the concurrent replays corrupted each other.
self.assertNotEqual(b.sum().item(), size * 94) self.assertNotEqual(b.sum().item(), size * 94)
self.assertNotEqual(c.sum().item(), size * 156) self.assertNotEqual(c.sum().item(), size * 156)
else: else:
# Confirms concurrent replays using different mempools did not corrupt each other. # If we EITHER
# - used the native allocator without sharing mempools, OR
# - used cudaMallocAsync, which ignores graph pool-sharing hints and should always be safe
# we don't expect memory corruption.
self.assertEqual(b.sum().item(), size * 94) self.assertEqual(b.sum().item(), size * 94)
self.assertEqual(c.sum().item(), size * 156) self.assertEqual(c.sum().item(), size * 156)
@ -3647,9 +3694,10 @@ torch.cuda.synchronize()
g2.replay() g2.replay()
g1.replay() g1.replay()
# If share_mem is True, g2's capture should have reused c's memory for f. We replayed g2 then g1, expect_corruption = (not TEST_CUDAMALLOCASYNC) and (share_mem != "Don't share")
# so we expect g1's captured "e = c + 3" mistakenly filled e with "f's vals + 3". # If we used the native allocator and shared mempools, g2's capture should have reused c's memory for f.
self.assertEqual(e.sum().item(), size * (7 + 3) if share_mem != "Don't share" else size * 5) # We replayed g2 then g1, so we expect g1's captured "e = c + 3" mistakenly filled e with "f's vals + 3".
self.assertEqual(e.sum().item(), size * (7 + 3) if expect_corruption else size * 5)
self.assertEqual(f.sum().item(), size * 7) self.assertEqual(f.sum().item(), size * 7)
del a, b, d, e, f, g0, g1, g2 del a, b, d, e, f, g0, g1, g2
@ -3659,6 +3707,7 @@ torch.cuda.synchronize()
@unittest.skipIf((not TEST_CUDA) or @unittest.skipIf((not TEST_CUDA) or
TEST_WITH_ROCM or TEST_WITH_ROCM or
TEST_CUDAMALLOCASYNC or
int(torch.version.cuda.split(".")[0]) < 11, "CUDA >= 11.0 required for graphs") int(torch.version.cuda.split(".")[0]) < 11, "CUDA >= 11.0 required for graphs")
def test_graph_memory_stats_and_use_result_after_destroy_graph(self): def test_graph_memory_stats_and_use_result_after_destroy_graph(self):
kSmallSize = 1048576 kSmallSize = 1048576
@ -3831,6 +3880,8 @@ torch.cuda.synchronize()
g.capture_end() g.capture_end()
torch.cuda.current_stream().wait_stream(s) torch.cuda.current_stream().wait_stream(s)
g.replay()
y = model(x) y = model(x)
@unittest.skipIf((not TEST_CUDA) or @unittest.skipIf((not TEST_CUDA) or
@ -4624,6 +4675,7 @@ class TestCudaComm(TestCase):
cat = torch.cat((outputs[0][i].to('cpu'), outputs[1][i].to('cpu'))) cat = torch.cat((outputs[0][i].to('cpu'), outputs[1][i].to('cpu')))
self.assertTrue(torch.equal(x, cat)) self.assertTrue(torch.equal(x, cat))
@unittest.skipIf(TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync")
def test_memory_snapshot(self): def test_memory_snapshot(self):
try: try:
torch.cuda.memory.empty_cache() torch.cuda.memory.empty_cache()
@ -4667,7 +4719,7 @@ class TestCudaComm(TestCase):
finally: finally:
torch.cuda.memory._record_memory_history(False) torch.cuda.memory._record_memory_history(False)
@unittest.skipIf(TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync")
def test_memory_snapshot_with_cpp(self): def test_memory_snapshot_with_cpp(self):
try: try:
torch.cuda.memory.empty_cache() torch.cuda.memory.empty_cache()
@ -4703,7 +4755,7 @@ class TestCudaComm(TestCase):
return ret return ret
torch.cuda.memory.empty_cache() torch.cuda.memory.empty_cache()
key = 'active_bytes.all.allocated' key = 'active_bytes.all.allocated' if not TEST_CUDAMALLOCASYNC else 'allocated_bytes.all.current'
nelems = 21 * 1024 * 1024 nelems = 21 * 1024 * 1024
nbytes = 4 * nelems # floats are 4 bytes nbytes = 4 * nelems # floats are 4 bytes
@ -4719,6 +4771,8 @@ class TestCudaComm(TestCase):
pow2_div4_mem = torch.cuda.memory_stats()[key] pow2_div4_mem = torch.cuda.memory_stats()[key]
self.assertTrue(reg_mem - start_mem == nbytes) self.assertTrue(reg_mem - start_mem == nbytes)
if not TEST_CUDAMALLOCASYNC:
# not supported with the cudaMallocAsync backend
self.assertTrue(pow2_div4_mem - reg_mem == power2_div(nbytes, 4)) self.assertTrue(pow2_div4_mem - reg_mem == power2_div(nbytes, 4))
torch.cuda.memory._set_allocator_settings("garbage_collection_threshold:0.5") torch.cuda.memory._set_allocator_settings("garbage_collection_threshold:0.5")
@ -4747,6 +4801,7 @@ class TestCudaComm(TestCase):
torch.empty(1024 * 1024 * 1024 * 1024, device='cuda') torch.empty(1024 * 1024 * 1024 * 1024, device='cuda')
@unittest.skipIf(IS_WINDOWS, 'Windows CI does not like the load_inline') @unittest.skipIf(IS_WINDOWS, 'Windows CI does not like the load_inline')
@unittest.skipIf(TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync")
def test_cpp_memory_snapshot_pickle(self): def test_cpp_memory_snapshot_pickle(self):
from torch.utils.cpp_extension import load_inline from torch.utils.cpp_extension import load_inline
source = """ source = """
@ -4778,6 +4833,7 @@ class TestCudaComm(TestCase):
finally: finally:
m.record(False) m.record(False)
@unittest.skipIf(TEST_CUDAMALLOCASYNC, "temporarily disabled")
def test_notifies_oom(self): def test_notifies_oom(self):
x = False x = False

View File

@ -1181,6 +1181,7 @@ def _cuda_resetAccumulatedMemoryStats(device: _int) -> None: ...
def _cuda_resetPeakMemoryStats(device: _int) -> None: ... def _cuda_resetPeakMemoryStats(device: _int) -> None: ...
def _cuda_memorySnapshot() -> Dict[str, Any]: ... def _cuda_memorySnapshot() -> Dict[str, Any]: ...
def _cuda_recordMemoryHistory(enabled: _bool, record_context: _bool, record_context_cpp: _bool, alloc_trace_max_entries: _int, alloc_trace_record_context: _bool) -> None: ... def _cuda_recordMemoryHistory(enabled: _bool, record_context: _bool, record_context_cpp: _bool, alloc_trace_max_entries: _int, alloc_trace_record_context: _bool) -> None: ...
def _cuda_getAllocatorBackend() -> str: ...
def _cuda_lock_mutex() -> None: ... def _cuda_lock_mutex() -> None: ...
def _cuda_unlock_mutex() -> None: ... def _cuda_unlock_mutex() -> None: ...
def _cuda_canDeviceAccessPeer(device: _int, peer_device: _int) -> _bool: ... def _cuda_canDeviceAccessPeer(device: _int, peer_device: _int) -> _bool: ...

View File

@ -355,6 +355,26 @@ PyObject* THCPModule_cudaCachingAllocator_set_allocator_settings(
END_HANDLE_TH_ERRORS END_HANDLE_TH_ERRORS
} }
PyObject* THCPModule_getAllocatorBackend(PyObject* _unused, PyObject* noargs) {
HANDLE_TH_ERRORS
using c10::cuda::CUDACachingAllocator::AllocatorBackend;
AllocatorBackend backend =
c10::cuda::CUDACachingAllocator::allocatorBackend();
// this call should be uncommon, don't bother interning strings
switch (backend) {
case AllocatorBackend::NATIVE:
return THPUtils_packString("native");
#ifndef USE_ROCM
case AllocatorBackend::CUDAMALLOCASYNC:
return THPUtils_packString("cudaMallocAsync");
#endif
default:
THPUtils_assert(false, "Unexpected value for backend");
return nullptr;
}
END_HANDLE_TH_ERRORS
}
PyObject* THCPModule_cudaSynchronize(PyObject* _unused, PyObject* noargs) { PyObject* THCPModule_cudaSynchronize(PyObject* _unused, PyObject* noargs) {
HANDLE_TH_ERRORS HANDLE_TH_ERRORS
c10::cuda::device_synchronize(); c10::cuda::device_synchronize();
@ -1061,6 +1081,10 @@ static struct PyMethodDef _THCPModule_methods[] = {
THCPModule_cudaCachingAllocator_set_allocator_settings, THCPModule_cudaCachingAllocator_set_allocator_settings,
METH_O, METH_O,
nullptr}, nullptr},
{"_cuda_getAllocatorBackend",
THCPModule_getAllocatorBackend,
METH_NOARGS,
nullptr},
{"_cuda_synchronize", THCPModule_cudaSynchronize, METH_NOARGS, nullptr}, {"_cuda_synchronize", THCPModule_cudaSynchronize, METH_NOARGS, nullptr},
{"_cuda_ipc_collect", THCPModule_cudaIPCCollect, METH_NOARGS, nullptr}, {"_cuda_ipc_collect", THCPModule_cudaIPCCollect, METH_NOARGS, nullptr},
{"_cuda_sleep", THCPModule_cudaSleep, METH_O, nullptr}, {"_cuda_sleep", THCPModule_cudaSleep, METH_O, nullptr},

View File

@ -837,10 +837,10 @@ __all__ = [
'CUDAGraph', 'CudaError', 'DeferredCudaCallError', 'Event', 'ExternalStream', 'OutOfMemoryError', 'CUDAGraph', 'CudaError', 'DeferredCudaCallError', 'Event', 'ExternalStream', 'OutOfMemoryError',
'Stream', 'StreamContext', 'amp', 'caching_allocator_alloc', 'caching_allocator_delete', 'can_device_access_peer', 'Stream', 'StreamContext', 'amp', 'caching_allocator_alloc', 'caching_allocator_delete', 'can_device_access_peer',
'check_error', 'cudaStatus', 'cudart', 'current_blas_handle', 'current_device', 'current_stream', 'default_generators', 'check_error', 'cudaStatus', 'cudart', 'current_blas_handle', 'current_device', 'current_stream', 'default_generators',
'default_stream', 'device', 'device_count', 'device_of', 'empty_cache', 'get_arch_list', 'get_device_capability', 'default_stream', 'device', 'device_count', 'device_of', 'empty_cache', 'get_allocator_backend', 'get_arch_list',
'get_device_name', 'get_device_properties', 'get_gencode_flags', 'get_rng_state', 'get_rng_state_all', 'get_sync_debug_mode', 'get_device_capability', 'get_device_name', 'get_device_properties', 'get_gencode_flags', 'get_rng_state', 'get_rng_state_all',
'graph', 'graph_pool_handle', 'graphs', 'has_half', 'has_magma', 'init', 'initial_seed', 'ipc_collect', 'is_available', 'get_sync_debug_mode', 'graph', 'graph_pool_handle', 'graphs', 'has_half', 'has_magma', 'init', 'initial_seed', 'ipc_collect',
'is_bf16_supported', 'is_current_stream_capturing', 'is_initialized', 'jiterator', 'list_gpu_processes', 'is_available', 'is_bf16_supported', 'is_current_stream_capturing', 'is_initialized', 'jiterator', 'list_gpu_processes',
'make_graphed_callables', 'manual_seed', 'manual_seed_all', 'max_memory_allocated', 'max_memory_cached', 'max_memory_reserved', 'make_graphed_callables', 'manual_seed', 'manual_seed_all', 'max_memory_allocated', 'max_memory_cached', 'max_memory_reserved',
'mem_get_info', 'memory', 'memory_allocated', 'memory_cached', 'memory_reserved', 'memory_snapshot', 'memory_stats', 'mem_get_info', 'memory', 'memory_allocated', 'memory_cached', 'memory_reserved', 'memory_snapshot', 'memory_stats',
'memory_stats_as_nested_dict', 'memory_summary', 'memory_usage', 'nccl', 'nvtx', 'profiler', 'random', 'memory_stats_as_nested_dict', 'memory_summary', 'memory_usage', 'nccl', 'nvtx', 'profiler', 'random',

View File

@ -16,7 +16,7 @@ __all__ = ["caching_allocator_alloc", "caching_allocator_delete", "set_per_proce
"reset_peak_memory_stats", "reset_max_memory_allocated", "reset_max_memory_cached", "reset_peak_memory_stats", "reset_max_memory_allocated", "reset_max_memory_cached",
"memory_allocated", "max_memory_allocated", "memory_reserved", "max_memory_reserved", "memory_allocated", "max_memory_allocated", "memory_reserved", "max_memory_reserved",
"memory_cached", "max_memory_cached", "memory_snapshot", "memory_summary", "list_gpu_processes", "memory_cached", "max_memory_cached", "memory_snapshot", "memory_summary", "list_gpu_processes",
"mem_get_info"] "mem_get_info", "get_allocator_backend"]
def _host_allocator(): def _host_allocator():
_lazy_init() _lazy_init()
@ -194,6 +194,10 @@ def memory_stats(device: Union[Device, int] = None) -> Dict[str, Any]:
.. note:: .. note::
See :ref:`cuda-memory-management` for more details about GPU memory See :ref:`cuda-memory-management` for more details about GPU memory
management. management.
.. note::
With :ref:`backend:cudaMallocAsync<cuda-memory-envvars>`, some stats are not
meaningful, and are always reported as zero.
""" """
result = [] result = []
@ -636,3 +640,14 @@ def _save_memory_usage(filename='output.svg', snapshot=None):
def _set_allocator_settings(env: str): def _set_allocator_settings(env: str):
return torch._C._cuda_cudaCachingAllocator_set_allocator_settings(env) return torch._C._cuda_cudaCachingAllocator_set_allocator_settings(env)
def get_allocator_backend() -> str:
r"""Returns a string describing the active allocator backend as set by
``PYTORCH_CUDA_ALLOC_CONF``. Currently available backends are
``native`` (PyTorch's native caching allocator) and `cudaMallocAsync``
(CUDA's built-in asynchronous allocator).
.. note::
See :ref:`cuda-memory-management` for details on choosing the allocator backend.
"""
return torch._C._cuda_getAllocatorBackend()