mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
(Re-open) Adds cudaMallocAsync as an alternative backend for the CUDA allocator (#82682)
Rebased version of @mcarilli 's cudaMallocAsync #65365 for continued testing Pull Request resolved: https://github.com/pytorch/pytorch/pull/82682 Approved by: https://github.com/ngimel
This commit is contained in:
parent
a216f4700c
commit
25725fd624
|
|
@ -19,10 +19,10 @@ namespace at {
|
||||||
*
|
*
|
||||||
* A CUDA graph containing multiple RNG ops behaves like a
|
* A CUDA graph containing multiple RNG ops behaves like a
|
||||||
* single giant kernel from the perspective of ops external
|
* single giant kernel from the perspective of ops external
|
||||||
* to the graph. During graph capture, logic below records
|
* to the graph. During graph capture, logic in CUDAGeneratorImpl
|
||||||
* the total of all offset increments that occur in the graphed
|
* records the total of all offset increments that occur in the
|
||||||
* region, and records the final total as the offset for the
|
* graphed region, and records the final total as the offset for
|
||||||
* entire graph.
|
* the entire graph.
|
||||||
*
|
*
|
||||||
* When the graph reruns, the logic that reruns it
|
* When the graph reruns, the logic that reruns it
|
||||||
* increments this device's CUDA generator's offset
|
* increments this device's CUDA generator's offset
|
||||||
|
|
@ -30,8 +30,8 @@ namespace at {
|
||||||
*
|
*
|
||||||
* Meanwhile, within the graph, at capture time, instead of
|
* Meanwhile, within the graph, at capture time, instead of
|
||||||
* populating PhiloxCudaStates with the uint64_t offset pulled
|
* populating PhiloxCudaStates with the uint64_t offset pulled
|
||||||
* directly from the global state, PhiloxCudaState instead
|
* directly from the global state, PhiloxCudaState uses a pointer
|
||||||
* holds a pointer to one-element stream-local int64_t device tensor
|
* to a one-element stream-local int64_t device tensor
|
||||||
* holding an initial offset value, and a uint64_t holding an
|
* holding an initial offset value, and a uint64_t holding an
|
||||||
* intra-graph offset. (The intra-graph offset starts from zero
|
* intra-graph offset. (The intra-graph offset starts from zero
|
||||||
* when capture begins.) In each consumer kernel,
|
* when capture begins.) In each consumer kernel,
|
||||||
|
|
|
||||||
|
|
@ -133,16 +133,42 @@ void CUDAGraph::capture_end() {
|
||||||
TORCH_CHECK(stream == capture_stream_,
|
TORCH_CHECK(stream == capture_stream_,
|
||||||
"Capture must end on the same stream it began on.");
|
"Capture must end on the same stream it began on.");
|
||||||
|
|
||||||
c10::cuda::CUDACachingAllocator::notifyCaptureEnd(capture_dev_, id_);
|
c10::cuda::CUDACachingAllocator::notifyCaptureAboutToEnd(capture_dev_, id_);
|
||||||
|
|
||||||
AT_CUDA_CHECK(cudaStreamEndCapture(capture_stream_, &graph_));
|
AT_CUDA_CHECK(cudaStreamEndCapture(capture_stream_, &graph_));
|
||||||
TORCH_CHECK(graph_ != NULL, "Invalid capture.");
|
TORCH_CHECK(graph_ != NULL, "Invalid capture.");
|
||||||
has_graph_ = true;
|
has_graph_ = true;
|
||||||
|
|
||||||
// Trailing NULL, NULL, 0 arguments were recommended by Cuda driver people,
|
c10::cuda::CUDACachingAllocator::notifyCaptureEnded(capture_dev_, id_);
|
||||||
// who prefer not to report error message through these arguments moving forward
|
|
||||||
// (they prefer return value, or errors on api calls internal to the capture)
|
// In typical graph usage some tensors (e.g. the tensors used for graph IO) are not freed
|
||||||
AT_CUDA_CHECK(cudaGraphInstantiate(&graph_exec_, graph_, NULL, NULL, 0));
|
// between replays.
|
||||||
|
// If Pytorch compiles and runs with a CUDA 11.4+ toolkit, there's a chance the allocator backend
|
||||||
|
// is cudaMallocAsync.
|
||||||
|
// cudaMallocAsync is generally graph-safe, but if some tensors are not freed between replays,
|
||||||
|
// the graph's internal bookkeeping requires that we instantiate with
|
||||||
|
// cudaGraphInstantiateFlagAutoFreeOnLaunch. See
|
||||||
|
// cudaGraphLaunch
|
||||||
|
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__GRAPH.html#group__CUDART__GRAPH_1g1accfe1da0c605a577c22d9751a09597
|
||||||
|
// cudaGraphInstantiateWithFlags
|
||||||
|
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__GRAPH.html#group__CUDART__GRAPH_1ga2c652a24ba93e52b99a47bec0888233
|
||||||
|
#if CUDA_VERSION >= 11040
|
||||||
|
int version;
|
||||||
|
AT_CUDA_CHECK(cudaDriverGetVersion(&version));
|
||||||
|
if (version < 11040) {
|
||||||
|
#endif
|
||||||
|
// Trailing NULL, NULL, 0 arguments were recommended by Cuda driver people,
|
||||||
|
// who prefer not to report error message through these arguments moving forward
|
||||||
|
// (they prefer return value, or errors on api calls internal to the capture)
|
||||||
|
AT_CUDA_CHECK(cudaGraphInstantiate(&graph_exec_, graph_, NULL, NULL, 0));
|
||||||
|
#if CUDA_VERSION >= 11040
|
||||||
|
} else {
|
||||||
|
AT_CUDA_CHECK(cudaGraphInstantiateWithFlags(&graph_exec_,
|
||||||
|
graph_,
|
||||||
|
cudaGraphInstantiateFlagAutoFreeOnLaunch));
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
has_graph_exec_ = true;
|
has_graph_exec_ = true;
|
||||||
|
|
||||||
auto* gen = get_generator_or_default<CUDAGeneratorImpl>(
|
auto* gen = get_generator_or_default<CUDAGeneratorImpl>(
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,11 @@
|
||||||
#include <ATen/cuda/PeerToPeerAccess.h>
|
#include <ATen/cuda/PeerToPeerAccess.h>
|
||||||
|
|
||||||
|
#include <c10/cuda/CUDACachingAllocator.h>
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
#include <c10/util/Exception.h>
|
#include <c10/util/Exception.h>
|
||||||
#include <c10/util/irange.h>
|
#include <c10/util/irange.h>
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <algorithm>
|
|
||||||
|
|
||||||
namespace at {
|
namespace at {
|
||||||
namespace cuda {
|
namespace cuda {
|
||||||
|
|
@ -38,6 +39,13 @@ bool get_p2p_access(int dev, int dev_to_access) {
|
||||||
dev_to_access, " is not a device");
|
dev_to_access, " is not a device");
|
||||||
TORCH_INTERNAL_ASSERT(num_devices_ >= 0, "p2p access cache not initialized");
|
TORCH_INTERNAL_ASSERT(num_devices_ >= 0, "p2p access cache not initialized");
|
||||||
|
|
||||||
|
#ifdef USE_ROCM
|
||||||
|
bool using_cudaMallocAsync = false;
|
||||||
|
#else
|
||||||
|
bool using_cudaMallocAsync = (CUDACachingAllocator::allocatorBackend() ==
|
||||||
|
CUDACachingAllocator::AllocatorBackend::CUDAMALLOCASYNC);
|
||||||
|
#endif
|
||||||
|
|
||||||
auto &cache = p2pAccessEnabled_[dev * num_devices_ + dev_to_access];
|
auto &cache = p2pAccessEnabled_[dev * num_devices_ + dev_to_access];
|
||||||
|
|
||||||
if (cache != -1) {
|
if (cache != -1) {
|
||||||
|
|
@ -49,12 +57,36 @@ bool get_p2p_access(int dev, int dev_to_access) {
|
||||||
int access = 0;
|
int access = 0;
|
||||||
C10_CUDA_CHECK(cudaDeviceCanAccessPeer(&access, dev, dev_to_access));
|
C10_CUDA_CHECK(cudaDeviceCanAccessPeer(&access, dev, dev_to_access));
|
||||||
if (access) {
|
if (access) {
|
||||||
cudaError_t err = cudaDeviceEnablePeerAccess(dev_to_access, 0);
|
if (using_cudaMallocAsync) {
|
||||||
if (err == cudaErrorPeerAccessAlreadyEnabled) {
|
#if CUDA_VERSION >= 11040
|
||||||
// ignore and clear the error if access was already enabled
|
// Double-checks allocator backend hasn't changed, which would definitely be an error.
|
||||||
cudaGetLastError();
|
#ifndef USE_ROCM
|
||||||
|
TORCH_INTERNAL_ASSERT(CUDACachingAllocator::allocatorBackend() ==
|
||||||
|
CUDACachingAllocator::AllocatorBackend::CUDAMALLOCASYNC);
|
||||||
|
#endif
|
||||||
|
// cudaMallocAsync pools are unaffected by cudaDeviceEnablePeerAccess.
|
||||||
|
// We need pool-specific enablement. See
|
||||||
|
// https://developer.nvidia.com/blog/using-cuda-stream-ordered-memory-allocator-part-2/
|
||||||
|
cudaMemPool_t mempool;
|
||||||
|
C10_CUDA_CHECK(cudaDeviceGetDefaultMemPool(&mempool, dev_to_access));
|
||||||
|
cudaMemAccessDesc desc = {};
|
||||||
|
desc.location.type = cudaMemLocationTypeDevice;
|
||||||
|
desc.location.id = dev;
|
||||||
|
desc.flags = cudaMemAccessFlagsProtReadWrite;
|
||||||
|
C10_CUDA_CHECK(cudaMemPoolSetAccess(mempool, &desc, 1 /* numDescs */));
|
||||||
|
#else
|
||||||
|
TORCH_INTERNAL_ASSERT(false);
|
||||||
|
#endif
|
||||||
} else {
|
} else {
|
||||||
C10_CUDA_CHECK(err);
|
TORCH_INTERNAL_ASSERT(CUDACachingAllocator::allocatorBackend() ==
|
||||||
|
CUDACachingAllocator::AllocatorBackend::NATIVE);
|
||||||
|
cudaError_t err = cudaDeviceEnablePeerAccess(dev_to_access, 0);
|
||||||
|
if (err == cudaErrorPeerAccessAlreadyEnabled) {
|
||||||
|
// ignore and clear the error if access was already enabled
|
||||||
|
cudaGetLastError();
|
||||||
|
} else {
|
||||||
|
C10_CUDA_CHECK(err);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
cache = 1;
|
cache = 1;
|
||||||
} else {
|
} else {
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,6 @@
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
#include <ATen/cuda/CUDAEvent.h>
|
#include <ATen/cuda/CUDAEvent.h>
|
||||||
#include <ATen/cuda/PeerToPeerAccess.h>
|
#include <ATen/cuda/PeerToPeerAccess.h>
|
||||||
#include <c10/cuda/CUDAStream.h>
|
|
||||||
#include <ATen/native/Copy.h>
|
#include <ATen/native/Copy.h>
|
||||||
#include <ATen/native/TensorIterator.h>
|
#include <ATen/native/TensorIterator.h>
|
||||||
#include <ATen/native/cuda/Loops.cuh>
|
#include <ATen/native/cuda/Loops.cuh>
|
||||||
|
|
@ -17,6 +16,9 @@
|
||||||
#include <ATen/ops/empty_like.h>
|
#include <ATen/ops/empty_like.h>
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#include <c10/cuda/CUDACachingAllocator.h>
|
||||||
|
#include <c10/cuda/CUDAStream.h>
|
||||||
|
|
||||||
namespace at {
|
namespace at {
|
||||||
namespace native {
|
namespace native {
|
||||||
|
|
||||||
|
|
@ -46,7 +48,9 @@ void neg_conj_kernel_cuda(TensorIteratorBase &iter) {
|
||||||
using namespace at::cuda;
|
using namespace at::cuda;
|
||||||
|
|
||||||
// device-to-device copy, does type conversion
|
// device-to-device copy, does type conversion
|
||||||
void copy_device_to_device(TensorIterator& iter, bool non_blocking) {
|
void copy_device_to_device(TensorIterator& iter,
|
||||||
|
bool non_blocking,
|
||||||
|
bool p2p_enabled) {
|
||||||
int64_t numel = iter.numel();
|
int64_t numel = iter.numel();
|
||||||
|
|
||||||
// We can memcpy the memory if both tensors have the same type AND both
|
// We can memcpy the memory if both tensors have the same type AND both
|
||||||
|
|
@ -87,11 +91,29 @@ void copy_device_to_device(TensorIterator& iter, bool non_blocking) {
|
||||||
void *src = iter.data_ptr(1);
|
void *src = iter.data_ptr(1);
|
||||||
size_t size = numel * iter.element_size(0);
|
size_t size = numel * iter.element_size(0);
|
||||||
if (src != dst || src_device != dst_device) {
|
if (src != dst || src_device != dst_device) {
|
||||||
// Perform the copy
|
// Due to bizarre cuda driver intricacies, copies of
|
||||||
AT_CUDA_CHECK(cudaMemcpyAsync(
|
// cudaMallocAsynced memory between devices that aren't
|
||||||
dst, src, size,
|
// peer-to-peer-capable need "cudaMemcpyPeerAsync".
|
||||||
cudaMemcpyDeviceToDevice,
|
#ifdef USE_ROCM
|
||||||
copy_stream));
|
bool using_cudaMallocAsync = false;
|
||||||
|
#else
|
||||||
|
bool using_cudaMallocAsync = (CUDACachingAllocator::allocatorBackend() ==
|
||||||
|
CUDACachingAllocator::AllocatorBackend::CUDAMALLOCASYNC);
|
||||||
|
#endif
|
||||||
|
bool needs_MemcpyPeer = (src_device != dst_device &&
|
||||||
|
using_cudaMallocAsync &&
|
||||||
|
!p2p_enabled);
|
||||||
|
if (needs_MemcpyPeer) {
|
||||||
|
AT_CUDA_CHECK(cudaMemcpyPeerAsync(
|
||||||
|
dst, dst_device.index(),
|
||||||
|
src, src_device.index(),
|
||||||
|
size, copy_stream));
|
||||||
|
} else {
|
||||||
|
AT_CUDA_CHECK(cudaMemcpyAsync(
|
||||||
|
dst, src, size,
|
||||||
|
cudaMemcpyDeviceToDevice,
|
||||||
|
copy_stream));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (same_neg) {
|
if (same_neg) {
|
||||||
|
|
@ -205,7 +227,7 @@ static void copy_kernel_cuda(TensorIterator& iter, bool non_blocking) {
|
||||||
|
|
||||||
// Copy on GPU (or between GPUs)
|
// Copy on GPU (or between GPUs)
|
||||||
if (dst_device.is_cuda() && src_device.is_cuda()) {
|
if (dst_device.is_cuda() && src_device.is_cuda()) {
|
||||||
copy_device_to_device(iter, non_blocking);
|
copy_device_to_device(iter, non_blocking, p2p_enabled);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -205,10 +205,11 @@ size_t getMaxWorkspaceSize(
|
||||||
{
|
{
|
||||||
size_t max_ws_size = 0;
|
size_t max_ws_size = 0;
|
||||||
size_t max_block_size = 0;
|
size_t max_block_size = 0;
|
||||||
size_t tmp_bytes = 0; // Only used for filling pointer parameters that aren't used later
|
|
||||||
|
|
||||||
const auto device = c10::cuda::current_device();
|
const auto device = c10::cuda::current_device();
|
||||||
c10::cuda::CUDACachingAllocator::cacheInfo(device, &tmp_bytes, &max_block_size);
|
// For the native allocator, retrieves the size of the largest unused block.
|
||||||
|
// For cudaMallocAsync, see c10/cuda/CUDAMallocAsync.cpp:cacheInfo for details.
|
||||||
|
c10::cuda::CUDACachingAllocator::cacheInfo(device, &max_block_size);
|
||||||
|
|
||||||
for (const auto i : c10::irange(n_algo)) {
|
for (const auto i : c10::irange(n_algo)) {
|
||||||
cudnnStatus_t err;
|
cudnnStatus_t err;
|
||||||
|
|
|
||||||
|
|
@ -307,8 +307,7 @@ size_t get_available_workspace() {
|
||||||
int device;
|
int device;
|
||||||
C10_CUDA_CHECK(cudaGetDevice(&device));
|
C10_CUDA_CHECK(cudaGetDevice(&device));
|
||||||
size_t max_block_size = 0;
|
size_t max_block_size = 0;
|
||||||
size_t tmp_bytes = 0; // Only used for filling pointer parameters that aren't used later
|
c10::cuda::CUDACachingAllocator::cacheInfo(device, &max_block_size);
|
||||||
c10::cuda::CUDACachingAllocator::cacheInfo(device, &tmp_bytes, &max_block_size);
|
|
||||||
return max_block_size;
|
return max_block_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -25,6 +25,8 @@ set(C10_CUDA_SRCS
|
||||||
CUDAFunctions.cpp
|
CUDAFunctions.cpp
|
||||||
CUDAMiscFunctions.cpp
|
CUDAMiscFunctions.cpp
|
||||||
CUDAStream.cpp
|
CUDAStream.cpp
|
||||||
|
CUDACachingAllocator.cpp
|
||||||
|
CUDAMallocAsyncAllocator.cpp
|
||||||
impl/CUDAGuardImpl.cpp
|
impl/CUDAGuardImpl.cpp
|
||||||
impl/CUDATest.cpp
|
impl/CUDATest.cpp
|
||||||
)
|
)
|
||||||
|
|
@ -36,6 +38,7 @@ set(C10_CUDA_HEADERS
|
||||||
CUDAMathCompat.h
|
CUDAMathCompat.h
|
||||||
CUDAMiscFunctions.h
|
CUDAMiscFunctions.h
|
||||||
CUDAStream.h
|
CUDAStream.h
|
||||||
|
CUDACachingAllocator.h
|
||||||
impl/CUDAGuardImpl.h
|
impl/CUDAGuardImpl.h
|
||||||
impl/CUDATest.h
|
impl/CUDATest.h
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
|
|
||||||
#include <c10/cuda/CUDACachingAllocator.h>
|
#include <c10/cuda/CUDACachingAllocator.h>
|
||||||
|
|
||||||
#include <c10/core/impl/GPUTrace.h>
|
#include <c10/core/impl/GPUTrace.h>
|
||||||
|
|
@ -28,6 +27,7 @@ C10_DEFINE_REGISTRY(FreeCudaMemoryCallbacksRegistry, FreeMemoryCallback);
|
||||||
|
|
||||||
namespace cuda {
|
namespace cuda {
|
||||||
namespace CUDACachingAllocator {
|
namespace CUDACachingAllocator {
|
||||||
|
namespace Native {
|
||||||
|
|
||||||
//
|
//
|
||||||
// Yet another caching allocator for CUDA device allocations.
|
// Yet another caching allocator for CUDA device allocations.
|
||||||
|
|
@ -73,7 +73,7 @@ namespace CUDACachingAllocator {
|
||||||
* must be available for the graph to use during replay. DeviceCachingAllocator
|
* must be available for the graph to use during replay. DeviceCachingAllocator
|
||||||
* assigns and frees memory eagerly and dynamically, so if we're not careful
|
* assigns and frees memory eagerly and dynamically, so if we're not careful
|
||||||
* about managing graphs' memory, at replay time those memory addresses could be
|
* about managing graphs' memory, at replay time those memory addresses could be
|
||||||
* use by other tensors.
|
* used by other tensors.
|
||||||
*
|
*
|
||||||
* To guarantee a graph's baked in addresses are safe to reuse in replay,
|
* To guarantee a graph's baked in addresses are safe to reuse in replay,
|
||||||
* DeviceAllocator satisfies allocations from a graph-private memory pool during
|
* DeviceAllocator satisfies allocations from a graph-private memory pool during
|
||||||
|
|
@ -88,14 +88,12 @@ namespace CUDACachingAllocator {
|
||||||
* (regardless whether those captures are idle or replaying).
|
* (regardless whether those captures are idle or replaying).
|
||||||
*
|
*
|
||||||
* CUDAGraph's requests for private pools are mediated by
|
* CUDAGraph's requests for private pools are mediated by
|
||||||
* DeviceAllocator::notifyCaptureBegin, notifyCaptureEnd, and
|
* DeviceAllocator::notifyCaptureBegin,
|
||||||
* notifyCaptureDestroy.
|
* notifyCaptureAboutToEnd,
|
||||||
|
* notifyCaptureEnded,
|
||||||
|
* notifyCaptureDestroy.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
using stream_set = ska::flat_hash_set<cuda::CUDAStream>;
|
|
||||||
|
|
||||||
constexpr size_t kMinBlockSize =
|
constexpr size_t kMinBlockSize =
|
||||||
512; // all sizes are rounded to at least 512 bytes
|
512; // all sizes are rounded to at least 512 bytes
|
||||||
constexpr size_t kSmallSize = 1048576; // largest "small" allocation is 1 MiB
|
constexpr size_t kSmallSize = 1048576; // largest "small" allocation is 1 MiB
|
||||||
|
|
@ -107,6 +105,10 @@ constexpr size_t kMinLargeAlloc =
|
||||||
10485760; // allocations between 1 and 10 MiB may use kLargeBuffer
|
10485760; // allocations between 1 and 10 MiB may use kLargeBuffer
|
||||||
constexpr size_t kRoundLarge = 2097152; // round up large allocations to 2 MiB
|
constexpr size_t kRoundLarge = 2097152; // round up large allocations to 2 MiB
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using stream_set = ska::flat_hash_set<cuda::CUDAStream>;
|
||||||
|
|
||||||
using StatTypes = std::array<bool, static_cast<size_t>(StatType::NUM_TYPES)>;
|
using StatTypes = std::array<bool, static_cast<size_t>(StatType::NUM_TYPES)>;
|
||||||
|
|
||||||
void update_stat(Stat& stat, int64_t amount) {
|
void update_stat(Stat& stat, int64_t amount) {
|
||||||
|
|
@ -237,25 +239,6 @@ static bool BlockComparator(const Block* a, const Block* b) {
|
||||||
return (uintptr_t)a->ptr < (uintptr_t)b->ptr;
|
return (uintptr_t)a->ptr < (uintptr_t)b->ptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::string format_size(uint64_t size) {
|
|
||||||
std::ostringstream os;
|
|
||||||
os.precision(2);
|
|
||||||
os << std::fixed;
|
|
||||||
if (size <= 1024) {
|
|
||||||
os << size << " bytes";
|
|
||||||
} else if (size <= 1048576) {
|
|
||||||
os << (size / 1024.0);
|
|
||||||
os << " KiB";
|
|
||||||
} else if (size <= 1073741824ULL) {
|
|
||||||
os << size / 1048576.0;
|
|
||||||
os << " MiB";
|
|
||||||
} else {
|
|
||||||
os << size / 1073741824.0;
|
|
||||||
os << " GiB";
|
|
||||||
}
|
|
||||||
return os.str();
|
|
||||||
}
|
|
||||||
|
|
||||||
struct AllocParams {
|
struct AllocParams {
|
||||||
AllocParams(
|
AllocParams(
|
||||||
int device,
|
int device,
|
||||||
|
|
@ -403,10 +386,102 @@ cudaError_t cudaMallocMaybeCapturing(void** p, size_t size) {
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // anonymous namespace
|
||||||
|
} // namespace Native
|
||||||
|
|
||||||
|
// Backend static initialization.
|
||||||
|
#define DECLARE_BACKEND_INTERFACE(RET, FUNC, ARGS) RET FUNC ARGS;
|
||||||
|
|
||||||
|
// Not called directly by clients.
|
||||||
|
namespace Native {
|
||||||
|
FORALL_ALLOCATOR_INTERFACE(DECLARE_BACKEND_INTERFACE)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Not called directly by clients.
|
||||||
|
namespace CudaMallocAsync {
|
||||||
|
FORALL_ALLOCATOR_INTERFACE(DECLARE_BACKEND_INTERFACE)
|
||||||
|
}
|
||||||
|
|
||||||
|
#undef DECLARE_BACKEND_INTERFACE
|
||||||
|
|
||||||
|
#define DEFINE_CHOSEN(RET, FUNC, ARGS) RET(*FUNC) ARGS = 0;
|
||||||
|
|
||||||
|
namespace Chosen {
|
||||||
|
FORALL_ALLOCATOR_INTERFACE(DEFINE_CHOSEN);
|
||||||
|
} // namespace Chosen
|
||||||
|
|
||||||
|
#define INITIALIZE_NATIVE(RET, FUNC, ARGS) Chosen::FUNC = Native::FUNC;
|
||||||
|
|
||||||
|
#define INITIALIZE_CUDAMALLOCASYNC(RET, FUNC, ARGS) \
|
||||||
|
Chosen::FUNC = CudaMallocAsync::FUNC;
|
||||||
|
|
||||||
|
struct BackendStaticInitializer {
|
||||||
|
AllocatorBackend backend;
|
||||||
|
|
||||||
|
// Parses env for backend at load time, duplicating some logic from
|
||||||
|
// CachingAllocatorConfig. CachingAllocatorConfig double-checks it later (at
|
||||||
|
// runtime). Defers verbose exceptions and error checks, including Cuda
|
||||||
|
// version checks, to CachingAllocatorConfig's runtime doublecheck. If this
|
||||||
|
// works, maybe we should move all of CachingAllocatorConfig here?
|
||||||
|
AllocatorBackend parseEnvForBackend() {
|
||||||
|
const char* val = getenv("PYTORCH_CUDA_ALLOC_CONF");
|
||||||
|
|
||||||
|
if (val == NULL) {
|
||||||
|
return AllocatorBackend::NATIVE;
|
||||||
|
} else {
|
||||||
|
const std::string config(val);
|
||||||
|
|
||||||
|
std::regex exp("[\\s,]+");
|
||||||
|
std::sregex_token_iterator it(config.begin(), config.end(), exp, -1);
|
||||||
|
std::sregex_token_iterator end;
|
||||||
|
std::vector<std::string> options(it, end);
|
||||||
|
|
||||||
|
for (auto option : options) {
|
||||||
|
std::regex exp2("[:]+");
|
||||||
|
std::sregex_token_iterator it2(option.begin(), option.end(), exp2, -1);
|
||||||
|
std::sregex_token_iterator end2;
|
||||||
|
std::vector<std::string> kv(it2, end2);
|
||||||
|
if (kv.size() >= 2) {
|
||||||
|
if (kv[0] == "backend") {
|
||||||
|
if (kv[1] == "cudaMallocAsync")
|
||||||
|
return AllocatorBackend::CUDAMALLOCASYNC;
|
||||||
|
if (kv[1] == "native")
|
||||||
|
return AllocatorBackend::NATIVE;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return AllocatorBackend::NATIVE;
|
||||||
|
}
|
||||||
|
|
||||||
|
BackendStaticInitializer() {
|
||||||
|
backend = parseEnvForBackend();
|
||||||
|
switch (backend) {
|
||||||
|
case AllocatorBackend::NATIVE:
|
||||||
|
FORALL_ALLOCATOR_INTERFACE(INITIALIZE_NATIVE)
|
||||||
|
break;
|
||||||
|
case AllocatorBackend::CUDAMALLOCASYNC:
|
||||||
|
FORALL_ALLOCATOR_INTERFACE(INITIALIZE_CUDAMALLOCASYNC)
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
#undef INITIALIZE_NATIVE
|
||||||
|
#undef INITIALIZE_CUDAMALLOCAYSNC
|
||||||
|
|
||||||
|
BackendStaticInitializer backend_static_initializer{};
|
||||||
|
|
||||||
|
// Environment config parser
|
||||||
|
// Defined here, rather than its own .cpp file,
|
||||||
|
// because parseArgs needs to know kLargeBuffer.
|
||||||
|
// Defined outside namespace Native because it's not Native-specific.
|
||||||
class CachingAllocatorConfig {
|
class CachingAllocatorConfig {
|
||||||
public:
|
public:
|
||||||
|
static AllocatorBackend allocator_backend() {
|
||||||
|
return instance().m_allocator_backend;
|
||||||
|
}
|
||||||
|
|
||||||
static size_t max_split_size() {
|
static size_t max_split_size() {
|
||||||
return instance().m_max_split_size;
|
return instance().m_max_split_size;
|
||||||
}
|
}
|
||||||
|
|
@ -441,6 +516,7 @@ class CachingAllocatorConfig {
|
||||||
m_roundup_power2_divisions = 0;
|
m_roundup_power2_divisions = 0;
|
||||||
m_roundup_bypass_threshold = std::numeric_limits<size_t>::max();
|
m_roundup_bypass_threshold = std::numeric_limits<size_t>::max();
|
||||||
m_garbage_collection_threshold = 0;
|
m_garbage_collection_threshold = 0;
|
||||||
|
m_allocator_backend = AllocatorBackend::NATIVE;
|
||||||
|
|
||||||
if (env == nullptr) {
|
if (env == nullptr) {
|
||||||
return;
|
return;
|
||||||
|
|
@ -453,6 +529,9 @@ class CachingAllocatorConfig {
|
||||||
std::sregex_token_iterator end;
|
std::sregex_token_iterator end;
|
||||||
std::vector<std::string> options(it, end);
|
std::vector<std::string> options(it, end);
|
||||||
|
|
||||||
|
bool used_cudaMallocAsync = false;
|
||||||
|
bool used_native_specific_option = false;
|
||||||
|
|
||||||
for (auto option : options) {
|
for (auto option : options) {
|
||||||
std::regex exp2("[:]+");
|
std::regex exp2("[:]+");
|
||||||
std::sregex_token_iterator it2(option.begin(), option.end(), exp2, -1);
|
std::sregex_token_iterator it2(option.begin(), option.end(), exp2, -1);
|
||||||
|
|
@ -460,28 +539,59 @@ class CachingAllocatorConfig {
|
||||||
std::vector<std::string> kv(it2, end2);
|
std::vector<std::string> kv(it2, end2);
|
||||||
if (kv.size() >= 2) {
|
if (kv.size() >= 2) {
|
||||||
/* Maximum split size in MB. Limited to large size blocks */
|
/* Maximum split size in MB. Limited to large size blocks */
|
||||||
if (kv[0].compare("max_split_size_mb") == 0) {
|
if (kv[0] == "max_split_size_mb") {
|
||||||
size_t val2 = stoi(kv[1]);
|
size_t val2 = stoi(kv[1]);
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
val2 > kLargeBuffer / (1024 * 1024),
|
val2 > Native::kLargeBuffer / (1024 * 1024),
|
||||||
"CachingAllocator option max_split_size_mb too small, must be > ",
|
"CachingAllocator option max_split_size_mb too small, must be > ",
|
||||||
kLargeBuffer / (1024 * 1024),
|
Native::kLargeBuffer / (1024 * 1024),
|
||||||
"");
|
"");
|
||||||
val2 = std::max(val2, kLargeBuffer / (1024 * 1024));
|
val2 = std::max(val2, Native::kLargeBuffer / (1024 * 1024));
|
||||||
val2 = std::min(
|
val2 = std::min(
|
||||||
val2, (std::numeric_limits<size_t>::max() / (1024 * 1024)));
|
val2, (std::numeric_limits<size_t>::max() / (1024 * 1024)));
|
||||||
m_max_split_size = val2 * 1024 * 1024;
|
m_max_split_size = val2 * 1024 * 1024;
|
||||||
} else if (kv[0].compare("roundup_power2_divisions") == 0) {
|
used_native_specific_option = true;
|
||||||
|
} else if (kv[0] == "roundup_power2_divisions") {
|
||||||
size_t val2 = stoi(kv[1]);
|
size_t val2 = stoi(kv[1]);
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
llvm::isPowerOf2_64(val2),
|
llvm::isPowerOf2_64(val2),
|
||||||
"For roundups, the divisons has to be power of 2 ",
|
"For roundups, the divisons has to be power of 2 ",
|
||||||
"");
|
"");
|
||||||
m_roundup_power2_divisions = val2;
|
m_roundup_power2_divisions = val2;
|
||||||
} else if (kv[0].compare("roundup_bypass_threshold_mb") == 0) {
|
used_native_specific_option = true;
|
||||||
|
} else if (kv[0] == "roundup_bypass_threshold_mb") {
|
||||||
size_t val2 = stoi(kv[1]);
|
size_t val2 = stoi(kv[1]);
|
||||||
m_roundup_bypass_threshold = val2 * 1024 * 1024;
|
m_roundup_bypass_threshold = val2 * 1024 * 1024;
|
||||||
} else if (kv[0].compare("garbage_collection_threshold") == 0) {
|
used_native_specific_option = true;
|
||||||
|
} else if (kv[0] == "backend") {
|
||||||
|
TORCH_CHECK(
|
||||||
|
((kv[1] == "native") || (kv[1] == "cudaMallocAsync")),
|
||||||
|
"Unknown allocator backend, "
|
||||||
|
"options are native and cudaMallocAsync");
|
||||||
|
used_cudaMallocAsync = (kv[1] == "cudaMallocAsync");
|
||||||
|
if (used_cudaMallocAsync) {
|
||||||
|
#if CUDA_VERSION >= 11040
|
||||||
|
int version;
|
||||||
|
C10_CUDA_CHECK(cudaDriverGetVersion(&version));
|
||||||
|
TORCH_CHECK(
|
||||||
|
version >= 11040,
|
||||||
|
"backend:cudaMallocAsync requires CUDA runtime "
|
||||||
|
"11.4 or newer, but cudaDriverGetVersion returned ",
|
||||||
|
version);
|
||||||
|
m_allocator_backend = AllocatorBackend::CUDAMALLOCASYNC;
|
||||||
|
#else
|
||||||
|
TORCH_CHECK(
|
||||||
|
false,
|
||||||
|
"backend:cudaMallocAsync requires PyTorch to be built with "
|
||||||
|
"CUDA 11.4 or newer, but CUDA_VERSION is ",
|
||||||
|
CUDA_VERSION);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
TORCH_INTERNAL_ASSERT(
|
||||||
|
m_allocator_backend == backend_static_initializer.backend,
|
||||||
|
"Allocator backend parsed at runtime != "
|
||||||
|
"allocator backend parsed at load time");
|
||||||
|
} else if (kv[0] == "garbage_collection_threshold") {
|
||||||
/*
|
/*
|
||||||
* Perform garbage collection of GPU memory blocks to avoid
|
* Perform garbage collection of GPU memory blocks to avoid
|
||||||
* triggering expensive sync-and-reclaim-all operation. Upon setting
|
* triggering expensive sync-and-reclaim-all operation. Upon setting
|
||||||
|
|
@ -500,10 +610,17 @@ class CachingAllocatorConfig {
|
||||||
"garbage_collect_threshold too big, set it 0.0~1.0",
|
"garbage_collect_threshold too big, set it 0.0~1.0",
|
||||||
"");
|
"");
|
||||||
m_garbage_collection_threshold = val2;
|
m_garbage_collection_threshold = val2;
|
||||||
|
used_native_specific_option = true;
|
||||||
} else {
|
} else {
|
||||||
TORCH_CHECK(false, "Unrecognized CachingAllocator option: ", kv[0]);
|
TORCH_CHECK(false, "Unrecognized CachingAllocator option: ", kv[0]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (used_cudaMallocAsync && used_native_specific_option) {
|
||||||
|
TORCH_WARN(
|
||||||
|
"backend:cudaMallocAsync ignores max_split_size_mb, roundup_bypass_threshold_mb,"
|
||||||
|
"roundup_power2_divisions, and garbage_collect_threshold.");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -516,8 +633,11 @@ class CachingAllocatorConfig {
|
||||||
std::atomic<size_t> m_roundup_power2_divisions;
|
std::atomic<size_t> m_roundup_power2_divisions;
|
||||||
std::atomic<size_t> m_roundup_bypass_threshold;
|
std::atomic<size_t> m_roundup_bypass_threshold;
|
||||||
std::atomic<double> m_garbage_collection_threshold;
|
std::atomic<double> m_garbage_collection_threshold;
|
||||||
|
AllocatorBackend m_allocator_backend;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
namespace Native {
|
||||||
|
|
||||||
class DeviceCachingAllocator {
|
class DeviceCachingAllocator {
|
||||||
private:
|
private:
|
||||||
// lock around all operations
|
// lock around all operations
|
||||||
|
|
@ -630,7 +750,7 @@ class DeviceCachingAllocator {
|
||||||
//
|
//
|
||||||
// Q. Why skip process_events if a capture might be underway?
|
// Q. Why skip process_events if a capture might be underway?
|
||||||
// A. process_events involves cudaEventQueries, illegal during CUDA graph
|
// A. process_events involves cudaEventQueries, illegal during CUDA graph
|
||||||
// capture.
|
// capture.
|
||||||
// Dumb simple solution: defer reclaiming these allocations until after
|
// Dumb simple solution: defer reclaiming these allocations until after
|
||||||
// capture. Cross-stream memory use is uncommon, so the deferral's
|
// capture. Cross-stream memory use is uncommon, so the deferral's
|
||||||
// effect on memory use during capture should be small.
|
// effect on memory use during capture should be small.
|
||||||
|
|
@ -938,8 +1058,8 @@ class DeviceCachingAllocator {
|
||||||
release_cached_blocks();
|
release_cached_blocks();
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Retrieves info (total size + largest block) of the memory cache **/
|
/** Retrieves size of largest unused block held by the memory cache **/
|
||||||
void cacheInfo(size_t* total, size_t* largest) {
|
void cacheInfo(size_t* largest) {
|
||||||
std::lock_guard<std::recursive_mutex> lock(mutex);
|
std::lock_guard<std::recursive_mutex> lock(mutex);
|
||||||
if (*largest ==
|
if (*largest ==
|
||||||
0) { // make an initial guess if a zero *largest is passed in
|
0) { // make an initial guess if a zero *largest is passed in
|
||||||
|
|
@ -948,11 +1068,11 @@ class DeviceCachingAllocator {
|
||||||
largest, // Use free memory as an optimistic initial guess of *largest
|
largest, // Use free memory as an optimistic initial guess of *largest
|
||||||
&tmp_bytes));
|
&tmp_bytes));
|
||||||
}
|
}
|
||||||
cache_info_aux(large_blocks, total, largest);
|
cache_info_aux(large_blocks, largest);
|
||||||
cache_info_aux(small_blocks, total, largest);
|
cache_info_aux(small_blocks, largest);
|
||||||
for (const auto& gp : graph_pools) {
|
for (const auto& gp : graph_pools) {
|
||||||
cache_info_aux(gp.second->large_blocks, total, largest);
|
cache_info_aux(gp.second->large_blocks, largest);
|
||||||
cache_info_aux(gp.second->small_blocks, total, largest);
|
cache_info_aux(gp.second->small_blocks, largest);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1145,7 +1265,7 @@ class DeviceCachingAllocator {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Called by CUDAGraph::capture_end
|
// Called by CUDAGraph::capture_end
|
||||||
void notifyCaptureEnd(CaptureId_t graph_id) {
|
void notifyCaptureAboutToEnd(CaptureId_t graph_id) {
|
||||||
std::lock_guard<std::recursive_mutex> lock(mutex);
|
std::lock_guard<std::recursive_mutex> lock(mutex);
|
||||||
captures_underway--;
|
captures_underway--;
|
||||||
auto it = capture_to_pool_map.find(graph_id);
|
auto it = capture_to_pool_map.find(graph_id);
|
||||||
|
|
@ -1474,14 +1594,13 @@ class DeviceCachingAllocator {
|
||||||
if (p.err != cudaSuccess) {
|
if (p.err != cudaSuccess) {
|
||||||
if (p.err == cudaErrorMemoryAllocation) {
|
if (p.err == cudaErrorMemoryAllocation) {
|
||||||
// If this is the first attempt (!isRetry), we can forgive and clear
|
// If this is the first attempt (!isRetry), we can forgive and clear
|
||||||
// CUDA's
|
// CUDA's internal error state.
|
||||||
// internal error state.
|
//
|
||||||
// If this is the second attempt (isRetry), malloc's TORCH_CHECK_WITH
|
// If this is the second attempt (isRetry), malloc's TORCH_CHECK_WITH
|
||||||
// will take
|
// will take over to throw a helpful exception. The user can choose
|
||||||
// over to throw a helpful exception. The user can choose to catch
|
// to catch the exception, free some stuff in their script, and
|
||||||
// the exception, free some stuff in their script, and attempt their
|
// attempt the allocation again. In this case, we can also forgive and
|
||||||
// allocation again. In this case, we can also forgive and clear
|
// clear CUDA's internal error state.
|
||||||
// CUDA's internal error state.
|
|
||||||
cudaGetLastError();
|
cudaGetLastError();
|
||||||
} else {
|
} else {
|
||||||
// If the error's unrelated to memory allocation, we should throw
|
// If the error's unrelated to memory allocation, we should throw
|
||||||
|
|
@ -1735,11 +1854,10 @@ class DeviceCachingAllocator {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Accumulates sizes of all memory blocks for given device in given pool
|
// Iterates over sizes of all memory blocks for given device in given pool
|
||||||
void cache_info_aux(const BlockPool& pool, size_t* total, size_t* largest) {
|
void cache_info_aux(const BlockPool& pool, size_t* largest) {
|
||||||
for (const auto& block : pool.blocks) {
|
for (const auto& block : pool.blocks) {
|
||||||
const auto blocksize = block->size;
|
const auto blocksize = block->size;
|
||||||
*total += blocksize;
|
|
||||||
if (blocksize > *largest) {
|
if (blocksize > *largest) {
|
||||||
*largest = blocksize;
|
*largest = blocksize;
|
||||||
}
|
}
|
||||||
|
|
@ -1769,7 +1887,7 @@ class DeviceCachingAllocator {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
class THCCachingAllocator {
|
class NativeCachingAllocator {
|
||||||
private:
|
private:
|
||||||
std::mutex mutex;
|
std::mutex mutex;
|
||||||
|
|
||||||
|
|
@ -1931,7 +2049,7 @@ class THCCachingAllocator {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
THCCachingAllocator caching_allocator;
|
NativeCachingAllocator caching_allocator;
|
||||||
|
|
||||||
// Returns whether to force all allocations to bypass the caching allocator and
|
// Returns whether to force all allocations to bypass the caching allocator and
|
||||||
// go straight to cudaMalloc. This setting is useful when debugging GPU memory
|
// go straight to cudaMalloc. This setting is useful when debugging GPU memory
|
||||||
|
|
@ -1950,8 +2068,8 @@ static void uncached_delete(void* ptr) {
|
||||||
C10_CUDA_CHECK(cudaFree(ptr));
|
C10_CUDA_CHECK(cudaFree(ptr));
|
||||||
}
|
}
|
||||||
|
|
||||||
// NB: I decided not to fold this into THCCachingAllocator, because the latter
|
// NB: I decided not to fold this into NativeCachingAllocator, because the
|
||||||
// has a lot more methods and it wasn't altogether clear that they should
|
// latter has a lot more methods and it wasn't altogether clear that they should
|
||||||
// actually be publicly exposed
|
// actually be publicly exposed
|
||||||
struct CudaCachingAllocator : public Allocator {
|
struct CudaCachingAllocator : public Allocator {
|
||||||
DataPtr allocate(size_t size) const override {
|
DataPtr allocate(size_t size) const override {
|
||||||
|
|
@ -2026,9 +2144,8 @@ void emptyCache(void) {
|
||||||
caching_allocator.emptyCache();
|
caching_allocator.emptyCache();
|
||||||
}
|
}
|
||||||
|
|
||||||
void cacheInfo(int dev_id, size_t* cachedAndFree, size_t* largestBlock) {
|
void cacheInfo(int dev_id, size_t* largestBlock) {
|
||||||
caching_allocator.device_allocator[dev_id]->cacheInfo(
|
caching_allocator.device_allocator[dev_id]->cacheInfo(largestBlock);
|
||||||
cachedAndFree, largestBlock);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void* getBaseAllocation(void* ptr, size_t* size) {
|
void* getBaseAllocation(void* ptr, size_t* size) {
|
||||||
|
|
@ -2081,17 +2198,45 @@ void notifyCaptureBegin(
|
||||||
graph_id, mempool_id);
|
graph_id, mempool_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
void notifyCaptureEnd(int device, CaptureId_t graph_id) {
|
void notifyCaptureAboutToEnd(int device, CaptureId_t graph_id) {
|
||||||
assertValidDevice(device);
|
assertValidDevice(device);
|
||||||
caching_allocator.device_allocator[device]->notifyCaptureEnd(graph_id);
|
caching_allocator.device_allocator[device]->notifyCaptureAboutToEnd(graph_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void notifyCaptureEnded(int device, CaptureId_t graph_id) {} // no-op
|
||||||
|
|
||||||
void notifyCaptureDestroy(int device, MempoolId_t mempool_id) {
|
void notifyCaptureDestroy(int device, MempoolId_t mempool_id) {
|
||||||
assertValidDevice(device);
|
assertValidDevice(device);
|
||||||
caching_allocator.device_allocator[device]->notifyCaptureDestroy(mempool_id);
|
caching_allocator.device_allocator[device]->notifyCaptureDestroy(mempool_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
void* raw_alloc(size_t nbytes) {
|
||||||
|
if (nbytes == 0) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
int device;
|
||||||
|
C10_CUDA_CHECK(cudaGetDevice(&device));
|
||||||
|
void* r = nullptr;
|
||||||
|
caching_allocator.malloc(
|
||||||
|
&r, device, nbytes, cuda::getCurrentCUDAStream(device));
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
|
||||||
|
void* raw_alloc_with_stream(size_t nbytes, cudaStream_t stream) {
|
||||||
|
if (nbytes == 0) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
int device;
|
||||||
|
C10_CUDA_CHECK(cudaGetDevice(&device));
|
||||||
|
void* r = nullptr;
|
||||||
|
caching_allocator.malloc(&r, device, nbytes, stream);
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
|
||||||
|
void raw_delete(void* ptr) {
|
||||||
|
caching_allocator.free(ptr);
|
||||||
|
}
|
||||||
|
|
||||||
// In CUDA IPC, sender sends a tensor to receiver, getIpcDevPtr
|
// In CUDA IPC, sender sends a tensor to receiver, getIpcDevPtr
|
||||||
// is called by the receiving process to map the CUDA memory from the sending
|
// is called by the receiving process to map the CUDA memory from the sending
|
||||||
// process into its own address space.
|
// process into its own address space.
|
||||||
|
|
@ -2146,34 +2291,42 @@ std::shared_ptr<void> getIpcDevPtr(std::string handle) {
|
||||||
return sp;
|
return sp;
|
||||||
}
|
}
|
||||||
|
|
||||||
void* raw_alloc(size_t nbytes) {
|
} // namespace Native
|
||||||
if (nbytes == 0) {
|
|
||||||
return nullptr;
|
// General caching allocator utilities
|
||||||
}
|
|
||||||
int device;
|
// External config interface (declared in CUDACachingAllocator.h)
|
||||||
C10_CUDA_CHECK(cudaGetDevice(&device));
|
// This is a useless layer of indirection with a minor
|
||||||
void* r = nullptr;
|
// code-cleanliness benefit: it alleviates the need to define
|
||||||
caching_allocator.malloc(
|
// CachingAllocatorConfig itself in CUDACachingAllocator.h.
|
||||||
&r, device, nbytes, cuda::getCurrentCUDAStream(device));
|
AllocatorBackend allocatorBackend() {
|
||||||
return r;
|
return CachingAllocatorConfig::allocator_backend();
|
||||||
}
|
}
|
||||||
|
|
||||||
void* raw_alloc_with_stream(size_t nbytes, cudaStream_t stream) {
|
void setAllocatorSettings(const std::string& env) {
|
||||||
if (nbytes == 0) {
|
CachingAllocatorConfig::instance().parseArgs(env.c_str());
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
int device;
|
|
||||||
C10_CUDA_CHECK(cudaGetDevice(&device));
|
|
||||||
void* r = nullptr;
|
|
||||||
caching_allocator.malloc(&r, device, nbytes, stream);
|
|
||||||
return r;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void raw_delete(void* ptr) {
|
// Size pretty-printer
|
||||||
caching_allocator.free(ptr);
|
inline std::string format_size(uint64_t size) {
|
||||||
|
std::ostringstream os;
|
||||||
|
os.precision(2);
|
||||||
|
os << std::fixed;
|
||||||
|
if (size <= 1024) {
|
||||||
|
os << size << " bytes";
|
||||||
|
} else if (size <= 1048576) {
|
||||||
|
os << (size / 1024.0);
|
||||||
|
os << " KiB";
|
||||||
|
} else if (size <= 1073741824ULL) {
|
||||||
|
os << size / 1048576.0;
|
||||||
|
os << " MiB";
|
||||||
|
} else {
|
||||||
|
os << size / 1073741824.0;
|
||||||
|
os << " GiB";
|
||||||
|
}
|
||||||
|
return os.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace CUDACachingAllocator
|
} // namespace CUDACachingAllocator
|
||||||
|
|
||||||
} // namespace cuda
|
} // namespace cuda
|
||||||
} // namespace c10
|
} // namespace c10
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
#ifndef THC_DEVICE_ALLOCATOR_INC
|
#pragma once
|
||||||
#define THC_DEVICE_ALLOCATOR_INC
|
|
||||||
#include <c10/core/Allocator.h>
|
#include <c10/core/Allocator.h>
|
||||||
#include <c10/cuda/CUDAGraphsC10Utils.h>
|
#include <c10/cuda/CUDAGraphsC10Utils.h>
|
||||||
#include <c10/cuda/CUDAMacros.h>
|
#include <c10/cuda/CUDAMacros.h>
|
||||||
|
|
@ -166,52 +166,174 @@ struct SnapshotInfo {
|
||||||
std::vector<std::vector<TraceEntry>> device_traces;
|
std::vector<std::vector<TraceEntry>> device_traces;
|
||||||
};
|
};
|
||||||
|
|
||||||
C10_CUDA_API void* raw_alloc(size_t nbytes);
|
// Allocator config options.
|
||||||
C10_CUDA_API void* raw_alloc_with_stream(size_t nbytes, cudaStream_t stream);
|
enum struct AllocatorBackend : uint8_t {
|
||||||
C10_CUDA_API void raw_delete(void* ptr);
|
NATIVE = 0,
|
||||||
|
CUDAMALLOCASYNC = 1,
|
||||||
|
};
|
||||||
|
|
||||||
C10_CUDA_API Allocator* get();
|
C10_CUDA_API AllocatorBackend allocatorBackend();
|
||||||
C10_CUDA_API void init(int device_count);
|
|
||||||
C10_CUDA_API void setMemoryFraction(double fraction, int device);
|
|
||||||
C10_CUDA_API void setAllocatorSettings(const std::string& env);
|
C10_CUDA_API void setAllocatorSettings(const std::string& env);
|
||||||
C10_CUDA_API void emptyCache();
|
|
||||||
C10_CUDA_API void cacheInfo(
|
|
||||||
int dev_id,
|
|
||||||
size_t* cachedAndFree,
|
|
||||||
size_t* largestBlock);
|
|
||||||
C10_CUDA_API void* getBaseAllocation(void* ptr, size_t* size);
|
|
||||||
C10_CUDA_API void recordStream(const DataPtr&, CUDAStream stream);
|
|
||||||
C10_CUDA_API DeviceStats getDeviceStats(int device);
|
|
||||||
C10_CUDA_API void resetAccumulatedStats(int device);
|
|
||||||
C10_CUDA_API void resetPeakStats(int device);
|
|
||||||
C10_CUDA_API SnapshotInfo snapshot();
|
|
||||||
|
|
||||||
// CUDAGraph interactions
|
// Size pretty-printer
|
||||||
C10_CUDA_API void notifyCaptureBegin(
|
std::string format_size(uint64_t size);
|
||||||
int device,
|
|
||||||
CaptureId_t graph_id,
|
|
||||||
MempoolId_t mempool_id);
|
|
||||||
C10_CUDA_API void notifyCaptureEnd(int device, CaptureId_t graph_id);
|
|
||||||
C10_CUDA_API void notifyCaptureDestroy(int device, MempoolId_t mempool_id);
|
|
||||||
|
|
||||||
C10_CUDA_API std::mutex* getFreeMutex();
|
|
||||||
|
|
||||||
C10_CUDA_API void recordHistory(
|
|
||||||
bool enabled,
|
|
||||||
CreateContextFn context_recorder,
|
|
||||||
size_t alloc_trace_max_entries,
|
|
||||||
bool alloc_trace_record_context);
|
|
||||||
using OutOfMemoryObserver = std::function<void(
|
using OutOfMemoryObserver = std::function<void(
|
||||||
int64_t device,
|
int64_t device,
|
||||||
int64_t allocated,
|
int64_t allocated,
|
||||||
int64_t device_total,
|
int64_t device_total,
|
||||||
int64_t device_free)>;
|
int64_t device_free)>;
|
||||||
C10_CUDA_API void attachOutOfMemoryObserver(OutOfMemoryObserver observer);
|
|
||||||
|
|
||||||
C10_CUDA_API std::shared_ptr<void> getIpcDevPtr(std::string handle);
|
#define FORALL_ALLOCATOR_INTERFACE(_) \
|
||||||
|
_(C10_CUDA_API void*, raw_alloc, (size_t nbytes)) \
|
||||||
|
_(C10_CUDA_API void*, \
|
||||||
|
raw_alloc_with_stream, \
|
||||||
|
(size_t nbytes, cudaStream_t stream)) \
|
||||||
|
_(C10_CUDA_API void, raw_delete, (void* ptr)) \
|
||||||
|
_(C10_CUDA_API Allocator*, get, ()) \
|
||||||
|
_(C10_CUDA_API void, init, (int device_count)) \
|
||||||
|
_(C10_CUDA_API void, setMemoryFraction, (double fraction, int device)) \
|
||||||
|
_(C10_CUDA_API void, emptyCache, ()) \
|
||||||
|
_(C10_CUDA_API void, cacheInfo, (int dev_id, size_t* largestBlock)) \
|
||||||
|
_(C10_CUDA_API void*, getBaseAllocation, (void* ptr, size_t* size)) \
|
||||||
|
_(C10_CUDA_API void, recordStream, (const DataPtr&, CUDAStream stream)) \
|
||||||
|
_(C10_CUDA_API DeviceStats, getDeviceStats, (int device)) \
|
||||||
|
_(C10_CUDA_API void, resetAccumulatedStats, (int device)) \
|
||||||
|
_(C10_CUDA_API void, resetPeakStats, (int device)) \
|
||||||
|
_(C10_CUDA_API SnapshotInfo, snapshot, ()) \
|
||||||
|
_(C10_CUDA_API void, \
|
||||||
|
notifyCaptureBegin, \
|
||||||
|
(int device, CaptureId_t graph_id, MempoolId_t mempool_id)) \
|
||||||
|
_(C10_CUDA_API void, \
|
||||||
|
notifyCaptureAboutToEnd, \
|
||||||
|
(int device, CaptureId_t graph_id)) \
|
||||||
|
_(C10_CUDA_API void, notifyCaptureEnded, (int device, CaptureId_t graph_id)) \
|
||||||
|
_(C10_CUDA_API void, \
|
||||||
|
notifyCaptureDestroy, \
|
||||||
|
(int device, MempoolId_t mempool_id)) \
|
||||||
|
_(C10_CUDA_API std::mutex*, getFreeMutex, ()) \
|
||||||
|
_(C10_CUDA_API std::shared_ptr<void>, getIpcDevPtr, (std::string handle)) \
|
||||||
|
_(C10_CUDA_API void, \
|
||||||
|
recordHistory, \
|
||||||
|
(bool enabled, \
|
||||||
|
CreateContextFn context_recorder, \
|
||||||
|
size_t alloc_trace_max_entries, \
|
||||||
|
bool alloc_trace_record_context)) \
|
||||||
|
_(C10_CUDA_API void, \
|
||||||
|
attachOutOfMemoryObserver, \
|
||||||
|
(OutOfMemoryObserver observer))
|
||||||
|
|
||||||
|
// Allocator backend function pointers, statically initialized
|
||||||
|
// according to PYTORCH_CUDA_ALLOC_CONF.
|
||||||
|
// See BackendInitializer in CUDACachingAllocator.cpp.
|
||||||
|
namespace Chosen {
|
||||||
|
#define DECLARE_CHOSEN(RET, FUNC, ARGS) extern RET(*FUNC) ARGS;
|
||||||
|
FORALL_ALLOCATOR_INTERFACE(DECLARE_CHOSEN)
|
||||||
|
#undef DECLARE_CHOSEN
|
||||||
|
} // namespace Chosen
|
||||||
|
|
||||||
|
// Called directly by clients.
|
||||||
|
inline void* raw_alloc(size_t nbytes) {
|
||||||
|
return Chosen::raw_alloc(nbytes);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void* raw_alloc_with_stream(size_t nbytes, cudaStream_t stream) {
|
||||||
|
return Chosen::raw_alloc_with_stream(nbytes, stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void raw_delete(void* ptr) {
|
||||||
|
return Chosen::raw_delete(ptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline Allocator* get() {
|
||||||
|
return Chosen::get();
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void init(int device_count) {
|
||||||
|
return Chosen::init(device_count);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void setMemoryFraction(double fraction, int device) {
|
||||||
|
return Chosen::setMemoryFraction(fraction, device);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void emptyCache() {
|
||||||
|
return Chosen::emptyCache();
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void cacheInfo(int dev_id, size_t* largestBlock) {
|
||||||
|
return Chosen::cacheInfo(dev_id, largestBlock);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void* getBaseAllocation(void* ptr, size_t* size) {
|
||||||
|
return Chosen::getBaseAllocation(ptr, size);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void recordStream(const DataPtr& dataPtr, CUDAStream stream) {
|
||||||
|
return Chosen::recordStream(dataPtr, stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline DeviceStats getDeviceStats(int device) {
|
||||||
|
return Chosen::getDeviceStats(device);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void resetAccumulatedStats(int device) {
|
||||||
|
return Chosen::resetAccumulatedStats(device);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void resetPeakStats(int device) {
|
||||||
|
return Chosen::resetPeakStats(device);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline SnapshotInfo snapshot() {
|
||||||
|
return Chosen::snapshot();
|
||||||
|
}
|
||||||
|
|
||||||
|
// CUDAGraph interactions
|
||||||
|
inline void notifyCaptureBegin(
|
||||||
|
int device,
|
||||||
|
CaptureId_t graph_id,
|
||||||
|
MempoolId_t mempool_id) {
|
||||||
|
return Chosen::notifyCaptureBegin(device, graph_id, mempool_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void notifyCaptureAboutToEnd(int device, CaptureId_t graph_id) {
|
||||||
|
return Chosen::notifyCaptureAboutToEnd(device, graph_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void recordHistory(
|
||||||
|
bool enabled,
|
||||||
|
CreateContextFn context_recorder,
|
||||||
|
size_t alloc_trace_max_entries,
|
||||||
|
bool alloc_trace_record_context) {
|
||||||
|
return Chosen::recordHistory(
|
||||||
|
enabled,
|
||||||
|
context_recorder,
|
||||||
|
alloc_trace_max_entries,
|
||||||
|
alloc_trace_record_context);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void attachOutOfMemoryObserver(OutOfMemoryObserver observer) {
|
||||||
|
return Chosen::attachOutOfMemoryObserver(observer);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void notifyCaptureEnded(int device, CaptureId_t graph_id) {
|
||||||
|
return Chosen::notifyCaptureEnded(device, graph_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void notifyCaptureDestroy(int device, MempoolId_t mempool_id) {
|
||||||
|
return Chosen::notifyCaptureDestroy(device, mempool_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline std::mutex* getFreeMutex() {
|
||||||
|
return Chosen::getFreeMutex();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Not part of CUDA_ALLOCATOR_BACKEND_INTERFACE
|
||||||
|
inline std::shared_ptr<void> getIpcDevPtr(std::string handle) {
|
||||||
|
return Chosen::getIpcDevPtr(handle);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace CUDACachingAllocator
|
} // namespace CUDACachingAllocator
|
||||||
|
|
||||||
} // namespace cuda
|
} // namespace cuda
|
||||||
} // namespace c10
|
} // namespace c10
|
||||||
|
|
||||||
#endif
|
|
||||||
|
|
|
||||||
924
c10/cuda/CUDAMallocAsyncAllocator.cpp
Normal file
924
c10/cuda/CUDAMallocAsyncAllocator.cpp
Normal file
|
|
@ -0,0 +1,924 @@
|
||||||
|
#include <c10/cuda/CUDACachingAllocator.h>
|
||||||
|
#include <c10/cuda/CUDAException.h>
|
||||||
|
#include <c10/cuda/CUDAFunctions.h>
|
||||||
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
#include <c10/util/UniqueVoidPtr.h>
|
||||||
|
#include <c10/util/flat_hash_map.h>
|
||||||
|
#include <c10/util/irange.h>
|
||||||
|
|
||||||
|
#include <unordered_set>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace c10 {
|
||||||
|
namespace cuda {
|
||||||
|
namespace CUDACachingAllocator {
|
||||||
|
namespace CudaMallocAsync {
|
||||||
|
|
||||||
|
#if CUDA_VERSION >= 11040
|
||||||
|
// CUDA device allocator that uses cudaMallocAsync to implement
|
||||||
|
// the same interface as CUDACachingAllocator.cpp.
|
||||||
|
|
||||||
|
// Designed to be safe for CUDA graph capture.
|
||||||
|
// Interactions with CUDA graph capture are mediated by
|
||||||
|
// notifyCaptureBegin
|
||||||
|
// notifyCaptureAboutToEnd
|
||||||
|
// notifyCaptureEnded
|
||||||
|
// notifyCaptureDestroy
|
||||||
|
|
||||||
|
// Implementation details, not declared in CUDACachingAllocator.h
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// General helpers
|
||||||
|
|
||||||
|
struct UsageStream {
|
||||||
|
cudaStream_t stream;
|
||||||
|
int device;
|
||||||
|
UsageStream() {}
|
||||||
|
UsageStream(cudaStream_t s, int d) : stream(s), device(d) {}
|
||||||
|
UsageStream(const UsageStream& us) : stream(us.stream), device(us.device) {}
|
||||||
|
UsageStream(const UsageStream&& us) : stream(us.stream), device(us.device) {}
|
||||||
|
UsageStream& operator=(UsageStream other) {
|
||||||
|
stream = other.stream;
|
||||||
|
device = other.device;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
bool operator==(const UsageStream& lhs, const UsageStream& rhs) {
|
||||||
|
return (lhs.stream == rhs.stream) && (lhs.device == rhs.device);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct UsageStreamHash {
|
||||||
|
size_t operator()(const UsageStream& us) const noexcept {
|
||||||
|
return std::hash<void*>{}(us.stream) + size_t(us.device);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct PtrUsage {
|
||||||
|
// recorded_streams holds side usage streams added by record_stream calls.
|
||||||
|
// In other words, it does NOT include the original creation stream.
|
||||||
|
ska::flat_hash_set<UsageStream, UsageStreamHash> recorded_streams;
|
||||||
|
UsageStream creation_stream;
|
||||||
|
uint64_t size;
|
||||||
|
bool captured;
|
||||||
|
PtrUsage(uint64_t s, bool c) : size(s), captured(c) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
int device_count = 0;
|
||||||
|
// these don't need to be c10::once_flags as in CUDAGeneratorImpl.cpp
|
||||||
|
// because they'll only be flipped by functions that have locked the mutex.
|
||||||
|
std::vector<bool> devs_initialized_flags;
|
||||||
|
std::vector<UsageStream> dummy_unifying_free_streams;
|
||||||
|
|
||||||
|
// Possible micro-optimization:
|
||||||
|
// Some accesses to ptr_info are read-only.
|
||||||
|
// We could let those be concurrent with a shared_mutex and
|
||||||
|
// have concurrent calls take a shared_lock.
|
||||||
|
// Keeping it simple with an ordinary mutex for now.
|
||||||
|
std::mutex general_mutex;
|
||||||
|
|
||||||
|
// Copy-paste CUDACachingAllocator's
|
||||||
|
// lock around calls to cudaFree (to prevent deadlocks with NCCL)
|
||||||
|
// is this safe?
|
||||||
|
std::mutex cuda_free_mutex;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Note [Avoid freeing uncaptured ptrs during CUDA graph capture]
|
||||||
|
* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
* During CUDA graph capture, it's illegal to call cudaFreeAsync
|
||||||
|
* on a pointer that came from a non-captured cudaMallocAsync.
|
||||||
|
* Unfortunately, Python being what it is, it's impossible to be
|
||||||
|
* sure no uncaptured tensor will ever have its destructor called
|
||||||
|
* in a capturing region.
|
||||||
|
* We avoid errors by
|
||||||
|
* 1. remembering if allocated pointers were captured or uncaptured
|
||||||
|
* 2. during capture, if we detect an attempt to free an uncaptured
|
||||||
|
* allocation on a capturing stream, don't free it immediately,
|
||||||
|
* just remember it and defer its cudaFreeAsync call to after
|
||||||
|
* the end of capture (specifically, to notifyCaptureEnded).
|
||||||
|
*/
|
||||||
|
|
||||||
|
using PtrInfo = ska::flat_hash_map<void*, PtrUsage>;
|
||||||
|
PtrInfo ptr_info;
|
||||||
|
std::vector<void*> ungraphed_ptrs_defer_free_until_no_capture;
|
||||||
|
|
||||||
|
// These two help setMemoryFraction limit the amount of memory
|
||||||
|
// used by PyTorch in particular (as opposed to other libraries
|
||||||
|
// in the same process that might be sharing the same cudaMemPool_t).
|
||||||
|
std::vector<size_t> pytorch_used_bytes;
|
||||||
|
std::vector<size_t> pytorch_memory_limits;
|
||||||
|
|
||||||
|
// Graph-specific helpers
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Note [Avoid dangling free streams during CUDA graph capture]
|
||||||
|
* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
* During capture, all stream dependencies must branch out from
|
||||||
|
* the stream on which capture began and rejoin this initial stream
|
||||||
|
* before capture ends.
|
||||||
|
* The user rigs desired forking and joining with event waits.
|
||||||
|
* But it's hard to be sure when tensor destructors get called relative
|
||||||
|
* to the final joins.
|
||||||
|
* For example, suppose a user
|
||||||
|
* forks work stream B from initial capture stream A
|
||||||
|
* creates a tensor T in B
|
||||||
|
* joins by syncing A with B
|
||||||
|
* ends capture.
|
||||||
|
* All well and good, right? Maybe not: maybe T went out of scope
|
||||||
|
* and its destructor got called AFTER the rejoin, leaving the graph with
|
||||||
|
* "unjoined work": a dangling cudaFreeAsync node in stream B.
|
||||||
|
* Ensuring that all tensor destructors for all side stream tensors
|
||||||
|
* are called before side streams rejoin the main stream is
|
||||||
|
* difficult. The user might have to add a bunch of explicit
|
||||||
|
* "del"s at the right spots in code that was fine for ordinary
|
||||||
|
* eager execution.
|
||||||
|
* Fortunately, we can spare the user this burden:
|
||||||
|
* during capture, we remember _all_ free streams,
|
||||||
|
* and manually rejoin them with the capture stream during
|
||||||
|
* notifyCaptureAboutToEnd.
|
||||||
|
* This approach is heavy-handed, but hopefully capture only needs to
|
||||||
|
* happen once, so we don't mind being heavy-handed.
|
||||||
|
*
|
||||||
|
* TODO: If, someday, we augment the graph bindings to support recapture
|
||||||
|
* https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#whole-graph-update
|
||||||
|
* (eg, as a way to accommodate dynamic params) we should think more
|
||||||
|
* carefully about the CPU overhead of remembering and rejoining
|
||||||
|
* all free streams during capture. Maybe it's not a big deal.
|
||||||
|
*/
|
||||||
|
std::unordered_set<UsageStream, UsageStreamHash> capture_free_streams;
|
||||||
|
bool capture_underway = false;
|
||||||
|
|
||||||
|
// Implementation functions
|
||||||
|
|
||||||
|
// Assumes the caller holds general_mutex
|
||||||
|
inline void lazy_init_device(int device) {
|
||||||
|
if (!devs_initialized_flags[device]) {
|
||||||
|
CUDAGuard g(device);
|
||||||
|
|
||||||
|
// See "Retaining memory in the pool" here:
|
||||||
|
// https://developer.nvidia.com/blog/using-cuda-stream-ordered-memory-allocator-part-1/
|
||||||
|
cudaMemPool_t mempool;
|
||||||
|
C10_CUDA_CHECK(cudaDeviceGetDefaultMemPool(&mempool, device));
|
||||||
|
uint64_t threshold = UINT64_MAX;
|
||||||
|
C10_CUDA_CHECK(cudaMemPoolSetAttribute(
|
||||||
|
mempool, cudaMemPoolAttrReleaseThreshold, &threshold));
|
||||||
|
|
||||||
|
// I think all these are on by default, but I want to enable them
|
||||||
|
// explicitly to ensure awareness.
|
||||||
|
int enable = 1;
|
||||||
|
C10_CUDA_CHECK(cudaMemPoolSetAttribute(
|
||||||
|
mempool, cudaMemPoolReuseFollowEventDependencies, &enable));
|
||||||
|
C10_CUDA_CHECK(cudaMemPoolSetAttribute(
|
||||||
|
mempool, cudaMemPoolReuseAllowOpportunistic, &enable));
|
||||||
|
C10_CUDA_CHECK(cudaMemPoolSetAttribute(
|
||||||
|
mempool, cudaMemPoolReuseAllowInternalDependencies, &enable));
|
||||||
|
|
||||||
|
// Grabs a stream from the current device to use as the "unifier" free
|
||||||
|
// stream for allocations that end up used on multiple streams.
|
||||||
|
const auto dufs = getStreamFromPool();
|
||||||
|
dummy_unifying_free_streams[device] =
|
||||||
|
UsageStream(dufs.stream(), dufs.device_index());
|
||||||
|
|
||||||
|
pytorch_used_bytes[device] = 0;
|
||||||
|
pytorch_memory_limits[device] = UINT64_MAX;
|
||||||
|
|
||||||
|
devs_initialized_flags[device] = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void sync_raw(cudaStream_t dependency, cudaStream_t dependent) {
|
||||||
|
// CUDACachingAllocator.cpp uses raw cuda events, as do we.
|
||||||
|
cudaEvent_t event;
|
||||||
|
C10_CUDA_CHECK(cudaEventCreateWithFlags(&event, cudaEventDisableTiming));
|
||||||
|
C10_CUDA_CHECK(cudaEventRecord(event, dependency));
|
||||||
|
C10_CUDA_CHECK(cudaStreamWaitEvent(dependent, event));
|
||||||
|
C10_CUDA_CHECK(cudaEventDestroy(event));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assumes the caller holds general_mutex
|
||||||
|
inline void free_impl(PtrInfo::iterator& it) {
|
||||||
|
// Possible micro-optimization: If we did a value-copy here, we could move
|
||||||
|
// ptr_info.erase(it) up here and drop the lock immediately.
|
||||||
|
const auto& recorded_streams = it->second.recorded_streams;
|
||||||
|
const auto& creation_stream = it->second.creation_stream;
|
||||||
|
|
||||||
|
// If the usage stream is a null (default) stream,
|
||||||
|
// cudaFreeAsync infers the device from the ambient context,
|
||||||
|
// so we need to set the right ambient context.
|
||||||
|
CUDAGuard g(creation_stream.device);
|
||||||
|
|
||||||
|
if (recorded_streams.empty()) {
|
||||||
|
// ptr was only used on one stream, which must have been
|
||||||
|
// the original allocation stream.
|
||||||
|
// Frees ptr in the original allocation stream.
|
||||||
|
|
||||||
|
C10_CUDA_CHECK(cudaFreeAsync(it->first, creation_stream.stream));
|
||||||
|
|
||||||
|
if (C10_UNLIKELY(capture_underway)) {
|
||||||
|
// See Note [Avoid dangling free streams during CUDA graph capture]
|
||||||
|
capture_free_streams.insert(creation_stream);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// ptr was used on many streams. We don't know which was the most recent.
|
||||||
|
// There could even have been multiple most recent usage streams acting
|
||||||
|
// on different regions of the memory.
|
||||||
|
// But cudaFreeAsync only accepts a single most recent usage stream.
|
||||||
|
// We can still safely free ptr with a trick:
|
||||||
|
// Use a dummy "unifying stream", sync the unifying stream with all of
|
||||||
|
// ptr's usage streams, and pass the dummy stream to cudaFreeAsync.
|
||||||
|
|
||||||
|
// Retrieves the dummy "unifier" stream from the device
|
||||||
|
// on which the pointer was originally allocated.
|
||||||
|
auto dummy_unifying_free_stream =
|
||||||
|
dummy_unifying_free_streams[creation_stream.device];
|
||||||
|
TORCH_INTERNAL_ASSERT(
|
||||||
|
dummy_unifying_free_stream.device == creation_stream.device);
|
||||||
|
|
||||||
|
// we're already on creation_stream.device, no need to re-guard
|
||||||
|
sync_raw(creation_stream.stream, dummy_unifying_free_stream.stream);
|
||||||
|
|
||||||
|
// The number of usage streams is typically small (low single digits)
|
||||||
|
for (const auto& recorded_stream : recorded_streams) {
|
||||||
|
// Logic here accommodates the chance some of the usage streams were on
|
||||||
|
// other devices, which is possible if some usage kernels accessed the
|
||||||
|
// memory via p2p.
|
||||||
|
|
||||||
|
// cudaEventRecord requires that the input event and stream are on the
|
||||||
|
// same device.
|
||||||
|
CUDAGuard g_usage(recorded_stream.device);
|
||||||
|
|
||||||
|
sync_raw(recorded_stream.stream, dummy_unifying_free_stream.stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Frees ptr in the dummy "unifier" stream.
|
||||||
|
C10_CUDA_CHECK(cudaFreeAsync(it->first, dummy_unifying_free_stream.stream));
|
||||||
|
// At this point, unless dummy_unifying_free_stream happens to alias some
|
||||||
|
// future user stream, the allocation is only available for "opportunistic"
|
||||||
|
// reuse, ie, if the CPU sees dummy_unifying_free_stream has reached the
|
||||||
|
// point that all events recorded on all usage streams have resolved from
|
||||||
|
// the CPU's perspective. In theory, we could remove the need for the driver
|
||||||
|
// to do this tracking by e.g. replacing
|
||||||
|
// cudaStreamWaitEvent(dummy_unifying_free_stream.stream, event);
|
||||||
|
// with
|
||||||
|
// cudaStreamWaitEvent(creation_stream.stream, event);
|
||||||
|
// then cudaFreeAsyncing straight back into creation_stream.stream,
|
||||||
|
// but this forces a potentially false dependency of creation_stream.stream
|
||||||
|
// on all the recorded_streams.
|
||||||
|
|
||||||
|
if (C10_UNLIKELY(capture_underway)) {
|
||||||
|
// See Note [Avoid dangling free streams during CUDA graph capture]
|
||||||
|
capture_free_streams.insert(UsageStream(
|
||||||
|
dummy_unifying_free_stream.stream,
|
||||||
|
dummy_unifying_free_stream.device));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pytorch_used_bytes[creation_stream.device] -= it->second.size;
|
||||||
|
|
||||||
|
ptr_info.erase(it);
|
||||||
|
}
|
||||||
|
|
||||||
|
void freeAsync(void* ptr) {
|
||||||
|
std::lock_guard<std::mutex> lk(general_mutex);
|
||||||
|
|
||||||
|
auto err = cudaGetLastError();
|
||||||
|
C10_CUDA_CHECK(err);
|
||||||
|
auto it = ptr_info.find(ptr);
|
||||||
|
TORCH_INTERNAL_ASSERT(it != ptr_info.end(), "ptr not found in ptr_info");
|
||||||
|
|
||||||
|
if (C10_UNLIKELY(capture_underway)) {
|
||||||
|
if (!it->second.captured) {
|
||||||
|
TORCH_WARN_ONCE(
|
||||||
|
"freeAsync() was called on an uncaptured allocation during graph capture "
|
||||||
|
"(address = ",
|
||||||
|
ptr,
|
||||||
|
"). This may be benign, for example, a Python tensor in the capture "
|
||||||
|
"might happen to shadow (use the same name as) an unrelated temporary "
|
||||||
|
"tensor from somewhere before capture, pushing the earlier tensor "
|
||||||
|
"out of scope. "
|
||||||
|
"However, if the tensor we're freeing here IS used by the capture, "
|
||||||
|
"freeing it is an error, and may cause illegal memory accesses or "
|
||||||
|
"memory corruption during graph replay.");
|
||||||
|
// See Note [Avoid freeing uncaptured ptrs during CUDA graph capture]
|
||||||
|
// Remembers the raw pointer, not the iterator.
|
||||||
|
// This forces notifyCaptureEnded to do another lookup,
|
||||||
|
// but avoids the risk the iterator might be invalidated
|
||||||
|
// between now and then.
|
||||||
|
ungraphed_ptrs_defer_free_until_no_capture.push_back(ptr);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
} else if (C10_UNLIKELY(it->second.captured)) {
|
||||||
|
TORCH_WARN(
|
||||||
|
"Attempting uncaptured free of a captured allocation with address ",
|
||||||
|
ptr,
|
||||||
|
"\nThis is technically allowed, but may indicate you are losing "
|
||||||
|
"the last user-visible tensor through which the allocation can "
|
||||||
|
"be accessed, so you'll have no way to view the data after "
|
||||||
|
"future replays of the owning graph.");
|
||||||
|
}
|
||||||
|
|
||||||
|
free_impl(it);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Symmetric with NativeCachingAllocator::malloc for now,
|
||||||
|
// although I don't think we absolutely need the symmetry.
|
||||||
|
void mallocAsync(void** devPtr, int device, size_t size, cudaStream_t stream) {
|
||||||
|
TORCH_INTERNAL_ASSERT(
|
||||||
|
0 <= device && device < device_count,
|
||||||
|
"Invalid device index ",
|
||||||
|
device,
|
||||||
|
": did you call init?");
|
||||||
|
|
||||||
|
// If stream is a null (default) stream,
|
||||||
|
// cudaMallocAsync infers the device from the ambient context,
|
||||||
|
// so we need to set the right ambient context.
|
||||||
|
CUDAGuard g(device);
|
||||||
|
|
||||||
|
std::lock_guard<std::mutex> lk(general_mutex);
|
||||||
|
|
||||||
|
lazy_init_device(device);
|
||||||
|
|
||||||
|
// Defensively checks for preexisting CUDA error state.
|
||||||
|
auto err = cudaGetLastError();
|
||||||
|
C10_CUDA_CHECK(err);
|
||||||
|
|
||||||
|
// TODO: Could we avoid calling cudaMallocAsync while holding general_mutex,
|
||||||
|
// perhaps by letting lazy_init_device use separate once_flags or an internal
|
||||||
|
// static initializer?
|
||||||
|
if (pytorch_used_bytes[device] + size > pytorch_memory_limits[device]) {
|
||||||
|
err = cudaErrorMemoryAllocation;
|
||||||
|
} else {
|
||||||
|
err = cudaMallocAsync(devPtr, size, stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (err == cudaErrorMemoryAllocation) {
|
||||||
|
// Clears CUDA's internal error state so the user, if desired, can catch the
|
||||||
|
// OOM exception, free some stuff on the script side, and retry the
|
||||||
|
// allocation. This aligns with the behavior of alloc_block in
|
||||||
|
// CUDACachingAllocator.cpp.
|
||||||
|
cudaGetLastError();
|
||||||
|
size_t device_free;
|
||||||
|
size_t device_total;
|
||||||
|
C10_CUDA_CHECK(cudaMemGetInfo(&device_free, &device_total));
|
||||||
|
TORCH_CHECK_WITH(
|
||||||
|
OutOfMemoryError,
|
||||||
|
false,
|
||||||
|
"Allocation on device ",
|
||||||
|
device,
|
||||||
|
" would exceed allowed memory. (out of memory)",
|
||||||
|
"\nCurrently allocated : ",
|
||||||
|
format_size(pytorch_used_bytes[device]),
|
||||||
|
"\nRequested : ",
|
||||||
|
format_size(size),
|
||||||
|
"\nDevice limit : ",
|
||||||
|
format_size(device_total),
|
||||||
|
"\nFree (according to CUDA): ",
|
||||||
|
format_size(device_free),
|
||||||
|
"\nPyTorch limit (set by user-supplied memory fraction)"
|
||||||
|
"\n : ",
|
||||||
|
format_size(pytorch_memory_limits[device]));
|
||||||
|
} else {
|
||||||
|
C10_CUDA_CHECK(err);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto inserted = ptr_info.emplace(*devPtr, PtrUsage(size, capture_underway));
|
||||||
|
TORCH_INTERNAL_ASSERT(
|
||||||
|
inserted.second,
|
||||||
|
"address returned by cudaMallocAsync already exists "
|
||||||
|
"in ptr_info");
|
||||||
|
|
||||||
|
inserted.first->second.creation_stream = {stream, device};
|
||||||
|
|
||||||
|
pytorch_used_bytes[device] += size;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // anonymous namespace
|
||||||
|
|
||||||
|
// Same pattern as CUDACachingAllocator.cpp.
|
||||||
|
struct CudaMallocAsyncAllocator : public Allocator {
|
||||||
|
DataPtr allocate(size_t size) const override {
|
||||||
|
constexpr size_t one_exa_bytes = 1152921504606846976ULL;
|
||||||
|
TORCH_CHECK_WITH(
|
||||||
|
OutOfMemoryError,
|
||||||
|
size < one_exa_bytes,
|
||||||
|
"CUDA out of memory. Tried to allocate more than 1EB memory.");
|
||||||
|
int device;
|
||||||
|
C10_CUDA_CHECK(cudaGetDevice(&device));
|
||||||
|
void* r = nullptr;
|
||||||
|
if (size != 0) {
|
||||||
|
mallocAsync(&r, device, size, cuda::getCurrentCUDAStream(device));
|
||||||
|
}
|
||||||
|
return {r, r, &raw_delete, Device(DeviceType::CUDA, device)};
|
||||||
|
}
|
||||||
|
DeleterFnPtr raw_deleter() const override {
|
||||||
|
return &raw_delete;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
CudaMallocAsyncAllocator device_allocator;
|
||||||
|
|
||||||
|
// Interface functions declared in CUDACachingAllocator.h
|
||||||
|
|
||||||
|
Allocator* get(void) {
|
||||||
|
return &device_allocator;
|
||||||
|
}
|
||||||
|
|
||||||
|
// This function should not issue any context-creating calls,
|
||||||
|
// just set up for later calls to init per-device pools based
|
||||||
|
// on the current device each later call sees.
|
||||||
|
void init(int dev_count) {
|
||||||
|
static bool called = [](int dev_count) {
|
||||||
|
;
|
||||||
|
// Are there external guarantees init will be called before
|
||||||
|
// any of the allocator's other functions?
|
||||||
|
// std::lock_guard<std::mutex> lk(general_mutex);
|
||||||
|
device_count = dev_count;
|
||||||
|
devs_initialized_flags.resize(dev_count, false);
|
||||||
|
dummy_unifying_free_streams.resize(dev_count);
|
||||||
|
pytorch_used_bytes.resize(dev_count);
|
||||||
|
pytorch_memory_limits.resize(dev_count);
|
||||||
|
return true;
|
||||||
|
}(dev_count);
|
||||||
|
(void)called;
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline void assertValidDevice(int device) {
|
||||||
|
TORCH_CHECK(0 <= device && device < device_count, "Invalid device argument.");
|
||||||
|
}
|
||||||
|
|
||||||
|
void setMemoryFraction(double fraction, int device) {
|
||||||
|
TORCH_INTERNAL_ASSERT(
|
||||||
|
0 <= fraction && fraction <= 1,
|
||||||
|
"invalid fraction:",
|
||||||
|
fraction,
|
||||||
|
". Please set within (0, 1).");
|
||||||
|
|
||||||
|
std::lock_guard<std::mutex> lk(general_mutex);
|
||||||
|
assertValidDevice(device);
|
||||||
|
CUDAGuard g(device);
|
||||||
|
// Should setMemoryFraction be allowed to trigger a full device context and
|
||||||
|
// pool-creating lazy_init_device, or should we simply assert this device is
|
||||||
|
// already initialized, ie
|
||||||
|
// TORCH_CHECK(devs_initialized_flags[device], ...)?
|
||||||
|
lazy_init_device(device);
|
||||||
|
|
||||||
|
size_t device_free;
|
||||||
|
size_t device_total;
|
||||||
|
C10_CUDA_CHECK(cudaMemGetInfo(&device_free, &device_total));
|
||||||
|
pytorch_memory_limits[device] =
|
||||||
|
static_cast<uint64_t>(fraction * device_total);
|
||||||
|
|
||||||
|
// Alternative: Instead of a manual hard limit, we could use
|
||||||
|
// cudaMemPoolSetAttribute(mempool, cudaMemPoolAttrReleaseThreshold,
|
||||||
|
// &threshold); This is a soft hint: The driver allows the pool's reserved
|
||||||
|
// memory to spike above threshold in regions of high cudaMallocAsync demand,
|
||||||
|
// but opportunistically trims reserved memory back to threshold when the
|
||||||
|
// memory in use is < threshold. I don't like this because it introduces
|
||||||
|
// performance nondeterminism.
|
||||||
|
}
|
||||||
|
|
||||||
|
void emptyCache(void) {
|
||||||
|
std::lock_guard<std::mutex> lk(general_mutex);
|
||||||
|
|
||||||
|
for (int dev = 0; dev < device_count; dev++) {
|
||||||
|
if (devs_initialized_flags[dev]) {
|
||||||
|
CUDAGuard g(dev);
|
||||||
|
|
||||||
|
cudaMemPool_t mempool;
|
||||||
|
cudaDeviceGetDefaultMemPool(&mempool, dev);
|
||||||
|
cudaDeviceSynchronize();
|
||||||
|
cudaMemPoolTrimTo(mempool, 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void cacheInfo(int device, size_t* maxWorkspaceGuess) {
|
||||||
|
// The only consumer of cacheInfo is getMaxWorkspaceSize in Conv_v7.cpp.
|
||||||
|
// Afaict, the role of cacheInfo is to give getMaxWorkspaceSize a reasonable
|
||||||
|
// maximum workspace size to use for an upcoming cudnnFind call.
|
||||||
|
//
|
||||||
|
// The native allocator's cacheInfo chooses to return the size of its largest
|
||||||
|
// unused block (which is the largest allocation the native allocator can
|
||||||
|
// service immediately and asynchronously without a cudaMalloc.
|
||||||
|
//
|
||||||
|
// Here, we use a different heuristic: figure out the max usable workspace
|
||||||
|
// size with a bit of educated trial and error. It's ok to be perf-inefficient
|
||||||
|
// because cacheInfo is a prelude to cudnnFind.
|
||||||
|
//
|
||||||
|
// The algo cache then stores the best-performing algo with workspace <=
|
||||||
|
// maxWorkspaceGuess. Later calls with the same param set hit in cache and try
|
||||||
|
// to allocate the same workspace. If, in one of those future calls, workspace
|
||||||
|
// allocation fails (ie because less ambient memory is available), the
|
||||||
|
// bindings rerun cudnnFind, including calling cacheInfo again beforehand to
|
||||||
|
// estimate a new (smaller) largest-available workspace. Over a few such
|
||||||
|
// calls, the cache should settle to the algo with a workspace size that's
|
||||||
|
// small enough to succeed every time (for that param set).
|
||||||
|
//
|
||||||
|
// So the strategy here is to return a rough, largeish guess and let the
|
||||||
|
// bindings retry to trim as needed over time.
|
||||||
|
//
|
||||||
|
// The only caveat is, even if a workspace is allocated without OOM errors now
|
||||||
|
// and in future calls, it's hard to be sure those later error-free
|
||||||
|
// cudaMallocAsyncs are fast and come straight from the pool (ie,
|
||||||
|
// cudaMallocAsync didn't need to reserve more memory from the system).
|
||||||
|
// Hopefully, after repeated workspace requests, the pool's reserved memory
|
||||||
|
// also stabilizes to a point where they all come straight from the pool.
|
||||||
|
std::lock_guard<std::mutex> lk(general_mutex);
|
||||||
|
assertValidDevice(device);
|
||||||
|
CUDAGuard g(device);
|
||||||
|
lazy_init_device(device);
|
||||||
|
|
||||||
|
size_t free_upper_bound;
|
||||||
|
size_t device_total;
|
||||||
|
C10_CUDA_CHECK(cudaMemGetInfo(&free_upper_bound, &device_total));
|
||||||
|
TORCH_INTERNAL_ASSERT(
|
||||||
|
free_upper_bound + pytorch_used_bytes[device] <= device_total);
|
||||||
|
size_t guess = std::min(
|
||||||
|
free_upper_bound,
|
||||||
|
pytorch_memory_limits[device] - pytorch_used_bytes[device]);
|
||||||
|
auto stream = c10::cuda::getCurrentCUDAStream();
|
||||||
|
void* dummy;
|
||||||
|
|
||||||
|
// Defensively checks for preexisting CUDA error state.
|
||||||
|
auto err = cudaGetLastError();
|
||||||
|
C10_CUDA_CHECK(err);
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
// Duplicates some logic from mallocAsync to work with the error state
|
||||||
|
// directly instead of repeatedly catching an exception thrown by
|
||||||
|
// mallocAsync.
|
||||||
|
if (pytorch_used_bytes[device] + guess > pytorch_memory_limits[device]) {
|
||||||
|
err = cudaErrorMemoryAllocation;
|
||||||
|
} else {
|
||||||
|
err = cudaMallocAsync(&dummy, guess, stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (err == cudaSuccess) {
|
||||||
|
cudaFreeAsync(dummy, stream);
|
||||||
|
*maxWorkspaceGuess = guess;
|
||||||
|
return;
|
||||||
|
} else if (err == cudaErrorMemoryAllocation) {
|
||||||
|
cudaGetLastError(); // clear CUDA error
|
||||||
|
guess >>= 1; // quick and dirty: try half the size next iteration
|
||||||
|
} else {
|
||||||
|
C10_CUDA_CHECK(err);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void* getBaseAllocation(void* ptr, size_t* size) {
|
||||||
|
std::lock_guard<std::mutex> lk(general_mutex);
|
||||||
|
|
||||||
|
auto it = ptr_info.find(ptr);
|
||||||
|
TORCH_INTERNAL_ASSERT(it != ptr_info.end(), "ptr not found in ptr_info");
|
||||||
|
|
||||||
|
if (size) {
|
||||||
|
*size = it->second.size;
|
||||||
|
}
|
||||||
|
|
||||||
|
return ptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
void recordStream(const DataPtr& ptr, cuda::CUDAStream stream) {
|
||||||
|
std::lock_guard<std::mutex> lk(general_mutex);
|
||||||
|
auto ptr_val = ptr.get();
|
||||||
|
// Empty tensor's storage().data() might be a null ptr. As there is no
|
||||||
|
// blocks associated with those tensors, it is fine to do nothing here.
|
||||||
|
if (!ptr_val) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// The pointer should exist in the map already.
|
||||||
|
auto it = ptr_info.find(ptr_val);
|
||||||
|
TORCH_INTERNAL_ASSERT(it != ptr_info.end(), "ptr not found in ptr_info");
|
||||||
|
|
||||||
|
UsageStream to_record{stream.stream(), stream.device_index()};
|
||||||
|
if (to_record == it->second.creation_stream) {
|
||||||
|
TORCH_WARN(
|
||||||
|
"Called record_stream on tensor whose original creation stream "
|
||||||
|
"matches the recorded stream. This is unnecessary and has no effect.");
|
||||||
|
} else {
|
||||||
|
it->second.recorded_streams.insert(to_record);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::mutex* getFreeMutex() {
|
||||||
|
return &cuda_free_mutex;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<void> getIpcDevPtr(std::string handle) {
|
||||||
|
TORCH_CHECK(
|
||||||
|
false,
|
||||||
|
"cudaMallocAsync does not yet support getIpcDevPtr. "
|
||||||
|
"If you need it, please file an issue describing your use case.");
|
||||||
|
}
|
||||||
|
|
||||||
|
void recordHistory(
|
||||||
|
bool enabled,
|
||||||
|
CreateContextFn context_recorder,
|
||||||
|
size_t alloc_trace_max_entries,
|
||||||
|
bool alloc_trace_record_context) {
|
||||||
|
TORCH_CHECK(
|
||||||
|
false,
|
||||||
|
"cudaMallocAsync does not yet support recordHistory. "
|
||||||
|
"If you need it, please file an issue describing your use case.");
|
||||||
|
}
|
||||||
|
|
||||||
|
void attachOutOfMemoryObserver(OutOfMemoryObserver observer) {
|
||||||
|
TORCH_CHECK(
|
||||||
|
false,
|
||||||
|
"cudaMallocAsync does not yet support attachOutOfMemoryObserver. "
|
||||||
|
"If you need it, please file an issue describing your use case.");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Collects stats for device.
|
||||||
|
// If device hasn't been used yet, returns 0s without creating a context.
|
||||||
|
DeviceStats getDeviceStats(int device) {
|
||||||
|
assertValidDevice(device);
|
||||||
|
|
||||||
|
// Memory currently reserved by the mempool
|
||||||
|
uint64_t reserved_mem_current = 0;
|
||||||
|
// High-water mark of memory reserved by the mempool since last reset
|
||||||
|
uint64_t reserved_mem_peak = 0;
|
||||||
|
// Memory currently in use by the mempool
|
||||||
|
uint64_t used_mem_current = 0;
|
||||||
|
// High-water mark of memory
|
||||||
|
uint64_t used_mem_peak = 0;
|
||||||
|
|
||||||
|
std::lock_guard<std::mutex> lk(general_mutex);
|
||||||
|
|
||||||
|
if (devs_initialized_flags[device]) {
|
||||||
|
CUDAGuard g(device);
|
||||||
|
|
||||||
|
cudaMemPool_t mempool;
|
||||||
|
C10_CUDA_CHECK(cudaDeviceGetDefaultMemPool(&mempool, device));
|
||||||
|
C10_CUDA_CHECK(cudaMemPoolGetAttribute(
|
||||||
|
mempool, cudaMemPoolAttrReservedMemCurrent, &reserved_mem_current));
|
||||||
|
|
||||||
|
C10_CUDA_CHECK(cudaMemPoolGetAttribute(
|
||||||
|
mempool, cudaMemPoolAttrReservedMemHigh, &reserved_mem_peak));
|
||||||
|
|
||||||
|
C10_CUDA_CHECK(cudaMemPoolGetAttribute(
|
||||||
|
mempool, cudaMemPoolAttrUsedMemCurrent, &used_mem_current));
|
||||||
|
|
||||||
|
C10_CUDA_CHECK(cudaMemPoolGetAttribute(
|
||||||
|
mempool, cudaMemPoolAttrUsedMemHigh, &used_mem_peak));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Many stat types are specific to the native allocator. We leave these
|
||||||
|
// untouched. Their "struct Stat"s will contain zeroed values.
|
||||||
|
DeviceStats stats;
|
||||||
|
|
||||||
|
// In the native allocator:
|
||||||
|
// allocated_bytes is the total bytes of blocks that have been malloc()ed
|
||||||
|
// and not yet free()d.
|
||||||
|
// active_bytes is the total bytes of blocks that have been malloc()ed but not
|
||||||
|
// yet released back into a free pool. In other words, it includes all
|
||||||
|
// allocated_bytes, as well as the bytes of "limbo state" blocks had have
|
||||||
|
// already been free()ed but not yet free_block()ed back into a pool due to
|
||||||
|
// outstanding stream_uses.
|
||||||
|
//
|
||||||
|
// Here, in the cudaMallocAsync allocator:
|
||||||
|
// We simply ask the driver's opinion about active memory.
|
||||||
|
// We don't bother distinguishing between allocated_bytes and active_bytes.
|
||||||
|
stats.allocated_bytes[static_cast<size_t>(StatType::AGGREGATE)].current =
|
||||||
|
used_mem_current;
|
||||||
|
stats.allocated_bytes[static_cast<size_t>(StatType::AGGREGATE)].peak =
|
||||||
|
used_mem_peak;
|
||||||
|
stats.active_bytes[static_cast<size_t>(StatType::AGGREGATE)].current =
|
||||||
|
used_mem_current;
|
||||||
|
stats.active_bytes[static_cast<size_t>(StatType::AGGREGATE)].peak =
|
||||||
|
used_mem_peak;
|
||||||
|
stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)].current =
|
||||||
|
reserved_mem_current;
|
||||||
|
stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)].peak =
|
||||||
|
reserved_mem_peak;
|
||||||
|
|
||||||
|
return stats;
|
||||||
|
}
|
||||||
|
|
||||||
|
void resetAccumulatedStats(int device) {
|
||||||
|
assertValidDevice(device);
|
||||||
|
TORCH_WARN_ONCE(
|
||||||
|
"For backend:cudaMallocAsync, resetAccumulatedStats has no effect.");
|
||||||
|
}
|
||||||
|
|
||||||
|
void resetPeakStats(int device) {
|
||||||
|
assertValidDevice(device);
|
||||||
|
|
||||||
|
CUDAGuard g(device);
|
||||||
|
cudaMemPool_t mempool;
|
||||||
|
C10_CUDA_CHECK(cudaDeviceGetDefaultMemPool(&mempool, device));
|
||||||
|
// Using zero as the reset value is the method recommended by Cuda driver
|
||||||
|
// team. Vivek Kini says:
|
||||||
|
// "Resetting to zero (which is the only valid value when setting
|
||||||
|
// ReservedMemHigh) resets it to ReservedMemCurrent inside the driver
|
||||||
|
// (same goes for UsedMemHigh/UsedMemCurrent)"
|
||||||
|
uint64_t zero = 0;
|
||||||
|
C10_CUDA_CHECK(
|
||||||
|
cudaMemPoolSetAttribute(mempool, cudaMemPoolAttrReservedMemHigh, &zero));
|
||||||
|
C10_CUDA_CHECK(
|
||||||
|
cudaMemPoolSetAttribute(mempool, cudaMemPoolAttrUsedMemHigh, &zero));
|
||||||
|
}
|
||||||
|
|
||||||
|
SnapshotInfo snapshot() {
|
||||||
|
TORCH_CHECK(
|
||||||
|
false,
|
||||||
|
"Calling snapshot with backend:cudaMallocAsync is not meaningful. "
|
||||||
|
"(For backend:native, snapshot returns a detailed summary of all "
|
||||||
|
"blocks tracked by the allocator, but the cudaMallocAsync backend "
|
||||||
|
"does not track individual blocks.)");
|
||||||
|
// Alternative: TORCH_WARN
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
// CUDAGraph interactions
|
||||||
|
void notifyCaptureBegin(
|
||||||
|
int device,
|
||||||
|
CaptureId_t graph_id,
|
||||||
|
MempoolId_t mempool_id) {
|
||||||
|
std::lock_guard<std::mutex> lk(general_mutex);
|
||||||
|
|
||||||
|
TORCH_INTERNAL_ASSERT(capture_free_streams.empty());
|
||||||
|
TORCH_CHECK(
|
||||||
|
!capture_underway, "Only one capture at a time is allowed in a process.")
|
||||||
|
capture_underway = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void notifyCaptureAboutToEnd(int device, CaptureId_t graph_id) {
|
||||||
|
assertValidDevice(device);
|
||||||
|
|
||||||
|
std::lock_guard<std::mutex> lk(general_mutex);
|
||||||
|
|
||||||
|
TORCH_CHECK(
|
||||||
|
capture_underway,
|
||||||
|
"CudaMallocAsync::notifyCaptureAboutToEnd called, "
|
||||||
|
"but CudaMallocAsync::capture_underway is false.");
|
||||||
|
|
||||||
|
auto capture_stream = cuda::getCurrentCUDAStream(device);
|
||||||
|
|
||||||
|
// See Note [Avoid dangling free streams during CUDA graph capture]
|
||||||
|
for (const auto& free_stream : capture_free_streams) {
|
||||||
|
// cudaEventRecord requires that the input event and stream are on the same
|
||||||
|
// device.
|
||||||
|
CUDAGuard g(free_stream.device);
|
||||||
|
|
||||||
|
// CUDACachingAllocator.cpp uses raw cuda events, as do we.
|
||||||
|
cudaEvent_t event;
|
||||||
|
C10_CUDA_CHECK(cudaEventCreateWithFlags(&event, cudaEventDisableTiming));
|
||||||
|
C10_CUDA_CHECK(cudaEventRecord(event, free_stream.stream));
|
||||||
|
C10_CUDA_CHECK(cudaStreamWaitEvent(capture_stream.stream(), event));
|
||||||
|
C10_CUDA_CHECK(cudaEventDestroy(event));
|
||||||
|
}
|
||||||
|
|
||||||
|
capture_free_streams.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
void notifyCaptureEnded(int device, CaptureId_t graph_id) {
|
||||||
|
assertValidDevice(device);
|
||||||
|
|
||||||
|
std::lock_guard<std::mutex> lk(general_mutex);
|
||||||
|
|
||||||
|
TORCH_CHECK(
|
||||||
|
capture_underway,
|
||||||
|
"CudaMallocAsync::notifyCaptureEnded called, "
|
||||||
|
"but CudaMallocAsync::capture_underway is false.");
|
||||||
|
capture_underway = false;
|
||||||
|
|
||||||
|
// See Note [Avoid freeing uncaptured ptrs during CUDA graph capture]
|
||||||
|
for (const auto ptr : ungraphed_ptrs_defer_free_until_no_capture) {
|
||||||
|
auto it = ptr_info.find(ptr);
|
||||||
|
TORCH_INTERNAL_ASSERT(it != ptr_info.end(), "ptr not found in ptr_info");
|
||||||
|
free_impl(it);
|
||||||
|
}
|
||||||
|
|
||||||
|
ungraphed_ptrs_defer_free_until_no_capture.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
void notifyCaptureDestroy(int device, MempoolId_t mempool_id) {
|
||||||
|
// Q: Do we need to do anything special here, like clear long-lived
|
||||||
|
// pointers created during the original capture (for example,
|
||||||
|
// tensors intended as the graph's I/O surface) that might still
|
||||||
|
// be resident in ptr_info?
|
||||||
|
// A: I don't think so.
|
||||||
|
// Those allocations survived capture because the user held
|
||||||
|
// explicit tensor references to them,
|
||||||
|
// Those tensors' destructors will call freeAsync() on each pointer
|
||||||
|
// when the user is done with them.
|
||||||
|
// The freeAsync()s will probably incur
|
||||||
|
// TORCH_WARN("Attempting uncaptured free of a captured allocation..."
|
||||||
|
// but stale ptrs will not permanently leak into ptr_info.
|
||||||
|
}
|
||||||
|
|
||||||
|
void* raw_alloc(size_t nbytes) {
|
||||||
|
if (nbytes == 0) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
int device;
|
||||||
|
C10_CUDA_CHECK(cudaGetDevice(&device));
|
||||||
|
void* r = nullptr;
|
||||||
|
mallocAsync(&r, device, nbytes, cuda::getCurrentCUDAStream(device));
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
|
||||||
|
void* raw_alloc_with_stream(size_t nbytes, cudaStream_t stream) {
|
||||||
|
if (nbytes == 0) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
int device;
|
||||||
|
C10_CUDA_CHECK(cudaGetDevice(&device));
|
||||||
|
void* r = nullptr;
|
||||||
|
mallocAsync(&r, device, nbytes, stream);
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
|
||||||
|
void raw_delete(void* ptr) {
|
||||||
|
freeAsync(ptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
#else
|
||||||
|
// CUDA_VERSION is < 11040
|
||||||
|
|
||||||
|
#define NOT_AVAILABLE(name) \
|
||||||
|
TORCH_CHECK( \
|
||||||
|
false, \
|
||||||
|
"Called CudaMallocAsync::" name \
|
||||||
|
" but PyTorch was built with cuda < 11.4.");
|
||||||
|
|
||||||
|
void* raw_alloc(size_t nbytes) {
|
||||||
|
NOT_AVAILABLE("raw_alloc");
|
||||||
|
}
|
||||||
|
void* raw_alloc_with_stream(size_t nbytes, cudaStream_t stream) {
|
||||||
|
NOT_AVAILABLE("raw_alloc_with_stream");
|
||||||
|
}
|
||||||
|
void raw_delete(void* ptr){NOT_AVAILABLE("raw_delete")} Allocator* get() {
|
||||||
|
NOT_AVAILABLE("get");
|
||||||
|
}
|
||||||
|
void init(int device_count) {
|
||||||
|
NOT_AVAILABLE("init");
|
||||||
|
}
|
||||||
|
void setMemoryFraction(double fraction, int device) {
|
||||||
|
NOT_AVAILABLE("setMemoryFraction");
|
||||||
|
}
|
||||||
|
void emptyCache() {
|
||||||
|
NOT_AVAILABLE("emptyCache");
|
||||||
|
}
|
||||||
|
void cacheInfo(int device, size_t* maxWorkspaceGuess) {
|
||||||
|
NOT_AVAILABLE("cacheInfo");
|
||||||
|
}
|
||||||
|
void* getBaseAllocation(void* ptr, size_t* size) {
|
||||||
|
NOT_AVAILABLE("getBaseAllocation");
|
||||||
|
}
|
||||||
|
void recordStream(const DataPtr&, CUDAStream stream) {
|
||||||
|
NOT_AVAILABLE("recordStream");
|
||||||
|
}
|
||||||
|
DeviceStats getDeviceStats(int device) {
|
||||||
|
NOT_AVAILABLE("getDeviceStats");
|
||||||
|
}
|
||||||
|
void resetAccumulatedStats(int device) {
|
||||||
|
NOT_AVAILABLE("resetAccumulatedStats");
|
||||||
|
}
|
||||||
|
void resetPeakStats(int device) {
|
||||||
|
NOT_AVAILABLE("resetPeakStats");
|
||||||
|
}
|
||||||
|
SnapshotInfo snapshot() {
|
||||||
|
NOT_AVAILABLE("snapshot");
|
||||||
|
}
|
||||||
|
void notifyCaptureBegin(
|
||||||
|
int device,
|
||||||
|
CaptureId_t graph_id,
|
||||||
|
MempoolId_t mempool_id) {
|
||||||
|
NOT_AVAILABLE("notifyCaptureBegin");
|
||||||
|
}
|
||||||
|
void notifyCaptureAboutToEnd(int device, CaptureId_t graph_id) {
|
||||||
|
NOT_AVAILABLE("notifyCaptureAboutToEnd");
|
||||||
|
}
|
||||||
|
void notifyCaptureEnded(int device, CaptureId_t graph_id) {
|
||||||
|
NOT_AVAILABLE("notifyCaptureEnded");
|
||||||
|
}
|
||||||
|
void notifyCaptureDestroy(int device, MempoolId_t mempool_id) {
|
||||||
|
NOT_AVAILABLE("notifyCaptureDestroy");
|
||||||
|
}
|
||||||
|
std::mutex* getFreeMutex() {
|
||||||
|
NOT_AVAILABLE("getFreeMutex");
|
||||||
|
}
|
||||||
|
std::shared_ptr<void> getIpcDevPtr(std::string handle) {
|
||||||
|
NOT_AVAILABLE("getIpcDevPtr");
|
||||||
|
}
|
||||||
|
void recordHistory(
|
||||||
|
bool enabled,
|
||||||
|
CreateContextFn context_recorder,
|
||||||
|
size_t alloc_trace_max_entries,
|
||||||
|
bool alloc_trace_record_context) {
|
||||||
|
NOT_AVAILABLE("recordHistory");
|
||||||
|
}
|
||||||
|
void attachOutOfMemoryObserver(OutOfMemoryObserver observer) {
|
||||||
|
NOT_AVAILABLE("attachOutOfMemoryObserver");
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
} // namespace CudaMallocAsync
|
||||||
|
} // namespace CUDACachingAllocator
|
||||||
|
} // namespace cuda
|
||||||
|
} // namespace c10
|
||||||
|
|
@ -87,6 +87,8 @@ Graphs (beta)
|
||||||
graph
|
graph
|
||||||
make_graphed_callables
|
make_graphed_callables
|
||||||
|
|
||||||
|
.. _cuda-memory-management-api:
|
||||||
|
|
||||||
Memory management
|
Memory management
|
||||||
-----------------
|
-----------------
|
||||||
.. autosummary::
|
.. autosummary::
|
||||||
|
|
@ -111,6 +113,7 @@ Memory management
|
||||||
reset_peak_memory_stats
|
reset_peak_memory_stats
|
||||||
caching_allocator_alloc
|
caching_allocator_alloc
|
||||||
caching_allocator_delete
|
caching_allocator_delete
|
||||||
|
get_allocator_backend
|
||||||
.. FIXME The following doesn't seem to exist. Is it supposed to?
|
.. FIXME The following doesn't seem to exist. Is it supposed to?
|
||||||
https://github.com/pytorch/pytorch/issues/27785
|
https://github.com/pytorch/pytorch/issues/27785
|
||||||
.. autofunction:: reset_max_memory_reserved
|
.. autofunction:: reset_max_memory_reserved
|
||||||
|
|
|
||||||
|
|
@ -40,7 +40,7 @@ only use the subset of Python supported in TorchScript. This section documents
|
||||||
what is supported in TorchScript as if it were a language reference for a stand
|
what is supported in TorchScript as if it were a language reference for a stand
|
||||||
alone language. Any features of Python not mentioned in this reference are not
|
alone language. Any features of Python not mentioned in this reference are not
|
||||||
part of TorchScript. See `Builtin Functions` for a complete reference of available
|
part of TorchScript. See `Builtin Functions` for a complete reference of available
|
||||||
Pytorch tensor methods, modules, and functions.
|
PyTorch tensor methods, modules, and functions.
|
||||||
|
|
||||||
As a subset of Python, any valid TorchScript function is also a valid Python
|
As a subset of Python, any valid TorchScript function is also a valid Python
|
||||||
function. This makes it possible to `disable TorchScript` and debug the
|
function. This makes it possible to `disable TorchScript` and debug the
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
.. _jit_unsupported:
|
.. _jit_unsupported:
|
||||||
|
|
||||||
TorchScript Unsupported Pytorch Constructs
|
TorchScript Unsupported PyTorch Constructs
|
||||||
============================================
|
============================================
|
||||||
|
|
||||||
Torch and Tensor Unsupported Attributes
|
Torch and Tensor Unsupported Attributes
|
||||||
|
|
|
||||||
|
|
@ -349,27 +349,42 @@ complete snapshot of the memory allocator state via
|
||||||
:meth:`~torch.cuda.memory_snapshot`, which can help you understand the
|
:meth:`~torch.cuda.memory_snapshot`, which can help you understand the
|
||||||
underlying allocation patterns produced by your code.
|
underlying allocation patterns produced by your code.
|
||||||
|
|
||||||
|
.. _cuda-memory-envvars:
|
||||||
|
|
||||||
|
Environment variables
|
||||||
|
^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
Use of a caching allocator can interfere with memory checking tools such as
|
Use of a caching allocator can interfere with memory checking tools such as
|
||||||
``cuda-memcheck``. To debug memory errors using ``cuda-memcheck``, set
|
``cuda-memcheck``. To debug memory errors using ``cuda-memcheck``, set
|
||||||
``PYTORCH_NO_CUDA_MEMORY_CACHING=1`` in your environment to disable caching.
|
``PYTORCH_NO_CUDA_MEMORY_CACHING=1`` in your environment to disable caching.
|
||||||
|
|
||||||
The behavior of caching allocator can be controlled via environment variable
|
The behavior of the caching allocator can be controlled via the environment variable
|
||||||
``PYTORCH_CUDA_ALLOC_CONF``.
|
``PYTORCH_CUDA_ALLOC_CONF``.
|
||||||
The format is ``PYTORCH_CUDA_ALLOC_CONF=<option>:<value>,<option2>:<value2>...``
|
The format is ``PYTORCH_CUDA_ALLOC_CONF=<option>:<value>,<option2>:<value2>...``
|
||||||
Available options:
|
Available options:
|
||||||
|
|
||||||
* ``max_split_size_mb`` prevents the allocator from splitting blocks larger
|
* ``backend`` allows selecting the underlying allocator implementation.
|
||||||
than this size (in MB). This can help prevent fragmentation and may allow
|
Currently, valid options are ``native``, which uses PyTorch's native
|
||||||
some borderline workloads to complete without running out of memory.
|
implementation, and ``cudaMallocAsync``, which uses
|
||||||
Performance cost can range from 'zero' to 'substatial' depending on
|
`CUDA's built-in asynchronous allocator`_.
|
||||||
allocation patterns. Default value is unlimited, i.e. all blocks can be
|
``cudaMallocAsync`` requires CUDA 11.4 or newer. The default is ``native``.
|
||||||
split. The :meth:`~torch.cuda.memory_stats` and
|
``backend`` applies to all devices used by the process, and can't be
|
||||||
|
specified on a per-device basis.
|
||||||
|
* ``max_split_size_mb`` prevents the native allocator
|
||||||
|
from splitting blocks larger than this size (in MB). This can reduce
|
||||||
|
fragmentation and may allow some borderline workloads to complete without
|
||||||
|
running out of memory. Performance cost can range from 'zero' to 'substantial'
|
||||||
|
depending on allocation patterns. Default value is unlimited, i.e. all blocks
|
||||||
|
can be split. The
|
||||||
|
:meth:`~torch.cuda.memory_stats` and
|
||||||
:meth:`~torch.cuda.memory_summary` methods are useful for tuning. This
|
:meth:`~torch.cuda.memory_summary` methods are useful for tuning. This
|
||||||
option should be used as a last resort for a workload that is aborting
|
option should be used as a last resort for a workload that is aborting
|
||||||
due to 'out of memory' and showing a large amount of inactive split blocks.
|
due to 'out of memory' and showing a large amount of inactive split blocks.
|
||||||
|
``max_split_size_mb`` is only meaningful with ``backend:native``.
|
||||||
|
With ``backend:cudaMallocAsync``, ``max_split_size_mb`` is ignored.
|
||||||
* ``roundup_power2_divisions`` helps with rounding the requested allocation
|
* ``roundup_power2_divisions`` helps with rounding the requested allocation
|
||||||
size to nearest power-2 division and making better use of the blocks. In
|
size to nearest power-2 division and making better use of the blocks. In
|
||||||
the current CUDACachingAllocator, the sizes are rounded up in multiple
|
the native CUDACachingAllocator, the sizes are rounded up in multiple
|
||||||
of blocks size of 512, so this works fine for smaller sizes. However, this
|
of blocks size of 512, so this works fine for smaller sizes. However, this
|
||||||
can be inefficient for large near-by allocations as each will go to different
|
can be inefficient for large near-by allocations as each will go to different
|
||||||
size of blocks and re-use of those blocks are minimized. This might create
|
size of blocks and re-use of those blocks are minimized. This might create
|
||||||
|
|
@ -379,10 +394,14 @@ Available options:
|
||||||
the size 1200 lies between 1024 and 2048 and if we do 4 divisions between
|
the size 1200 lies between 1024 and 2048 and if we do 4 divisions between
|
||||||
them, the values are 1024, 1280, 1536, and 1792. So, allocation size of 1200
|
them, the values are 1024, 1280, 1536, and 1792. So, allocation size of 1200
|
||||||
will be rounded to 1280 as the nearest ceiling of power-2 division.
|
will be rounded to 1280 as the nearest ceiling of power-2 division.
|
||||||
|
``roundup_power2_divisions`` is only meaningful with ``backend:native``.
|
||||||
|
With ``backend:cudaMallocAsync``, ``roundup_power2_divisions`` is ignored.
|
||||||
* ``roundup_bypass_threshold_mb`` bypass rounding the requested allocation size,
|
* ``roundup_bypass_threshold_mb`` bypass rounding the requested allocation size,
|
||||||
for allocation requests larger than the threshold value (in MB). This can help
|
for allocation requests larger than the threshold value (in MB). This can help
|
||||||
reduce the memory footprint when making large allocations that are expected to
|
reduce the memory footprint when making large allocations that are expected to
|
||||||
be persistent or have a large lifetime.
|
be persistent or have a large lifetime.
|
||||||
|
``roundup_bypass_threshold_mb`` is only meaningful with ``backend:native``.
|
||||||
|
With ``backend:cudaMallocAsync``, ``roundup_bypass_threshold_mb`` is ignored.
|
||||||
* ``garbage_collection_threshold`` helps actively reclaiming unused GPU memory to
|
* ``garbage_collection_threshold`` helps actively reclaiming unused GPU memory to
|
||||||
avoid triggering expensive sync-and-reclaim-all operation (release_cached_blocks),
|
avoid triggering expensive sync-and-reclaim-all operation (release_cached_blocks),
|
||||||
which can be unfavorable to latency-critical GPU applications (e.g., servers).
|
which can be unfavorable to latency-critical GPU applications (e.g., servers).
|
||||||
|
|
@ -391,6 +410,19 @@ Available options:
|
||||||
80% of the total memory allocated to the GPU application). The algorithm prefers
|
80% of the total memory allocated to the GPU application). The algorithm prefers
|
||||||
to free old & unused blocks first to avoid freeing blocks that are actively being
|
to free old & unused blocks first to avoid freeing blocks that are actively being
|
||||||
reused. The threshold value should be between greater than 0.0 and less than 1.0.
|
reused. The threshold value should be between greater than 0.0 and less than 1.0.
|
||||||
|
``garbage_collection_threshold`` is only meaningful with ``backend:native``.
|
||||||
|
With ``backend:cudaMallocAsync``, ``garbage_collection_threshold`` is ignored.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
Some stats reported by the
|
||||||
|
:ref:`CUDA memory management API<cuda-memory-management-api>`
|
||||||
|
are specific to ``backend:native``, and are not meaningful with
|
||||||
|
``backend:cudaMallocAsync``.
|
||||||
|
See each function's docstring for details.
|
||||||
|
|
||||||
|
.. _CUDA's built-in asynchronous allocator:
|
||||||
|
https://developer.nvidia.com/blog/using-cuda-stream-ordered-memory-allocator-part-1/
|
||||||
|
|
||||||
.. _cufft-plan-cache:
|
.. _cufft-plan-cache:
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,7 @@ feature. A section at the end discusses the extensions for forward mode AD.
|
||||||
When to use
|
When to use
|
||||||
^^^^^^^^^^^
|
^^^^^^^^^^^
|
||||||
In general, implement a custom function if you want to perform computations in your model
|
In general, implement a custom function if you want to perform computations in your model
|
||||||
that are not differentiable or rely on non-Pytorch libraries (e.g., NumPy), but
|
that are not differentiable or rely on non-PyTorch libraries (e.g., NumPy), but
|
||||||
still wish for your operation to chain with other ops and work with the autograd engine.
|
still wish for your operation to chain with other ops and work with the autograd engine.
|
||||||
|
|
||||||
In some situations, custom functions can also be used to improve performance and
|
In some situations, custom functions can also be used to improve performance and
|
||||||
|
|
|
||||||
|
|
@ -103,7 +103,7 @@ There are three types of quantization supported:
|
||||||
3. static quantization aware training (weights quantized, activations quantized,
|
3. static quantization aware training (weights quantized, activations quantized,
|
||||||
quantization numerics modeled during training)
|
quantization numerics modeled during training)
|
||||||
|
|
||||||
Please see our `Introduction to Quantization on Pytorch
|
Please see our `Introduction to Quantization on PyTorch
|
||||||
<https://pytorch.org/blog/introduction-to-quantization-on-pytorch/>`_ blog post
|
<https://pytorch.org/blog/introduction-to-quantization-on-pytorch/>`_ blog post
|
||||||
for a more comprehensive overview of the tradeoffs between these quantization
|
for a more comprehensive overview of the tradeoffs between these quantization
|
||||||
types.
|
types.
|
||||||
|
|
|
||||||
|
|
@ -142,7 +142,7 @@ only:
|
||||||
Sparse hybrid COO tensors
|
Sparse hybrid COO tensors
|
||||||
-------------------------
|
-------------------------
|
||||||
|
|
||||||
Pytorch implements an extension of sparse tensors with scalar values
|
PyTorch implements an extension of sparse tensors with scalar values
|
||||||
to sparse tensors with (contiguous) tensor values. Such tensors are
|
to sparse tensors with (contiguous) tensor values. Such tensors are
|
||||||
called hybrid tensors.
|
called hybrid tensors.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -44,6 +44,7 @@ if not TEST_CUDA:
|
||||||
print('CUDA not available, skipping tests', file=sys.stderr)
|
print('CUDA not available, skipping tests', file=sys.stderr)
|
||||||
TestCase = object # noqa: F811
|
TestCase = object # noqa: F811
|
||||||
|
|
||||||
|
TEST_CUDAMALLOCASYNC = TEST_CUDA and (torch.cuda.get_allocator_backend() == "cudaMallocAsync")
|
||||||
TEST_LARGE_TENSOR = TEST_CUDA
|
TEST_LARGE_TENSOR = TEST_CUDA
|
||||||
TEST_MEDIUM_TENSOR = TEST_CUDA
|
TEST_MEDIUM_TENSOR = TEST_CUDA
|
||||||
TEST_CUDNN = TEST_CUDA
|
TEST_CUDNN = TEST_CUDA
|
||||||
|
|
@ -271,6 +272,7 @@ class TestCuda(TestCase):
|
||||||
self.assertEqual(r, 0)
|
self.assertEqual(r, 0)
|
||||||
self.assertFalse(t.is_pinned())
|
self.assertFalse(t.is_pinned())
|
||||||
|
|
||||||
|
@unittest.skipIf(TEST_CUDAMALLOCASYNC, "temporarily disabled")
|
||||||
def test_memory_stats(self):
|
def test_memory_stats(self):
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
@ -322,6 +324,7 @@ class TestCuda(TestCase):
|
||||||
device_capability_no_argument = torch.cuda.get_device_capability()
|
device_capability_no_argument = torch.cuda.get_device_capability()
|
||||||
self.assertEqual(current_device_capability, device_capability_no_argument)
|
self.assertEqual(current_device_capability, device_capability_no_argument)
|
||||||
|
|
||||||
|
@unittest.skipIf(TEST_CUDAMALLOCASYNC, "temporarily disabled")
|
||||||
@unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
|
@unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
|
||||||
def test_memory_stats_multigpu(self):
|
def test_memory_stats_multigpu(self):
|
||||||
# advance a generator with a end flag
|
# advance a generator with a end flag
|
||||||
|
|
@ -362,7 +365,9 @@ class TestCuda(TestCase):
|
||||||
def test_out_of_memory(self):
|
def test_out_of_memory(self):
|
||||||
tensor = torch.zeros(1024, device='cuda')
|
tensor = torch.zeros(1024, device='cuda')
|
||||||
|
|
||||||
with self.assertRaisesRegex(RuntimeError, "Tried to allocate 800000000.00 GiB"):
|
oom_regex = "would exceed allowed memory" if TEST_CUDAMALLOCASYNC else \
|
||||||
|
"Tried to allocate 800000000.00 GiB"
|
||||||
|
with self.assertRaisesRegex(RuntimeError, oom_regex):
|
||||||
torch.empty(1024 * 1024 * 1024 * 800000000, dtype=torch.int8, device='cuda')
|
torch.empty(1024 * 1024 * 1024 * 800000000, dtype=torch.int8, device='cuda')
|
||||||
|
|
||||||
with self.assertRaisesRegex(RuntimeError, "Tried to allocate more than 1EB memory"):
|
with self.assertRaisesRegex(RuntimeError, "Tried to allocate more than 1EB memory"):
|
||||||
|
|
@ -372,6 +377,22 @@ class TestCuda(TestCase):
|
||||||
tensor.fill_(1)
|
tensor.fill_(1)
|
||||||
self.assertTrue((tensor == 1).all())
|
self.assertTrue((tensor == 1).all())
|
||||||
|
|
||||||
|
def test_out_of_memory_retry(self):
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
total_memory = torch.cuda.get_device_properties(0).total_memory
|
||||||
|
oom_regex = "would exceed allowed memory" if TEST_CUDAMALLOCASYNC else \
|
||||||
|
"Tried to allocate"
|
||||||
|
size = int(total_memory * 0.5)
|
||||||
|
a = torch.empty(size , dtype=torch.int8, device='cuda')
|
||||||
|
with self.assertRaisesRegex(RuntimeError, oom_regex):
|
||||||
|
b = torch.empty(size, dtype=torch.int8, device='cuda')
|
||||||
|
del a
|
||||||
|
b = torch.empty(size, dtype=torch.int8, device='cuda')
|
||||||
|
del b
|
||||||
|
# We used a lot of memory here, clean up so we don't affect other tests too much
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
|
||||||
def test_set_per_process_memory_fraction(self):
|
def test_set_per_process_memory_fraction(self):
|
||||||
# test invalid fraction value.
|
# test invalid fraction value.
|
||||||
with self.assertRaisesRegex(TypeError, "Invalid type"):
|
with self.assertRaisesRegex(TypeError, "Invalid type"):
|
||||||
|
|
@ -394,7 +415,9 @@ class TestCuda(TestCase):
|
||||||
|
|
||||||
application = int(total_memory * 0.5)
|
application = int(total_memory * 0.5)
|
||||||
# it will get OOM when try to allocate more than half memory.
|
# it will get OOM when try to allocate more than half memory.
|
||||||
with self.assertRaisesRegex(RuntimeError, "out of memory"):
|
oom_regex = "would exceed allowed memory" if TEST_CUDAMALLOCASYNC else \
|
||||||
|
"out of memory"
|
||||||
|
with self.assertRaisesRegex(RuntimeError, oom_regex):
|
||||||
torch.empty(application, dtype=torch.int8, device='cuda')
|
torch.empty(application, dtype=torch.int8, device='cuda')
|
||||||
|
|
||||||
# ensure out of memory error doesn't disturb subsequent kernel
|
# ensure out of memory error doesn't disturb subsequent kernel
|
||||||
|
|
@ -1297,11 +1320,13 @@ class TestCuda(TestCase):
|
||||||
|
|
||||||
self.assertEqual(result.tolist(), [1, 2, 3, 4])
|
self.assertEqual(result.tolist(), [1, 2, 3, 4])
|
||||||
|
|
||||||
# Check that the block will be re-used after the main stream finishes
|
if not TEST_CUDAMALLOCASYNC:
|
||||||
torch.cuda.current_stream().synchronize()
|
# In the native allocator, we expect "tmp"'s side-stream-tagged block will be reused
|
||||||
with torch.cuda.stream(stream):
|
# in that side stream after result.copy_(tmp) in the main stream finishes.
|
||||||
tmp3 = torch.cuda.FloatTensor(t.size())
|
torch.cuda.current_stream().synchronize()
|
||||||
self.assertEqual(tmp3.data_ptr(), ptr[0], msg='allocation not re-used')
|
with torch.cuda.stream(stream):
|
||||||
|
tmp3 = torch.cuda.FloatTensor(t.size())
|
||||||
|
self.assertEqual(tmp3.data_ptr(), ptr[0], msg='allocation not re-used')
|
||||||
|
|
||||||
def test_record_stream_on_shifted_view(self):
|
def test_record_stream_on_shifted_view(self):
|
||||||
# See issue #27366
|
# See issue #27366
|
||||||
|
|
@ -3231,7 +3256,9 @@ torch.cuda.synchronize()
|
||||||
TEST_WITH_ROCM or
|
TEST_WITH_ROCM or
|
||||||
int(torch.version.cuda.split(".")[0]) < 11, "CUDA >= 11.0 required for graphs")
|
int(torch.version.cuda.split(".")[0]) < 11, "CUDA >= 11.0 required for graphs")
|
||||||
def test_graph_capture_oom(self):
|
def test_graph_capture_oom(self):
|
||||||
with self.assertRaisesRegex(RuntimeError, "out of memory"):
|
oom_regex = "would exceed allowed memory" if TEST_CUDAMALLOCASYNC else \
|
||||||
|
"out of memory"
|
||||||
|
with self.assertRaisesRegex(RuntimeError, oom_regex):
|
||||||
with torch.cuda.graph(torch.cuda.CUDAGraph()):
|
with torch.cuda.graph(torch.cuda.CUDAGraph()):
|
||||||
torch.zeros(2 ** 40, device="cuda")
|
torch.zeros(2 ** 40, device="cuda")
|
||||||
|
|
||||||
|
|
@ -3416,11 +3443,19 @@ torch.cuda.synchronize()
|
||||||
g.capture_end()
|
g.capture_end()
|
||||||
torch.cuda.current_stream().wait_stream(stream)
|
torch.cuda.current_stream().wait_stream(stream)
|
||||||
|
|
||||||
try:
|
if not TEST_CUDAMALLOCASYNC:
|
||||||
self.assertNotEqual(control1, t1)
|
# Makes sure values haven't been populated yet
|
||||||
self.assertNotEqual(control2, t2)
|
# (in other words, makes sure capture didn't actually run ops).
|
||||||
except Exception as e:
|
# We can only try this with the native allocator, for which captured
|
||||||
raise RuntimeError("Failed on " + module + "." + op) from e
|
# addresses are already backed by cudaMalloced memory.
|
||||||
|
# If we try it with cudaMallocAsync, CUDA won't event consider
|
||||||
|
# the captured addresses allocated until replay(), and if we
|
||||||
|
# access them before replay() we get IMAs.
|
||||||
|
try:
|
||||||
|
self.assertNotEqual(control1, t1)
|
||||||
|
self.assertNotEqual(control2, t2)
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError("Failed on " + module + "." + op) from e
|
||||||
|
|
||||||
# Set a new seed to check if graph would use it
|
# Set a new seed to check if graph would use it
|
||||||
for seed in [6, 314, 271]:
|
for seed in [6, 314, 271]:
|
||||||
|
|
@ -3442,8 +3477,11 @@ torch.cuda.synchronize()
|
||||||
else:
|
else:
|
||||||
getattr(dummy, op)(*args)
|
getattr(dummy, op)(*args)
|
||||||
|
|
||||||
t1.copy_(alloc)
|
# see above comment on TEST_CUDAMALLOCASYNC
|
||||||
t2.copy_(alloc)
|
if not TEST_CUDAMALLOCASYNC:
|
||||||
|
t1.copy_(alloc)
|
||||||
|
t2.copy_(alloc)
|
||||||
|
|
||||||
# Runs RNG ops that fill t1 and t2.
|
# Runs RNG ops that fill t1 and t2.
|
||||||
g.replay()
|
g.replay()
|
||||||
|
|
||||||
|
|
@ -3516,22 +3554,27 @@ torch.cuda.synchronize()
|
||||||
self.assertEqual(b.sum().item(), size * 3070)
|
self.assertEqual(b.sum().item(), size * 3070)
|
||||||
self.assertEqual(c.sum().item(), size * 442)
|
self.assertEqual(c.sum().item(), size * 442)
|
||||||
|
|
||||||
if share_mem != "Don't share":
|
if not TEST_CUDAMALLOCASYNC:
|
||||||
self.assertEqual(reserved_no_sharing - torch.cuda.memory_stats()["reserved_bytes.all.current"],
|
# These stat checks are specific to the native allocator.
|
||||||
kSmallBuffer)
|
if share_mem != "Don't share":
|
||||||
else:
|
self.assertEqual(reserved_no_sharing - torch.cuda.memory_stats()["reserved_bytes.all.current"],
|
||||||
reserved_no_sharing = torch.cuda.memory_stats()["reserved_bytes.all.current"]
|
kSmallBuffer)
|
||||||
|
else:
|
||||||
|
reserved_no_sharing = torch.cuda.memory_stats()["reserved_bytes.all.current"]
|
||||||
|
|
||||||
del a, b, c, g0, g1
|
del a, b, c, g0, g1
|
||||||
# Tensors used across streams (a and b) were held until just now, so no need to call record_stream on them.
|
# Tensors used across streams (a and b) were held until just now, so no need to call record_stream on them.
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
@unittest.skip("Temporarily disabled due to a graphs bug in libcuda.so, " +
|
|
||||||
"see https://github.com/pytorch/pytorch/pull/57556")
|
|
||||||
@unittest.skipIf((not TEST_CUDA) or
|
@unittest.skipIf((not TEST_CUDA) or
|
||||||
TEST_WITH_ROCM or
|
TEST_WITH_ROCM or
|
||||||
int(torch.version.cuda.split(".")[0]) < 11, "CUDA >= 11.0 required for graphs")
|
IS_WINDOWS or # appears to still be broken on Windows as of 11.4+
|
||||||
|
int(torch.version.cuda.split(".")[0]) < 11 or
|
||||||
|
(int(torch.version.cuda.split(".")[0]) == 11 and
|
||||||
|
int(torch.version.cuda.split(".")[1]) < 4),
|
||||||
|
"Graph bindings disallow concurrent replay for CUDA < 11.4, see " +
|
||||||
|
"https://github.com/pytorch/pytorch/pull/57556")
|
||||||
def test_graph_concurrent_replay(self):
|
def test_graph_concurrent_replay(self):
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
@ -3582,12 +3625,16 @@ torch.cuda.synchronize()
|
||||||
torch.cuda.current_stream().wait_stream(s0)
|
torch.cuda.current_stream().wait_stream(s0)
|
||||||
torch.cuda.current_stream().wait_stream(s1)
|
torch.cuda.current_stream().wait_stream(s1)
|
||||||
|
|
||||||
if share_mem != "Don't share":
|
if (not TEST_CUDAMALLOCASYNC) and (share_mem != "Don't share"):
|
||||||
# Confirms concurrent replays using the same mempool corrupted each other.
|
# If we used the native allocator and shared mempools,
|
||||||
|
# we expect the concurrent replays corrupted each other.
|
||||||
self.assertNotEqual(b.sum().item(), size * 94)
|
self.assertNotEqual(b.sum().item(), size * 94)
|
||||||
self.assertNotEqual(c.sum().item(), size * 156)
|
self.assertNotEqual(c.sum().item(), size * 156)
|
||||||
else:
|
else:
|
||||||
# Confirms concurrent replays using different mempools did not corrupt each other.
|
# If we EITHER
|
||||||
|
# - used the native allocator without sharing mempools, OR
|
||||||
|
# - used cudaMallocAsync, which ignores graph pool-sharing hints and should always be safe
|
||||||
|
# we don't expect memory corruption.
|
||||||
self.assertEqual(b.sum().item(), size * 94)
|
self.assertEqual(b.sum().item(), size * 94)
|
||||||
self.assertEqual(c.sum().item(), size * 156)
|
self.assertEqual(c.sum().item(), size * 156)
|
||||||
|
|
||||||
|
|
@ -3647,9 +3694,10 @@ torch.cuda.synchronize()
|
||||||
g2.replay()
|
g2.replay()
|
||||||
g1.replay()
|
g1.replay()
|
||||||
|
|
||||||
# If share_mem is True, g2's capture should have reused c's memory for f. We replayed g2 then g1,
|
expect_corruption = (not TEST_CUDAMALLOCASYNC) and (share_mem != "Don't share")
|
||||||
# so we expect g1's captured "e = c + 3" mistakenly filled e with "f's vals + 3".
|
# If we used the native allocator and shared mempools, g2's capture should have reused c's memory for f.
|
||||||
self.assertEqual(e.sum().item(), size * (7 + 3) if share_mem != "Don't share" else size * 5)
|
# We replayed g2 then g1, so we expect g1's captured "e = c + 3" mistakenly filled e with "f's vals + 3".
|
||||||
|
self.assertEqual(e.sum().item(), size * (7 + 3) if expect_corruption else size * 5)
|
||||||
self.assertEqual(f.sum().item(), size * 7)
|
self.assertEqual(f.sum().item(), size * 7)
|
||||||
|
|
||||||
del a, b, d, e, f, g0, g1, g2
|
del a, b, d, e, f, g0, g1, g2
|
||||||
|
|
@ -3659,6 +3707,7 @@ torch.cuda.synchronize()
|
||||||
|
|
||||||
@unittest.skipIf((not TEST_CUDA) or
|
@unittest.skipIf((not TEST_CUDA) or
|
||||||
TEST_WITH_ROCM or
|
TEST_WITH_ROCM or
|
||||||
|
TEST_CUDAMALLOCASYNC or
|
||||||
int(torch.version.cuda.split(".")[0]) < 11, "CUDA >= 11.0 required for graphs")
|
int(torch.version.cuda.split(".")[0]) < 11, "CUDA >= 11.0 required for graphs")
|
||||||
def test_graph_memory_stats_and_use_result_after_destroy_graph(self):
|
def test_graph_memory_stats_and_use_result_after_destroy_graph(self):
|
||||||
kSmallSize = 1048576
|
kSmallSize = 1048576
|
||||||
|
|
@ -3831,6 +3880,8 @@ torch.cuda.synchronize()
|
||||||
g.capture_end()
|
g.capture_end()
|
||||||
torch.cuda.current_stream().wait_stream(s)
|
torch.cuda.current_stream().wait_stream(s)
|
||||||
|
|
||||||
|
g.replay()
|
||||||
|
|
||||||
y = model(x)
|
y = model(x)
|
||||||
|
|
||||||
@unittest.skipIf((not TEST_CUDA) or
|
@unittest.skipIf((not TEST_CUDA) or
|
||||||
|
|
@ -4624,6 +4675,7 @@ class TestCudaComm(TestCase):
|
||||||
cat = torch.cat((outputs[0][i].to('cpu'), outputs[1][i].to('cpu')))
|
cat = torch.cat((outputs[0][i].to('cpu'), outputs[1][i].to('cpu')))
|
||||||
self.assertTrue(torch.equal(x, cat))
|
self.assertTrue(torch.equal(x, cat))
|
||||||
|
|
||||||
|
@unittest.skipIf(TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync")
|
||||||
def test_memory_snapshot(self):
|
def test_memory_snapshot(self):
|
||||||
try:
|
try:
|
||||||
torch.cuda.memory.empty_cache()
|
torch.cuda.memory.empty_cache()
|
||||||
|
|
@ -4667,7 +4719,7 @@ class TestCudaComm(TestCase):
|
||||||
finally:
|
finally:
|
||||||
torch.cuda.memory._record_memory_history(False)
|
torch.cuda.memory._record_memory_history(False)
|
||||||
|
|
||||||
|
@unittest.skipIf(TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync")
|
||||||
def test_memory_snapshot_with_cpp(self):
|
def test_memory_snapshot_with_cpp(self):
|
||||||
try:
|
try:
|
||||||
torch.cuda.memory.empty_cache()
|
torch.cuda.memory.empty_cache()
|
||||||
|
|
@ -4703,7 +4755,7 @@ class TestCudaComm(TestCase):
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
torch.cuda.memory.empty_cache()
|
torch.cuda.memory.empty_cache()
|
||||||
key = 'active_bytes.all.allocated'
|
key = 'active_bytes.all.allocated' if not TEST_CUDAMALLOCASYNC else 'allocated_bytes.all.current'
|
||||||
|
|
||||||
nelems = 21 * 1024 * 1024
|
nelems = 21 * 1024 * 1024
|
||||||
nbytes = 4 * nelems # floats are 4 bytes
|
nbytes = 4 * nelems # floats are 4 bytes
|
||||||
|
|
@ -4719,7 +4771,9 @@ class TestCudaComm(TestCase):
|
||||||
pow2_div4_mem = torch.cuda.memory_stats()[key]
|
pow2_div4_mem = torch.cuda.memory_stats()[key]
|
||||||
|
|
||||||
self.assertTrue(reg_mem - start_mem == nbytes)
|
self.assertTrue(reg_mem - start_mem == nbytes)
|
||||||
self.assertTrue(pow2_div4_mem - reg_mem == power2_div(nbytes, 4))
|
if not TEST_CUDAMALLOCASYNC:
|
||||||
|
# not supported with the cudaMallocAsync backend
|
||||||
|
self.assertTrue(pow2_div4_mem - reg_mem == power2_div(nbytes, 4))
|
||||||
|
|
||||||
torch.cuda.memory._set_allocator_settings("garbage_collection_threshold:0.5")
|
torch.cuda.memory._set_allocator_settings("garbage_collection_threshold:0.5")
|
||||||
torch.cuda.memory._set_allocator_settings("garbage_collection_threshold:0.5,max_split_size_mb:40")
|
torch.cuda.memory._set_allocator_settings("garbage_collection_threshold:0.5,max_split_size_mb:40")
|
||||||
|
|
@ -4747,6 +4801,7 @@ class TestCudaComm(TestCase):
|
||||||
torch.empty(1024 * 1024 * 1024 * 1024, device='cuda')
|
torch.empty(1024 * 1024 * 1024 * 1024, device='cuda')
|
||||||
|
|
||||||
@unittest.skipIf(IS_WINDOWS, 'Windows CI does not like the load_inline')
|
@unittest.skipIf(IS_WINDOWS, 'Windows CI does not like the load_inline')
|
||||||
|
@unittest.skipIf(TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync")
|
||||||
def test_cpp_memory_snapshot_pickle(self):
|
def test_cpp_memory_snapshot_pickle(self):
|
||||||
from torch.utils.cpp_extension import load_inline
|
from torch.utils.cpp_extension import load_inline
|
||||||
source = """
|
source = """
|
||||||
|
|
@ -4778,6 +4833,7 @@ class TestCudaComm(TestCase):
|
||||||
finally:
|
finally:
|
||||||
m.record(False)
|
m.record(False)
|
||||||
|
|
||||||
|
@unittest.skipIf(TEST_CUDAMALLOCASYNC, "temporarily disabled")
|
||||||
def test_notifies_oom(self):
|
def test_notifies_oom(self):
|
||||||
x = False
|
x = False
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1181,6 +1181,7 @@ def _cuda_resetAccumulatedMemoryStats(device: _int) -> None: ...
|
||||||
def _cuda_resetPeakMemoryStats(device: _int) -> None: ...
|
def _cuda_resetPeakMemoryStats(device: _int) -> None: ...
|
||||||
def _cuda_memorySnapshot() -> Dict[str, Any]: ...
|
def _cuda_memorySnapshot() -> Dict[str, Any]: ...
|
||||||
def _cuda_recordMemoryHistory(enabled: _bool, record_context: _bool, record_context_cpp: _bool, alloc_trace_max_entries: _int, alloc_trace_record_context: _bool) -> None: ...
|
def _cuda_recordMemoryHistory(enabled: _bool, record_context: _bool, record_context_cpp: _bool, alloc_trace_max_entries: _int, alloc_trace_record_context: _bool) -> None: ...
|
||||||
|
def _cuda_getAllocatorBackend() -> str: ...
|
||||||
def _cuda_lock_mutex() -> None: ...
|
def _cuda_lock_mutex() -> None: ...
|
||||||
def _cuda_unlock_mutex() -> None: ...
|
def _cuda_unlock_mutex() -> None: ...
|
||||||
def _cuda_canDeviceAccessPeer(device: _int, peer_device: _int) -> _bool: ...
|
def _cuda_canDeviceAccessPeer(device: _int, peer_device: _int) -> _bool: ...
|
||||||
|
|
|
||||||
|
|
@ -355,6 +355,26 @@ PyObject* THCPModule_cudaCachingAllocator_set_allocator_settings(
|
||||||
END_HANDLE_TH_ERRORS
|
END_HANDLE_TH_ERRORS
|
||||||
}
|
}
|
||||||
|
|
||||||
|
PyObject* THCPModule_getAllocatorBackend(PyObject* _unused, PyObject* noargs) {
|
||||||
|
HANDLE_TH_ERRORS
|
||||||
|
using c10::cuda::CUDACachingAllocator::AllocatorBackend;
|
||||||
|
AllocatorBackend backend =
|
||||||
|
c10::cuda::CUDACachingAllocator::allocatorBackend();
|
||||||
|
// this call should be uncommon, don't bother interning strings
|
||||||
|
switch (backend) {
|
||||||
|
case AllocatorBackend::NATIVE:
|
||||||
|
return THPUtils_packString("native");
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
case AllocatorBackend::CUDAMALLOCASYNC:
|
||||||
|
return THPUtils_packString("cudaMallocAsync");
|
||||||
|
#endif
|
||||||
|
default:
|
||||||
|
THPUtils_assert(false, "Unexpected value for backend");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
END_HANDLE_TH_ERRORS
|
||||||
|
}
|
||||||
|
|
||||||
PyObject* THCPModule_cudaSynchronize(PyObject* _unused, PyObject* noargs) {
|
PyObject* THCPModule_cudaSynchronize(PyObject* _unused, PyObject* noargs) {
|
||||||
HANDLE_TH_ERRORS
|
HANDLE_TH_ERRORS
|
||||||
c10::cuda::device_synchronize();
|
c10::cuda::device_synchronize();
|
||||||
|
|
@ -1061,6 +1081,10 @@ static struct PyMethodDef _THCPModule_methods[] = {
|
||||||
THCPModule_cudaCachingAllocator_set_allocator_settings,
|
THCPModule_cudaCachingAllocator_set_allocator_settings,
|
||||||
METH_O,
|
METH_O,
|
||||||
nullptr},
|
nullptr},
|
||||||
|
{"_cuda_getAllocatorBackend",
|
||||||
|
THCPModule_getAllocatorBackend,
|
||||||
|
METH_NOARGS,
|
||||||
|
nullptr},
|
||||||
{"_cuda_synchronize", THCPModule_cudaSynchronize, METH_NOARGS, nullptr},
|
{"_cuda_synchronize", THCPModule_cudaSynchronize, METH_NOARGS, nullptr},
|
||||||
{"_cuda_ipc_collect", THCPModule_cudaIPCCollect, METH_NOARGS, nullptr},
|
{"_cuda_ipc_collect", THCPModule_cudaIPCCollect, METH_NOARGS, nullptr},
|
||||||
{"_cuda_sleep", THCPModule_cudaSleep, METH_O, nullptr},
|
{"_cuda_sleep", THCPModule_cudaSleep, METH_O, nullptr},
|
||||||
|
|
|
||||||
|
|
@ -837,10 +837,10 @@ __all__ = [
|
||||||
'CUDAGraph', 'CudaError', 'DeferredCudaCallError', 'Event', 'ExternalStream', 'OutOfMemoryError',
|
'CUDAGraph', 'CudaError', 'DeferredCudaCallError', 'Event', 'ExternalStream', 'OutOfMemoryError',
|
||||||
'Stream', 'StreamContext', 'amp', 'caching_allocator_alloc', 'caching_allocator_delete', 'can_device_access_peer',
|
'Stream', 'StreamContext', 'amp', 'caching_allocator_alloc', 'caching_allocator_delete', 'can_device_access_peer',
|
||||||
'check_error', 'cudaStatus', 'cudart', 'current_blas_handle', 'current_device', 'current_stream', 'default_generators',
|
'check_error', 'cudaStatus', 'cudart', 'current_blas_handle', 'current_device', 'current_stream', 'default_generators',
|
||||||
'default_stream', 'device', 'device_count', 'device_of', 'empty_cache', 'get_arch_list', 'get_device_capability',
|
'default_stream', 'device', 'device_count', 'device_of', 'empty_cache', 'get_allocator_backend', 'get_arch_list',
|
||||||
'get_device_name', 'get_device_properties', 'get_gencode_flags', 'get_rng_state', 'get_rng_state_all', 'get_sync_debug_mode',
|
'get_device_capability', 'get_device_name', 'get_device_properties', 'get_gencode_flags', 'get_rng_state', 'get_rng_state_all',
|
||||||
'graph', 'graph_pool_handle', 'graphs', 'has_half', 'has_magma', 'init', 'initial_seed', 'ipc_collect', 'is_available',
|
'get_sync_debug_mode', 'graph', 'graph_pool_handle', 'graphs', 'has_half', 'has_magma', 'init', 'initial_seed', 'ipc_collect',
|
||||||
'is_bf16_supported', 'is_current_stream_capturing', 'is_initialized', 'jiterator', 'list_gpu_processes',
|
'is_available', 'is_bf16_supported', 'is_current_stream_capturing', 'is_initialized', 'jiterator', 'list_gpu_processes',
|
||||||
'make_graphed_callables', 'manual_seed', 'manual_seed_all', 'max_memory_allocated', 'max_memory_cached', 'max_memory_reserved',
|
'make_graphed_callables', 'manual_seed', 'manual_seed_all', 'max_memory_allocated', 'max_memory_cached', 'max_memory_reserved',
|
||||||
'mem_get_info', 'memory', 'memory_allocated', 'memory_cached', 'memory_reserved', 'memory_snapshot', 'memory_stats',
|
'mem_get_info', 'memory', 'memory_allocated', 'memory_cached', 'memory_reserved', 'memory_snapshot', 'memory_stats',
|
||||||
'memory_stats_as_nested_dict', 'memory_summary', 'memory_usage', 'nccl', 'nvtx', 'profiler', 'random',
|
'memory_stats_as_nested_dict', 'memory_summary', 'memory_usage', 'nccl', 'nvtx', 'profiler', 'random',
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@ __all__ = ["caching_allocator_alloc", "caching_allocator_delete", "set_per_proce
|
||||||
"reset_peak_memory_stats", "reset_max_memory_allocated", "reset_max_memory_cached",
|
"reset_peak_memory_stats", "reset_max_memory_allocated", "reset_max_memory_cached",
|
||||||
"memory_allocated", "max_memory_allocated", "memory_reserved", "max_memory_reserved",
|
"memory_allocated", "max_memory_allocated", "memory_reserved", "max_memory_reserved",
|
||||||
"memory_cached", "max_memory_cached", "memory_snapshot", "memory_summary", "list_gpu_processes",
|
"memory_cached", "max_memory_cached", "memory_snapshot", "memory_summary", "list_gpu_processes",
|
||||||
"mem_get_info"]
|
"mem_get_info", "get_allocator_backend"]
|
||||||
|
|
||||||
def _host_allocator():
|
def _host_allocator():
|
||||||
_lazy_init()
|
_lazy_init()
|
||||||
|
|
@ -194,6 +194,10 @@ def memory_stats(device: Union[Device, int] = None) -> Dict[str, Any]:
|
||||||
.. note::
|
.. note::
|
||||||
See :ref:`cuda-memory-management` for more details about GPU memory
|
See :ref:`cuda-memory-management` for more details about GPU memory
|
||||||
management.
|
management.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
With :ref:`backend:cudaMallocAsync<cuda-memory-envvars>`, some stats are not
|
||||||
|
meaningful, and are always reported as zero.
|
||||||
"""
|
"""
|
||||||
result = []
|
result = []
|
||||||
|
|
||||||
|
|
@ -636,3 +640,14 @@ def _save_memory_usage(filename='output.svg', snapshot=None):
|
||||||
|
|
||||||
def _set_allocator_settings(env: str):
|
def _set_allocator_settings(env: str):
|
||||||
return torch._C._cuda_cudaCachingAllocator_set_allocator_settings(env)
|
return torch._C._cuda_cudaCachingAllocator_set_allocator_settings(env)
|
||||||
|
|
||||||
|
def get_allocator_backend() -> str:
|
||||||
|
r"""Returns a string describing the active allocator backend as set by
|
||||||
|
``PYTORCH_CUDA_ALLOC_CONF``. Currently available backends are
|
||||||
|
``native`` (PyTorch's native caching allocator) and `cudaMallocAsync``
|
||||||
|
(CUDA's built-in asynchronous allocator).
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
See :ref:`cuda-memory-management` for details on choosing the allocator backend.
|
||||||
|
"""
|
||||||
|
return torch._C._cuda_getAllocatorBackend()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user