mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Move torch.cuda annotations inline (#40075)
Summary: Also enable `torch.cuda` typechecking Pull Request resolved: https://github.com/pytorch/pytorch/pull/40075 Differential Revision: D22121275 Pulled By: malfet fbshipit-source-id: dbecef09911334e8f3d87f5ecab66349da9f2325
This commit is contained in:
parent
c1958de49d
commit
8b5732e8ad
20
mypy.ini
20
mypy.ini
|
|
@ -227,7 +227,25 @@ ignore_errors = True
|
|||
[mypy-torch.nn.quantized.modules.functional_modules]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.cuda.*]
|
||||
[mypy-torch.cuda]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.cuda.amp.*]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.cuda.comm]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.cuda.memory]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.cuda.nccl]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.cuda.nvtx]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.cuda.streams]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch._lobpcg]
|
||||
|
|
|
|||
|
|
@ -216,3 +216,13 @@ class _TensorBase(object):
|
|||
imag: Tensor
|
||||
_version: _bool
|
||||
${tensor_method_hints}
|
||||
|
||||
# Defined in torch/csrs/cuda/Module.cpp
|
||||
class _CudaDeviceProperties:
|
||||
name: str
|
||||
major: _int
|
||||
minor: _int
|
||||
multi_processor_count: _int
|
||||
total_memory: _int
|
||||
is_integrated: _int
|
||||
is_multi_gpu_board: _int
|
||||
|
|
|
|||
|
|
@ -396,8 +396,8 @@ PyObject * THCPModule_memorySnapshot(PyObject *_unused, PyObject *noargs)
|
|||
// Cuda module initialization
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
static void bindCudaDeviceProperties(PyObject* module) {
|
||||
// Add class and method to torch.cuda
|
||||
static void registerCudaDeviceProperties(PyObject* module) {
|
||||
// Add _cudaDevicePropertires class to torch._C
|
||||
auto m = py::handle(module).cast<py::module>();
|
||||
py::class_<cudaDeviceProp>(m, "_CudaDeviceProperties")
|
||||
.def_readonly("name", &cudaDeviceProp::name)
|
||||
|
|
@ -414,6 +414,11 @@ static void bindCudaDeviceProperties(PyObject* module) {
|
|||
<< "MB, multi_processor_count=" << prop.multiProcessorCount << ")";
|
||||
return stream.str();
|
||||
});
|
||||
}
|
||||
|
||||
static void bindGetDeviceProperties(PyObject* module) {
|
||||
// Add method to torch.cuda
|
||||
auto m = py::handle(module).cast<py::module>();
|
||||
m.def("_get_device_properties", [](int device) -> cudaDeviceProp * {
|
||||
return at::cuda::getDeviceProperties(device);
|
||||
}, py::return_value_policy::reference);
|
||||
|
|
@ -469,8 +474,7 @@ static PyObject * THCPModule_initExtension(PyObject *self, PyObject *noargs)
|
|||
PyTuple_SetItem(default_cuda_generators, i, (PyObject*)cast_gen);
|
||||
}
|
||||
set_module_attr("default_generators", default_cuda_generators);
|
||||
|
||||
bindCudaDeviceProperties(m);
|
||||
bindGetDeviceProperties(m);
|
||||
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
|
|
@ -551,6 +555,7 @@ void initModule(PyObject *module) {
|
|||
#if defined(USE_CUDNN) || defined(__HIP_PLATFORM_HCC__)
|
||||
shared::initCudnnBindings(module);
|
||||
#endif
|
||||
registerCudaDeviceProperties(module);
|
||||
}
|
||||
|
||||
}}
|
||||
|
|
|
|||
|
|
@ -14,8 +14,11 @@ import torch
|
|||
import traceback
|
||||
import warnings
|
||||
import threading
|
||||
from typing import Optional, Tuple, Union
|
||||
from torch._six import raise_from
|
||||
from ._utils import _get_device_index, _dummy_type
|
||||
from .streams import Stream, Event
|
||||
from .. import device as _device
|
||||
import torch._C
|
||||
|
||||
try:
|
||||
|
|
@ -28,9 +31,20 @@ _tls = threading.local()
|
|||
_initialization_lock = threading.Lock()
|
||||
_queued_calls = [] # don't invoke these until initialization occurs
|
||||
_is_in_bad_fork = getattr(torch._C, "_cuda_isInBadFork", lambda: False)
|
||||
_device_t = Union[_device, str, int]
|
||||
|
||||
# Define dummy _CudaDeviceProperties type if PyTorch was compiled without CUDA
|
||||
if hasattr(torch._C, '_CudaDeviceProperties'):
|
||||
_CudaDeviceProperties = torch._C._CudaDeviceProperties
|
||||
else:
|
||||
_CudaDeviceProperties = _dummy_type('_CudaDeviceProperties')
|
||||
|
||||
def is_available():
|
||||
# Global variables dynamically populated by native code
|
||||
has_magma: bool = False
|
||||
has_half: bool = False
|
||||
default_generators: Tuple[torch._C.Generator] = ()
|
||||
|
||||
def is_available() -> bool:
|
||||
r"""Returns a bool indicating if CUDA is currently available."""
|
||||
if (not hasattr(torch._C, '_cuda_isDriverSufficient') or
|
||||
not torch._C._cuda_isDriverSufficient()):
|
||||
|
|
@ -175,17 +189,16 @@ def cudart():
|
|||
|
||||
|
||||
class cudaStatus(object):
|
||||
SUCCESS = 0
|
||||
ERROR_NOT_READY = 34
|
||||
|
||||
SUCCESS: int = 0
|
||||
ERROR_NOT_READY: int = 34
|
||||
|
||||
class CudaError(RuntimeError):
|
||||
def __init__(self, code):
|
||||
def __init__(self, code: int) -> None:
|
||||
msg = _cudart.cudaGetErrorString(code).decode('utf-8')
|
||||
super(CudaError, self).__init__('{0} ({1})'.format(msg, code))
|
||||
|
||||
|
||||
def check_error(res):
|
||||
def check_error(res: int) -> None:
|
||||
if res != _cudart.cudaError.success:
|
||||
raise CudaError(res)
|
||||
|
||||
|
|
@ -231,7 +244,7 @@ class device_of(device):
|
|||
super(device_of, self).__init__(idx)
|
||||
|
||||
|
||||
def set_device(device):
|
||||
def set_device(device: _device_t) -> None:
|
||||
r"""Sets the current device.
|
||||
|
||||
Usage of this function is discouraged in favor of :any:`device`. In most
|
||||
|
|
@ -246,7 +259,7 @@ def set_device(device):
|
|||
torch._C._cuda_setDevice(device)
|
||||
|
||||
|
||||
def get_device_name(device=None):
|
||||
def get_device_name(device: Optional[_device_t] = None) -> str:
|
||||
r"""Gets the name of a device.
|
||||
|
||||
Arguments:
|
||||
|
|
@ -258,7 +271,7 @@ def get_device_name(device=None):
|
|||
return get_device_properties(device).name
|
||||
|
||||
|
||||
def get_device_capability(device=None):
|
||||
def get_device_capability(device: Optional[_device_t] = None) -> Tuple[int, int]:
|
||||
r"""Gets the cuda capability of a device.
|
||||
|
||||
Arguments:
|
||||
|
|
@ -275,8 +288,8 @@ def get_device_capability(device=None):
|
|||
return prop.major, prop.minor
|
||||
|
||||
|
||||
def get_device_properties(device):
|
||||
_lazy_init() # will define _get_device_properties and _CudaDeviceProperties
|
||||
def get_device_properties(device: _device_t) -> _CudaDeviceProperties:
|
||||
_lazy_init() # will define _get_device_properties
|
||||
device = _get_device_index(device, optional=True)
|
||||
if device < 0 or device >= device_count():
|
||||
raise AssertionError("Invalid device id")
|
||||
|
|
@ -318,7 +331,7 @@ def stream(stream):
|
|||
torch._C._cuda_setStream(src_prev_stream._cdata)
|
||||
|
||||
|
||||
def device_count():
|
||||
def device_count() -> int:
|
||||
r"""Returns the number of GPUs available."""
|
||||
if is_available():
|
||||
return torch._C._cuda_getDeviceCount()
|
||||
|
|
@ -326,13 +339,13 @@ def device_count():
|
|||
return 0
|
||||
|
||||
|
||||
def current_device():
|
||||
def current_device() -> int:
|
||||
r"""Returns the index of a currently selected device."""
|
||||
_lazy_init()
|
||||
return torch._C._cuda_getDevice()
|
||||
|
||||
|
||||
def synchronize(device=None):
|
||||
def synchronize(device: _device_t = None) -> None:
|
||||
r"""Waits for all kernels in all streams on a CUDA device to complete.
|
||||
|
||||
Arguments:
|
||||
|
|
@ -358,7 +371,7 @@ def ipc_collect():
|
|||
return torch._C._cuda_ipc_collect()
|
||||
|
||||
|
||||
def current_stream(device=None):
|
||||
def current_stream(device: Optional[_device_t] = None) -> Stream:
|
||||
r"""Returns the currently selected :class:`Stream` for a given device.
|
||||
|
||||
Arguments:
|
||||
|
|
@ -368,11 +381,11 @@ def current_stream(device=None):
|
|||
(default).
|
||||
"""
|
||||
_lazy_init()
|
||||
return torch.cuda.Stream(_cdata=torch._C._cuda_getCurrentStream(
|
||||
return Stream(_cdata=torch._C._cuda_getCurrentStream(
|
||||
_get_device_index(device, optional=True)))
|
||||
|
||||
|
||||
def default_stream(device=None):
|
||||
def default_stream(device: Optional[_device_t] = None) -> Stream:
|
||||
r"""Returns the default :class:`Stream` for a given device.
|
||||
|
||||
Arguments:
|
||||
|
|
@ -382,7 +395,7 @@ def default_stream(device=None):
|
|||
(default).
|
||||
"""
|
||||
_lazy_init()
|
||||
return torch.cuda.Stream(_cdata=torch._C._cuda_getDefaultStream(
|
||||
return Stream(_cdata=torch._C._cuda_getDefaultStream(
|
||||
_get_device_index(device, optional=True)))
|
||||
|
||||
|
||||
|
|
@ -500,5 +513,4 @@ torch._storage_classes.add(ComplexFloatStorage)
|
|||
from . import sparse
|
||||
from . import profiler
|
||||
from . import nvtx
|
||||
from .streams import Stream, Event
|
||||
from . import amp
|
||||
|
|
|
|||
|
|
@ -1,51 +0,0 @@
|
|||
from typing import Optional, Tuple, Union
|
||||
from .. import device as _device
|
||||
from .random import (
|
||||
get_rng_state as get_rng_state,
|
||||
get_rng_state_all as get_rng_state_all,
|
||||
set_rng_state as set_rng_state,
|
||||
set_rng_state_all as set_rng_state_all,
|
||||
manual_seed as manual_seed,
|
||||
manual_seed_all as manual_seed_all,
|
||||
seed as seed,
|
||||
seed_all as seed_all,
|
||||
initial_seed as initial_seed,
|
||||
)
|
||||
|
||||
|
||||
def is_available() -> bool: ...
|
||||
def init() -> None: ...
|
||||
|
||||
class cudaStatus:
|
||||
SUCCESS: int
|
||||
ERROR_NOT_READY: int
|
||||
|
||||
class CudaError:
|
||||
def __init__(self, code: int) -> None: ...
|
||||
|
||||
class _CudaDeviceProperties:
|
||||
name: str
|
||||
major: int
|
||||
minor: int
|
||||
multi_processor_count: int
|
||||
total_memory: int
|
||||
is_integrated: int
|
||||
is_multi_gpu_board: int
|
||||
|
||||
_device_t = Union[_device, int]
|
||||
|
||||
def check_error(res: int) -> None: ...
|
||||
def device_count() -> int: ...
|
||||
def empty_cache() -> None: ...
|
||||
def synchronize(device: _device_t) -> None: ...
|
||||
def set_device(device: _device_t) -> None: ...
|
||||
def get_device_capability(device: Optional[_device_t]=...) -> Tuple[int, int]: ...
|
||||
def get_device_name(device: Optional[_device_t]=...) -> str: ...
|
||||
def get_device_properties(device: _device_t) -> _CudaDeviceProperties: ...
|
||||
def current_device() -> int: ...
|
||||
def memory_allocated(device: Optional[_device_t]=...) -> int: ...
|
||||
def max_memory_allocated(device: Optional[_device_t]=...) -> int: ...
|
||||
def reset_max_memory_allocated(device: Optional[_device_t]=...) -> None: ...
|
||||
def memory_cached(device: Optional[_device_t]=...) -> int: ...
|
||||
def max_memory_cached(device: Optional[_device_t]=...) -> int: ...
|
||||
def reset_max_memory_cached(device: Optional[_device_t]=...) -> None: ...
|
||||
|
|
@ -1,9 +1,10 @@
|
|||
import torch
|
||||
import torch._six
|
||||
from typing import Union
|
||||
from typing import Optional, Union
|
||||
from torch.types import Device
|
||||
|
||||
|
||||
def _get_device_index(device: Union[str, torch.device, int, None], optional=False) -> int:
|
||||
def _get_device_index(device: Union[Device, int], optional=False) -> int:
|
||||
r"""Gets the device index from :attr:`device`, which can be a torch.device
|
||||
object, a Python integer, or ``None``.
|
||||
|
||||
|
|
@ -19,6 +20,7 @@ def _get_device_index(device: Union[str, torch.device, int, None], optional=Fals
|
|||
"""
|
||||
if isinstance(device, torch._six.string_classes):
|
||||
device = torch.device(device)
|
||||
device_idx: Optional[int]
|
||||
if isinstance(device, torch.device):
|
||||
dev_type = device.type
|
||||
if device.type != 'cuda':
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
import collections
|
||||
import contextlib
|
||||
import warnings
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
import torch
|
||||
from . import is_initialized, _get_device_index
|
||||
|
||||
from torch.types import Device
|
||||
|
||||
def _host_allocator():
|
||||
_lazy_init()
|
||||
|
|
@ -20,7 +21,7 @@ def _free_mutex():
|
|||
torch._C._cuda_unlock_mutex()
|
||||
|
||||
|
||||
def caching_allocator_alloc(size, device=None, stream=None):
|
||||
def caching_allocator_alloc(size, device: Union[Device, int] = None, stream=None):
|
||||
r"""Performs a memory allocation using the CUDA memory allocator.
|
||||
|
||||
Memory is allocated for a given device and a stream, this
|
||||
|
|
@ -71,7 +72,7 @@ def caching_allocator_delete(mem_ptr):
|
|||
torch._C._cuda_cudaCachingAllocator_raw_delete(mem_ptr)
|
||||
|
||||
|
||||
def empty_cache():
|
||||
def empty_cache() -> None:
|
||||
r"""Releases all unoccupied cached memory currently held by the caching
|
||||
allocator so that those can be used in other GPU application and visible in
|
||||
`nvidia-smi`.
|
||||
|
|
@ -86,7 +87,7 @@ def empty_cache():
|
|||
torch._C._cuda_emptyCache()
|
||||
|
||||
|
||||
def memory_stats(device=None):
|
||||
def memory_stats(device: Union[Device, int] = None) -> Dict[str, Any]:
|
||||
r"""Returns a dictionary of CUDA memory allocator statistics for a
|
||||
given device.
|
||||
|
||||
|
|
@ -163,13 +164,13 @@ def memory_stats(device=None):
|
|||
return collections.OrderedDict(result)
|
||||
|
||||
|
||||
def memory_stats_as_nested_dict(device=None):
|
||||
def memory_stats_as_nested_dict(device: Union[Device, int] = None) -> Dict[str, Any]:
|
||||
r"""Returns the result of :func:`~torch.cuda.memory_stats` as a nested dictionary."""
|
||||
device = _get_device_index(device, optional=True)
|
||||
return torch._C._cuda_memoryStats(device)
|
||||
|
||||
|
||||
def reset_accumulated_memory_stats(device=None):
|
||||
def reset_accumulated_memory_stats(device: Union[Device, int] = None) -> None:
|
||||
r"""Resets the "accumulated" (historical) stats tracked by the CUDA memory allocator.
|
||||
|
||||
See :func:`~torch.cuda.memory_stats` for details. Accumulated stats correspond to
|
||||
|
|
@ -189,7 +190,7 @@ def reset_accumulated_memory_stats(device=None):
|
|||
return torch._C._cuda_resetAccumulatedMemoryStats(device)
|
||||
|
||||
|
||||
def reset_peak_memory_stats(device=None):
|
||||
def reset_peak_memory_stats(device: Union[Device, int] = None) -> None:
|
||||
r"""Resets the "peak" stats tracked by the CUDA memory allocator.
|
||||
|
||||
See :func:`~torch.cuda.memory_stats` for details. Peak stats correspond to the
|
||||
|
|
@ -208,7 +209,7 @@ def reset_peak_memory_stats(device=None):
|
|||
return torch._C._cuda_resetPeakMemoryStats(device)
|
||||
|
||||
|
||||
def reset_max_memory_allocated(device=None):
|
||||
def reset_max_memory_allocated(device: Union[Device, int] = None) -> None:
|
||||
r"""Resets the starting point in tracking maximum GPU memory occupied by
|
||||
tensors for a given device.
|
||||
|
||||
|
|
@ -234,7 +235,7 @@ def reset_max_memory_allocated(device=None):
|
|||
return reset_peak_memory_stats(device=device)
|
||||
|
||||
|
||||
def reset_max_memory_cached(device=None):
|
||||
def reset_max_memory_cached(device: Union[Device, int] = None) -> None:
|
||||
r"""Resets the starting point in tracking maximum GPU memory managed by the
|
||||
caching allocator for a given device.
|
||||
|
||||
|
|
@ -260,7 +261,7 @@ def reset_max_memory_cached(device=None):
|
|||
return reset_peak_memory_stats(device=device)
|
||||
|
||||
|
||||
def memory_allocated(device=None):
|
||||
def memory_allocated(device: Union[Device, int] = None) -> int:
|
||||
r"""Returns the current GPU memory occupied by tensors in bytes for a given
|
||||
device.
|
||||
|
||||
|
|
@ -278,7 +279,7 @@ def memory_allocated(device=None):
|
|||
return memory_stats(device=device)["allocated_bytes.all.current"]
|
||||
|
||||
|
||||
def max_memory_allocated(device=None):
|
||||
def max_memory_allocated(device: Union[Device, int] = None) -> int:
|
||||
r"""Returns the maximum GPU memory occupied by tensors in bytes for a given
|
||||
device.
|
||||
|
||||
|
|
@ -300,7 +301,7 @@ def max_memory_allocated(device=None):
|
|||
return memory_stats(device=device)["allocated_bytes.all.peak"]
|
||||
|
||||
|
||||
def memory_reserved(device=None):
|
||||
def memory_reserved(device: Union[Device, int] = None) -> int:
|
||||
r"""Returns the current GPU memory managed by the caching allocator in bytes
|
||||
for a given device.
|
||||
|
||||
|
|
@ -316,7 +317,7 @@ def memory_reserved(device=None):
|
|||
return memory_stats(device=device)["reserved_bytes.all.current"]
|
||||
|
||||
|
||||
def max_memory_reserved(device=None):
|
||||
def max_memory_reserved(device: Union[Device, int] = None) -> int:
|
||||
r"""Returns the maximum GPU memory managed by the caching allocator in bytes
|
||||
for a given device.
|
||||
|
||||
|
|
@ -338,7 +339,7 @@ def max_memory_reserved(device=None):
|
|||
return memory_stats(device=device)["reserved_bytes.all.peak"]
|
||||
|
||||
|
||||
def memory_cached(device=None):
|
||||
def memory_cached(device: Union[Device, int] = None) -> int:
|
||||
r"""Deprecated; see :func:`~torch.cuda.memory_reserved`."""
|
||||
warnings.warn(
|
||||
"torch.cuda.memory_cached has been renamed to torch.cuda.memory_reserved",
|
||||
|
|
@ -346,7 +347,7 @@ def memory_cached(device=None):
|
|||
return memory_reserved(device=device)
|
||||
|
||||
|
||||
def max_memory_cached(device=None):
|
||||
def max_memory_cached(device: Union[Device, int] = None) -> int:
|
||||
r"""Deprecated; see :func:`~torch.cuda.max_memory_reserved`."""
|
||||
warnings.warn(
|
||||
"torch.cuda.max_memory_cached has been renamed to torch.cuda.max_memory_reserved",
|
||||
|
|
@ -367,7 +368,7 @@ def memory_snapshot():
|
|||
return torch._C._cuda_memorySnapshot()
|
||||
|
||||
|
||||
def memory_summary(device=None, abbreviated=False):
|
||||
def memory_summary(device: Union[Device, int] = None, abbreviated: bool = False) -> str:
|
||||
r"""Returns a human-readable printout of the current memory allocator
|
||||
statistics for a given device.
|
||||
|
||||
|
|
|
|||
|
|
@ -15,9 +15,9 @@ DEFAULT_FLAGS = [
|
|||
|
||||
|
||||
def init(output_file, flags=None, output_mode='key_value'):
|
||||
rt = cudart()
|
||||
if not hasattr(rt, 'cudaOutputMode'):
|
||||
raise AssertionError("HIP does not support profiler initialization!")
|
||||
rt = cudart()
|
||||
flags = DEFAULT_FLAGS if flags is None else flags
|
||||
if output_mode == 'key_value':
|
||||
output_mode_enum = rt.cudaOutputMode.KeyValuePair
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
import torch
|
||||
from typing import cast, Iterable, List, Union
|
||||
from . import _lazy_init, _lazy_call, device_count, current_device
|
||||
from .. import Tensor
|
||||
|
||||
__all__ = ['get_rng_state', 'get_rng_state_all',
|
||||
'set_rng_state', 'set_rng_state_all',
|
||||
|
|
@ -7,7 +9,7 @@ __all__ = ['get_rng_state', 'get_rng_state_all',
|
|||
'seed', 'seed_all', 'initial_seed']
|
||||
|
||||
|
||||
def get_rng_state(device='cuda'):
|
||||
def get_rng_state(device: Union[int, str, torch.device] = 'cuda') -> Tensor:
|
||||
r"""Returns the random number generator state of the specified GPU as a ByteTensor.
|
||||
|
||||
Args:
|
||||
|
|
@ -29,7 +31,7 @@ def get_rng_state(device='cuda'):
|
|||
return default_generator.get_state()
|
||||
|
||||
|
||||
def get_rng_state_all():
|
||||
def get_rng_state_all() -> List[Tensor]:
|
||||
r"""Returns a list of ByteTensor representing the random number states of all devices."""
|
||||
|
||||
results = []
|
||||
|
|
@ -38,7 +40,7 @@ def get_rng_state_all():
|
|||
return results
|
||||
|
||||
|
||||
def set_rng_state(new_state, device='cuda'):
|
||||
def set_rng_state(new_state: Tensor, device: Union[int, str, torch.device] = 'cuda') -> None:
|
||||
r"""Sets the random number generator state of the specified GPU.
|
||||
|
||||
Args:
|
||||
|
|
@ -53,7 +55,7 @@ def set_rng_state(new_state, device='cuda'):
|
|||
device = torch.device('cuda', device)
|
||||
|
||||
def cb():
|
||||
idx = device.index
|
||||
idx = cast(torch.device, device).index
|
||||
if idx is None:
|
||||
idx = current_device()
|
||||
default_generator = torch.cuda.default_generators[idx]
|
||||
|
|
@ -62,7 +64,7 @@ def set_rng_state(new_state, device='cuda'):
|
|||
_lazy_call(cb)
|
||||
|
||||
|
||||
def set_rng_state_all(new_states):
|
||||
def set_rng_state_all(new_states: Iterable[Tensor]) -> None:
|
||||
r"""Sets the random number generator state of all devices.
|
||||
|
||||
Args:
|
||||
|
|
@ -71,7 +73,7 @@ def set_rng_state_all(new_states):
|
|||
set_rng_state(state, i)
|
||||
|
||||
|
||||
def manual_seed(seed):
|
||||
def manual_seed(seed: int) -> None:
|
||||
r"""Sets the seed for generating random numbers for the current GPU.
|
||||
It's safe to call this function if CUDA is not available; in that
|
||||
case, it is silently ignored.
|
||||
|
|
@ -93,7 +95,7 @@ def manual_seed(seed):
|
|||
_lazy_call(cb)
|
||||
|
||||
|
||||
def manual_seed_all(seed):
|
||||
def manual_seed_all(seed: int) -> None:
|
||||
r"""Sets the seed for generating random numbers on all GPUs.
|
||||
It's safe to call this function if CUDA is not available; in that
|
||||
case, it is silently ignored.
|
||||
|
|
@ -111,7 +113,7 @@ def manual_seed_all(seed):
|
|||
_lazy_call(cb)
|
||||
|
||||
|
||||
def seed():
|
||||
def seed() -> None:
|
||||
r"""Sets the seed for generating random numbers to a random number for the current GPU.
|
||||
It's safe to call this function if CUDA is not available; in that
|
||||
case, it is silently ignored.
|
||||
|
|
@ -128,7 +130,7 @@ def seed():
|
|||
_lazy_call(cb)
|
||||
|
||||
|
||||
def seed_all():
|
||||
def seed_all() -> None:
|
||||
r"""Sets the seed for generating random numbers to a random number on all GPUs.
|
||||
It's safe to call this function if CUDA is not available; in that
|
||||
case, it is silently ignored.
|
||||
|
|
@ -148,7 +150,7 @@ def seed_all():
|
|||
_lazy_call(cb)
|
||||
|
||||
|
||||
def initial_seed():
|
||||
def initial_seed() -> int:
|
||||
r"""Returns the current random seed of the current GPU.
|
||||
|
||||
.. warning::
|
||||
|
|
|
|||
|
|
@ -1,15 +0,0 @@
|
|||
from typing import Iterable, List, Optional, Union
|
||||
from .. import device as _device, Tensor
|
||||
|
||||
|
||||
_device_t = Union[_device, int]
|
||||
|
||||
def get_rng_state(device: Optional[_device_t]=...) -> int: ...
|
||||
def get_rng_state_all() -> List[int]: ...
|
||||
def set_rng_state(new_state: Tensor, device: Optional[_device_t]=...) -> None: ...
|
||||
def set_rng_state_all(new_state: Iterable[Tensor]) -> None: ...
|
||||
def manual_seed(seed: int) -> None: ...
|
||||
def manual_seed_all(seed: int) -> None: ...
|
||||
def seed() -> None: ...
|
||||
def seed_all() -> None: ...
|
||||
def initial_seed() -> int: ...
|
||||
Loading…
Reference in New Issue
Block a user