pytorch/torch/csrc/cuda/shared/cudart.cpp
Daniel Galvez c7515da7b0 Implement cuda graphs implementation of torch.cond and torch.while_loop (#140979)
This is a new PR for #130386 , which got stale and was closed. Since I force-pushed to that branch in order to rebase it on top of main, the PR can no longer be reopened, according to https://github.com/isaacs/github/issues/361

I fixed the possibly-not-warmed-up problem described here: https://github.com/pytorch/pytorch/pull/130386/files#r1690856534

Since starting this, torch.cond and torch.while_loop now apparently have support for backward passes. I will look into what it might take to support that.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140979
Approved by: https://github.com/eqy, https://github.com/eellison
2025-02-11 18:16:15 +00:00

156 lines
4.2 KiB
C++

#include <cuda.h>
#include <cuda_runtime.h>
#include <torch/csrc/utils/pybind.h>
#if !defined(USE_ROCM)
#include <cuda_profiler_api.h>
#else
#include <hip/hip_runtime_api.h>
#endif
#include <c10/cuda/CUDAException.h>
#include <c10/cuda/CUDAGuard.h>
namespace torch::cuda::shared {
#ifdef USE_ROCM
namespace {
hipError_t hipReturnSuccess() {
return hipSuccess;
}
} // namespace
#endif
void initCudartBindings(PyObject* module) {
auto m = py::handle(module).cast<py::module>();
auto cudart = m.def_submodule("_cudart", "libcudart.so bindings");
// By splitting the names of these objects into two literals we prevent the
// HIP rewrite rules from changing these names when building with HIP.
#if !defined(USE_ROCM) && defined(CUDA_VERSION) && CUDA_VERSION < 12000
// cudaOutputMode_t is used in cudaProfilerInitialize only. The latter is gone
// in CUDA 12.
py::enum_<cudaOutputMode_t>(
cudart,
"cuda"
"OutputMode")
.value("KeyValuePair", cudaKeyValuePair)
.value("CSV", cudaCSV);
#endif
py::enum_<cudaError_t>(
cudart,
"cuda"
"Error")
.value("success", cudaSuccess);
cudart.def(
"cuda"
"GetErrorString",
cudaGetErrorString);
cudart.def(
"cuda"
"ProfilerStart",
#ifdef USE_ROCM
hipReturnSuccess
#else
cudaProfilerStart
#endif
);
cudart.def(
"cuda"
"ProfilerStop",
#ifdef USE_ROCM
hipReturnSuccess
#else
cudaProfilerStop
#endif
);
cudart.def(
"cuda"
"HostRegister",
[](uintptr_t ptr, size_t size, unsigned int flags) -> cudaError_t {
py::gil_scoped_release no_gil;
return C10_CUDA_ERROR_HANDLED(
// NOLINTNEXTLINE(performance-no-int-to-ptr)
cudaHostRegister((void*)ptr, size, flags));
});
cudart.def(
"cuda"
"HostUnregister",
[](uintptr_t ptr) -> cudaError_t {
py::gil_scoped_release no_gil;
// NOLINTNEXTLINE(performance-no-int-to-ptr)
return C10_CUDA_ERROR_HANDLED(cudaHostUnregister((void*)ptr));
});
cudart.def(
"cuda"
"StreamCreate",
[](uintptr_t ptr) -> cudaError_t {
py::gil_scoped_release no_gil;
// NOLINTNEXTLINE(performance-no-int-to-ptr)
return C10_CUDA_ERROR_HANDLED(cudaStreamCreate((cudaStream_t*)ptr));
});
cudart.attr(
"cuda"
"StreamDefault") = cudaStreamDefault;
cudart.attr(
"cuda"
"StreamNonBlocking") = cudaStreamNonBlocking;
cudart.def(
"cuda"
"StreamCreateWithFlags",
[](uintptr_t ptr, unsigned int flags) -> cudaError_t {
// NOLINTNEXTLINE(performance-no-int-to-ptr)
return C10_CUDA_ERROR_HANDLED(
cudaStreamCreateWithFlags((cudaStream_t*)ptr, flags));
});
cudart.def(
"cuda"
"StreamDestroy",
[](uintptr_t ptr) -> cudaError_t {
py::gil_scoped_release no_gil;
// NOLINTNEXTLINE(performance-no-int-to-ptr)
return C10_CUDA_ERROR_HANDLED(cudaStreamDestroy((cudaStream_t)ptr));
});
#if !defined(USE_ROCM) && defined(CUDA_VERSION) && CUDA_VERSION < 12000
// cudaProfilerInitialize is no longer needed after CUDA 12:
// https://forums.developer.nvidia.com/t/cudaprofilerinitialize-is-deprecated-alternative/200776/3
cudart.def(
"cuda"
"ProfilerInitialize",
cudaProfilerInitialize,
py::call_guard<py::gil_scoped_release>());
#endif
cudart.def(
"cuda"
"MemGetInfo",
[](c10::DeviceIndex device) -> std::pair<size_t, size_t> {
c10::cuda::CUDAGuard guard(device);
size_t device_free = 0;
size_t device_total = 0;
py::gil_scoped_release no_gil;
C10_CUDA_CHECK(cudaMemGetInfo(&device_free, &device_total));
return {device_free, device_total};
});
py::enum_<cudaStreamCaptureMode>(
cudart,
"cuda"
"StreamCaptureMode")
.value("Global", cudaStreamCaptureModeGlobal)
.value("ThreadLocal", cudaStreamCaptureModeThreadLocal)
.value("Relaxed", cudaStreamCaptureModeRelaxed);
cudart.def(
"cuda"
"ThreadExchangeStreamCaptureMode",
[](cudaStreamCaptureMode mode) -> cudaStreamCaptureMode {
C10_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&mode));
return mode;
});
}
} // namespace torch::cuda::shared