diff --git a/test/inductor/test_cudagraph_trees.py b/test/inductor/test_cudagraph_trees.py index 32b326e4615..d6e10fec83f 100644 --- a/test/inductor/test_cudagraph_trees.py +++ b/test/inductor/test_cudagraph_trees.py @@ -2061,6 +2061,49 @@ if HAS_CUDA: with self.assertRaisesRegex(Exception, "custom error msg"): device = x.untyped_storage() + def test_side_stream_memory_allocation(self): + from torch._inductor.cudagraph_trees import cudagraphify_impl + + def multi_stream_allocation(args): + side_stream = torch.cuda.Stream() + side_stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(side_stream): + side_stream_buffer = torch.ones( + *args, device="cuda:0", dtype=torch.float32 + ) + torch.cuda.current_stream().wait_stream(side_stream) + + main_stream_buffer = torch.ones( + *args, device="cuda:0", dtype=torch.float32 + ) + + if isinstance(args, list): + args.clear() + + return main_stream_buffer, side_stream_buffer + + graphed_multi_stream_func = cudagraphify_impl( + multi_stream_allocation, + inputs=[], + static_input_idxs=[], + is_backward=False, + is_inference=False, + device_index=0, + stack_traces=["dummy stack trace1", "dummy stack trace2"], + ) + + ref_out = torch.ones((2, 3), device="cuda:0", dtype=torch.float32) + + for _ in range(3): + torch.compiler.cudagraph_mark_step_begin() + main_stream_buffer, side_stream_buffer = graphed_multi_stream_func( + [2, 3] + ) + self.assertEqual(main_stream_buffer, ref_out) + self.assertEqual(side_stream_buffer, ref_out) + + self.assertEqual(self.get_manager().new_graph_id().id, 1) + @torch._dynamo.config.patch("inline_inbuilt_nn_modules", False) @torch._inductor.config.patch("triton.cudagraph_support_input_mutation", False) def test_static_inputs_address_mutation_log(self): diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 149df10195d..928d32c11fb 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -1926,6 +1926,8 @@ def _cuda_cudaCachingAllocator_raw_delete(ptr: _int) -> None: ... def _cuda_cudaCachingAllocator_enable(val: _bool) -> None: ... def _cuda_cudaCachingAllocator_set_allocator_settings(env: str) -> None: ... def _cuda_beginAllocateToPool(device: _int, mempool_id: Tuple[_int, _int]) -> None: ... +def _cuda_beginAllocateCurrentThreadToPool(device: _int, mempool_id: Tuple[_int, _int]) -> None: ... +def _cuda_endAllocateToPool(device: _int, mempool_id: Tuple[_int, _int]) -> None: ... def _cuda_beginAllocateCurrentStreamToPool(device: _int, mempool_id: Tuple[_int, _int]) -> None: ... def _cuda_endAllocateCurrentStreamToPool(device: _int, mempool_id: Tuple[_int, _int]) -> None: ... def _cuda_releasePool(device: _int, mempool_id: Tuple[_int, _int]) -> None: ... diff --git a/torch/_inductor/cudagraph_trees.py b/torch/_inductor/cudagraph_trees.py index 70ab30a7076..66f72e142f3 100644 --- a/torch/_inductor/cudagraph_trees.py +++ b/torch/_inductor/cudagraph_trees.py @@ -555,11 +555,14 @@ def _use_cuda_memory_pool_manager( stream.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(stream), torch.device(device): - torch._C._cuda_beginAllocateCurrentStreamToPool(device, mem_pool) + # Begin allocate to mem pool for all memory allocation on the current thread. + # This is thread safe since a thread can only warmup or record 1 cudagraph + # at the same time. + torch._C._cuda_beginAllocateCurrentThreadToPool(device, mem_pool) try: yield finally: - torch._C._cuda_endAllocateCurrentStreamToPool(device, mem_pool) + torch._C._cuda_endAllocateToPool(device, mem_pool) torch._C._cuda_releasePool(device, mem_pool) torch.cuda.current_stream().wait_stream(stream) diff --git a/torch/csrc/cuda/Module.cpp b/torch/csrc/cuda/Module.cpp index 4a1725b988c..bf8cecf76e6 100644 --- a/torch/csrc/cuda/Module.cpp +++ b/torch/csrc/cuda/Module.cpp @@ -26,6 +26,13 @@ #include #include +#if !defined(_MSC_VER) +#include +#include +#elif defined(_MSC_VER) +#include +#endif + #ifdef USE_NCCL #include #endif @@ -1388,6 +1395,31 @@ static void registerCudaPluggableAllocator(PyObject* module) { device, mempool_id, [](cudaStream_t) { return true; }); }); + m.def( + "_cuda_beginAllocateCurrentThreadToPool", + [](c10::DeviceIndex device, at::cuda::MempoolId_t mempool_id) { +#ifdef _MSC_VER + auto pid = GetCurrentProcessId(); +#else + auto pid = getpid(); +#endif + c10::cuda::CUDACachingAllocator::beginAllocateToPool( + device, mempool_id, [=](cudaStream_t) { +#ifdef _MSC_VER + auto current_pid = GetCurrentProcessId(); +#else + auto current_pid = getpid(); +#endif + return current_pid == pid; + }); + }); + + m.def( + "_cuda_endAllocateToPool", + [](c10::DeviceIndex device, at::cuda::MempoolId_t mempool_id) { + c10::cuda::CUDACachingAllocator::endAllocateToPool(device, mempool_id); + }); + m.def( "_cuda_endAllocateCurrentStreamToPool", [](c10::DeviceIndex device, at::cuda::MempoolId_t mempool_id) {