mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
And type annotations for cpp_extension, utils.data, signal_handling (#42647)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/42647 Reviewed By: ezyang Differential Revision: D22967041 Pulled By: malfet fbshipit-source-id: 35e124da0be56934faef56834a93b2b400decf66
This commit is contained in:
parent
608f99e4ea
commit
bcab2d6848
12
mypy.ini
12
mypy.ini
|
|
@ -234,18 +234,9 @@ ignore_errors = True
|
|||
[mypy-torch.contrib._tensorboard_vis]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.utils.cpp_extension]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.utils.bottleneck.__main__]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.utils.data]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.utils.data._utils.signal_handling]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.utils.data._utils.collate]
|
||||
ignore_errors = True
|
||||
|
||||
|
|
@ -448,6 +439,9 @@ ignore_missing_imports = True
|
|||
[mypy-setuptools.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-distutils.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-nvd3.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
|
|
|
|||
|
|
@ -294,7 +294,7 @@ class _TensorBase(object):
|
|||
grad_fn: Any
|
||||
${tensor_method_hints}
|
||||
|
||||
# Defined in torch/csrs/cuda/Module.cpp
|
||||
# Defined in torch/csrc/cuda/Module.cpp
|
||||
class _CudaDeviceProperties:
|
||||
name: str
|
||||
major: _int
|
||||
|
|
@ -329,3 +329,9 @@ class _CudaEventBase:
|
|||
def elapsed_time(self, other: _CudaEventBase) -> _float: ...
|
||||
def synchronize(self) -> None: ...
|
||||
def ipc_handle(self) -> bytes: ...
|
||||
|
||||
# Defined in torch/csrc/DataLoader.cpp
|
||||
def _set_worker_signal_handlers(*arg: Any) -> None: ... # THPModule_setWorkerSignalHandlers
|
||||
def _set_worker_pids(key: _int, child_pids: Tuple[_int, ...]) -> None: ... # THPModule_setWorkerPIDs
|
||||
def _remove_worker_pids(loader_id: _int) -> None: ... # THPModule_removeWorkerPIDs
|
||||
def _error_if_any_worker_fails() -> None: ... # THPModule_errorIfAnyWorkerFails
|
||||
|
|
|
|||
|
|
@ -129,7 +129,10 @@ with compiling PyTorch from source.
|
|||
ROCM_HOME = _find_rocm_home()
|
||||
MIOPEN_HOME = _join_rocm_home('miopen') if ROCM_HOME else None
|
||||
IS_HIP_EXTENSION = True if ((ROCM_HOME is not None) and (torch.version.hip is not None)) else False
|
||||
ROCM_VERSION = tuple(int(v) for v in torch.version.hip.split('.')[:2]) if torch.version.hip is not None else None
|
||||
ROCM_VERSION = None
|
||||
if torch.version.hip is not None:
|
||||
ROCM_VERSION = tuple(int(v) for v in torch.version.hip.split('.')[:2])
|
||||
|
||||
CUDA_HOME = _find_cuda_home()
|
||||
CUDNN_HOME = os.environ.get('CUDNN_HOME') or os.environ.get('CUDNN_PATH')
|
||||
# PyTorch releases have the version pattern major.minor.patch, whereas when
|
||||
|
|
@ -259,8 +262,8 @@ def check_compiler_abi_compatibility(compiler):
|
|||
try:
|
||||
if sys.platform.startswith('linux'):
|
||||
minimum_required_version = MINIMUM_GCC_VERSION
|
||||
version = subprocess.check_output([compiler, '-dumpfullversion', '-dumpversion'])
|
||||
version = version.decode().strip().split('.')
|
||||
versionstr = subprocess.check_output([compiler, '-dumpfullversion', '-dumpversion'])
|
||||
version = versionstr.decode().strip().split('.')
|
||||
else:
|
||||
minimum_required_version = MINIMUM_MSVC_VERSION
|
||||
compiler_info = subprocess.check_output(compiler, stderr=subprocess.STDOUT)
|
||||
|
|
@ -316,7 +319,7 @@ class BuildExtension(build_ext, object):
|
|||
Returns a subclass with alternative constructor that extends any original keyword
|
||||
arguments to the original constructor with the given options.
|
||||
'''
|
||||
class cls_with_options(cls):
|
||||
class cls_with_options(cls): # type: ignore
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs.update(options)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
|
@ -613,7 +616,7 @@ class BuildExtension(build_ext, object):
|
|||
cuda_post_cflags = list(extra_postargs)
|
||||
cuda_post_cflags = win_cuda_flags(cuda_post_cflags)
|
||||
|
||||
from distutils.spawn import _nt_quote_args
|
||||
from distutils.spawn import _nt_quote_args # type: ignore
|
||||
cflags = _nt_quote_args(cflags)
|
||||
post_cflags = _nt_quote_args(post_cflags)
|
||||
if with_cuda:
|
||||
|
|
@ -786,6 +789,7 @@ def CUDAExtension(name, sources, *args, **kwargs):
|
|||
libraries.append('torch_cpu')
|
||||
libraries.append('torch_python')
|
||||
if IS_HIP_EXTENSION:
|
||||
assert ROCM_VERSION is not None
|
||||
libraries.append('amdhip64' if ROCM_VERSION >= (3, 5) else 'hip_hcc')
|
||||
libraries.append('c10_hip')
|
||||
libraries.append('torch_hip')
|
||||
|
|
@ -1352,6 +1356,7 @@ def _prepare_ldflags(extra_ldflags, with_cuda, verbose):
|
|||
if CUDNN_HOME is not None:
|
||||
extra_ldflags.append('-L{}'.format(os.path.join(CUDNN_HOME, 'lib64')))
|
||||
elif IS_HIP_EXTENSION:
|
||||
assert ROCM_VERSION is not None
|
||||
extra_ldflags.append('-L{}'.format(_join_rocm_home('lib')))
|
||||
extra_ldflags.append('-lamdhip64' if ROCM_VERSION >= (3, 5) else '-lhip_hcc')
|
||||
return extra_ldflags
|
||||
|
|
@ -1397,20 +1402,20 @@ def _get_cuda_arch_flags(cflags=None):
|
|||
# First check for an env var (same as used by the main setup.py)
|
||||
# Can be one or more architectures, e.g. "6.1" or "3.5;5.2;6.0;6.1;7.0+PTX"
|
||||
# See cmake/Modules_CUDA_fix/upstream/FindCUDA/select_compute_arch.cmake
|
||||
arch_list = os.environ.get('TORCH_CUDA_ARCH_LIST', None)
|
||||
_arch_list = os.environ.get('TORCH_CUDA_ARCH_LIST', None)
|
||||
|
||||
# If not given, determine what's needed for the GPU that can be found
|
||||
if not arch_list:
|
||||
if not _arch_list:
|
||||
capability = torch.cuda.get_device_capability()
|
||||
arch_list = ['{}.{}'.format(capability[0], capability[1])]
|
||||
else:
|
||||
# Deal with lists that are ' ' separated (only deal with ';' after)
|
||||
arch_list = arch_list.replace(' ', ';')
|
||||
_arch_list = _arch_list.replace(' ', ';')
|
||||
# Expand named arches
|
||||
for named_arch, archval in named_arches.items():
|
||||
arch_list = arch_list.replace(named_arch, archval)
|
||||
_arch_list = _arch_list.replace(named_arch, archval)
|
||||
|
||||
arch_list = arch_list.split(';')
|
||||
arch_list = _arch_list.split(';')
|
||||
|
||||
flags = []
|
||||
for arch in arch_list:
|
||||
|
|
@ -1528,8 +1533,10 @@ def _run_ninja_build(build_directory, verbose, error_prefix):
|
|||
_, error, _ = sys.exc_info()
|
||||
# error.output contains the stdout and stderr of the build attempt.
|
||||
message = error_prefix
|
||||
if hasattr(error, 'output') and error.output:
|
||||
message += ": {}".format(error.output.decode())
|
||||
# `error` is a CalledProcessError (which has an `ouput`) attribute, but
|
||||
# mypy thinks it's Optional[BaseException] and doesn't narrow
|
||||
if hasattr(error, 'output') and error.output: # type: ignore
|
||||
message += ": {}".format(error.output.decode()) # type: ignore
|
||||
raise RuntimeError(message)
|
||||
|
||||
|
||||
|
|
@ -1580,7 +1587,7 @@ def _write_ninja_file_to_build_library(path,
|
|||
|
||||
if IS_WINDOWS:
|
||||
cflags = common_cflags + COMMON_MSVC_FLAGS + extra_cflags
|
||||
from distutils.spawn import _nt_quote_args
|
||||
from distutils.spawn import _nt_quote_args # type: ignore
|
||||
cflags = _nt_quote_args(cflags)
|
||||
else:
|
||||
cflags = common_cflags + ['-fPIC', '-std=c++14'] + extra_cflags
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ def _set_SIGCHLD_handler():
|
|||
if IS_WINDOWS:
|
||||
return
|
||||
# can't set signal in child threads
|
||||
if not isinstance(threading.current_thread(), threading._MainThread):
|
||||
if not isinstance(threading.current_thread(), threading._MainThread): # type: ignore
|
||||
return
|
||||
global _SIGCHLD_handler_set
|
||||
if _SIGCHLD_handler_set:
|
||||
|
|
@ -65,6 +65,7 @@ def _set_SIGCHLD_handler():
|
|||
# Python can still get and update the process status successfully.
|
||||
_error_if_any_worker_fails()
|
||||
if previous_handler is not None:
|
||||
assert callable(previous_handler)
|
||||
previous_handler(signum, frame)
|
||||
|
||||
signal.signal(signal.SIGCHLD, handler)
|
||||
|
|
|
|||
|
|
@ -109,8 +109,8 @@ class DataLoader(Generic[T_co]):
|
|||
worker_init_fn (callable, optional): If not ``None``, this will be called on each
|
||||
worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
|
||||
input, after seeding and before data loading. (default: ``None``)
|
||||
prefetch_factor (int, optional, keyword-only arg): Number of sample loaded
|
||||
in advance by each worker. ``2`` means there will be a total of
|
||||
prefetch_factor (int, optional, keyword-only arg): Number of sample loaded
|
||||
in advance by each worker. ``2`` means there will be a total of
|
||||
2 * num_workers samples prefetched across all workers. (default: ``2``)
|
||||
|
||||
|
||||
|
|
@ -152,9 +152,9 @@ class DataLoader(Generic[T_co]):
|
|||
shuffle: bool = False, sampler: Optional[Sampler[int]] = None,
|
||||
batch_sampler: Optional[Sampler[Sequence[int]]] = None,
|
||||
num_workers: int = 0, collate_fn: _collate_fn_t = None,
|
||||
pin_memory: bool = False, drop_last: bool = False,
|
||||
pin_memory: bool = False, drop_last: bool = False,
|
||||
timeout: float = 0, worker_init_fn: _worker_init_fn_t = None,
|
||||
multiprocessing_context=None, generator=None,
|
||||
multiprocessing_context=None, generator=None,
|
||||
*, prefetch_factor: int = 2):
|
||||
torch._C._log_api_usage_once("python.data_loader") # type: ignore
|
||||
|
||||
|
|
@ -797,7 +797,8 @@ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
|
|||
else:
|
||||
self._data_queue = self._worker_result_queue
|
||||
|
||||
_utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self._workers))
|
||||
# .pid can be None only before process is spawned (not the case, so ignore)
|
||||
_utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self._workers)) # type: ignore
|
||||
_utils.signal_handling._set_SIGCHLD_handler()
|
||||
self._worker_pids_set = True
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user