From 0be0de4ffa055d098116db2a370cbe2eb403570c Mon Sep 17 00:00:00 2001 From: Maggie Moss Date: Tue, 21 Oct 2025 02:54:18 +0000 Subject: [PATCH] Add type suppressions to _inductor/runtime (#165918) Original PR that did this was reverted due to merge conflicts. Trying it again Pull Request resolved: https://github.com/pytorch/pytorch/pull/165918 Approved by: https://github.com/oulgen --- pyrefly.toml | 4 +++- .../torchfuzz/multi_process_fuzzer.py | 1 + torch/_inductor/runtime/autotune_cache.py | 8 ++++++++ torch/_inductor/runtime/benchmarking.py | 2 ++ .../runtime/caching/implementations.py | 1 + .../runtime/coordinate_descent_tuner.py | 5 ++++- torch/_inductor/runtime/hints.py | 2 ++ torch/_inductor/runtime/runtime_utils.py | 5 +++++ torch/_inductor/runtime/static_cuda_launcher.py | 17 +++++++++++++++++ torch/backends/mkldnn/__init__.py | 1 + torch/cuda/memory.py | 3 +++ 11 files changed, 47 insertions(+), 2 deletions(-) diff --git a/pyrefly.toml b/pyrefly.toml index 5516963d262..3dd62368184 100644 --- a/pyrefly.toml +++ b/pyrefly.toml @@ -23,10 +23,12 @@ project-includes = [ project-excludes = [ # ==== below will be enabled directory by directory ==== # ==== to test Pyrefly on a specific directory, simply comment it out ==== - "torch/_inductor/runtime", "torch/_inductor/codegen/triton.py", "tools/linter/adapters/test_device_bias_linter.py", "tools/code_analyzer/gen_operators_yaml.py", + "torch/_inductor/runtime/triton_heuristics.py", + "torch/_inductor/runtime/triton_helpers.py", + "torch/_inductor/runtime/halide_helpers.py", # formatting issues, will turn on after adjusting where suppressions can be # in import statements "tools/flight_recorder/components/types.py", diff --git a/tools/experimental/torchfuzz/multi_process_fuzzer.py b/tools/experimental/torchfuzz/multi_process_fuzzer.py index bbaf7d669b5..37573f5940c 100644 --- a/tools/experimental/torchfuzz/multi_process_fuzzer.py +++ b/tools/experimental/torchfuzz/multi_process_fuzzer.py @@ -591,6 +591,7 @@ def run_until_failure( pbar.set_postfix_str(f"{total_successful}/{total_ignored}") def write_func(msg): + # pyrefly: ignore # missing-attribute pbar.write(msg) else: pbar = None diff --git a/torch/_inductor/runtime/autotune_cache.py b/torch/_inductor/runtime/autotune_cache.py index 3c55a9cd1b0..63d7a52ff7d 100644 --- a/torch/_inductor/runtime/autotune_cache.py +++ b/torch/_inductor/runtime/autotune_cache.py @@ -275,8 +275,11 @@ class AutotuneCache: triton_cache_hash: str | None = None, ) -> None: data = { + # pyrefly: ignore # missing-attribute **config.kwargs, + # pyrefly: ignore # missing-attribute "num_warps": config.num_warps, + # pyrefly: ignore # missing-attribute "num_stages": config.num_stages, "configs_hash": self.configs_hash, "found_by_coordesc": found_by_coordesc, @@ -570,15 +573,20 @@ def _load_cached_autotuning( ) # Create the triton_config with the appropriate arguments + # pyrefly: ignore # bad-argument-count triton_config = Config(best_config, **config_args) + # pyrefly: ignore # missing-attribute triton_config.found_by_coordesc = True return triton_config matching_configs = [ cfg for cfg in configs + # pyrefly: ignore # missing-attribute if all(val == best_config.get(key) for key, val in cfg.kwargs.items()) + # pyrefly: ignore # missing-attribute and cfg.num_warps == best_config.get("num_warps") + # pyrefly: ignore # missing-attribute and cfg.num_stages == best_config.get("num_stages") ] if len(matching_configs) != 1: diff --git a/torch/_inductor/runtime/benchmarking.py b/torch/_inductor/runtime/benchmarking.py index 698484658dd..ee504b1a057 100644 --- a/torch/_inductor/runtime/benchmarking.py +++ b/torch/_inductor/runtime/benchmarking.py @@ -123,6 +123,7 @@ class Benchmarker: - The runtime of `fn(*fn_args, **fn_kwargs)`, in milliseconds. """ inferred_device = None + # pyrefly: ignore # bad-assignment for arg_or_kwarg in chain(fn_args, fn_kwargs.values()): if not isinstance(arg_or_kwarg, torch.Tensor): continue @@ -196,6 +197,7 @@ class TritonBenchmarker(Benchmarker): @may_distort_benchmarking_result @time_and_count + # pyrefly: ignore # bad-override def benchmark_gpu( self: Self, _callable: Callable[[], Any], diff --git a/torch/_inductor/runtime/caching/implementations.py b/torch/_inductor/runtime/caching/implementations.py index abc113caae9..8292b957f56 100644 --- a/torch/_inductor/runtime/caching/implementations.py +++ b/torch/_inductor/runtime/caching/implementations.py @@ -190,6 +190,7 @@ class _OnDiskCacheImpl(_CacheImpl): Defaults to empty string if not specified. """ self._cache_dir: Path = self._base_dir / (sub_dir or "") + # pyrefly: ignore # bad-assignment self._flock: FileLock = FileLock(str(self._cache_dir / "dir.lock")) @property diff --git a/torch/_inductor/runtime/coordinate_descent_tuner.py b/torch/_inductor/runtime/coordinate_descent_tuner.py index 68db68ca11c..526669d7ca8 100644 --- a/torch/_inductor/runtime/coordinate_descent_tuner.py +++ b/torch/_inductor/runtime/coordinate_descent_tuner.py @@ -186,6 +186,7 @@ class CoordescTuner: def check_all_tuning_directions( self, + # pyrefly: ignore # missing-attribute func: Callable[["triton.Config"], float], best_config, best_timing, @@ -255,10 +256,12 @@ class CoordescTuner: def autotune( self, + # pyrefly: ignore # missing-attribute func: Callable[["triton.Config"], float], + # pyrefly: ignore # missing-attribute baseline_config: "triton.Config", baseline_timing: float | None = None, - ) -> "triton.Config": + ) -> "triton.Config": # pyrefly: ignore # missing-attribute if baseline_timing is None: baseline_timing = self.call_func(func, baseline_config) diff --git a/torch/_inductor/runtime/hints.py b/torch/_inductor/runtime/hints.py index 10a5a9749a5..98357daacab 100644 --- a/torch/_inductor/runtime/hints.py +++ b/torch/_inductor/runtime/hints.py @@ -89,11 +89,13 @@ if has_triton_package(): divisible_by_16=None, equal_to_1=None, ): + # pyrefly: ignore # not-iterable return {(x,): [["tt.divisibility", 16]] for x in divisible_by_16} else: # Define a namedtuple as a fallback when AttrsDescriptor is not available AttrsDescriptorWrapper = collections.namedtuple( # type: ignore[no-redef, name-match] + # pyrefly: ignore # invalid-argument "AttrsDescriptor", ["divisible_by_16", "equal_to_1"], defaults=[(), ()], diff --git a/torch/_inductor/runtime/runtime_utils.py b/torch/_inductor/runtime/runtime_utils.py index 21cd5987f8f..30087d95663 100644 --- a/torch/_inductor/runtime/runtime_utils.py +++ b/torch/_inductor/runtime/runtime_utils.py @@ -68,8 +68,11 @@ def triton_config_to_hashable(cfg: Config) -> Hashable: Convert triton config to a tuple that can uniquely identify it. We can use the return value as a dictionary key. """ + # pyrefly: ignore # missing-attribute items = sorted(cfg.kwargs.items()) + # pyrefly: ignore # missing-attribute items.append(("num_warps", cfg.num_warps)) + # pyrefly: ignore # missing-attribute items.append(("num_stages", cfg.num_stages)) return tuple(items) @@ -103,6 +106,7 @@ def get_max_y_grid() -> int: try: + # pyrefly: ignore # import-error import colorama HAS_COLORAMA = True @@ -114,6 +118,7 @@ except ModuleNotFoundError: if HAS_COLORAMA: def _color_text(msg: str, color: str) -> str: + # pyrefly: ignore # missing-attribute return getattr(colorama.Fore, color.upper()) + msg + colorama.Fore.RESET else: diff --git a/torch/_inductor/runtime/static_cuda_launcher.py b/torch/_inductor/runtime/static_cuda_launcher.py index a5e511052b2..e7d4705740e 100644 --- a/torch/_inductor/runtime/static_cuda_launcher.py +++ b/torch/_inductor/runtime/static_cuda_launcher.py @@ -34,21 +34,29 @@ class StaticallyLaunchedCudaKernel: """ def __init__(self, kernel: CompiledKernel) -> None: + # pyrefly: ignore # missing-attribute self.name = kernel.src.fn.__name__ + # pyrefly: ignore # missing-attribute self.cubin_raw = kernel.asm.get("cubin", None) + # pyrefly: ignore # missing-attribute self.cubin_path = kernel._cubin_path # Used by torch.compile to filter constants in older triton versions + # pyrefly: ignore # missing-attribute self.arg_names = kernel.src.fn.arg_names # Const exprs that are declared by the triton kernel directly # Used to generate the kernel launcher's def args + # pyrefly: ignore # missing-attribute self.declared_constexprs = kernel.src.fn.constexprs + # pyrefly: ignore # missing-attribute self.hash = kernel.hash if triton_knobs is None: + # pyrefly: ignore # missing-attribute launch_enter = kernel.__class__.launch_enter_hook + # pyrefly: ignore # missing-attribute launch_exit = kernel.__class__.launch_exit_hook else: launch_enter = triton_knobs.runtime.launch_enter_hook @@ -70,12 +78,15 @@ class StaticallyLaunchedCudaKernel: raise NotImplementedError( "We don't support launch enter or launch exit hooks" ) + # pyrefly: ignore # missing-attribute self.num_warps = kernel.metadata.num_warps self.shared = ( + # pyrefly: ignore # missing-attribute kernel.shared if hasattr(kernel, "shared") else kernel.metadata.shared ) def needs_scratch_arg(scratch_name: str, param_name: str) -> bool: + # pyrefly: ignore # missing-attribute if hasattr(kernel.metadata, param_name): if getattr(kernel.metadata, param_name) > 0: raise NotImplementedError( @@ -91,6 +102,7 @@ class StaticallyLaunchedCudaKernel: # same situation for profile scratch - triton-lang/triton#7258 self.has_profile_scratch = needs_scratch_arg("Profile", "profile_scratch_size") + # pyrefly: ignore # missing-attribute self.arg_tys = self.arg_ty_from_signature(kernel.src) self.function: int | None = None # Loaded by load_kernel(on the parent process) num_ctas = 1 @@ -170,6 +182,7 @@ class StaticallyLaunchedCudaKernel: def arg_ty_from_signature(self, src: ASTSource) -> str: def index_key(i: Any) -> int: if isinstance(i, str): + # pyrefly: ignore # missing-attribute return src.fn.arg_names.index(i) elif isinstance(i, tuple): # In triton 3.3, src.fn.constants has tuples as a key @@ -177,6 +190,7 @@ class StaticallyLaunchedCudaKernel: else: return i + # pyrefly: ignore # missing-attribute signature = {index_key(key): value for key, value in src.signature.items()} # Triton uses these as the main way to filter out constants passed to their cubin constants = [index_key(key) for key in getattr(src, "constants", dict())] @@ -198,6 +212,7 @@ class StaticallyLaunchedCudaKernel: if ty == "constexpr" or i in constants: pass else: + # pyrefly: ignore # bad-argument-type params.append(self.extract_type(ty)) return "".join(params) @@ -235,6 +250,7 @@ class StaticallyLaunchedCudaKernel: if has_scratch: arg_tys = arg_tys + "O" args = (*args, None) + # pyrefly: ignore # bad-argument-type assert len(args) == len(arg_tys) # TODO: can handle grid functions here or in C++, so @@ -247,6 +263,7 @@ class StaticallyLaunchedCudaKernel: self.num_warps, self.shared, arg_tys, + # pyrefly: ignore # bad-argument-type args, stream, ) diff --git a/torch/backends/mkldnn/__init__.py b/torch/backends/mkldnn/__init__.py index a98c2cb64df..e4b30125144 100644 --- a/torch/backends/mkldnn/__init__.py +++ b/torch/backends/mkldnn/__init__.py @@ -21,6 +21,7 @@ def is_available(): def is_acl_available(): r"""Return whether PyTorch is built with MKL-DNN + ACL support.""" + # pyrefly: ignore # missing-attribute return torch._C._has_mkldnn_acl diff --git a/torch/cuda/memory.py b/torch/cuda/memory.py index 5844f53da84..dc4c3827c8a 100644 --- a/torch/cuda/memory.py +++ b/torch/cuda/memory.py @@ -1074,6 +1074,7 @@ def _set_memory_metadata(metadata: str): metadata (str): Custom metadata string to attach to allocations. Pass an empty string to clear the metadata. """ + # pyrefly: ignore # missing-attribute torch._C._cuda_setMemoryMetadata(metadata) @@ -1084,6 +1085,7 @@ def _get_memory_metadata() -> str: Returns: str: The current metadata string, or empty string if no metadata is set. """ + # pyrefly: ignore # missing-attribute return torch._C._cuda_getMemoryMetadata() @@ -1106,6 +1108,7 @@ def _save_memory_usage(filename="output.svg", snapshot=None): category=FutureWarning, ) def _set_allocator_settings(env: str): + # pyrefly: ignore # missing-attribute return torch._C._accelerator_setAllocatorSettings(env)