Add type suppressions to _inductor/runtime (#165918)

Original PR that did this was reverted due to merge conflicts.

Trying it again

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165918
Approved by: https://github.com/oulgen
This commit is contained in:
Maggie Moss 2025-10-21 02:54:18 +00:00 committed by PyTorch MergeBot
parent 7406d2e665
commit 0be0de4ffa
11 changed files with 47 additions and 2 deletions

View File

@ -23,10 +23,12 @@ project-includes = [
project-excludes = [ project-excludes = [
# ==== below will be enabled directory by directory ==== # ==== below will be enabled directory by directory ====
# ==== to test Pyrefly on a specific directory, simply comment it out ==== # ==== to test Pyrefly on a specific directory, simply comment it out ====
"torch/_inductor/runtime",
"torch/_inductor/codegen/triton.py", "torch/_inductor/codegen/triton.py",
"tools/linter/adapters/test_device_bias_linter.py", "tools/linter/adapters/test_device_bias_linter.py",
"tools/code_analyzer/gen_operators_yaml.py", "tools/code_analyzer/gen_operators_yaml.py",
"torch/_inductor/runtime/triton_heuristics.py",
"torch/_inductor/runtime/triton_helpers.py",
"torch/_inductor/runtime/halide_helpers.py",
# formatting issues, will turn on after adjusting where suppressions can be # formatting issues, will turn on after adjusting where suppressions can be
# in import statements # in import statements
"tools/flight_recorder/components/types.py", "tools/flight_recorder/components/types.py",

View File

@ -591,6 +591,7 @@ def run_until_failure(
pbar.set_postfix_str(f"{total_successful}/{total_ignored}") pbar.set_postfix_str(f"{total_successful}/{total_ignored}")
def write_func(msg): def write_func(msg):
# pyrefly: ignore # missing-attribute
pbar.write(msg) pbar.write(msg)
else: else:
pbar = None pbar = None

View File

@ -275,8 +275,11 @@ class AutotuneCache:
triton_cache_hash: str | None = None, triton_cache_hash: str | None = None,
) -> None: ) -> None:
data = { data = {
# pyrefly: ignore # missing-attribute
**config.kwargs, **config.kwargs,
# pyrefly: ignore # missing-attribute
"num_warps": config.num_warps, "num_warps": config.num_warps,
# pyrefly: ignore # missing-attribute
"num_stages": config.num_stages, "num_stages": config.num_stages,
"configs_hash": self.configs_hash, "configs_hash": self.configs_hash,
"found_by_coordesc": found_by_coordesc, "found_by_coordesc": found_by_coordesc,
@ -570,15 +573,20 @@ def _load_cached_autotuning(
) )
# Create the triton_config with the appropriate arguments # Create the triton_config with the appropriate arguments
# pyrefly: ignore # bad-argument-count
triton_config = Config(best_config, **config_args) triton_config = Config(best_config, **config_args)
# pyrefly: ignore # missing-attribute
triton_config.found_by_coordesc = True triton_config.found_by_coordesc = True
return triton_config return triton_config
matching_configs = [ matching_configs = [
cfg cfg
for cfg in configs for cfg in configs
# pyrefly: ignore # missing-attribute
if all(val == best_config.get(key) for key, val in cfg.kwargs.items()) if all(val == best_config.get(key) for key, val in cfg.kwargs.items())
# pyrefly: ignore # missing-attribute
and cfg.num_warps == best_config.get("num_warps") and cfg.num_warps == best_config.get("num_warps")
# pyrefly: ignore # missing-attribute
and cfg.num_stages == best_config.get("num_stages") and cfg.num_stages == best_config.get("num_stages")
] ]
if len(matching_configs) != 1: if len(matching_configs) != 1:

View File

@ -123,6 +123,7 @@ class Benchmarker:
- The runtime of `fn(*fn_args, **fn_kwargs)`, in milliseconds. - The runtime of `fn(*fn_args, **fn_kwargs)`, in milliseconds.
""" """
inferred_device = None inferred_device = None
# pyrefly: ignore # bad-assignment
for arg_or_kwarg in chain(fn_args, fn_kwargs.values()): for arg_or_kwarg in chain(fn_args, fn_kwargs.values()):
if not isinstance(arg_or_kwarg, torch.Tensor): if not isinstance(arg_or_kwarg, torch.Tensor):
continue continue
@ -196,6 +197,7 @@ class TritonBenchmarker(Benchmarker):
@may_distort_benchmarking_result @may_distort_benchmarking_result
@time_and_count @time_and_count
# pyrefly: ignore # bad-override
def benchmark_gpu( def benchmark_gpu(
self: Self, self: Self,
_callable: Callable[[], Any], _callable: Callable[[], Any],

View File

@ -190,6 +190,7 @@ class _OnDiskCacheImpl(_CacheImpl):
Defaults to empty string if not specified. Defaults to empty string if not specified.
""" """
self._cache_dir: Path = self._base_dir / (sub_dir or "") self._cache_dir: Path = self._base_dir / (sub_dir or "")
# pyrefly: ignore # bad-assignment
self._flock: FileLock = FileLock(str(self._cache_dir / "dir.lock")) self._flock: FileLock = FileLock(str(self._cache_dir / "dir.lock"))
@property @property

View File

@ -186,6 +186,7 @@ class CoordescTuner:
def check_all_tuning_directions( def check_all_tuning_directions(
self, self,
# pyrefly: ignore # missing-attribute
func: Callable[["triton.Config"], float], func: Callable[["triton.Config"], float],
best_config, best_config,
best_timing, best_timing,
@ -255,10 +256,12 @@ class CoordescTuner:
def autotune( def autotune(
self, self,
# pyrefly: ignore # missing-attribute
func: Callable[["triton.Config"], float], func: Callable[["triton.Config"], float],
# pyrefly: ignore # missing-attribute
baseline_config: "triton.Config", baseline_config: "triton.Config",
baseline_timing: float | None = None, baseline_timing: float | None = None,
) -> "triton.Config": ) -> "triton.Config": # pyrefly: ignore # missing-attribute
if baseline_timing is None: if baseline_timing is None:
baseline_timing = self.call_func(func, baseline_config) baseline_timing = self.call_func(func, baseline_config)

View File

@ -89,11 +89,13 @@ if has_triton_package():
divisible_by_16=None, divisible_by_16=None,
equal_to_1=None, equal_to_1=None,
): ):
# pyrefly: ignore # not-iterable
return {(x,): [["tt.divisibility", 16]] for x in divisible_by_16} return {(x,): [["tt.divisibility", 16]] for x in divisible_by_16}
else: else:
# Define a namedtuple as a fallback when AttrsDescriptor is not available # Define a namedtuple as a fallback when AttrsDescriptor is not available
AttrsDescriptorWrapper = collections.namedtuple( # type: ignore[no-redef, name-match] AttrsDescriptorWrapper = collections.namedtuple( # type: ignore[no-redef, name-match]
# pyrefly: ignore # invalid-argument
"AttrsDescriptor", "AttrsDescriptor",
["divisible_by_16", "equal_to_1"], ["divisible_by_16", "equal_to_1"],
defaults=[(), ()], defaults=[(), ()],

View File

@ -68,8 +68,11 @@ def triton_config_to_hashable(cfg: Config) -> Hashable:
Convert triton config to a tuple that can uniquely identify it. We can use Convert triton config to a tuple that can uniquely identify it. We can use
the return value as a dictionary key. the return value as a dictionary key.
""" """
# pyrefly: ignore # missing-attribute
items = sorted(cfg.kwargs.items()) items = sorted(cfg.kwargs.items())
# pyrefly: ignore # missing-attribute
items.append(("num_warps", cfg.num_warps)) items.append(("num_warps", cfg.num_warps))
# pyrefly: ignore # missing-attribute
items.append(("num_stages", cfg.num_stages)) items.append(("num_stages", cfg.num_stages))
return tuple(items) return tuple(items)
@ -103,6 +106,7 @@ def get_max_y_grid() -> int:
try: try:
# pyrefly: ignore # import-error
import colorama import colorama
HAS_COLORAMA = True HAS_COLORAMA = True
@ -114,6 +118,7 @@ except ModuleNotFoundError:
if HAS_COLORAMA: if HAS_COLORAMA:
def _color_text(msg: str, color: str) -> str: def _color_text(msg: str, color: str) -> str:
# pyrefly: ignore # missing-attribute
return getattr(colorama.Fore, color.upper()) + msg + colorama.Fore.RESET return getattr(colorama.Fore, color.upper()) + msg + colorama.Fore.RESET
else: else:

View File

@ -34,21 +34,29 @@ class StaticallyLaunchedCudaKernel:
""" """
def __init__(self, kernel: CompiledKernel) -> None: def __init__(self, kernel: CompiledKernel) -> None:
# pyrefly: ignore # missing-attribute
self.name = kernel.src.fn.__name__ self.name = kernel.src.fn.__name__
# pyrefly: ignore # missing-attribute
self.cubin_raw = kernel.asm.get("cubin", None) self.cubin_raw = kernel.asm.get("cubin", None)
# pyrefly: ignore # missing-attribute
self.cubin_path = kernel._cubin_path self.cubin_path = kernel._cubin_path
# Used by torch.compile to filter constants in older triton versions # Used by torch.compile to filter constants in older triton versions
# pyrefly: ignore # missing-attribute
self.arg_names = kernel.src.fn.arg_names self.arg_names = kernel.src.fn.arg_names
# Const exprs that are declared by the triton kernel directly # Const exprs that are declared by the triton kernel directly
# Used to generate the kernel launcher's def args # Used to generate the kernel launcher's def args
# pyrefly: ignore # missing-attribute
self.declared_constexprs = kernel.src.fn.constexprs self.declared_constexprs = kernel.src.fn.constexprs
# pyrefly: ignore # missing-attribute
self.hash = kernel.hash self.hash = kernel.hash
if triton_knobs is None: if triton_knobs is None:
# pyrefly: ignore # missing-attribute
launch_enter = kernel.__class__.launch_enter_hook launch_enter = kernel.__class__.launch_enter_hook
# pyrefly: ignore # missing-attribute
launch_exit = kernel.__class__.launch_exit_hook launch_exit = kernel.__class__.launch_exit_hook
else: else:
launch_enter = triton_knobs.runtime.launch_enter_hook launch_enter = triton_knobs.runtime.launch_enter_hook
@ -70,12 +78,15 @@ class StaticallyLaunchedCudaKernel:
raise NotImplementedError( raise NotImplementedError(
"We don't support launch enter or launch exit hooks" "We don't support launch enter or launch exit hooks"
) )
# pyrefly: ignore # missing-attribute
self.num_warps = kernel.metadata.num_warps self.num_warps = kernel.metadata.num_warps
self.shared = ( self.shared = (
# pyrefly: ignore # missing-attribute
kernel.shared if hasattr(kernel, "shared") else kernel.metadata.shared kernel.shared if hasattr(kernel, "shared") else kernel.metadata.shared
) )
def needs_scratch_arg(scratch_name: str, param_name: str) -> bool: def needs_scratch_arg(scratch_name: str, param_name: str) -> bool:
# pyrefly: ignore # missing-attribute
if hasattr(kernel.metadata, param_name): if hasattr(kernel.metadata, param_name):
if getattr(kernel.metadata, param_name) > 0: if getattr(kernel.metadata, param_name) > 0:
raise NotImplementedError( raise NotImplementedError(
@ -91,6 +102,7 @@ class StaticallyLaunchedCudaKernel:
# same situation for profile scratch - triton-lang/triton#7258 # same situation for profile scratch - triton-lang/triton#7258
self.has_profile_scratch = needs_scratch_arg("Profile", "profile_scratch_size") self.has_profile_scratch = needs_scratch_arg("Profile", "profile_scratch_size")
# pyrefly: ignore # missing-attribute
self.arg_tys = self.arg_ty_from_signature(kernel.src) self.arg_tys = self.arg_ty_from_signature(kernel.src)
self.function: int | None = None # Loaded by load_kernel(on the parent process) self.function: int | None = None # Loaded by load_kernel(on the parent process)
num_ctas = 1 num_ctas = 1
@ -170,6 +182,7 @@ class StaticallyLaunchedCudaKernel:
def arg_ty_from_signature(self, src: ASTSource) -> str: def arg_ty_from_signature(self, src: ASTSource) -> str:
def index_key(i: Any) -> int: def index_key(i: Any) -> int:
if isinstance(i, str): if isinstance(i, str):
# pyrefly: ignore # missing-attribute
return src.fn.arg_names.index(i) return src.fn.arg_names.index(i)
elif isinstance(i, tuple): elif isinstance(i, tuple):
# In triton 3.3, src.fn.constants has tuples as a key # In triton 3.3, src.fn.constants has tuples as a key
@ -177,6 +190,7 @@ class StaticallyLaunchedCudaKernel:
else: else:
return i return i
# pyrefly: ignore # missing-attribute
signature = {index_key(key): value for key, value in src.signature.items()} signature = {index_key(key): value for key, value in src.signature.items()}
# Triton uses these as the main way to filter out constants passed to their cubin # Triton uses these as the main way to filter out constants passed to their cubin
constants = [index_key(key) for key in getattr(src, "constants", dict())] constants = [index_key(key) for key in getattr(src, "constants", dict())]
@ -198,6 +212,7 @@ class StaticallyLaunchedCudaKernel:
if ty == "constexpr" or i in constants: if ty == "constexpr" or i in constants:
pass pass
else: else:
# pyrefly: ignore # bad-argument-type
params.append(self.extract_type(ty)) params.append(self.extract_type(ty))
return "".join(params) return "".join(params)
@ -235,6 +250,7 @@ class StaticallyLaunchedCudaKernel:
if has_scratch: if has_scratch:
arg_tys = arg_tys + "O" arg_tys = arg_tys + "O"
args = (*args, None) args = (*args, None)
# pyrefly: ignore # bad-argument-type
assert len(args) == len(arg_tys) assert len(args) == len(arg_tys)
# TODO: can handle grid functions here or in C++, so # TODO: can handle grid functions here or in C++, so
@ -247,6 +263,7 @@ class StaticallyLaunchedCudaKernel:
self.num_warps, self.num_warps,
self.shared, self.shared,
arg_tys, arg_tys,
# pyrefly: ignore # bad-argument-type
args, args,
stream, stream,
) )

View File

@ -21,6 +21,7 @@ def is_available():
def is_acl_available(): def is_acl_available():
r"""Return whether PyTorch is built with MKL-DNN + ACL support.""" r"""Return whether PyTorch is built with MKL-DNN + ACL support."""
# pyrefly: ignore # missing-attribute
return torch._C._has_mkldnn_acl return torch._C._has_mkldnn_acl

View File

@ -1074,6 +1074,7 @@ def _set_memory_metadata(metadata: str):
metadata (str): Custom metadata string to attach to allocations. metadata (str): Custom metadata string to attach to allocations.
Pass an empty string to clear the metadata. Pass an empty string to clear the metadata.
""" """
# pyrefly: ignore # missing-attribute
torch._C._cuda_setMemoryMetadata(metadata) torch._C._cuda_setMemoryMetadata(metadata)
@ -1084,6 +1085,7 @@ def _get_memory_metadata() -> str:
Returns: Returns:
str: The current metadata string, or empty string if no metadata is set. str: The current metadata string, or empty string if no metadata is set.
""" """
# pyrefly: ignore # missing-attribute
return torch._C._cuda_getMemoryMetadata() return torch._C._cuda_getMemoryMetadata()
@ -1106,6 +1108,7 @@ def _save_memory_usage(filename="output.svg", snapshot=None):
category=FutureWarning, category=FutureWarning,
) )
def _set_allocator_settings(env: str): def _set_allocator_settings(env: str):
# pyrefly: ignore # missing-attribute
return torch._C._accelerator_setAllocatorSettings(env) return torch._C._accelerator_setAllocatorSettings(env)