mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[C10][CUDA] Eagerly create context on torch.cuda.set_device(device) call (#155900)
Fixes #155668 Pull Request resolved: https://github.com/pytorch/pytorch/pull/155900 Approved by: https://github.com/ngimel
This commit is contained in:
parent
fc177801af
commit
4a26bb8a12
|
|
@ -130,8 +130,8 @@ DeviceIndex current_device() {
|
|||
return cur_device;
|
||||
}
|
||||
|
||||
void set_device(DeviceIndex device) {
|
||||
C10_CUDA_CHECK(c10::cuda::SetDevice(device));
|
||||
void set_device(DeviceIndex device, const bool force) {
|
||||
C10_CUDA_CHECK(c10::cuda::SetDevice(device, force));
|
||||
}
|
||||
|
||||
void device_synchronize() {
|
||||
|
|
@ -231,9 +231,12 @@ cudaError_t GetDevice(DeviceIndex* device) {
|
|||
return err;
|
||||
}
|
||||
|
||||
cudaError_t SetDevice(DeviceIndex device) {
|
||||
TORCH_CHECK(device >= 0, "device id must be positive!", device);
|
||||
cudaError_t SetDevice(DeviceIndex device, const bool force) {
|
||||
TORCH_CHECK(device >= 0, "device id must be non-negative!", device);
|
||||
targetDeviceIndex = -1;
|
||||
if (force) {
|
||||
return cudaSetDevice(device);
|
||||
}
|
||||
int cur_device = -1;
|
||||
C10_CUDA_CHECK(cudaGetDevice(&cur_device));
|
||||
if (device == cur_device) {
|
||||
|
|
@ -309,8 +312,11 @@ cudaError_t GetDevice(DeviceIndex* device) {
|
|||
return err;
|
||||
}
|
||||
|
||||
cudaError_t SetDevice(DeviceIndex device) {
|
||||
TORCH_CHECK(device >= 0, "device id must be positive!", device);
|
||||
cudaError_t SetDevice(DeviceIndex device, const bool force) {
|
||||
TORCH_CHECK(device >= 0, "device id must be non-negative!", device);
|
||||
if (force) {
|
||||
return cudaSetDevice(device);
|
||||
}
|
||||
int cur_device = -1;
|
||||
C10_CUDA_CHECK(cudaGetDevice(&cur_device));
|
||||
if (device == cur_device) {
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ C10_CUDA_API DeviceIndex device_count_ensure_non_zero();
|
|||
|
||||
C10_CUDA_API DeviceIndex current_device();
|
||||
|
||||
C10_CUDA_API void set_device(DeviceIndex device);
|
||||
C10_CUDA_API void set_device(DeviceIndex device, const bool force = false);
|
||||
|
||||
C10_CUDA_API void device_synchronize();
|
||||
|
||||
|
|
@ -38,7 +38,8 @@ C10_CUDA_API cudaError_t GetDeviceCount(int* dev_count);
|
|||
|
||||
C10_CUDA_API cudaError_t GetDevice(DeviceIndex* device);
|
||||
|
||||
C10_CUDA_API cudaError_t SetDevice(DeviceIndex device);
|
||||
C10_CUDA_API cudaError_t
|
||||
SetDevice(DeviceIndex device, const bool force = false);
|
||||
|
||||
C10_CUDA_API cudaError_t MaybeSetDevice(DeviceIndex device);
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,11 @@ import sys
|
|||
import unittest
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU
|
||||
from torch.testing._internal.common_cuda import (
|
||||
_get_torch_cuda_version,
|
||||
TEST_CUDA,
|
||||
TEST_MULTIGPU,
|
||||
)
|
||||
from torch.testing._internal.common_utils import NoTest, run_tests, TestCase
|
||||
|
||||
|
||||
|
|
@ -31,6 +35,19 @@ class TestCudaPrimaryCtx(TestCase):
|
|||
TestCudaPrimaryCtx.CTX_ALREADY_CREATED_ERR_MSG,
|
||||
)
|
||||
|
||||
def test_set_device_0(self):
|
||||
# In CUDA 12 the behavior of cudaSetDevice has changed. It eagerly creates context on target.
|
||||
# The behavior of `torch.cuda.set_device(0)` should also create context on the device 0.
|
||||
# Initially, we should not have any context on device 0.
|
||||
self.assertFalse(torch._C._cuda_hasPrimaryContext(0))
|
||||
torch.cuda.set_device(0)
|
||||
if _get_torch_cuda_version() >= (12, 0):
|
||||
# Now after the device was set, the contex should present in CUDA 12.
|
||||
self.assertTrue(torch._C._cuda_hasPrimaryContext(0))
|
||||
else:
|
||||
# In CUDA 11 the context should not be created.
|
||||
self.assertFalse(torch._C._cuda_hasPrimaryContext(0))
|
||||
|
||||
@unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
|
||||
def test_str_repr(self):
|
||||
x = torch.randn(1, device="cuda:1")
|
||||
|
|
|
|||
|
|
@ -64,7 +64,7 @@ PyObject* THCPModule_setDevice_wrap(PyObject* self, PyObject* arg) {
|
|||
auto device = THPUtils_unpackLong(arg);
|
||||
|
||||
torch::utils::device_lazy_init(at::kCUDA);
|
||||
c10::cuda::set_device(static_cast<c10::DeviceIndex>(device));
|
||||
c10::cuda::set_device(static_cast<c10::DeviceIndex>(device), /*force*/ true);
|
||||
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user