mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Typing] Improve device typing for torch.set_default_device() (#153028)
Part of: #152952
Here is the definition of `torch.types.Device`:
ab997d9ff5/torch/types.py (L74)
So `_Optional[_Union["torch.device", str, builtins.int]]` is equivalent to it.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153028
Approved by: https://github.com/Skylion007
This commit is contained in:
parent
dd7d231ed3
commit
f5f8f637a5
|
|
@ -36,7 +36,7 @@ from typing_extensions import ParamSpec as _ParamSpec
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .types import IntLikeType
|
from .types import Device, IntLikeType
|
||||||
|
|
||||||
|
|
||||||
# multipy/deploy is setting this import before importing torch, this is the most
|
# multipy/deploy is setting this import before importing torch, this is the most
|
||||||
|
|
@ -1154,9 +1154,7 @@ def get_default_device() -> "torch.device":
|
||||||
return torch.device("cpu")
|
return torch.device("cpu")
|
||||||
|
|
||||||
|
|
||||||
def set_default_device(
|
def set_default_device(device: "Device") -> None:
|
||||||
device: _Optional[_Union["torch.device", str, builtins.int]],
|
|
||||||
) -> None:
|
|
||||||
"""Sets the default ``torch.Tensor`` to be allocated on ``device``. This
|
"""Sets the default ``torch.Tensor`` to be allocated on ``device``. This
|
||||||
does not affect factory function calls which are called with an explicit
|
does not affect factory function calls which are called with an explicit
|
||||||
``device`` argument. Factory calls will be performed as if they
|
``device`` argument. Factory calls will be performed as if they
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user