pytorch/torch/csrc/cuda/Graph.cpp
Daniel Galvez c7515da7b0 Implement cuda graphs implementation of torch.cond and torch.while_loop (#140979)
This is a new PR for #130386 , which got stale and was closed. Since I force-pushed to that branch in order to rebase it on top of main, the PR can no longer be reopened, according to https://github.com/isaacs/github/issues/361

I fixed the possibly-not-warmed-up problem described here: https://github.com/pytorch/pytorch/pull/130386/files#r1690856534

Since starting this, torch.cond and torch.while_loop now apparently have support for backward passes. I will look into what it might take to support that.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140979
Approved by: https://github.com/eqy, https://github.com/eellison
2025-02-11 18:16:15 +00:00

117 lines
4.5 KiB
C++

#include <torch/csrc/python_headers.h>
#include <pybind11/chrono.h>
#include <torch/csrc/jit/python/pybind_utils.h>
#include <torch/csrc/utils/pybind.h>
#include <ATen/cuda/CUDAGraph.h>
#include <c10/cuda/CUDAGraphsC10Utils.h>
// Cargo culted partially from csrc/distributed/c10d/init.cpp
// and partially from csrc/cuda/Stream.cpp.
// THCPStream_init is also declared at global scope.
// Because THCPGraph_init is forward declared in the only consumer
// (csrc/Module.cpp) I don't think we need a Graph.h.
template <typename T>
using shared_ptr_class_ = py::class_<T, std::shared_ptr<T>>;
void THCPGraph_init(PyObject* module) {
// Pybind11 patch notes say "py::module_" is more up-to-date syntax,
// but CI linter and some builds prefer "module".
auto torch_C_m = py::handle(module).cast<py::module>();
torch_C_m.def("_graph_pool_handle", &::at::cuda::graph_pool_handle);
shared_ptr_class_<::at::cuda::CUDAGraph>(torch_C_m, "_CUDAGraph")
.def(py::init<>())
.def(
"capture_begin",
[](::at::cuda::CUDAGraph& self,
std::optional<c10::cuda::MempoolId_t> pool_opt,
const std::string& capture_error_mode) {
cudaStreamCaptureMode capture_mode{};
c10::cuda::MempoolId_t pool = pool_opt.has_value()
? pool_opt.value()
: c10::cuda::MempoolId_t{0, 0};
if (capture_error_mode == "global") {
capture_mode = cudaStreamCaptureModeGlobal;
} else if (capture_error_mode == "thread_local") {
capture_mode = cudaStreamCaptureModeThreadLocal;
} else if (capture_error_mode == "relaxed") {
capture_mode = cudaStreamCaptureModeRelaxed;
} else {
TORCH_CHECK(
false,
"Unknown capture error mode. Expected `global`, `thread_local`, or `relaxed`, got ",
capture_error_mode);
}
return self.capture_begin(pool, capture_mode);
},
py::arg("pool"),
py::arg("capture_error_mode"),
py::call_guard<py::gil_scoped_release>())
.def(
"capture_end",
torch::wrap_pybind_function_no_gil(&at::cuda::CUDAGraph::capture_end))
.def(
"register_generator_state",
[](::at::cuda::CUDAGraph& self, py::handle raw_generator) {
auto generator = THPGenerator_Unwrap(raw_generator.ptr());
// We've unwrapped Python object to C++ object,
// so we could release GIL before calling into C++
py::gil_scoped_release release;
return self.register_generator_state(generator);
},
py::arg("generator"))
.def(
"replay",
torch::wrap_pybind_function_no_gil(&at::cuda::CUDAGraph::replay))
.def(
"reset",
torch::wrap_pybind_function_no_gil(&at::cuda::CUDAGraph::reset))
.def(
"pool",
torch::wrap_pybind_function_no_gil(&at::cuda::CUDAGraph::pool))
.def(
"debug_dump",
torch::wrap_pybind_function_no_gil(
&::at::cuda::CUDAGraph::debug_dump))
.def(
"enable_debug_mode",
torch::wrap_pybind_function_no_gil(
&::at::cuda::CUDAGraph::enable_debug_mode))
.def(
"debug_dump",
torch::wrap_pybind_function_no_gil(
&::at::cuda::CUDAGraph::debug_dump),
py::arg("debug_path"))
.def_static(
"get_currently_capturing_graph",
torch::wrap_pybind_function_no_gil(
&::at::cuda::CUDAGraph::get_currently_capturing_graph),
py::return_value_policy::reference)
.def(
"begin_capture_to_if_node",
torch::wrap_pybind_function_no_gil(
&::at::cuda::CUDAGraph::begin_capture_to_if_node),
py::arg("scalar_cuda_pred_tensor"))
.def(
"begin_capture_to_while_loop_node",
torch::wrap_pybind_function_no_gil(
&::at::cuda::CUDAGraph::begin_capture_to_while_loop_node),
py::arg("scalar_cuda_pred_tensor"))
.def(
"end_capture_to_conditional_node",
torch::wrap_pybind_function_no_gil(
&::at::cuda::CUDAGraph::end_capture_to_conditional_node))
.def_static(
"set_conditional_handle",
torch::wrap_pybind_function_no_gil(
&::at::cuda::CUDAGraph::set_conditional_handle),
py::arg("handle"),
py::arg("scalar_cuda_pred_tensor"));
}