Add type annotations to torch/__init__.py (#106214)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106214
Approved by: https://github.com/Skylion007
This commit is contained in:
Richard Barnes 2023-08-02 19:13:28 +00:00 committed by PyTorch MergeBot
parent bd84651e19
commit 1534af2a5c

View File

@ -161,7 +161,7 @@ def _preload_cuda_deps(lib_folder, lib_name):
# See Note [Global dependencies]
def _load_global_deps():
def _load_global_deps() -> None:
if _running_with_deploy() or platform.system() == 'Windows':
return
@ -659,7 +659,7 @@ def set_default_dtype(d):
"""
_C._set_default_dtype(d)
def use_deterministic_algorithms(mode, *, warn_only=False):
def use_deterministic_algorithms(mode: builtins.bool, *, warn_only: builtins.bool = False) -> None:
r""" Sets whether PyTorch operations must use "deterministic"
algorithms. That is, algorithms which, given the same input, and when
run on the same software and hardware, always produce the same output.
@ -799,13 +799,13 @@ def use_deterministic_algorithms(mode, *, warn_only=False):
"""
_C._set_deterministic_algorithms(mode, warn_only=warn_only)
def are_deterministic_algorithms_enabled():
def are_deterministic_algorithms_enabled() -> builtins.bool:
r"""Returns True if the global deterministic flag is turned on. Refer to
:func:`torch.use_deterministic_algorithms` documentation for more details.
"""
return _C._get_deterministic_algorithms()
def is_deterministic_algorithms_warn_only_enabled():
def is_deterministic_algorithms_warn_only_enabled() -> builtins.bool:
r"""Returns True if the global deterministic flag is set to warn only.
Refer to :func:`torch.use_deterministic_algorithms` documentation for more
details.
@ -874,7 +874,7 @@ def get_float32_matmul_precision() -> builtins.str:
"""
return _C._get_float32_matmul_precision()
def set_float32_matmul_precision(precision):
def set_float32_matmul_precision(precision: str) -> None:
r"""Sets the internal precision of float32 matrix multiplications.
Running float32 matrix multiplications in lower precision may significantly increase
@ -919,7 +919,7 @@ def set_float32_matmul_precision(precision):
"""
_C._set_float32_matmul_precision(precision)
def set_warn_always(b):
def set_warn_always(b: builtins.bool) -> None:
r"""When this flag is False (default) then some PyTorch warnings may only
appear once per process. This helps avoid excessive warning information.
Setting it to True causes these warnings to always appear, which may be
@ -931,7 +931,7 @@ def set_warn_always(b):
"""
_C._set_warnAlways(b)
def is_warn_always_enabled():
def is_warn_always_enabled() -> builtins.bool:
r"""Returns True if the global warn_always flag is turned on. Refer to
:func:`torch.set_warn_always` documentation for more details.
"""
@ -1439,7 +1439,7 @@ from . import _torch_docs, _tensor_docs, _storage_docs
del _torch_docs, _tensor_docs, _storage_docs
def compiled_with_cxx11_abi():
def compiled_with_cxx11_abi() -> builtins.bool:
r"""Returns whether PyTorch was built with _GLIBCXX_USE_CXX11_ABI=1"""
return _C._GLIBCXX_USE_CXX11_ABI