mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add type informations to torch.cuda (#47134)
Summary: Fixes https://github.com/pytorch/pytorch/issues/47133 Pull Request resolved: https://github.com/pytorch/pytorch/pull/47134 Reviewed By: smessmer Differential Revision: D24955031 Pulled By: ezyang fbshipit-source-id: 87f4623643715baa6ac0627383f009956f80cd46
This commit is contained in:
parent
2eb1e866e8
commit
4f9d0757f3
6
mypy.ini
6
mypy.ini
|
|
@ -128,9 +128,6 @@ ignore_errors = True
|
|||
[mypy-torch.nn.quantized.modules.conv]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.cuda]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch._lobpcg]
|
||||
ignore_errors = True
|
||||
|
||||
|
|
@ -140,9 +137,6 @@ ignore_errors = True
|
|||
[mypy-torch._utils]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch._overrides]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.utils.tensorboard._caffe2_graph]
|
||||
ignore_errors = True
|
||||
|
||||
|
|
|
|||
|
|
@ -667,11 +667,14 @@ def gen_pyi(declarations_path, out):
|
|||
|
||||
# TODO: These are deprecated, maybe we shouldn't type hint them
|
||||
legacy_storage_base_hints = []
|
||||
for c in ('Double', 'Float', 'Long', 'Int',
|
||||
'Short', 'Char', 'Byte', 'Bool',
|
||||
'Half', 'BFloat16', 'ComplexDouble',
|
||||
'ComplexFloat', 'QUInt8', 'QInt8', 'QInt32', 'QUInt4x2'):
|
||||
dt = ('Double', 'Float', 'Long', 'Int',
|
||||
'Short', 'Char', 'Byte', 'Bool',
|
||||
'Half', 'BFloat16', 'ComplexDouble',
|
||||
'ComplexFloat', 'QUInt8', 'QInt8', 'QInt32', 'QUInt4x2')
|
||||
for c in dt:
|
||||
legacy_storage_base_hints.append('class {}StorageBase(object): ...'.format(c))
|
||||
for c in dt:
|
||||
legacy_storage_base_hints.append('class Cuda{}StorageBase(object): ...'.format(c))
|
||||
|
||||
legacy_class_hints = []
|
||||
for c in ('DoubleTensor', 'FloatTensor', 'LongTensor', 'IntTensor',
|
||||
|
|
|
|||
|
|
@ -561,6 +561,13 @@ def _cuda_getCurrentStream(device: _int) -> _int: ...
|
|||
def _cuda_getDefaultStream(device: _int) -> _int: ...
|
||||
def _cuda_getCurrentBlasHandle() -> _int: ...
|
||||
def _cuda_setDevice(device: _int) -> None: ...
|
||||
def _cuda_getDevice() -> _int: ...
|
||||
def _cuda_getDeviceCount() -> _int: ...
|
||||
def _cuda_sleep(cycles: _int) -> None: ...
|
||||
def _cuda_synchronize() -> None: ...
|
||||
def _cuda_ipc_collect() -> None: ...
|
||||
def _cuda_getArchFlags() -> Optional[str]: ...
|
||||
def _cuda_init() -> None: ...
|
||||
def _cuda_setStream(cuda_stream: _int) -> None: ...
|
||||
def _cuda_getCompiledVersion() -> _int: ...
|
||||
def _cuda_cudaHostAllocator() -> _int: ...
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ from .. import device as _device
|
|||
import torch._C
|
||||
|
||||
try:
|
||||
from torch._C import _cudart
|
||||
from torch._C import _cudart # type: ignore
|
||||
except ImportError:
|
||||
_cudart = None
|
||||
|
||||
|
|
@ -30,18 +30,18 @@ _tls = threading.local()
|
|||
_initialization_lock = threading.Lock()
|
||||
_queued_calls = [] # don't invoke these until initialization occurs
|
||||
_is_in_bad_fork = getattr(torch._C, "_cuda_isInBadFork", lambda: False)
|
||||
_device_t = Union[_device, str, int]
|
||||
_device_t = Union[_device, str, int, None]
|
||||
|
||||
# Define dummy _CudaDeviceProperties type if PyTorch was compiled without CUDA
|
||||
if hasattr(torch._C, '_CudaDeviceProperties'):
|
||||
_CudaDeviceProperties = torch._C._CudaDeviceProperties
|
||||
else:
|
||||
_CudaDeviceProperties = _dummy_type('_CudaDeviceProperties')
|
||||
_CudaDeviceProperties = _dummy_type('_CudaDeviceProperties') # type: ignore
|
||||
|
||||
# Global variables dynamically populated by native code
|
||||
has_magma: bool = False
|
||||
has_half: bool = False
|
||||
default_generators: Tuple[torch._C.Generator] = ()
|
||||
default_generators: Tuple[torch._C.Generator] = () # type: ignore[assignment]
|
||||
|
||||
def is_available() -> bool:
|
||||
r"""Returns a bool indicating if CUDA is currently available."""
|
||||
|
|
@ -297,7 +297,7 @@ def get_device_properties(device: _device_t) -> _CudaDeviceProperties:
|
|||
device = _get_device_index(device, optional=True)
|
||||
if device < 0 or device >= device_count():
|
||||
raise AssertionError("Invalid device id")
|
||||
return _get_device_properties(device)
|
||||
return _get_device_properties(device) # type: ignore[name-defined]
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
|
|
@ -356,8 +356,8 @@ def get_gencode_flags() -> str:
|
|||
arch_list = get_arch_list()
|
||||
if len(arch_list) == 0:
|
||||
return ""
|
||||
arch_list = [arch.split("_") for arch in arch_list]
|
||||
return " ".join([f"-gencode compute=compute_{arch},code={kind}_{arch}" for (kind, arch) in arch_list])
|
||||
arch_list_ = [arch.split("_") for arch in arch_list]
|
||||
return " ".join([f"-gencode compute=compute_{arch},code={kind}_{arch}" for (kind, arch) in arch_list_])
|
||||
|
||||
|
||||
|
||||
|
|
@ -454,7 +454,7 @@ if not hasattr(torch._C, 'CudaDoubleStorageBase'):
|
|||
torch._C.__dict__['_CudaEventBase'] = _dummy_type('CudaEventBase')
|
||||
|
||||
|
||||
@staticmethod
|
||||
@staticmethod # type: ignore[misc]
|
||||
def _lazy_new(cls, *args, **kwargs):
|
||||
_lazy_init()
|
||||
# We may need to call lazy init again if we are a forked child
|
||||
|
|
@ -467,8 +467,11 @@ class _CudaBase(object):
|
|||
is_sparse = False
|
||||
|
||||
def type(self, *args, **kwargs):
|
||||
with device(self.get_device()):
|
||||
return super(_CudaBase, self).type(*args, **kwargs)
|
||||
# We could use a Protocol here to tell mypy that self has `get_device` method
|
||||
# but it is only available in the typing module on Python >= 3.8
|
||||
# or on typing_extensions module on Python >= 3.6
|
||||
with device(self.get_device()): # type: ignore
|
||||
return super(_CudaBase, self).type(*args, **kwargs) # type: ignore[misc]
|
||||
|
||||
__new__ = _lazy_new
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from torch.types import Device
|
|||
from torch._utils import _get_device_index as _torch_get_device_index
|
||||
|
||||
|
||||
def _get_device_index(device: Union[Device, str, int], optional: bool = False,
|
||||
def _get_device_index(device: Union[Device, str, int, None], optional: bool = False,
|
||||
allow_cpu: bool = False) -> int:
|
||||
r"""Gets the device index from :attr:`device`, which can be a torch.device
|
||||
object, a Python integer, or ``None``.
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user