diff --git a/pyrefly.toml b/pyrefly.toml index b643be2265e..ad74e4df084 100644 --- a/pyrefly.toml +++ b/pyrefly.toml @@ -23,7 +23,7 @@ project-excludes = [ # ==== below will be enabled directory by directory ==== # ==== to test Pyrefly on a specific directory, simply comment it out ==== "torch/_inductor/runtime", - "torch/_inductor/codegen", + "torch/_inductor/codegen/triton.py", # formatting issues, will turn on after adjusting where suppressions can be # in import statements "torch/linalg/__init__.py", diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index 2f6efb03165..36ded3aea2f 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -950,6 +950,7 @@ class OpOverrides(BasicMathOpsMixin, OpDecompositions, OpsHandler[Any]): or _all_in_parens(string) ): # don't put extra parens for strings that are already wrapped in parens + # pyrefly: ignore # bad-return return string return f"({string})" @@ -1736,7 +1737,9 @@ class KernelArgs: ) ) for outer, inner in chain( - self.input_buffers.items(), self.output_buffers.items() + # pyrefly: ignore # bad-argument-type + self.input_buffers.items(), + self.output_buffers.items(), ): if outer in self.inplace_buffers or isinstance(inner, RemovedArg): continue @@ -2047,6 +2050,7 @@ class Kernel(CodeGen, Generic[CSEVariableType]): ) -> None: super().__init__() if increase_kernel_count: + # pyrefly: ignore # bad-assignment metrics.generated_kernel_count += 1 self.args = args or KernelArgs() self.loads = IndentedBuffer() @@ -2113,6 +2117,7 @@ class Kernel(CodeGen, Generic[CSEVariableType]): self.compute = compute self.stores = stores self.cse = cse + # pyrefly: ignore # unbound-name if disallow_stores: assert not sb, "unexpected store inside swap_buffers" @@ -2384,6 +2389,7 @@ class KernelTemplate: class DetailedTemplateSyntaxError(TemplateSyntaxError): def __init__(self, original_error: TemplateSyntaxError) -> None: super().__init__( + # pyrefly: ignore # bad-argument-type original_error.message, original_error.lineno, original_error.name, @@ -2395,6 +2401,7 @@ class KernelTemplate: error_info = f"Error in template at line {self.lineno}\n" error_info += f"Error message: {self.message}\n" if hasattr(self.original_error, "source"): + # pyrefly: ignore # missing-attribute lines = self.original_error.source.split("\n") error_info += "Context:\n" start = max(0, self.lineno - 2) diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 64e0fa196d6..1b8b0a9b9e2 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -504,6 +504,7 @@ class OuterLoopFusedSchedulerNode(FusedSchedulerNode): if any(type(node) is OuterLoopFusedSchedulerNode for node in (node1, node2)): return cls( node1.scheduler, + # pyrefly: ignore # bad-argument-type ( list(node1.get_outer_nodes()) if type(node1) is OuterLoopFusedSchedulerNode @@ -1716,6 +1717,7 @@ class CppVecOverrides(CppOverrides): body_vec_var.dtype = dtype other_vec_var.dtype = dtype overrides: type[Union[CppOverrides, CppVecOverrides]] = ( + # pyrefly: ignore # bad-assignment V.kernel.overrides ) # type: ignore[has-type] code.writeline( @@ -1759,6 +1761,7 @@ class CppVecOverrides(CppOverrides): csevar = V.kernel._load_or_store_non_contiguous( # type: ignore[assignment] None, index, dtype, V.kernel.compute ) + # pyrefly: ignore # missing-attribute csevar.update_on_args("index_expr", (expr, dtype), {}) return csevar @@ -2036,6 +2039,7 @@ class CppKernel(Kernel): # mask's dtype should be bool mask.dtype = torch.bool + # pyrefly: ignore # bad-assignment self._load_mask = mask try: yield mask @@ -2363,6 +2367,7 @@ class CppKernel(Kernel): sympy_index_symbol_with_prefix(SymT.XBLOCK, n) for n in range(len(self.ranges)) ] + # pyrefly: ignore # bad-assignment self.reduction_depth = len(lengths) return ( self.itervars[: self.reduction_depth], @@ -2610,7 +2615,9 @@ class CppKernel(Kernel): and end == self.ranges[var_id] ): end = 1 + # pyrefly: ignore # bad-argument-type conditions.append(f"{var} >= {cexpr_index(start)}") + # pyrefly: ignore # bad-argument-type conditions.append(f"{var} < {cexpr_index(end)}") return True @@ -4085,6 +4092,7 @@ class CppKernelProxy(CppKernel): and (dt := get_output_dtype(_node)) in DTYPE_LOWP_FP ): # No need to promote to float if all users are ops that accepts lowp fp input + # pyrefly: ignore # bad-argument-type if all(is_lowp_fp_sink(user, dt) for user in _node.users): continue ops = _node.args[0] @@ -4095,12 +4103,14 @@ class CppKernelProxy(CppKernel): _node.replace_all_uses_with( to_type_node, lambda n: n is not to_type_node ) + # pyrefly: ignore # bad-assignment metrics.cpp_to_dtype_count += 1 elif ( _node.target == "store" and (dt := get_input_dtype(_node)) in DTYPE_LOWP_FP ): ops, name, _, value_var, _ = _node.args + # pyrefly: ignore # bad-argument-type if is_lowp_fp_source_no_promote(value_var, dt): continue dtype = V.graph.get_dtype(name) @@ -4109,6 +4119,7 @@ class CppKernelProxy(CppKernel): "to_dtype", args=(ops, value_var, dtype) ) _node.replace_input_with(value_var, to_type_node) + # pyrefly: ignore # bad-assignment metrics.cpp_to_dtype_count += 1 elif _node.target == "reduction": ( @@ -4178,6 +4189,7 @@ class CppKernelProxy(CppKernel): "to_dtype", args=(ops, value_var, src_dtype) ) _node.replace_input_with(value_var, to_type_node) + # pyrefly: ignore # bad-assignment metrics.cpp_to_dtype_count += 1 # to_dtype_bitcast act as a lowp fp source: @@ -4196,6 +4208,7 @@ class CppKernelProxy(CppKernel): _node.replace_all_uses_with( to_type_node, lambda n: n is not to_type_node ) + # pyrefly: ignore # bad-assignment metrics.cpp_to_dtype_count += 1 def eliminate_to_dtype(sub_graph: torch.fx.Graph): @@ -4289,6 +4302,7 @@ class CppKernelProxy(CppKernel): with kernel_group.new_kernel(cls, *args) as kernel: # Ugly hack to maintain the metrics kernel count since # we only count in CppKernelProxy, not those contained in it + # pyrefly: ignore # bad-assignment metrics.generated_kernel_count -= 1 run(kernel) @@ -4360,6 +4374,7 @@ class CppKernelProxy(CppKernel): ) if len(tiling_indices) == 1: + # pyrefly: ignore # bad-assignment metrics.generated_cpp_vec_kernel_count += 1 loop = self.loop_nest.tile(tiling_indices[0], factor=tiling_factors[0]) vec_kernel = codegen_kernel( @@ -4386,6 +4401,7 @@ class CppKernelProxy(CppKernel): and tiling_factors[0] == tiling_factors[1] ) + # pyrefly: ignore # bad-assignment metrics.generated_cpp_vec_kernel_count += 2 outer_loop = self.loop_nest.tile( tiling_indices[0], factor=tiling_factors[0] @@ -5134,10 +5150,12 @@ class CppScheduling(BaseScheduling): contiguous_index_expr = 0 stride = 1 for var, range in reversed( + # pyrefly: ignore # missing-attribute scheduler_node._body.var_ranges.items() ): contiguous_index_expr += stride * var stride *= range + # pyrefly: ignore # missing-attribute write_index_expr = scheduler_node._body.get_write_expr( scheduler_buffer.get_name() ) @@ -5206,6 +5224,7 @@ class CppScheduling(BaseScheduling): ) local_buffers.append(local_buffer_used) local_to_global_buffers[local_buffer_used.name] = [] # type: ignore[index] + # pyrefly: ignore # index-error local_to_global_buffers[local_buffer_used.name].append( global_buffer, ) @@ -5450,6 +5469,7 @@ class CppScheduling(BaseScheduling): wrapper = V.graph.wrapper_code debug_handle = set_kernel_post_grad_provenance_tracing( node_schedule, # type: ignore[arg-type] + # pyrefly: ignore # bad-argument-type kernel_name, ) wrapper.write_provenance_debug_handle(kernel_name, debug_handle) @@ -5771,6 +5791,7 @@ class LoopNest: loop = self.loops[par_depth.start_depth] loop.parallel = par_depth.parallel_depth if loop.is_reduction: + # pyrefly: ignore # bad-assignment metrics.parallel_reduction_count += 1 for i in range(par_depth.start_depth + 1, par_depth.parallel_depth): self.loops[i].collapsed = True diff --git a/torch/_inductor/codegen/cpp_gemm_template.py b/torch/_inductor/codegen/cpp_gemm_template.py index 6dbf1c8ad69..9b26105bab1 100644 --- a/torch/_inductor/codegen/cpp_gemm_template.py +++ b/torch/_inductor/codegen/cpp_gemm_template.py @@ -396,12 +396,15 @@ def transpose_w(W: _T, trans_w: bool) -> _T: if isinstance(W, ir.IRNode): if trans_w: if not isinstance(W, ir.TensorBox): + # pyrefly: ignore # bad-assignment W = ir.TensorBox(W) W = L.permute(W, [1, 0]) else: if trans_w: assert isinstance(W, torch.Tensor) + # pyrefly: ignore # bad-assignment W = W.transpose(0, 1) + # pyrefly: ignore # bad-return return W @@ -412,12 +415,15 @@ def expand_bias(B: Optional[_T], X: _T) -> Optional[_T]: if B is not None: if isinstance(B, ir.IRNode): if not isinstance(B, ir.TensorBox): + # pyrefly: ignore # bad-assignment B = ir.TensorBox(B) assert hasattr(X, "get_size") + # pyrefly: ignore # missing-attribute B = L.expand(B, (X.get_size()[0], B.get_size()[-1])) else: assert isinstance(B, torch.Tensor) assert isinstance(X, torch.Tensor) + # pyrefly: ignore # bad-assignment B = B.expand(X.shape[0], B.shape[-1]) return B @@ -1043,6 +1049,7 @@ class CppGemmTemplate(CppTemplate): return cls.prep_weight( new_inputs, new_layout, + # pyrefly: ignore # bad-argument-type micro_gemm, pre_block_weights, use_int8_fast_compensation_path, @@ -1066,6 +1073,7 @@ class CppGemmTemplate(CppTemplate): new_input_nodes, _ = cls.prep_weight( new_input_nodes, new_layout, + # pyrefly: ignore # bad-argument-type micro_gemm, pre_block_weights, use_int8_fast_compensation_path, @@ -1470,7 +1478,9 @@ class CppGemmTemplate(CppTemplate): assert isinstance(template_buffer, ir.IRNode) gemm_output_name = f"{template_buffer.get_name()}_GemmOut" gemm_output_buffer = ir.Buffer( - name=gemm_output_name, layout=template_buffer.layout + # pyrefly: ignore # missing-attribute + name=gemm_output_name, + layout=template_buffer.layout, ) current_input_buffer = gemm_output_buffer for i, creator in enumerate(epilogue_creators): @@ -1481,6 +1491,7 @@ class CppGemmTemplate(CppTemplate): epilogues.append( ir.ComputedBuffer( name=buffer_name, + # pyrefly: ignore # missing-attribute layout=template_buffer.layout, data=creator(current_input_buffer), ) @@ -1490,7 +1501,9 @@ class CppGemmTemplate(CppTemplate): reindexers.append(None) if i < len(epilogue_creators) - 1: current_input_buffer = ir.Buffer( - name=buffer_name, layout=template_buffer.layout + # pyrefly: ignore # missing-attribute + name=buffer_name, + layout=template_buffer.layout, ) assert isinstance(Y, (ir.Buffer, ir.ReinterpretView)) @@ -1521,6 +1534,7 @@ class CppGemmTemplate(CppTemplate): self.n, self.k, input_dtype=X.get_dtype(), + # pyrefly: ignore # missing-attribute input2_dtype=W.get_dtype(), output_dtype=output_dtype, compute_dtype=compute_dtype, diff --git a/torch/_inductor/codegen/cpp_grouped_gemm_template.py b/torch/_inductor/codegen/cpp_grouped_gemm_template.py index 4b973522227..ed554d28004 100644 --- a/torch/_inductor/codegen/cpp_grouped_gemm_template.py +++ b/torch/_inductor/codegen/cpp_grouped_gemm_template.py @@ -183,12 +183,14 @@ class CppGroupedGemmTemplate(CppGemmTemplate): ) self.act_mapping = act_mapping self.gemm_grouped_num = gemm_grouped_num + # pyrefly: ignore # bad-override self.output_node: list[ir.Buffer] = [ ir.Buffer(name="buf_out" + str(idx), layout=layout) for idx in range(gemm_grouped_num) ] @classmethod + # pyrefly: ignore # bad-override def add_choices( cls, choices: list[ChoiceCaller], @@ -231,6 +233,7 @@ class CppGroupedGemmTemplate(CppGemmTemplate): if isinstance(inputs[idx], torch.Tensor): W = inputs[idx] assert isinstance(W, torch.Tensor), "W must be a torch.Tensor" + # pyrefly: ignore # unsupported-operation new_inputs[idx] = W.to_dense() if W.is_mkldnn else W return new_inputs, layout_or_out @@ -246,8 +249,10 @@ class CppGroupedGemmTemplate(CppGemmTemplate): new_input = new_inputs[wgt_idx] new_inputs[wgt_idx] = transpose_w(new_input, trans_w) for bias_idx in range(bias_start_idx, len(new_inputs)): + # pyrefly: ignore # bad-argument-type new_bias = expand_bias(new_inputs[bias_idx], X) assert new_bias is not None + # pyrefly: ignore # unsupported-operation new_inputs[bias_idx] = new_bias return new_inputs, layout_or_out @@ -308,6 +313,7 @@ class CppGroupedGemmTemplate(CppGemmTemplate): W_tensor = [] for W_node in W_nodes: assert W_node.get_name() in V.graph.constants + # pyrefly: ignore # bad-argument-type W_tensor.append(V.graph.constants[W_node.get_name()]) new_input_nodes[wgt_start_idx : wgt_start_idx + gemm_grouped_num] = ( W_tensor # type: ignore[assignment] @@ -324,6 +330,7 @@ class CppGroupedGemmTemplate(CppGemmTemplate): template_buffer.inputs[idx] = ( ir.InputsKernel.unwrap_storage_for_input(W_packed_constant) ) + # pyrefly: ignore # bad-return return output template = DataProcessorTemplateWrapper( @@ -362,6 +369,7 @@ class CppGroupedGemmTemplate(CppGemmTemplate): cur_idx = bias_start_idx for inp_idx in range(self.gemm_grouped_num): inp = None + # pyrefly: ignore # index-error if self.has_bias[inp_idx]: inp = self.input_nodes[cur_idx] cur_idx += 1 @@ -390,6 +398,7 @@ class CppGroupedGemmTemplate(CppGemmTemplate): self.n, self.k, input_dtype=X_list[0].get_dtype(), + # pyrefly: ignore # missing-attribute input2_dtype=W_list[0].get_dtype(), output_dtype=output_dtype, compute_dtype=compute_dtype, @@ -427,6 +436,7 @@ class CppGroupedGemmTemplate(CppGemmTemplate): for x_idx in range(wgt_start_idx): kernel_args["X" + str(x_idx)] = act_deduplicated[x_idx] for w_idx in range(self.gemm_grouped_num): + # pyrefly: ignore # unsupported-operation kernel_args["W" + str(w_idx)] = W_list[w_idx] for inp_idx in range(self.gemm_grouped_num): kernel_args["inp" + str(inp_idx)] = inp_list[inp_idx] diff --git a/torch/_inductor/codegen/cpp_template.py b/torch/_inductor/codegen/cpp_template.py index d72f13a3e3f..c2fcaeadebf 100644 --- a/torch/_inductor/codegen/cpp_template.py +++ b/torch/_inductor/codegen/cpp_template.py @@ -85,6 +85,7 @@ class CppTemplate(KernelTemplate): bmreq = CppBenchmarkRequest( kernel_name=kernel_name, input_tensor_meta=TensorMeta.from_irnodes(self.input_nodes), + # pyrefly: ignore # bad-argument-type output_tensor_meta=TensorMeta.from_irnodes(self.output_node), extra_args=extra_args, source_code=code, @@ -112,6 +113,7 @@ class CppTemplate(KernelTemplate): kernel_hash_name, self.name, self.input_nodes, + # pyrefly: ignore # index-error self.output_node[0].get_layout() if isinstance(self.output_node, Iterable) else self.output_node.get_layout(), diff --git a/torch/_inductor/codegen/cpp_template_kernel.py b/torch/_inductor/codegen/cpp_template_kernel.py index b0dee69b012..a077ab394db 100644 --- a/torch/_inductor/codegen/cpp_template_kernel.py +++ b/torch/_inductor/codegen/cpp_template_kernel.py @@ -411,6 +411,7 @@ class CppTemplateKernel(CppKernel): ) epilogue_nodes = scope.localize_nodes(epilogue_nodes) return self.store_pointwise_nodes( + # pyrefly: ignore # bad-argument-type dst, epilogue_nodes, # type: ignore[arg-type] offsets, @@ -422,6 +423,7 @@ class CppTemplateKernel(CppKernel): copy = L.copy(dst, src).data.data with LocalBufferContext(self.args) as scope: scope.add_local_buffer(src) + # pyrefly: ignore # bad-argument-type return self.store_pointwise_nodes(dst, [copy]) else: assert dst.layout == src.layout, f"{dst=}, {src=}" diff --git a/torch/_inductor/codegen/cpp_utils.py b/torch/_inductor/codegen/cpp_utils.py index a2d9878f222..de70481a3c3 100644 --- a/torch/_inductor/codegen/cpp_utils.py +++ b/torch/_inductor/codegen/cpp_utils.py @@ -311,6 +311,7 @@ class LocalizeBufferHandler(V.WrapperHandler): # type: ignore[name-defined] return res def store_reduction(self, name, index, value): + # pyrefly: ignore # bad-argument-count return self._inner.store_reduction(*self.localize(name, index), value) diff --git a/torch/_inductor/codegen/cpp_wrapper_gpu.py b/torch/_inductor/codegen/cpp_wrapper_gpu.py index 05825907dd1..d1ddc7e1cd4 100644 --- a/torch/_inductor/codegen/cpp_wrapper_gpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_gpu.py @@ -307,6 +307,7 @@ class DeferredTritonCallWrapper: f"RAIIC10IValueHandle RAII_{arg_name}(tmp_{arg_name});", ] ) + # pyrefly: ignore # bad-argument-type total_args.append(f"tmp_{arg_name}") def process_args_for_input_shape(arg, arg_type, arg_signature=None): @@ -331,6 +332,7 @@ class DeferredTritonCallWrapper: f"RAIIC10IValueHandle RAII_{arg_name}(tmp_{arg_name});", ] ) + # pyrefly: ignore # bad-argument-type total_args.append(f"tmp_{arg_name}") elif ( isinstance(arg_type, type(SymbolicCallArg)) @@ -348,6 +350,7 @@ class DeferredTritonCallWrapper: for arg, arg_type, arg_signature in zip_longest( call_args, arg_types, arg_signatures ): + # pyrefly: ignore # bad-argument-type ordered_argsname.append(f'"{arg}"') process_args_for_input_shape(arg, arg_type, arg_signature) @@ -819,7 +822,9 @@ class CppWrapperGpu(CppWrapperCpu): if triton: call_args, arg_types = self.prepare_triton_wrapper_args( - call_args, arg_types + # pyrefly: ignore # bad-argument-type + call_args, + arg_types, ) wrapper_name = f"call_{kernel_name}" if wrapper_name not in self._triton_call_wrappers: @@ -843,10 +848,12 @@ class CppWrapperGpu(CppWrapperCpu): self.writeline(f"{wrapper_name}({', '.join(call_args)});") else: casted = [] + # pyrefly: ignore # no-matching-overload for arg_type, arg in zip(arg_types, call_args): new_arg = arg if arg_type.endswith("*") and arg != "nullptr": new_arg = f"{arg}.data_ptr()" + # pyrefly: ignore # bad-argument-type casted.append(f"({arg_type}){cexpr(new_arg)}") call_args_str = ", ".join(casted) self.writeline(f"kernels.{kernel_name}({call_args_str}, {stream});") diff --git a/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py b/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py index 53c8739765e..3d2ee95e523 100644 --- a/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py +++ b/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py @@ -190,6 +190,7 @@ class CUDACPPScheduling(BaseScheduling): assert all(n.node is not None for n in nodes), ( "All epilogue nodes should have an IRNode" ) + # pyrefly: ignore # redundant-cast return cast( list[BaseSchedulerNode], [n for n in nodes if n.node is not template_node] ) diff --git a/torch/_inductor/codegen/cuda/cuda_template.py b/torch/_inductor/codegen/cuda/cuda_template.py index 4aa0aeb46e0..fe764e652c0 100644 --- a/torch/_inductor/codegen/cuda/cuda_template.py +++ b/torch/_inductor/codegen/cuda/cuda_template.py @@ -72,6 +72,7 @@ class CUDATemplate(KernelTemplate): @classmethod @functools.lru_cache(None) + # pyrefly: ignore # bad-override def _template_from_string(cls, source: str) -> Any: return KernelTemplate._template_from_string(source) diff --git a/torch/_inductor/codegen/cuda/cutlass_lib_extensions/evt_extensions.py b/torch/_inductor/codegen/cuda/cutlass_lib_extensions/evt_extensions.py index 17e5a08f51e..49b57b89236 100644 --- a/torch/_inductor/codegen/cuda/cutlass_lib_extensions/evt_extensions.py +++ b/torch/_inductor/codegen/cuda/cutlass_lib_extensions/evt_extensions.py @@ -163,6 +163,7 @@ non-contiguous layout, received stride: {stride} and shape: {shape}" ) -> None: self.example_inputs = example_inputs self.ast = ast.parse(self.source) + # pyrefly: ignore # missing-attribute self.visit(self.ast) cc = int(cuda_env.get_cuda_arch()) diff --git a/torch/_inductor/codegen/cuda/cutlass_utils.py b/torch/_inductor/codegen/cuda/cutlass_utils.py index e10e7b7daaf..2f673e92e24 100644 --- a/torch/_inductor/codegen/cuda/cutlass_utils.py +++ b/torch/_inductor/codegen/cuda/cutlass_utils.py @@ -470,6 +470,7 @@ class CUDACompileSourceCapturingContext: self.sources.append(source_code) return _compile_method_orig(source_code, dst_file_ext) + # pyrefly: ignore # bad-assignment self._compile_patch = mock.patch( "torch._inductor.codecache.CUDACodeCache.compile", my_compile ) diff --git a/torch/_inductor/codegen/cutedsl/cutedsl_kernel.py b/torch/_inductor/codegen/cutedsl/cutedsl_kernel.py index 967607ca0e3..adec95b76c2 100644 --- a/torch/_inductor/codegen/cutedsl/cutedsl_kernel.py +++ b/torch/_inductor/codegen/cutedsl/cutedsl_kernel.py @@ -286,6 +286,7 @@ class CuteDSLTemplateKernel(Kernel): # Generate unpacking assignments: in_ptr4 = buffers[0], etc. unpacking_lines = [] for i, buffer_name in enumerate(tensor_buffers): + # pyrefly: ignore # bad-argument-type unpacking_lines.append(f"{buffer_name} = buffers[{i}]") return "\n ".join(unpacking_lines) @@ -493,6 +494,7 @@ class ModificationWrapperCuteDSL(V.WrapperHandler): # type: ignore[name-defined """Convert index variable to symbolic form.""" return sympy_index_symbol(str(index_var)) + # pyrefly: ignore # bad-override def store( self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None ) -> str: diff --git a/torch/_inductor/codegen/cutedsl/cutedsl_op_overrides.py b/torch/_inductor/codegen/cutedsl/cutedsl_op_overrides.py index 5dd79db7bdb..a0f76ab5efb 100644 --- a/torch/_inductor/codegen/cutedsl/cutedsl_op_overrides.py +++ b/torch/_inductor/codegen/cutedsl/cutedsl_op_overrides.py @@ -274,7 +274,9 @@ class CuteDSLOpOverrides(OpOverrides): else "mlir_math.absi" ) return CuteDSLOpOverrides._apply_unary_op( - x, f"cute.TensorSSA({abs_op}({{x}}), {{x}}.shape, {{x}}.dtype)" + # pyrefly: ignore # bad-argument-type + x, + f"cute.TensorSSA({abs_op}({{x}}), {{x}}.shape, {{x}}.dtype)", ) @staticmethod diff --git a/torch/_inductor/codegen/cutedsl/cutedsl_template.py b/torch/_inductor/codegen/cutedsl/cutedsl_template.py index b43dbd9cfd7..016edb63a35 100644 --- a/torch/_inductor/codegen/cutedsl/cutedsl_template.py +++ b/torch/_inductor/codegen/cutedsl/cutedsl_template.py @@ -43,6 +43,7 @@ class CuteDSLTemplate(KernelTemplate): @staticmethod @functools.lru_cache(None) + # pyrefly: ignore # bad-override def _template_from_string(source: str) -> Any: return KernelTemplate._template_from_string(source) diff --git a/torch/_inductor/codegen/halide.py b/torch/_inductor/codegen/halide.py index f477d16cc76..f0a2b07b1cc 100644 --- a/torch/_inductor/codegen/halide.py +++ b/torch/_inductor/codegen/halide.py @@ -636,6 +636,7 @@ class DimensionInfo: return "hl.Var()" if replacements: replacements = {**replacements} + # pyrefly: ignore # missing-attribute for sym in expr.free_symbols: if symbol_is_type(sym, SymT.TMP): assert isinstance(sym, sympy.Symbol) @@ -709,8 +710,10 @@ class HalideKernel(SIMDKernel): def dtype_to_str(self, dtype: torch.dtype) -> str: return halide_type(dtype) + # pyrefly: ignore # bad-override def create_cse_var(self, name, bounds=None, dtype=None, shape=None): self.body.writeline(f"{name} = hl.Func({name!r})") + # pyrefly: ignore # bad-argument-type return HalideCSEVariable(name, bounds, dtype, shape) def finalize_indexing(self, indices: Sequence[sympy.Expr]): @@ -728,6 +731,7 @@ class HalideKernel(SIMDKernel): self.index_replacements or self.halide_vars or self.reduction_renames ) size_hint = functools.partial(V.graph.sizevars.size_hint, fallback=inf) # type: ignore[arg-type] + # pyrefly: ignore # bad-assignment indices = dict.fromkeys(map(super().prepare_indexing, indices)) all_used_symbols = OrderedSet[Any]() sym_to_node = { @@ -826,6 +830,7 @@ class HalideKernel(SIMDKernel): handled_count = len(nodes) had_fallback = True sym = sympy_index_symbol(f"h{len(self.halide_vars)}") + # pyrefly: ignore # missing-argument if tree.is_reduction: self.reduction_renames[sym] = sympy_index_symbol( f"hr{len(self.halide_vars)}" @@ -1222,8 +1227,10 @@ class HalideKernel(SIMDKernel): parts = [] stride = 1 for i, sym in enumerate(self.reduction_renames): + # pyrefly: ignore # bad-argument-type parts.append(f"{index}[{i}]") if stride != 1: + # pyrefly: ignore # unsupported-operation parts[-1] += f"*{stride}" stride *= self.halide_vars[sym] self.body.writeline(f"{result_var} = {' + '.join(parts)}") @@ -1576,6 +1583,7 @@ class HalideKernel(SIMDKernel): hint = self._autoscheduler_workarounds( V.graph.sizevars.size_hint(dim.size, fallback=1), dims ) + # pyrefly: ignore # bad-argument-type range_hints.append(f"hl.Range(0, {hint})") if "out" not in arg.name: code.writeline(f"{arg.name}.dim({i}).set_min(0)") diff --git a/torch/_inductor/codegen/mps.py b/torch/_inductor/codegen/mps.py index f68c241ca83..a74506d7247 100644 --- a/torch/_inductor/codegen/mps.py +++ b/torch/_inductor/codegen/mps.py @@ -516,6 +516,7 @@ class MetalKernel(SIMDKernel): var = self.args.output(name) index = self.prepare_indexing(index) dtype_str = self.dtype_to_str(V.graph.get_dtype(name)) + # pyrefly: ignore # missing-argument reduction_dim = next(t for t in self.range_trees if t.is_reduction) # Only one thread in the reduction group needs to store the results line = f"{var}[{self.index_to_str(index)}] = static_cast<{dtype_str}>({value});" @@ -582,6 +583,7 @@ class MetalKernel(SIMDKernel): reduction_idx = "" acc_buf_size = 1 for rd in self.range_trees: + # pyrefly: ignore # missing-argument if not rd.is_reduction: continue if reduction_idx: @@ -678,7 +680,10 @@ class MetalKernel(SIMDKernel): ) idx_val = self._new_idxvar(dtype, default_value=0, is_threadgroup=False) # type: ignore[assignment] idx_var = next( - t for t in self.range_tree_nodes.values() if t.is_reduction + # pyrefly: ignore # missing-argument + t + for t in self.range_tree_nodes.values() + if t.is_reduction ) cmp_op = ">" if reduction_type == "argmax" else "<" nan_suffix = ( @@ -745,6 +750,7 @@ class MetalKernel(SIMDKernel): index_expr = self.rename_indexing(entry.expr) index_str = self.sexpr(index_expr) # type: ignore[misc] + # pyrefly: ignore # missing-argument if not entry.is_reduction or ( isinstance(entry.root.numel, sympy.Integer) and entry.root.numel <= self.max_threadgroup_size @@ -856,7 +862,10 @@ class MetalKernel(SIMDKernel): if self.inside_reduction: total_reduction_size = math.prod( - t.numel for t in self.range_trees if t.is_reduction + # pyrefly: ignore # missing-argument + t.numel + for t in self.range_trees + if t.is_reduction ) # If using dynamic shapes, set the threadgroup size to be the # max possible size @@ -958,6 +967,7 @@ class MetalKernel(SIMDKernel): else: expr = V.graph.wrapper_code.generate_numel_expr(name, tree).inner + # pyrefly: ignore # missing-argument if not tree.is_reduction or self.inside_reduction: args.append(str(expr)) arg_types.append(int) @@ -977,6 +987,7 @@ class MetalKernel(SIMDKernel): threads = [ expr_printer( sympy.Min(v.numel, self.max_threadgroup_size) # type: ignore[misc] + # pyrefly: ignore # missing-argument if v.is_reduction else v.numel ) @@ -992,6 +1003,7 @@ class MetalKernel(SIMDKernel): if self.inside_reduction: threads = [ expr_printer(sympy.Min(v.numel, self.max_threadgroup_size)) # type: ignore[misc] + # pyrefly: ignore # missing-argument if v.is_reduction else "1" for v in self.active_range_trees() diff --git a/torch/_inductor/codegen/multi_kernel.py b/torch/_inductor/codegen/multi_kernel.py index 01055f5cd6e..0861b218f9c 100644 --- a/torch/_inductor/codegen/multi_kernel.py +++ b/torch/_inductor/codegen/multi_kernel.py @@ -306,6 +306,7 @@ class MultiKernelCall: # manually force a subkernel to ease perf testing picked_by_config = config.triton.multi_kernel - 2 assert picked_by_config < len(self._kernels) + # pyrefly: ignore # bad-assignment self.picked_kernel = picked_by_config elif not self.disable_cache: self.load_cache() @@ -329,7 +330,9 @@ class MultiKernelCall: path = self.cache_file_path() if path.exists(): with path.open() as fd: + # pyrefly: ignore # bad-assignment self.picked_kernel = int(fd.read()) + # pyrefly: ignore # unsupported-operation assert self.picked_kernel >= 0 and self.picked_kernel < len( self._kernels ) @@ -599,5 +602,6 @@ class SizeHintMultiKernelCall(MultiKernelCall): self._dist_heuristic(shape_key, key) if key is not None else 2**62 for key in self._kernel_hints ] + # pyrefly: ignore # bad-assignment self.picked_kernel = dists.index(min(dists)) self._cache_shape_choice(shape_key, self.picked_kernel) diff --git a/torch/_inductor/codegen/rocm/ck_conv_template.py b/torch/_inductor/codegen/rocm/ck_conv_template.py index 37d9898f6be..b8e7da3e156 100644 --- a/torch/_inductor/codegen/rocm/ck_conv_template.py +++ b/torch/_inductor/codegen/rocm/ck_conv_template.py @@ -513,9 +513,11 @@ class CKGroupedConvFwdTemplate(CKTemplate): arg = f"/* {field_name} */ Tuple<{tuple_elements}>" else: # tile shape arg = f"/* {field_name} */ S<{tuple_elements}>" + # pyrefly: ignore # bad-argument-type template_params.append(arg) else: if field_value is not None: + # pyrefly: ignore # bad-argument-type template_params.append(f"/* {field_name} */ {field_value}") return self._template_from_string(template_definition).render( operation_name=op.name(), diff --git a/torch/_inductor/codegen/rocm/ck_universal_gemm_template.py b/torch/_inductor/codegen/rocm/ck_universal_gemm_template.py index b6add1e8dbd..db2bd69b1d0 100644 --- a/torch/_inductor/codegen/rocm/ck_universal_gemm_template.py +++ b/torch/_inductor/codegen/rocm/ck_universal_gemm_template.py @@ -590,9 +590,11 @@ class CKGemmTemplate(CKTemplate): arg = f"/* {field_name} */ Tuple<{tuple_elements}>" else: # tile shape arg = f"/* {field_name} */ S<{tuple_elements}>" + # pyrefly: ignore # bad-argument-type template_params.append(arg) else: if field_value is not None: + # pyrefly: ignore # bad-argument-type template_params.append(f"/* {field_name} */ {field_value}") operation_name = op.name().replace("(", "").replace(",", "").replace(")", "") return self._template_from_string(template_definition).render( diff --git a/torch/_inductor/codegen/simd.py b/torch/_inductor/codegen/simd.py index 8c3dd051cdd..e2294f05ddc 100644 --- a/torch/_inductor/codegen/simd.py +++ b/torch/_inductor/codegen/simd.py @@ -187,6 +187,7 @@ class IterationRangesRoot(IterationRanges): # True if the dimension is implemented as a single program looping over # the full dimension (currently only used for non-persistent reduction) + # pyrefly: ignore # missing-argument assert not is_loop or (self.is_reduction and grid_dim is None) self.is_loop = is_loop # Index of corresponding dimension on triton tensors @@ -374,6 +375,7 @@ class SIMDKernel(Kernel[CSEVariableType], Generic[CSEVariableType]): sexpr: Callable[[sympy.Expr], str] = pexpr kexpr: Callable[[sympy.Expr], str] allow_block_ptr: bool = False + # pyrefly: ignore # bad-override kernel_name: str def __init__( @@ -570,6 +572,7 @@ class SIMDKernel(Kernel[CSEVariableType], Generic[CSEVariableType]): if tree.tensor_dim is None: continue + # pyrefly: ignore # missing-argument if not tree.is_reduction or self.inside_reduction: sizes[tree.tensor_dim] = f"{tree.prefix.upper()}BLOCK" return sizes @@ -962,7 +965,10 @@ class SIMDKernel(Kernel[CSEVariableType], Generic[CSEVariableType]): def active_range_trees(self) -> list[IterationRangesRoot]: return [ - t for t in self.range_trees if not t.is_reduction or self.inside_reduction + # pyrefly: ignore # missing-argument + t + for t in self.range_trees + if not t.is_reduction or self.inside_reduction ] def codegen_indexing(self, expr: sympy.Expr) -> sympy.Expr: @@ -1110,6 +1116,7 @@ class SIMDKernel(Kernel[CSEVariableType], Generic[CSEVariableType]): numel = buf_size dtype = V.graph.get_dtype(arg) dtype_size = get_dtype_size(dtype) + # pyrefly: ignore # bad-argument-type nbytes.append(numel * dtype_size * (1 + int(i < ninplace_args))) return sum(nbytes) @@ -1130,6 +1137,7 @@ class SIMDKernel(Kernel[CSEVariableType], Generic[CSEVariableType]): argdefs, call_args, _signature, _ = self.args.python_argdefs() uniform_stride_order = None + # pyrefly: ignore # bad-assignment for arg_name in call_args: buf = V.graph.try_get_buffer(arg_name) if not buf: @@ -1753,11 +1761,13 @@ class SIMDScheduling(BaseScheduling): for input_name in kernel.named_input_nodes.keys(): subgraph_name = f"" + # pyrefly: ignore # missing-attribute partial_code.finalize_hook(subgraph_name, strict=False) num_store_subgraphs = kernel.get_store_output_count() for i in range(num_store_subgraphs): subgraph_name = kernel._get_store_output_subgraph_name(i) + # pyrefly: ignore # missing-attribute partial_code.finalize_hook(subgraph_name) if isinstance(partial_code, str): @@ -1879,6 +1889,7 @@ class SIMDScheduling(BaseScheduling): only_gen_src_code=True, ) assert isinstance(src_code, str) + # pyrefly: ignore # bad-argument-type src_codes.append(src_code) else: if size_hint is None: @@ -2708,6 +2719,7 @@ class SIMDScheduling(BaseScheduling): perf_hint_log.info("possibly bad tiling: %s", ranked_tilings) # Optionally, prefer tiling into as many dimensions as possible. + # pyrefly: ignore # unbound-name if config.triton.prefer_nd_tiling: ranked_tilings = ( cls.get_nd_tilings(node_schedule, numel, reduction_numel) @@ -2757,6 +2769,7 @@ class SIMDScheduling(BaseScheduling): hint_override=hint_override, ) + # pyrefly: ignore # missing-attribute src_code = src_code.replace(str(Placeholder.KERNEL_NAME), "triton_") return src_code diff --git a/torch/_inductor/codegen/subgraph.py b/torch/_inductor/codegen/subgraph.py index 1fbed50db91..a015d52d24f 100644 --- a/torch/_inductor/codegen/subgraph.py +++ b/torch/_inductor/codegen/subgraph.py @@ -80,6 +80,7 @@ class SubgraphChoiceCaller(ir.ChoiceCaller): bm_graph_lowering.graph_input_names.append(sym_inp.name) sym_inputs = [ + # pyrefly: ignore # no-matching-overload int(V.graph.sizevars.shape_env.size_hint(sym_var)) for sym_var in self.sym_inputs ] diff --git a/torch/_inductor/codegen/triton_combo_kernel.py b/torch/_inductor/codegen/triton_combo_kernel.py index c28321923c5..c52bd1dbeee 100644 --- a/torch/_inductor/codegen/triton_combo_kernel.py +++ b/torch/_inductor/codegen/triton_combo_kernel.py @@ -379,6 +379,7 @@ class ComboKernel(Kernel): def create_sub_kernel(self, triton_kernel: TritonKernel) -> TritonKernel: sub_kernel = triton_kernel + # pyrefly: ignore # bad-assignment metrics.generated_kernel_count -= 1 sub_kernel.args = self.args sub_kernel.iter_vars_count = self.iter_vars_count @@ -434,10 +435,12 @@ class ComboKernel(Kernel): assert f"{tree.prefix}numel_{num}" in self.dynamic_shape_args uniquify_block_sizes.append(f"{tree.prefix}numel") + # pyrefly: ignore # missing-argument if not tree.is_reduction: if isinstance(simplified_tree_numel, (Integer, int)): grid.append(int(simplified_tree_numel)) else: + # pyrefly: ignore # bad-argument-type grid.append(f"{tree.prefix}numel_{num}") if tree.is_reduction and sub_kernel.persistent_reduction: @@ -475,8 +478,10 @@ class ComboKernel(Kernel): if sub_kernel.no_x_dim: min_x_blocks = x_numels x_numels = ( + # pyrefly: ignore # unsupported-operation -min_x_blocks if isinstance(x_numels, int) + # pyrefly: ignore # redundant-cast else "-" + cast(str, x_numels) ) else: @@ -606,6 +611,7 @@ class ComboKernel(Kernel): "device": DeviceProperties.create(V.graph.get_current_device_or_throw()), "constants": {}, } + # pyrefly: ignore # unsupported-operation triton_meta["configs"] = [config_of(signature)] mutated_args = self.get_mutated_args_sub_kernels() dispatch = self.dispatch_class @@ -684,6 +690,7 @@ class ComboKernel(Kernel): for sub_kernel in self.sub_kernels: # TODO: we assume all sub_kernels have the same block size for tree in sub_kernel.range_trees: + # pyrefly: ignore # missing-argument if tree.is_reduction and ( not sub_kernel.inside_reduction or sub_kernel.persistent_reduction ): @@ -722,6 +729,7 @@ class ComboKernel(Kernel): expr = V.graph.wrapper_code.generate_numel_expr( name, tree, suffix=str(num) ) + # pyrefly: ignore # missing-argument if not tree.is_reduction or sub_kernel.inside_reduction: call_args.append(expr) arg_types.append(type(expr)) @@ -733,6 +741,7 @@ class ComboKernel(Kernel): numel_name = f"{tree.prefix}numel_{num}" if numel_name not in self.dynamic_shape_args: continue + # pyrefly: ignore # missing-argument if not tree.is_reduction or sub_kernel.inside_reduction: extra_args.append( str( @@ -1012,6 +1021,7 @@ class ComboKernel(Kernel): for num, sub_kernel in enumerate(self.sub_kernels): meta[f"no_x_dim_{num}"] = sub_kernel.no_x_dim for i, tree in enumerate(sub_kernel.range_trees): + # pyrefly: ignore # missing-argument if not tree.is_reduction: numel_name = f"{tree.prefix}numel_{num}" if numel_name in self.dynamic_shape_args: diff --git a/torch/_inductor/codegen/triton_utils.py b/torch/_inductor/codegen/triton_utils.py index d97988f684c..74385a4e284 100644 --- a/torch/_inductor/codegen/triton_utils.py +++ b/torch/_inductor/codegen/triton_utils.py @@ -256,4 +256,5 @@ def config_of( equal_to_1 = equal_1_arg_indices(args, indices=indices) + # pyrefly: ignore # bad-argument-type return AttrsDescriptorWrapper(divisible_by_16, equal_to_1) diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 226291f533b..dc613c46758 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -1115,6 +1115,7 @@ class PythonWrapperCodegen(CodeGen): return PythonWrapperCodegen() def set_launcher_fn_name(self) -> None: + # pyrefly: ignore # bad-assignment self.launcher_fn_name = "call" def write_constant(self, name: str, hashed: str) -> None: @@ -1251,14 +1252,17 @@ class PythonWrapperCodegen(CodeGen): self.write_get_raw_stream_header() def add_meta_once(self, meta: TritonMetaParams) -> str: + # pyrefly: ignore # bad-assignment meta = repr(meta) if meta not in self._metas: var = f"meta{len(self._metas)}" + # pyrefly: ignore # unsupported-operation self._metas[meta] = var self.header.writeline(f"{var} = {meta}") if config.triton.autotune_at_compile_time: self.kernel_autotune_calls.writeline(f"{var} = {meta}") self._meta_vars.add(var) + # pyrefly: ignore # index-error return self._metas[meta] @cache_on_self @@ -1694,6 +1698,7 @@ class PythonWrapperCodegen(CodeGen): with self.set_writeline(self.wrapper_call.writeline): for line in self.lines: if isinstance(line, WrapperLine): + # pyrefly: ignore # missing-attribute line.codegen(self.wrapper_call) else: self.wrapper_call.writeline(line) @@ -2774,13 +2779,18 @@ class PythonWrapperCodegen(CodeGen): self, kernel_name=kernel_name, call_args=call_args, + # pyrefly: ignore # bad-argument-type raw_keys=raw_keys, + # pyrefly: ignore # bad-argument-type raw_args=raw_args, + # pyrefly: ignore # bad-argument-type arg_types=arg_types, triton=triton, + # pyrefly: ignore # bad-argument-type triton_meta=triton_meta, device=device, graph_name=V.graph.name, + # pyrefly: ignore # bad-argument-type original_fxnode_name=original_fxnode_name, ) ) @@ -2901,6 +2911,7 @@ class PythonWrapperCodegen(CodeGen): reused_args = {} for i, (arg, arg_type, raw_key, raw_arg) in enumerate( + # pyrefly: ignore # no-matching-overload zip(call_args, arg_types, raw_keys, raw_args) ): key = None @@ -3688,6 +3699,7 @@ class SubgraphPythonWrapperCodegen(PythonWrapperCodegen): def set_launcher_fn_name(self) -> None: # This sets up the name of the function containing the launcher code of # the subgraph. + # pyrefly: ignore # bad-assignment self.launcher_fn_name = self.subgraph_name def write_header(self) -> None: diff --git a/torch/_inductor/codegen/wrapper_fxir.py b/torch/_inductor/codegen/wrapper_fxir.py index 897b2d6e15d..72c8e033550 100644 --- a/torch/_inductor/codegen/wrapper_fxir.py +++ b/torch/_inductor/codegen/wrapper_fxir.py @@ -186,6 +186,7 @@ class WrapperFxCodegen(PythonWrapperCodegen): """ Get the input nodes corresponding to FX graph placeholders. """ + # pyrefly: ignore # missing-argument if V.aot_compilation and not self.is_subgraph: # AOT graphs must match the signature of the input module. return { @@ -210,6 +211,7 @@ class WrapperFxCodegen(PythonWrapperCodegen): graph_inputs=self.get_fx_graph_inputs(), graph_outputs=self.get_graph_outputs(), subgms=self.subgms, + # pyrefly: ignore # missing-argument is_subgraph=self.is_subgraph, ).generate() @@ -992,13 +994,17 @@ class FxConverter: call_kwargs = { key: val for key, val in zip(signature, call_args) + # pyrefly: ignore # missing-attribute if key not in constants and key not in cfg.kwargs } # Add constants stored as Triton metadata, in signature order. call_kwargs |= constants new_call_args = [ - call_kwargs[key] for key in signature if key not in cfg.kwargs + # pyrefly: ignore # missing-attribute + call_kwargs[key] + for key in signature + if key not in cfg.kwargs ] # Add Inductor's extra launcher args to the end. @@ -1014,9 +1020,11 @@ class FxConverter: call_args = add_constants_to_call_args(call_args, kernel_config) call_args, grid = tuner._interpret_args_grid(call_args, kernel_config) call_kwargs = dict(zip(signature, call_args)) + # pyrefly: ignore # missing-attribute assert not any(kwarg in kernel_config.kwargs for kwarg in call_kwargs), ( f"kwargs overlap config: {call_kwargs}" ) + # pyrefly: ignore # missing-attribute call_kwargs.update(kernel_config.kwargs) # Replace sympy.floor with FloorDiv, to make the expression traceable. diff --git a/torch/_inductor/fx_passes/bucketing.py b/torch/_inductor/fx_passes/bucketing.py index a0f213a1e49..d509e8c515e 100644 --- a/torch/_inductor/fx_passes/bucketing.py +++ b/torch/_inductor/fx_passes/bucketing.py @@ -356,7 +356,7 @@ def bucket_all_reduce( mode: str | None = None, ) -> None: if bucket_cap_mb_by_bucket_idx is None: - from torch._inductor.fx_passes.bucketing import ( + from torch._inductor.fx_passes.bucketing import ( # pyrefly: ignore # missing-module-attribute bucket_cap_mb_by_bucket_idx_default, )