Expose cudaStreamCaptureMode in CUDA Graphs, use local setting in inductor (#107407)

>  capture_error_mode (str, optional): specifies the cudaStreamCaptureMode for the graph capture stream.
Can be "global", "thread_local" or "relaxed". During cuda graph capture, some actions, such as cudaMalloc,
 may be unsafe. "global" will error on actions in other threads, "thread_local" will only error for
 actions in the current thread, and "relaxed" will not error on these actions.

Inductor codegen is single-threaded, so it should be safe to enable "thread_local" for inductor's cuda graph capturing. We have seen errors when inductor cudagraphs has been used concurrently with data preprocessing in other threads.

Differential Revision: [D48656014](https://our.internmc.facebook.com/intern/diff/D48656014)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107407
Approved by: https://github.com/albanD, https://github.com/eqy
This commit is contained in:
Elias Ellison 2023-08-23 22:15:22 +00:00 committed by PyTorch MergeBot
parent c18d2a3c05
commit 0a9778a372
9 changed files with 119 additions and 22 deletions

View File

@ -63,7 +63,7 @@ CUDAGraph::CUDAGraph()
#endif
}
void CUDAGraph::capture_begin(MempoolId_t pool/*=0*/) {
void CUDAGraph::capture_begin(MempoolId_t pool/*=0*/, cudaStreamCaptureMode capture_mode) {
#if !defined(USE_ROCM) || ROCM_VERSION >= 50300
TORCH_CHECK(!has_graph_exec_,
"This CUDAGraph instance already owns a captured graph. "
@ -118,7 +118,7 @@ void CUDAGraph::capture_begin(MempoolId_t pool/*=0*/) {
// cudaStreamCaptureModeGlobal is the most conservative option to
// prevent potentially unsafe CUDA API calls during capture. See
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85
AT_CUDA_CHECK(cudaStreamBeginCapture(capture_stream_, cudaStreamCaptureModeGlobal));
AT_CUDA_CHECK(cudaStreamBeginCapture(capture_stream_, capture_mode));
cudaStreamCaptureStatus status;
AT_CUDA_CHECK(cudaStreamGetCaptureInfo(stream, &status, nullptr));

View File

@ -19,7 +19,7 @@ struct TORCH_CUDA_CPP_API CUDAGraph {
CUDAGraph();
~CUDAGraph();
void capture_begin(MempoolId_t pool={0, 0});
void capture_begin(MempoolId_t pool={0, 0}, cudaStreamCaptureMode capture_mode = cudaStreamCaptureModeGlobal);
void capture_end();
void replay();
void reset();

View File

@ -3186,6 +3186,61 @@ exit(2)
for p_control, p_graphed in zip(params_control, params_graphed):
self.assertEqual(p_control, p_graphed)
@unittest.skipIf(not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs")
def test_cuda_graph_error_options(self):
def fn():
x = torch.zeros([2000], device="cuda")
y = x + x + x
return y
mem = None
def raw_malloc():
global mem
mem = None
stream = torch.cuda.Stream()
try:
with torch.cuda.stream(stream):
mem = torch.cuda.caching_allocator_alloc(1024)
except BaseException:
if mem is None:
return
try:
torch.cuda.caching_allocator_delete(mem)
mem = None
return None
except BaseException:
pass
def throws_on_cuda_event(capture_error_mode):
graph = torch.cuda.CUDAGraph()
torch.cuda.synchronize()
stream = torch.cuda.Stream()
stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(stream):
fn()
stream.synchronize()
torch.cuda.current_stream().wait_stream(stream)
torch.cuda.synchronize()
try:
with torch.cuda.graph(graph, stream=stream, capture_error_mode=capture_error_mode):
out = fn()
thread = threading.Thread(target=raw_malloc)
thread.start()
thread.join()
except Exception:
if mem is not None:
torch.cuda.caching_allocator_delete(mem)
return True
return False
self.assertFalse(throws_on_cuda_event("thread_local"))
self.assertFalse(throws_on_cuda_event("relaxed"))
# Exception would Corrupt Process and make other tests fail
# self.assertTrue(throws_on_cuda_event("global"))
def test_batch_norm_gather_stats(self):
input = torch.randn(1, 3, 3, 3, device='cuda')
mean, invstd = torch.batch_norm_gather_stats(

View File

@ -1705,7 +1705,7 @@ class _CudaEventBase:
# Defined in torch/csrc/cuda/Graph.cpp
class _CUDAGraph:
def capture_begin(self, pool: Optional[Tuple[_int, _int]] = ...) -> None: ...
def capture_begin(self, pool: Optional[Tuple[_int, _int]] = ..., capture_error_mode: str = "global") -> None: ...
def capture_end(self) -> None: ...
def replay(self) -> None: ...
def reset(self) -> None: ...

View File

@ -770,7 +770,7 @@ def cudagraphify_impl(
# record
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=stream):
with torch.cuda.graph(graph, stream=stream, capture_error_mode="thread_local"):
static_outputs = model(list(static_inputs))
if not isinstance(static_outputs, (list, tuple)):
static_outputs = (static_outputs,)

View File

@ -1076,7 +1076,10 @@ class CUDAGraphNode:
with preserve_rng_state(), torch.cuda.device(
self.device
), clear_cublas_manager(), torch.cuda.graph(
self.graph, stream=self.stream, pool=self.cuda_graphs_pool
self.graph,
stream=self.stream,
pool=self.cuda_graphs_pool,
capture_error_mode="thread_local",
), get_history_recording():
static_outputs = model(inputs)
@ -1686,6 +1689,7 @@ class CUDAGraphTreeManager:
self.graph,
pool=self.cuda_graphs_thread_pool,
stream=self.stream,
capture_error_mode="thread_local",
):
pass

View File

@ -6,6 +6,7 @@
#include <torch/csrc/utils/pybind.h>
#include <ATen/cuda/CUDAGraph.h>
#include <c10/cuda/CUDAGraphsC10Utils.h>
// Cargo culted partially from csrc/distributed/c10d/init.cpp
// and partially from csrc/cuda/Stream.cpp.
@ -26,13 +27,32 @@ void THCPGraph_init(PyObject* module) {
shared_ptr_class_<::at::cuda::CUDAGraph>(torch_C_m, "_CUDAGraph")
.def(py::init<>())
// I'm not sure this is the correct order of all the arguments. Pybind11
// docs aren't clear. But it works.
.def(
"capture_begin",
torch::wrap_pybind_function_no_gil(
&at::cuda::CUDAGraph::capture_begin),
py::arg("pool") = c10::cuda::MempoolId_t{0, 0})
[](::at::cuda::CUDAGraph& self,
c10::optional<c10::cuda::MempoolId_t> pool_opt,
std::string capture_error_mode) {
cudaStreamCaptureMode capture_mode;
c10::cuda::MempoolId_t pool = pool_opt.has_value()
? pool_opt.value()
: c10::cuda::MempoolId_t{0, 0};
if (capture_error_mode == "global") {
capture_mode = cudaStreamCaptureModeGlobal;
} else if (capture_error_mode == "thread_local") {
capture_mode = cudaStreamCaptureModeThreadLocal;
} else if (capture_error_mode == "relaxed") {
capture_mode = cudaStreamCaptureModeRelaxed;
} else {
TORCH_CHECK(
false,
"Unknown capture error mode. Expected `global`, `thread_local`, or `relaxed`, got ",
capture_error_mode);
}
return self.capture_begin(pool, capture_mode);
},
py::arg("pool"),
py::arg("capture_error_mode"),
py::call_guard<py::gil_scoped_release>())
.def(
"capture_end",
torch::wrap_pybind_function_no_gil(&at::cuda::CUDAGraph::capture_end))

View File

@ -56,7 +56,7 @@ class CUDAGraph(torch._C._CUDAGraph):
def __new__(cls):
return super().__new__(cls)
def capture_begin(self, pool=None):
def capture_begin(self, pool=None, capture_error_mode="global"):
r"""
Begins capturing CUDA work on the current stream.
@ -68,13 +68,13 @@ class CUDAGraph(torch._C._CUDAGraph):
pool (optional): Token (returned by :func:`~torch.cuda.graph_pool_handle` or
:meth:`other_Graph_instance.pool()<torch.cuda.CUDAGraph.pool>`) that hints this graph may share memory
with the indicated pool. See :ref:`Graph memory management<graph-memory-management>`.
"""
# I'm not sure if pybind11 converts a None arg to the default defined on the C++ side,
# so I'm not taking any chances.
if pool is None:
super().capture_begin()
else:
super().capture_begin(pool)
capture_error_mode (str, optional): specifies the cudaStreamCaptureMode for the graph capture stream.
Can be "global", "thread_local" or "relaxed". During cuda graph capture, some actions, such as cudaMalloc,
may be unsafe. "global" will error on actions in other threads, "thread_local" will only error for
actions in the current thread, and "relaxed" will not error on these actions. Do NOT change this setting
unless you're familiar with `cudaStreamCaptureMode <https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85>`_
""" # noqa: B950
super().capture_begin(pool=pool, capture_error_mode=capture_error_mode)
def capture_end(self):
r"""
@ -139,6 +139,11 @@ class graph:
may share memory from the specified pool. See :ref:`Graph memory management<graph-memory-management>`.
stream (torch.cuda.Stream, optional): If supplied, will be set as the current stream in the context.
If not supplied, ``graph`` sets its own internal side stream as the current stream in the context.
capture_error_mode (str, optional): specifies the cudaStreamCaptureMode for the graph capture stream.
Can be "global", "thread_local" or "relaxed". During cuda graph capture, some actions, such as cudaMalloc,
may be unsafe. "global" will error on actions in other threads, "thread_local" will only error for
actions in the current thread, and "relaxed" will not error on actions. Do NOT change this setting
unless you're familiar with `cudaStreamCaptureMode <https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85>`_
.. note::
For effective memory sharing, if you pass a ``pool`` used by a previous capture and the previous capture
@ -146,10 +151,19 @@ class graph:
.. warning::
This API is in beta and may change in future releases.
"""
.. _cudaStreamCaptureMode:
https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85
""" # noqa: B950
default_capture_stream = None
def __init__(self, cuda_graph, pool=None, stream=None):
def __init__(
self,
cuda_graph,
pool=None,
stream=None,
capture_error_mode: str = "global",
):
# Lazy-init of default_capture_stream helps avoid circular-import errors.
# Not thread safe, but graphs already have the general (explicitly documented)
# restriction that only one capture may be underway at a time in the process.
@ -163,6 +177,7 @@ class graph:
assert self.capture_stream is not None
self.stream_ctx = torch.cuda.stream(self.capture_stream)
self.cuda_graph = cuda_graph
self.capture_error_mode = capture_error_mode
def __enter__(self):
# Free as much memory as we can for the graph
@ -174,7 +189,9 @@ class graph:
# https://stackoverflow.com/questions/26635684/calling-enter-and-exit-manually#39172487
self.stream_ctx.__enter__()
self.cuda_graph.capture_begin(*self.pool)
self.cuda_graph.capture_begin(
*self.pool, capture_error_mode=self.capture_error_mode
)
def __exit__(self, exc_type, exc_value, traceback):
self.cuda_graph.capture_end()

View File

@ -4161,6 +4161,7 @@ CUDA_IDENTIFIER_MAP = collections.OrderedDict(
("cudaStreamCaptureMode", ("hipStreamCaptureMode", CONV_TYPE, API_RUNTIME)),
("cudaStreamCaptureModeGlobal", ("hipStreamCaptureModeGlobal", CONV_TYPE, API_RUNTIME)),
("cudaStreamCaptureModeRelaxed", ("hipStreamCaptureModeRelaxed", CONV_TYPE, API_RUNTIME)),
("cudaStreamCaptureModeThreadLocal", ("hipStreamCaptureModeThreadLocal", CONV_TYPE, API_RUNTIME)),
("cudaStreamBeginCapture", ("hipStreamBeginCapture", CONV_TYPE, API_RUNTIME)),
("cudaStreamEndCapture", ("hipStreamEndCapture", CONV_TYPE, API_RUNTIME)),
("cudaGraphInstantiate", ("hipGraphInstantiate", CONV_TYPE, API_RUNTIME)),