From 0a9778a3727faa12266819e8f7b8d28e59900b51 Mon Sep 17 00:00:00 2001 From: Elias Ellison Date: Wed, 23 Aug 2023 22:15:22 +0000 Subject: [PATCH] Expose cudaStreamCaptureMode in CUDA Graphs, use local setting in inductor (#107407) > capture_error_mode (str, optional): specifies the cudaStreamCaptureMode for the graph capture stream. Can be "global", "thread_local" or "relaxed". During cuda graph capture, some actions, such as cudaMalloc, may be unsafe. "global" will error on actions in other threads, "thread_local" will only error for actions in the current thread, and "relaxed" will not error on these actions. Inductor codegen is single-threaded, so it should be safe to enable "thread_local" for inductor's cuda graph capturing. We have seen errors when inductor cudagraphs has been used concurrently with data preprocessing in other threads. Differential Revision: [D48656014](https://our.internmc.facebook.com/intern/diff/D48656014) Pull Request resolved: https://github.com/pytorch/pytorch/pull/107407 Approved by: https://github.com/albanD, https://github.com/eqy --- aten/src/ATen/cuda/CUDAGraph.cpp | 4 +- aten/src/ATen/cuda/CUDAGraph.h | 2 +- test/test_cuda.py | 55 ++++++++++++++++++++++ torch/_C/__init__.pyi.in | 2 +- torch/_inductor/compile_fx.py | 2 +- torch/_inductor/cudagraph_trees.py | 6 ++- torch/csrc/cuda/Graph.cpp | 30 ++++++++++-- torch/cuda/graphs.py | 39 ++++++++++----- torch/utils/hipify/cuda_to_hip_mappings.py | 1 + 9 files changed, 119 insertions(+), 22 deletions(-) diff --git a/aten/src/ATen/cuda/CUDAGraph.cpp b/aten/src/ATen/cuda/CUDAGraph.cpp index 8f499fa9fd6..017d29f748c 100644 --- a/aten/src/ATen/cuda/CUDAGraph.cpp +++ b/aten/src/ATen/cuda/CUDAGraph.cpp @@ -63,7 +63,7 @@ CUDAGraph::CUDAGraph() #endif } -void CUDAGraph::capture_begin(MempoolId_t pool/*=0*/) { +void CUDAGraph::capture_begin(MempoolId_t pool/*=0*/, cudaStreamCaptureMode capture_mode) { #if !defined(USE_ROCM) || ROCM_VERSION >= 50300 TORCH_CHECK(!has_graph_exec_, "This CUDAGraph instance already owns a captured graph. " @@ -118,7 +118,7 @@ void CUDAGraph::capture_begin(MempoolId_t pool/*=0*/) { // cudaStreamCaptureModeGlobal is the most conservative option to // prevent potentially unsafe CUDA API calls during capture. See // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85 - AT_CUDA_CHECK(cudaStreamBeginCapture(capture_stream_, cudaStreamCaptureModeGlobal)); + AT_CUDA_CHECK(cudaStreamBeginCapture(capture_stream_, capture_mode)); cudaStreamCaptureStatus status; AT_CUDA_CHECK(cudaStreamGetCaptureInfo(stream, &status, nullptr)); diff --git a/aten/src/ATen/cuda/CUDAGraph.h b/aten/src/ATen/cuda/CUDAGraph.h index c4b6fe44d95..00113180e3f 100644 --- a/aten/src/ATen/cuda/CUDAGraph.h +++ b/aten/src/ATen/cuda/CUDAGraph.h @@ -19,7 +19,7 @@ struct TORCH_CUDA_CPP_API CUDAGraph { CUDAGraph(); ~CUDAGraph(); - void capture_begin(MempoolId_t pool={0, 0}); + void capture_begin(MempoolId_t pool={0, 0}, cudaStreamCaptureMode capture_mode = cudaStreamCaptureModeGlobal); void capture_end(); void replay(); void reset(); diff --git a/test/test_cuda.py b/test/test_cuda.py index 82e9a55439b..e81c9365139 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -3186,6 +3186,61 @@ exit(2) for p_control, p_graphed in zip(params_control, params_graphed): self.assertEqual(p_control, p_graphed) + @unittest.skipIf(not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs") + def test_cuda_graph_error_options(self): + def fn(): + x = torch.zeros([2000], device="cuda") + y = x + x + x + return y + + mem = None + + def raw_malloc(): + global mem + mem = None + stream = torch.cuda.Stream() + try: + with torch.cuda.stream(stream): + mem = torch.cuda.caching_allocator_alloc(1024) + except BaseException: + if mem is None: + return + try: + torch.cuda.caching_allocator_delete(mem) + mem = None + return None + except BaseException: + pass + + def throws_on_cuda_event(capture_error_mode): + graph = torch.cuda.CUDAGraph() + torch.cuda.synchronize() + stream = torch.cuda.Stream() + stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(stream): + fn() + stream.synchronize() + torch.cuda.current_stream().wait_stream(stream) + torch.cuda.synchronize() + try: + with torch.cuda.graph(graph, stream=stream, capture_error_mode=capture_error_mode): + out = fn() + thread = threading.Thread(target=raw_malloc) + thread.start() + thread.join() + except Exception: + if mem is not None: + torch.cuda.caching_allocator_delete(mem) + return True + + return False + + self.assertFalse(throws_on_cuda_event("thread_local")) + self.assertFalse(throws_on_cuda_event("relaxed")) + + # Exception would Corrupt Process and make other tests fail + # self.assertTrue(throws_on_cuda_event("global")) + def test_batch_norm_gather_stats(self): input = torch.randn(1, 3, 3, 3, device='cuda') mean, invstd = torch.batch_norm_gather_stats( diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 3e60bf7b807..e7073296bb8 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -1705,7 +1705,7 @@ class _CudaEventBase: # Defined in torch/csrc/cuda/Graph.cpp class _CUDAGraph: - def capture_begin(self, pool: Optional[Tuple[_int, _int]] = ...) -> None: ... + def capture_begin(self, pool: Optional[Tuple[_int, _int]] = ..., capture_error_mode: str = "global") -> None: ... def capture_end(self) -> None: ... def replay(self) -> None: ... def reset(self) -> None: ... diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index 28189bf90a3..01188b49055 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -770,7 +770,7 @@ def cudagraphify_impl( # record graph = torch.cuda.CUDAGraph() - with torch.cuda.graph(graph, stream=stream): + with torch.cuda.graph(graph, stream=stream, capture_error_mode="thread_local"): static_outputs = model(list(static_inputs)) if not isinstance(static_outputs, (list, tuple)): static_outputs = (static_outputs,) diff --git a/torch/_inductor/cudagraph_trees.py b/torch/_inductor/cudagraph_trees.py index 627e240a768..bdbae6d00ba 100644 --- a/torch/_inductor/cudagraph_trees.py +++ b/torch/_inductor/cudagraph_trees.py @@ -1076,7 +1076,10 @@ class CUDAGraphNode: with preserve_rng_state(), torch.cuda.device( self.device ), clear_cublas_manager(), torch.cuda.graph( - self.graph, stream=self.stream, pool=self.cuda_graphs_pool + self.graph, + stream=self.stream, + pool=self.cuda_graphs_pool, + capture_error_mode="thread_local", ), get_history_recording(): static_outputs = model(inputs) @@ -1686,6 +1689,7 @@ class CUDAGraphTreeManager: self.graph, pool=self.cuda_graphs_thread_pool, stream=self.stream, + capture_error_mode="thread_local", ): pass diff --git a/torch/csrc/cuda/Graph.cpp b/torch/csrc/cuda/Graph.cpp index f0781f9b0ca..1cc514480ed 100644 --- a/torch/csrc/cuda/Graph.cpp +++ b/torch/csrc/cuda/Graph.cpp @@ -6,6 +6,7 @@ #include #include +#include // Cargo culted partially from csrc/distributed/c10d/init.cpp // and partially from csrc/cuda/Stream.cpp. @@ -26,13 +27,32 @@ void THCPGraph_init(PyObject* module) { shared_ptr_class_<::at::cuda::CUDAGraph>(torch_C_m, "_CUDAGraph") .def(py::init<>()) - // I'm not sure this is the correct order of all the arguments. Pybind11 - // docs aren't clear. But it works. .def( "capture_begin", - torch::wrap_pybind_function_no_gil( - &at::cuda::CUDAGraph::capture_begin), - py::arg("pool") = c10::cuda::MempoolId_t{0, 0}) + [](::at::cuda::CUDAGraph& self, + c10::optional pool_opt, + std::string capture_error_mode) { + cudaStreamCaptureMode capture_mode; + c10::cuda::MempoolId_t pool = pool_opt.has_value() + ? pool_opt.value() + : c10::cuda::MempoolId_t{0, 0}; + if (capture_error_mode == "global") { + capture_mode = cudaStreamCaptureModeGlobal; + } else if (capture_error_mode == "thread_local") { + capture_mode = cudaStreamCaptureModeThreadLocal; + } else if (capture_error_mode == "relaxed") { + capture_mode = cudaStreamCaptureModeRelaxed; + } else { + TORCH_CHECK( + false, + "Unknown capture error mode. Expected `global`, `thread_local`, or `relaxed`, got ", + capture_error_mode); + } + return self.capture_begin(pool, capture_mode); + }, + py::arg("pool"), + py::arg("capture_error_mode"), + py::call_guard()) .def( "capture_end", torch::wrap_pybind_function_no_gil(&at::cuda::CUDAGraph::capture_end)) diff --git a/torch/cuda/graphs.py b/torch/cuda/graphs.py index 144959f5340..df9477d6cfd 100644 --- a/torch/cuda/graphs.py +++ b/torch/cuda/graphs.py @@ -56,7 +56,7 @@ class CUDAGraph(torch._C._CUDAGraph): def __new__(cls): return super().__new__(cls) - def capture_begin(self, pool=None): + def capture_begin(self, pool=None, capture_error_mode="global"): r""" Begins capturing CUDA work on the current stream. @@ -68,13 +68,13 @@ class CUDAGraph(torch._C._CUDAGraph): pool (optional): Token (returned by :func:`~torch.cuda.graph_pool_handle` or :meth:`other_Graph_instance.pool()`) that hints this graph may share memory with the indicated pool. See :ref:`Graph memory management`. - """ - # I'm not sure if pybind11 converts a None arg to the default defined on the C++ side, - # so I'm not taking any chances. - if pool is None: - super().capture_begin() - else: - super().capture_begin(pool) + capture_error_mode (str, optional): specifies the cudaStreamCaptureMode for the graph capture stream. + Can be "global", "thread_local" or "relaxed". During cuda graph capture, some actions, such as cudaMalloc, + may be unsafe. "global" will error on actions in other threads, "thread_local" will only error for + actions in the current thread, and "relaxed" will not error on these actions. Do NOT change this setting + unless you're familiar with `cudaStreamCaptureMode `_ + """ # noqa: B950 + super().capture_begin(pool=pool, capture_error_mode=capture_error_mode) def capture_end(self): r""" @@ -139,6 +139,11 @@ class graph: may share memory from the specified pool. See :ref:`Graph memory management`. stream (torch.cuda.Stream, optional): If supplied, will be set as the current stream in the context. If not supplied, ``graph`` sets its own internal side stream as the current stream in the context. + capture_error_mode (str, optional): specifies the cudaStreamCaptureMode for the graph capture stream. + Can be "global", "thread_local" or "relaxed". During cuda graph capture, some actions, such as cudaMalloc, + may be unsafe. "global" will error on actions in other threads, "thread_local" will only error for + actions in the current thread, and "relaxed" will not error on actions. Do NOT change this setting + unless you're familiar with `cudaStreamCaptureMode `_ .. note:: For effective memory sharing, if you pass a ``pool`` used by a previous capture and the previous capture @@ -146,10 +151,19 @@ class graph: .. warning:: This API is in beta and may change in future releases. - """ + + .. _cudaStreamCaptureMode: + https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85 + """ # noqa: B950 default_capture_stream = None - def __init__(self, cuda_graph, pool=None, stream=None): + def __init__( + self, + cuda_graph, + pool=None, + stream=None, + capture_error_mode: str = "global", + ): # Lazy-init of default_capture_stream helps avoid circular-import errors. # Not thread safe, but graphs already have the general (explicitly documented) # restriction that only one capture may be underway at a time in the process. @@ -163,6 +177,7 @@ class graph: assert self.capture_stream is not None self.stream_ctx = torch.cuda.stream(self.capture_stream) self.cuda_graph = cuda_graph + self.capture_error_mode = capture_error_mode def __enter__(self): # Free as much memory as we can for the graph @@ -174,7 +189,9 @@ class graph: # https://stackoverflow.com/questions/26635684/calling-enter-and-exit-manually#39172487 self.stream_ctx.__enter__() - self.cuda_graph.capture_begin(*self.pool) + self.cuda_graph.capture_begin( + *self.pool, capture_error_mode=self.capture_error_mode + ) def __exit__(self, exc_type, exc_value, traceback): self.cuda_graph.capture_end() diff --git a/torch/utils/hipify/cuda_to_hip_mappings.py b/torch/utils/hipify/cuda_to_hip_mappings.py index c0ac38dc7c0..73586440e72 100644 --- a/torch/utils/hipify/cuda_to_hip_mappings.py +++ b/torch/utils/hipify/cuda_to_hip_mappings.py @@ -4161,6 +4161,7 @@ CUDA_IDENTIFIER_MAP = collections.OrderedDict( ("cudaStreamCaptureMode", ("hipStreamCaptureMode", CONV_TYPE, API_RUNTIME)), ("cudaStreamCaptureModeGlobal", ("hipStreamCaptureModeGlobal", CONV_TYPE, API_RUNTIME)), ("cudaStreamCaptureModeRelaxed", ("hipStreamCaptureModeRelaxed", CONV_TYPE, API_RUNTIME)), + ("cudaStreamCaptureModeThreadLocal", ("hipStreamCaptureModeThreadLocal", CONV_TYPE, API_RUNTIME)), ("cudaStreamBeginCapture", ("hipStreamBeginCapture", CONV_TYPE, API_RUNTIME)), ("cudaStreamEndCapture", ("hipStreamEndCapture", CONV_TYPE, API_RUNTIME)), ("cudaGraphInstantiate", ("hipGraphInstantiate", CONV_TYPE, API_RUNTIME)),