ggml: Avoid cudaMemsetAsync during memory fitting

We pass invalid pointers when we check the size of the required
compute graph before fitting. Some CUDA APIs validate these pointers
but we can just skip them during this phase. cudaMemsetAsync is one
of these that we weren't skipping but never took the code path that
used it. Now that we have enabled op_offload, we can hit it in
memory pressured situations.
This commit is contained in:
Jesse Gross 2025-10-31 14:16:20 -07:00 committed by Jesse Gross
parent 3bee3af6ed
commit 392a270261
2 changed files with 28 additions and 8 deletions

View File

@ -11,9 +11,9 @@ must be recreated with no-alloc set to false before loading data.
ggml/include/ggml-backend.h | 1 + ggml/include/ggml-backend.h | 1 +
ggml/src/ggml-backend-impl.h | 16 +++ ggml/src/ggml-backend-impl.h | 16 +++
ggml/src/ggml-backend.cpp | 72 ++++++++++- ggml/src/ggml-backend.cpp | 72 ++++++++++-
ggml/src/ggml-cuda/common.cuh | 48 ++++++- ggml/src/ggml-cuda/common.cuh | 58 ++++++++-
ggml/src/ggml-cuda/ggml-cuda.cu | 217 ++++++++++++++++++++++++++------ ggml/src/ggml-cuda/ggml-cuda.cu | 217 ++++++++++++++++++++++++++------
5 files changed, 310 insertions(+), 44 deletions(-) 5 files changed, 320 insertions(+), 44 deletions(-)
diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h
index 2763f2bd6..b3b5b356a 100644 index 2763f2bd6..b3b5b356a 100644
@ -219,10 +219,10 @@ index 41eef3b5f..c81a2e48a 100644
void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend) { void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend) {
diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh
index e0abde542..28d6bcd71 100644 index e0abde542..e98044bd8 100644
--- a/ggml/src/ggml-cuda/common.cuh --- a/ggml/src/ggml-cuda/common.cuh
+++ b/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh
@@ -35,6 +35,31 @@ @@ -35,6 +35,41 @@
#include "vendors/cuda.h" #include "vendors/cuda.h"
#endif // defined(GGML_USE_HIP) #endif // defined(GGML_USE_HIP)
@ -246,15 +246,25 @@ index e0abde542..28d6bcd71 100644
+ } + }
+} +}
+ +
+static cudaError_t cudaMemsetAsyncReserve ( void* devPtr, int value, size_t count, cudaStream_t stream = 0 ) {
+ if (!reserving_graph) {
+ return cudaMemsetAsync(devPtr, value, count, stream);
+ } else {
+ return cudaSuccess;
+ }
+}
+
+#undef cudaMemcpyAsync +#undef cudaMemcpyAsync
+#define cudaMemcpyAsync cudaMemcpyAsyncReserve +#define cudaMemcpyAsync cudaMemcpyAsyncReserve
+#undef cudaMemcpy2DAsync +#undef cudaMemcpy2DAsync
+#define cudaMemcpy2DAsync cudaMemcpy2DAsyncReserve +#define cudaMemcpy2DAsync cudaMemcpy2DAsyncReserve
+#undef cudaMemsetAsync
+#define cudaMemsetAsync cudaMemsetAsyncReserve
+ +
#define STRINGIZE_IMPL(...) #__VA_ARGS__ #define STRINGIZE_IMPL(...) #__VA_ARGS__
#define STRINGIZE(...) STRINGIZE_IMPL(__VA_ARGS__) #define STRINGIZE(...) STRINGIZE_IMPL(__VA_ARGS__)
@@ -856,6 +881,9 @@ struct ggml_cuda_pool { @@ -856,6 +891,9 @@ struct ggml_cuda_pool {
virtual void * alloc(size_t size, size_t * actual_size) = 0; virtual void * alloc(size_t size, size_t * actual_size) = 0;
virtual void free(void * ptr, size_t size) = 0; virtual void free(void * ptr, size_t size) = 0;
@ -264,7 +274,7 @@ index e0abde542..28d6bcd71 100644
}; };
template<typename T> template<typename T>
@@ -999,11 +1027,11 @@ struct ggml_backend_cuda_context { @@ -999,11 +1037,11 @@ struct ggml_backend_cuda_context {
// pool // pool
std::unique_ptr<ggml_cuda_pool> pools[GGML_CUDA_MAX_DEVICES]; std::unique_ptr<ggml_cuda_pool> pools[GGML_CUDA_MAX_DEVICES];
@ -278,7 +288,7 @@ index e0abde542..28d6bcd71 100644
} }
return *pools[device]; return *pools[device];
} }
@@ -1011,4 +1039,20 @@ struct ggml_backend_cuda_context { @@ -1011,4 +1049,20 @@ struct ggml_backend_cuda_context {
ggml_cuda_pool & pool() { ggml_cuda_pool & pool() {
return pool(device); return pool(device);
} }
@ -300,7 +310,7 @@ index e0abde542..28d6bcd71 100644
+ } + }
}; };
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index f4d4a4267..ac70dcac8 100644 index c555cd30f..eb3db0f19 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu --- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -350,6 +350,8 @@ const ggml_cuda_device_info & ggml_cuda_info() { @@ -350,6 +350,8 @@ const ggml_cuda_device_info & ggml_cuda_info() {

View File

@ -55,10 +55,20 @@ static cudaError_t cudaMemcpy2DAsyncReserve ( void* dst, size_t dpitch, const vo
} }
} }
static cudaError_t cudaMemsetAsyncReserve ( void* devPtr, int value, size_t count, cudaStream_t stream = 0 ) {
if (!reserving_graph) {
return cudaMemsetAsync(devPtr, value, count, stream);
} else {
return cudaSuccess;
}
}
#undef cudaMemcpyAsync #undef cudaMemcpyAsync
#define cudaMemcpyAsync cudaMemcpyAsyncReserve #define cudaMemcpyAsync cudaMemcpyAsyncReserve
#undef cudaMemcpy2DAsync #undef cudaMemcpy2DAsync
#define cudaMemcpy2DAsync cudaMemcpy2DAsyncReserve #define cudaMemcpy2DAsync cudaMemcpy2DAsyncReserve
#undef cudaMemsetAsync
#define cudaMemsetAsync cudaMemsetAsyncReserve
#define STRINGIZE_IMPL(...) #__VA_ARGS__ #define STRINGIZE_IMPL(...) #__VA_ARGS__
#define STRINGIZE(...) STRINGIZE_IMPL(__VA_ARGS__) #define STRINGIZE(...) STRINGIZE_IMPL(__VA_ARGS__)