mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[CUDAGraph Trees] support memory allocation on side stream (#152472)
I tried `beginAllocateToPool` instead of `_cuda_beginAllocateCurrentStreamToPool` and the error in #151199 does not happen any more. However, this approach is unsafe for multithreading. When multiple run_eager happens concurrently, we expect memory allocation to different mem_pool. Since beginAllocateToPool does not check stream, these memory allocation may happen on the same mem_pool. So, I use `_cuda_beginAllocateCurrentThreadToPool` to direct all memory allocation on the same thread to a given mem_pool. In particular, `_cuda_beginAllocateCurrentThreadToPool` records the launching thread id, and during runtime checks if the current thread id matches the launching thread id. Fixes #151199 Pull Request resolved: https://github.com/pytorch/pytorch/pull/152472 Approved by: https://github.com/eellison
This commit is contained in:
parent
0904a182c2
commit
c620763ec2
|
|
@ -2061,6 +2061,49 @@ if HAS_CUDA:
|
|||
with self.assertRaisesRegex(Exception, "custom error msg"):
|
||||
device = x.untyped_storage()
|
||||
|
||||
def test_side_stream_memory_allocation(self):
|
||||
from torch._inductor.cudagraph_trees import cudagraphify_impl
|
||||
|
||||
def multi_stream_allocation(args):
|
||||
side_stream = torch.cuda.Stream()
|
||||
side_stream.wait_stream(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(side_stream):
|
||||
side_stream_buffer = torch.ones(
|
||||
*args, device="cuda:0", dtype=torch.float32
|
||||
)
|
||||
torch.cuda.current_stream().wait_stream(side_stream)
|
||||
|
||||
main_stream_buffer = torch.ones(
|
||||
*args, device="cuda:0", dtype=torch.float32
|
||||
)
|
||||
|
||||
if isinstance(args, list):
|
||||
args.clear()
|
||||
|
||||
return main_stream_buffer, side_stream_buffer
|
||||
|
||||
graphed_multi_stream_func = cudagraphify_impl(
|
||||
multi_stream_allocation,
|
||||
inputs=[],
|
||||
static_input_idxs=[],
|
||||
is_backward=False,
|
||||
is_inference=False,
|
||||
device_index=0,
|
||||
stack_traces=["dummy stack trace1", "dummy stack trace2"],
|
||||
)
|
||||
|
||||
ref_out = torch.ones((2, 3), device="cuda:0", dtype=torch.float32)
|
||||
|
||||
for _ in range(3):
|
||||
torch.compiler.cudagraph_mark_step_begin()
|
||||
main_stream_buffer, side_stream_buffer = graphed_multi_stream_func(
|
||||
[2, 3]
|
||||
)
|
||||
self.assertEqual(main_stream_buffer, ref_out)
|
||||
self.assertEqual(side_stream_buffer, ref_out)
|
||||
|
||||
self.assertEqual(self.get_manager().new_graph_id().id, 1)
|
||||
|
||||
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", False)
|
||||
@torch._inductor.config.patch("triton.cudagraph_support_input_mutation", False)
|
||||
def test_static_inputs_address_mutation_log(self):
|
||||
|
|
|
|||
|
|
@ -1926,6 +1926,8 @@ def _cuda_cudaCachingAllocator_raw_delete(ptr: _int) -> None: ...
|
|||
def _cuda_cudaCachingAllocator_enable(val: _bool) -> None: ...
|
||||
def _cuda_cudaCachingAllocator_set_allocator_settings(env: str) -> None: ...
|
||||
def _cuda_beginAllocateToPool(device: _int, mempool_id: Tuple[_int, _int]) -> None: ...
|
||||
def _cuda_beginAllocateCurrentThreadToPool(device: _int, mempool_id: Tuple[_int, _int]) -> None: ...
|
||||
def _cuda_endAllocateToPool(device: _int, mempool_id: Tuple[_int, _int]) -> None: ...
|
||||
def _cuda_beginAllocateCurrentStreamToPool(device: _int, mempool_id: Tuple[_int, _int]) -> None: ...
|
||||
def _cuda_endAllocateCurrentStreamToPool(device: _int, mempool_id: Tuple[_int, _int]) -> None: ...
|
||||
def _cuda_releasePool(device: _int, mempool_id: Tuple[_int, _int]) -> None: ...
|
||||
|
|
|
|||
|
|
@ -555,11 +555,14 @@ def _use_cuda_memory_pool_manager(
|
|||
stream.wait_stream(torch.cuda.current_stream())
|
||||
|
||||
with torch.cuda.stream(stream), torch.device(device):
|
||||
torch._C._cuda_beginAllocateCurrentStreamToPool(device, mem_pool)
|
||||
# Begin allocate to mem pool for all memory allocation on the current thread.
|
||||
# This is thread safe since a thread can only warmup or record 1 cudagraph
|
||||
# at the same time.
|
||||
torch._C._cuda_beginAllocateCurrentThreadToPool(device, mem_pool)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
torch._C._cuda_endAllocateCurrentStreamToPool(device, mem_pool)
|
||||
torch._C._cuda_endAllocateToPool(device, mem_pool)
|
||||
torch._C._cuda_releasePool(device, mem_pool)
|
||||
|
||||
torch.cuda.current_stream().wait_stream(stream)
|
||||
|
|
|
|||
|
|
@ -26,6 +26,13 @@
|
|||
#include <c10/cuda/CUDAFunctions.h>
|
||||
#include <ATen/cuda/CUDAGraphsUtils.cuh>
|
||||
|
||||
#if !defined(_MSC_VER)
|
||||
#include <sys/types.h>
|
||||
#include <unistd.h>
|
||||
#elif defined(_MSC_VER)
|
||||
#include <c10/util/win32-headers.h>
|
||||
#endif
|
||||
|
||||
#ifdef USE_NCCL
|
||||
#include <torch/csrc/cuda/python_nccl.h>
|
||||
#endif
|
||||
|
|
@ -1388,6 +1395,31 @@ static void registerCudaPluggableAllocator(PyObject* module) {
|
|||
device, mempool_id, [](cudaStream_t) { return true; });
|
||||
});
|
||||
|
||||
m.def(
|
||||
"_cuda_beginAllocateCurrentThreadToPool",
|
||||
[](c10::DeviceIndex device, at::cuda::MempoolId_t mempool_id) {
|
||||
#ifdef _MSC_VER
|
||||
auto pid = GetCurrentProcessId();
|
||||
#else
|
||||
auto pid = getpid();
|
||||
#endif
|
||||
c10::cuda::CUDACachingAllocator::beginAllocateToPool(
|
||||
device, mempool_id, [=](cudaStream_t) {
|
||||
#ifdef _MSC_VER
|
||||
auto current_pid = GetCurrentProcessId();
|
||||
#else
|
||||
auto current_pid = getpid();
|
||||
#endif
|
||||
return current_pid == pid;
|
||||
});
|
||||
});
|
||||
|
||||
m.def(
|
||||
"_cuda_endAllocateToPool",
|
||||
[](c10::DeviceIndex device, at::cuda::MempoolId_t mempool_id) {
|
||||
c10::cuda::CUDACachingAllocator::endAllocateToPool(device, mempool_id);
|
||||
});
|
||||
|
||||
m.def(
|
||||
"_cuda_endAllocateCurrentStreamToPool",
|
||||
[](c10::DeviceIndex device, at::cuda::MempoolId_t mempool_id) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user