[CUDA12] set_device change (#94864)

This PR adds workaround for CUDA 12 [`cudaSetDevice` change](https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__DEVICE.html#group__CUDART__DEVICE_1g159587909ffa0791bbe4b40187a4c6bb) which will always create primary context on target device. So operations like this:
```Python
import torch
x = torch.randn(1, device="cuda:1")
```
would always create primary context on on device `cuda:1` because it is creating a tensor on it and on device `cuda:0` because the destructor of CUDA Device guard calls `cudaSetDevice(0)`.
After this PR the CUDA Device guard will not call `cudaSetDevice(0)` if primary context does not exist on `cuda:0`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94864
Approved by: https://github.com/malfet, https://github.com/atalman, https://github.com/ezyang
This commit is contained in:
Aidyn-A 2023-04-10 17:31:12 +00:00 committed by PyTorch MergeBot
parent 3fcc5ff0d6
commit 69eef5a4be
26 changed files with 282 additions and 63 deletions

View File

@ -637,6 +637,37 @@ command = [
'@{{PATHSFILE}}'
]
[[linter]]
code = 'RAWCUDADEVICE'
include_patterns = [
'aten/**',
'c10/**',
'torch/csrc/**',
]
exclude_patterns = [
'aten/src/ATen/cuda/CUDAContext.cpp',
'aten/src/ATen/cuda/CUDAGeneratorImpl.cpp',
'aten/src/ATen/test/**',
'c10/core/impl/InlineDeviceGuard.h',
'c10/cuda/CUDAFunctions.cpp',
'c10/cuda/CUDAGuard.h',
'c10/cuda/impl/CUDATest.cpp',
'torch/csrc/cuda/nccl.cpp',
]
command = [
'python3',
'tools/linter/adapters/grep_linter.py',
'--pattern=cudaSetDevice',
'--pattern=cudaGetDevice',
'--linter-name=RAWCUDADEVICE',
'--error-name=raw CUDA API usage',
"""--error-description=\
This line calls raw CUDA APIs directly; please use c10::cuda wrappers instead.
""",
'--',
'@{{PATHSFILE}}'
]
[[linter]]
code = 'ROOT_LOGGING'
include_patterns = [

View File

@ -27,7 +27,7 @@ using CuSparsePoolType = DeviceThreadHandlePool<cusparseHandle_t, createCusparse
cusparseHandle_t getCurrentCUDASparseHandle() {
int device;
AT_CUDA_CHECK(cudaGetDevice(&device));
AT_CUDA_CHECK(c10::cuda::GetDevice(&device));
// Thread local PoolWindows are lazily-initialized
// to avoid initialization issues that caused hangs on Windows.

View File

@ -81,7 +81,7 @@ at::DataPtr getNewWorkspace() {
cublasHandle_t getCurrentCUDABlasHandle() {
int device;
AT_CUDA_CHECK(cudaGetDevice(&device));
AT_CUDA_CHECK(c10::cuda::GetDevice(&device));
// Thread local PoolWindows are lazily-initialized
// to avoid initialization issues that caused hangs on Windows.

View File

@ -15,6 +15,7 @@
#include <ATen/native/cuda/CuFFTPlanCache.h>
#include <c10/util/Exception.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/cuda/CUDAFunctions.h>
#include <c10/util/irange.h>
#if AT_CUDNN_ENABLED()
@ -225,7 +226,7 @@ const at::cuda::NVRTC& CUDAHooks::nvrtc() const {
int64_t current_device() {
int device;
cudaError_t err = cudaGetDevice(&device);
cudaError_t err = c10::cuda::GetDevice(&device);
if (err == cudaSuccess) {
return device;
}

View File

@ -33,7 +33,7 @@ using CudnnPoolType = at::cuda::DeviceThreadHandlePool<cudnnHandle_t, createCuDN
cudnnHandle_t getCudnnHandle() {
int device;
AT_CUDA_CHECK(cudaGetDevice(&device));
AT_CUDA_CHECK(c10::cuda::GetDevice(&device));
// Thread local PoolWindows are lazily-initialized
// to avoid initialization issues that caused hangs on Windows.

View File

@ -56,7 +56,7 @@ bool allContiguous(at::TensorList tensors) {
void getLaunchConfig(dim3* block, dim3* grid, int64_t numel) {
int curDevice = -1;
cudaGetDevice(&curDevice);
c10::cuda::GetDevice(&curDevice);
*block = cuda::getApplyBlock();
TORCH_INTERNAL_ASSERT(cuda::getApplyGrid(numel, *grid, curDevice),
"Could not get grid size for pointwise apply.");

View File

@ -85,7 +85,7 @@ std::tuple<Tensor, Tensor, Tensor> compute_unique(
dim3(std::min(static_cast<int64_t>(cuda::getApplyBlock().x), num_inp));
dim3 grid;
int curDevice = -1;
cudaGetDevice(&curDevice);
c10::cuda::GetDevice(&curDevice);
cuda::getApplyGrid(num_inp, grid, curDevice);
adjacent_difference_kernel<<<grid, block, 0, stream>>>(
num_inp, data, inv_loc_ptr);

View File

@ -30,7 +30,7 @@ using CuSolverDnPoolType = DeviceThreadHandlePool<cusolverDnHandle_t, createCuso
cusolverDnHandle_t getCurrentCUDASolverDnHandle() {
int device;
AT_CUDA_CHECK(cudaGetDevice(&device));
AT_CUDA_CHECK(c10::cuda::GetDevice(&device));
// Thread local PoolWindows are lazily-initialized
// to avoid initialization issues that caused hangs on Windows.

View File

@ -326,7 +326,7 @@ auto get_generator_sources(const cudnnBackendDescriptorType_t& desc, const Tenso
int64_t get_available_workspace() {
int device;
C10_CUDA_CHECK(cudaGetDevice(&device));
C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
size_t max_block_size = 0;
c10::cuda::CUDACachingAllocator::cacheInfo(device, &max_block_size);
return static_cast<int64_t>(max_block_size);

View File

@ -314,7 +314,7 @@ Tensor& add_out_dense_sparse_cuda(Tensor& r_, const Tensor& dense, const SparseT
const dim3 block = cuda::getApplyBlock();
dim3 grid;
int curDevice = -1;
cudaGetDevice(&curDevice);
c10::cuda::GetDevice(&curDevice);
cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
if (sparse.dense_dim() == 0) {
TORCH_CHECK(cuda::getApplyGrid(nnz, grid, curDevice), "add: Argument #0: tensor too large or too many dimensions");
@ -606,7 +606,7 @@ Tensor _sparse_sum_backward_cuda(const Tensor& grad_, const SparseTensor& input_
}
else {
int curDevice = -1;
cudaGetDevice(&curDevice);
c10::cuda::GetDevice(&curDevice);
cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
at::cuda::ThrustAllocator allocator;
auto policy = thrust::cuda::par(allocator).on(stream);
@ -711,7 +711,7 @@ __global__ void search_end_matrix_indices_cuda_kernel(
// indices to find the end index for each matrix
void search_end_matrix_indices(int64_t* mat_el_end_indices, int64_t num_matrices, const Tensor& indices_1D) {
int curDevice = -1;
cudaGetDevice(&curDevice);
c10::cuda::GetDevice(&curDevice);
cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
auto indices_1D_ti = getTensorInfo<int64_t, int64_t>(indices_1D);

View File

@ -2225,12 +2225,12 @@ class DeviceCachingAllocator {
void insert_events(Block* block) {
int prev_device;
C10_CUDA_CHECK(cudaGetDevice(&prev_device));
C10_CUDA_CHECK(c10::cuda::GetDevice(&prev_device));
stream_set streams(std::move(block->stream_uses));
AT_ASSERT(block->stream_uses.empty());
for (auto& stream : streams) {
C10_CUDA_CHECK(cudaSetDevice(stream.device_index()));
C10_CUDA_CHECK(c10::cuda::SetDevice(stream.device_index()));
EventPool::Event event =
create_event_internal(static_cast<int>(stream.device_index()));
@ -2240,7 +2240,7 @@ class DeviceCachingAllocator {
cuda_events[stream].emplace_back(std::move(event), block);
}
C10_CUDA_CHECK(cudaSetDevice(prev_device));
C10_CUDA_CHECK(c10::cuda::MaybeSetDevice(prev_device));
}
void insert_events_deferred_until_no_capture() {
@ -2434,11 +2434,7 @@ class NativeCachingAllocator : public CUDAAllocator {
"invalid fraction:",
fraction,
". Please set within (0, 1).");
int activated_device;
C10_CUDA_CHECK(cudaGetDevice(&activated_device));
if (activated_device != device) {
C10_CUDA_CHECK(cudaSetDevice(device));
}
C10_CUDA_CHECK(c10::cuda::SetDevice(device));
device_allocator[device]->setMemoryFraction(fraction);
}
@ -2448,7 +2444,7 @@ class NativeCachingAllocator : public CUDAAllocator {
size_t alloc_trace_max_entries,
bool alloc_trace_record_context) override {
int device;
C10_CUDA_CHECK(cudaGetDevice(&device));
C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
device_allocator[device]->recordHistory(
enabled,
std::move(context_recorder),
@ -2458,7 +2454,7 @@ class NativeCachingAllocator : public CUDAAllocator {
void attachOutOfMemoryObserver(OutOfMemoryObserver observer) override {
int device;
C10_CUDA_CHECK(cudaGetDevice(&device));
C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
device_allocator[device]->attachOutOfMemoryObserver(std::move(observer));
}
@ -2557,7 +2553,7 @@ class NativeCachingAllocator : public CUDAAllocator {
size < one_exa_bytes,
"CUDA out of memory. Tried to allocate more than 1EB memory.");
int device;
C10_CUDA_CHECK(cudaGetDevice(&device));
C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
void* r = nullptr;
if (forceUncachedAllocator()) {
// Deliberately don't use cudaMallocMaybeCapturing here, to force an error
@ -2634,7 +2630,7 @@ class NativeCachingAllocator : public CUDAAllocator {
return nullptr;
}
int device;
C10_CUDA_CHECK(cudaGetDevice(&device));
C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
void* r = nullptr;
malloc(&r, device, nbytes, cuda::getCurrentCUDAStream(device));
return r;
@ -2645,7 +2641,7 @@ class NativeCachingAllocator : public CUDAAllocator {
return nullptr;
}
int device;
C10_CUDA_CHECK(cudaGetDevice(&device));
C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
void* r = nullptr;
malloc(&r, device, nbytes, stream);
return r;
@ -2712,7 +2708,7 @@ class NativeCachingAllocator : public CUDAAllocator {
&dev, *ipc_handle, cudaIpcMemLazyEnablePeerAccess));
// devPtr has to be deleted in same device when created.
int curr_device;
C10_CUDA_CHECK(cudaGetDevice(&curr_device));
C10_CUDA_CHECK(c10::cuda::GetDevice(&curr_device));
auto sp =
std::shared_ptr<void>(dev, [handle, curr_device, this](void* ptr) {
cuda::CUDAGuard device_guard(curr_device);

View File

@ -1,5 +1,6 @@
#include <c10/cuda/CUDADeviceAssertionHost.h>
#include <c10/cuda/CUDAException.h>
#include <c10/cuda/CUDAFunctions.h>
#include <c10/util/Backtrace.h>
#include <c10/util/Exception.h>
#include <c10/util/irange.h>
@ -36,7 +37,7 @@ namespace {
/// an infinite initialization loop for CUDAKernelLaunchRegistry
int dsa_get_device_id() {
int device = -1;
C10_CUDA_CHECK_WO_DSA(cudaGetDevice(&device));
C10_CUDA_CHECK_WO_DSA(c10::cuda::GetDevice(&device));
return device;
}
@ -60,7 +61,7 @@ int dsa_get_device_compute_capability(const int device_num) {
/// an infinite initialization loop for CUDAKernelLaunchRegistry
int dsa_get_device_count() {
int device_count = -1;
C10_CUDA_CHECK_WO_DSA(cudaGetDeviceCount(&device_count));
C10_CUDA_CHECK_WO_DSA(c10::cuda::GetDeviceCount(&device_count));
return device_count;
}

View File

@ -15,7 +15,7 @@ int32_t driver_version() {
int device_count_impl(bool fail_if_no_driver) {
int count;
auto err = C10_CUDA_ERROR_HANDLED(cudaGetDeviceCount(&count));
auto err = C10_CUDA_ERROR_HANDLED(c10::cuda::GetDeviceCount(&count));
if (err == cudaSuccess) {
return count;
}
@ -122,12 +122,12 @@ DeviceIndex device_count_ensure_non_zero() {
DeviceIndex current_device() {
int cur_device;
C10_CUDA_CHECK(cudaGetDevice(&cur_device));
C10_CUDA_CHECK(c10::cuda::GetDevice(&cur_device));
return static_cast<DeviceIndex>(cur_device);
}
void set_device(DeviceIndex device) {
C10_CUDA_CHECK(cudaSetDevice(static_cast<int>(device)));
C10_CUDA_CHECK(c10::cuda::SetDevice(static_cast<int>(device)));
}
void device_synchronize() {
@ -182,4 +182,130 @@ bool hasPrimaryContext(int64_t device_index) {
return _internal::hasPrimaryContext(device_index);
}
// Wrappers for raw CUDA device management functions
cudaError_t GetDeviceCount(int* dev_count) {
return cudaGetDeviceCount(dev_count);
}
// This is a codepath for CUDA 12 that comes with a critical change in behavior
// of `cudaSetDevice`. Unlike to previous CUDA versions that allocate context
// lazily CUDA 12.x eagerly allocates primary context the moment `cudaSetDevice`
// is called. This can lead to dramatic consequences and pollute the device
// memory in distributed runs. To avoid unnecessary context creation a new
// function called `MaybeSetDevice` was introduced. This function is to be
// called in device guard destructor and at the exit of torch.cuda.device
// context manager. The behavior of `MaybeSetDevice` is quite simple, it calls
// to `cudaSetDevice` if context already exist or if context was not allocated
// on targeted device it simply saves the device index. This way we can keep
// PyTorch backward compatible for applications like this:
//
// ```
// import torch
// x = torch.empty(1, device=“cuda:1”) # no CUDA context on cuda:0 after this
// call y = torch.empty(1, device=“cuda”) # CUDA context is created on cuda:0
// ```
#if CUDA_VERSION >= 12000
thread_local int targetDeviceIndex = -1;
cudaError_t GetDevice(int* device) {
if (targetDeviceIndex >= 0) {
*device = targetDeviceIndex;
return cudaSuccess;
}
return cudaGetDevice(device);
}
cudaError_t SetDevice(int device) {
TORCH_CHECK(device >= 0, "device id must be positive!");
targetDeviceIndex = -1;
int cur_device = -1;
C10_CUDA_CHECK(cudaGetDevice(&cur_device));
if (device == cur_device) {
return cudaSuccess;
}
return cudaSetDevice(device);
}
cudaError_t MaybeSetDevice(int device) {
if (hasPrimaryContext(device)) {
return c10::cuda::SetDevice(device);
}
targetDeviceIndex = device;
return cudaSuccess;
}
// This function always initializes the CUDA context
// on to_device
int ExchangeDevice(int to_device) {
int cur_device = targetDeviceIndex;
targetDeviceIndex = -1;
if (cur_device < 0) {
C10_CUDA_CHECK(cudaGetDevice(&cur_device));
if (to_device == cur_device) {
return cur_device;
}
}
C10_CUDA_CHECK(cudaSetDevice(to_device));
return cur_device;
}
// This function does not initialize the CUDA context
// on to_device if it does not already exist
int MaybeExchangeDevice(int to_device) {
int cur_device = -1;
C10_CUDA_CHECK(cudaGetDevice(&cur_device));
if (to_device == cur_device) {
return cur_device;
}
if (hasPrimaryContext(to_device)) {
C10_CUDA_CHECK(cudaSetDevice(to_device));
} else {
targetDeviceIndex = to_device;
}
return cur_device;
}
void SetTargetDevice() {
if (targetDeviceIndex >= 0) {
C10_CUDA_CHECK(c10::cuda::SetDevice(targetDeviceIndex));
}
}
#else
cudaError_t GetDevice(int* device) {
return cudaGetDevice(device);
}
cudaError_t SetDevice(int device) {
TORCH_CHECK(device >= 0, "device id must be positive!");
int cur_device = -1;
C10_CUDA_CHECK(cudaGetDevice(&cur_device));
if (device == cur_device) {
return cudaSuccess;
}
return cudaSetDevice(device);
}
cudaError_t MaybeSetDevice(int device) {
return c10::cuda::SetDevice(device);
}
int ExchangeDevice(int to_device) {
int cur_device = -1;
C10_CUDA_CHECK(c10::cuda::GetDevice(&cur_device));
if (to_device == cur_device) {
return cur_device;
}
C10_CUDA_CHECK(cudaSetDevice(to_device));
return cur_device;
}
int MaybeExchangeDevice(int to_device) {
return c10::cuda::ExchangeDevice(to_device);
}
void SetTargetDevice() {
// no-op on CUDA version < 12.x
}
#endif
} // namespace c10::cuda

View File

@ -34,6 +34,21 @@ C10_CUDA_API void device_synchronize();
C10_CUDA_API void warn_or_error_on_sync();
// Raw CUDA device management functions
C10_CUDA_API cudaError_t GetDeviceCount(int* dev_count);
C10_CUDA_API cudaError_t GetDevice(int* device);
C10_CUDA_API cudaError_t SetDevice(int device);
C10_CUDA_API cudaError_t MaybeSetDevice(int device);
C10_CUDA_API int ExchangeDevice(int device);
C10_CUDA_API int MaybeExchangeDevice(int device);
C10_CUDA_API void SetTargetDevice();
enum class SyncDebugMode { L_DISABLED = 0, L_WARN, L_ERROR };
// this is a holder for c10 global state (similar to at GlobalContext)

View File

@ -411,7 +411,7 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator {
size < one_exa_bytes,
"CUDA out of memory. Tried to allocate more than 1EB memory.");
int device;
C10_CUDA_CHECK(cudaGetDevice(&device));
C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
void* r = nullptr;
if (size != 0) {
mallocAsync(&r, device, size, cuda::getCurrentCUDAStream(device));
@ -818,7 +818,7 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator {
return nullptr;
}
int device;
C10_CUDA_CHECK(cudaGetDevice(&device));
C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
void* r = nullptr;
mallocAsync(&r, device, nbytes, cuda::getCurrentCUDAStream(device));
return r;
@ -829,7 +829,7 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator {
return nullptr;
}
int device;
C10_CUDA_CHECK(cudaGetDevice(&device));
C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
void* r = nullptr;
mallocAsync(&r, device, nbytes, stream);
return r;

View File

@ -261,8 +261,10 @@ CUDAStream getStreamFromPool(
const bool isHighPriority,
DeviceIndex device_index) {
initCUDAStreamsOnce();
if (device_index == -1)
if (device_index == -1) {
device_index = current_device();
c10::cuda::SetTargetDevice();
}
check_gpu(device_index);
// Initializes the stream pools (once)
@ -289,6 +291,7 @@ CUDAStream getDefaultCUDAStream(DeviceIndex device_index) {
initCUDAStreamsOnce();
if (device_index == -1) {
device_index = current_device();
c10::cuda::SetTargetDevice();
}
check_gpu(device_index);
return CUDAStreamForId(device_index, makeStreamId(StreamIdType::DEFAULT, 0));
@ -298,6 +301,7 @@ CUDAStream getCurrentCUDAStream(DeviceIndex device_index) {
initCUDAStreamsOnce();
if (device_index == -1) {
device_index = current_device();
c10::cuda::SetTargetDevice();
}
check_gpu(device_index);
return CUDAStreamForId(device_index, current_streams[device_index]);

View File

@ -29,20 +29,17 @@ struct CUDAGuardImpl final : public c10::impl::DeviceGuardImplInterface {
}
Device exchangeDevice(Device d) const override {
TORCH_INTERNAL_ASSERT(d.is_cuda());
Device old_device = getDevice();
if (old_device.index() != d.index()) {
C10_CUDA_CHECK(cudaSetDevice(d.index()));
}
return old_device;
int old_device_index = c10::cuda::ExchangeDevice(d.index());
return Device(DeviceType::CUDA, old_device_index);
}
Device getDevice() const override {
int device;
C10_CUDA_CHECK(cudaGetDevice(&device));
C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
return Device(DeviceType::CUDA, device);
}
c10::optional<Device> uncheckedGetDevice() const noexcept {
int device;
const auto err = C10_CUDA_ERROR_HANDLED(cudaGetDevice(&device));
const auto err = C10_CUDA_ERROR_HANDLED(c10::cuda::GetDevice(&device));
C10_CUDA_CHECK_WARN(err);
if (err != cudaSuccess) {
return c10::nullopt;
@ -51,16 +48,10 @@ struct CUDAGuardImpl final : public c10::impl::DeviceGuardImplInterface {
}
void setDevice(Device d) const override {
TORCH_INTERNAL_ASSERT(d.is_cuda());
Device current_device = getDevice();
if (current_device != d) {
C10_CUDA_CHECK(cudaSetDevice(d.index()));
}
C10_CUDA_CHECK(c10::cuda::SetDevice(d.index()));
}
void uncheckedSetDevice(Device d) const noexcept override {
auto current_device = uncheckedGetDevice();
if (!current_device.has_value() || current_device.value() != d) {
C10_CUDA_CHECK_WARN(cudaSetDevice(d.index()));
}
C10_CUDA_CHECK_WARN(c10::cuda::MaybeSetDevice(d.index()));
}
Stream getStream(Device d) const noexcept override {
return getCurrentCUDAStream(d.index()).unwrap();
@ -114,15 +105,15 @@ struct CUDAGuardImpl final : public c10::impl::DeviceGuardImplInterface {
return;
auto cuda_event = static_cast<cudaEvent_t>(event);
int orig_device;
C10_CUDA_CHECK_WARN(cudaGetDevice(&orig_device));
C10_CUDA_CHECK_WARN(cudaSetDevice(device_index));
C10_CUDA_CHECK_WARN(c10::cuda::GetDevice(&orig_device));
C10_CUDA_CHECK_WARN(c10::cuda::SetDevice(device_index));
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_deletion(
reinterpret_cast<uintptr_t>(cuda_event));
}
C10_CUDA_CHECK_WARN(cudaEventDestroy(cuda_event));
C10_CUDA_CHECK_WARN(cudaSetDevice(orig_device));
C10_CUDA_CHECK_WARN(c10::cuda::SetDevice(orig_device));
}
void record(

View File

@ -604,3 +604,18 @@ class TestCUDA(JitTestCase):
FileCheck().check("cuda::_exchange_device(").run(g)
torch._C._jit_pass_inline(g)
FileCheck().check("cuda::_exchange_device(").run(g)
# Make sure that cuda._maybe_exchange_device doesn't get DCE'ed
@unittest.skipIf(not TEST_CUDA, "Cuda not available")
def test__maybe_exchange_device_op(self):
def fn(device: int, tensor):
torch.cuda._maybe_exchange_device(device)
return tensor.cos().relu()
fn_s = torch.jit.script(fn)
# Just check the graph, don't run it. Otherwise, we'd need to
# run this test on a multi-gpu CI runner, which is overkill.
g = fn_s.graph
FileCheck().check("cuda::_maybe_exchange_device(").run(g)
torch._C._jit_pass_inline(g)
FileCheck().check("cuda::_maybe_exchange_device(").run(g)

View File

@ -1448,6 +1448,7 @@ def _cuda_getCurrentBlasHandle() -> _int: ...
def _cuda_clearCublasWorkspaces() -> None: ...
def _cuda_setDevice(device: _int) -> None: ...
def _cuda_exchangeDevice(device: _int) -> _int: ...
def _cuda_maybeExchangeDevice(device: _int) -> _int: ...
def _cuda_getDevice() -> _int: ...
def _cuda_getDeviceCount() -> _int: ...
def _cuda_set_sync_debug_mode(warn_level: Union[_int, str]) -> None: ...

View File

@ -774,7 +774,7 @@ void set_device(int device) {
// as in some settings we compile with cuda, but
// have lazy stubs for CUDA functionality (so actually
// attempting to setup a guard(CPU_DEVICE) will cause an
// error, because it will still query cudaGetDevice).
// error, because it will still query GetDevice).
//
// Don't use DeviceGuard here because its destructor may be called before the
// device is reset. This is fine because the device is thread local.

View File

@ -98,7 +98,7 @@ void* CUDAPluggableAllocator::malloc(
c10::DataPtr CUDAPluggableAllocator::allocate(size_t size) const {
int device;
C10_CUDA_CHECK(cudaGetDevice(&device));
C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
cudaStream_t stream = c10::cuda::getCurrentCUDAStream(device);
void* r =
const_cast<CUDAPluggableAllocator*>(this)->malloc(size, device, stream);
@ -113,7 +113,7 @@ c10::DeleterFnPtr CUDAPluggableAllocator::raw_deleter() const {
void* CUDAPluggableAllocator::raw_alloc(size_t nbytes) {
int device;
C10_CUDA_CHECK(cudaGetDevice(&device));
C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
cudaStream_t stream = c10::cuda::getCurrentCUDAStream(device);
return malloc(nbytes, device, stream);
}
@ -122,7 +122,7 @@ void* CUDAPluggableAllocator::raw_alloc_with_stream(
size_t nbytes,
cudaStream_t stream) {
int device;
C10_CUDA_CHECK(cudaGetDevice(&device));
C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
return malloc(nbytes, device, stream);
}

View File

@ -103,12 +103,24 @@ PyObject* THCPModule_exchangeDevice(PyObject* self, PyObject* arg) {
}
torch::utils::cuda_lazy_init();
auto current_device = c10::cuda::current_device();
if (current_device != device) {
THCPModule_setDevice(device);
int current_device = c10::cuda::ExchangeDevice(device);
return THPUtils_packInt32(current_device);
END_HANDLE_TH_ERRORS
}
PyObject* THCPModule_maybeExchangeDevice(PyObject* self, PyObject* arg) {
HANDLE_TH_ERRORS
TORCH_CHECK(THPUtils_checkLong(arg), "invalid argument to exchangeDevice");
int64_t device = THPUtils_unpackLong(arg);
if (device < 0) {
return THPUtils_packInt32(-1);
}
return THPUtils_packInt32(static_cast<int>(current_device));
torch::utils::cuda_lazy_init();
int current_device = c10::cuda::MaybeExchangeDevice(device);
return THPUtils_packInt32(current_device);
END_HANDLE_TH_ERRORS
}
@ -1348,6 +1360,10 @@ static struct PyMethodDef _THCPModule_methods[] = {
{"_cuda_init", THCPModule_initExtension, METH_NOARGS, nullptr},
{"_cuda_setDevice", THCPModule_setDevice_wrap, METH_O, nullptr},
{"_cuda_exchangeDevice", THCPModule_exchangeDevice, METH_O, nullptr},
{"_cuda_maybeExchangeDevice",
THCPModule_maybeExchangeDevice,
METH_O,
nullptr},
{"_cuda_getDevice", THCPModule_getDevice_wrap, METH_NOARGS, nullptr},
{"_cuda_getDeviceCount",
THCPModule_getDeviceCount_wrap,

View File

@ -229,6 +229,7 @@ std::shared_ptr<SugaredValue> CUDAPythonModuleValue::attr(
"default_stream",
"current_device",
"_exchange_device",
"_maybe_exchange_device",
"set_device",
"device_index",
"device_count",

View File

@ -103,6 +103,19 @@ RegisterOperators const reg({
},
// cuda::set_device has side effects.
c10::AliasAnalysisKind::CONSERVATIVE),
Operator(
"cuda::_maybe_exchange_device(int64_t index) -> int",
[](Stack& stack) {
int64_t idx = -1;
pop(stack, idx);
if (idx < 0) {
push(stack, -1);
return;
}
int prev_idx = c10::cuda::MaybeExchangeDevice(static_cast<int>(idx));
push(stack, prev_idx);
},
c10::AliasAnalysisKind::CONSERVATIVE),
Operator(
"cuda::_set_device(int64_t val) -> ()",
[](Stack& stack) {

View File

@ -39,7 +39,7 @@ struct CUDAMethods : public ProfilerStubs {
void record(int* device, ProfilerEventStub* event, int64_t* cpu_ns)
const override {
if (device) {
TORCH_CUDA_CHECK(cudaGetDevice(device));
TORCH_CUDA_CHECK(c10::cuda::GetDevice(device));
}
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
CUevent_st* cuda_event_ptr;

View File

@ -87,6 +87,14 @@ else:
return -1
raise RuntimeError("PyTorch was compiled without CUDA support")
if hasattr(torch._C, '_cuda_maybeExchangeDevice'):
_maybe_exchange_device = torch._C._cuda_maybeExchangeDevice
else:
def _maybe_exchange_device(device: int) -> int:
if device < 0:
return -1
raise RuntimeError("PyTorch was compiled without CUDA support")
# Global variables dynamically populated by native code
has_magma: bool = False
@ -305,7 +313,7 @@ class _DeviceGuard:
self.prev_idx = torch.cuda._exchange_device(self.idx)
def __exit__(self, type: Any, value: Any, traceback: Any):
self.idx = torch.cuda._exchange_device(self.prev_idx)
self.idx = torch.cuda._maybe_exchange_device(self.prev_idx)
return False
@ -325,7 +333,7 @@ class device:
self.prev_idx = torch.cuda._exchange_device(self.idx)
def __exit__(self, type: Any, value: Any, traceback: Any):
self.idx = torch.cuda._exchange_device(self.prev_idx)
self.idx = torch.cuda._maybe_exchange_device(self.prev_idx)
return False