[C10][CUDA] Eagerly create context on torch.cuda.set_device(device) call (#155900)

Fixes #155668

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155900
Approved by: https://github.com/ngimel
This commit is contained in:
Aidyn-A 2025-06-17 18:59:41 +00:00 committed by PyTorch MergeBot
parent fc177801af
commit 4a26bb8a12
4 changed files with 34 additions and 10 deletions

View File

@ -130,8 +130,8 @@ DeviceIndex current_device() {
return cur_device;
}
void set_device(DeviceIndex device) {
C10_CUDA_CHECK(c10::cuda::SetDevice(device));
void set_device(DeviceIndex device, const bool force) {
C10_CUDA_CHECK(c10::cuda::SetDevice(device, force));
}
void device_synchronize() {
@ -231,9 +231,12 @@ cudaError_t GetDevice(DeviceIndex* device) {
return err;
}
cudaError_t SetDevice(DeviceIndex device) {
TORCH_CHECK(device >= 0, "device id must be positive!", device);
cudaError_t SetDevice(DeviceIndex device, const bool force) {
TORCH_CHECK(device >= 0, "device id must be non-negative!", device);
targetDeviceIndex = -1;
if (force) {
return cudaSetDevice(device);
}
int cur_device = -1;
C10_CUDA_CHECK(cudaGetDevice(&cur_device));
if (device == cur_device) {
@ -309,8 +312,11 @@ cudaError_t GetDevice(DeviceIndex* device) {
return err;
}
cudaError_t SetDevice(DeviceIndex device) {
TORCH_CHECK(device >= 0, "device id must be positive!", device);
cudaError_t SetDevice(DeviceIndex device, const bool force) {
TORCH_CHECK(device >= 0, "device id must be non-negative!", device);
if (force) {
return cudaSetDevice(device);
}
int cur_device = -1;
C10_CUDA_CHECK(cudaGetDevice(&cur_device));
if (device == cur_device) {

View File

@ -27,7 +27,7 @@ C10_CUDA_API DeviceIndex device_count_ensure_non_zero();
C10_CUDA_API DeviceIndex current_device();
C10_CUDA_API void set_device(DeviceIndex device);
C10_CUDA_API void set_device(DeviceIndex device, const bool force = false);
C10_CUDA_API void device_synchronize();
@ -38,7 +38,8 @@ C10_CUDA_API cudaError_t GetDeviceCount(int* dev_count);
C10_CUDA_API cudaError_t GetDevice(DeviceIndex* device);
C10_CUDA_API cudaError_t SetDevice(DeviceIndex device);
C10_CUDA_API cudaError_t
SetDevice(DeviceIndex device, const bool force = false);
C10_CUDA_API cudaError_t MaybeSetDevice(DeviceIndex device);

View File

@ -4,7 +4,11 @@ import sys
import unittest
import torch
from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU
from torch.testing._internal.common_cuda import (
_get_torch_cuda_version,
TEST_CUDA,
TEST_MULTIGPU,
)
from torch.testing._internal.common_utils import NoTest, run_tests, TestCase
@ -31,6 +35,19 @@ class TestCudaPrimaryCtx(TestCase):
TestCudaPrimaryCtx.CTX_ALREADY_CREATED_ERR_MSG,
)
def test_set_device_0(self):
# In CUDA 12 the behavior of cudaSetDevice has changed. It eagerly creates context on target.
# The behavior of `torch.cuda.set_device(0)` should also create context on the device 0.
# Initially, we should not have any context on device 0.
self.assertFalse(torch._C._cuda_hasPrimaryContext(0))
torch.cuda.set_device(0)
if _get_torch_cuda_version() >= (12, 0):
# Now after the device was set, the contex should present in CUDA 12.
self.assertTrue(torch._C._cuda_hasPrimaryContext(0))
else:
# In CUDA 11 the context should not be created.
self.assertFalse(torch._C._cuda_hasPrimaryContext(0))
@unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
def test_str_repr(self):
x = torch.randn(1, device="cuda:1")

View File

@ -64,7 +64,7 @@ PyObject* THCPModule_setDevice_wrap(PyObject* self, PyObject* arg) {
auto device = THPUtils_unpackLong(arg);
torch::utils::device_lazy_init(at::kCUDA);
c10::cuda::set_device(static_cast<c10::DeviceIndex>(device));
c10::cuda::set_device(static_cast<c10::DeviceIndex>(device), /*force*/ true);
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS