Add type informations to torch.cuda (#47134)

Summary:
Fixes https://github.com/pytorch/pytorch/issues/47133

Pull Request resolved: https://github.com/pytorch/pytorch/pull/47134

Reviewed By: smessmer

Differential Revision: D24955031

Pulled By: ezyang

fbshipit-source-id: 87f4623643715baa6ac0627383f009956f80cd46
This commit is contained in:
Guilherme Leobas 2020-11-13 21:18:00 -08:00 committed by Facebook GitHub Bot
parent 2eb1e866e8
commit 4f9d0757f3
5 changed files with 28 additions and 21 deletions

View File

@ -128,9 +128,6 @@ ignore_errors = True
[mypy-torch.nn.quantized.modules.conv]
ignore_errors = True
[mypy-torch.cuda]
ignore_errors = True
[mypy-torch._lobpcg]
ignore_errors = True
@ -140,9 +137,6 @@ ignore_errors = True
[mypy-torch._utils]
ignore_errors = True
[mypy-torch._overrides]
ignore_errors = True
[mypy-torch.utils.tensorboard._caffe2_graph]
ignore_errors = True

View File

@ -667,11 +667,14 @@ def gen_pyi(declarations_path, out):
# TODO: These are deprecated, maybe we shouldn't type hint them
legacy_storage_base_hints = []
for c in ('Double', 'Float', 'Long', 'Int',
'Short', 'Char', 'Byte', 'Bool',
'Half', 'BFloat16', 'ComplexDouble',
'ComplexFloat', 'QUInt8', 'QInt8', 'QInt32', 'QUInt4x2'):
dt = ('Double', 'Float', 'Long', 'Int',
'Short', 'Char', 'Byte', 'Bool',
'Half', 'BFloat16', 'ComplexDouble',
'ComplexFloat', 'QUInt8', 'QInt8', 'QInt32', 'QUInt4x2')
for c in dt:
legacy_storage_base_hints.append('class {}StorageBase(object): ...'.format(c))
for c in dt:
legacy_storage_base_hints.append('class Cuda{}StorageBase(object): ...'.format(c))
legacy_class_hints = []
for c in ('DoubleTensor', 'FloatTensor', 'LongTensor', 'IntTensor',

View File

@ -561,6 +561,13 @@ def _cuda_getCurrentStream(device: _int) -> _int: ...
def _cuda_getDefaultStream(device: _int) -> _int: ...
def _cuda_getCurrentBlasHandle() -> _int: ...
def _cuda_setDevice(device: _int) -> None: ...
def _cuda_getDevice() -> _int: ...
def _cuda_getDeviceCount() -> _int: ...
def _cuda_sleep(cycles: _int) -> None: ...
def _cuda_synchronize() -> None: ...
def _cuda_ipc_collect() -> None: ...
def _cuda_getArchFlags() -> Optional[str]: ...
def _cuda_init() -> None: ...
def _cuda_setStream(cuda_stream: _int) -> None: ...
def _cuda_getCompiledVersion() -> _int: ...
def _cuda_cudaHostAllocator() -> _int: ...

View File

@ -21,7 +21,7 @@ from .. import device as _device
import torch._C
try:
from torch._C import _cudart
from torch._C import _cudart # type: ignore
except ImportError:
_cudart = None
@ -30,18 +30,18 @@ _tls = threading.local()
_initialization_lock = threading.Lock()
_queued_calls = [] # don't invoke these until initialization occurs
_is_in_bad_fork = getattr(torch._C, "_cuda_isInBadFork", lambda: False)
_device_t = Union[_device, str, int]
_device_t = Union[_device, str, int, None]
# Define dummy _CudaDeviceProperties type if PyTorch was compiled without CUDA
if hasattr(torch._C, '_CudaDeviceProperties'):
_CudaDeviceProperties = torch._C._CudaDeviceProperties
else:
_CudaDeviceProperties = _dummy_type('_CudaDeviceProperties')
_CudaDeviceProperties = _dummy_type('_CudaDeviceProperties') # type: ignore
# Global variables dynamically populated by native code
has_magma: bool = False
has_half: bool = False
default_generators: Tuple[torch._C.Generator] = ()
default_generators: Tuple[torch._C.Generator] = () # type: ignore[assignment]
def is_available() -> bool:
r"""Returns a bool indicating if CUDA is currently available."""
@ -297,7 +297,7 @@ def get_device_properties(device: _device_t) -> _CudaDeviceProperties:
device = _get_device_index(device, optional=True)
if device < 0 or device >= device_count():
raise AssertionError("Invalid device id")
return _get_device_properties(device)
return _get_device_properties(device) # type: ignore[name-defined]
@contextlib.contextmanager
@ -356,8 +356,8 @@ def get_gencode_flags() -> str:
arch_list = get_arch_list()
if len(arch_list) == 0:
return ""
arch_list = [arch.split("_") for arch in arch_list]
return " ".join([f"-gencode compute=compute_{arch},code={kind}_{arch}" for (kind, arch) in arch_list])
arch_list_ = [arch.split("_") for arch in arch_list]
return " ".join([f"-gencode compute=compute_{arch},code={kind}_{arch}" for (kind, arch) in arch_list_])
@ -454,7 +454,7 @@ if not hasattr(torch._C, 'CudaDoubleStorageBase'):
torch._C.__dict__['_CudaEventBase'] = _dummy_type('CudaEventBase')
@staticmethod
@staticmethod # type: ignore[misc]
def _lazy_new(cls, *args, **kwargs):
_lazy_init()
# We may need to call lazy init again if we are a forked child
@ -467,8 +467,11 @@ class _CudaBase(object):
is_sparse = False
def type(self, *args, **kwargs):
with device(self.get_device()):
return super(_CudaBase, self).type(*args, **kwargs)
# We could use a Protocol here to tell mypy that self has `get_device` method
# but it is only available in the typing module on Python >= 3.8
# or on typing_extensions module on Python >= 3.6
with device(self.get_device()): # type: ignore
return super(_CudaBase, self).type(*args, **kwargs) # type: ignore[misc]
__new__ = _lazy_new

View File

@ -5,7 +5,7 @@ from torch.types import Device
from torch._utils import _get_device_index as _torch_get_device_index
def _get_device_index(device: Union[Device, str, int], optional: bool = False,
def _get_device_index(device: Union[Device, str, int, None], optional: bool = False,
allow_cpu: bool = False) -> int:
r"""Gets the device index from :attr:`device`, which can be a torch.device
object, a Python integer, or ``None``.