mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Refine the logic of device construction when only device index is given (#129119)
# Motivation
Before this PR, device construction was `cuda` type when only a device index was given. It also returns the `PrivateUser1` type if a `PrivateUser1` type is registered.
```bash
>>> import torch
>>> device = torch.device(0)
>>> device.type
'cuda'
>>> a = torch.tensor([1, 2])
>>> b = a.to(0)
>>> b
tensor([1, 2], device='cuda:0')
```
It works well on CUDA GPU. But it will raise unexpected information and error running on XPU.
```bash
>>> import torch
>>> device = torch.device(0)
>>> device.type
'cuda'
>>> a = torch.tensor([1, 2])
>>> b = a.to(0)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/xxx/pytorch/torch/cuda/__init__.py", line 302, in _lazy_init
raise AssertionError("Torch not compiled with CUDA enabled")
AssertionError: Torch not compiled with CUDA enabled
```
With this PR, refine the logic to use the currently available device type instead.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129119
Approved by: https://github.com/albanD, https://github.com/gujinghui, https://github.com/EikanWang
ghstack dependencies: #129463, #129205, #129363
This commit is contained in:
parent
9cae2160f5
commit
7cd48df2da
|
|
@ -213,7 +213,8 @@ non-None device argument. To globally change the default device, see also
|
|||
|
||||
.. note::
|
||||
For legacy reasons, a device can be constructed via a single device ordinal, which is treated
|
||||
as a cuda device. This matches :meth:`Tensor.get_device`, which returns an ordinal for cuda
|
||||
as the current :ref:`accelerator<accelerators>` type.
|
||||
This matches :meth:`Tensor.get_device`, which returns an ordinal for device
|
||||
tensors and is not supported for cpu tensors.
|
||||
|
||||
>>> torch.device(1)
|
||||
|
|
|
|||
|
|
@ -66,6 +66,7 @@
|
|||
#include <torch/csrc/utils/python_symnode.h>
|
||||
#include <torch/csrc/utils/six.h>
|
||||
|
||||
#include <ATen/DeviceAccelerator.h>
|
||||
#include <ATen/PythonTorchFunctionTLS.h>
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <c10/util/Exception.h>
|
||||
|
|
@ -811,13 +812,9 @@ inline at::Device toDevice(PyObject* obj) {
|
|||
if (THPUtils_checkLong(obj)) {
|
||||
const auto device_index = THPUtils_unpackLong(obj);
|
||||
TORCH_CHECK(device_index >= 0, "Device index must not be negative");
|
||||
if (c10::is_privateuse1_backend_registered()) {
|
||||
return at::Device(
|
||||
c10::DeviceType::PrivateUse1,
|
||||
static_cast<c10::DeviceIndex>(device_index));
|
||||
}
|
||||
return at::Device(
|
||||
c10::DeviceType::CUDA, static_cast<c10::DeviceIndex>(device_index));
|
||||
at::getAccelerator(true).value(),
|
||||
static_cast<c10::DeviceIndex>(device_index));
|
||||
}
|
||||
const std::string& device_str = THPUtils_unpackString(obj);
|
||||
return at::Device(device_str);
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user