Refine the logic of device construction when only device index is given (#129119)

# Motivation
Before this PR, device construction was `cuda` type when only a device index was given. It also returns the `PrivateUser1` type if a `PrivateUser1` type is registered.
```bash
>>> import torch
>>> device = torch.device(0)
>>> device.type
'cuda'
>>> a = torch.tensor([1, 2])
>>> b = a.to(0)
>>> b
tensor([1, 2], device='cuda:0')
```
It works well on CUDA GPU. But it will raise unexpected information and error running on XPU.
```bash
>>> import torch
>>> device = torch.device(0)
>>> device.type
'cuda'
>>> a = torch.tensor([1, 2])
>>> b = a.to(0)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/xxx/pytorch/torch/cuda/__init__.py", line 302, in _lazy_init
    raise AssertionError("Torch not compiled with CUDA enabled")
AssertionError: Torch not compiled with CUDA enabled
```
With this PR, refine the logic to use the currently available device type instead.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129119
Approved by: https://github.com/albanD, https://github.com/gujinghui, https://github.com/EikanWang
ghstack dependencies: #129463, #129205, #129363
This commit is contained in:
Yu, Guangye 2024-07-04 13:03:03 +00:00 committed by PyTorch MergeBot
parent 9cae2160f5
commit 7cd48df2da
2 changed files with 5 additions and 7 deletions

View File

@ -213,7 +213,8 @@ non-None device argument. To globally change the default device, see also
.. note::
For legacy reasons, a device can be constructed via a single device ordinal, which is treated
as a cuda device. This matches :meth:`Tensor.get_device`, which returns an ordinal for cuda
as the current :ref:`accelerator<accelerators>` type.
This matches :meth:`Tensor.get_device`, which returns an ordinal for device
tensors and is not supported for cpu tensors.
>>> torch.device(1)

View File

@ -66,6 +66,7 @@
#include <torch/csrc/utils/python_symnode.h>
#include <torch/csrc/utils/six.h>
#include <ATen/DeviceAccelerator.h>
#include <ATen/PythonTorchFunctionTLS.h>
#include <ATen/core/Tensor.h>
#include <c10/util/Exception.h>
@ -811,13 +812,9 @@ inline at::Device toDevice(PyObject* obj) {
if (THPUtils_checkLong(obj)) {
const auto device_index = THPUtils_unpackLong(obj);
TORCH_CHECK(device_index >= 0, "Device index must not be negative");
if (c10::is_privateuse1_backend_registered()) {
return at::Device(
c10::DeviceType::PrivateUse1,
static_cast<c10::DeviceIndex>(device_index));
}
return at::Device(
c10::DeviceType::CUDA, static_cast<c10::DeviceIndex>(device_index));
at::getAccelerator(true).value(),
static_cast<c10::DeviceIndex>(device_index));
}
const std::string& device_str = THPUtils_unpackString(obj);
return at::Device(device_str);