mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Revert "Pyrefly suppressions 2 (#165692)"
This reverts commit 43d78423ac.
Reverted https://github.com/pytorch/pytorch/pull/165692 on behalf of https://github.com/seemethere due to This is causing merge conflicts when attempting to land internally, see D84890919 for more details ([comment](https://github.com/pytorch/pytorch/pull/165692#issuecomment-3416397240))
This commit is contained in:
parent
630520b346
commit
2928c5c572
|
|
@ -22,10 +22,8 @@ project-includes = [
|
||||||
project-excludes = [
|
project-excludes = [
|
||||||
# ==== below will be enabled directory by directory ====
|
# ==== below will be enabled directory by directory ====
|
||||||
# ==== to test Pyrefly on a specific directory, simply comment it out ====
|
# ==== to test Pyrefly on a specific directory, simply comment it out ====
|
||||||
|
"torch/_inductor/runtime",
|
||||||
"torch/_inductor/codegen/triton.py",
|
"torch/_inductor/codegen/triton.py",
|
||||||
"torch/_inductor/runtime/triton_helpers.py",
|
|
||||||
"torch/_inductor/runtime/triton_heuristics.py",
|
|
||||||
"torch/_inductor/runtime/halide_helpers.py",
|
|
||||||
# formatting issues, will turn on after adjusting where suppressions can be
|
# formatting issues, will turn on after adjusting where suppressions can be
|
||||||
# in import statements
|
# in import statements
|
||||||
"torch/linalg/__init__.py",
|
"torch/linalg/__init__.py",
|
||||||
|
|
|
||||||
|
|
@ -1739,7 +1739,6 @@ class KernelArgs:
|
||||||
for outer, inner in chain(
|
for outer, inner in chain(
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore # bad-argument-type
|
||||||
self.input_buffers.items(),
|
self.input_buffers.items(),
|
||||||
# pyrefly: ignore # bad-argument-type
|
|
||||||
self.output_buffers.items(),
|
self.output_buffers.items(),
|
||||||
):
|
):
|
||||||
if outer in self.inplace_buffers or isinstance(inner, RemovedArg):
|
if outer in self.inplace_buffers or isinstance(inner, RemovedArg):
|
||||||
|
|
|
||||||
|
|
@ -1480,7 +1480,6 @@ class CppGemmTemplate(CppTemplate):
|
||||||
gemm_output_buffer = ir.Buffer(
|
gemm_output_buffer = ir.Buffer(
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore # missing-attribute
|
||||||
name=gemm_output_name,
|
name=gemm_output_name,
|
||||||
# pyrefly: ignore # missing-attribute
|
|
||||||
layout=template_buffer.layout,
|
layout=template_buffer.layout,
|
||||||
)
|
)
|
||||||
current_input_buffer = gemm_output_buffer
|
current_input_buffer = gemm_output_buffer
|
||||||
|
|
@ -1504,7 +1503,6 @@ class CppGemmTemplate(CppTemplate):
|
||||||
current_input_buffer = ir.Buffer(
|
current_input_buffer = ir.Buffer(
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore # missing-attribute
|
||||||
name=buffer_name,
|
name=buffer_name,
|
||||||
# pyrefly: ignore # missing-attribute
|
|
||||||
layout=template_buffer.layout,
|
layout=template_buffer.layout,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -824,7 +824,6 @@ class CppWrapperGpu(CppWrapperCpu):
|
||||||
call_args, arg_types = self.prepare_triton_wrapper_args(
|
call_args, arg_types = self.prepare_triton_wrapper_args(
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore # bad-argument-type
|
||||||
call_args,
|
call_args,
|
||||||
# pyrefly: ignore # bad-argument-type
|
|
||||||
arg_types,
|
arg_types,
|
||||||
)
|
)
|
||||||
wrapper_name = f"call_{kernel_name}"
|
wrapper_name = f"call_{kernel_name}"
|
||||||
|
|
|
||||||
|
|
@ -683,7 +683,6 @@ class MetalKernel(SIMDKernel):
|
||||||
# pyrefly: ignore # missing-argument
|
# pyrefly: ignore # missing-argument
|
||||||
t
|
t
|
||||||
for t in self.range_tree_nodes.values()
|
for t in self.range_tree_nodes.values()
|
||||||
# pyrefly: ignore # missing-argument
|
|
||||||
if t.is_reduction
|
if t.is_reduction
|
||||||
)
|
)
|
||||||
cmp_op = ">" if reduction_type == "argmax" else "<"
|
cmp_op = ">" if reduction_type == "argmax" else "<"
|
||||||
|
|
@ -866,7 +865,6 @@ class MetalKernel(SIMDKernel):
|
||||||
# pyrefly: ignore # missing-argument
|
# pyrefly: ignore # missing-argument
|
||||||
t.numel
|
t.numel
|
||||||
for t in self.range_trees
|
for t in self.range_trees
|
||||||
# pyrefly: ignore # missing-argument
|
|
||||||
if t.is_reduction
|
if t.is_reduction
|
||||||
)
|
)
|
||||||
# If using dynamic shapes, set the threadgroup size to be the
|
# If using dynamic shapes, set the threadgroup size to be the
|
||||||
|
|
|
||||||
|
|
@ -968,7 +968,6 @@ class SIMDKernel(Kernel[CSEVariableType], Generic[CSEVariableType]):
|
||||||
# pyrefly: ignore # missing-argument
|
# pyrefly: ignore # missing-argument
|
||||||
t
|
t
|
||||||
for t in self.range_trees
|
for t in self.range_trees
|
||||||
# pyrefly: ignore # missing-argument
|
|
||||||
if not t.is_reduction or self.inside_reduction
|
if not t.is_reduction or self.inside_reduction
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1004,7 +1004,6 @@ class FxConverter:
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore # missing-attribute
|
||||||
call_kwargs[key]
|
call_kwargs[key]
|
||||||
for key in signature
|
for key in signature
|
||||||
# pyrefly: ignore # missing-attribute
|
|
||||||
if key not in cfg.kwargs
|
if key not in cfg.kwargs
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -275,11 +275,8 @@ class AutotuneCache:
|
||||||
triton_cache_hash: str | None = None,
|
triton_cache_hash: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
data = {
|
data = {
|
||||||
# pyrefly: ignore # missing-attribute
|
|
||||||
**config.kwargs,
|
**config.kwargs,
|
||||||
# pyrefly: ignore # missing-attribute
|
|
||||||
"num_warps": config.num_warps,
|
"num_warps": config.num_warps,
|
||||||
# pyrefly: ignore # missing-attribute
|
|
||||||
"num_stages": config.num_stages,
|
"num_stages": config.num_stages,
|
||||||
"configs_hash": self.configs_hash,
|
"configs_hash": self.configs_hash,
|
||||||
"found_by_coordesc": found_by_coordesc,
|
"found_by_coordesc": found_by_coordesc,
|
||||||
|
|
@ -573,20 +570,15 @@ def _load_cached_autotuning(
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create the triton_config with the appropriate arguments
|
# Create the triton_config with the appropriate arguments
|
||||||
# pyrefly: ignore # bad-argument-count
|
|
||||||
triton_config = Config(best_config, **config_args)
|
triton_config = Config(best_config, **config_args)
|
||||||
# pyrefly: ignore # missing-attribute
|
|
||||||
triton_config.found_by_coordesc = True
|
triton_config.found_by_coordesc = True
|
||||||
return triton_config
|
return triton_config
|
||||||
|
|
||||||
matching_configs = [
|
matching_configs = [
|
||||||
cfg
|
cfg
|
||||||
for cfg in configs
|
for cfg in configs
|
||||||
# pyrefly: ignore # missing-attribute
|
|
||||||
if all(val == best_config.get(key) for key, val in cfg.kwargs.items())
|
if all(val == best_config.get(key) for key, val in cfg.kwargs.items())
|
||||||
# pyrefly: ignore # missing-attribute
|
|
||||||
and cfg.num_warps == best_config.get("num_warps")
|
and cfg.num_warps == best_config.get("num_warps")
|
||||||
# pyrefly: ignore # missing-attribute
|
|
||||||
and cfg.num_stages == best_config.get("num_stages")
|
and cfg.num_stages == best_config.get("num_stages")
|
||||||
]
|
]
|
||||||
if len(matching_configs) != 1:
|
if len(matching_configs) != 1:
|
||||||
|
|
|
||||||
|
|
@ -123,7 +123,6 @@ class Benchmarker:
|
||||||
- The runtime of `fn(*fn_args, **fn_kwargs)`, in milliseconds.
|
- The runtime of `fn(*fn_args, **fn_kwargs)`, in milliseconds.
|
||||||
"""
|
"""
|
||||||
inferred_device = None
|
inferred_device = None
|
||||||
# pyrefly: ignore # bad-assignment
|
|
||||||
for arg_or_kwarg in chain(fn_args, fn_kwargs.values()):
|
for arg_or_kwarg in chain(fn_args, fn_kwargs.values()):
|
||||||
if not isinstance(arg_or_kwarg, torch.Tensor):
|
if not isinstance(arg_or_kwarg, torch.Tensor):
|
||||||
continue
|
continue
|
||||||
|
|
@ -197,7 +196,6 @@ class TritonBenchmarker(Benchmarker):
|
||||||
|
|
||||||
@may_distort_benchmarking_result
|
@may_distort_benchmarking_result
|
||||||
@time_and_count
|
@time_and_count
|
||||||
# pyrefly: ignore # bad-override
|
|
||||||
def benchmark_gpu(
|
def benchmark_gpu(
|
||||||
self: Self,
|
self: Self,
|
||||||
_callable: Callable[[], Any],
|
_callable: Callable[[], Any],
|
||||||
|
|
|
||||||
|
|
@ -190,7 +190,6 @@ class _OnDiskCacheImpl(_CacheImpl):
|
||||||
Defaults to empty string if not specified.
|
Defaults to empty string if not specified.
|
||||||
"""
|
"""
|
||||||
self._cache_dir: Path = self._base_dir / (sub_dir or "")
|
self._cache_dir: Path = self._base_dir / (sub_dir or "")
|
||||||
# pyrefly: ignore # bad-assignment
|
|
||||||
self._flock: FileLock = FileLock(str(self._cache_dir / "dir.lock"))
|
self._flock: FileLock = FileLock(str(self._cache_dir / "dir.lock"))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
|
||||||
|
|
@ -186,7 +186,6 @@ class CoordescTuner:
|
||||||
|
|
||||||
def check_all_tuning_directions(
|
def check_all_tuning_directions(
|
||||||
self,
|
self,
|
||||||
# pyrefly: ignore # missing-attribute
|
|
||||||
func: Callable[["triton.Config"], float],
|
func: Callable[["triton.Config"], float],
|
||||||
best_config,
|
best_config,
|
||||||
best_timing,
|
best_timing,
|
||||||
|
|
@ -256,12 +255,10 @@ class CoordescTuner:
|
||||||
|
|
||||||
def autotune(
|
def autotune(
|
||||||
self,
|
self,
|
||||||
func: Callable[
|
func: Callable[["triton.Config"], float],
|
||||||
["triton.Config"], float # pyrefly: ignore # missing-attribute
|
baseline_config: "triton.Config",
|
||||||
],
|
baseline_timing: float | None = None,
|
||||||
baseline_config: "triton.Config", # pyrefly: ignore # missing-attribute
|
) -> "triton.Config":
|
||||||
baseline_timing: float | None = None, # pyrefly: ignore # missing-attribute
|
|
||||||
) -> "triton.Config": # pyrefly: ignore # missing-attribute
|
|
||||||
if baseline_timing is None:
|
if baseline_timing is None:
|
||||||
baseline_timing = self.call_func(func, baseline_config)
|
baseline_timing = self.call_func(func, baseline_config)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -88,13 +88,11 @@ if has_triton_package():
|
||||||
divisible_by_16=None,
|
divisible_by_16=None,
|
||||||
equal_to_1=None,
|
equal_to_1=None,
|
||||||
):
|
):
|
||||||
# pyrefly: ignore # not-iterable
|
|
||||||
return {(x,): [["tt.divisibility", 16]] for x in divisible_by_16}
|
return {(x,): [["tt.divisibility", 16]] for x in divisible_by_16}
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Define a namedtuple as a fallback when AttrsDescriptor is not available
|
# Define a namedtuple as a fallback when AttrsDescriptor is not available
|
||||||
AttrsDescriptorWrapper = collections.namedtuple( # type: ignore[no-redef, name-match]
|
AttrsDescriptorWrapper = collections.namedtuple( # type: ignore[no-redef, name-match]
|
||||||
# pyrefly: ignore # invalid-argument
|
|
||||||
"AttrsDescriptor",
|
"AttrsDescriptor",
|
||||||
["divisible_by_16", "equal_to_1"],
|
["divisible_by_16", "equal_to_1"],
|
||||||
defaults=[(), ()],
|
defaults=[(), ()],
|
||||||
|
|
|
||||||
|
|
@ -68,11 +68,8 @@ def triton_config_to_hashable(cfg: Config) -> Hashable:
|
||||||
Convert triton config to a tuple that can uniquely identify it. We can use
|
Convert triton config to a tuple that can uniquely identify it. We can use
|
||||||
the return value as a dictionary key.
|
the return value as a dictionary key.
|
||||||
"""
|
"""
|
||||||
# pyrefly: ignore # missing-attribute
|
|
||||||
items = sorted(cfg.kwargs.items())
|
items = sorted(cfg.kwargs.items())
|
||||||
# pyrefly: ignore # missing-attribute
|
|
||||||
items.append(("num_warps", cfg.num_warps))
|
items.append(("num_warps", cfg.num_warps))
|
||||||
# pyrefly: ignore # missing-attribute
|
|
||||||
items.append(("num_stages", cfg.num_stages))
|
items.append(("num_stages", cfg.num_stages))
|
||||||
return tuple(items)
|
return tuple(items)
|
||||||
|
|
||||||
|
|
@ -106,7 +103,6 @@ def get_max_y_grid() -> int:
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# pyrefly: ignore # import-error
|
|
||||||
import colorama
|
import colorama
|
||||||
|
|
||||||
HAS_COLORAMA = True
|
HAS_COLORAMA = True
|
||||||
|
|
@ -118,7 +114,6 @@ except ModuleNotFoundError:
|
||||||
if HAS_COLORAMA:
|
if HAS_COLORAMA:
|
||||||
|
|
||||||
def _color_text(msg: str, color: str) -> str:
|
def _color_text(msg: str, color: str) -> str:
|
||||||
# pyrefly: ignore # missing-attribute
|
|
||||||
return getattr(colorama.Fore, color.upper()) + msg + colorama.Fore.RESET
|
return getattr(colorama.Fore, color.upper()) + msg + colorama.Fore.RESET
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -34,29 +34,21 @@ class StaticallyLaunchedCudaKernel:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, kernel: CompiledKernel) -> None:
|
def __init__(self, kernel: CompiledKernel) -> None:
|
||||||
# pyrefly: ignore # missing-attribute
|
|
||||||
self.name = kernel.src.fn.__name__
|
self.name = kernel.src.fn.__name__
|
||||||
# pyrefly: ignore # missing-attribute
|
|
||||||
self.cubin_raw = kernel.asm.get("cubin", None)
|
self.cubin_raw = kernel.asm.get("cubin", None)
|
||||||
# pyrefly: ignore # missing-attribute
|
|
||||||
self.cubin_path = kernel._cubin_path
|
self.cubin_path = kernel._cubin_path
|
||||||
|
|
||||||
# Used by torch.compile to filter constants in older triton versions
|
# Used by torch.compile to filter constants in older triton versions
|
||||||
# pyrefly: ignore # missing-attribute
|
|
||||||
self.arg_names = kernel.src.fn.arg_names
|
self.arg_names = kernel.src.fn.arg_names
|
||||||
|
|
||||||
# Const exprs that are declared by the triton kernel directly
|
# Const exprs that are declared by the triton kernel directly
|
||||||
# Used to generate the kernel launcher's def args
|
# Used to generate the kernel launcher's def args
|
||||||
# pyrefly: ignore # missing-attribute
|
|
||||||
self.declared_constexprs = kernel.src.fn.constexprs
|
self.declared_constexprs = kernel.src.fn.constexprs
|
||||||
|
|
||||||
# pyrefly: ignore # missing-attribute
|
|
||||||
self.hash = kernel.hash
|
self.hash = kernel.hash
|
||||||
|
|
||||||
if triton_knobs is None:
|
if triton_knobs is None:
|
||||||
# pyrefly: ignore # missing-attribute
|
|
||||||
launch_enter = kernel.__class__.launch_enter_hook
|
launch_enter = kernel.__class__.launch_enter_hook
|
||||||
# pyrefly: ignore # missing-attribute
|
|
||||||
launch_exit = kernel.__class__.launch_exit_hook
|
launch_exit = kernel.__class__.launch_exit_hook
|
||||||
else:
|
else:
|
||||||
launch_enter = triton_knobs.runtime.launch_enter_hook
|
launch_enter = triton_knobs.runtime.launch_enter_hook
|
||||||
|
|
@ -78,15 +70,12 @@ class StaticallyLaunchedCudaKernel:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"We don't support launch enter or launch exit hooks"
|
"We don't support launch enter or launch exit hooks"
|
||||||
)
|
)
|
||||||
# pyrefly: ignore # missing-attribute
|
|
||||||
self.num_warps = kernel.metadata.num_warps
|
self.num_warps = kernel.metadata.num_warps
|
||||||
self.shared = (
|
self.shared = (
|
||||||
# pyrefly: ignore # missing-attribute
|
|
||||||
kernel.shared if hasattr(kernel, "shared") else kernel.metadata.shared
|
kernel.shared if hasattr(kernel, "shared") else kernel.metadata.shared
|
||||||
)
|
)
|
||||||
|
|
||||||
def needs_scratch_arg(scratch_name: str, param_name: str) -> bool:
|
def needs_scratch_arg(scratch_name: str, param_name: str) -> bool:
|
||||||
# pyrefly: ignore # missing-attribute
|
|
||||||
if hasattr(kernel.metadata, param_name):
|
if hasattr(kernel.metadata, param_name):
|
||||||
if getattr(kernel.metadata, param_name) > 0:
|
if getattr(kernel.metadata, param_name) > 0:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
|
|
@ -102,7 +91,6 @@ class StaticallyLaunchedCudaKernel:
|
||||||
# same situation for profile scratch - triton-lang/triton#7258
|
# same situation for profile scratch - triton-lang/triton#7258
|
||||||
self.has_profile_scratch = needs_scratch_arg("Profile", "profile_scratch_size")
|
self.has_profile_scratch = needs_scratch_arg("Profile", "profile_scratch_size")
|
||||||
|
|
||||||
# pyrefly: ignore # missing-attribute
|
|
||||||
self.arg_tys = self.arg_ty_from_signature(kernel.src)
|
self.arg_tys = self.arg_ty_from_signature(kernel.src)
|
||||||
self.function: int | None = None # Loaded by load_kernel(on the parent process)
|
self.function: int | None = None # Loaded by load_kernel(on the parent process)
|
||||||
num_ctas = 1
|
num_ctas = 1
|
||||||
|
|
@ -182,7 +170,6 @@ class StaticallyLaunchedCudaKernel:
|
||||||
def arg_ty_from_signature(self, src: ASTSource) -> str:
|
def arg_ty_from_signature(self, src: ASTSource) -> str:
|
||||||
def index_key(i: Any) -> int:
|
def index_key(i: Any) -> int:
|
||||||
if isinstance(i, str):
|
if isinstance(i, str):
|
||||||
# pyrefly: ignore # missing-attribute
|
|
||||||
return src.fn.arg_names.index(i)
|
return src.fn.arg_names.index(i)
|
||||||
elif isinstance(i, tuple):
|
elif isinstance(i, tuple):
|
||||||
# In triton 3.3, src.fn.constants has tuples as a key
|
# In triton 3.3, src.fn.constants has tuples as a key
|
||||||
|
|
@ -190,7 +177,6 @@ class StaticallyLaunchedCudaKernel:
|
||||||
else:
|
else:
|
||||||
return i
|
return i
|
||||||
|
|
||||||
# pyrefly: ignore # missing-attribute
|
|
||||||
signature = {index_key(key): value for key, value in src.signature.items()}
|
signature = {index_key(key): value for key, value in src.signature.items()}
|
||||||
# Triton uses these as the main way to filter out constants passed to their cubin
|
# Triton uses these as the main way to filter out constants passed to their cubin
|
||||||
constants = [index_key(key) for key in getattr(src, "constants", dict())]
|
constants = [index_key(key) for key in getattr(src, "constants", dict())]
|
||||||
|
|
@ -212,7 +198,6 @@ class StaticallyLaunchedCudaKernel:
|
||||||
if ty == "constexpr" or i in constants:
|
if ty == "constexpr" or i in constants:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
# pyrefly: ignore # bad-argument-type
|
|
||||||
params.append(self.extract_type(ty))
|
params.append(self.extract_type(ty))
|
||||||
return "".join(params)
|
return "".join(params)
|
||||||
|
|
||||||
|
|
@ -250,7 +235,6 @@ class StaticallyLaunchedCudaKernel:
|
||||||
if has_scratch:
|
if has_scratch:
|
||||||
arg_tys = arg_tys + "O"
|
arg_tys = arg_tys + "O"
|
||||||
args = (*args, None)
|
args = (*args, None)
|
||||||
# pyrefly: ignore # bad-argument-type
|
|
||||||
assert len(args) == len(arg_tys)
|
assert len(args) == len(arg_tys)
|
||||||
|
|
||||||
# TODO: can handle grid functions here or in C++, so
|
# TODO: can handle grid functions here or in C++, so
|
||||||
|
|
@ -263,7 +247,6 @@ class StaticallyLaunchedCudaKernel:
|
||||||
self.num_warps,
|
self.num_warps,
|
||||||
self.shared,
|
self.shared,
|
||||||
arg_tys,
|
arg_tys,
|
||||||
# pyrefly: ignore # bad-argument-type
|
|
||||||
args,
|
args,
|
||||||
stream,
|
stream,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -421,7 +421,6 @@ def get_proxy_slot(
|
||||||
else:
|
else:
|
||||||
# Attempt to build it from first principles.
|
# Attempt to build it from first principles.
|
||||||
_build_proxy_for_sym_expr(tracer, obj.node.expr, obj)
|
_build_proxy_for_sym_expr(tracer, obj.node.expr, obj)
|
||||||
# pyrefly: ignore # no-matching-overload
|
|
||||||
value = tracker.get(obj)
|
value = tracker.get(obj)
|
||||||
|
|
||||||
if value is None:
|
if value is None:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user