Expose cudaStreamCaptureMode in CUDA Graphs, use local setting in inductor (#107407)

>  capture_error_mode (str, optional): specifies the cudaStreamCaptureMode for the graph capture stream.
Can be "global", "thread_local" or "relaxed". During cuda graph capture, some actions, such as cudaMalloc,
 may be unsafe. "global" will error on actions in other threads, "thread_local" will only error for
 actions in the current thread, and "relaxed" will not error on these actions.

Inductor codegen is single-threaded, so it should be safe to enable "thread_local" for inductor's cuda graph capturing. We have seen errors when inductor cudagraphs has been used concurrently with data preprocessing in other threads.

Differential Revision: [D48656014](https://our.internmc.facebook.com/intern/diff/D48656014)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107407
Approved by: https://github.com/albanD, https://github.com/eqy
This commit is contained in:
Elias Ellison 2023-08-23 22:15:22 +00:00 committed by PyTorch MergeBot
parent c18d2a3c05
commit 0a9778a372
9 changed files with 119 additions and 22 deletions

View File

@ -63,7 +63,7 @@ CUDAGraph::CUDAGraph()
#endif #endif
} }
void CUDAGraph::capture_begin(MempoolId_t pool/*=0*/) { void CUDAGraph::capture_begin(MempoolId_t pool/*=0*/, cudaStreamCaptureMode capture_mode) {
#if !defined(USE_ROCM) || ROCM_VERSION >= 50300 #if !defined(USE_ROCM) || ROCM_VERSION >= 50300
TORCH_CHECK(!has_graph_exec_, TORCH_CHECK(!has_graph_exec_,
"This CUDAGraph instance already owns a captured graph. " "This CUDAGraph instance already owns a captured graph. "
@ -118,7 +118,7 @@ void CUDAGraph::capture_begin(MempoolId_t pool/*=0*/) {
// cudaStreamCaptureModeGlobal is the most conservative option to // cudaStreamCaptureModeGlobal is the most conservative option to
// prevent potentially unsafe CUDA API calls during capture. See // prevent potentially unsafe CUDA API calls during capture. See
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85 // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85
AT_CUDA_CHECK(cudaStreamBeginCapture(capture_stream_, cudaStreamCaptureModeGlobal)); AT_CUDA_CHECK(cudaStreamBeginCapture(capture_stream_, capture_mode));
cudaStreamCaptureStatus status; cudaStreamCaptureStatus status;
AT_CUDA_CHECK(cudaStreamGetCaptureInfo(stream, &status, nullptr)); AT_CUDA_CHECK(cudaStreamGetCaptureInfo(stream, &status, nullptr));

View File

@ -19,7 +19,7 @@ struct TORCH_CUDA_CPP_API CUDAGraph {
CUDAGraph(); CUDAGraph();
~CUDAGraph(); ~CUDAGraph();
void capture_begin(MempoolId_t pool={0, 0}); void capture_begin(MempoolId_t pool={0, 0}, cudaStreamCaptureMode capture_mode = cudaStreamCaptureModeGlobal);
void capture_end(); void capture_end();
void replay(); void replay();
void reset(); void reset();

View File

@ -3186,6 +3186,61 @@ exit(2)
for p_control, p_graphed in zip(params_control, params_graphed): for p_control, p_graphed in zip(params_control, params_graphed):
self.assertEqual(p_control, p_graphed) self.assertEqual(p_control, p_graphed)
@unittest.skipIf(not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs")
def test_cuda_graph_error_options(self):
def fn():
x = torch.zeros([2000], device="cuda")
y = x + x + x
return y
mem = None
def raw_malloc():
global mem
mem = None
stream = torch.cuda.Stream()
try:
with torch.cuda.stream(stream):
mem = torch.cuda.caching_allocator_alloc(1024)
except BaseException:
if mem is None:
return
try:
torch.cuda.caching_allocator_delete(mem)
mem = None
return None
except BaseException:
pass
def throws_on_cuda_event(capture_error_mode):
graph = torch.cuda.CUDAGraph()
torch.cuda.synchronize()
stream = torch.cuda.Stream()
stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(stream):
fn()
stream.synchronize()
torch.cuda.current_stream().wait_stream(stream)
torch.cuda.synchronize()
try:
with torch.cuda.graph(graph, stream=stream, capture_error_mode=capture_error_mode):
out = fn()
thread = threading.Thread(target=raw_malloc)
thread.start()
thread.join()
except Exception:
if mem is not None:
torch.cuda.caching_allocator_delete(mem)
return True
return False
self.assertFalse(throws_on_cuda_event("thread_local"))
self.assertFalse(throws_on_cuda_event("relaxed"))
# Exception would Corrupt Process and make other tests fail
# self.assertTrue(throws_on_cuda_event("global"))
def test_batch_norm_gather_stats(self): def test_batch_norm_gather_stats(self):
input = torch.randn(1, 3, 3, 3, device='cuda') input = torch.randn(1, 3, 3, 3, device='cuda')
mean, invstd = torch.batch_norm_gather_stats( mean, invstd = torch.batch_norm_gather_stats(

View File

@ -1705,7 +1705,7 @@ class _CudaEventBase:
# Defined in torch/csrc/cuda/Graph.cpp # Defined in torch/csrc/cuda/Graph.cpp
class _CUDAGraph: class _CUDAGraph:
def capture_begin(self, pool: Optional[Tuple[_int, _int]] = ...) -> None: ... def capture_begin(self, pool: Optional[Tuple[_int, _int]] = ..., capture_error_mode: str = "global") -> None: ...
def capture_end(self) -> None: ... def capture_end(self) -> None: ...
def replay(self) -> None: ... def replay(self) -> None: ...
def reset(self) -> None: ... def reset(self) -> None: ...

View File

@ -770,7 +770,7 @@ def cudagraphify_impl(
# record # record
graph = torch.cuda.CUDAGraph() graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=stream): with torch.cuda.graph(graph, stream=stream, capture_error_mode="thread_local"):
static_outputs = model(list(static_inputs)) static_outputs = model(list(static_inputs))
if not isinstance(static_outputs, (list, tuple)): if not isinstance(static_outputs, (list, tuple)):
static_outputs = (static_outputs,) static_outputs = (static_outputs,)

View File

@ -1076,7 +1076,10 @@ class CUDAGraphNode:
with preserve_rng_state(), torch.cuda.device( with preserve_rng_state(), torch.cuda.device(
self.device self.device
), clear_cublas_manager(), torch.cuda.graph( ), clear_cublas_manager(), torch.cuda.graph(
self.graph, stream=self.stream, pool=self.cuda_graphs_pool self.graph,
stream=self.stream,
pool=self.cuda_graphs_pool,
capture_error_mode="thread_local",
), get_history_recording(): ), get_history_recording():
static_outputs = model(inputs) static_outputs = model(inputs)
@ -1686,6 +1689,7 @@ class CUDAGraphTreeManager:
self.graph, self.graph,
pool=self.cuda_graphs_thread_pool, pool=self.cuda_graphs_thread_pool,
stream=self.stream, stream=self.stream,
capture_error_mode="thread_local",
): ):
pass pass

View File

@ -6,6 +6,7 @@
#include <torch/csrc/utils/pybind.h> #include <torch/csrc/utils/pybind.h>
#include <ATen/cuda/CUDAGraph.h> #include <ATen/cuda/CUDAGraph.h>
#include <c10/cuda/CUDAGraphsC10Utils.h>
// Cargo culted partially from csrc/distributed/c10d/init.cpp // Cargo culted partially from csrc/distributed/c10d/init.cpp
// and partially from csrc/cuda/Stream.cpp. // and partially from csrc/cuda/Stream.cpp.
@ -26,13 +27,32 @@ void THCPGraph_init(PyObject* module) {
shared_ptr_class_<::at::cuda::CUDAGraph>(torch_C_m, "_CUDAGraph") shared_ptr_class_<::at::cuda::CUDAGraph>(torch_C_m, "_CUDAGraph")
.def(py::init<>()) .def(py::init<>())
// I'm not sure this is the correct order of all the arguments. Pybind11
// docs aren't clear. But it works.
.def( .def(
"capture_begin", "capture_begin",
torch::wrap_pybind_function_no_gil( [](::at::cuda::CUDAGraph& self,
&at::cuda::CUDAGraph::capture_begin), c10::optional<c10::cuda::MempoolId_t> pool_opt,
py::arg("pool") = c10::cuda::MempoolId_t{0, 0}) std::string capture_error_mode) {
cudaStreamCaptureMode capture_mode;
c10::cuda::MempoolId_t pool = pool_opt.has_value()
? pool_opt.value()
: c10::cuda::MempoolId_t{0, 0};
if (capture_error_mode == "global") {
capture_mode = cudaStreamCaptureModeGlobal;
} else if (capture_error_mode == "thread_local") {
capture_mode = cudaStreamCaptureModeThreadLocal;
} else if (capture_error_mode == "relaxed") {
capture_mode = cudaStreamCaptureModeRelaxed;
} else {
TORCH_CHECK(
false,
"Unknown capture error mode. Expected `global`, `thread_local`, or `relaxed`, got ",
capture_error_mode);
}
return self.capture_begin(pool, capture_mode);
},
py::arg("pool"),
py::arg("capture_error_mode"),
py::call_guard<py::gil_scoped_release>())
.def( .def(
"capture_end", "capture_end",
torch::wrap_pybind_function_no_gil(&at::cuda::CUDAGraph::capture_end)) torch::wrap_pybind_function_no_gil(&at::cuda::CUDAGraph::capture_end))

View File

@ -56,7 +56,7 @@ class CUDAGraph(torch._C._CUDAGraph):
def __new__(cls): def __new__(cls):
return super().__new__(cls) return super().__new__(cls)
def capture_begin(self, pool=None): def capture_begin(self, pool=None, capture_error_mode="global"):
r""" r"""
Begins capturing CUDA work on the current stream. Begins capturing CUDA work on the current stream.
@ -68,13 +68,13 @@ class CUDAGraph(torch._C._CUDAGraph):
pool (optional): Token (returned by :func:`~torch.cuda.graph_pool_handle` or pool (optional): Token (returned by :func:`~torch.cuda.graph_pool_handle` or
:meth:`other_Graph_instance.pool()<torch.cuda.CUDAGraph.pool>`) that hints this graph may share memory :meth:`other_Graph_instance.pool()<torch.cuda.CUDAGraph.pool>`) that hints this graph may share memory
with the indicated pool. See :ref:`Graph memory management<graph-memory-management>`. with the indicated pool. See :ref:`Graph memory management<graph-memory-management>`.
""" capture_error_mode (str, optional): specifies the cudaStreamCaptureMode for the graph capture stream.
# I'm not sure if pybind11 converts a None arg to the default defined on the C++ side, Can be "global", "thread_local" or "relaxed". During cuda graph capture, some actions, such as cudaMalloc,
# so I'm not taking any chances. may be unsafe. "global" will error on actions in other threads, "thread_local" will only error for
if pool is None: actions in the current thread, and "relaxed" will not error on these actions. Do NOT change this setting
super().capture_begin() unless you're familiar with `cudaStreamCaptureMode <https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85>`_
else: """ # noqa: B950
super().capture_begin(pool) super().capture_begin(pool=pool, capture_error_mode=capture_error_mode)
def capture_end(self): def capture_end(self):
r""" r"""
@ -139,6 +139,11 @@ class graph:
may share memory from the specified pool. See :ref:`Graph memory management<graph-memory-management>`. may share memory from the specified pool. See :ref:`Graph memory management<graph-memory-management>`.
stream (torch.cuda.Stream, optional): If supplied, will be set as the current stream in the context. stream (torch.cuda.Stream, optional): If supplied, will be set as the current stream in the context.
If not supplied, ``graph`` sets its own internal side stream as the current stream in the context. If not supplied, ``graph`` sets its own internal side stream as the current stream in the context.
capture_error_mode (str, optional): specifies the cudaStreamCaptureMode for the graph capture stream.
Can be "global", "thread_local" or "relaxed". During cuda graph capture, some actions, such as cudaMalloc,
may be unsafe. "global" will error on actions in other threads, "thread_local" will only error for
actions in the current thread, and "relaxed" will not error on actions. Do NOT change this setting
unless you're familiar with `cudaStreamCaptureMode <https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85>`_
.. note:: .. note::
For effective memory sharing, if you pass a ``pool`` used by a previous capture and the previous capture For effective memory sharing, if you pass a ``pool`` used by a previous capture and the previous capture
@ -146,10 +151,19 @@ class graph:
.. warning:: .. warning::
This API is in beta and may change in future releases. This API is in beta and may change in future releases.
"""
.. _cudaStreamCaptureMode:
https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85
""" # noqa: B950
default_capture_stream = None default_capture_stream = None
def __init__(self, cuda_graph, pool=None, stream=None): def __init__(
self,
cuda_graph,
pool=None,
stream=None,
capture_error_mode: str = "global",
):
# Lazy-init of default_capture_stream helps avoid circular-import errors. # Lazy-init of default_capture_stream helps avoid circular-import errors.
# Not thread safe, but graphs already have the general (explicitly documented) # Not thread safe, but graphs already have the general (explicitly documented)
# restriction that only one capture may be underway at a time in the process. # restriction that only one capture may be underway at a time in the process.
@ -163,6 +177,7 @@ class graph:
assert self.capture_stream is not None assert self.capture_stream is not None
self.stream_ctx = torch.cuda.stream(self.capture_stream) self.stream_ctx = torch.cuda.stream(self.capture_stream)
self.cuda_graph = cuda_graph self.cuda_graph = cuda_graph
self.capture_error_mode = capture_error_mode
def __enter__(self): def __enter__(self):
# Free as much memory as we can for the graph # Free as much memory as we can for the graph
@ -174,7 +189,9 @@ class graph:
# https://stackoverflow.com/questions/26635684/calling-enter-and-exit-manually#39172487 # https://stackoverflow.com/questions/26635684/calling-enter-and-exit-manually#39172487
self.stream_ctx.__enter__() self.stream_ctx.__enter__()
self.cuda_graph.capture_begin(*self.pool) self.cuda_graph.capture_begin(
*self.pool, capture_error_mode=self.capture_error_mode
)
def __exit__(self, exc_type, exc_value, traceback): def __exit__(self, exc_type, exc_value, traceback):
self.cuda_graph.capture_end() self.cuda_graph.capture_end()

View File

@ -4161,6 +4161,7 @@ CUDA_IDENTIFIER_MAP = collections.OrderedDict(
("cudaStreamCaptureMode", ("hipStreamCaptureMode", CONV_TYPE, API_RUNTIME)), ("cudaStreamCaptureMode", ("hipStreamCaptureMode", CONV_TYPE, API_RUNTIME)),
("cudaStreamCaptureModeGlobal", ("hipStreamCaptureModeGlobal", CONV_TYPE, API_RUNTIME)), ("cudaStreamCaptureModeGlobal", ("hipStreamCaptureModeGlobal", CONV_TYPE, API_RUNTIME)),
("cudaStreamCaptureModeRelaxed", ("hipStreamCaptureModeRelaxed", CONV_TYPE, API_RUNTIME)), ("cudaStreamCaptureModeRelaxed", ("hipStreamCaptureModeRelaxed", CONV_TYPE, API_RUNTIME)),
("cudaStreamCaptureModeThreadLocal", ("hipStreamCaptureModeThreadLocal", CONV_TYPE, API_RUNTIME)),
("cudaStreamBeginCapture", ("hipStreamBeginCapture", CONV_TYPE, API_RUNTIME)), ("cudaStreamBeginCapture", ("hipStreamBeginCapture", CONV_TYPE, API_RUNTIME)),
("cudaStreamEndCapture", ("hipStreamEndCapture", CONV_TYPE, API_RUNTIME)), ("cudaStreamEndCapture", ("hipStreamEndCapture", CONV_TYPE, API_RUNTIME)),
("cudaGraphInstantiate", ("hipGraphInstantiate", CONV_TYPE, API_RUNTIME)), ("cudaGraphInstantiate", ("hipGraphInstantiate", CONV_TYPE, API_RUNTIME)),