mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Enable torch.utils typechecks (#42960)
Summary: Fix typos in torch.utils/_benchmark/README.md Add empty __init__.py to examples folder to make example invocations from README.md correct Fixed uniform distribution logic generation when mixval and maxval are None Fixes https://github.com/pytorch/pytorch/issues/42984 Pull Request resolved: https://github.com/pytorch/pytorch/pull/42960 Reviewed By: seemethere Differential Revision: D23095399 Pulled By: malfet fbshipit-source-id: 0546ce7299b157d9a1f8634340024b10c4b7e7de
This commit is contained in:
parent
eb47940c0a
commit
6753157c5a
7
mypy.ini
7
mypy.ini
|
|
@ -246,8 +246,8 @@ ignore_errors = True
|
|||
[mypy-torch.utils.data.distributed]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.utils.checkpoint]
|
||||
ignore_errors = True
|
||||
#[mypy-torch.utils.checkpoint]
|
||||
#ignore_errors = True
|
||||
|
||||
[mypy-torch.utils.collect_env]
|
||||
ignore_errors = True
|
||||
|
|
@ -261,9 +261,6 @@ ignore_errors = True
|
|||
[mypy-torch.nn.cpp]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.utils]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.utils.show_pickle]
|
||||
ignore_errors = True
|
||||
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ performance differences. Grouping and layout is based on metadata passed to
|
|||
table will be generated per unique label.
|
||||
|
||||
* `sub_label`: This is the label for a given configuration. Multiple statements
|
||||
may be logically equivilent differ in implementation. Assigning separate
|
||||
may be logically equivalent differ in implementation. Assigning separate
|
||||
sub_labels will result in a row per sub_label. If a sublabel is not provided,
|
||||
`stmt` is used instead. Statistics (such as computing the fastest
|
||||
implementation) are use all sub_labels.
|
||||
|
|
@ -54,7 +54,7 @@ own `description`, which allows them to appear in separate columns.
|
|||
Statistics do not mix values of different descriptions, since comparing the
|
||||
run time of drastically different inputs is generally not meaningful.
|
||||
|
||||
* `env`: An optional descripton of the torch environment. (e.g. `master` or
|
||||
* `env`: An optional description of the torch environment. (e.g. `master` or
|
||||
`my_branch`). Like sub_labels, statistics are calculated across envs. (Since
|
||||
comparing a branch to master or a stable release is a common use case.)
|
||||
However `Compare` will visually group rows which are run with the same `env`.
|
||||
|
|
|
|||
0
torch/utils/_benchmark/examples/__init__.py
Normal file
0
torch/utils/_benchmark/examples/__init__.py
Normal file
|
|
@ -28,6 +28,7 @@ import numpy as np
|
|||
import torch
|
||||
from torch.utils._benchmark.op_fuzzers import unary
|
||||
from torch.utils._benchmark import Timer, Measurement
|
||||
from typing import Dict, Tuple, List
|
||||
|
||||
|
||||
_MAIN, _SUBPROCESS = "main", "subprocess"
|
||||
|
|
@ -64,7 +65,7 @@ _DEVICES_TO_TEST = {
|
|||
"39744": {_CPU: True, _GPU: True},
|
||||
}
|
||||
|
||||
_AVAILABLE_GPUS = queue.Queue()
|
||||
_AVAILABLE_GPUS = queue.Queue[int]()
|
||||
_DTYPES_TO_TEST = {
|
||||
"39850": ("int8", "float32", "float64"),
|
||||
"39967": ("float32", "float64"),
|
||||
|
|
@ -226,7 +227,7 @@ def merge(measurements):
|
|||
|
||||
|
||||
def process_results(results, test_variance):
|
||||
paired_results = {}
|
||||
paired_results: Dict[Tuple[str, str, int, bool, int], List] = {}
|
||||
for (seed, use_gpu), result_batch in results:
|
||||
for r in result_batch:
|
||||
key = (r.label, r.description, r.num_threads, use_gpu, seed)
|
||||
|
|
@ -235,7 +236,7 @@ def process_results(results, test_variance):
|
|||
paired_results[key][index].append(r)
|
||||
|
||||
paired_results = {
|
||||
key: (merge(r_ref_list), merge(r_pr_list))
|
||||
key: [merge(r_ref_list), merge(r_pr_list)]
|
||||
for key, (r_ref_list, r_pr_list) in paired_results.items()
|
||||
}
|
||||
|
||||
|
|
@ -366,7 +367,7 @@ def run(cmd, cuda_visible_devices=""):
|
|||
cmd,
|
||||
env={
|
||||
"CUDA_VISIBLE_DEVICES": str(cuda_visible_devices),
|
||||
"PATH": os.getenv("PATH"),
|
||||
"PATH": os.getenv("PATH", ""),
|
||||
},
|
||||
stdout=subprocess.PIPE,
|
||||
shell=True
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ def assert_dicts_equal(dict_0, dict_1):
|
|||
def run(n, stmt, fuzzer_cls):
|
||||
float_iter = fuzzer_cls(seed=0, dtype=torch.float32).take(n)
|
||||
int_iter = fuzzer_cls(seed=0, dtype=torch.int32).take(n)
|
||||
results = []
|
||||
raw_results = []
|
||||
for i, (float_values, int_values) in enumerate(zip(float_iter, int_iter)):
|
||||
float_tensors, float_tensor_params, float_params = float_values
|
||||
int_tensors, int_tensor_params, int_params = int_values
|
||||
|
|
@ -58,13 +58,13 @@ def run(n, stmt, fuzzer_cls):
|
|||
steps = float_tensor_params[name]["steps"]
|
||||
steps_str = str(steps) if sum(steps) > len(steps) else ""
|
||||
descriptions.append((name, shape_str, order_str, steps_str))
|
||||
results.append((float_measurement, int_measurement, descriptions))
|
||||
raw_results.append((float_measurement, int_measurement, descriptions))
|
||||
|
||||
print(f"\r{i + 1} / {n}", end="")
|
||||
print()
|
||||
|
||||
parsed_results, name_len, shape_len, order_len, steps_len = [], 0, 0, 0, 0
|
||||
for float_measurement, int_measurement, descriptions in results:
|
||||
for float_measurement, int_measurement, descriptions in raw_results:
|
||||
t_float = float_measurement.median * 1e6
|
||||
t_int = int_measurement.median * 1e6
|
||||
rel_diff = abs(t_float - t_int) / (t_float + t_int) * 2
|
||||
|
|
|
|||
|
|
@ -103,13 +103,12 @@ class FuzzedParameter(object):
|
|||
|
||||
def _loguniform(self, state):
|
||||
output = int(2 ** state.uniform(
|
||||
low=np.log2(self._minval),
|
||||
high=np.log2(self._maxval)
|
||||
low=np.log2(self._minval) if self._minval is not None else None,
|
||||
high=np.log2(self._maxval) if self._maxval is not None else None,
|
||||
))
|
||||
# `or 0` is to appease MyPy
|
||||
if output < (self._minval or 0.0):
|
||||
if self._minval is not None and output < self._minval:
|
||||
return self._minval
|
||||
if output > (self._maxval or 0.0):
|
||||
if self._maxval is not None and output > self._maxval:
|
||||
return self._maxval
|
||||
return output
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
import torch
|
||||
import warnings
|
||||
from typing import Any, Iterable, List, Tuple
|
||||
|
||||
|
||||
def detach_variable(inputs):
|
||||
def detach_variable(inputs: Tuple[Any, ...]) -> Tuple[torch.Tensor, ...]:
|
||||
if isinstance(inputs, tuple):
|
||||
out = []
|
||||
for inp in inputs:
|
||||
|
|
@ -20,7 +21,7 @@ def detach_variable(inputs):
|
|||
"Only tuple of tensors is supported. Got Unsupported input type: ", type(inputs).__name__)
|
||||
|
||||
|
||||
def check_backward_validity(inputs):
|
||||
def check_backward_validity(inputs: Iterable[Any]) -> None:
|
||||
if not any(inp.requires_grad for inp in inputs if isinstance(inp, torch.Tensor)):
|
||||
warnings.warn("None of the inputs have requires_grad=True. Gradients will be None")
|
||||
|
||||
|
|
@ -32,7 +33,7 @@ def check_backward_validity(inputs):
|
|||
# the device of all Tensor args.
|
||||
#
|
||||
# To consider: maybe get_device_states and set_device_states should reside in torch/random.py?
|
||||
def get_device_states(*args):
|
||||
def get_device_states(*args) -> Tuple[List[int], List[torch.Tensor]]:
|
||||
# This will not error out if "arg" is a CPU tensor or a non-tensor type because
|
||||
# the conditionals short-circuit.
|
||||
fwd_gpu_devices = list(set(arg.get_device() for arg in args
|
||||
|
|
@ -46,7 +47,7 @@ def get_device_states(*args):
|
|||
return fwd_gpu_devices, fwd_gpu_states
|
||||
|
||||
|
||||
def set_device_states(devices, states):
|
||||
def set_device_states(devices, states) -> None:
|
||||
for device, state in zip(devices, states):
|
||||
with torch.cuda.device(device):
|
||||
torch.cuda.set_rng_state(state)
|
||||
|
|
|
|||
|
|
@ -18,13 +18,14 @@ from .file_baton import FileBaton
|
|||
from ._cpp_extension_versioner import ExtensionVersioner
|
||||
from .hipify import hipify_python
|
||||
from .hipify.hipify_python import get_hip_file_path, GeneratedFileCleaner
|
||||
from typing import List, Optional
|
||||
|
||||
from setuptools.command.build_ext import build_ext
|
||||
|
||||
|
||||
IS_WINDOWS = sys.platform == 'win32'
|
||||
|
||||
def _find_cuda_home():
|
||||
def _find_cuda_home() -> Optional[str]:
|
||||
r'''Finds the CUDA install path.'''
|
||||
# Guess #1
|
||||
cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH')
|
||||
|
|
@ -53,7 +54,7 @@ def _find_cuda_home():
|
|||
print("No CUDA runtime is found, using CUDA_HOME='{}'".format(cuda_home))
|
||||
return cuda_home
|
||||
|
||||
def _find_rocm_home():
|
||||
def _find_rocm_home() -> Optional[str]:
|
||||
r'''Finds the ROCm install path.'''
|
||||
# Guess #1
|
||||
rocm_home = os.environ.get('ROCM_HOME') or os.environ.get('ROCM_PATH')
|
||||
|
|
@ -76,7 +77,7 @@ def _find_rocm_home():
|
|||
return rocm_home
|
||||
|
||||
|
||||
def _join_rocm_home(*paths):
|
||||
def _join_rocm_home(*paths) -> str:
|
||||
r'''
|
||||
Joins paths with ROCM_HOME, or raises an error if it ROCM_HOME is not set.
|
||||
|
||||
|
|
@ -172,16 +173,16 @@ PLAT_TO_VCVARS = {
|
|||
}
|
||||
|
||||
|
||||
def _is_binary_build():
|
||||
def _is_binary_build() -> bool:
|
||||
return not BUILT_FROM_SOURCE_VERSION_PATTERN.match(torch.version.__version__)
|
||||
|
||||
|
||||
def _accepted_compilers_for_platform():
|
||||
def _accepted_compilers_for_platform() -> List[str]:
|
||||
# gnu-c++ and gnu-cc are the conda gcc compilers
|
||||
return ['clang++', 'clang'] if sys.platform.startswith('darwin') else ['g++', 'gcc', 'gnu-c++', 'gnu-cc']
|
||||
|
||||
|
||||
def get_default_build_root():
|
||||
def get_default_build_root() -> str:
|
||||
r'''
|
||||
Returns the path to the root folder under which extensions will built.
|
||||
|
||||
|
|
@ -196,7 +197,7 @@ def get_default_build_root():
|
|||
return os.path.realpath(torch._appdirs.user_cache_dir(appname='torch_extensions'))
|
||||
|
||||
|
||||
def check_compiler_ok_for_platform(compiler):
|
||||
def check_compiler_ok_for_platform(compiler: str) -> bool:
|
||||
r'''
|
||||
Verifies that the compiler is the expected one for the current platform.
|
||||
|
||||
|
|
@ -231,7 +232,7 @@ def check_compiler_ok_for_platform(compiler):
|
|||
return False
|
||||
|
||||
|
||||
def check_compiler_abi_compatibility(compiler):
|
||||
def check_compiler_abi_compatibility(compiler) -> bool:
|
||||
r'''
|
||||
Verifies that the given compiler is ABI-compatible with PyTorch.
|
||||
|
||||
|
|
@ -326,7 +327,7 @@ class BuildExtension(build_ext, object):
|
|||
|
||||
return cls_with_options
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super(BuildExtension, self).__init__(*args, **kwargs)
|
||||
self.no_python_abi_suffix = kwargs.get("no_python_abi_suffix", False)
|
||||
|
||||
|
|
@ -339,12 +340,12 @@ class BuildExtension(build_ext, object):
|
|||
warnings.warn(msg.format('we could not find ninja.'))
|
||||
self.use_ninja = False
|
||||
|
||||
def finalize_options(self):
|
||||
def finalize_options(self) -> None:
|
||||
super().finalize_options()
|
||||
if self.use_ninja:
|
||||
self.force = True
|
||||
|
||||
def build_extensions(self):
|
||||
def build_extensions(self) -> None:
|
||||
self._check_abi()
|
||||
for extension in self.extensions:
|
||||
self._add_compile_flag(extension, '-DTORCH_API_INCLUDE_EXTENSION_H')
|
||||
|
|
@ -361,7 +362,7 @@ class BuildExtension(build_ext, object):
|
|||
else:
|
||||
original_compile = self.compiler._compile
|
||||
|
||||
def append_std14_if_no_std_present(cflags):
|
||||
def append_std14_if_no_std_present(cflags) -> None:
|
||||
# NVCC does not allow multiple -std to be passed, so we avoid
|
||||
# overriding the option if the user explicitly passed it.
|
||||
cpp_format_prefix = '/{}:' if self.compiler.compiler_type == 'msvc' else '-{}='
|
||||
|
|
@ -382,15 +383,13 @@ class BuildExtension(build_ext, object):
|
|||
if not os.path.isabs(paths[i]):
|
||||
paths[i] = os.path.abspath(paths[i])
|
||||
|
||||
def unix_wrap_single_compile(obj, src, ext, cc_args, extra_postargs, pp_opts):
|
||||
def unix_wrap_single_compile(obj, src, ext, cc_args, extra_postargs, pp_opts) -> None:
|
||||
# Copy before we make any modifications.
|
||||
cflags = copy.deepcopy(extra_postargs)
|
||||
try:
|
||||
original_compiler = self.compiler.compiler_so
|
||||
if _is_cuda_file(src):
|
||||
nvcc = (_join_rocm_home('bin', 'hipcc') if IS_HIP_EXTENSION else _join_cuda_home('bin', 'nvcc'))
|
||||
if not isinstance(nvcc, list):
|
||||
nvcc = [nvcc]
|
||||
nvcc = [_join_rocm_home('bin', 'hipcc') if IS_HIP_EXTENSION else _join_cuda_home('bin', 'nvcc')]
|
||||
self.compiler.set_executable('compiler_so', nvcc)
|
||||
if isinstance(cflags, dict):
|
||||
cflags = cflags['nvcc']
|
||||
|
|
@ -808,7 +807,7 @@ def CUDAExtension(name, sources, *args, **kwargs):
|
|||
return setuptools.Extension(name, sources, *args, **kwargs)
|
||||
|
||||
|
||||
def include_paths(cuda=False):
|
||||
def include_paths(cuda: bool = False) -> List[str]:
|
||||
'''
|
||||
Get the include paths required to build a C++ or CUDA extension.
|
||||
|
||||
|
|
@ -846,7 +845,7 @@ def include_paths(cuda=False):
|
|||
return paths
|
||||
|
||||
|
||||
def library_paths(cuda=False):
|
||||
def library_paths(cuda: bool = False) -> List[str]:
|
||||
r'''
|
||||
Get the library paths required to build a C++ or CUDA extension.
|
||||
|
||||
|
|
@ -886,14 +885,14 @@ def library_paths(cuda=False):
|
|||
|
||||
|
||||
def load(name,
|
||||
sources,
|
||||
sources: List[str],
|
||||
extra_cflags=None,
|
||||
extra_cuda_cflags=None,
|
||||
extra_ldflags=None,
|
||||
extra_include_paths=None,
|
||||
build_directory=None,
|
||||
verbose=False,
|
||||
with_cuda=None,
|
||||
with_cuda: Optional[bool] = None,
|
||||
is_python_module=True,
|
||||
keep_intermediates=True):
|
||||
r'''
|
||||
|
|
@ -1136,11 +1135,11 @@ def _jit_compile(name,
|
|||
extra_cuda_cflags,
|
||||
extra_ldflags,
|
||||
extra_include_paths,
|
||||
build_directory,
|
||||
verbose,
|
||||
with_cuda,
|
||||
build_directory: str,
|
||||
verbose: bool,
|
||||
with_cuda: Optional[bool],
|
||||
is_python_module,
|
||||
keep_intermediates=True):
|
||||
keep_intermediates=True) -> None:
|
||||
if with_cuda is None:
|
||||
with_cuda = any(map(_is_cuda_file, sources))
|
||||
with_cudnn = any(['cudnn' in f for f in extra_ldflags or []])
|
||||
|
|
@ -1192,20 +1191,20 @@ def _jit_compile(name,
|
|||
'module {}, skipping build step...'.format(name))
|
||||
|
||||
if verbose:
|
||||
print('Loading extension module {}...'.format(name))
|
||||
print(f'Loading extension module {name}...')
|
||||
return _import_module_from_library(name, build_directory, is_python_module)
|
||||
|
||||
|
||||
def _write_ninja_file_and_compile_objects(
|
||||
sources,
|
||||
sources: List[str],
|
||||
objects,
|
||||
cflags,
|
||||
post_cflags,
|
||||
cuda_cflags,
|
||||
cuda_post_cflags,
|
||||
build_directory,
|
||||
verbose,
|
||||
with_cuda):
|
||||
build_directory: str,
|
||||
verbose: bool,
|
||||
with_cuda: Optional[bool]) -> None:
|
||||
verify_ninja_availability()
|
||||
if IS_WINDOWS:
|
||||
compiler = os.environ.get('CXX', 'cl')
|
||||
|
|
@ -1216,8 +1215,7 @@ def _write_ninja_file_and_compile_objects(
|
|||
with_cuda = any(map(_is_cuda_file, sources))
|
||||
build_file_path = os.path.join(build_directory, 'build.ninja')
|
||||
if verbose:
|
||||
print(
|
||||
'Emitting ninja build file {}...'.format(build_file_path))
|
||||
print(f'Emitting ninja build file {build_file_path}...')
|
||||
_write_ninja_file(
|
||||
path=build_file_path,
|
||||
cflags=cflags,
|
||||
|
|
@ -1241,14 +1239,14 @@ def _write_ninja_file_and_compile_objects(
|
|||
|
||||
def _write_ninja_file_and_build_library(
|
||||
name,
|
||||
sources,
|
||||
sources: List[str],
|
||||
extra_cflags,
|
||||
extra_cuda_cflags,
|
||||
extra_ldflags,
|
||||
extra_include_paths,
|
||||
build_directory,
|
||||
verbose,
|
||||
with_cuda):
|
||||
build_directory: str,
|
||||
verbose: bool,
|
||||
with_cuda: Optional[bool]) -> None:
|
||||
verify_ninja_availability()
|
||||
if IS_WINDOWS:
|
||||
compiler = os.environ.get('CXX', 'cl')
|
||||
|
|
@ -1263,8 +1261,7 @@ def _write_ninja_file_and_build_library(
|
|||
verbose)
|
||||
build_file_path = os.path.join(build_directory, 'build.ninja')
|
||||
if verbose:
|
||||
print(
|
||||
'Emitting ninja build file {}...'.format(build_file_path))
|
||||
print(f'Emitting ninja build file {build_file_path}...')
|
||||
# NOTE: Emitting a new ninja build file does not cause re-compilation if
|
||||
# the sources did not change, so it's ok to re-emit (and it's fast).
|
||||
_write_ninja_file_to_build_library(
|
||||
|
|
@ -1362,7 +1359,7 @@ def _prepare_ldflags(extra_ldflags, with_cuda, verbose):
|
|||
return extra_ldflags
|
||||
|
||||
|
||||
def _get_cuda_arch_flags(cflags=None):
|
||||
def _get_cuda_arch_flags(cflags: Optional[List[str]] = None) -> List[str]:
|
||||
r'''
|
||||
Determine CUDA arch flags to use.
|
||||
|
||||
|
|
@ -1430,7 +1427,7 @@ def _get_cuda_arch_flags(cflags=None):
|
|||
return list(set(flags))
|
||||
|
||||
|
||||
def _get_rocm_arch_flags(cflags=None):
|
||||
def _get_rocm_arch_flags(cflags: Optional[List[str]] = None) -> List[str]:
|
||||
# If cflags is given, there may already be user-provided arch flags in it
|
||||
# (from `extra_compile_args`)
|
||||
if cflags is not None:
|
||||
|
|
@ -1446,7 +1443,7 @@ def _get_rocm_arch_flags(cflags=None):
|
|||
]
|
||||
|
||||
|
||||
def _get_build_directory(name, verbose):
|
||||
def _get_build_directory(name: str, verbose: bool) -> str:
|
||||
root_extensions_directory = os.environ.get('TORCH_EXTENSIONS_DIR')
|
||||
if root_extensions_directory is None:
|
||||
root_extensions_directory = get_default_build_root()
|
||||
|
|
@ -1458,14 +1455,14 @@ def _get_build_directory(name, verbose):
|
|||
build_directory = os.path.join(root_extensions_directory, name)
|
||||
if not os.path.exists(build_directory):
|
||||
if verbose:
|
||||
print('Creating extension directory {}...'.format(build_directory))
|
||||
print(f'Creating extension directory {build_directory}...')
|
||||
# This is like mkdir -p, i.e. will also create parent directories.
|
||||
os.makedirs(build_directory, exist_ok=True)
|
||||
|
||||
return build_directory
|
||||
|
||||
|
||||
def _get_num_workers(verbose):
|
||||
def _get_num_workers(verbose: bool) -> Optional[int]:
|
||||
max_jobs = os.environ.get('MAX_JOBS')
|
||||
if max_jobs is not None and max_jobs.isdigit():
|
||||
if verbose:
|
||||
|
|
@ -1477,7 +1474,7 @@ def _get_num_workers(verbose):
|
|||
return None
|
||||
|
||||
|
||||
def _run_ninja_build(build_directory, verbose, error_prefix):
|
||||
def _run_ninja_build(build_directory: str, verbose: bool, error_prefix: str) -> None:
|
||||
command = ['ninja', '-v']
|
||||
num_workers = _get_num_workers(verbose)
|
||||
if num_workers is not None:
|
||||
|
|
@ -1558,7 +1555,7 @@ def _write_ninja_file_to_build_library(path,
|
|||
extra_cuda_cflags,
|
||||
extra_ldflags,
|
||||
extra_include_paths,
|
||||
with_cuda):
|
||||
with_cuda) -> None:
|
||||
extra_cflags = [flag.strip() for flag in extra_cflags]
|
||||
extra_cuda_cflags = [flag.strip() for flag in extra_cuda_cflags]
|
||||
extra_ldflags = [flag.strip() for flag in extra_ldflags]
|
||||
|
|
@ -1617,7 +1614,7 @@ def _write_ninja_file_to_build_library(path,
|
|||
else:
|
||||
cuda_flags = None
|
||||
|
||||
def object_file_path(source_file):
|
||||
def object_file_path(source_file: str) -> str:
|
||||
# '/path/to/file.cpp' -> 'file'
|
||||
file_name = os.path.splitext(os.path.basename(source_file))[0]
|
||||
if _is_cuda_file(source_file) and with_cuda:
|
||||
|
|
@ -1665,7 +1662,7 @@ def _write_ninja_file(path,
|
|||
objects,
|
||||
ldflags,
|
||||
library_target,
|
||||
with_cuda):
|
||||
with_cuda) -> None:
|
||||
r"""Write a ninja file that does the desired compiling and linking.
|
||||
|
||||
`path`: Where to write this file
|
||||
|
|
@ -1783,7 +1780,7 @@ def _write_ninja_file(path,
|
|||
build_file.write('{}\n\n'.format(lines))
|
||||
|
||||
|
||||
def _join_cuda_home(*paths):
|
||||
def _join_cuda_home(*paths) -> str:
|
||||
r'''
|
||||
Joins paths with CUDA_HOME, or raises an error if it CUDA_HOME is not set.
|
||||
|
||||
|
|
@ -1796,7 +1793,7 @@ def _join_cuda_home(*paths):
|
|||
return os.path.join(CUDA_HOME, *paths)
|
||||
|
||||
|
||||
def _is_cuda_file(path):
|
||||
def _is_cuda_file(path: str) -> bool:
|
||||
valid_ext = ['.cu', '.cuh']
|
||||
if IS_HIP_EXTENSION:
|
||||
valid_ext.append('.hip')
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user