[CUDAGraph Trees] support memory allocation on side stream (#152472)

I tried `beginAllocateToPool` instead of `_cuda_beginAllocateCurrentStreamToPool` and the error in #151199 does not happen any more.

However, this approach is unsafe for multithreading. When multiple run_eager happens concurrently, we expect memory allocation to different mem_pool. Since beginAllocateToPool does not check stream, these memory allocation may happen on the same mem_pool.

So, I use `_cuda_beginAllocateCurrentThreadToPool` to direct all memory allocation on the same thread to a given mem_pool. In particular, `_cuda_beginAllocateCurrentThreadToPool` records the launching thread id, and during runtime checks if the current thread id matches the launching thread id.

Fixes #151199

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152472
Approved by: https://github.com/eellison
This commit is contained in:
Boyuan Feng 2025-04-30 17:45:07 +00:00 committed by PyTorch MergeBot
parent 0904a182c2
commit c620763ec2
4 changed files with 82 additions and 2 deletions

View File

@ -2061,6 +2061,49 @@ if HAS_CUDA:
with self.assertRaisesRegex(Exception, "custom error msg"):
device = x.untyped_storage()
def test_side_stream_memory_allocation(self):
from torch._inductor.cudagraph_trees import cudagraphify_impl
def multi_stream_allocation(args):
side_stream = torch.cuda.Stream()
side_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(side_stream):
side_stream_buffer = torch.ones(
*args, device="cuda:0", dtype=torch.float32
)
torch.cuda.current_stream().wait_stream(side_stream)
main_stream_buffer = torch.ones(
*args, device="cuda:0", dtype=torch.float32
)
if isinstance(args, list):
args.clear()
return main_stream_buffer, side_stream_buffer
graphed_multi_stream_func = cudagraphify_impl(
multi_stream_allocation,
inputs=[],
static_input_idxs=[],
is_backward=False,
is_inference=False,
device_index=0,
stack_traces=["dummy stack trace1", "dummy stack trace2"],
)
ref_out = torch.ones((2, 3), device="cuda:0", dtype=torch.float32)
for _ in range(3):
torch.compiler.cudagraph_mark_step_begin()
main_stream_buffer, side_stream_buffer = graphed_multi_stream_func(
[2, 3]
)
self.assertEqual(main_stream_buffer, ref_out)
self.assertEqual(side_stream_buffer, ref_out)
self.assertEqual(self.get_manager().new_graph_id().id, 1)
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", False)
@torch._inductor.config.patch("triton.cudagraph_support_input_mutation", False)
def test_static_inputs_address_mutation_log(self):

View File

@ -1926,6 +1926,8 @@ def _cuda_cudaCachingAllocator_raw_delete(ptr: _int) -> None: ...
def _cuda_cudaCachingAllocator_enable(val: _bool) -> None: ...
def _cuda_cudaCachingAllocator_set_allocator_settings(env: str) -> None: ...
def _cuda_beginAllocateToPool(device: _int, mempool_id: Tuple[_int, _int]) -> None: ...
def _cuda_beginAllocateCurrentThreadToPool(device: _int, mempool_id: Tuple[_int, _int]) -> None: ...
def _cuda_endAllocateToPool(device: _int, mempool_id: Tuple[_int, _int]) -> None: ...
def _cuda_beginAllocateCurrentStreamToPool(device: _int, mempool_id: Tuple[_int, _int]) -> None: ...
def _cuda_endAllocateCurrentStreamToPool(device: _int, mempool_id: Tuple[_int, _int]) -> None: ...
def _cuda_releasePool(device: _int, mempool_id: Tuple[_int, _int]) -> None: ...

View File

@ -555,11 +555,14 @@ def _use_cuda_memory_pool_manager(
stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(stream), torch.device(device):
torch._C._cuda_beginAllocateCurrentStreamToPool(device, mem_pool)
# Begin allocate to mem pool for all memory allocation on the current thread.
# This is thread safe since a thread can only warmup or record 1 cudagraph
# at the same time.
torch._C._cuda_beginAllocateCurrentThreadToPool(device, mem_pool)
try:
yield
finally:
torch._C._cuda_endAllocateCurrentStreamToPool(device, mem_pool)
torch._C._cuda_endAllocateToPool(device, mem_pool)
torch._C._cuda_releasePool(device, mem_pool)
torch.cuda.current_stream().wait_stream(stream)

View File

@ -26,6 +26,13 @@
#include <c10/cuda/CUDAFunctions.h>
#include <ATen/cuda/CUDAGraphsUtils.cuh>
#if !defined(_MSC_VER)
#include <sys/types.h>
#include <unistd.h>
#elif defined(_MSC_VER)
#include <c10/util/win32-headers.h>
#endif
#ifdef USE_NCCL
#include <torch/csrc/cuda/python_nccl.h>
#endif
@ -1388,6 +1395,31 @@ static void registerCudaPluggableAllocator(PyObject* module) {
device, mempool_id, [](cudaStream_t) { return true; });
});
m.def(
"_cuda_beginAllocateCurrentThreadToPool",
[](c10::DeviceIndex device, at::cuda::MempoolId_t mempool_id) {
#ifdef _MSC_VER
auto pid = GetCurrentProcessId();
#else
auto pid = getpid();
#endif
c10::cuda::CUDACachingAllocator::beginAllocateToPool(
device, mempool_id, [=](cudaStream_t) {
#ifdef _MSC_VER
auto current_pid = GetCurrentProcessId();
#else
auto current_pid = getpid();
#endif
return current_pid == pid;
});
});
m.def(
"_cuda_endAllocateToPool",
[](c10::DeviceIndex device, at::cuda::MempoolId_t mempool_id) {
c10::cuda::CUDACachingAllocator::endAllocateToPool(device, mempool_id);
});
m.def(
"_cuda_endAllocateCurrentStreamToPool",
[](c10::DeviceIndex device, at::cuda::MempoolId_t mempool_id) {