mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Enable torch.cuda.memory typechecking (#43444)
Summary: Add number of function prototypes defined in torch/csrs/cuda/Module.cpp to `__init__.pyi.in` Fixes https://github.com/pytorch/pytorch/issues/43442 Pull Request resolved: https://github.com/pytorch/pytorch/pull/43444 Reviewed By: ezyang Differential Revision: D23280221 Pulled By: malfet fbshipit-source-id: 7d67dff7b24c8d7b7e72c919e6e7b847f242ef83
This commit is contained in:
parent
7024ce8a2c
commit
0fa99d50bc
3
mypy.ini
3
mypy.ini
|
|
@ -198,9 +198,6 @@ ignore_errors = True
|
||||||
[mypy-torch.cuda.comm]
|
[mypy-torch.cuda.comm]
|
||||||
ignore_errors = True
|
ignore_errors = True
|
||||||
|
|
||||||
[mypy-torch.cuda.memory]
|
|
||||||
ignore_errors = True
|
|
||||||
|
|
||||||
[mypy-torch.cuda.nccl]
|
[mypy-torch.cuda.nccl]
|
||||||
ignore_errors = True
|
ignore_errors = True
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from typing import (Any, BinaryIO, Callable, ContextManager, Iterator, List, NamedTuple,
|
from typing import (Any, BinaryIO, Callable, ContextManager, Dict, Iterator, List, NamedTuple,
|
||||||
Optional, overload, Sequence, Tuple, TypeVar, Type, Union)
|
Optional, overload, Sequence, Tuple, TypeVar, Type, Union)
|
||||||
from torch._six import inf
|
from torch._six import inf
|
||||||
|
|
||||||
|
|
@ -300,6 +300,24 @@ class _TensorBase(object):
|
||||||
${tensor_method_hints}
|
${tensor_method_hints}
|
||||||
|
|
||||||
# Defined in torch/csrc/cuda/Module.cpp
|
# Defined in torch/csrc/cuda/Module.cpp
|
||||||
|
def _cuda_getCurrentStream(device: _int) -> _int: ...
|
||||||
|
def _cuda_getDefaultStream(device: _int) -> _int: ...
|
||||||
|
def _cuda_getCurrentBlasHandle() -> _int: ...
|
||||||
|
def _cuda_setStream(cuda_stream: _int) -> None: ...
|
||||||
|
def _cuda_getCompiledVersion() -> _int: ...
|
||||||
|
def _cuda_cudaHostAllocator() -> _int: ...
|
||||||
|
def _cuda_cudaCachingAllocator_raw_alloc(size: _int, cuda_stream: _int) -> _int: ...
|
||||||
|
def _cuda_cudaCachingAllocator_raw_delete(ptr: _int) -> None: ...
|
||||||
|
def _cuda_emptyCache() -> None: ...
|
||||||
|
def _cuda_memoryStats(device: _int) -> Dict[str, Any]: ...
|
||||||
|
def _cuda_resetAccumulatedMemoryStats(device: _int) -> None: ...
|
||||||
|
def _cuda_resetPeakMemoryStats(device: _int) -> None: ...
|
||||||
|
def _cuda_memorySnapshot() -> List[Dict[str, Any]]: ...
|
||||||
|
def _cuda_lock_mutex() -> None: ...
|
||||||
|
def _cuda_unlock_mutex() -> None: ...
|
||||||
|
def _nccl_version() -> _int: ...
|
||||||
|
def _nccl_unique_id() -> bytes: ...
|
||||||
|
|
||||||
class _CudaDeviceProperties:
|
class _CudaDeviceProperties:
|
||||||
name: str
|
name: str
|
||||||
major: _int
|
major: _int
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ import warnings
|
||||||
from typing import Any, Dict, Union
|
from typing import Any, Dict, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from . import is_initialized, _get_device_index
|
from . import is_initialized, _get_device_index, _lazy_init
|
||||||
from torch.types import Device
|
from torch.types import Device
|
||||||
|
|
||||||
def _host_allocator():
|
def _host_allocator():
|
||||||
|
|
@ -31,7 +31,7 @@ def caching_allocator_alloc(size, device: Union[Device, int] = None, stream=None
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
size (int): number of bytes to be allocated.
|
size (int): number of bytes to be allocated.
|
||||||
device (torch.device or int, optional): selected device. If it is
|
device (torch.device or int, optional): selected device. If it is
|
||||||
``None`` the default CUDA device is used.
|
``None`` the default CUDA device is used.
|
||||||
stream (torch.cuda.Stream or int, optional): selected stream. If is ``None`` then
|
stream (torch.cuda.Stream or int, optional): selected stream. If is ``None`` then
|
||||||
the default stream for the selected device is used.
|
the default stream for the selected device is used.
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user