mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[CUDAGraph] Add getter for cuda graph exec (#161294)
This is far simpler than #155164 since we never destroy the cudaGraphExec_t. The request comes from TRT-LLM specifically. The motivation is that some power users would like to mutate specific kernel parameters via APIs like `cudaGraphExec*SetParams` after a cuda graph has been instantiated. For example, a common request has been to be able to change the sequence length of attention kernels, after having captured a graph for the largest possible sequence length. It turns out that the host overhead you eliminate via cuda graphs in LLM inference ends up causing an increase in computation time when you size your kernels to the maximum possible sequence length (which I believe is done in both TRT-LLM and vLLM). Attention is the most problematic kernel because its computation time is quadratic in the sequence length, rather than linear. This can work if your attention kernel can work for arbitrary shapes (this is not the case for all attention implementations! Many of them specialize with templates), and you have a persistent kernel that allocates only as many blocks as you have SM's (so you don't have to figure out how many blocks to allocate for a specific sequence length). Using a conditional SWITCH node is a better generic approach to this problem, but that requires more infrastructure work. Note that this requires knowledge of the exact location of the value in your kernel's parameter buffer to mutate. It won't work with arbitrary stream capture code whose kernels you don't know before hand. So I expect this code path to be rarely used. Testing: ``` pytest -s -k raw_graph_exec test/test_cuda.py ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/161294 Approved by: https://github.com/ngimel, https://github.com/BoyuanFeng, https://github.com/eellison, https://github.com/eqy
This commit is contained in:
parent
995397d47a
commit
cf94cadbee
|
|
@ -252,6 +252,13 @@ cudaGraph_t CUDAGraph::raw_cuda_graph() {
|
|||
return graph_;
|
||||
}
|
||||
|
||||
cudaGraphExec_t CUDAGraph::raw_cuda_graph_exec() {
|
||||
TORCH_CHECK(
|
||||
has_graph_exec_,
|
||||
"You cannot access the raw cudaGraphExec_t instance until instantiate() has been called");
|
||||
return graph_exec_;
|
||||
}
|
||||
|
||||
void CUDAGraph::reset() {
|
||||
// I'd prefer these checks throw exceptions, not print warnings,
|
||||
// but the destructor calls reset(), and at least one CI build
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ struct TORCH_CUDA_CPP_API CUDAGraph {
|
|||
void enable_debug_mode();
|
||||
void debug_dump(const std::string& debug_path);
|
||||
cudaGraph_t raw_cuda_graph();
|
||||
cudaGraphExec_t raw_cuda_graph_exec();
|
||||
|
||||
protected:
|
||||
cudaGraph_t graph_ = nullptr;
|
||||
|
|
|
|||
|
|
@ -3666,6 +3666,35 @@ exit(2)
|
|||
|
||||
graph.replay()
|
||||
|
||||
@unittest.skipIf(
|
||||
not TEST_CUDA_GRAPH or not TEST_CUDA_PYTHON_BINDINGS,
|
||||
"CUDA >= 11.0 or ROCM >= 5.3 required for graphs, cuda-bindings must be installed",
|
||||
)
|
||||
@parametrize("keep_graph", [True, False])
|
||||
def test_cuda_graph_raw_graph_exec(self, keep_graph):
|
||||
import cuda.bindings.runtime as cudart
|
||||
|
||||
graph = torch.cuda.CUDAGraph(keep_graph=keep_graph)
|
||||
x = torch.zeros([2000], device="cuda")
|
||||
y = torch.ones([2000], device="cuda")
|
||||
with torch.cuda.graph(graph, capture_error_mode="relaxed"):
|
||||
z = x + y
|
||||
|
||||
if keep_graph:
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
r"You cannot access the raw (cuda|hip)GraphExec_t instance until instantiate\(\) has been called",
|
||||
):
|
||||
graph.raw_cuda_graph_exec()
|
||||
|
||||
graph.instantiate()
|
||||
raw_pointer = graph.raw_cuda_graph_exec()
|
||||
|
||||
cudart_cuda_graph_exec = cudart.cudaGraphExec_t(init_value=raw_pointer)
|
||||
cuda_python_error_check(cudart.cudaGraphExecGetFlags(cudart_cuda_graph_exec))
|
||||
|
||||
graph.replay()
|
||||
|
||||
@unittest.skipIf(
|
||||
not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -2328,6 +2328,7 @@ class _CUDAGraph:
|
|||
def enable_debug_mode(self) -> None: ...
|
||||
def debug_dump(self, debug_path: str) -> None: ...
|
||||
def raw_cuda_graph(self) -> _int: ...
|
||||
def raw_cuda_graph_exec(self) -> _int: ...
|
||||
|
||||
# Defined in torch/csrc/cuda/MemPool.cpp
|
||||
class _MemPool:
|
||||
|
|
|
|||
|
|
@ -101,5 +101,16 @@ void THCPGraph_init(PyObject* module) {
|
|||
// compile error.
|
||||
return reinterpret_cast<uintptr_t>(graph);
|
||||
},
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def(
|
||||
"raw_cuda_graph_exec",
|
||||
[](::at::cuda::CUDAGraph& self) {
|
||||
cudaGraphExec_t graph_exec = self.raw_cuda_graph_exec();
|
||||
// We return a raw int here, since otherwise pybind11 will
|
||||
// try to return the underlying struct of cudaGraphExec_t
|
||||
// points to, which is opaque and therefore causes a
|
||||
// compile error.
|
||||
return reinterpret_cast<uintptr_t>(graph_exec);
|
||||
},
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
}
|
||||
|
|
|
|||
|
|
@ -173,6 +173,13 @@ class CUDAGraph(torch._C._CUDAGraph):
|
|||
""" # noqa: B950
|
||||
return super().raw_cuda_graph()
|
||||
|
||||
def raw_cuda_graph_exec(self) -> int:
|
||||
r"""Returns the underlying cudaGraphExec_t. ``instantiate`` must have been called if ``keep_graph`` is True, or ``capture_end`` must have been called if ``keep_graph`` is False. If you call ``instantiate()`` after ``raw_cuda_graph_exec()``, the previously returned cudaGraphExec_t will be destroyed. It is your responsibility not to use this object after destruction.
|
||||
|
||||
See the following for APIs for how to manipulate this object: `Graph Execution <https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__GRAPH__EXEC.html>`_ and `cuda-python Graph Execution bindings <https://nvidia.github.io/cuda-python/cuda-bindings/latest/module/runtime.html#graph-execution>`_
|
||||
""" # noqa: B950
|
||||
return super().raw_cuda_graph_exec()
|
||||
|
||||
|
||||
class graph:
|
||||
r"""Context-manager that captures CUDA work into a :class:`torch.cuda.CUDAGraph` object for later replay.
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user