mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Expose cudaStreamCaptureMode in CUDA Graphs, use local setting in inductor (#107407)
> capture_error_mode (str, optional): specifies the cudaStreamCaptureMode for the graph capture stream. Can be "global", "thread_local" or "relaxed". During cuda graph capture, some actions, such as cudaMalloc, may be unsafe. "global" will error on actions in other threads, "thread_local" will only error for actions in the current thread, and "relaxed" will not error on these actions. Inductor codegen is single-threaded, so it should be safe to enable "thread_local" for inductor's cuda graph capturing. We have seen errors when inductor cudagraphs has been used concurrently with data preprocessing in other threads. Differential Revision: [D48656014](https://our.internmc.facebook.com/intern/diff/D48656014) Pull Request resolved: https://github.com/pytorch/pytorch/pull/107407 Approved by: https://github.com/albanD, https://github.com/eqy
This commit is contained in:
parent
c18d2a3c05
commit
0a9778a372
|
|
@ -63,7 +63,7 @@ CUDAGraph::CUDAGraph()
|
|||
#endif
|
||||
}
|
||||
|
||||
void CUDAGraph::capture_begin(MempoolId_t pool/*=0*/) {
|
||||
void CUDAGraph::capture_begin(MempoolId_t pool/*=0*/, cudaStreamCaptureMode capture_mode) {
|
||||
#if !defined(USE_ROCM) || ROCM_VERSION >= 50300
|
||||
TORCH_CHECK(!has_graph_exec_,
|
||||
"This CUDAGraph instance already owns a captured graph. "
|
||||
|
|
@ -118,7 +118,7 @@ void CUDAGraph::capture_begin(MempoolId_t pool/*=0*/) {
|
|||
// cudaStreamCaptureModeGlobal is the most conservative option to
|
||||
// prevent potentially unsafe CUDA API calls during capture. See
|
||||
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85
|
||||
AT_CUDA_CHECK(cudaStreamBeginCapture(capture_stream_, cudaStreamCaptureModeGlobal));
|
||||
AT_CUDA_CHECK(cudaStreamBeginCapture(capture_stream_, capture_mode));
|
||||
|
||||
cudaStreamCaptureStatus status;
|
||||
AT_CUDA_CHECK(cudaStreamGetCaptureInfo(stream, &status, nullptr));
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ struct TORCH_CUDA_CPP_API CUDAGraph {
|
|||
CUDAGraph();
|
||||
~CUDAGraph();
|
||||
|
||||
void capture_begin(MempoolId_t pool={0, 0});
|
||||
void capture_begin(MempoolId_t pool={0, 0}, cudaStreamCaptureMode capture_mode = cudaStreamCaptureModeGlobal);
|
||||
void capture_end();
|
||||
void replay();
|
||||
void reset();
|
||||
|
|
|
|||
|
|
@ -3186,6 +3186,61 @@ exit(2)
|
|||
for p_control, p_graphed in zip(params_control, params_graphed):
|
||||
self.assertEqual(p_control, p_graphed)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs")
|
||||
def test_cuda_graph_error_options(self):
|
||||
def fn():
|
||||
x = torch.zeros([2000], device="cuda")
|
||||
y = x + x + x
|
||||
return y
|
||||
|
||||
mem = None
|
||||
|
||||
def raw_malloc():
|
||||
global mem
|
||||
mem = None
|
||||
stream = torch.cuda.Stream()
|
||||
try:
|
||||
with torch.cuda.stream(stream):
|
||||
mem = torch.cuda.caching_allocator_alloc(1024)
|
||||
except BaseException:
|
||||
if mem is None:
|
||||
return
|
||||
try:
|
||||
torch.cuda.caching_allocator_delete(mem)
|
||||
mem = None
|
||||
return None
|
||||
except BaseException:
|
||||
pass
|
||||
|
||||
def throws_on_cuda_event(capture_error_mode):
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
torch.cuda.synchronize()
|
||||
stream = torch.cuda.Stream()
|
||||
stream.wait_stream(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(stream):
|
||||
fn()
|
||||
stream.synchronize()
|
||||
torch.cuda.current_stream().wait_stream(stream)
|
||||
torch.cuda.synchronize()
|
||||
try:
|
||||
with torch.cuda.graph(graph, stream=stream, capture_error_mode=capture_error_mode):
|
||||
out = fn()
|
||||
thread = threading.Thread(target=raw_malloc)
|
||||
thread.start()
|
||||
thread.join()
|
||||
except Exception:
|
||||
if mem is not None:
|
||||
torch.cuda.caching_allocator_delete(mem)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
self.assertFalse(throws_on_cuda_event("thread_local"))
|
||||
self.assertFalse(throws_on_cuda_event("relaxed"))
|
||||
|
||||
# Exception would Corrupt Process and make other tests fail
|
||||
# self.assertTrue(throws_on_cuda_event("global"))
|
||||
|
||||
def test_batch_norm_gather_stats(self):
|
||||
input = torch.randn(1, 3, 3, 3, device='cuda')
|
||||
mean, invstd = torch.batch_norm_gather_stats(
|
||||
|
|
|
|||
|
|
@ -1705,7 +1705,7 @@ class _CudaEventBase:
|
|||
|
||||
# Defined in torch/csrc/cuda/Graph.cpp
|
||||
class _CUDAGraph:
|
||||
def capture_begin(self, pool: Optional[Tuple[_int, _int]] = ...) -> None: ...
|
||||
def capture_begin(self, pool: Optional[Tuple[_int, _int]] = ..., capture_error_mode: str = "global") -> None: ...
|
||||
def capture_end(self) -> None: ...
|
||||
def replay(self) -> None: ...
|
||||
def reset(self) -> None: ...
|
||||
|
|
|
|||
|
|
@ -770,7 +770,7 @@ def cudagraphify_impl(
|
|||
|
||||
# record
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph, stream=stream):
|
||||
with torch.cuda.graph(graph, stream=stream, capture_error_mode="thread_local"):
|
||||
static_outputs = model(list(static_inputs))
|
||||
if not isinstance(static_outputs, (list, tuple)):
|
||||
static_outputs = (static_outputs,)
|
||||
|
|
|
|||
|
|
@ -1076,7 +1076,10 @@ class CUDAGraphNode:
|
|||
with preserve_rng_state(), torch.cuda.device(
|
||||
self.device
|
||||
), clear_cublas_manager(), torch.cuda.graph(
|
||||
self.graph, stream=self.stream, pool=self.cuda_graphs_pool
|
||||
self.graph,
|
||||
stream=self.stream,
|
||||
pool=self.cuda_graphs_pool,
|
||||
capture_error_mode="thread_local",
|
||||
), get_history_recording():
|
||||
static_outputs = model(inputs)
|
||||
|
||||
|
|
@ -1686,6 +1689,7 @@ class CUDAGraphTreeManager:
|
|||
self.graph,
|
||||
pool=self.cuda_graphs_thread_pool,
|
||||
stream=self.stream,
|
||||
capture_error_mode="thread_local",
|
||||
):
|
||||
pass
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
#include <torch/csrc/utils/pybind.h>
|
||||
|
||||
#include <ATen/cuda/CUDAGraph.h>
|
||||
#include <c10/cuda/CUDAGraphsC10Utils.h>
|
||||
|
||||
// Cargo culted partially from csrc/distributed/c10d/init.cpp
|
||||
// and partially from csrc/cuda/Stream.cpp.
|
||||
|
|
@ -26,13 +27,32 @@ void THCPGraph_init(PyObject* module) {
|
|||
|
||||
shared_ptr_class_<::at::cuda::CUDAGraph>(torch_C_m, "_CUDAGraph")
|
||||
.def(py::init<>())
|
||||
// I'm not sure this is the correct order of all the arguments. Pybind11
|
||||
// docs aren't clear. But it works.
|
||||
.def(
|
||||
"capture_begin",
|
||||
torch::wrap_pybind_function_no_gil(
|
||||
&at::cuda::CUDAGraph::capture_begin),
|
||||
py::arg("pool") = c10::cuda::MempoolId_t{0, 0})
|
||||
[](::at::cuda::CUDAGraph& self,
|
||||
c10::optional<c10::cuda::MempoolId_t> pool_opt,
|
||||
std::string capture_error_mode) {
|
||||
cudaStreamCaptureMode capture_mode;
|
||||
c10::cuda::MempoolId_t pool = pool_opt.has_value()
|
||||
? pool_opt.value()
|
||||
: c10::cuda::MempoolId_t{0, 0};
|
||||
if (capture_error_mode == "global") {
|
||||
capture_mode = cudaStreamCaptureModeGlobal;
|
||||
} else if (capture_error_mode == "thread_local") {
|
||||
capture_mode = cudaStreamCaptureModeThreadLocal;
|
||||
} else if (capture_error_mode == "relaxed") {
|
||||
capture_mode = cudaStreamCaptureModeRelaxed;
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Unknown capture error mode. Expected `global`, `thread_local`, or `relaxed`, got ",
|
||||
capture_error_mode);
|
||||
}
|
||||
return self.capture_begin(pool, capture_mode);
|
||||
},
|
||||
py::arg("pool"),
|
||||
py::arg("capture_error_mode"),
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def(
|
||||
"capture_end",
|
||||
torch::wrap_pybind_function_no_gil(&at::cuda::CUDAGraph::capture_end))
|
||||
|
|
|
|||
|
|
@ -56,7 +56,7 @@ class CUDAGraph(torch._C._CUDAGraph):
|
|||
def __new__(cls):
|
||||
return super().__new__(cls)
|
||||
|
||||
def capture_begin(self, pool=None):
|
||||
def capture_begin(self, pool=None, capture_error_mode="global"):
|
||||
r"""
|
||||
Begins capturing CUDA work on the current stream.
|
||||
|
||||
|
|
@ -68,13 +68,13 @@ class CUDAGraph(torch._C._CUDAGraph):
|
|||
pool (optional): Token (returned by :func:`~torch.cuda.graph_pool_handle` or
|
||||
:meth:`other_Graph_instance.pool()<torch.cuda.CUDAGraph.pool>`) that hints this graph may share memory
|
||||
with the indicated pool. See :ref:`Graph memory management<graph-memory-management>`.
|
||||
"""
|
||||
# I'm not sure if pybind11 converts a None arg to the default defined on the C++ side,
|
||||
# so I'm not taking any chances.
|
||||
if pool is None:
|
||||
super().capture_begin()
|
||||
else:
|
||||
super().capture_begin(pool)
|
||||
capture_error_mode (str, optional): specifies the cudaStreamCaptureMode for the graph capture stream.
|
||||
Can be "global", "thread_local" or "relaxed". During cuda graph capture, some actions, such as cudaMalloc,
|
||||
may be unsafe. "global" will error on actions in other threads, "thread_local" will only error for
|
||||
actions in the current thread, and "relaxed" will not error on these actions. Do NOT change this setting
|
||||
unless you're familiar with `cudaStreamCaptureMode <https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85>`_
|
||||
""" # noqa: B950
|
||||
super().capture_begin(pool=pool, capture_error_mode=capture_error_mode)
|
||||
|
||||
def capture_end(self):
|
||||
r"""
|
||||
|
|
@ -139,6 +139,11 @@ class graph:
|
|||
may share memory from the specified pool. See :ref:`Graph memory management<graph-memory-management>`.
|
||||
stream (torch.cuda.Stream, optional): If supplied, will be set as the current stream in the context.
|
||||
If not supplied, ``graph`` sets its own internal side stream as the current stream in the context.
|
||||
capture_error_mode (str, optional): specifies the cudaStreamCaptureMode for the graph capture stream.
|
||||
Can be "global", "thread_local" or "relaxed". During cuda graph capture, some actions, such as cudaMalloc,
|
||||
may be unsafe. "global" will error on actions in other threads, "thread_local" will only error for
|
||||
actions in the current thread, and "relaxed" will not error on actions. Do NOT change this setting
|
||||
unless you're familiar with `cudaStreamCaptureMode <https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85>`_
|
||||
|
||||
.. note::
|
||||
For effective memory sharing, if you pass a ``pool`` used by a previous capture and the previous capture
|
||||
|
|
@ -146,10 +151,19 @@ class graph:
|
|||
|
||||
.. warning::
|
||||
This API is in beta and may change in future releases.
|
||||
"""
|
||||
|
||||
.. _cudaStreamCaptureMode:
|
||||
https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85
|
||||
""" # noqa: B950
|
||||
default_capture_stream = None
|
||||
|
||||
def __init__(self, cuda_graph, pool=None, stream=None):
|
||||
def __init__(
|
||||
self,
|
||||
cuda_graph,
|
||||
pool=None,
|
||||
stream=None,
|
||||
capture_error_mode: str = "global",
|
||||
):
|
||||
# Lazy-init of default_capture_stream helps avoid circular-import errors.
|
||||
# Not thread safe, but graphs already have the general (explicitly documented)
|
||||
# restriction that only one capture may be underway at a time in the process.
|
||||
|
|
@ -163,6 +177,7 @@ class graph:
|
|||
assert self.capture_stream is not None
|
||||
self.stream_ctx = torch.cuda.stream(self.capture_stream)
|
||||
self.cuda_graph = cuda_graph
|
||||
self.capture_error_mode = capture_error_mode
|
||||
|
||||
def __enter__(self):
|
||||
# Free as much memory as we can for the graph
|
||||
|
|
@ -174,7 +189,9 @@ class graph:
|
|||
# https://stackoverflow.com/questions/26635684/calling-enter-and-exit-manually#39172487
|
||||
self.stream_ctx.__enter__()
|
||||
|
||||
self.cuda_graph.capture_begin(*self.pool)
|
||||
self.cuda_graph.capture_begin(
|
||||
*self.pool, capture_error_mode=self.capture_error_mode
|
||||
)
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.cuda_graph.capture_end()
|
||||
|
|
|
|||
|
|
@ -4161,6 +4161,7 @@ CUDA_IDENTIFIER_MAP = collections.OrderedDict(
|
|||
("cudaStreamCaptureMode", ("hipStreamCaptureMode", CONV_TYPE, API_RUNTIME)),
|
||||
("cudaStreamCaptureModeGlobal", ("hipStreamCaptureModeGlobal", CONV_TYPE, API_RUNTIME)),
|
||||
("cudaStreamCaptureModeRelaxed", ("hipStreamCaptureModeRelaxed", CONV_TYPE, API_RUNTIME)),
|
||||
("cudaStreamCaptureModeThreadLocal", ("hipStreamCaptureModeThreadLocal", CONV_TYPE, API_RUNTIME)),
|
||||
("cudaStreamBeginCapture", ("hipStreamBeginCapture", CONV_TYPE, API_RUNTIME)),
|
||||
("cudaStreamEndCapture", ("hipStreamEndCapture", CONV_TYPE, API_RUNTIME)),
|
||||
("cudaGraphInstantiate", ("hipGraphInstantiate", CONV_TYPE, API_RUNTIME)),
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user