From 1cb4e2df659059ac85cfe6f64a2363a12ae1071a Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Fri, 28 Feb 2025 15:35:13 +0800 Subject: [PATCH] [BE][PYFMT] migrate PYFMT for `torch._inductor` to `ruff format` (#144550) Pull Request resolved: https://github.com/pytorch/pytorch/pull/144550 Approved by: https://github.com/jansel --- tools/linter/adapters/pyfmt_linter.py | 1 - torch/_inductor/__init__.py | 23 +++- .../_inductor/analyze_preserves_zero_mask.py | 5 +- torch/_inductor/async_compile.py | 3 +- torch/_inductor/choices.py | 7 +- torch/_inductor/codecache.py | 29 ++-- torch/_inductor/codegen/common.py | 10 +- torch/_inductor/codegen/cpp.py | 72 +++++----- .../codegen/cpp_flex_attention_template.py | 4 +- torch/_inductor/codegen/cpp_gemm_template.py | 30 ++--- .../codegen/cpp_grouped_gemm_template.py | 18 +-- torch/_inductor/codegen/cpp_micro_gemm.py | 6 +- torch/_inductor/codegen/cpp_template.py | 12 +- .../_inductor/codegen/cpp_template_kernel.py | 5 +- torch/_inductor/codegen/cpp_wrapper_cpu.py | 93 +++++++------ .../codegen/cpp_wrapper_cpu_array_ref.py | 30 ++--- torch/_inductor/codegen/cpp_wrapper_gpu.py | 36 ++--- .../codegen/cuda/cuda_cpp_scheduling.py | 6 +- torch/_inductor/codegen/cuda/cuda_kernel.py | 4 +- torch/_inductor/codegen/cuda/cuda_template.py | 15 ++- .../gemm_operation_extensions.py | 8 +- torch/_inductor/codegen/cuda/cutlass_utils.py | 12 +- torch/_inductor/codegen/cuda/gemm_template.py | 11 +- torch/_inductor/codegen/halide.py | 24 ++-- torch/_inductor/codegen/memory_planning.py | 3 +- torch/_inductor/codegen/mps.py | 5 +- torch/_inductor/codegen/multi_kernel.py | 10 +- .../codegen/rocm/ck_conv_template.py | 7 +- .../rocm/ck_universal_gemm_template.py | 9 +- .../codegen/rocm/rocm_cpp_scheduling.py | 6 +- torch/_inductor/codegen/rocm/rocm_kernel.py | 4 +- torch/_inductor/codegen/rocm/rocm_template.py | 15 ++- torch/_inductor/codegen/simd.py | 19 +-- torch/_inductor/codegen/triton.py | 71 +++++----- .../_inductor/codegen/triton_combo_kernel.py | 40 +++--- torch/_inductor/codegen/triton_split_scan.py | 6 +- torch/_inductor/codegen/triton_utils.py | 3 +- torch/_inductor/codegen/wrapper.py | 54 ++++---- torch/_inductor/comms.py | 11 +- torch/_inductor/compile_fx.py | 127 ++++++++++-------- torch/_inductor/compiler_bisector.py | 3 +- torch/_inductor/config.py | 18 +-- torch/_inductor/constant_folding.py | 3 +- torch/_inductor/cpp_builder.py | 6 +- torch/_inductor/cudagraph_trees.py | 94 +++++++------ torch/_inductor/cudagraph_utils.py | 11 +- torch/_inductor/debug.py | 2 +- torch/_inductor/decomposition.py | 2 +- torch/_inductor/dependencies.py | 24 +++- torch/_inductor/dtype_propagation.py | 3 +- torch/_inductor/fuzzer.py | 7 +- torch/_inductor/fx_passes/b2b_gemm.py | 6 +- torch/_inductor/fx_passes/ddp_fusion.py | 6 +- .../fx_passes/efficient_conv_bn_eval.py | 4 +- torch/_inductor/fx_passes/fuse_attention.py | 48 ++++--- .../_inductor/fx_passes/group_batch_fusion.py | 22 +-- .../_inductor/fx_passes/micro_pipeline_tp.py | 4 +- torch/_inductor/fx_passes/mkldnn_fusion.py | 24 ++-- torch/_inductor/fx_passes/pad_mm.py | 2 +- torch/_inductor/fx_passes/post_grad.py | 6 +- torch/_inductor/fx_passes/pre_grad.py | 6 +- torch/_inductor/fx_passes/quantization.py | 18 +-- torch/_inductor/fx_passes/split_cat.py | 87 ++++++++---- torch/_inductor/fx_utils.py | 6 +- torch/_inductor/graph.py | 60 +++++---- torch/_inductor/index_propagation.py | 28 ++-- torch/_inductor/ir.py | 106 ++++++++------- torch/_inductor/kernel/flex_attention.py | 54 ++++---- torch/_inductor/kernel/flex_decoding.py | 12 +- torch/_inductor/kernel/mm_common.py | 12 +- torch/_inductor/loop_body.py | 23 ++-- torch/_inductor/lowering.py | 105 ++++++++------- torch/_inductor/memory.py | 20 +-- torch/_inductor/metrics.py | 18 +-- torch/_inductor/mkldnn_ir.py | 6 +- torch/_inductor/mkldnn_lowerings.py | 12 +- torch/_inductor/ops_handler.py | 9 +- torch/_inductor/package/package.py | 10 +- torch/_inductor/pattern_matcher.py | 30 +++-- torch/_inductor/runtime/benchmarking.py | 2 +- torch/_inductor/runtime/runtime_utils.py | 6 +- torch/_inductor/runtime/triton_heuristics.py | 62 +++++---- torch/_inductor/scheduler.py | 110 +++++++-------- torch/_inductor/select_algorithm.py | 81 +++++------ torch/_inductor/sizevars.py | 6 +- torch/_inductor/utils.py | 52 +++---- torch/_inductor/virtualized.py | 6 +- torch/_inductor/wrapper_benchmark.py | 15 +-- 88 files changed, 1157 insertions(+), 954 deletions(-) diff --git a/tools/linter/adapters/pyfmt_linter.py b/tools/linter/adapters/pyfmt_linter.py index cb8157fac62..42564d119fb 100644 --- a/tools/linter/adapters/pyfmt_linter.py +++ b/tools/linter/adapters/pyfmt_linter.py @@ -53,7 +53,6 @@ USE_BLACK_FILELIST = re.compile( # torch/_[e-h]*/** "torch/_[e-h]*/**", # torch/_i*/** - "torch/_i*/**", # torch/_[j-z]*/** "torch/_[j-z]*/**", # torch/[a-c]*/** diff --git a/torch/_inductor/__init__.py b/torch/_inductor/__init__.py index 3cef26f31e8..4160d7107d4 100644 --- a/torch/_inductor/__init__.py +++ b/torch/_inductor/__init__.py @@ -66,7 +66,9 @@ def aoti_compile_and_package( .. code-block:: python ep = torch.export.export(M(), ...) - aoti_file = torch._inductor.aoti_compile_and_package(ep, package_path="my_package.pt2") + aoti_file = torch._inductor.aoti_compile_and_package( + ep, package_path="my_package.pt2" + ) compiled_model = torch._inductor.aoti_load_package("my_package.pt2") To compile and save multiple models into a single ``.pt2`` artifact, you can do @@ -75,11 +77,16 @@ def aoti_compile_and_package( .. code-block:: python ep1 = torch.export.export(M1(), ...) - aoti_file1 = torch._inductor.aot_compile(ep1, ..., options={"aot_inductor.package": True}) + aoti_file1 = torch._inductor.aot_compile( + ep1, ..., options={"aot_inductor.package": True} + ) ep2 = torch.export.export(M2(), ...) - aoti_file2 = torch._inductor.aot_compile(ep2, ..., options={"aot_inductor.package": True}) + aoti_file2 = torch._inductor.aot_compile( + ep2, ..., options={"aot_inductor.package": True} + ) from torch._inductor.package import package_aoti, load_package + package_aoti("my_package.pt2", {"model1": aoti_file1, "model2": aoti_file2}) compiled_model1 = load_package("my_package.pt2", "model1") @@ -123,7 +130,9 @@ def aoti_compile_and_package( isinstance(package_path, (str, os.PathLike)) and os.fspath(package_path).endswith(".pt2") ) - ), f"Expect package path to be a file ending in .pt2, is None, or is a buffer. Instead got {package_path}" + ), ( + f"Expect package path to be a file ending in .pt2, is None, or is a buffer. Instead got {package_path}" + ) inductor_configs = inductor_configs or {} inductor_configs["aot_inductor.package"] = True @@ -168,9 +177,9 @@ def _aoti_compile_and_package_inner( """ if check_accuracy: - assert ( - kwargs is None or len(kwargs) == 0 - ), "when checking for accuracy, the inputs must have been flattened and kwargs is None" + assert kwargs is None or len(kwargs) == 0, ( + "when checking for accuracy, the inputs must have been flattened and kwargs is None" + ) from .package import package_aoti diff --git a/torch/_inductor/analyze_preserves_zero_mask.py b/torch/_inductor/analyze_preserves_zero_mask.py index 771ecb8c3ff..90d0ff80c5f 100644 --- a/torch/_inductor/analyze_preserves_zero_mask.py +++ b/torch/_inductor/analyze_preserves_zero_mask.py @@ -156,8 +156,9 @@ def can_codegen_without_upcasts( low_prec_analysis = RecordLowPrecisionOps(disallow_fp32_ops) # Need to turn off upcasting to do analysis of whether we can turn it off - with config.patch("triton.codegen_upcast_to_fp32", False), V.set_ops_handler( - low_prec_analysis + with ( + config.patch("triton.codegen_upcast_to_fp32", False), + V.set_ops_handler(low_prec_analysis), ): prologue._body(*prologue.get_ranges()) diff --git a/torch/_inductor/async_compile.py b/torch/_inductor/async_compile.py index 91d470c0409..e2090917861 100644 --- a/torch/_inductor/async_compile.py +++ b/torch/_inductor/async_compile.py @@ -245,8 +245,7 @@ class AsyncCompile: def use_process_pool(self): return ( - get_compile_threads() > 1 - and self.process_pool().ready_future.done() # type: ignore[union-attr] + get_compile_threads() > 1 and self.process_pool().ready_future.done() # type: ignore[union-attr] ) def triton(self, kernel_name: str, source_code: str, device_str: str = "cuda"): diff --git a/torch/_inductor/choices.py b/torch/_inductor/choices.py index b629d56ee62..ce7e941ee1f 100644 --- a/torch/_inductor/choices.py +++ b/torch/_inductor/choices.py @@ -24,8 +24,7 @@ if TYPE_CHECKING: class Sortable(typing.Protocol): """Anything that can be used as a list.sort() key (int/tuple/etc)""" - def __lt__(self, other: typing.Self) -> bool: - ... + def __lt__(self, other: typing.Self) -> bool: ... class InductorChoices: @@ -100,7 +99,9 @@ class InductorChoices: # to pick the faster one. if config.triton.multi_kernel: threshold *= 16 - return V.graph.sizevars.statically_known_leq(features.reduction_numel, threshold) # type: ignore[arg-types] + return V.graph.sizevars.statically_known_leq( + features.reduction_numel, threshold + ) # type: ignore[arg-types] @staticmethod def want_no_x_dim(features: SIMDKernelFeatures) -> bool: diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index d19da33b9bb..d0f8546a290 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -417,9 +417,9 @@ def write_atomic( ) -> None: # Write into temporary file first to avoid conflicts between threads # Avoid using a named temporary file, as those have restricted permissions - assert isinstance( - content, (str, bytes) - ), "Only strings and byte arrays can be saved in the cache" + assert isinstance(content, (str, bytes)), ( + "Only strings and byte arrays can be saved in the cache" + ) path = Path(path_) if make_dirs: path.parent.mkdir(parents=True, exist_ok=True) @@ -975,9 +975,9 @@ class FxGraphCache: symints = FxGraphCache._filter_backed_symints(example_inputs) hints = [hint_int(s) for s in symints] - def iterate_over_candidates() -> ( - Generator[tuple[CompiledFxGraph, bytes], None, None] - ): + def iterate_over_candidates() -> Generator[ + tuple[CompiledFxGraph, bytes], None, None + ]: if local: subdir = FxGraphCache._get_tmp_dir_for_key(key) if os.path.exists(subdir): @@ -1123,9 +1123,9 @@ class FxGraphCache: """ from .compile_fx import CompiledFxGraph - assert isinstance( - compiled_graph, CompiledFxGraph - ), f"serialization for {type(compiled_graph)} NYI" + assert isinstance(compiled_graph, CompiledFxGraph), ( + f"serialization for {type(compiled_graph)} NYI" + ) disk_compiled_graph = copy(compiled_graph) disk_compiled_graph.prepare_for_serialization() @@ -1315,9 +1315,8 @@ class FxGraphCache: "distributed_ephemeral_timeout_us", time_saved_ns // 1000 ) if ( - ephemeral_increase := add_ephemeral_timeout_increase_for_distributed( - time_saved_ns - ) + ephemeral_increase + := add_ephemeral_timeout_increase_for_distributed(time_saved_ns) ) != 0: cache_info["ephemeral_timeout_increase"] = ephemeral_increase else: @@ -1556,9 +1555,9 @@ class AotCodeCompiler: cpp_path_operator.with_name(f"{cpp_path_operator.stem}_metadata.json") ) for k, v in config.aot_inductor.metadata.items(): - assert isinstance(k, str) and isinstance( - v, (str) - ), "Metadata must only contain strings" + assert isinstance(k, str) and isinstance(v, (str)), ( + "Metadata must only contain strings" + ) with open(meta_json, "w") as f: f.write(json.dumps(config.aot_inductor.metadata)) diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index da09d34c983..c1e9053b90d 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -341,7 +341,7 @@ class BackendFeature(Enum): def get_backend_features( - device: Union[torch.device, str, None] + device: Union[torch.device, str, None], ) -> OrderedSet[BackendFeature]: if device is None: return OrderedSet() @@ -986,9 +986,9 @@ class OpOverrides(BasicMathOpsMixin, OpDecompositions, OpsHandler[Any]): if cls._is_unimplemented(funcname): setattr(cls, funcname, cls._unimplemented(funcname)) else: - assert ( - funcname not in cls.__dict__ - ), f"multiple definitions of {funcname} on {cls.__name__}" + assert funcname not in cls.__dict__, ( + f"multiple definitions of {funcname} on {cls.__name__}" + ) impl.__name__ = funcname setattr(cls, funcname, staticmethod(impl)) @@ -2229,7 +2229,7 @@ class KernelTemplate: @staticmethod def _fake_get_dtype( - fake_outs: Union[list[Buffer], Buffer] + fake_outs: Union[list[Buffer], Buffer], ) -> Callable[[str], torch.dtype]: _get_dtype_real = V.graph.get_dtype if isinstance(fake_outs, (list, tuple)): diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 70790602f05..ccf6a7d0804 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -483,9 +483,9 @@ class OuterLoopFusedSchedulerNode(FusedSchedulerNode): outer_fused_nodes: list[Union[FusedSchedulerNode, SchedulerNode]], outer_loop_fusion_depth, ): - self.outer_fused_nodes: list[ - Union[FusedSchedulerNode, SchedulerNode] - ] = outer_fused_nodes + self.outer_fused_nodes: list[Union[FusedSchedulerNode, SchedulerNode]] = ( + outer_fused_nodes + ) self.outer_loop_fusion_depth = outer_loop_fusion_depth flatten_snodes = [] for _node in self.outer_fused_nodes: @@ -1361,9 +1361,9 @@ class CppVecOverrides(CppOverrides): @staticmethod def remainder(a, b): - assert ( - a.dtype == b.dtype - ), "remainder vec implementation expect the same inputs' dtype." + assert a.dtype == b.dtype, ( + "remainder vec implementation expect the same inputs' dtype." + ) return f"{a} - ({CppVecOverrides.floordiv(a, b)}) * {b}" @staticmethod @@ -1468,9 +1468,9 @@ class CppVecOverrides(CppOverrides): @staticmethod def floordiv(a, b): if is_float_dtype(a.dtype): - assert ( - a.dtype == b.dtype - ), "div_floor_floating_vec implementation expect the same inputs' dtype." + assert a.dtype == b.dtype, ( + "div_floor_floating_vec implementation expect the same inputs' dtype." + ) return f"div_floor_floating_vec({a}, {b})" else: assert all(is_integer_dtype(item.dtype) for item in [a, b]) @@ -1629,9 +1629,9 @@ class CppVecOverrides(CppOverrides): assert isinstance(other_vec_var, CppCSEVariable), other_vec_var body_vec_var.dtype = dtype other_vec_var.dtype = dtype - overrides: type[ - Union[CppOverrides, CppVecOverrides] - ] = V.kernel.overrides # type: ignore[has-type] + overrides: type[Union[CppOverrides, CppVecOverrides]] = ( + V.kernel.overrides + ) # type: ignore[has-type] code.writeline( f"return {overrides.where(new_mask, body_vec_var, other_vec_var)};" ) @@ -2108,9 +2108,9 @@ class CppKernel(Kernel): def set_ranges(self, lengths, reduction_lengths): if self.call_ranges: - assert self.call_ranges == tuple(lengths) + tuple( - reduction_lengths - ), f"{self.call_ranges} == {tuple(lengths)} + {tuple(reduction_lengths)}" + assert self.call_ranges == tuple(lengths) + tuple(reduction_lengths), ( + f"{self.call_ranges} == {tuple(lengths)} + {tuple(reduction_lengths)}" + ) assert self.reduction_depth == len(lengths) else: self.call_ranges = tuple(lengths) + tuple(reduction_lengths) @@ -2787,9 +2787,9 @@ class CppVecKernel(CppKernel): self.weight_recps_val = self.weight_recps_cse.generate( self.compute, f"reduction {self.weight_recp_vec_range}", write=False ) - self.weight_recps_cse.reduction_cache[ - self.weight_recp_vec_range - ] = self.weight_recps_val + self.weight_recps_cse.reduction_cache[self.weight_recp_vec_range] = ( + self.weight_recps_val + ) self.non_parallel_reduction_prefix.writeline( self.welford_weight_reciprocal_vec(dtype) ) @@ -4969,9 +4969,9 @@ class CppScheduling(BaseScheduling): ] counters["inductor"]["cpp_epilogue_fusion_counter"] += len(epilogue_nodes) - assert self.is_cpp_template( - template_node - ), "Template node passed to CppScheduler.codegen_template must be a SchedulerNode that wraps a CppTemplateBuffer" + assert self.is_cpp_template(template_node), ( + "Template node passed to CppScheduler.codegen_template must be a SchedulerNode that wraps a CppTemplateBuffer" + ) template_node = cast(SchedulerNode, template_node) _, (_, rnumel) = template_node.group assert rnumel == () @@ -4979,9 +4979,9 @@ class CppScheduling(BaseScheduling): epilogue_ir_nodes: list[Optional[ir.Operation]] = [ n.node for n in epilogue_nodes ] - assert all( - isinstance(n, ir.ComputedBuffer) for n in epilogue_ir_nodes - ), "Epilogue nodes must all be instances of ir.ComputedBuffer" + assert all(isinstance(n, ir.ComputedBuffer) for n in epilogue_ir_nodes), ( + "Epilogue nodes must all be instances of ir.ComputedBuffer" + ) def template_buffer_has_other_users( template_buffer, outputs_by_name, epilogue_nodes @@ -5019,16 +5019,16 @@ class CppScheduling(BaseScheduling): if is_multi_outputs_template(template_node.node): # For multi outputs template, allocate buffers for each output after the epilogue # codegen to which determines if the buffer has been removed. - assert ( - len(template_node.outputs) == 1 - ), "Multi outputs template should be with 1 output template buffer of MultiOutputLayout" + assert len(template_node.outputs) == 1, ( + "Multi outputs template should be with 1 output template buffer of MultiOutputLayout" + ) for user in template_node.outputs[0].users: - assert isinstance( - user.node, ExternKernelSchedulerNode - ), "Multi outputs template should be with ExternKernelSchedulerNode" - assert isinstance( - user.node.node, ir.MultiOutput - ), "Multi outputs template has multi users with MultiOutput" + assert isinstance(user.node, ExternKernelSchedulerNode), ( + "Multi outputs template should be with ExternKernelSchedulerNode" + ) + assert isinstance(user.node.node, ir.MultiOutput), ( + "Multi outputs template has multi users with MultiOutput" + ) user.node.mark_run() kernel.call_kernel(kernel_name, ctb) @@ -5347,9 +5347,9 @@ class LoopNest: return self.loops is not None and self.loops[0].is_reduction def mark_parallel(self, par_depth): - assert ( - par_depth <= self.max_parallel_depth() - ), "Parallel depth cannot exceed the maximal allowed parallel depth" + assert par_depth <= self.max_parallel_depth(), ( + "Parallel depth cannot exceed the maximal allowed parallel depth" + ) assert self.loops is not None assert len(self.loops) >= par_depth loop = self.loops[0] diff --git a/torch/_inductor/codegen/cpp_flex_attention_template.py b/torch/_inductor/codegen/cpp_flex_attention_template.py index 4d17976bb1b..e615a93fdc9 100644 --- a/torch/_inductor/codegen/cpp_flex_attention_template.py +++ b/torch/_inductor/codegen/cpp_flex_attention_template.py @@ -862,7 +862,9 @@ class CppFlexAttentionTemplate(CppTemplate): assert all( mem.buffer_name in kernel_group.args.input_buffers for mem in body.memory_usage[MemoryUsageType.LOAD] - ), "All the buffers in the score and mask subgraph should be in kernel_group.args.input_buffers" + ), ( + "All the buffers in the score and mask subgraph should be in kernel_group.args.input_buffers" + ) bodies.append(body) var_sizes_list.append((var_sizes, ())) diff --git a/torch/_inductor/codegen/cpp_gemm_template.py b/torch/_inductor/codegen/cpp_gemm_template.py index 3b2ffb98dcb..24851e66a1b 100644 --- a/torch/_inductor/codegen/cpp_gemm_template.py +++ b/torch/_inductor/codegen/cpp_gemm_template.py @@ -557,9 +557,9 @@ class CppGemmTemplate(CppTemplate): thread_block_m = math.ceil(m_blocks / m_factor) return GemmBlocking(thread_block_m, thread_block_n, thread_block_k) - assert ( - not self.is_dynamic_M - ), "Unable to determine thread blocking for dynamic M." + assert not self.is_dynamic_M, ( + "Unable to determine thread blocking for dynamic M." + ) register_blocking = self.register_blocking m_blocks = math.ceil(self.m / register_blocking.block_m) n_blocks = math.ceil(self.n / register_blocking.block_n) @@ -673,17 +673,17 @@ class CppGemmTemplate(CppTemplate): L1_cache_size = ( torch._C._cpu._L1d_cache_size() ) # per core cache size in Bytes - assert ( - L1_cache_size > 0 - ), f"Expect L1_cache_size > 0 but got {L1_cache_size}" + assert L1_cache_size > 0, ( + f"Expect L1_cache_size > 0 but got {L1_cache_size}" + ) L1 = L1_cache_size * L1_limit_factor L2_cache_size = ( torch._C._cpu._L2_cache_size() ) # per core cache size in Bytes - assert ( - L2_cache_size > 0 - ), f"Expect L2_cache_size > 0 but got {L2_cache_size}" + assert L2_cache_size > 0, ( + f"Expect L2_cache_size > 0 but got {L2_cache_size}" + ) L2 = L2_cache_size * L2_limit_factor def get_num_byte(dtype): @@ -744,9 +744,9 @@ class CppGemmTemplate(CppTemplate): return Mc_blocks, Nc_blocks, Kc_blocks - assert ( - not self.is_dynamic_M - ), "Unable to determine cache blocking for dynamic M." + assert not self.is_dynamic_M, ( + "Unable to determine cache blocking for dynamic M." + ) register_blocking = self.register_blocking thread_blocking = self.thread_blocking(num_threads) @@ -1114,9 +1114,9 @@ class CppGemmTemplate(CppTemplate): LayoutType.VNNI4, ], f"We only support {layout_str} for now" vnni_size = 4 if micro_gemm.get_b_layout() == LayoutType.VNNI4 else 2 - assert ( - k % vnni_size == 0 - ), f"k should be divisible by vnni_size for {layout_str} layout" + assert k % vnni_size == 0, ( + f"k should be divisible by vnni_size for {layout_str} layout" + ) vnni_view_size = list(new_size) vnni_view_size[-2] = k // vnni_size vnni_view_size.insert(-1, vnni_size) diff --git a/torch/_inductor/codegen/cpp_grouped_gemm_template.py b/torch/_inductor/codegen/cpp_grouped_gemm_template.py index 75c4ef3318c..4b973522227 100644 --- a/torch/_inductor/codegen/cpp_grouped_gemm_template.py +++ b/torch/_inductor/codegen/cpp_grouped_gemm_template.py @@ -309,9 +309,9 @@ class CppGroupedGemmTemplate(CppGemmTemplate): for W_node in W_nodes: assert W_node.get_name() in V.graph.constants W_tensor.append(V.graph.constants[W_node.get_name()]) - new_input_nodes[ - wgt_start_idx : wgt_start_idx + gemm_grouped_num - ] = W_tensor # type: ignore[assignment] + new_input_nodes[wgt_start_idx : wgt_start_idx + gemm_grouped_num] = ( + W_tensor # type: ignore[assignment] + ) new_input_nodes, _ = pack_weight( *normalize_shapes(*maybe_to_dense(new_input_nodes, layout)) ) @@ -321,9 +321,9 @@ class CppGroupedGemmTemplate(CppGemmTemplate): W_packed = new_input_nodes[idx] assert isinstance(W_packed, torch.Tensor) W_packed_constant = V.graph.add_tensor_constant(W_packed) - template_buffer.inputs[ - idx - ] = ir.InputsKernel.unwrap_storage_for_input(W_packed_constant) + template_buffer.inputs[idx] = ( + ir.InputsKernel.unwrap_storage_for_input(W_packed_constant) + ) return output template = DataProcessorTemplateWrapper( @@ -419,9 +419,9 @@ class CppGroupedGemmTemplate(CppGemmTemplate): ir.Buffer(name=gemm_output_name, layout=template_buffer.layout) ) - assert ( - not self.epilogue_creator - ), "epilogue_creator is not supported yet in Grouped GEMM Template" + assert not self.epilogue_creator, ( + "epilogue_creator is not supported yet in Grouped GEMM Template" + ) kernel_args: dict[str, Optional[ir.IRNode]] = {} for x_idx in range(wgt_start_idx): diff --git a/torch/_inductor/codegen/cpp_micro_gemm.py b/torch/_inductor/codegen/cpp_micro_gemm.py index e02171fd325..bbfc54f6a66 100644 --- a/torch/_inductor/codegen/cpp_micro_gemm.py +++ b/torch/_inductor/codegen/cpp_micro_gemm.py @@ -231,9 +231,9 @@ micro_gemm_configs: dict[type[CppMicroGemm], list[CppMicroGemmConfig]] = {} def register_micro_gemm(*configs): def inner(cls): - assert ( - cls not in micro_gemm_configs - ), f"Duplicate micro_gemm registration for {cls}" + assert cls not in micro_gemm_configs, ( + f"Duplicate micro_gemm registration for {cls}" + ) assert len(configs) > 0, f"No micro_gemm configs provided for {cls}" micro_gemm_configs[cls] = list(configs) return cls diff --git a/torch/_inductor/codegen/cpp_template.py b/torch/_inductor/codegen/cpp_template.py index b94d0b54228..3c01c5a398c 100644 --- a/torch/_inductor/codegen/cpp_template.py +++ b/torch/_inductor/codegen/cpp_template.py @@ -44,11 +44,13 @@ class CppTemplate(KernelTemplate): def generate(self, **kwargs): kernel_name = f"cpp_{self.name}" - with patch.object( - V.graph, "get_dtype", self._fake_get_dtype(self.output_node) - ), patch.object(ir.FlexibleLayout, "allow_indexing", True), CppTemplateKernel( - kernel_name=kernel_name, num_threads=self.num_threads - ) as kernel: + with ( + patch.object(V.graph, "get_dtype", self._fake_get_dtype(self.output_node)), + patch.object(ir.FlexibleLayout, "allow_indexing", True), + CppTemplateKernel( + kernel_name=kernel_name, num_threads=self.num_threads + ) as kernel, + ): code = kernel.render(self, **kwargs) _, call_args, _, _ = kernel.args.python_argdefs() log.debug("Generated Code:\n%s", code) diff --git a/torch/_inductor/codegen/cpp_template_kernel.py b/torch/_inductor/codegen/cpp_template_kernel.py index 137825a7934..c0ff25f7efc 100644 --- a/torch/_inductor/codegen/cpp_template_kernel.py +++ b/torch/_inductor/codegen/cpp_template_kernel.py @@ -377,7 +377,10 @@ class CppTemplateKernel(CppKernel): ) epilogue_nodes = scope.localize_nodes(epilogue_nodes) return self.store_pointwise_nodes( - dst, epilogue_nodes, offsets, reindexers # type: ignore[arg-type] + dst, + epilogue_nodes, # type: ignore[arg-type] + offsets, + reindexers, ) else: if dst.get_name() != src.get_name(): diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index 48579f83337..75c3a4d03cc 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -110,9 +110,9 @@ class CppWrapperCpu(PythonWrapperCodegen): Only valid when cuda == True. """ assert not gpu, "CppWrapperCpu.generate_kernel_call does not support GPU" - assert arg_types is not None and len(call_args) == len( - arg_types - ), "Mismatch call_args and arg_types in generate_kernel_call" + assert arg_types is not None and len(call_args) == len(arg_types), ( + "Mismatch call_args and arg_types in generate_kernel_call" + ) new_args = [] for idx, arg in enumerate(call_args): if "*" in arg_types[idx]: @@ -506,9 +506,9 @@ class CppWrapperCpu(PythonWrapperCodegen): dtype = may_get_constant_buffer_dtype( V.graph.graph_inputs[input_key] # type: ignore[arg-type] ) - assert ( - dtype is not None - ), "Fails to get the dtype of the sympy.Expr" + assert dtype is not None, ( + "Fails to get the dtype of the sympy.Expr" + ) self.codegen_tensor_item( dtype, f"inputs[{idx}]", input_key, self.prefix ) @@ -555,8 +555,7 @@ class CppWrapperCpu(PythonWrapperCodegen): def codegen_tensor_dtype_var_decl(self, code: IndentedBuffer, name): code.writeline(f"int32_t {name}_dtype;") code.writeline( - "AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype" - f"({name}, &{name}_dtype));" + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype({name}, &{name}_dtype));" ) def codegen_input_size_var_decl(self, code: IndentedBuffer, name): @@ -570,9 +569,9 @@ class CppWrapperCpu(PythonWrapperCodegen): # Tell compiler we need to link with the non-mangled symbols for kernel in self.initialized_kernels.values(): - assert hasattr( - kernel, "get_signature" - ), f"{kernel} must have get_signature implemented" + assert hasattr(kernel, "get_signature"), ( + f"{kernel} must have get_signature implemented" + ) signature = kernel.get_signature() self.prefix.writeline(f'extern "C" {signature};') @@ -597,9 +596,9 @@ class CppWrapperCpu(PythonWrapperCodegen): ) ) for name, kernel in self.initialized_kernels.items(): - assert hasattr( - kernel, "get_signature" - ), f"{kernel} must have get_signature implemented" + assert hasattr(kernel, "get_signature"), ( + f"{kernel} must have get_signature implemented" + ) kernel_ptr = f"(*{name})" signature = kernel.get_signature().replace(name, kernel_ptr) self.prefix.writeline(f" {signature} = torch::aot_inductor::{name};") @@ -645,9 +644,9 @@ class CppWrapperCpu(PythonWrapperCodegen): with self.prefix.indent(): for idx, (name, inp) in enumerate(V.graph.graph_inputs.items()): - assert not isinstance( - inp, sympy.Expr - ), f"input {name=} cannot be symbolic" + assert not isinstance(inp, sympy.Expr), ( + f"input {name=} cannot be symbolic" + ) self.write_input_output_info("inputs_info_", idx, name) all_cuda = all( @@ -718,9 +717,9 @@ class CppWrapperCpu(PythonWrapperCodegen): opaque_metadata_tensor = torch.ops.mkldnn._get_mkldnn_serialized_md( tensor ) - assert ( - opaque_metadata_tensor.dim() == 1 - ), "Expect opaque_metadata_tensor to be 1-D" + assert opaque_metadata_tensor.dim() == 1, ( + "Expect opaque_metadata_tensor to be 1-D" + ) opaque_metadata_list = opaque_metadata_tensor.tolist() opaque_metadata_str = self.codegen_shape_tuple(opaque_metadata_list) @@ -757,9 +756,9 @@ class CppWrapperCpu(PythonWrapperCodegen): ) for idx, output in enumerate(V.graph.graph_outputs): - assert not isinstance( - output, sympy.Expr - ), f"output {name=} cannot be symbolic" + assert not isinstance(output, sympy.Expr), ( + f"output {name=} cannot be symbolic" + ) name = f"output{idx}" self.write_input_output_info("outputs_info_", idx, name) @@ -816,9 +815,9 @@ class CppWrapperCpu(PythonWrapperCodegen): for idx, (name, _) in enumerate(V.graph.constants.items()): if name in V.graph.const_output_index: const_index_mapping[V.graph.const_output_index[name]] = (idx, name) # type: ignore[call-overload] - assert ( - None not in const_index_mapping - ), "Not all constant gets mapped for constant folding graph." + assert None not in const_index_mapping, ( + "Not all constant gets mapped for constant folding graph." + ) self.prefix.writeline( f""" @@ -1117,9 +1116,9 @@ class CppWrapperCpu(PythonWrapperCodegen): name = f"{output.get_name()}" output_handle_name = f"{name}_handle" if output.indices: - assert ( - output.indices[0][1] == idx - ), f"expected {output.indices[0][1]=} == {idx=} for {output_name_base=}" + assert output.indices[0][1] == idx, ( + f"expected {output.indices[0][1]=} == {idx=} for {output_name_base=}" + ) self.writeline(f"AtenTensorHandle {output_handle_name};") output_args.append(f"&{output_handle_name}") output_raii_handles.append( @@ -1140,7 +1139,9 @@ class CppWrapperCpu(PythonWrapperCodegen): args = args + output_args device = d.type if (d := fallback_kernel.get_device()) else self.device self.generate_c_shim_extern_kernel_call( - fallback_kernel.cpp_kernel_name, args, device # type: ignore[arg-type] + fallback_kernel.cpp_kernel_name, # type: ignore[arg-type] + args, + device, ) for raii_handle in output_raii_handles: self.writeline(raii_handle) @@ -1189,9 +1190,9 @@ class CppWrapperCpu(PythonWrapperCodegen): if reduce: line += f", {V.graph.wrapper_code.val_to_arg_str(reduce)}" else: - assert ( - reduce is None - ), "Expect reduce to be None for aten.scatter_ with scalar src" + assert reduce is None, ( + "Expect reduce to be None for aten.scatter_ with scalar src" + ) line += ");" self.writeline(line) @@ -1841,18 +1842,24 @@ class CppWrapperCpu(PythonWrapperCodegen): # Only treat int Scalar as dynamic is_int_type = [isinstance(a, int) for a in arg] if any(is_int_type): - assert all( - is_int_type - ), "AOTInductor only supports int scalars of the same type" + assert all(is_int_type), ( + "AOTInductor only supports int scalars of the same type" + ) new_int_args.extend([str(a) for a in arg]) else: assert isinstance( - arg_type.getElementType(), static_arg_types # type: ignore[arg-type] - ), f"Fall through arguments must be one of static_arg_types, got {type(arg_type)}" + arg_type.getElementType(), + static_arg_types, # type: ignore[arg-type] + ), ( + f"Fall through arguments must be one of static_arg_types, got {type(arg_type)}" + ) else: assert isinstance( - arg_type, static_arg_types # type: ignore[arg-type] - ), f"Fall through arguments must be one of static_arg_types, got {type(arg_type)}" + arg_type, + static_arg_types, # type: ignore[arg-type] + ), ( + f"Fall through arguments must be one of static_arg_types, got {type(arg_type)}" + ) for arg, arg_type in zip(raw_args, arg_types): if arg is not None: @@ -2378,9 +2385,9 @@ if (custom_op_wrapper.get() == NULL) { return f"&{var_name}" if isinstance(type_, torch.ListType): - assert isinstance( - val, (list, tuple) - ), f"{val} does not match with arg type {type_}" + assert isinstance(val, (list, tuple)), ( + f"{val} does not match with arg type {type_}" + ) element_type = type_.getElementType() var_name = f"var_array_{next(self.var_array_id)}" if len(val) == 0: diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py b/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py index f39bad14baa..51ba610b04f 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py @@ -56,9 +56,9 @@ class CppWrapperCpuArrayRef(CppWrapperCpu): self.cached_output_id = count() self.scalar_to_tensor_id = count() self.custom_op_wrapper_loaded = False - self.allow_stack_allocation: Optional[ - bool - ] = config.aot_inductor.allow_stack_allocation + self.allow_stack_allocation: Optional[bool] = ( + config.aot_inductor.allow_stack_allocation + ) self.stack_allocated_buffers: dict[BufferName, BufferLike] = {} @staticmethod @@ -126,12 +126,12 @@ class CppWrapperCpuArrayRef(CppWrapperCpu): Otherwise it uses the CUDA language for codegen. Only valid when cuda == True. """ - assert ( - not gpu - ), "CppWrapperCpuArrayRef.generate_kernel_call does not support GPU" - assert arg_types is not None and len(call_args) == len( - arg_types - ), "Mismatch call_args and arg_types in generate_kernel_call" + assert not gpu, ( + "CppWrapperCpuArrayRef.generate_kernel_call does not support GPU" + ) + assert arg_types is not None and len(call_args) == len(arg_types), ( + "Mismatch call_args and arg_types in generate_kernel_call" + ) new_args = [] for idx, arg in enumerate(call_args): if "*" in arg_types[idx]: @@ -328,9 +328,9 @@ class CppWrapperCpuArrayRef(CppWrapperCpu): dtype = may_get_constant_buffer_dtype( V.graph.graph_inputs[input_key] # type: ignore[arg-type] ) - assert ( - dtype is not None - ), "Fails to get the dtype of the sympy.Expr" + assert dtype is not None, ( + "Fails to get the dtype of the sympy.Expr" + ) self.codegen_tensor_item( dtype, f"inputs[{idx}]", input_key, self.prefix ) @@ -724,9 +724,9 @@ class CppWrapperCpuArrayRef(CppWrapperCpu): if reduce: line += f", {V.graph.wrapper_code.val_to_arg_str(reduce)}" else: - assert ( - reduce is None - ), "Expect reduce to be None for aten.scatter_ with scalar src" + assert reduce is None, ( + "Expect reduce to be None for aten.scatter_ with scalar src" + ) line += ");" self.writeline(line) diff --git a/torch/_inductor/codegen/cpp_wrapper_gpu.py b/torch/_inductor/codegen/cpp_wrapper_gpu.py index 6c27882e84c..e2f76cea9f0 100644 --- a/torch/_inductor/codegen/cpp_wrapper_gpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_gpu.py @@ -60,13 +60,13 @@ class DeferredGpuKernelLine(DeferredLineBase): # MultiKernel will select one kernel after running the autotune block self.kernel_name = MultiKernelCall.lookup_choice(self.kernel_name) params = CudaKernelParamCache.get(self.kernel_name) - assert ( - params is not None - ), f"{self.kernel_name} not found in CudaKernelParamCache" + assert params is not None, ( + f"{self.kernel_name} not found in CudaKernelParamCache" + ) for key in self.keys: - assert ( - key in params - ), f"{key} not found in CudaKernelParamCache[{self.kernel_name}]" + assert key in params, ( + f"{key} not found in CudaKernelParamCache[{self.kernel_name}]" + ) if key == get_cpp_wrapper_cubin_path_name(): assert os.path.exists(params[key]), f"{params[key]} does not exist" self.additional_files.append(params[key]) @@ -122,9 +122,9 @@ class DeferredGpuDefaultGrid: grid_fn = self.grid_callable(*grid, **self.grid_extra_kwargs) params = CudaKernelParamCache.get(self.kernel_name) - assert ( - params is not None - ), f"{self.kernel_name} not found in CudaKernelParamCache" + assert params is not None, ( + f"{self.kernel_name} not found in CudaKernelParamCache" + ) return grid_fn(params["meta"]) @@ -153,9 +153,9 @@ class DeferredGpuGridLine(DeferredLineBase): self.kernel_name = MultiKernelCall.lookup_choice(self.kernel_name) params = CudaKernelParamCache.get(self.kernel_name) - assert ( - params is not None - ), f"{self.kernel_name} not found in CudaKernelParamCache" + assert params is not None, ( + f"{self.kernel_name} not found in CudaKernelParamCache" + ) if self.autotune_configs is not None: # This indicates the Triton kernel is a user-defined one. @@ -248,13 +248,13 @@ class CppWrapperGpu(CppWrapperCpu): if V.graph.aot_mode and V.graph.inputs_to_check: for idx in V.graph.inputs_to_check: input_name = V.graph.graph_input_names[idx] - assert ( - input_name in V.graph.graph_inputs - ), f"{input_name} not found in graph inputs" + assert input_name in V.graph.graph_inputs, ( + f"{input_name} not found in graph inputs" + ) value = V.graph.graph_inputs[input_name] - assert isinstance( - value, TensorBox - ), f"{input_name} is expected to be tensor but found as {type(value)}" + assert isinstance(value, TensorBox), ( + f"{input_name} is expected to be tensor but found as {type(value)}" + ) warn_msg = ( f"Input {idx} was compiled as {GPU_ALIGN_BYTES}-bytes aligned, " "but it is not aligned at run time. Copying to an aligned tensor " diff --git a/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py b/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py index f6bd8615fda..f8be71fa64d 100644 --- a/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py +++ b/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py @@ -87,9 +87,9 @@ class CUDACPPScheduling(BaseScheduling): Codegen a CUDA template, possibly with fused epilogues """ counters["inductor"]["cuda_epilogue_fusion_counter"] += len(epilogue_nodes) - assert self.is_cuda_cpp_template( - template_node - ), "Template node passed to CUDAScheduler.codegen_template must be a SchedulerNode that wraps a CUDATemplateBuffer" + assert self.is_cuda_cpp_template(template_node), ( + "Template node passed to CUDAScheduler.codegen_template must be a SchedulerNode that wraps a CUDATemplateBuffer" + ) template_node = cast(SchedulerNode, template_node) _, (_numel, rnumel) = template_node.group assert rnumel == 1 diff --git a/torch/_inductor/codegen/cuda/cuda_kernel.py b/torch/_inductor/codegen/cuda/cuda_kernel.py index d760716b646..33aa039553f 100644 --- a/torch/_inductor/codegen/cuda/cuda_kernel.py +++ b/torch/_inductor/codegen/cuda/cuda_kernel.py @@ -496,7 +496,9 @@ class CUDATemplateCaller(ChoiceCaller): make_kernel_render: Callable[[CUDATemplateBuffer, Optional[list[IRNode]]], str], bmreq: CUDABenchmarkRequest, template: "CUDATemplate", # type: ignore[name-defined] - info_kwargs: Optional[dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]]], # type: ignore[type-arg] + info_kwargs: Optional[ + dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]] + ], # type: ignore[type-arg] description: str, ) -> None: super().__init__(name, input_nodes, layout, description) diff --git a/torch/_inductor/codegen/cuda/cuda_template.py b/torch/_inductor/codegen/cuda/cuda_template.py index 2e4281065aa..3de6f20bb6a 100644 --- a/torch/_inductor/codegen/cuda/cuda_template.py +++ b/torch/_inductor/codegen/cuda/cuda_template.py @@ -71,13 +71,14 @@ class CUDATemplate(KernelTemplate): A CUDATemplateCaller object representing the generated CUDA template caller. """ kernel_name = f"cuda_{self.name}" - with patch.object( - V.graph, "get_dtype", self._fake_get_dtype(self.output_node) - ), CUDATemplateKernel( - kernel_name=kernel_name, - runtime_arg_info=self.get_runtime_arg_info(), - runtime_arg_values=self.get_runtime_arg_values(**kwargs), - ) as kernel: + with ( + patch.object(V.graph, "get_dtype", self._fake_get_dtype(self.output_node)), + CUDATemplateKernel( + kernel_name=kernel_name, + runtime_arg_info=self.get_runtime_arg_info(), + runtime_arg_values=self.get_runtime_arg_values(**kwargs), + ) as kernel, + ): code = self.render(kernel=kernel, **kwargs) _, call_args, _, _ = kernel.args.python_argdefs() autotuning_log.debug("Generated Code:\n%s", code) diff --git a/torch/_inductor/codegen/cuda/cutlass_lib_extensions/gemm_operation_extensions.py b/torch/_inductor/codegen/cuda/cutlass_lib_extensions/gemm_operation_extensions.py index 39200b8f103..bdbe9f8e0d2 100644 --- a/torch/_inductor/codegen/cuda/cutlass_lib_extensions/gemm_operation_extensions.py +++ b/torch/_inductor/codegen/cuda/cutlass_lib_extensions/gemm_operation_extensions.py @@ -147,7 +147,9 @@ if try_import_cutlass(): "element_d": DataTypeTag[operation.D.element], # type: ignore[name-defined] "layout_d": LayoutTag[instance_layout_D], # type: ignore[name-defined] "element_accumulator": DataTypeTag[operation.accumulator_type()], # type: ignore[name-defined] - "opcode_class": OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], # type: ignore[name-defined] # noqa: B950 + "opcode_class": OpcodeClassTag[ # type: ignore[name-defined] + operation.tile_description.math_instruction.opcode_class + ], "arch": f"cutlass::arch::Sm{operation.arch:d}", "tile_shape_m": str(operation.tile_description.tile_shape[0]), "tile_shape_n": str(operation.tile_description.tile_shape[1]), @@ -168,7 +170,9 @@ if try_import_cutlass(): operation.tile_description.math_instruction.instruction_shape[2] ), "kernel_schedule": str(KernelScheduleTag[operation.kernel_schedule]), # type: ignore[name-defined] - "epilogue_schedule": str(EpilogueScheduleTag[operation.epilogue_schedule]), # type: ignore[name-defined] + "epilogue_schedule": str( + EpilogueScheduleTag[operation.epilogue_schedule] # type: ignore[name-defined] + ), "epilogue_functor": epilogue_functor, "stages": stage_count_string, "align_a": str(operation.A.alignment), diff --git a/torch/_inductor/codegen/cuda/cutlass_utils.py b/torch/_inductor/codegen/cuda/cutlass_utils.py index 131affc9cd3..c631558fbe4 100644 --- a/torch/_inductor/codegen/cuda/cutlass_utils.py +++ b/torch/_inductor/codegen/cuda/cutlass_utils.py @@ -56,9 +56,9 @@ def try_import_cutlass() -> bool: "Found cutlass_library in python search path, overriding config.cuda.cutlass_dir" ) cutlass_library_dir = os.path.dirname(cutlass_library.__file__) - assert os.path.isdir( - cutlass_library_dir - ), f"{cutlass_library_dir} is not a directory" + assert os.path.isdir(cutlass_library_dir), ( + f"{cutlass_library_dir} is not a directory" + ) config.cuda.cutlass_dir = os.path.abspath( os.path.join( cutlass_library_dir, @@ -86,9 +86,9 @@ def try_import_cutlass() -> bool: if os.path.isdir(cutlass_py_full_path): if tmp_cutlass_py_full_path not in sys.path: if os.path.exists(dst_link): - assert os.path.islink( - dst_link - ), f"{dst_link} is not a symlink. Try to remove {dst_link} manually and try again." + assert os.path.islink(dst_link), ( + f"{dst_link} is not a symlink. Try to remove {dst_link} manually and try again." + ) assert os.path.realpath(os.readlink(dst_link)) == os.path.realpath( cutlass_py_full_path ), f"Symlink at {dst_link} does not point to {cutlass_py_full_path}" diff --git a/torch/_inductor/codegen/cuda/gemm_template.py b/torch/_inductor/codegen/cuda/gemm_template.py index 6d0187d0319..9f843199262 100644 --- a/torch/_inductor/codegen/cuda/gemm_template.py +++ b/torch/_inductor/codegen/cuda/gemm_template.py @@ -949,9 +949,9 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC): import cutlass_library.gemm_operation as cutlass_gemm_op import cutlass_library.library as cutlass_lib - assert isinstance( - op, cutlass_gemm_op.GemmOperation - ), "op argument is required and has to be an instance of GemmOperation" + assert isinstance(op, cutlass_gemm_op.GemmOperation), ( + "op argument is required and has to be an instance of GemmOperation" + ) assert len(self.input_nodes) >= 2 and self.output_node is not None X, W = self.input_nodes[0], self.input_nodes[1] @@ -977,7 +977,10 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC): else: input_reorder = None kernel_call_signature = kernel.def_kernel( - inputs=inputs, outputs=[Y], names_str=names_str, input_reorder=input_reorder # type: ignore[arg-type] + inputs=inputs, # type: ignore[arg-type] + outputs=[Y], + names_str=names_str, + input_reorder=input_reorder, ) test_call_statement = self.test_call_statement(kernel, inputs, names_str) # The layouts might have changed between autotuning and this call if they were FlexibleLayout diff --git a/torch/_inductor/codegen/halide.py b/torch/_inductor/codegen/halide.py index 65e47122102..94f8d54340b 100644 --- a/torch/_inductor/codegen/halide.py +++ b/torch/_inductor/codegen/halide.py @@ -198,7 +198,7 @@ class HalidePrinter(PythonPrinter): val, n = expr.args val = self._print(val) n = int(n) - return f"hl.f32({10.**(-n)!r})*hl.round(({val})*hl.f32({10.**n!r}))" + return f"hl.f32({10.0 ** (-n)!r})*hl.round(({val})*hl.f32({10.0**n!r}))" texpr = HalidePrinter().doprint @@ -856,11 +856,11 @@ class HalideKernel(SIMDKernel): for sym, size in added_sym_size: full_index += stride * sym stride *= size - self.index_replacements[ - node.symbol() - ] = V.graph.sizevars.simplify_with_ranges( - ModularIndexing(full_index, node.divisor, node.length), - self.halide_vars, # type: ignore[arg-type] + self.index_replacements[node.symbol()] = ( + V.graph.sizevars.simplify_with_ranges( + ModularIndexing(full_index, node.divisor, node.length), + self.halide_vars, # type: ignore[arg-type] + ) ) # codegen the variable definitions @@ -1183,9 +1183,9 @@ class HalideKernel(SIMDKernel): if isinstance(value, tuple): assert reduction_type == "welford_combine" - self.cse.reduction_cache[ - cache_key - ] = result_tuple = self.welford_combine_impl(*value) + self.cse.reduction_cache[cache_key] = result_tuple = ( + self.welford_combine_impl(*value) + ) return result_tuple assert isinstance(value, HalideCSEVariable) and value.used_dims is not None @@ -1304,9 +1304,9 @@ class HalideKernel(SIMDKernel): scan = f"{scan_dom}.x" self.body.writeline(f"{scan_dom} = hl.RDom([hl.Range(1, {length})])") - assert ( - len(self.reduction_renames) == 1 - ), "multi-dimensional scan not implemented" + assert len(self.reduction_renames) == 1, ( + "multi-dimensional scan not implemented" + ) (scan_var,) = [*self.reduction_renames] # type: ignore[misc] scan_renames_cur = {scan_var: sympy_index_symbol(scan)} scan_renames_pri = {scan_var: sympy_index_symbol(scan) - 1} diff --git a/torch/_inductor/codegen/memory_planning.py b/torch/_inductor/codegen/memory_planning.py index 476646c123e..8efec7eeca9 100644 --- a/torch/_inductor/codegen/memory_planning.py +++ b/torch/_inductor/codegen/memory_planning.py @@ -214,8 +214,7 @@ class MemorySplitProtocol(Protocol): get_size_hint: CachedMethod[[], int] get_symbolic_size: CachedMethod[[], sympy.Expr] - def _allocate(self, block: Allocation, is_last: bool) -> bool: - ... + def _allocate(self, block: Allocation, is_last: bool) -> bool: ... class ClearCacheOnAllocateMixin(MemorySplitProtocol): diff --git a/torch/_inductor/codegen/mps.py b/torch/_inductor/codegen/mps.py index 08a93325542..ba1ad65e8dd 100644 --- a/torch/_inductor/codegen/mps.py +++ b/torch/_inductor/codegen/mps.py @@ -560,7 +560,10 @@ class MetalKernel(SIMDKernel): threads = [self.pexpr(v.numel) for v in self.active_range_trees()] # type: ignore[misc] args += [f"threads=[{', '.join(threads)}]"] if self.inside_reduction: - threads = [self.pexpr(v.numel) if v.is_reduction else "1" for v in self.active_range_trees()] # type: ignore[misc] + threads = [ + self.pexpr(v.numel) if v.is_reduction else "1" # type: ignore[misc] + for v in self.active_range_trees() + ] args += [f"group_size=[{', '.join(threads)}]"] wrapper.generate_kernel_call( diff --git a/torch/_inductor/codegen/multi_kernel.py b/torch/_inductor/codegen/multi_kernel.py index 59eba17ec24..a186598fab1 100644 --- a/torch/_inductor/codegen/multi_kernel.py +++ b/torch/_inductor/codegen/multi_kernel.py @@ -33,9 +33,9 @@ def _get_all_args(args_list, arg_types_list=None): all_args = max(args_list, key=len)[:] arg_types = max(arg_types_list, key=len)[:] if arg_types_list is not None else None for args in args_list: - assert OrderedSet(args).issubset( - OrderedSet(all_args) - ), f"{args} v.s. {all_args}" + assert OrderedSet(args).issubset(OrderedSet(all_args)), ( + f"{args} v.s. {all_args}" + ) return all_args, arg_types @@ -149,7 +149,9 @@ class MultiKernel: Assume we do codegen for a MultiKernel encapsulating kernel1 and kernel2. The generated definition for the multi-kernel will looks like: ``` - multi_kernel_kernel1 = MultiKernelCall([kernel1, kernel2], multi_kernel_definition_code) + multi_kernel_kernel1 = MultiKernelCall( + [kernel1, kernel2], multi_kernel_definition_code + ) ``` Here is an concrete example: https://gist.github.com/shunting314/d9f3fb6bc6cee3dbae005825ca196d39 diff --git a/torch/_inductor/codegen/rocm/ck_conv_template.py b/torch/_inductor/codegen/rocm/ck_conv_template.py index 38e2c283d09..7065b0aceb0 100644 --- a/torch/_inductor/codegen/rocm/ck_conv_template.py +++ b/torch/_inductor/codegen/rocm/ck_conv_template.py @@ -516,7 +516,12 @@ class CKGroupedConvFwdTemplate(CKTemplate): template_params=(",\n" + 12 * " ").join(template_params), ), self._template_from_string(template_type).render(operation_name=op.name()) - def render(self, kernel: ROCmTemplateKernel, op: "CKGroupedConvFwdOp", **kwargs) -> str: # type: ignore[override, name-defined] + def render( # type: ignore[override] + self, + kernel: ROCmTemplateKernel, + op: "CKGroupedConvFwdOp", # type: ignore[name-defined] + **kwargs, + ) -> str: template_buffer_node = kwargs.get("template_buffer_node", None) if template_buffer_node is not None: self.output_node = template_buffer_node diff --git a/torch/_inductor/codegen/rocm/ck_universal_gemm_template.py b/torch/_inductor/codegen/rocm/ck_universal_gemm_template.py index edb042c2011..2aad124e50e 100644 --- a/torch/_inductor/codegen/rocm/ck_universal_gemm_template.py +++ b/torch/_inductor/codegen/rocm/ck_universal_gemm_template.py @@ -602,7 +602,12 @@ class CKGemmTemplate(CKTemplate): operation_name=operation_name ) - def render(self, kernel: ROCmTemplateKernel, op: "CKGemmOperation", **kwargs) -> str: # type: ignore[override] + def render( # type: ignore[override] + self, + kernel: ROCmTemplateKernel, + op: "CKGemmOperation", + **kwargs, + ) -> str: """ The primary entry point for the code rendering process used in this template. """ @@ -706,7 +711,7 @@ class CKGemmTemplate(CKTemplate): * Template instance {op} * * {torch.__version__=} -* torch.version.git_version={getattr(torch.version, 'git_version', 'None')} +* torch.version.git_version={getattr(torch.version, "git_version", "None")} */ """ epilogue = None diff --git a/torch/_inductor/codegen/rocm/rocm_cpp_scheduling.py b/torch/_inductor/codegen/rocm/rocm_cpp_scheduling.py index 0b8b16d36cb..720509f2826 100644 --- a/torch/_inductor/codegen/rocm/rocm_cpp_scheduling.py +++ b/torch/_inductor/codegen/rocm/rocm_cpp_scheduling.py @@ -79,9 +79,9 @@ class ROCmCPPScheduling(BaseScheduling): """ Codegen a ROCm template, possibly with fused epilogues """ - assert self.is_rocm_cpp_template( - template_node - ), "Template node passed to ROCmScheduler.codegen_template must be a SchedulerNode that wraps a ROCmTemplateBuffer" + assert self.is_rocm_cpp_template(template_node), ( + "Template node passed to ROCmScheduler.codegen_template must be a SchedulerNode that wraps a ROCmTemplateBuffer" + ) template_node = cast(SchedulerNode, template_node) _, (_numel, rnumel) = template_node.group assert rnumel == 1 diff --git a/torch/_inductor/codegen/rocm/rocm_kernel.py b/torch/_inductor/codegen/rocm/rocm_kernel.py index 3b0882f0271..6ff84df7447 100644 --- a/torch/_inductor/codegen/rocm/rocm_kernel.py +++ b/torch/_inductor/codegen/rocm/rocm_kernel.py @@ -232,7 +232,9 @@ class ROCmTemplateCaller(ChoiceCaller): ], bmreq: ROCmBenchmarkRequest, template: "ROCmTemplate", # type: ignore[name-defined] - info_kwargs: Optional[dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]]], # type: ignore[type-arg] + info_kwargs: Optional[ + dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]] + ], # type: ignore[type-arg] ) -> None: super().__init__(name, input_nodes, layout, description="") self.category = category diff --git a/torch/_inductor/codegen/rocm/rocm_template.py b/torch/_inductor/codegen/rocm/rocm_template.py index f15623f7469..9f9659eca1a 100644 --- a/torch/_inductor/codegen/rocm/rocm_template.py +++ b/torch/_inductor/codegen/rocm/rocm_template.py @@ -70,13 +70,14 @@ class ROCmTemplate(KernelTemplate): """ kernel_name = f"rocm_{self.name}" kernel_hash_name = f"rocm_{self.name}_{next(self.index_counter)}" - with patch.object( - V.graph, "get_dtype", self._fake_get_dtype(self.output_node) - ), ROCmTemplateKernel( - kernel_name=kernel_name, - runtime_arg_info=self.get_runtime_arg_info(), - runtime_arg_values=self.get_runtime_arg_values(**kwargs), - ) as kernel: + with ( + patch.object(V.graph, "get_dtype", self._fake_get_dtype(self.output_node)), + ROCmTemplateKernel( + kernel_name=kernel_name, + runtime_arg_info=self.get_runtime_arg_info(), + runtime_arg_values=self.get_runtime_arg_values(**kwargs), + ) as kernel, + ): code = self.render(kernel=kernel, **kwargs) _, call_args, _, _ = kernel.args.python_argdefs() log.debug("Autotune key: %s, Generated Code:\n%s", kernel_hash_name, code) diff --git a/torch/_inductor/codegen/simd.py b/torch/_inductor/codegen/simd.py index 3d5589cd45b..e0ae8a0cf57 100644 --- a/torch/_inductor/codegen/simd.py +++ b/torch/_inductor/codegen/simd.py @@ -638,7 +638,8 @@ class SIMDKernel(Kernel[CSEVariableType], Generic[CSEVariableType]): continue while current_group < len(remaining) and sv.statically_known_equals( - remaining[current_group], 1 # type: ignore[arg-type] + remaining[current_group], + 1, # type: ignore[arg-type] ): # scroll to next group with remaining elements current_group += 1 @@ -666,9 +667,9 @@ class SIMDKernel(Kernel[CSEVariableType], Generic[CSEVariableType]): ) return_getters_groups.append(return_getters) - assert all( - V.graph.sizevars.size_hint(s) == 1 for s in remaining - ), f"failed to set ranges {remaining} {lengths}" + assert all(V.graph.sizevars.size_hint(s) == 1 for s in remaining), ( + f"failed to set ranges {remaining} {lengths}" + ) return new_ranges, return_getters_groups @@ -836,7 +837,8 @@ class SIMDKernel(Kernel[CSEVariableType], Generic[CSEVariableType]): replacements[ps] = V.graph.sizevars.lookup_precomputed_size(ps) if len(replacements) > 0: self.range_tree_nodes[sym].expr = sympy_subs( # type: ignore[index] - self.range_tree_nodes[sym].expr, replacements # type: ignore[index] + self.range_tree_nodes[sym].expr, + replacements, # type: ignore[index] ) self.range_tree_nodes[sym].codegen() # type: ignore[index] return expr @@ -2071,9 +2073,10 @@ class SIMDScheduling(BaseScheduling): features=SIMDKernelFeatures(node_schedule, numel, rnumel), ) self.codegen_node_schedule_with_kernel(node_schedule, kernel) - with config.patch( - "benchmark_kernel", benchmark_kernel - ), V.set_kernel_handler(kernel): + with ( + config.patch("benchmark_kernel", benchmark_kernel), + V.set_kernel_handler(kernel), + ): src_code = kernel.codegen_kernel() else: prologue, template, epilogue = nodes[0].get_prologue_template_epilogue( diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 8d1a1fc9ae0..311dcaa7879 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -1579,9 +1579,9 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]): self.block_ptr_id = itertools.count() self.block_ptr_to_buffer = dict[str, str]() self.helper_functions = HelperFunctions() - self.pointer_advancements: dict[ - SymT, dict[str, list[sympy.Expr]] - ] = collections.defaultdict(dict) + self.pointer_advancements: dict[SymT, dict[str, list[sympy.Expr]]] = ( + collections.defaultdict(dict) + ) self._load_counts: collections.Counter[str] = collections.Counter() # A set of autotuning hints to pass as part of triton_meta @@ -2053,9 +2053,9 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]): continue advancements = self.pointer_advancements[symt] - assert ( - block_ptr not in advancements - ), "duplicate advancement for pointer '{block_ptr}' at type '{symt}'" + assert block_ptr not in advancements, ( + "duplicate advancement for pointer '{block_ptr}' at type '{symt}'" + ) advancements[block_ptr] = advance_offsets else: block_ptr = indexing.format(var) @@ -2476,7 +2476,7 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]): buffer.splice( f"""\ {result_var}_val, {result_var}_idx = triton_helpers.{root_op}_with_index({value}, {index}, {dim}) - {result_var} = {self.reduction_resize(f'{result_var}_idx')} + {result_var} = {self.reduction_resize(f"{result_var}_idx")} """ ) @@ -2576,8 +2576,8 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]): {accumulator}_next, {accumulator_index}_next = triton_helpers.{root_op}imum_with_index( {accumulator}, {accumulator_index}, {value}, {reduction_range_prefix}index ) - {accumulator} = {where_cond(f'{accumulator}_next', accumulator)} - {accumulator_index} = {where_cond(f'{accumulator_index}_next', accumulator_index)} + {accumulator} = {where_cond(f"{accumulator}_next", accumulator)} + {accumulator_index} = {where_cond(f"{accumulator_index}_next", accumulator_index)} """ ) final_argreduce( @@ -2751,9 +2751,9 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]): ) self.compute.splice( f"""\ - {accumulator} = {where_cond(f'{accumulator}_next', accumulator)} - {accumulator_m2} = {where_cond(f'{accumulator_m2}_next', accumulator_m2)} - {accumulator_weight} = {where_cond(f'{accumulator_weight}_next', accumulator_weight)} + {accumulator} = {where_cond(f"{accumulator}_next", accumulator)} + {accumulator_m2} = {where_cond(f"{accumulator_m2}_next", accumulator_m2)} + {accumulator_weight} = {where_cond(f"{accumulator_weight}_next", accumulator_weight)} """ ) result_mean = result_var @@ -3040,9 +3040,9 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]): self.filter_masks(masks) masks = sorted(masks) assert not self._load_mask, "ops.sort not supported inside ops.masked" - assert ( - self.persistent_reduction - ), "ops.sort is only supported in persistent reductions" + assert self.persistent_reduction, ( + "ops.sort is only supported in persistent reductions" + ) cse_compute = functools.partial(self.cse.generate, self.compute) dim = self.triton_tensor_ndim() - self.num_reduction_dims @@ -3302,9 +3302,7 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]): {} import torch from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid - """.format( - V.graph.device_ops.import_get_raw_stream_as("get_raw_stream") - ) + """.format(V.graph.device_ops.import_get_raw_stream_as("get_raw_stream")) ) def _get_heuristic(self): @@ -3344,19 +3342,19 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]): inductor_meta["profile_bandwidth"] = config.profile_bandwidth inductor_meta["profile_bandwidth_regex"] = config.profile_bandwidth_regex inductor_meta["profile_bandwidth_output"] = config.profile_bandwidth_output - inductor_meta[ - "profile_bandwidth_with_do_bench_using_profiling" - ] = config.profile_bandwidth_with_do_bench_using_profiling + inductor_meta["profile_bandwidth_with_do_bench_using_profiling"] = ( + config.profile_bandwidth_with_do_bench_using_profiling + ) if config.coordinate_descent_tuning: - inductor_meta[ - "coordinate_descent_tuning" - ] = config.coordinate_descent_tuning - inductor_meta[ - "coordinate_descent_search_radius" - ] = config.coordinate_descent_search_radius - inductor_meta[ - "coordinate_descent_check_all_directions" - ] = config.coordinate_descent_check_all_directions + inductor_meta["coordinate_descent_tuning"] = ( + config.coordinate_descent_tuning + ) + inductor_meta["coordinate_descent_search_radius"] = ( + config.coordinate_descent_search_radius + ) + inductor_meta["coordinate_descent_check_all_directions"] = ( + config.coordinate_descent_check_all_directions + ) return inductor_meta def codegen_kernel(self, name=None): @@ -4046,9 +4044,10 @@ class TritonScheduling(SIMDScheduling): ) -> tuple[float, str]: """Benchmark an already compiled module""" device_interface = get_interface_for_device(V.graph.device_type) - with preserve_rng_state(), device_interface.device( - V.graph.get_current_device_or_throw() - ): # type: ignore[attr-defined] + with ( + preserve_rng_state(), + device_interface.device(V.graph.get_current_device_or_throw()), # type: ignore[attr-defined] + ): ms = None def cache_file_path(): @@ -4322,9 +4321,9 @@ def debug_triton_code(node: BaseSchedulerNode) -> list[str]: device = node.get_device() assert device is not None backend = node.scheduler.get_backend(device) - assert isinstance( - backend, (SIMDScheduling, CUDACombinedScheduling) - ), f"Scheduling backend should be SIMD or CUDACombined when generating debug Triton strings, got: {type(backend)}" + assert isinstance(backend, (SIMDScheduling, CUDACombinedScheduling)), ( + f"Scheduling backend should be SIMD or CUDACombined when generating debug Triton strings, got: {type(backend)}" + ) with V.graph.set_current_device(device): # Don't increment kernel count when generating debug string. diff --git a/torch/_inductor/codegen/triton_combo_kernel.py b/torch/_inductor/codegen/triton_combo_kernel.py index a5f99731bee..b2d70c6d59f 100644 --- a/torch/_inductor/codegen/triton_combo_kernel.py +++ b/torch/_inductor/codegen/triton_combo_kernel.py @@ -86,7 +86,9 @@ def _default_custom_combo_kernel_horizontal_partition( # rnumel > 2048 usually has long execution time # BaseSchedulerNode.group[-1][-1] is rnumel for reduction nodes long_reduction = [ - n for n in reduction if V.graph.sizevars.size_hint(n.group[-1][-1]) > 2048 # type: ignore[arg-type] + n + for n in reduction + if V.graph.sizevars.size_hint(n.group[-1][-1]) > 2048 # type: ignore[arg-type] ] short_reduction = [n for n in reduction if n not in long_reduction] if long_reduction: @@ -138,7 +140,7 @@ def set_custom_combo_kernel_horizontal_partition( dict[BaseSchedulerNode, tuple[Any, Any, Any, Any]], ], list[list[BaseSchedulerNode]], - ] + ], ) -> None: """Sets the algorithm used to partition nodes into horizontal partitions. Nodes in different partitions are implemented in different combo kernels. Nodes in the same partition are likely to be implemented @@ -593,9 +595,9 @@ class ComboKernel(Kernel): num_persistent_reduction = len( [e for e in heuristics_list if e == "persistent_reduction"] ) - assert ( - num_reduction == 0 - ), "combining pointwise and reduction are not supported yet." + assert num_reduction == 0, ( + "combining pointwise and reduction are not supported yet." + ) heuristics = ( "pointwise_with_reduction" if num_persistent_reduction > 0 @@ -784,13 +786,13 @@ class ComboKernel(Kernel): name, tree, suffix=str(num) ) if not tree.is_reduction: - assert isinstance( - grid[i][num], str - ), f"Grid {grid[i][num]} should be a dynamic shape." + assert isinstance(grid[i][num], str), ( + f"Grid {grid[i][num]} should be a dynamic shape." + ) numel_sign = grid[i][num][0] if grid[i][num][0] == "-" else "" - assert ( - grid[i][num] == numel_sign + numel_name - ), f"numel args mismatch: {grid[i][num]} vs {numel_name}" + assert grid[i][num] == numel_sign + numel_name, ( + f"numel args mismatch: {grid[i][num]} vs {numel_name}" + ) grid[i][num] = -expr if numel_sign == "-" else expr if not tree.is_reduction or sub_kernel.inside_reduction: @@ -807,13 +809,13 @@ class ComboKernel(Kernel): continue expr = V.graph.sizevars.size_hint(tree.numel) if not tree.is_reduction: - assert isinstance( - grid[i][num], str - ), f"Grid {grid[i][num]} should be a dynamic shape." + assert isinstance(grid[i][num], str), ( + f"Grid {grid[i][num]} should be a dynamic shape." + ) numel_sign = grid[i][num][0] if grid[i][num][0] == "-" else "" - assert ( - grid[i][num] == numel_sign + numel_name - ), f"grid mismatch: {grid[i][num]} vs {numel_name}" + assert grid[i][num] == numel_sign + numel_name, ( + f"grid mismatch: {grid[i][num]} vs {numel_name}" + ) grid[i][num] = -expr if numel_sign == "-" else expr if not tree.is_reduction or sub_kernel.inside_reduction: extra_args.append(expr) @@ -1015,9 +1017,7 @@ class ComboKernel(Kernel): {} import torch from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid, grid_combo_kernels - """.format( - V.graph.device_ops.import_get_raw_stream_as("get_raw_stream") - ) + """.format(V.graph.device_ops.import_get_raw_stream_as("get_raw_stream")) ) def uniquify_block_sizes( diff --git a/torch/_inductor/codegen/triton_split_scan.py b/torch/_inductor/codegen/triton_split_scan.py index d025a8120d1..ce008228a5f 100644 --- a/torch/_inductor/codegen/triton_split_scan.py +++ b/torch/_inductor/codegen/triton_split_scan.py @@ -57,9 +57,9 @@ class TritonSplitScanKernel(TritonKernel): def initialize_range_tree(self, pid_cache): prefixes = ["y", "x", "r0_"] - assert len(self.numels) <= len( - prefixes - ), "z dimension not supported for split scan" + assert len(self.numels) <= len(prefixes), ( + "z dimension not supported for split scan" + ) active_prefixes = prefixes[len(prefixes) - len(self.numels) :] grid_dims = {"r0_": 0, "x": 1, "y": 2} diff --git a/torch/_inductor/codegen/triton_utils.py b/torch/_inductor/codegen/triton_utils.py index 193080d360c..2d5f6a55b4c 100644 --- a/torch/_inductor/codegen/triton_utils.py +++ b/torch/_inductor/codegen/triton_utils.py @@ -184,7 +184,8 @@ def config_of( if isinstance(x, TensorArg): if include_tensor: offset_aligned = V.graph.sizevars.statically_known_multiple_of( - x.offset * x.dtype.itemsize, alignment # type: ignore[arg-type] + x.offset * x.dtype.itemsize, + alignment, # type: ignore[arg-type] ) return offset_aligned and not is_unaligned_buffer(x) else: diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 8872b049b8a..e3062f73169 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -104,8 +104,7 @@ def can_match_buffer_size(input_buf: BufferLike, output_buf: BufferLike): # NB: this is symbolic so that we don't try to reuse a buffer # for s0 for s1, just because they happen to share the same # size hint - sympy_str(input_size) - == sympy_str(output_size) + sympy_str(input_size) == sympy_str(output_size) ) or ( # statically known that 0.95 * input_size <= output_size <= input_size V.graph.sizevars.statically_known_geq(output_size, 0.95 * input_size) @@ -138,9 +137,9 @@ def convert_arg_type(arg: torch.Argument) -> str: container_match = re.findall(py_container + r"\[([a-zA-Z_]+)]", python_type) if len(container_match) == 1: contained_type = container_match[0] - assert ( - contained_type in PYTHON_TO_CPP - ), f"unsupported {py_container} type in convert_arg_type: {contained_type}" + assert contained_type in PYTHON_TO_CPP, ( + f"unsupported {py_container} type in convert_arg_type: {contained_type}" + ) cpp_contained_type = PYTHON_TO_CPP[contained_type] return f"{cpp_container}<{cpp_contained_type}>" @@ -367,9 +366,9 @@ class SymbolicCallArg: class MemoryPlanningState: def __init__(self): super().__init__() - self.reuse_pool: dict[ - ReuseKey, list[FreeIfNotReusedLine] - ] = collections.defaultdict(list) + self.reuse_pool: dict[ReuseKey, list[FreeIfNotReusedLine]] = ( + collections.defaultdict(list) + ) self.total_allocated_buffer_size: int = 0 def __contains__(self, key: ReuseKey) -> bool: @@ -431,9 +430,9 @@ class EnterDeviceContextManagerLine(WrapperLine): f"{V.graph.device_ops.cpp_aoti_stream_guard()} stream_guard(stream, this->device_idx_);" ) else: - assert ( - self.last_seen_device_guard_index == self.device_idx - ), "AOTInductor only supports running on one CUDA device" + assert self.last_seen_device_guard_index == self.device_idx, ( + "AOTInductor only supports running on one CUDA device" + ) else: if self.last_seen_device_guard_index is None: code.writeline( @@ -1794,7 +1793,8 @@ class PythonWrapperCodegen(CodeGen): equals_1 = isinstance( arg, (int, sympy.Integer) ) and V.graph.sizevars.statically_known_equals( - arg, 1 # type: ignore[arg-type] + arg, + 1, # type: ignore[arg-type] ) add_arg(idx, SizeArg(key, arg), equals_1=equals_1) @@ -2052,9 +2052,9 @@ class PythonWrapperCodegen(CodeGen): buf_name = arg buf = V.graph.get_buffer(arg) else: - assert ( - raw_arg is not None - ), "V.graph.get_buffer(arg) and raw_arg can't be None at the same time" + assert raw_arg is not None, ( + "V.graph.get_buffer(arg) and raw_arg can't be None at the same time" + ) buf_name = f"tmp_arg_{index}" buf = raw_arg @@ -2181,9 +2181,9 @@ class PythonWrapperCodegen(CodeGen): and kernel_name not in self.kernel_autotune_names ): # Create example args for autotune in a separate epilogue - assert arg_types is not None and len(call_args) == len( - arg_types - ), "call_args and arg_types do not match" + assert arg_types is not None and len(call_args) == len(arg_types), ( + "call_args and arg_types do not match" + ) tensor_args = {} all_args = [] @@ -2191,9 +2191,9 @@ class PythonWrapperCodegen(CodeGen): # create a dummy raw_args for uniform behavior in the following loop raw_args = [None] * len(call_args) else: - assert len(raw_args) == len( - call_args - ), "call_args and raw_args do not match" + assert len(raw_args) == len(call_args), ( + "call_args and raw_args do not match" + ) for i, (arg, arg_type, raw_arg) in enumerate( zip(call_args, arg_types, raw_args) @@ -2411,9 +2411,9 @@ class PythonWrapperCodegen(CodeGen): if isinstance(layout, ir.NoneLayout): return if isinstance(layout, ir.NonOwningLayout): - assert isinstance( - layout.view, ir.ReinterpretView - ), f"unexpected {type(layout.view)}: {layout.view}" + assert isinstance(layout.view, ir.ReinterpretView), ( + f"unexpected {type(layout.view)}: {layout.view}" + ) assert isinstance(layout.view.data, ir.StorageBox), type(layout.view.data) assert isinstance(layout.view.data.data, ir.Buffer), type(layout.view.data) self.codegen_allocation(layout.view.data.data) @@ -2535,9 +2535,9 @@ class PythonWrapperCodegen(CodeGen): def codegen_subgraph_prefix(self, subgraph, outer_inputs, outer_outputs): # All inputs of hops must be explicitly passed in. # Free tensors and basic symbols should have been explicitly lifted as inputs in dynamo. - assert len(outer_inputs) == len( - subgraph.graph.graph_input_names - ), f"graph_input_names:{subgraph.graph.graph_input_names}, outer_inputs: {outer_inputs}" + assert len(outer_inputs) == len(subgraph.graph.graph_input_names), ( + f"graph_input_names:{subgraph.graph.graph_input_names}, outer_inputs: {outer_inputs}" + ) for inner_input, outer_input in zip( subgraph.graph.graph_input_names, outer_inputs ): diff --git a/torch/_inductor/comms.py b/torch/_inductor/comms.py index f6135ba1a5c..deb4ca2a22b 100644 --- a/torch/_inductor/comms.py +++ b/torch/_inductor/comms.py @@ -219,8 +219,7 @@ def _schedule_for_comm( for snode, deps in unmet_deps.items(): assert len(deps) == 0, ( - "Detected unscheduled nodes. " - f"Nodes with unmet dependencies: {unmet_deps}" + f"Detected unscheduled nodes. Nodes with unmet dependencies: {unmet_deps}" ) return scheduled @@ -354,9 +353,7 @@ def remove_fsdp2_unsharded_param_graph_input_usage(graph: torch.fx.Graph): node.op == "call_function" and node.target == torch.ops.inductor.resize_storage_bytes_.default ): - assert ( - node.args[0].op == "placeholder" - ), f"""\ + assert node.args[0].op == "placeholder", f"""\ Resize can only operate on graph inputs, but got {node} which is resizing non-graph-input {node.args[0]} """ graph_input = node.args[0] @@ -408,9 +405,7 @@ Skipping `remove_fsdp2_unsharded_param_graph_input_usage` FX graph pass for that if node.op == "call_function" and node.target == torch.ops.fsdp.copy_.default: fsdp_copy_node = node unsharded_param = node.args[0] - assert ( - unsharded_param.op == "placeholder" - ), f""" + assert unsharded_param.op == "placeholder", f""" Assumed all FSDP2 `unsharded_param`s to be graph input, but it's not true! Offending node: {unsharded_param}. Graph: {graph} """ diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index 1ff90566d10..4516afee89d 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -281,9 +281,9 @@ def _unlift_graph( elif node_name in graph_signature.inputs_to_buffers: buffer_name = graph_signature.inputs_to_buffers[node_name] lifted_inputs.append(buffer_name) - gm.meta[ - get_cloned_parameter_buffer_name(buffer_name) - ] = clone_preserve_strides(state_dict[buffer_name]) + gm.meta[get_cloned_parameter_buffer_name(buffer_name)] = ( + clone_preserve_strides(state_dict[buffer_name]) + ) else: assert node_name in graph_signature.user_inputs lifted_inputs.append(None) @@ -542,7 +542,7 @@ def fake_tensor_prop( # pass config dict back to user def get_patched_config_dict( - config_patches: Optional[Union[str, dict[str, Any]]] = None + config_patches: Optional[Union[str, dict[str, Any]]] = None, ) -> dict[str, Any]: with config.patch(config_patches): return config.get_config_copy() @@ -579,8 +579,7 @@ class _CompileFxCallable(Protocol): gm: GraphModule, example_inputs: Sequence[InputType], **kwargs: Unpack[_CompileFxKwargs], - ) -> OutputCode: - ... + ) -> OutputCode: ... def compile_fx_inner( @@ -662,9 +661,9 @@ def _compile_fx_inner( static_inputs_log.debug("static input idxs compile_fx_inner: %s", static_input_idxs) inputs_to_check = get_input_idxs_to_check(example_inputs, static_input_idxs) - assert isinstance( - next(iter(reversed(gm.graph.nodes))).args[0], (tuple, list) - ), f"inductor can only compile FX graphs which return a tuple/list, but got {gm.graph}" + assert isinstance(next(iter(reversed(gm.graph.nodes))).args[0], (tuple, list)), ( + f"inductor can only compile FX graphs which return a tuple/list, but got {gm.graph}" + ) if (cudagraphs := graph_kwargs.get("cudagraphs")) is None: graph_kwargs["cudagraphs"] = cudagraphs = BoxedBool(config.triton.cudagraphs) @@ -679,9 +678,10 @@ def _compile_fx_inner( fx_graph_remote_cache = should_use_remote_fx_graph_cache() - with _WaitCounter( - "pytorch.wait_counter.fx_codegen_and_compile" - ).guard() as _, _WaitCounter("pytorch.wait_counter.all_compilation_types").guard(): + with ( + _WaitCounter("pytorch.wait_counter.fx_codegen_and_compile").guard() as _, + _WaitCounter("pytorch.wait_counter.all_compilation_types").guard(), + ): use_cache = ( not config.force_disable_caches and (config.fx_graph_cache or fx_graph_remote_cache) @@ -865,8 +865,7 @@ class FxCompile(ABC): example_inputs: Sequence[InputType], inputs_to_check: Sequence[int], graph_kwargs: _CompileFxKwargs, - ) -> OutputCode: - ... + ) -> OutputCode: ... class _InProcessFxCompile(FxCompile): @@ -890,16 +889,17 @@ class _InProcessFxCompile(FxCompile): cpp_wrapper: bool = graph_kwargs.get("cpp_wrapper", False) aot_mode: bool = V.aot_compilation is_inference: bool = graph_kwargs.get("is_inference", False) - extern_node_serializer: Optional[ - Callable[[list[ExternKernelNode]], Any] - ] = graph_kwargs.get("extern_node_serializer", None) + extern_node_serializer: Optional[Callable[[list[ExternKernelNode]], Any]] = ( + graph_kwargs.get("extern_node_serializer", None) + ) boxed_forward_device_index: Optional[BoxedDeviceIndex] = graph_kwargs.get( "boxed_forward_device_index", None ) - with _WaitCounter( - "pytorch.wait_counter.actual_codegen_and_compile" - ).guard(), dynamo_utils.preserve_rng_state(): + with ( + _WaitCounter("pytorch.wait_counter.actual_codegen_and_compile").guard(), + dynamo_utils.preserve_rng_state(), + ): if (sleep_sec := config.sleep_sec_TESTING_ONLY) is not None: import time @@ -1038,9 +1038,11 @@ class _InProcessFxCompile(FxCompile): # See details in vllm/compilation/pass_manager.py. log.warning("failed to log pt2_configs") - with V.set_fake_mode(fake_mode), maybe_disable_comprehensive_padding( - example_inputs - ), maybe_disable_graph_partition(cpp_wrapper, aot_mode): + with ( + V.set_fake_mode(fake_mode), + maybe_disable_comprehensive_padding(example_inputs), + maybe_disable_graph_partition(cpp_wrapper, aot_mode), + ): const_output_index = None const_graph = None const_code = None @@ -1123,9 +1125,9 @@ class _InProcessFxCompile(FxCompile): if graph.aot_mode: from .codecache import AotCodeCompiler - assert ( - graph.cpp_wrapper - ), "AOT mode only supports C++ wrapper" + assert graph.cpp_wrapper, ( + "AOT mode only supports C++ wrapper" + ) code, linemap = graph.codegen_with_cpp_wrapper() output_code_log.debug("Output code: \n%s", code) @@ -1509,10 +1511,13 @@ def cudagraphify( def run(new_inputs: Sequence[InputType]) -> Any: nonlocal compiled_fn if compiled_fn is None: - with dynamo_utils.dynamo_timed( - "cudagraphify", - log_pt2_compile_event=True, - ), dynamo_utils.preserve_rng_state(): + with ( + dynamo_utils.dynamo_timed( + "cudagraphify", + log_pt2_compile_event=True, + ), + dynamo_utils.preserve_rng_state(), + ): compiled_fn = cudagraphify_fn(model, new_inputs, static_input_idxs) return compiled_fn(new_inputs) @@ -1669,13 +1674,16 @@ def compile_fx_aot( extern_node_serializer = config_patches.pop("extern_node_serializer", None) saved_compile_id = model_.meta.get("dynamo_compile_id", None) saved_compile_context = torch._guards.CompileContext(saved_compile_id) - with V.set_aot_compilation(True), torch._guards.compile_context( - saved_compile_context - ), chromium_event_timed( - "compile_fx_aot", - log_pt2_compile_event=True, - reset_event_log_on_exit=True, - ), get_metrics_context(): + with ( + V.set_aot_compilation(True), + torch._guards.compile_context(saved_compile_context), + chromium_event_timed( + "compile_fx_aot", + log_pt2_compile_event=True, + reset_event_log_on_exit=True, + ), + get_metrics_context(), + ): compiled_artifacts = compile_fx( model_, example_inputs_, @@ -1875,12 +1883,15 @@ def compile_fx( # TODO: This probably shouldn't be a recursive call if config.cpp_wrapper: - with config.patch( - { - "cpp_wrapper": False, # reset to break recursive call to compile_fx - **get_cpp_wrapper_config(), - } - ), V.set_real_inputs(example_inputs_): + with ( + config.patch( + { + "cpp_wrapper": False, # reset to break recursive call to compile_fx + **get_cpp_wrapper_config(), + } + ), + V.set_real_inputs(example_inputs_), + ): inputs_: Sequence[InputType] = example_inputs_ if isinstance(model_, GraphModule): @@ -1940,10 +1951,10 @@ def compile_fx( # Do the actual work - with _use_lazy_graph_module( - dynamo_config.use_lazy_graph_module - ), enable_python_dispatcher(), torch.fx.traceback.preserve_node_meta( - config.trace.enabled + with ( + _use_lazy_graph_module(dynamo_config.use_lazy_graph_module), + enable_python_dispatcher(), + torch.fx.traceback.preserve_node_meta(config.trace.enabled), ): # Pre-grad passes cannot be run if we weren't given a GraphModule. # Dynamo will always produce a GraphModule, but this handles cases @@ -2085,9 +2096,9 @@ def compile_fx( boxed_forward_device_index=forward_device, ) - fw_compiler: Callable[ - [GraphModule, Sequence[InputType]], OutputCode - ] = functools.partial(fw_compiler_base, is_inference=False) + fw_compiler: Callable[[GraphModule, Sequence[InputType]], OutputCode] = ( + functools.partial(fw_compiler_base, is_inference=False) + ) fw_compiler = SerializableAOTDispatchCompiler(OutputCode, fw_compiler) if config.freezing and not torch.is_grad_enabled(): @@ -2124,9 +2135,10 @@ def compile_fx( ) -> OutputCode: from torch._dynamo.convert_frame import compile_lock - with dynamo_utils.dynamo_timed( - "compile_fx..bw_compiler" - ), compile_lock: + with ( + dynamo_utils.dynamo_timed("compile_fx..bw_compiler"), + compile_lock, + ): model_outputs_node = output_node(gm) if config.bw_outputs_user_visible: model_outputs = pytree.arg_tree_leaves(*model_outputs_node.args) @@ -2194,10 +2206,11 @@ def compile_fx( with V.set_fake_mode(fake_mode), compiled_autograd._disable(), context(): return inference_compiler(unlifted_gm, example_inputs_) - with V.set_fake_mode(fake_mode), torch._guards.tracing( - tracing_context - ), compiled_autograd._disable(), functorch_config.patch( - unlift_effect_tokens=True + with ( + V.set_fake_mode(fake_mode), + torch._guards.tracing(tracing_context), + compiled_autograd._disable(), + functorch_config.patch(unlift_effect_tokens=True), ): try: return aot_autograd( diff --git a/torch/_inductor/compiler_bisector.py b/torch/_inductor/compiler_bisector.py index 17fe92c12aa..844af495793 100644 --- a/torch/_inductor/compiler_bisector.py +++ b/torch/_inductor/compiler_bisector.py @@ -530,7 +530,8 @@ class CompilerBisector: ) if result: curr_subsystem = cls.get_subsystem_object( - curr_backend, cls.get_subsystem() # type: ignore[arg-type] + curr_backend, + cls.get_subsystem(), # type: ignore[arg-type] ) if isinstance(curr_subsystem, BinarySubsystem): diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 18d25cf342c..789f727d62c 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -80,9 +80,9 @@ fx_graph_cache: bool = Config( fx_graph_remote_cache: Optional[bool] = fx_graph_remote_cache_default() # should we bundle triton caching into fx graph cache -bundle_triton_into_fx_graph_cache: Optional[ - bool -] = bundle_triton_into_fx_graph_cache_default() +bundle_triton_into_fx_graph_cache: Optional[bool] = ( + bundle_triton_into_fx_graph_cache_default() +) # Enable autotune local cache. # @@ -1390,12 +1390,12 @@ class halide: # Halide autoscheduler to use, choices are: # "Anderson2021" (gpu-only), "Li2018", "Adams2019" (cpu-only), or "Mullapudi2016" (cpu-only) - scheduler_cuda: Literal[ - "Anderson2021", "Li2018", "Adams2019", "Mullapudi2016" - ] = "Anderson2021" - scheduler_cpu: Literal[ - "Anderson2021", "Li2018", "Adams2019", "Mullapudi2016" - ] = "Adams2019" + scheduler_cuda: Literal["Anderson2021", "Li2018", "Adams2019", "Mullapudi2016"] = ( + "Anderson2021" + ) + scheduler_cpu: Literal["Anderson2021", "Li2018", "Adams2019", "Mullapudi2016"] = ( + "Adams2019" + ) # Controls `no_asserts` flag passed to Halide target (warning: can false positive) asserts = False diff --git a/torch/_inductor/constant_folding.py b/torch/_inductor/constant_folding.py index 893b9faff89..1972bcc3583 100644 --- a/torch/_inductor/constant_folding.py +++ b/torch/_inductor/constant_folding.py @@ -125,7 +125,8 @@ class ConstantFolder(torch.fx.Interpreter): and is_woq_int8_pattern(next(iter(node.users))) ) ) and is_const_source( - node.args[0], self.lifted_constant_names # type: ignore[arg-type] + node.args[0], # type: ignore[arg-type] + self.lifted_constant_names, ): # Case 1: int8_weight -> dq -> bf16_weight # Case 2: int8_weight -> permute -> dq -> bf16_weight diff --git a/torch/_inductor/cpp_builder.py b/torch/_inductor/cpp_builder.py index 403a7358d66..2643f30ff22 100644 --- a/torch/_inductor/cpp_builder.py +++ b/torch/_inductor/cpp_builder.py @@ -1633,8 +1633,8 @@ class CppBuilder: """ ) - assert os.path.exists( - cmake_path - ), f"save_link_cmd_to_cmakefile expects {cmake_path} to already exist" + assert os.path.exists(cmake_path), ( + f"save_link_cmd_to_cmakefile expects {cmake_path} to already exist" + ) with open(cmake_path, "a") as f: f.write(contents) diff --git a/torch/_inductor/cudagraph_trees.py b/torch/_inductor/cudagraph_trees.py index a1bbbb1f39d..c2b3485fd7e 100644 --- a/torch/_inductor/cudagraph_trees.py +++ b/torch/_inductor/cudagraph_trees.py @@ -119,6 +119,7 @@ from . import config @dataclasses.dataclass(frozen=True) class GraphID: "Unique counter of a cuda graph recording" + id: int @@ -622,11 +623,15 @@ class CUDAWarmupNode: refs = list(self.path_live_weakrefs()) check_memory_pool(self.device_index, self.cuda_graphs_pool, refs) - with torch.cuda.device( - self.device_index - ), disable_conv_cache_emptying(), clear_cublas_manager(), _use_cuda_memory_pool_manager( - self.device_index, self.cuda_graphs_pool, self.stream - ), get_history_recording(): + with ( + torch.cuda.device(self.device_index), + disable_conv_cache_emptying(), + clear_cublas_manager(), + _use_cuda_memory_pool_manager( + self.device_index, self.cuda_graphs_pool, self.stream + ), + get_history_recording(), + ): out = self.wrapped_function.model(new_inputs) # We need to know which outputs are allocated within the cudagraph pool @@ -713,6 +718,7 @@ UnaliasedStorage = _UnaliasedStorage() class AliasesPriorGraphOutput(OutputAliasInfo): "Marks that the graph output aliases an output of a prior graph" + __slots__ = ["index"] index: PathOutputIndex @@ -1200,14 +1206,18 @@ class CUDAGraphNode: ] check_memory_pool(self.device, self.cuda_graphs_pool, memory) - with preserve_rng_state(), torch.cuda.device( - self.device - ), clear_cublas_manager(), torch.cuda.graph( - self.graph, - stream=self.stream, - pool=self.cuda_graphs_pool, - capture_error_mode="thread_local", - ), get_history_recording(): + with ( + preserve_rng_state(), + torch.cuda.device(self.device), + clear_cublas_manager(), + torch.cuda.graph( + self.graph, + stream=self.stream, + pool=self.cuda_graphs_pool, + capture_error_mode="thread_local", + ), + get_history_recording(), + ): static_outputs = model(inputs) # running model should reclaim memory @@ -1247,11 +1257,13 @@ class CUDAGraphNode: self.output_storage_alias.append(UnaliasedStorage) continue - torch._check( - o.is_cuda or o.untyped_storage().data_ptr() == 0, - lambda: ( - "Expected all cuda outputs in cuda graph recording. Non cuda output " - f"from {self.stack_traces[i] if self.stack_traces else '(unknown)'}" + ( + torch._check( + o.is_cuda or o.untyped_storage().data_ptr() == 0, + lambda: ( + "Expected all cuda outputs in cuda graph recording. Non cuda output " + f"from {self.stack_traces[i] if self.stack_traces else '(unknown)'}" + ), ), ) @@ -1291,9 +1303,9 @@ class CUDAGraphNode: if self.stack_traces is None: self.stack_traces = [None for _ in range(len(outputs))] else: - assert len(self.stack_traces) == len( - outputs - ), "Wrong number of stack traces passed in" + assert len(self.stack_traces) == len(outputs), ( + "Wrong number of stack traces passed in" + ) assert not self.outputs_weakrefs for out, static_output_tensor in zip(outputs, self.static_output_tensors): @@ -1599,12 +1611,14 @@ class CUDAGraphNode: self.stream.wait_stream(torch.cuda.current_stream()) recording_inputs: list[InputType] = [] - with warnings.catch_warnings(record=True), torch.cuda.device( - self.device - ), _use_cuda_memory_pool_manager( - self.device, - mem_pool=self.cuda_graphs_pool, - stream=self.stream, + with ( + warnings.catch_warnings(record=True), + torch.cuda.device(self.device), + _use_cuda_memory_pool_manager( + self.device, + mem_pool=self.cuda_graphs_pool, + stream=self.stream, + ), ): for i, inp in enumerate(inputs): if not isinstance(inp, torch.Tensor): @@ -1736,12 +1750,8 @@ def check_memory_pool( pool_id: tuple[int, int], live_storages_ptrs: list[StorageWeakRefWrapper], ) -> None: - assert all( - isinstance(elem, StorageWeakRefWrapper) for elem in live_storages_ptrs - ) # noqa: C419 - unique_storages = { - stor.data_ptr() for stor in live_storages_ptrs if stor() - } # noqa: set_linter + assert all(isinstance(elem, StorageWeakRefWrapper) for elem in live_storages_ptrs) # noqa: C419 + unique_storages = {stor.data_ptr() for stor in live_storages_ptrs if stor()} # noqa: set_linter # check if there is a divergence first, then do the expensive snapshot call after # we know it will error @@ -1864,11 +1874,14 @@ class CUDAGraphTreeManager: self.graph: Optional[torch.cuda.CUDAGraph] = torch.cuda.CUDAGraph() self.cuda_graphs_thread_pool = torch.cuda.graph_pool_handle() - with warnings.catch_warnings(record=True), torch.cuda.graph( - self.graph, - pool=self.cuda_graphs_thread_pool, - stream=self.stream, - capture_error_mode="thread_local", + with ( + warnings.catch_warnings(record=True), + torch.cuda.graph( + self.graph, + pool=self.cuda_graphs_thread_pool, + stream=self.stream, + capture_error_mode="thread_local", + ), ): pass @@ -2230,7 +2243,10 @@ class CUDAGraphTreeManager: constants: tuple[torch.Tensor, ...], placeholders: tuple[PlaceholderInfo, ...], mutated_input_idxs: tuple[int, ...], - ) -> tuple[ModelType, OutputType,]: + ) -> tuple[ + ModelType, + OutputType, + ]: id = self.new_func_id() self.ids_to_stack_traces[id] = stack_traces self.ids_to_funcs[id] = WrappedFunction( diff --git a/torch/_inductor/cudagraph_utils.py b/torch/_inductor/cudagraph_utils.py index 6bf731f60b7..9a3e160e19c 100644 --- a/torch/_inductor/cudagraph_utils.py +++ b/torch/_inductor/cudagraph_utils.py @@ -28,6 +28,7 @@ ModelType = Callable[[list[InputType]], OutputType] @dataclasses.dataclass(frozen=True) class FunctionID: "Unique counter of a function wrapped in cudagraphify_impl" + id: int @@ -164,7 +165,7 @@ def _get_use_stack_trace(node: torch.fx.Node) -> Optional[str]: def check_multiple_devices_or_any_cpu_nodes( - device_node_mapping: dict[torch.device, torch.fx.Node] + device_node_mapping: dict[torch.device, torch.fx.Node], ) -> Optional[str]: if cpu_node := device_node_mapping.get(torch.device("cpu")): msg = f"cpu device ({cpu_node.name})" @@ -184,7 +185,7 @@ def check_multiple_devices_or_any_cpu_nodes( def check_lowering_disable_cudagraph( - device_node_mapping: dict[torch.device, torch.fx.Node] + device_node_mapping: dict[torch.device, torch.fx.Node], ) -> Optional[str]: return check_multiple_devices_or_any_cpu_nodes(device_node_mapping) @@ -276,9 +277,9 @@ def log_data_ptr_mismatch( Logs the mismatch between input data pointers and recorded data pointers. This checks only idxs in target_idxs. """ - assert len(inputs) == len(recorded_data_ptr) and len(inputs) == len( - placeholders - ), "length mismatch between inputs, recorded_data_ptr, and placeholders" + assert len(inputs) == len(recorded_data_ptr) and len(inputs) == len(placeholders), ( + "length mismatch between inputs, recorded_data_ptr, and placeholders" + ) t_tensors = [inputs[i] for i in target_idxs] t_data_ptrs = [recorded_data_ptr[i] for i in target_idxs] diff --git a/torch/_inductor/debug.py b/torch/_inductor/debug.py index b6a4f7bdca9..81c986e62ca 100644 --- a/torch/_inductor/debug.py +++ b/torch/_inductor/debug.py @@ -240,7 +240,7 @@ def update_orig_fx_node_name_to_buf_name( def get_node_name_to_buf_meta( - node_name_to_buf_name: dict[str, str] + node_name_to_buf_name: dict[str, str], ) -> dict[str, BufMeta]: buf_name_to_n_node = {} for node_name, buf_name in node_name_to_buf_name.items(): diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py index 19ceafc5e76..a5d7bf7a8cc 100644 --- a/torch/_inductor/decomposition.py +++ b/torch/_inductor/decomposition.py @@ -123,7 +123,7 @@ remove_decompositions(decompositions, decomps_to_exclude) def register_decomposition( - ops: list[Union[torch._ops.OperatorBase, torch._ops.OpOverloadPacket]] + ops: list[Union[torch._ops.OperatorBase, torch._ops.OpOverloadPacket]], ) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: for op in [ops] if callable(ops) else ops: # type: ignore[attr-defined] if op in decompositions: diff --git a/torch/_inductor/dependencies.py b/torch/_inductor/dependencies.py index 39ae0efb0a1..a9860186c84 100644 --- a/torch/_inductor/dependencies.py +++ b/torch/_inductor/dependencies.py @@ -194,7 +194,9 @@ class MemoryDep(Dep): ) new_index = sympy_subs(sympy.expand(self.index), replacement) # type: ignore[arg-type] # next PR - out = MemoryDep(self.name, new_index, tuple(var_ranges.keys()), tuple(var_ranges.values())) # type: ignore[arg-type] + out = MemoryDep( + self.name, new_index, tuple(var_ranges.keys()), tuple(var_ranges.values()) + ) # type: ignore[arg-type] return out @property @@ -649,11 +651,16 @@ def extract_loop_body_with_args( inner.load_seed(entry.buffer_name, int(name_to_index[entry.index_name])) # type: ignore[arg-type] for entry in fn.memory_usage[MemoryUsageType.STORE]: inner.store( - entry.buffer_name, name_to_index[entry.index_name], None, entry.mode # type: ignore[arg-type] + entry.buffer_name, + name_to_index[entry.index_name], + None, # type: ignore[arg-type] + entry.mode, ) for entry in fn.memory_usage[MemoryUsageType.STORE_REDUCTION]: inner.store_reduction( - entry.buffer_name, name_to_index[entry.index_name], None # type: ignore[arg-type] + entry.buffer_name, + name_to_index[entry.index_name], + None, # type: ignore[arg-type] ) for entry in fn.memory_usage[MemoryUsageType.INDEX_EXPR]: inner.index_expr(name_to_index[entry.index_name], None) @@ -661,7 +668,11 @@ def extract_loop_body_with_args( # All that matters is that we record the buffer name, so place it in the # "boundaries" name position to ensure that it's recorded. inner.bucketize( - None, (entry.buffer_name, None, None, None), None, None, None # type: ignore[arg-type] + None, + (entry.buffer_name, None, None, None), + None, + None, # type: ignore[arg-type] + None, # type: ignore[arg-type] ) # fn.memory_usage[MemoryUsageType.CHECK_BOUNDS] intentionally skipped return inner @@ -801,8 +812,9 @@ def extract_free_unbacked_symbols( handler = FreeUnbackedSymbolsOpsHandler() # NB: I cargo culted the allow_indexing patch here, I don't understand why # people do this all over - with V.set_ops_handler(handler), patch.object( - FlexibleLayout, "allow_indexing", True + with ( + V.set_ops_handler(handler), + patch.object(FlexibleLayout, "allow_indexing", True), ): fn(*args) return handler.symbols diff --git a/torch/_inductor/dtype_propagation.py b/torch/_inductor/dtype_propagation.py index 256079c8071..0bcc120af3c 100644 --- a/torch/_inductor/dtype_propagation.py +++ b/torch/_inductor/dtype_propagation.py @@ -19,8 +19,7 @@ T = TypeVar("T") class DTypeVar(Protocol): @property - def dtype(self) -> torch.dtype: - ... + def dtype(self) -> torch.dtype: ... DTypeArg = Union[DTypeVar, torch.types.Number, str, OpsValue] diff --git a/torch/_inductor/fuzzer.py b/torch/_inductor/fuzzer.py index 07432704b18..dabe7e8d4e1 100644 --- a/torch/_inductor/fuzzer.py +++ b/torch/_inductor/fuzzer.py @@ -526,6 +526,7 @@ class ConfigFuzzer: ```python import torch._inductor.config as cfg + def create_simple_test_model_gpu() -> FactoryOutputType: batch_size = 32 seq_length = 50 @@ -539,6 +540,8 @@ class ConfigFuzzer: return True return test_fn + + fuzzer = ConfigFuzzer(cfg, create_simple_test_model_gpu, seed=2) # Test every pair of configs: @@ -550,7 +553,9 @@ class ConfigFuzzer: ret = fuzzer.bisect(num_attempts=10) # reproduce a failing config - fuzzer.reproduce([{"triton.autotune_pointwise": ..., "coordinate_descent_tuning": ...}]) + fuzzer.reproduce( + [{"triton.autotune_pointwise": ..., "coordinate_descent_tuning": ...}] + ) ``` The list of known failures on inductor config are: diff --git a/torch/_inductor/fx_passes/b2b_gemm.py b/torch/_inductor/fx_passes/b2b_gemm.py index 4ee23d90bdc..325698dcb7b 100644 --- a/torch/_inductor/fx_passes/b2b_gemm.py +++ b/torch/_inductor/fx_passes/b2b_gemm.py @@ -531,7 +531,11 @@ def tuned_b2b_gemm( A.realize() B.realize() C.realize() - layout = FixedLayout(A.get_device_or_error(), A.get_dtype(), [A.shape[0], C.shape[1]]) # type: ignore[index] + layout = FixedLayout( + A.get_device_or_error(), + A.get_dtype(), + [A.shape[0], C.shape[1]], # type: ignore[index] + ) subgraph_buffer = build_subgraph_buffer( [create_placeholder("inner_mm", A.get_dtype(), A.get_device_or_error())], subgraph, diff --git a/torch/_inductor/fx_passes/ddp_fusion.py b/torch/_inductor/fx_passes/ddp_fusion.py index 5c7fc03103e..2d9409523c1 100644 --- a/torch/_inductor/fx_passes/ddp_fusion.py +++ b/torch/_inductor/fx_passes/ddp_fusion.py @@ -545,9 +545,9 @@ def schedule_comm_wait(graph: fx.Graph) -> None: node_indices = {node: i for i, node in enumerate(graph.nodes)} for allreduce in comm_blocks: # Find the earliest/first user -- target_node. - assert ( - len(allreduce.outputs) >= 1 - ), f"Found a allreduce that has zero outputs/users -- {allreduce}." + assert len(allreduce.outputs) >= 1, ( + f"Found a allreduce that has zero outputs/users -- {allreduce}." + ) # Initialize the target node to avoid typing issues. target_node = next(iter(next(iter(allreduce.outputs)).users)) target_node_index = 2**31 diff --git a/torch/_inductor/fx_passes/efficient_conv_bn_eval.py b/torch/_inductor/fx_passes/efficient_conv_bn_eval.py index bc6ebbcd5ce..0e647e37cd3 100644 --- a/torch/_inductor/fx_passes/efficient_conv_bn_eval.py +++ b/torch/_inductor/fx_passes/efficient_conv_bn_eval.py @@ -380,7 +380,9 @@ def efficient_conv_bn_eval_graph_transform(match: Match, *args, **kwargs): # argument. `graph.get_attr` and # `graph.call_function` does not allow the `name` argument. conv_get_node = graph.create_node( - op="get_attr", target=conv_node.target, name="get_conv" # type: ignore[union-attr] + op="get_attr", + target=conv_node.target, # type: ignore[union-attr] + name="get_conv", ) bn_get_node = graph.create_node( op="get_attr", target=bn_node.target, name="get_bn" diff --git a/torch/_inductor/fx_passes/fuse_attention.py b/torch/_inductor/fx_passes/fuse_attention.py index 59d90a13078..1894eb628c8 100644 --- a/torch/_inductor/fx_passes/fuse_attention.py +++ b/torch/_inductor/fx_passes/fuse_attention.py @@ -866,15 +866,18 @@ def _get_sfdp_patterns(): name += "_bs1" training_name = name + "_training" - yield training_name, { - "search_fn": pattern, - "replace_fn": replacement, - "example_inputs": args, - "trace_fn": joint_fwd_bwd, - "pass_dicts": patterns, - "extra_check": extra_check, - "scalar_workaround": workaround, - } + yield ( + training_name, + { + "search_fn": pattern, + "replace_fn": replacement, + "example_inputs": args, + "trace_fn": joint_fwd_bwd, + "pass_dicts": patterns, + "extra_check": extra_check, + "scalar_workaround": workaround, + }, + ) if workaround: assert len(workaround) == 1 and "dropout_p" in workaround @@ -886,18 +889,21 @@ def _get_sfdp_patterns(): workaround = {} inference_name = name + "_inference" - yield inference_name, { - "search_fn": pattern, - "replace_fn": replacement, - "example_inputs": args, - "trace_fn": fwd_only, - "pass_dicts": patterns, - "extra_check": extra_check, - "scalar_workaround": workaround, - # with dropout turned into clone, we end up with a number of - # semantically identical graphs - "skip_duplicates": True, - } + yield ( + inference_name, + { + "search_fn": pattern, + "replace_fn": replacement, + "example_inputs": args, + "trace_fn": fwd_only, + "pass_dicts": patterns, + "extra_check": extra_check, + "scalar_workaround": workaround, + # with dropout turned into clone, we end up with a number of + # semantically identical graphs + "skip_duplicates": True, + }, + ) @functools.lru_cache(None) diff --git a/torch/_inductor/fx_passes/group_batch_fusion.py b/torch/_inductor/fx_passes/group_batch_fusion.py index 7df5b6f2120..baa9a8cb660 100644 --- a/torch/_inductor/fx_passes/group_batch_fusion.py +++ b/torch/_inductor/fx_passes/group_batch_fusion.py @@ -271,7 +271,9 @@ class PostGradBatchLinearFusion(BatchFusion): args=(batch_biases[i],), kwargs={"size": broadcast_shape}, ) - broadcast_bias.meta["val"] = aten.broadcast_to(batch_biases_meta[i]["val"], broadcast_shape) # type: ignore[assignment] + broadcast_bias.meta["val"] = aten.broadcast_to( + batch_biases_meta[i]["val"], broadcast_shape + ) # type: ignore[assignment] new_bias_add = graph.call_function( # type: ignore[operator] aten.add.Tensor, args=((broadcast_bias, new_mm)) ) @@ -803,9 +805,9 @@ class BatchLayernormFusion(BatchFusion): group_biases = None # type: ignore[assignment] if all(weight is None for weight in group_weights): group_weights = None # type: ignore[assignment] - assert all( - eps == group_epss[0] for eps in group_epss - ), "all epsilon values must be equal" + assert all(eps == group_epss[0] for eps in group_epss), ( + "all epsilon values must be equal" + ) with graph.inserting_before(subset[0]): # type: ignore[operator] stack_input = graph.call_function( # type: ignore[operator] @@ -996,7 +998,11 @@ class BatchPointwiseOpsPostGradFusion(BatchPointwiseOpsFusionFactory): # for relu op, we also use the inplace to construct the key # we batch the ops with same parent to enable followup split cat parent = node.args[0] - parent = parent.target if self.graph_search_options.get("fuse_nodes_with_same_parent", False) else "" # type: ignore[union-attr] + parent = ( + parent.target # type: ignore[union-attr] + if self.graph_search_options.get("fuse_nodes_with_same_parent", False) + else "" + ) group_key = ( "batch_aten_" + self.op.__name__.lower().split(".")[0], str(input.meta["val"].shape), @@ -1293,9 +1299,9 @@ def get_fusion_candidates( """ q: collections.deque[tuple[int, torch.fx.Node]] = collections.deque() - candidate_dict: collections.defaultdict[ - Any, list[torch.fx.Node] - ] = collections.defaultdict(list) + candidate_dict: collections.defaultdict[Any, list[torch.fx.Node]] = ( + collections.defaultdict(list) + ) if root_node.target in SEARCH_EXCLUSIONS: return candidate_dict diff --git a/torch/_inductor/fx_passes/micro_pipeline_tp.py b/torch/_inductor/fx_passes/micro_pipeline_tp.py index 87e76c9fbd4..7490f013426 100644 --- a/torch/_inductor/fx_passes/micro_pipeline_tp.py +++ b/torch/_inductor/fx_passes/micro_pipeline_tp.py @@ -763,9 +763,7 @@ def _get_node_to_ancestors( """ Compute the ancestors for all nodes in a graph. """ - node_to_ancestors = defaultdict( - OrderedSet[torch.fx.Node] - ) # type: ignore[var-annotated] + node_to_ancestors = defaultdict(OrderedSet[torch.fx.Node]) # type: ignore[var-annotated] for node in graph.nodes: node_to_ancestors[node] = OrderedSet(node.all_input_nodes) for dep in node.all_input_nodes: diff --git a/torch/_inductor/fx_passes/mkldnn_fusion.py b/torch/_inductor/fx_passes/mkldnn_fusion.py index db20ba2a49d..9e69f96d27f 100644 --- a/torch/_inductor/fx_passes/mkldnn_fusion.py +++ b/torch/_inductor/fx_passes/mkldnn_fusion.py @@ -558,9 +558,9 @@ if torch._C._has_mkldnn: binary_nodes = filter_nodes(match.nodes, binary_op) def _get_compute_node(_binary_node, _other_index): - assert ( - len(_binary_node.all_input_nodes) == 2 - ), "Binary node should have 2 input nodes." + assert len(_binary_node.all_input_nodes) == 2, ( + "Binary node should have 2 input nodes." + ) _compute_index = 1 if (_other_index == 0) else 0 return _binary_node.args[_compute_index] @@ -614,9 +614,9 @@ if torch._C._has_mkldnn: else: computation_args += [1.0, None, [], None] counters["inductor"]["mkldnn_conv_binary_unary_fusion_matcher_count"] += 1 - counters["inductor"][ - "mkldnn_conv_binary_unary_fusion_matcher_nodes" - ] += len(match.nodes) + counters["inductor"]["mkldnn_conv_binary_unary_fusion_matcher_nodes"] += ( + len(match.nodes) + ) return L[fusion_op](*computation_args) return fn @@ -659,9 +659,9 @@ if torch._C._has_mkldnn: else: computation_args += [1.0, None, [], None] counters["inductor"]["mkldnn_conv_binary_unary_fusion_matcher_count"] += 1 - counters["inductor"][ - "mkldnn_conv_binary_unary_fusion_matcher_nodes" - ] += len(match.nodes) + counters["inductor"]["mkldnn_conv_binary_unary_fusion_matcher_nodes"] += ( + len(match.nodes) + ) # Make sure the other is not an alias or mutation(fx side doesn't has such info). other.realize() if not _can_be_inplace(other) or other.data.shape != list( @@ -1310,9 +1310,9 @@ if torch._C._has_mkldnn: ) batch_size = input.meta.get("val").shape[0] if has_free_symbols(batch_size): - assert ( - is_lp_weight or mkldnn._is_mkldnn_acl_supported() - ), f"only bf16/fp16 weight prepacking supports dynamic shape inputs but got {weight_dtype}" + assert is_lp_weight or mkldnn._is_mkldnn_acl_supported(), ( + f"only bf16/fp16 weight prepacking supports dynamic shape inputs but got {weight_dtype}" + ) # For bfloat16 dynamic shape path, using input size hint to pack weight for a better performance. packed_weight_inputs = ( transpose_weight_node, diff --git a/torch/_inductor/fx_passes/pad_mm.py b/torch/_inductor/fx_passes/pad_mm.py index ef6353a0c88..a42296fe68a 100644 --- a/torch/_inductor/fx_passes/pad_mm.py +++ b/torch/_inductor/fx_passes/pad_mm.py @@ -437,7 +437,7 @@ def _should_pad_bench( return False def realize_symbols( - ds: Union[torch.Size, tuple[torch.SymInt, ...]] + ds: Union[torch.Size, tuple[torch.SymInt, ...]], ) -> list[int]: return [d if isinstance(d, int) else d.node.hint for d in ds] diff --git a/torch/_inductor/fx_passes/post_grad.py b/torch/_inductor/fx_passes/post_grad.py index a239a8b2025..89a69d6db57 100644 --- a/torch/_inductor/fx_passes/post_grad.py +++ b/torch/_inductor/fx_passes/post_grad.py @@ -137,9 +137,9 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool): pattern_matcher_pass.apply ) if not is_same_dict(counters["inductor"], inductor_before_change): - optimus_scuba_log[ - f"{pattern_matcher_pass.pass_name}_post_grad" - ] = upload_graph(gm.graph) + optimus_scuba_log[f"{pattern_matcher_pass.pass_name}_post_grad"] = ( + upload_graph(gm.graph) + ) if config.b2b_gemm_pass: B2B_GEMM_PASS.apply(gm.graph) # type: ignore[arg-type] diff --git a/torch/_inductor/fx_passes/pre_grad.py b/torch/_inductor/fx_passes/pre_grad.py index e719c7499e2..aefac0bb39e 100644 --- a/torch/_inductor/fx_passes/pre_grad.py +++ b/torch/_inductor/fx_passes/pre_grad.py @@ -277,9 +277,9 @@ def pre_grad_passes( for _ in range(counter): pattern_matcher_pass.apply(gm.graph) # type: ignore[arg-type] if not is_same_dict(counters["inductor"], inductor_before_change): - optimus_scuba_log[ - f"{pattern_matcher_pass.pass_name}_pre_grad" - ] = upload_graph(gm.graph) + optimus_scuba_log[f"{pattern_matcher_pass.pass_name}_pre_grad"] = ( + upload_graph(gm.graph) + ) # TODO: move efficient_conv_bn_eval_pass to the fusions dict too. efficient_conv_bn_eval_pass.apply(gm.graph) # type: ignore[arg-type] diff --git a/torch/_inductor/fx_passes/quantization.py b/torch/_inductor/fx_passes/quantization.py index 20d50d8bd3c..7d7fac2ff5d 100644 --- a/torch/_inductor/fx_passes/quantization.py +++ b/torch/_inductor/fx_passes/quantization.py @@ -763,9 +763,9 @@ def _register_quantized_conv_binary_lowering( accum.realize() from .mkldnn_fusion import _can_be_inplace - assert _can_be_inplace( - accum - ), "QConv Binary Inplace Fusion requires accum is not an alias or mutation." + assert _can_be_inplace(accum), ( + "QConv Binary Inplace Fusion requires accum is not an alias or mutation." + ) computation_args = ( x, @@ -1307,9 +1307,9 @@ def _register_dequant_promotion_pass(pattern, pass_number, dtype=torch.float32): def clone_to_new_node(graph, source_node, user_node): # Clone the source_node to a new node # Replace user_node's input from source_node to new_node - assert ( - source_node.op == "call_function" - ), "clone_to_new_node only support node.op call_function" + assert source_node.op == "call_function", ( + "clone_to_new_node only support node.op call_function" + ) with graph.inserting_before(user_node): new_node = graph.call_function( source_node.target, @@ -1343,9 +1343,9 @@ def _register_dequant_promotion_pass(pattern, pass_number, dtype=torch.float32): # For a dequant pattern, we expect the start node is a dequantize_per_tensor node return _node else: - assert ( - len(_node.args) >= 1 - ), "In in dequant pattern, each node should have more than 1 arg." + assert len(_node.args) >= 1, ( + "In in dequant pattern, each node should have more than 1 arg." + ) return _find_first_node_in_dequant_pattern(_node.args[0]) dequant_pattern_start_node = _find_first_node_in_dequant_pattern( diff --git a/torch/_inductor/fx_passes/split_cat.py b/torch/_inductor/fx_passes/split_cat.py index ed951efa0da..3a3df02bdba 100644 --- a/torch/_inductor/fx_passes/split_cat.py +++ b/torch/_inductor/fx_passes/split_cat.py @@ -616,7 +616,8 @@ def merge_splits( dim=first_split_dim, ) first_split_num_to_user = { - user.args[1]: user for user in first_split.users.keys() # type: ignore[union-attr] + user.args[1]: user + for user in first_split.users.keys() # type: ignore[union-attr] } new_split_num = 0 @@ -706,7 +707,11 @@ class SplitCatSimplifier: graph, split_node, split_sections, user_inputs_list, simplified_split_ranges ) self.replace_cat( - graph, split_node, next_users, user_inputs_list_new, transform_params_list # type: ignore[arg-type] + graph, + split_node, + next_users, + user_inputs_list_new, + transform_params_list, # type: ignore[arg-type] ) self.erase_old_nodes(graph, split_node, next_users) # type: ignore[arg-type] counters["inductor"]["unbind_stack_pass"] += 1 @@ -913,7 +918,9 @@ class SplitCatSimplifier: ) if is_node_meta_valid(split_input): # type: ignore[arg-type, union-attr] new_split.meta["example_value"] = torch.split( - split_input.meta["example_value"], [r[1] - r[0] for r in split_ranges], dim=split_dim # type: ignore[union-attr] + split_input.meta["example_value"], # type: ignore[union-attr] + [r[1] - r[0] for r in split_ranges], + dim=split_dim, ) counters["inductor"]["scmerge_split_added"] += 1 split_items = [] @@ -1005,7 +1012,10 @@ class SplitCatSimplifier: stacked_input = graph.call_function( torch.stack, args=(to_stack,), kwargs={"dim": stack_dim} ) - stacked_input.meta["example_value"] = torch.stack(to_stack_meta, dim=stack_dim) # type: ignore[arg-type, union-attr] + stacked_input.meta["example_value"] = torch.stack( # type: ignore[arg-type] + to_stack_meta, + dim=stack_dim, # type: ignore[arg-type] + ) to_stack, to_stack_meta = [], [] stack_dim = None user_inputs_new_transformed.append(stacked_input) @@ -1023,19 +1033,28 @@ class SplitCatSimplifier: user_input_new = graph.call_function( torch.unflatten, args=(user_input_new, *unflatten_params) ) - user_input_new.meta["example_value"] = torch.unflatten(user_input_new_meta, *unflatten_params) # type: ignore[arg-type, possibly-undefined, union-attr] + user_input_new.meta["example_value"] = torch.unflatten( # type: ignore[arg-type] + user_input_new_meta, # type: ignore[arg-type] + *unflatten_params, # type: ignore[arg-type] + ) if movedim_params: user_input_new_meta = user_input_new.meta["example_value"] user_input_new = graph.call_function( torch.movedim, args=(user_input_new, *movedim_params) ) - user_input_new.meta["example_value"] = torch.movedim(user_input_new_meta, *movedim_params) # type: ignore[arg-type, possibly-undefined, union-attr] + user_input_new.meta["example_value"] = torch.movedim( # type: ignore[arg-type] + user_input_new_meta, # type: ignore[arg-type] + *movedim_params, # type: ignore[arg-type] + ) if flatten_params: user_input_new_meta = user_input_new.meta["example_value"] user_input_new = graph.call_function( torch.flatten, args=(user_input_new, *flatten_params) ) - user_input_new.meta["example_value"] = torch.flatten(user_input_new_meta, *flatten_params) # type: ignore[arg-type, possibly-undefined, union-attr] + user_input_new.meta["example_value"] = torch.flatten( # type: ignore[arg-type] + user_input_new_meta, + *flatten_params, # type: ignore[arg-type] + ) user_inputs_new_transformed.append(user_input_new) user_inputs_new_transformed_meta.append( user_input_new.meta["example_value"] @@ -1044,7 +1063,10 @@ class SplitCatSimplifier: stacked_input = graph.call_function( torch.stack, args=(to_stack,), kwargs={"dim": stack_dim} ) - stacked_input.meta["example_value"] = torch.stack(to_stack_meta, dim=stack_dim) # type: ignore[arg-type, union-attr] + stacked_input.meta["example_value"] = torch.stack( # type: ignore[arg-type] + to_stack_meta, + dim=stack_dim, # type: ignore[arg-type] + ) user_inputs_new_transformed.append(stacked_input) user_inputs_new_transformed_meta.append( stacked_input.meta["example_value"] @@ -1058,14 +1080,15 @@ class SplitCatSimplifier: kwargs={"dim": cat_dim}, ) new_cat_node.meta["example_value"] = torch.cat( - user_inputs_new_transformed_meta, dim=cat_dim + user_inputs_new_transformed_meta, + dim=cat_dim, ) counters["inductor"]["scmerge_cat_added"] += 1 else: new_cat_node = user_inputs_new_transformed[-1] - new_cat_node.meta[ - "example_value" - ] = user_inputs_new_transformed_meta[-1] + new_cat_node.meta["example_value"] = ( + user_inputs_new_transformed_meta[-1] + ) if ( user_node.target == torch.cat @@ -1077,7 +1100,11 @@ class SplitCatSimplifier: new_cat_node = graph.call_function( torch.flatten, args=(new_cat_node, cat_dim, cat_dim + 1) ) - new_cat_node.meta["example_value"] = torch.flatten(new_cat_node_meta, cat_dim, cat_dim + 1) # type: ignore[possibly-undefined, union-attr] + new_cat_node.meta["example_value"] = torch.flatten( + new_cat_node_meta, + cat_dim, + cat_dim + 1, + ) user_node.replace_all_uses_with(new_cat_node) new_cats.append(new_cat_node) @@ -1123,9 +1150,7 @@ class UnbindCatRemover(SplitCatSimplifier): ] if not is_sorted_and_consecutive(getitem_indices) or len( # type: ignore[arg-type] getitem_indices - ) != len( - unbind_node.meta["example_value"] - ): + ) != len(unbind_node.meta["example_value"]): return num_unbind = len(getitem_indices) split_sections = [1 for _ in range(num_unbind)] # type: ignore[operator, arg-type] @@ -1510,7 +1535,8 @@ def merge_getitem_cat(match: Match, split_sections: list[int], dim: int): fused_tensor_size += split_node.args[1][i] # type: ignore[operator, assignment, index] # update the split sections split_sections[indices[0]] = calculate_fused_tensor_size( # type: ignore[index] - split_node, indices # type: ignore[arg-type] + split_node, + indices, # type: ignore[arg-type] ) # padding others with zeros to keep the same dict size for i in indices[1:]: @@ -1613,10 +1639,12 @@ def mutate_cat_node(match: Match, split_sections: list[int], dim: int): elif is_node_meta_valid(split_node.args[0]): # type: ignore[arg-type] # check the split dim, and construct the slice tuple start_fused_size = calculate_fused_tensor_size( - split_node, list(range(indices[0])) # type: ignore[arg-type] + split_node, + list(range(indices[0])), # type: ignore[arg-type] ) end_fused_size = start_fused_size + calculate_fused_tensor_size( - split_node, indices # type: ignore[arg-type] + split_node, + indices, # type: ignore[arg-type] ) slice_list = [] for i in range(len(split_node.args[0].meta["example_value"].shape)): # type: ignore[union-attr] @@ -1714,7 +1742,10 @@ def merge_split_cat_aten(match: Match, *args, **kwargs): continue # check the cat node has consecutive indices indices = [arg.args[1] for arg in cat_node.args[0]] # type: ignore[union-attr] - if not is_sorted_and_consecutive(indices) and len(getitem_nodes) != len(cat_inputs): # type: ignore[arg-type] + if ( + not is_sorted_and_consecutive(indices) # type: ignore[arg-type] + and len(getitem_nodes) != len(cat_inputs) + ): continue # replace the users of the cat node to be the input of the split node cat_node.replace_all_uses_with(split_input) @@ -1764,7 +1795,10 @@ def merge_select_cat_aten(match: Match, *args, **kwargs): continue # check the cat node has consecutive indices indices = [select.args[2] for select in cat_node.args[0]] # type: ignore[union-attr] - if not is_sorted_and_consecutive(indices) or len(select_nodes) != len(cat_inputs): # type: ignore[arg-type] + if ( + not is_sorted_and_consecutive(indices) # type: ignore[arg-type] + or len(select_nodes) != len(cat_inputs) + ): continue # check all the select nodes can be merged to the cat node input if len(indices) != select_nodes[0].args[0].meta["val"].shape[cat_dim]: # type: ignore[union-attr] @@ -2318,7 +2352,9 @@ def unbind_cat_to_view(match: Match, unbind_input: torch.fx.Node, dim: int): args=(new_cat_args,), kwargs={"dim": cat_dim}, ) - new_cat_node.meta["example_value"] = torch.cat(new_cat_args_meta, dim=cat_dim) # type: ignore[arg-type] + new_cat_node.meta["example_value"] = torch.cat( + new_cat_args_meta, dim=cat_dim + ) # type: ignore[arg-type] cat_node.replace_all_uses_with(new_cat_node) new_cat_node.meta.update(cat_node.meta) # remove inputs of cat_node if they have no users @@ -2411,7 +2447,8 @@ def convert_reshape_cat_arg_to_stack( args=(permute_node, tuple(stack_node_shape)), # type: ignore[arg-type] ) reshape_node.meta["example_value"] = torch.Tensor.view( - permute_node.meta["example_value"], tuple(stack_node_shape) # type: ignore[arg-type] + permute_node.meta["example_value"], + tuple(stack_node_shape), # type: ignore[arg-type] ) return reshape_node @@ -2687,7 +2724,9 @@ def move_reshape_out_of_split_stack(match: Match, *args, **kwargs): cat_inputs.append(decomposed_stack_node) # cat_arg must be the split input view_shape_list = get_view_shape_list(cat_arg, stack_dim) - stack_node_shape = torch.reshape(cat_arg.meta["example_value"], tuple(view_shape_list)).shape # type: ignore[union-attr] + stack_node_shape = torch.reshape( + cat_arg.meta["example_value"], tuple(view_shape_list) + ).shape # type: ignore[union-attr] cat_inputs.append( convert_reshape_cat_arg_to_stack( graph, diff --git a/torch/_inductor/fx_utils.py b/torch/_inductor/fx_utils.py index 96b713822b0..280bf47a0b4 100644 --- a/torch/_inductor/fx_utils.py +++ b/torch/_inductor/fx_utils.py @@ -105,9 +105,9 @@ class FakeTensorUpdater: if new is None: return old is None if not isinstance(new, torch.Tensor): - assert isinstance( - new, (torch.SymInt, torch.SymBool, torch.SymFloat) - ), f"Unknown type {type(new)} in {self.graph}" + assert isinstance(new, (torch.SymInt, torch.SymBool, torch.SymFloat)), ( + f"Unknown type {type(new)} in {self.graph}" + ) return ( new.node.shape_env._maybe_evaluate_static( sympy.Eq(new.node.expr, old.node.expr) diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index 1c77f68b0b9..72e4567f085 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -136,7 +136,9 @@ else: def may_get_constant_buffer_dtype(constant_buffer: sympy.Expr) -> Optional[torch.dtype]: assert isinstance( constant_buffer, (sympy.Symbol, sympy.Expr, sympy.core.numbers.Integer) - ), "get_constant_buffer_dtype only supports input of sympy.Symbol, sympy.Expr or sympy.core.numbers.Integer" + ), ( + "get_constant_buffer_dtype only supports input of sympy.Symbol, sympy.Expr or sympy.core.numbers.Integer" + ) if isinstance(constant_buffer, sympy.core.numbers.Integer): return torch.int64 @@ -308,9 +310,9 @@ class GraphLowering(torch.fx.Interpreter): self.reuse_shape_env = True self._shape_env = shape_env # We're going to mutate ras_by_symbol as we finish generating them - self.ras_by_symbol: dict[ - Optional[sympy.Symbol], list[RuntimeAssert] - ] = shape_env.deferred_runtime_asserts.copy() + self.ras_by_symbol: dict[Optional[sympy.Symbol], list[RuntimeAssert]] = ( + shape_env.deferred_runtime_asserts.copy() + ) self.bound_unbacked_symbols = OrderedSet[sympy.Symbol]() self.sizevars = SizeVarAllocator(shape_env) self.graph_input_names: list[str] = [] @@ -400,9 +402,7 @@ class GraphLowering(torch.fx.Interpreter): self.cache_path: str = "" # This is the path in the filesystem where the compiled artifact is stored self.cache_linemap: list[ tuple[int, str] - ] = ( - [] - ) # This is the linemap used by the profiler to mark custom compiled kernels getting run + ] = [] # This is the linemap used by the profiler to mark custom compiled kernels getting run # Used if lowering encounters cases where cudagraphs are not supported self.disable_cudagraphs_reason: Optional[str] = None @@ -1012,7 +1012,10 @@ class GraphLowering(torch.fx.Interpreter): ) def placeholder( - self, target: str, args: tuple[object], kwargs: dict[str, object] # type: ignore[override] + self, + target: str, # type: ignore[override] + args: tuple[object], # type: ignore[override] + kwargs: dict[str, object], ) -> Union[Expr, TensorBox, None]: self.placeholder_idx += 1 example = super().placeholder(target, args, kwargs) # type: ignore[arg-type] @@ -1118,9 +1121,9 @@ class GraphLowering(torch.fx.Interpreter): return target(*args, **kwargs) if target not in lowerings: - assert isinstance( - target, torch._ops.OpOverload - ), f"{target} is not an OpOverload" + assert isinstance(target, torch._ops.OpOverload), ( + f"{target} is not an OpOverload" + ) base_name = target.name().split(".")[0] if base_name in FALLBACK_ALLOW_LIST: make_fallback(target, warn=False, override_decomp=True) @@ -1189,7 +1192,10 @@ class GraphLowering(torch.fx.Interpreter): return len(t.shape) == 1 and t.shape[0] <= 8 def get_attr( - self, target: str, args: tuple[()], kwargs: dict[str, object] # type: ignore[override] + self, + target: str, # type: ignore[override] + args: tuple[()], # type: ignore[override] + kwargs: dict[str, object], ) -> Union[Constant, TensorBox, ir.Subgraph, TorchBindObject]: # this is a constant value = getattr_recursive(self.module, target) # type: ignore[arg-type] @@ -1241,7 +1247,10 @@ class GraphLowering(torch.fx.Interpreter): raise AssertionError def output( - self, target: str, args: tuple[object], kwargs: dict[str, object] # type: ignore[override] + self, + target: str, # type: ignore[override] + args: tuple[object], # type: ignore[override] + kwargs: dict[str, object], ) -> None: result = super().output(target, args, kwargs) # type: ignore[arg-type] if not isinstance(result, (tuple, list)): @@ -1439,9 +1448,11 @@ class GraphLowering(torch.fx.Interpreter): if is_call_function: args, kwargs = self.fetch_args_kwargs_from_env(n) origins |= gather_origins(args, kwargs) - with ir.IRNode.current_origins(origins), self.set_current_node( - n - ), V.set_current_node(n): + with ( + ir.IRNode.current_origins(origins), + self.set_current_node(n), + V.set_current_node(n), + ): if ( n.op == "call_function" and n.target is not operator.getitem @@ -1454,7 +1465,8 @@ class GraphLowering(torch.fx.Interpreter): ): debug("fallback_handler") result = fallback_handler(n.target, add_to_fallback_set=False)( - *args, **kwargs # type: ignore[possibly-undefined] + *args, # type: ignore[possibly-undefined] + **kwargs, # type: ignore[possibly-undefined] ) elif ( n.op == "call_function" @@ -1833,9 +1845,9 @@ class GraphLowering(torch.fx.Interpreter): wrapper_code_gen_cls = get_wrapper_codegen_for_device( self.device_type, self.cpp_wrapper ) - assert ( - wrapper_code_gen_cls is not None - ), f"Device {self.device_type} not supported" + assert wrapper_code_gen_cls is not None, ( + f"Device {self.device_type} not supported" + ) self.wrapper_code = wrapper_code_gen_cls.create( is_subgraph, subgraph_name, @@ -1866,7 +1878,7 @@ class GraphLowering(torch.fx.Interpreter): compiled = self.compile_to_module().call def materialize( - x: Union[torch.SymInt, torch.SymFloat, torch.Tensor] + x: Union[torch.SymInt, torch.SymFloat, torch.Tensor], ) -> Union[int, float, torch.Tensor]: if x is None: return None @@ -1876,9 +1888,9 @@ class GraphLowering(torch.fx.Interpreter): elif isinstance(x, FakeTensor): return defake(x) else: - assert isinstance( - x, torch.Tensor - ), "Unknown type when creating real inputs" + str(type(x)) + assert isinstance(x, torch.Tensor), ( + "Unknown type when creating real inputs" + str(type(x)) + ) return x tracing_context = torch._guards.TracingContext.try_get() diff --git a/torch/_inductor/index_propagation.py b/torch/_inductor/index_propagation.py index 2e564041340..16430ced7e6 100644 --- a/torch/_inductor/index_propagation.py +++ b/torch/_inductor/index_propagation.py @@ -5,21 +5,22 @@ propagation of sympy expressions downstream of ops.index_expr calls. For example, say we have the IR: - tmp0 = ops.index_expr(x, torch.int32) - tmp1 = ops.constant(2, torch.int32) - tmp2 = ops.mul(tmp0, tmp1) - tmp3 = ops.indirect_indexing(tmp2, x_size) - tmp4 = ops.load("buf0", tmp3) + tmp0 = ops.index_expr(x, torch.int32) + tmp1 = ops.constant(2, torch.int32) + tmp2 = ops.mul(tmp0, tmp1) + tmp3 = ops.indirect_indexing(tmp2, x_size) + tmp4 = ops.load("buf0", tmp3) The underlying handler would just see: - ops.load("buf0", x * 2) + ops.load("buf0", x * 2) This is limited by the set of operators handled in the sympy expression printers. So simple operations like minimum and maximum cannot be translated to SymPy expressions yet, despite sympy.Min and sympy.Max existing. """ + import itertools from collections.abc import Sequence from dataclasses import dataclass @@ -179,9 +180,9 @@ class IndexPropVar: return IndexPropVar(expr, is_symbolic=True) def __post_init__(self): - assert not self.is_symbolic or isinstance( - self.value, TypedExpr - ), "Symbolic IndexPropVar must contain a TypedExpr" + assert not self.is_symbolic or isinstance(self.value, TypedExpr), ( + "Symbolic IndexPropVar must contain a TypedExpr" + ) IndexPropResult: TypeAlias = Union[IndexPropVar, tuple["IndexPropResult", ...]] @@ -251,14 +252,12 @@ class IndexPropagation(DefaultHandler): name: Literal["indirect_indexing"], args: Sequence[Any], kwargs: dict[str, Any], - ) -> IndexPropVar: - ... + ) -> IndexPropVar: ... @overload def fallback( self, name: str, args: Sequence[Any], kwargs: dict[str, Any] - ) -> IndexPropResult: - ... + ) -> IndexPropResult: ... def fallback( self, name: str, args: Sequence[Any], kwargs: dict[str, Any] @@ -283,8 +282,7 @@ class IndexPropagation(DefaultHandler): is_valid_expr = new_expr is not NotImplemented and ( # Inductor doesn't expect floating point in sympy expressions, but # allow floating point constants to be propagated - new_expr.is_constant() - or new_expr.expr.is_integer + new_expr.is_constant() or new_expr.expr.is_integer ) if not is_valid_expr: return self.fallback(name, args, kwargs) diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 03205c43295..e740f5f6bbc 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -211,7 +211,9 @@ def validate_ir(node_or_nodes: Optional[_NodeOrNodes]) -> None: int, EffectfulKernel, ), - ), f"Found {type(nodes)}, which is not a supported top level IR node. See [Note: Inductor IR]" + ), ( + f"Found {type(nodes)}, which is not a supported top level IR node. See [Note: Inductor IR]" + ) # Be picky about the accepted data structure (don't use pytree here) _check_tensorbox(node_or_nodes) @@ -298,13 +300,11 @@ def get_stride_order( @overload -def ir_node_to_tensor(x: Literal[None], guard_shape: bool = True) -> None: - ... +def ir_node_to_tensor(x: Literal[None], guard_shape: bool = True) -> None: ... @overload -def ir_node_to_tensor(x: IRNode, guard_shape: bool = True) -> torch.Tensor: - ... +def ir_node_to_tensor(x: IRNode, guard_shape: bool = True) -> torch.Tensor: ... def ir_node_to_tensor( @@ -346,7 +346,7 @@ def may_convert_to_optional( def get_device_type( - x: Union[IRNode, OutputSpec, torch.device, None, str] + x: Union[IRNode, OutputSpec, torch.device, None, str], ) -> Optional[str]: if isinstance(x, str) or x is None: return x @@ -698,8 +698,7 @@ class IRNode: if TYPE_CHECKING: @property - def dtype(self) -> torch.dtype: - ... + def dtype(self) -> torch.dtype: ... @ir_dataclass(frozen=False) @@ -839,8 +838,9 @@ class Loops(IRNode): @cache_on_self def inner_fn_opcount(self) -> OpCountResult: opcounter = OpCounterCSE(V.MockHandler()) - with V.set_ops_handler(opcounter), patch.object( - FlexibleLayout, "allow_indexing", True + with ( + V.set_ops_handler(opcounter), + patch.object(FlexibleLayout, "allow_indexing", True), ): self.inner_fn(*self.inner_fn_args()) return opcounter.getvalue() @@ -1364,9 +1364,9 @@ class Reduction(Loops): # "all" is desugared to `!any(!val)` } - assert ( - reduction_type in rtypes_to_inits.keys() - ), f"{reduction_type} not supported for zero-dimension tensors!" + assert reduction_type in rtypes_to_inits.keys(), ( + f"{reduction_type} not supported for zero-dimension tensors!" + ) def const_fn(index: int) -> OpsValue: return ops.constant(rtypes_to_inits[reduction_type], dst_dtype) @@ -1575,9 +1575,9 @@ class Reduction(Loops): new_ranges: Sequence[Integer], new_reduction_ranges: Sequence[Integer], ) -> Callable[[Sequence[sympy.Expr], Sequence[sympy.Expr]], OpsValue]: - assert all( - r == 1 for r in original_ranges - ), f"Only enabled for numel_hint == 1, found {original_ranges=}" + assert all(r == 1 for r in original_ranges), ( + f"Only enabled for numel_hint == 1, found {original_ranges=}" + ) reindex = View.dynamic_reshape_indexer( original_reduction_ranges, tuple(new_ranges) + tuple(new_reduction_ranges) ) @@ -1828,7 +1828,7 @@ class WelfordReduction(Reduction): if reduction_numel == 1: def copy( - loader: Callable[[Sequence[Expr], Sequence[Expr]], OpsValue] + loader: Callable[[Sequence[Expr], Sequence[Expr]], OpsValue], ) -> TensorBox: def inner_fn(idx: Sequence[Expr]) -> OpsValue: reduction_index = [sympy.S.Zero for _ in reduction_ranges] @@ -2571,9 +2571,9 @@ class ExpandView(BaseView): # NB: new_size[i] == old_size[i] is expected to already be # guarded because the meta formula was expected to have taught # us this equality. - assert ( - sizevars.size_hint(new_size[i] - old_size[i], fallback=0) == 0 - ), "Broadcast failed in ExpandView({x.get_size()}, {new_size}) on dimension {i}" + assert sizevars.size_hint(new_size[i] - old_size[i], fallback=0) == 0, ( + "Broadcast failed in ExpandView({x.get_size()}, {new_size}) on dimension {i}" + ) return new_size @classmethod @@ -3382,9 +3382,9 @@ class Layout(OutputSpec): ) def make_indexer(self) -> Callable[[Sequence[Expr]], Expr]: - assert ( - FlexibleLayout.allow_indexing - ), f"convert {type(self).__name__} to FixedLayout first" + assert FlexibleLayout.allow_indexing, ( + f"convert {type(self).__name__} to FixedLayout first" + ) return self.as_fixed().make_indexer() def __eq__(self, other) -> bool: # type: ignore[no-untyped-def] @@ -3684,9 +3684,9 @@ class MutationLayoutSHOULDREMOVE(Layout): return target result = unwrap_views(self.target) - assert isinstance( - result, Buffer - ), "MutationLayoutSHOULDREMOVE must refer to a buffer" + assert isinstance(result, Buffer), ( + "MutationLayoutSHOULDREMOVE must refer to a buffer" + ) return result def real_layout(self): # type: ignore[no-untyped-def] @@ -3803,7 +3803,9 @@ class Buffer(IRNode): assert isinstance(self.layout, FlexibleLayout) self.layout = self.layout.as_same_order(stride) - def freeze_layout_with_exact_strides(self, exact_strides, allow_padding=False) -> None: # type: ignore[no-untyped-def] + def freeze_layout_with_exact_strides( # type: ignore[no-untyped-def] + self, exact_strides, allow_padding=False + ) -> None: assert isinstance(self.layout, FlexibleLayout) self.layout = self.layout.as_exact_strides( exact_strides, allow_padding=allow_padding @@ -4365,9 +4367,9 @@ class TritonTemplateBuffer(TemplateBuffer): torch.ops.higher_order.flex_attention_backward, ) current_node = V.graph.current_node.target - assert ( - current_node in allowed_set - ), f"Mutated inputs are only allowed for {allowed_set} but got {current_node}" + assert current_node in allowed_set, ( + f"Mutated inputs are only allowed for {allowed_set} but got {current_node}" + ) device = self.inputs[0].get_device() self.outputs += [ MutationOutput(NoneLayout(device=device), buf, self) @@ -5106,7 +5108,8 @@ class ExternKernel(InputsKernel): x_unwrap_view.freeze_layout() index_args, var_ranges = dependencies.index_vars_squeeze( - x.get_size(), prefix="r" # type: ignore[arg-type] + x.get_size(), + prefix="r", # type: ignore[arg-type] ) range_vars = index_args[0] index = x.make_indexer()(range_vars) @@ -5404,9 +5407,9 @@ class ExternKernel(InputsKernel): # pass in a list of const arg names for arg_properties lookup. name_to_arg_properties = None if names and self.arg_properties: - assert len(self.constant_args) == len( - names - ), "names passed to codegen_const_args does not match self.constant_args" + assert len(self.constant_args) == len(names), ( + "names passed to codegen_const_args does not match self.constant_args" + ) name_to_arg_properties = { arg.get("name"): arg for arg in self.arg_properties } @@ -5442,9 +5445,9 @@ class ExternKernel(InputsKernel): args = [] for i, x in enumerate(inputs): if V.graph.cpp_wrapper: - assert self.arg_properties and i < len( - self.arg_properties - ), "Invalid access to ExternKernel.arg_properties" + assert self.arg_properties and i < len(self.arg_properties), ( + "Invalid access to ExternKernel.arg_properties" + ) type_ = self.arg_properties[i].get("type") args.append(V.graph.wrapper_code.val_to_arg_str(x, type_)) else: @@ -5914,7 +5917,9 @@ class UserDefinedTritonKernel(ExternKernel): def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: return OrderedSet() - def __init__(self, *, kernel_idx, grid, tma_descriptor_metadata, kernel_args) -> None: # type: ignore[no-untyped-def] + def __init__( # type: ignore[no-untyped-def] + self, *, kernel_idx, grid, tma_descriptor_metadata, kernel_args + ) -> None: inputs = [] kwargs = {} constant_args = [] @@ -6835,9 +6840,9 @@ class FallbackKernel(ExternKernelAlloc): elif isinstance(output, torch.SymInt): return output.node.expr else: - assert ( - output is None - ), f"FallbackKernel output type {type(output)} is not supported" + assert output is None, ( + f"FallbackKernel output type {type(output)} is not supported" + ) return None outputs = generate_output(example_output, []) @@ -6919,7 +6924,12 @@ class MultiOutput(ExternKernel): ) self.codegen_size_asserts(wrapper) - def __init__(self, layout: OutputSpec, input, indices: list[tuple[Any, ...]]) -> None: # type: ignore[no-untyped-def] + def __init__( # type: ignore[no-untyped-def] + self, + layout: OutputSpec, + input, + indices: list[tuple[Any, ...]], + ) -> None: super().__init__(None, layout, [input], ()) self.name = V.graph.register_buffer(self) V.graph.register_operation(self) @@ -7496,9 +7506,9 @@ class WhileLoop(ExternKernel): assert p.get_dtype() == torch.bool, p assert len(p.get_size()) == 0, p - assert ( - len(all_inputs) > 0 - ), "torch.while_loop is assumed to have at least one operand." + assert len(all_inputs) > 0, ( + "torch.while_loop is assumed to have at least one operand." + ) device = all_inputs[0].get_device() @@ -7669,9 +7679,9 @@ class _CollectiveKernel(FallbackKernel): # This is identical to FallbackKernel.set_cpp_kernel(), minus the # part that checks against input aliasing and mutation. def set_cpp_kernel_name(self, cpp_kernel_name: Optional[str] = None) -> None: - assert ( - type(self.op_overload) is torch._ops.OpOverload - ), "Setting cpp kernel needs a valid op_overload" + assert type(self.op_overload) is torch._ops.OpOverload, ( + "Setting cpp kernel needs a valid op_overload" + ) kernel = self.op_overload self.cpp_kernel_name = kernel._schema.name diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py index 674f48e7c9e..c06ebf6fa3f 100644 --- a/torch/_inductor/kernel/flex_attention.py +++ b/torch/_inductor/kernel/flex_attention.py @@ -1,5 +1,5 @@ # mypy: allow-untyped-defs -""" Triton Implementation of the flex_attention Kernel""" +"""Triton Implementation of the flex_attention Kernel""" import copy import logging @@ -60,9 +60,9 @@ def construct_strides( ) -> Sequence[int]: """From a list of sizes and a fill order, construct the strides of the permuted tensor.""" # Initialize strides - assert len(sizes) == len( - fill_order - ), "Length of sizes must match the length of the fill order" + assert len(sizes) == len(fill_order), ( + "Length of sizes must match the length of the fill order" + ) strides = [0] * len(sizes) # Start with stride 1 for the innermost dimension @@ -1151,10 +1151,14 @@ def lower_cpu( SPARSE_Q_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_Q_BLOCK_SIZE) assert V.graph.sizevars.evaluate_expr( sympy.Le(seq_len_q, sympy.Mul(kv_indices.get_size()[-2], SPARSE_Q_BLOCK_SIZE)) - ), "Q seqlen must be smaller than the block_mask size in the Q dimension, considering pass a larger block_mask." + ), ( + "Q seqlen must be smaller than the block_mask size in the Q dimension, considering pass a larger block_mask." + ) assert V.graph.sizevars.evaluate_expr( sympy.Le(seq_len_kv, sympy.Mul(kv_indices.get_size()[-1], SPARSE_KV_BLOCK_SIZE)) - ), "KV seqlen must be smaller than the block_mask size in the KV dimension, considering pass a larger block_mask." + ), ( + "KV seqlen must be smaller than the block_mask size in the KV dimension, considering pass a larger block_mask." + ) CppFlexAttentionTemplate.add_choices( choices=_choices, input_nodes=input_nodes, @@ -1364,15 +1368,15 @@ def flex_attention( Bq, Hq, seq_len_q, qk_head_dim = query.get_size() Bkv, Hkv, seq_len_kv, v_head_dim = value.get_size() - assert V.graph.sizevars.evaluate_expr( - sympy.Eq(Bq, Bkv) | sympy.Eq(Bkv, 1) - ), f"Bq and Bkv must broadcastable. Got Bq={Bq} and Bkv={Bkv}" - assert V.graph.sizevars.evaluate_expr( - sympy.Gt(seq_len_q, 0) - ), "Query length must be greater than 0" - assert V.graph.sizevars.evaluate_expr( - sympy.Gt(seq_len_kv, 0) - ), "Key length must be greater than 0" + assert V.graph.sizevars.evaluate_expr(sympy.Eq(Bq, Bkv) | sympy.Eq(Bkv, 1)), ( + f"Bq and Bkv must broadcastable. Got Bq={Bq} and Bkv={Bkv}" + ) + assert V.graph.sizevars.evaluate_expr(sympy.Gt(seq_len_q, 0)), ( + "Query length must be greater than 0" + ) + assert V.graph.sizevars.evaluate_expr(sympy.Gt(seq_len_kv, 0)), ( + "Key length must be greater than 0" + ) B = Bq @@ -2291,9 +2295,9 @@ def process_joint_outputs( JointOutputResult containing processed buffers and gradients """ assert isinstance(all_joint_outputs, list) - assert ( - all_joint_outputs[0] is not None - ), "joint_subgraph_buffer is None - this is a bug!" + assert all_joint_outputs[0] is not None, ( + "joint_subgraph_buffer is None - this is a bug!" + ) joint_buffer = all_joint_outputs[0] other_grads = all_joint_outputs[num_placeholders - 1 :] @@ -2392,9 +2396,9 @@ def flex_attention_backward(*args, **kwargs): Bq, Hq, seq_len_q, qk_head_dim = query.get_size() Bkv, Hkv, seq_len_kv, v_head_dim = value.get_size() - assert V.graph.sizevars.evaluate_expr( - sympy.Eq(Bq, Bkv) | sympy.Eq(Bkv, 1) - ), f"Bq and Bkv must broadcastable. Got Bq={Bq} and Bkv={Bkv}" + assert V.graph.sizevars.evaluate_expr(sympy.Eq(Bq, Bkv) | sympy.Eq(Bkv, 1)), ( + f"Bq and Bkv must broadcastable. Got Bq={Bq} and Bkv={Bkv}" + ) kernel_options = dict(kernel_options) # Mark symbols in custom kernel options as static shapes and add guards. @@ -2639,9 +2643,11 @@ def flex_attention_backward(*args, **kwargs): grad_key = broadcasted_grad_key grad_value = broadcasted_grad_value else: - assert V.graph.sizevars.evaluate_expr( - sympy.Gt(Bq, 1) & sympy.Eq(Bkv, 1) - ), f"Bq and Bkv must broadcastable. Got Bq={V.graph.sizevars.evaluate_expr(Bq)} and Bkv={V.graph.sizevars.evaluate_expr(Bkv)}" # noqa: B950 + assert V.graph.sizevars.evaluate_expr(sympy.Gt(Bq, 1) & sympy.Eq(Bkv, 1)), ( + f"Bq and Bkv must broadcastable. " + f"Got Bq={V.graph.sizevars.evaluate_expr(Bq)} " + f"and Bkv={V.graph.sizevars.evaluate_expr(Bkv)}" + ) grad_key = lowerings[aten.sum](broadcasted_grad_key, axis=0, keepdims=True) grad_value = lowerings[aten.sum](broadcasted_grad_value, axis=0, keepdims=True) diff --git a/torch/_inductor/kernel/flex_decoding.py b/torch/_inductor/kernel/flex_decoding.py index bcae7ac689c..12d69971ad2 100644 --- a/torch/_inductor/kernel/flex_decoding.py +++ b/torch/_inductor/kernel/flex_decoding.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs -""" Triton Implementation of the flex_attention Kernel for short query length (FlexDecoding)""" +"""Triton Implementation of the flex_attention Kernel for short query length (FlexDecoding)""" + from typing import Any import sympy @@ -367,9 +368,9 @@ def create_flex_decoding_kernel(*args, **kwargs): Bq, Hq, seq_len_q, qk_head_dim = query.get_size() Bkv, Hkv, seq_len_kv, v_head_dim = value.get_size() - assert V.graph.sizevars.evaluate_expr( - sympy.Eq(Bq, Bkv) | sympy.Eq(Bkv, 1) - ), f"Bq and Bkv must broadcastable. Got Bq={Bq} and Bkv={Bkv}" + assert V.graph.sizevars.evaluate_expr(sympy.Eq(Bq, Bkv) | sympy.Eq(Bkv, 1)), ( + f"Bq and Bkv must broadcastable. Got Bq={Bq} and Bkv={Bkv}" + ) B = Bq kernel_options = dict(kernel_options) @@ -481,7 +482,8 @@ def create_flex_decoding_kernel(*args, **kwargs): max( next_power_of_2( V.graph.sizevars.size_hint( - seq_len_q, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type] + seq_len_q, + fallback=torch._inductor.config.unbacked_symint_fallback, # type: ignore[arg-type] ) * gqa_shared_heads ), diff --git a/torch/_inductor/kernel/mm_common.py b/torch/_inductor/kernel/mm_common.py index 34e3c7cf9dd..064666082a7 100644 --- a/torch/_inductor/kernel/mm_common.py +++ b/torch/_inductor/kernel/mm_common.py @@ -65,7 +65,8 @@ def filtered_configs( m = max( next_power_of_2( V.graph.sizevars.size_hint( - m, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type] + m, + fallback=torch._inductor.config.unbacked_symint_fallback, # type: ignore[arg-type] ) ), min_block_size, @@ -73,7 +74,8 @@ def filtered_configs( n = max( next_power_of_2( V.graph.sizevars.size_hint( - n, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type] + n, + fallback=torch._inductor.config.unbacked_symint_fallback, # type: ignore[arg-type] ) ), min_block_size, @@ -81,7 +83,8 @@ def filtered_configs( k = max( next_power_of_2( V.graph.sizevars.size_hint( - k, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type] + k, + fallback=torch._inductor.config.unbacked_symint_fallback, # type: ignore[arg-type] ) ), min_block_size_k, @@ -467,8 +470,7 @@ def mm_options(config, sym_m, sym_n, sym_k, layout): """ even_k_symbolic = ( # it isn't worth guarding on this - sympy.gcd(sym_k, config.kwargs["BLOCK_K"]) - == config.kwargs["BLOCK_K"] + sympy.gcd(sym_k, config.kwargs["BLOCK_K"]) == config.kwargs["BLOCK_K"] ) allow_tf32 = torch.backends.cuda.matmul.allow_tf32 and ( not inductor_config.force_same_precision diff --git a/torch/_inductor/loop_body.py b/torch/_inductor/loop_body.py index c3a3ab7133e..ffa51d3ad75 100644 --- a/torch/_inductor/loop_body.py +++ b/torch/_inductor/loop_body.py @@ -194,11 +194,12 @@ class LoopBody: # There is indeed an issue due to symbol name conflicting. # y0 maybe reused for the y dimension later. ( - iter_vars, - reduce_vars, - ), var_ranges = dependencies.index_vars_no_squeeze( - iter_sizes, reduce_sizes, prefix="t" - ) + ( + iter_vars, + reduce_vars, + ), + var_ranges, + ) = dependencies.index_vars_no_squeeze(iter_sizes, reduce_sizes, prefix="t") new_body = LoopBody( old_body, [iter_reindex(iter_vars), reduce_reindex(reduce_vars)], @@ -234,7 +235,8 @@ class LoopBody: new_sizes = (new_iter_size, reduce_size) (iter_vars, reduce_vars), var_ranges = dependencies.index_vars_no_squeeze( - *new_sizes, prefix="t" # type: ignore[arg-type] + *new_sizes, + prefix="t", # type: ignore[arg-type] ) inverse_order = {b: a for a, b in enumerate(new_order)} @@ -254,7 +256,8 @@ class LoopBody: # use the original symbol prefix so we can do multiple round of reordering (iter_vars2, reduce_vars2), var_ranges2 = dependencies.index_vars_no_squeeze( - *new_sizes, prefix="p" # type: ignore[arg-type] + *new_sizes, + prefix="p", # type: ignore[arg-type] ) new_body = LoopBody( loop_body, (iter_vars2, reduce_vars2), var_ranges2, iter_vars2, reduce_vars2 @@ -385,9 +388,9 @@ class LoopBody: def indexing_from_args(self, indices): index = [*itertools.chain.from_iterable(indices)] assert len(index) == len(self.var_ranges), (index, self.var_ranges) - assert all( - v not in self.var_ranges for v in index - ), f"{self.var_ranges=}, {indices=}" + assert all(v not in self.var_ranges for v in index), ( + f"{self.var_ranges=}, {indices=}" + ) replacements = dict(zip(self.var_ranges.keys(), index)) return { name: sympy_subs(expr, replacements) diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 628d0f14576..2589ae3b1d8 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -346,7 +346,8 @@ def transform_args( # only consider tensor kwargs for promotion, for now promoting_args.extend(a for a in kwargs.values() if hasattr(a, "dtype")) dtype = get_promoted_dtype( - *promoting_args, type_promotion_kind=type_promotion_kind # type: ignore[arg-type] + *promoting_args, + type_promotion_kind=type_promotion_kind, # type: ignore[arg-type] ) device = ( @@ -448,9 +449,9 @@ def _register_lowering( (fn in fallbacks or in_namespace(fn, "_c10d_functional")) for fn in aten_fn ): # explicitly assert for "out=" ops for better error messages - assert not any( - x == "out" for x in kwargs.keys() - ), "out= ops aren't yet supported" + assert not any(x == "out" for x in kwargs.keys()), ( + "out= ops aren't yet supported" + ) args, kwargs = transform_args( args, kwargs, broadcast, type_promotion_kind, convert_input_to_bool @@ -517,9 +518,9 @@ def broadcast_symbolic_shapes(a, b): def promote_constants(inputs, override_return_dtype=None, type_promotion_kind=None): - assert ( - override_return_dtype is None or type_promotion_kind is None - ), "only one of override_return_dtype or type_promotion_kind may be given" + assert override_return_dtype is None or type_promotion_kind is None, ( + "only one of override_return_dtype or type_promotion_kind may be given" + ) if override_return_dtype is None and type_promotion_kind is None: type_promotion_kind = ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT @@ -674,9 +675,9 @@ def make_foreach_pointwise(pw_fn, allow_alpha=False): if isinstance(input, (list, tuple)): a_list_input = input break - assert ( - a_list_input is not None - ), "at least one input must be a list to a foreach op" + assert a_list_input is not None, ( + "at least one input must be a list to a foreach op" + ) # broadcast scalar inputs to match length of list inputs broadcast_inputs = [] @@ -1321,12 +1322,12 @@ def quantized_decomposed_quantize_per_channel( if input.get_dtype() == torch.bfloat16: input = to_dtype(input, torch.float32) - assert ( - input.get_dtype() == torch.float32 - ), f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}" - assert axis < len( - input.get_size() - ), f"Expecting axis to be < {len(input.get_size())}" + assert input.get_dtype() == torch.float32, ( + f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}" + ) + assert axis < len(input.get_size()), ( + f"Expecting axis to be < {len(input.get_size())}" + ) input_loader = input.make_loader() scales_loader = scales.make_loader() @@ -1373,12 +1374,12 @@ def quantized_decomposed_dequantize_per_channel( ) -> TensorBox: assert len(scales.get_size()) == 1, "expect scales 1 dim" assert len(zero_points.get_size()) == 1, "expect zero_points 1 dim" - assert ( - input.get_dtype() == dtype - ), f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}" - assert axis < len( - input.get_size() - ), f"Expecting axis to be < {len(input.get_size())}" + assert input.get_dtype() == dtype, ( + f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}" + ) + assert axis < len(input.get_size()), ( + f"Expecting axis to be < {len(input.get_size())}" + ) if out_dtype is None: out_dtype = torch.float32 @@ -1423,9 +1424,9 @@ def quantized_decomposed_quantize_per_tensor_default( ) -> TensorBox: if input.get_dtype() == torch.bfloat16: input = to_dtype(input, torch.float32) - assert ( - input.get_dtype() == torch.float32 - ), f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}" + assert input.get_dtype() == torch.float32, ( + f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}" + ) input_loader = input.make_loader() @@ -1462,9 +1463,9 @@ def quantized_decomposed_dequantize_per_tensor_default( *, out_dtype: Optional[torch.dtype] = None, ) -> TensorBox: - assert ( - input.get_dtype() == dtype - ), f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}" + assert input.get_dtype() == dtype, ( + f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}" + ) if out_dtype is None: out_dtype = torch.float32 @@ -1501,9 +1502,9 @@ def quantized_decomposed_quantize_per_tensor_tensor( ) -> TensorBox: if input.get_dtype() == torch.bfloat16: input = to_dtype(input, torch.float32) - assert ( - input.get_dtype() == torch.float32 - ), f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}" + assert input.get_dtype() == torch.float32, ( + f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}" + ) assert len(scale.get_size()) == 0 or ( len(scale.get_size()) == 1 and scale.get_size()[0] == 1 ), "expect scale as scalar tensor" @@ -1555,9 +1556,9 @@ def quantized_decomposed_dequantize_per_tensor_tensor( assert len(zero_point.get_size()) == 0 or ( len(zero_point.get_size()) == 1 and zero_point.get_size()[0] == 1 ), "expect zero_point as scalar tensor" - assert ( - input.get_dtype() == dtype - ), f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}" + assert input.get_dtype() == dtype, ( + f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}" + ) if out_dtype is None: out_dtype = torch.float32 @@ -1973,9 +1974,9 @@ def fallback_node_due_to_unsupported_type(node: torch.fx.Node, allow_cpu_inputs= def make_fallback(op, layout_constraint=None, warn=True, override_decomp=False): - assert ( - op not in decompositions or override_decomp - ), f"both a fallback and a decomp for same op: {op}" + assert op not in decompositions or override_decomp, ( + f"both a fallback and a decomp for same op: {op}" + ) if ( warn and bool(os.getenv("CI")) @@ -2086,9 +2087,9 @@ def native_dropout(x, p, train): @register_lowering(aten.bernoulli_, type_promotion_kind=None) def bernoulli_(x, *args): - assert config.fallback_random or x.get_device() == torch.device( - "cpu" - ), "this should be handled in decomps unless config.fallback_random or the device is CPU" + assert config.fallback_random or x.get_device() == torch.device("cpu"), ( + "this should be handled in decomps unless config.fallback_random or the device is CPU" + ) x.realize() op_overload = ( aten.bernoulli_.float @@ -2101,9 +2102,9 @@ def bernoulli_(x, *args): @register_lowering(aten.bernoulli.p, type_promotion_kind=None) def bernoulli_p(x, *args): - assert config.fallback_random or x.get_device() == torch.device( - "cpu" - ), "this should be handled in decomps unless config.fallback_random or the device is CPU" + assert config.fallback_random or x.get_device() == torch.device("cpu"), ( + "this should be handled in decomps unless config.fallback_random or the device is CPU" + ) return bernoulli_(clone(x), *args) @@ -3376,7 +3377,9 @@ def check_and_broadcast_indices(indices, device): i.get_dtype() in (torch.int64, torch.int32, torch.bool, torch.uint8) for i in indices if i is not None - ), f"indices must be int64, byte or bool. Got {[i.get_dtype() for i in indices if i is not None]}" + ), ( + f"indices must be int64, byte or bool. Got {[i.get_dtype() for i in indices if i is not None]}" + ) if any( i.get_dtype() in (torch.bool, torch.uint8) for i in indices if i is not None ): @@ -5668,7 +5671,8 @@ def make_reduction(reduction_type: ReductionType, override_return_dtype=None): ) result = Reduction.create(reduction_type=reduction_type, input_node=x, **kwargs) if isinstance( - result.data.data, Reduction # type: ignore[attr-defined] + result.data.data, # type: ignore[attr-defined] + Reduction, ): # Only realize if reduction isn't unrolled result.realize() return result @@ -6008,8 +6012,9 @@ def get_constant_value(x: ir.IRNode) -> Optional[ir.Constant]: return None handler = torch._inductor.ops_handler.ExtractConstantsHandler(x.get_device()) - with V.set_ops_handler(handler), patch.object( - ir.FlexibleLayout, "allow_indexing", True + with ( + V.set_ops_handler(handler), + patch.object(ir.FlexibleLayout, "allow_indexing", True), ): out = x.inner_fn(*x.inner_fn_args()) @@ -6898,9 +6903,9 @@ def force_fallback(op: torch._ops.OpOverload): A context manager to force fallback an op. Used in unit test for FallbackKernel. """ - assert isinstance( - op, torch._ops.OpOverload - ), "Only OpOverload to make the clean up easier" + assert isinstance(op, torch._ops.OpOverload), ( + "Only OpOverload to make the clean up easier" + ) old_handler = lowerings.get(op) try: register_lowering(op)(fallback_handler(op)) diff --git a/torch/_inductor/memory.py b/torch/_inductor/memory.py index 95d103e0a5f..83a927e8c5f 100644 --- a/torch/_inductor/memory.py +++ b/torch/_inductor/memory.py @@ -35,9 +35,9 @@ class MemoryPlanningInfoForBuffer: class MemoryPlanningInfoForNode: index: int = 0 size: int = 0 - pred_buffers: OrderedSet[ - Union[SchedulerBuffer, FreeableInputBuffer] - ] = dataclasses.field(default_factory=OrderedSet) + pred_buffers: OrderedSet[Union[SchedulerBuffer, FreeableInputBuffer]] = ( + dataclasses.field(default_factory=OrderedSet) + ) pred_nodes: OrderedSet[BaseSchedulerNode] = dataclasses.field( default_factory=OrderedSet ) @@ -87,9 +87,9 @@ def get_freeable_input_buf( # get freeable input buffers' successor nodes and their sizes # note that different deps can have the same name, so we use name as keys - dep_name_to_succ_nodes: dict[ - str, OrderedSet[BaseSchedulerNode] - ] = collections.defaultdict(OrderedSet) + dep_name_to_succ_nodes: dict[str, OrderedSet[BaseSchedulerNode]] = ( + collections.defaultdict(OrderedSet) + ) dep_name_to_size: dict[str, int] = dict() for node in nodes: for dep in node.read_writes.reads: @@ -112,7 +112,7 @@ def get_freeable_input_buf( def compute_size_for_scheduler_buffer( - name_to_buf: dict[str, SchedulerBuffer] + name_to_buf: dict[str, SchedulerBuffer], ) -> dict[str, tuple[int, int]]: """ Compute the size of each scheduler buffer, including (1) memory allocated when @@ -187,9 +187,9 @@ def assign_memory_planning_info_for_scheduler_buffers( # get buffer's successor nodes # note that different deps can have the same name, so we use name as keys - dep_name_to_succ_nodes: dict[ - str, OrderedSet[BaseSchedulerNode] - ] = collections.defaultdict(OrderedSet) + dep_name_to_succ_nodes: dict[str, OrderedSet[BaseSchedulerNode]] = ( + collections.defaultdict(OrderedSet) + ) for node in nodes: for dep in node.unmet_dependencies: dep_name_to_succ_nodes[dep.name].add(node) diff --git a/torch/_inductor/metrics.py b/torch/_inductor/metrics.py index 1407ae09b88..590a7419ad6 100644 --- a/torch/_inductor/metrics.py +++ b/torch/_inductor/metrics.py @@ -138,12 +138,12 @@ class MetricTable: return row_dict = row_fn() - assert len(self.column_names) == len( - row_dict - ), f"{len(self.column_names)} v.s. {len(row_dict)}" - assert OrderedSet(self.column_names) == OrderedSet( - row_dict.keys() - ), f"{OrderedSet(self.column_names)} v.s. {OrderedSet(row_dict.keys())}" + assert len(self.column_names) == len(row_dict), ( + f"{len(self.column_names)} v.s. {len(row_dict)}" + ) + assert OrderedSet(self.column_names) == OrderedSet(row_dict.keys()), ( + f"{OrderedSet(self.column_names)} v.s. {OrderedSet(row_dict.keys())}" + ) bn = get_benchmark_name() # assert bn is not None @@ -433,9 +433,9 @@ def enabled_metric_tables_impl(config_str: str) -> OrderedSet[str]: name = name.strip() if not name: continue - assert ( - name in REGISTERED_METRIC_TABLES - ), f"Metric table name {name} is not registered" + assert name in REGISTERED_METRIC_TABLES, ( + f"Metric table name {name} is not registered" + ) enabled.add(name) return enabled diff --git a/torch/_inductor/mkldnn_ir.py b/torch/_inductor/mkldnn_ir.py index 688353ddf91..11ed7710e37 100644 --- a/torch/_inductor/mkldnn_ir.py +++ b/torch/_inductor/mkldnn_ir.py @@ -751,9 +751,9 @@ class QConvPointWiseBinaryPT2E(ExternKernelAlloc): unary_algorithm, ] - assert ( - binary_attr == "sum" - ), "For now, only post op sum is supported in QConvPointWiseBinaryPT2E." + assert binary_attr == "sum", ( + "For now, only post op sum is supported in QConvPointWiseBinaryPT2E." + ) V.graph.mark_buffer_mutated(qaccum.get_name()) packed = QConvPointWiseBinaryPT2E( diff --git a/torch/_inductor/mkldnn_lowerings.py b/torch/_inductor/mkldnn_lowerings.py index b3e26d091b9..e53c1b09ee9 100644 --- a/torch/_inductor/mkldnn_lowerings.py +++ b/torch/_inductor/mkldnn_lowerings.py @@ -575,9 +575,9 @@ def register_onednn_fusion_ops(): algorithm, layout=None, ): - assert ( - packed_weight.get_dtype() is torch.int8 - ), "Only int8 weights are supported by oneDNN qlinear." + assert packed_weight.get_dtype() is torch.int8, ( + "Only int8 weights are supported by oneDNN qlinear." + ) x_size = x.get_size() if len(x_size) > 2: # GEMM template needs 2D input, normalize input shape here @@ -928,9 +928,9 @@ def register_onednn_fusion_ops(): # we will do accum dtype convertion here. x2 = to_dtype(x2, output_dtype) else: - assert ( - x2.get_dtype() == output_dtype - ), "dtype of accum for qlinear post op sum should be the same as output" + assert x2.get_dtype() == output_dtype, ( + "dtype of accum for qlinear post op sum should be the same as output" + ) x2_dtype = x2.get_dtype() bias_dtype = bias.get_dtype() if bias is not None else None choices: list[ChoiceCaller] = [] diff --git a/torch/_inductor/ops_handler.py b/torch/_inductor/ops_handler.py index 0118d29368c..692857f2609 100644 --- a/torch/_inductor/ops_handler.py +++ b/torch/_inductor/ops_handler.py @@ -806,8 +806,8 @@ class DefaultHandler(OpsHandler[Any]): assert self_arg == "self" code.write( f""" - def {target}(self, {', '.join(args)}): - return self._default({target!r}, ({', '.join(args)}, ), {{}}) + def {target}(self, {", ".join(args)}): + return self._default({target!r}, ({", ".join(args)}, ), {{}}) """.strip() ) code.write("\n\n") @@ -994,8 +994,9 @@ class KernelFormatterHandler(DefaultHandler): ) formatter._output.writeline(f"{lhs} = {name}") - with V.set_ops_handler(formatter), patch.object( - FlexibleLayout, "allow_indexing", True + with ( + V.set_ops_handler(formatter), + patch.object(FlexibleLayout, "allow_indexing", True), ): result = ir_fn(*args) return formatter.getvalue(result) diff --git a/torch/_inductor/package/package.py b/torch/_inductor/package/package.py index 478b5cfca6a..d14ac142271 100644 --- a/torch/_inductor/package/package.py +++ b/torch/_inductor/package/package.py @@ -188,7 +188,9 @@ def package_aoti( ) or ( isinstance(archive_file, (str, os.PathLike)) and os.fspath(archive_file).endswith(".pt2") - ), f"Expect archive file to be a file ending in .pt2, or is a buffer. Instead got {archive_file}" + ), ( + f"Expect archive file to be a file ending in .pt2, or is a buffer. Instead got {archive_file}" + ) # Save using the PT2 packaging format # (https://docs.google.com/document/d/1jLPp8MN8Whs0-VW9PmJ93Yg02W85tpujvHrTa1pc5x8/edit#heading=h.v2y2jgnwc56a) @@ -285,9 +287,9 @@ class AOTICompiledModel: def load_package(path: FileLike, model_name: str = "model") -> AOTICompiledModel: # type: ignore[type-arg] assert ( isinstance(path, (io.IOBase, IO)) and path.readable() and path.seekable() - ) or ( - isinstance(path, (str, os.PathLike)) and os.fspath(path).endswith(".pt2") - ), f"Unable to load package. Path must be a buffer or a file ending in .pt2. Instead got {path}" + ) or (isinstance(path, (str, os.PathLike)) and os.fspath(path).endswith(".pt2")), ( + f"Unable to load package. Path must be a buffer or a file ending in .pt2. Instead got {path}" + ) if isinstance(path, (io.IOBase, IO)): with tempfile.NamedTemporaryFile(suffix=".pt2") as f: diff --git a/torch/_inductor/pattern_matcher.py b/torch/_inductor/pattern_matcher.py index 44bfccd6ee1..84c42d61f61 100644 --- a/torch/_inductor/pattern_matcher.py +++ b/torch/_inductor/pattern_matcher.py @@ -90,20 +90,17 @@ NodeOrConstant = Union[Constant, torch.fx.Node] class SearchFn(Protocol): __name__: str - def __call__(self, *args: Any, **kwargs: Any) -> Any: - ... + def __call__(self, *args: Any, **kwargs: Any) -> Any: ... class ReplaceFn(Protocol): - def __call__(self, *args: Any, **kwargs: Any) -> Any: - ... + def __call__(self, *args: Any, **kwargs: Any) -> Any: ... class TraceFn(Protocol): def __call__( self, fn: Union[SearchFn, ReplaceFn], *args: Any, **kwargs: Any - ) -> torch.fx.GraphModule: - ... + ) -> torch.fx.GraphModule: ... T = TypeVar("T") @@ -365,8 +362,7 @@ class PatternExpr(ABC): """ @abstractmethod - def _match(self, node: torch.fx.Node, ctx: MatchContext) -> MatchResult: - ... + def _match(self, node: torch.fx.Node, ctx: MatchContext) -> MatchResult: ... def match(self, node: torch.fx.Node) -> MatchResult: try: @@ -489,8 +485,7 @@ class _TargetExpr(PatternExpr): @property @abstractmethod - def op(self) -> str: - ... + def op(self) -> str: ... def fns_repr(self) -> str: first_repr = self.fns[0] @@ -997,8 +992,9 @@ class PatternPrettyPrinter: class _PassDictsType(Protocol): - def __getitem__(self, k: tuple[str, torch.fx.node.Target]) -> list[PatternEntry]: - ... + def __getitem__( + self, k: tuple[str, torch.fx.node.Target] + ) -> list[PatternEntry]: ... @dataclasses.dataclass @@ -1925,7 +1921,10 @@ def fx_to_pattern( get_attr = _not_implemented def placeholder( - self, target: str, args: Sequence[Any], kwargs: Mapping[str, Any] # type: ignore[override] + self, + target: str, # type: ignore[override] + args: Sequence[Any], + kwargs: Mapping[str, Any], ) -> Union[ExclusiveKeywordArg, KeywordArg]: n = next(argnum) if n < len(argnames): @@ -1942,7 +1941,10 @@ def fx_to_pattern( return KeywordArg(name) def call_function( - self, target: str, args: Sequence[Any], kwargs: Mapping[str, Any] # type: ignore[override] + self, + target: str, # type: ignore[override] + args: Sequence[Any], + kwargs: Mapping[str, Any], ) -> PatternExpr: process_arg_fn = process_arg # Indexing is critical for matching getitem nodes, so we can't ignore int args here diff --git a/torch/_inductor/runtime/benchmarking.py b/torch/_inductor/runtime/benchmarking.py index fb50a385e54..057d59bb855 100644 --- a/torch/_inductor/runtime/benchmarking.py +++ b/torch/_inductor/runtime/benchmarking.py @@ -24,7 +24,7 @@ T = TypeVar("T") def time_and_count( - fn: Callable[Concatenate[Any, P], T] + fn: Callable[Concatenate[Any, P], T], ) -> Callable[Concatenate[Any, P], T]: """Wraps `fn` with `dynamo_timed` context, and increments the appropriate dynamo counters. It is expected that `fn` is a method of `Benchmarker` or one of its diff --git a/torch/_inductor/runtime/runtime_utils.py b/torch/_inductor/runtime/runtime_utils.py index 05bb3ddf0b9..4c16842a775 100644 --- a/torch/_inductor/runtime/runtime_utils.py +++ b/torch/_inductor/runtime/runtime_utils.py @@ -77,9 +77,9 @@ def validate_triton_config(cfg: Config) -> None: # right now, if a pre-hook is attached to the config, it will not be saved; # and then it won't be used when the config is loaded from cache. # So we assert - if we do get a pre_hook, it might get ignored after caching. - assert ( - getattr(cfg, "pre_hook", None) is None - ), "triton configs with pre_hooks not supported" + assert getattr(cfg, "pre_hook", None) is None, ( + "triton configs with pre_hooks not supported" + ) def create_bandwidth_info_str( diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index c3b231b614e..2d3c2ca4d07 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -450,9 +450,9 @@ class CachingAutotuner(KernelInterface): self.launchers = [] def __getstate__(self) -> dict[str, Any]: - assert ( - not self.launchers - ), "pickle should not be called with after make_launchers()" + assert not self.launchers, ( + "pickle should not be called with after make_launchers()" + ) return { **self.__dict__, "lock": None, @@ -678,7 +678,9 @@ class CachingAutotuner(KernelInterface): assert isinstance( arg, torch.Tensor, - ), "self.reset_to_zero_arg_names should only contain valid argument names" + ), ( + "self.reset_to_zero_arg_names should only contain valid argument names" + ) arg.zero_() for name, arg in kwargs.items(): @@ -686,7 +688,9 @@ class CachingAutotuner(KernelInterface): assert isinstance( arg, torch.Tensor, - ), "self.reset_to_zero_arg_names should only contain valid argument names" + ), ( + "self.reset_to_zero_arg_names should only contain valid argument names" + ) arg.zero_() def maybe_clone_args( @@ -866,7 +870,9 @@ class CachingAutotuner(KernelInterface): assert not ( self.heuristic_type == HeuristicType.PERSISTENT_REDUCTION and "R0_BLOCK" in launcher.config.kwargs - ), "Coordinate descent tuner relies on the assumption that persistent reduction's triton config does not have R0_BLOCK" + ), ( + "Coordinate descent tuner relies on the assumption that persistent reduction's triton config does not have R0_BLOCK" + ) start_time = time.time_ns() best_config = self.coordesc_tuner.autotune( benchmark_one_config, launcher.config, None @@ -882,9 +888,7 @@ class CachingAutotuner(KernelInterface): ) return config2launcher.get(best_config) - def run( - self, *args, grid, stream, benchmark_run=False, **kwargs - ): # type:ignore[override] + def run(self, *args, grid, stream, benchmark_run=False, **kwargs): # type:ignore[override] if self.triton_interpret: return self.fn[grid]( *args, @@ -1192,12 +1196,12 @@ class TritonCompileResult: exec( f""" - def launcher({', '.join(def_args)}, grid, stream): + def launcher({", ".join(def_args)}, grid, stream): if callable(grid): grid_0, grid_1, grid_2 = grid(grid_meta) else: grid_0, grid_1, grid_2 = grid - runner({', '.join(runner_args)}) + runner({", ".join(runner_args)}) return bin """.lstrip(), scope, @@ -1503,9 +1507,9 @@ def check_max_block(cfg: dict[str, int]): if block_suffix in var: prefix = var.removesuffix(block_suffix) max_block = TRITON_MAX_BLOCK[prefix] - assert ( - val <= max_block - ), f"'{var}' too large. Maximum: {max_block}. Actual: {val}." + assert val <= max_block, ( + f"'{var}' too large. Maximum: {max_block}. Actual: {val}." + ) def _num_warps(num_warps, max_num_warps=8, min_num_warps=2, register_intensive=False): @@ -1657,20 +1661,20 @@ def _get_nd_reduction_numels(r: int, size_hints: dict[str, int]) -> dict[str, in prefix = f"r{idx}_" max_size = min(size_hints[prefix], TRITON_MAX_BLOCK[prefix.upper()]) dim = min(max_size, remaining) - assert ( - remaining % dim == 0 - ), f"Expected dimension '{dim}' to divide remaining size '{remaining}'" + assert remaining % dim == 0, ( + f"Expected dimension '{dim}' to divide remaining size '{remaining}'" + ) rnumels[prefix] = dim remaining //= dim # Sanity check the results. final_numel = conditional_product(*rnumels.values()) - assert ( - r == final_numel - ), f"Expected ND reduction size ({rnumels}) to have {r} elements." - assert all( - rnumels[prefix] <= size_hints[prefix] for prefix in rnumels - ), f"rnumels exceed size_hints. {rnumels} > {size_hints}" + assert r == final_numel, ( + f"Expected ND reduction size ({rnumels}) to have {r} elements." + ) + assert all(rnumels[prefix] <= size_hints[prefix] for prefix in rnumels), ( + f"rnumels exceed size_hints. {rnumels} > {size_hints}" + ) return rnumels @@ -1967,9 +1971,9 @@ def cooperative_reduction( size_hints["x"] = 1 # Cooperative reductions currently only support a single reduction dimension. - assert ( - len(size_hints) == 2 - ), "Cooperative reductions don't support tiling reduction dims" + assert len(size_hints) == 2, ( + "Cooperative reductions don't support tiling reduction dims" + ) xnumel, rnumel = size_hints["x"], size_hints["r0_"] # TODO(jansel): we should base target on the SM count of the local GPU @@ -2274,9 +2278,9 @@ def grid_combo_kernels( assert min_blocks_d is not None min_blocks = min_blocks_d else: - assert ( - min_blocks_d is None or min_blocks == min_blocks_d - ), f"inconsistent min_blocks {min_blocks} vs x grid {numels[-1]}" + assert min_blocks_d is None or min_blocks == min_blocks_d, ( + f"inconsistent min_blocks {min_blocks} vs x grid {numels[-1]}" + ) else: # sequential dispatch seq_numels = list(numels) diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 24938017459..1a739f3513a 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -200,9 +200,9 @@ class BaseSchedulerNode: def __init__(self, scheduler: Scheduler) -> None: self.scheduler: Scheduler = scheduler - self.debug_device_str: Callable[ - [BaseSchedulerNode], list[str] - ] = lambda *args, **kwargs: [] + self.debug_device_str: Callable[[BaseSchedulerNode], list[str]] = ( + lambda *args, **kwargs: [] + ) def _init_from_node(self, node: ir.Operation) -> None: self.node: Optional[ir.Operation] = node @@ -232,7 +232,7 @@ class BaseSchedulerNode: buf = IndentedBuffer() buf.splice( f"""\ -{name}: {type(self).__name__}({type(getattr(self, 'node', None)).__name__}) +{name}: {type(self).__name__}({type(getattr(self, "node", None)).__name__}) {name}.writes = {pformat(self.read_writes.writes)} {name}.unmet_dependencies = {pformat(self.unmet_dependencies)} {name}.met_dependencies = {pformat(self.read_writes.reads - self.unmet_dependencies)} @@ -525,9 +525,9 @@ class BaseSchedulerNode: V.kernel.mutations.add(input_buf.get_name()) V.kernel.mutations.add(buf.get_name()) - V.kernel.inplace_update_buffers[ - buf.get_name() - ] = input_buf.get_name() + V.kernel.inplace_update_buffers[buf.get_name()] = ( + input_buf.get_name() + ) break def codegen_originating_info( @@ -693,7 +693,7 @@ class BaseSchedulerNode: continue def get_buf_bytes( - buf: Optional[Union[ir.Buffer, ir.TensorBox, ir.TorchBindObject]] + buf: Optional[Union[ir.Buffer, ir.TensorBox, ir.TorchBindObject]], ) -> int: if not buf: return 0 @@ -794,12 +794,11 @@ class BaseSchedulerNode: # runtime for that today return 0 - with FakeTensorMode() as fake_mode, FlopCounterMode( - display=False - ) as flop_counter_mode, V.set_current_node( - self.node.fx_node - ), V.set_fake_mode( - fake_mode + with ( + FakeTensorMode() as fake_mode, + FlopCounterMode(display=False) as flop_counter_mode, + V.set_current_node(self.node.fx_node), + V.set_fake_mode(fake_mode), ): from .ir import ir_node_to_tensor @@ -1123,15 +1122,15 @@ class SchedulerNode(BaseSchedulerNode): return self._sizes def is_reduction(self) -> bool: - assert isinstance( - self.node, (ir.ComputedBuffer, ir.TemplateBuffer) - ), f"{type(self.node)=}" + assert isinstance(self.node, (ir.ComputedBuffer, ir.TemplateBuffer)), ( + f"{type(self.node)=}" + ) return bool(self.node.get_reduction_type()) def is_split_scan(self) -> bool: - assert isinstance( - self.node, (ir.ComputedBuffer, ir.TemplateBuffer) - ), f"{type(self.node)=}" + assert isinstance(self.node, (ir.ComputedBuffer, ir.TemplateBuffer)), ( + f"{type(self.node)=}" + ) return isinstance(self.node, ir.ComputedBuffer) and isinstance( self.node.data, ir.SplitScan ) @@ -1163,9 +1162,10 @@ class SchedulerNode(BaseSchedulerNode): def codegen(self, index_vars: Sequence[Sequence[sympy.Expr]]) -> None: var_ranges = self.ranges_from_index_vars(index_vars) try: - with V.set_ops_handler( - SimplifyIndexing(V.get_ops_handler(), var_ranges) - ), V.kernel.set_current_node(self): + with ( + V.set_ops_handler(SimplifyIndexing(V.get_ops_handler(), var_ranges)), + V.kernel.set_current_node(self), + ): self._body(*index_vars) except Exception: log.fatal("Error in codegen for %s", self.node) @@ -1231,7 +1231,7 @@ class SchedulerNode(BaseSchedulerNode): def refresh_group_node_dependencies( - group_snode: Union[FusedSchedulerNode, GroupedSchedulerNode] + group_snode: Union[FusedSchedulerNode, GroupedSchedulerNode], ) -> None: snodes = group_snode.snodes group_snode.set_read_writes( @@ -1754,7 +1754,7 @@ class ForeachKernelSchedulerNode(FusedSchedulerNode): @staticmethod def set_group_algorithm_for_combo_kernels( - custom_group_algorithm: Callable[[Scheduler], list[list[BaseSchedulerNode]]] + custom_group_algorithm: Callable[[Scheduler], list[list[BaseSchedulerNode]]], ) -> None: ForeachKernelSchedulerNode.group_algorithm_for_combo_kernels = ( custom_group_algorithm @@ -1975,9 +1975,9 @@ class Scheduler: for node in self.nodes: node.prune_deps() - self.name_to_donated_buffer: dict[ - str, SchedulerDonatedBuffer - ] = self.get_donated_buffers() + self.name_to_donated_buffer: dict[str, SchedulerDonatedBuffer] = ( + self.get_donated_buffers() + ) self.name_to_node: dict[str, BaseSchedulerNode] = { n.get_name(): n for n in self.nodes } @@ -2099,9 +2099,9 @@ class Scheduler: node.log_details() def create_scheduler_node(self, node: ir.Operation) -> BaseSchedulerNode: - assert ( - node.get_origins() is not None - ), "All nodes passed to scheduling must have an origin" + assert node.get_origins() is not None, ( + "All nodes passed to scheduling must have an origin" + ) if node.is_no_op(): return NopKernelSchedulerNode(self, node) elif isinstance(node, (ir.ComputedBuffer, ir.TemplateBuffer)): @@ -2260,9 +2260,9 @@ class Scheduler: ) # if a kernel takes unbacked symints, register dependencies for s in unbacked_symbol_uses: - assert ( - s in unbacked_symbol_to_origin_node - ), f"{s} not in {unbacked_symbol_to_origin_node}" + assert s in unbacked_symbol_to_origin_node, ( + f"{s} not in {unbacked_symbol_to_origin_node}" + ) if (r := unbacked_symbol_to_origin_node[s]) is not None: for buf in self.name_to_node[r].get_outputs(): node.add_fake_dep(StarDep(buf.get_name())) @@ -2310,9 +2310,9 @@ class Scheduler: for alt_name in buf.get_mutations(): self.mutation_renames[rename(alt_name)] = buf.get_name() self.mutation_renames[alt_name] = buf.get_name() - self.mutation_real_name[ - buf.get_name() - ] = self.mutation_real_name.get(alt_name, alt_name) + self.mutation_real_name[buf.get_name()] = ( + self.mutation_real_name.get(alt_name, alt_name) + ) # make sure outputs aren't dead-code-eliminated for buf_name in V.graph.get_output_names(): @@ -2322,9 +2322,9 @@ class Scheduler: # make sure unbacked symints aren't dead-code-eliminated for out in V.graph.graph_outputs: for s in out.get_unbacked_symbol_uses(): - assert ( - s in unbacked_symbol_to_origin_node - ), f"{s} not in {unbacked_symbol_to_origin_node.keys()}" + assert s in unbacked_symbol_to_origin_node, ( + f"{s} not in {unbacked_symbol_to_origin_node.keys()}" + ) if r := unbacked_symbol_to_origin_node[s]: for buf_name in self.name_to_node[r].get_buffer_names(): log.debug( @@ -3304,15 +3304,15 @@ class Scheduler: rhs_dep = node2_name2dep[buf_name] if not isinstance(lhs_dep, MemoryDep) or not isinstance(rhs_dep, MemoryDep): - reasons[ - buf_name - ] = f"not MemoryDep: {type(lhs_dep)} v.s. {type(rhs_dep)}" + reasons[buf_name] = ( + f"not MemoryDep: {type(lhs_dep)} v.s. {type(rhs_dep)}" + ) continue if lhs_dep.get_numel() != rhs_dep.get_numel(): - reasons[ - buf_name - ] = f"different numel: {lhs_dep.get_numel()} v.s. {rhs_dep.get_numel()}" + reasons[buf_name] = ( + f"different numel: {lhs_dep.get_numel()} v.s. {rhs_dep.get_numel()}" + ) continue # same numel but different MemoryDep.size. Should be broadcasting @@ -3340,9 +3340,9 @@ class Scheduler: layout_str = "" if not isinstance(buf, ir.TorchBindObject): layout_str = f"Layout: {buf.layout}" - reasons[ - buf_name - ] = f"Unknown reason: {lhs_dep} v.s. {rhs_dep}. {layout_str}" + reasons[buf_name] = ( + f"Unknown reason: {lhs_dep} v.s. {rhs_dep}. {layout_str}" + ) return str(reasons) @@ -3903,9 +3903,9 @@ class Scheduler: self.free_buffers() def create_backend(self, device: torch.device) -> BaseScheduling: - assert ( - not is_gpu(device.type) or device.index is not None - ), f"{device} should have been normalized in lowering" + assert not is_gpu(device.type) or device.index is not None, ( + f"{device} should have been normalized in lowering" + ) V.graph.add_device_info(device) device_scheduling = get_scheduling_for_device(device.type) @@ -4135,9 +4135,9 @@ class Scheduler: partitions, signatures = self.graph_partition() for partition, signature in zip(partitions, signatures): - assert ( - len(partition) >= 1 - ), f"Each partition must have at least one node but found {len(partition)}" + assert len(partition) >= 1, ( + f"Each partition must have at least one node but found {len(partition)}" + ) if signature.skip_cudagraph: self._codegen(partition) diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index 7ccb474d33d..414d1f97c29 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -168,9 +168,9 @@ class PartialRender: ) else: return - assert ( - self.replacement_hooks[hook_key] is not None - ), "hook_key can only be called once" + assert self.replacement_hooks[hook_key] is not None, ( + "hook_key can only be called once" + ) self.code = self.code.replace(hook_key, self.replacement_hooks[hook_key]()) self.replacement_hooks[hook_key] = None @@ -257,9 +257,9 @@ class ModificationWrapper(V.WrapperHandler): # type: ignore[name-defined] This is used by flex_attention's backwards grad for captured buffers, see zeros_and_scatter lowering """ - assert ( - self.mask is not None - ), "Mask is required for inner stores in modifications" + assert self.mask is not None, ( + "Mask is required for inner stores in modifications" + ) assert mode == "atomic_add", "Only atomic_add is supported for inner stores" buf_name = self._add_kernel_input(name) @@ -573,12 +573,12 @@ class TritonTemplateKernel(TritonKernel): def _get_subgraph(self, subgraph_number: int): assert isinstance(subgraph_number, int) assert isinstance(self.subgraphs, list) - assert subgraph_number < len( - self.subgraphs - ), f"Invalid subgraph number provided to create_modification, {subgraph_number} must be < {len(self.subgraphs)}" - assert ( - self.body.getvalue() == "" - ), "Body should be clear before adding a modification" + assert subgraph_number < len(self.subgraphs), ( + f"Invalid subgraph number provided to create_modification, {subgraph_number} must be < {len(self.subgraphs)}" + ) + assert self.body.getvalue() == "", ( + "Body should be clear before adding a modification" + ) return self.subgraphs[subgraph_number] def _handle_scatter_graph(self, scatter_graph): @@ -587,9 +587,9 @@ class TritonTemplateKernel(TritonKernel): Args: scatter_graph: The scatter graph to process """ - assert isinstance( - scatter_graph, ir.ComputedBuffer - ), f"scatter_graph must be an instance of ComputeBuffer but got {type(scatter_graph)}" + assert isinstance(scatter_graph, ir.ComputedBuffer), ( + f"scatter_graph must be an instance of ComputeBuffer but got {type(scatter_graph)}" + ) def contiguous_strides(x): # We always create a fresh contiguous grad for scattering into @@ -597,7 +597,9 @@ class TritonTemplateKernel(TritonKernel): x_i * stride for x_i, stride in zip(x, scatter_graph.get_stride()) ) - return scatter_graph.data.store_output(scatter_graph.name, contiguous_strides, []) # type: ignore[attr-defined] + return scatter_graph.data.store_output( # type: ignore[attr-defined] + scatter_graph.name, contiguous_strides, [] + ) def modification( self, @@ -626,9 +628,9 @@ class TritonTemplateKernel(TritonKernel): self, subgraph_number, fixed_inputs, mask ) with V.set_ops_handler(modification_handler): - assert isinstance( - subgraph, (ir.ComputedBuffer, list) - ), f"Expected the subgraph to be a ComputedBuffer or a List[ComputedBuffer], got {type(subgraph)}" + assert isinstance(subgraph, (ir.ComputedBuffer, list)), ( + f"Expected the subgraph to be a ComputedBuffer or a List[ComputedBuffer], got {type(subgraph)}" + ) # Handle scatter stores if isinstance(subgraph, list): for scatter_graph in subgraph: @@ -1123,15 +1125,17 @@ class TritonTemplate(KernelTemplate): "subgraphs": subgraphs, } - with patch.object( - V.graph, "get_dtype", self._fake_get_dtype(fake_out) - ), V.graph.set_current_device(layout.device), TritonTemplateKernel( - kernel_name=kernel_name, - output_node=fake_out, - workspace_arg=workspace_arg, - use_jit=False, - **kernel_options, - ) as kernel: + with ( + patch.object(V.graph, "get_dtype", self._fake_get_dtype(fake_out)), + V.graph.set_current_device(layout.device), + TritonTemplateKernel( + kernel_name=kernel_name, + output_node=fake_out, + workspace_arg=workspace_arg, + use_jit=False, + **kernel_options, + ) as kernel, + ): try: template = kernel.render(self.template, kwargs) with kernel.set_subgraph_body(""): @@ -1442,9 +1446,9 @@ class ExternKernelCaller(ChoiceCaller): def output_node(self): if self.choice.use_fallback_kernel: - assert ( - self.choice.op_overload is not None - ), "Please provide an op_overload to use ir.FallbackKernel" + assert self.choice.op_overload is not None, ( + "Please provide an op_overload to use ir.FallbackKernel" + ) inner = ir.FallbackKernel.create( self.choice.op_overload, *self.input_nodes, **self.kwargs ) @@ -1979,7 +1983,7 @@ class AlgorithmSelectorCache(PersistentCache): input_gen_fns = {} def get_inputs( - choices: Union[list[ExternKernelCaller], list[TritonTemplateCaller]] + choices: Union[list[ExternKernelCaller], list[TritonTemplateCaller]], ) -> AutotuneArgs: # de-duplicate args unique_example_inputs = { @@ -2099,7 +2103,7 @@ class AlgorithmSelectorCache(PersistentCache): return timings def benchmark_in_sub_process( - choices: Union[list[ExternKernelCaller], list[TritonTemplateCaller]] + choices: Union[list[ExternKernelCaller], list[TritonTemplateCaller]], ): from . import autotune_process @@ -2139,7 +2143,8 @@ class AlgorithmSelectorCache(PersistentCache): map( str, V.graph.sizevars.size_hints( - n.get_size(), fallback=config.unbacked_symint_fallback # type: ignore[arg-type] + n.get_size(), + fallback=config.unbacked_symint_fallback, # type: ignore[arg-type] ), ) ) @@ -2313,15 +2318,15 @@ def autotune_select_algorithm(*args, **kwargs): _ALGORITHM_SELECTOR_CACHE = AlgorithmSelectorCache() if "return_multi_template" not in kwargs: - kwargs[ - "return_multi_template" - ] = torch._inductor.config.benchmark_epilogue_fusion + kwargs["return_multi_template"] = ( + torch._inductor.config.benchmark_epilogue_fusion + ) return _ALGORITHM_SELECTOR_CACHE(*args, **kwargs) def add_feedback_saver( - fn: Callable[[dict[ChoiceCaller, float], str, list[Any], list[ChoiceCaller]], None] + fn: Callable[[dict[ChoiceCaller, float], str, list[Any], list[ChoiceCaller]], None], ): global _ALGORITHM_SELECTOR_CACHE if _ALGORITHM_SELECTOR_CACHE is None: diff --git a/torch/_inductor/sizevars.py b/torch/_inductor/sizevars.py index fcc549bf652..4d7f24f9649 100644 --- a/torch/_inductor/sizevars.py +++ b/torch/_inductor/sizevars.py @@ -905,9 +905,9 @@ class SimplifyIndexing(V.WrapperHandler): # type: ignore[name-defined] def __init__(self, inner, var_ranges: VarRanges) -> None: super().__init__(inner) self.name = "SimplifyIndexing" - self._simplify: Callable[ - [Expr], Expr - ] = lambda index: V.graph.sizevars.simplify_with_ranges(index, var_ranges) + self._simplify: Callable[[Expr], Expr] = ( + lambda index: V.graph.sizevars.simplify_with_ranges(index, var_ranges) + ) def load(self, name: str, index: sympy.Expr): return self._inner.load(name, self._simplify(index)) diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index a2ad12f0e6e..130074f9d3f 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -283,9 +283,9 @@ def ceildiv( # TODO: There is a bug in a call to this function, to repro: # python benchmarks/dynamo/huggingface.py --inductor -d cuda --accuracy # --amp --only YituTechConvBert --dynamic-shapes - assert isinstance(numer, int) and isinstance( - denom, int - ), f"{numer}: {type(numer)}, {denom}: {type(denom)}" + assert isinstance(numer, int) and isinstance(denom, int), ( + f"{numer}: {type(numer)}, {denom}: {type(denom)}" + ) return runtime_ceildiv(numer, denom) @@ -325,7 +325,7 @@ def _type_of(key: Optional[torch.dtype]) -> str: def convert_shape_to_inductor( - lst: Iterable[Union[int, torch.SymInt]] + lst: Iterable[Union[int, torch.SymInt]], ) -> list[sympy.Expr]: """ Gets the shape and stride of a tensor. For non-symbolic tensors, this is @@ -502,11 +502,9 @@ RV = TypeVar("RV", covariant=True) class CachedMethod(Protocol, Generic[P, RV]): @staticmethod - def clear_cache(cache: Any) -> None: - ... + def clear_cache(cache: Any) -> None: ... - def __call__(self, *args: P.args, **kwargs: P.kwargs) -> RV: - ... + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> RV: ... # See https://github.com/python/mypy/issues/13222#issuecomment-1193073470 to understand the type signature @@ -1359,9 +1357,9 @@ def _rocm_native_device_arch_name(device: str) -> str: @functools.lru_cache(None) -def try_import_ck_lib() -> ( - tuple[Optional[str], Callable[[], list[Any]], Callable[[], list[Any]], type[Any]] -): +def try_import_ck_lib() -> tuple[ + Optional[str], Callable[[], list[Any]], Callable[[], list[Any]], type[Any] +]: try: import ck4inductor # type: ignore[import] from ck4inductor.universal_gemm.gen_instances import ( # type: ignore[import] @@ -1610,9 +1608,12 @@ def get_code(fn: Callable[..., Any], *args: Any, **kwargs: Any) -> list[str]: return DummyModule() - with mock.patch.object( - GraphLowering, "compile_to_module", patched_compile_to_module - ), mock.patch.object(GraphLowering, "save_output_code", save_output_code): + with ( + mock.patch.object( + GraphLowering, "compile_to_module", patched_compile_to_module + ), + mock.patch.object(GraphLowering, "save_output_code", save_output_code), + ): torch._dynamo.reset() # Note the return here is None _ = fn(*args, **kwargs) @@ -1623,18 +1624,18 @@ def get_code(fn: Callable[..., Any], *args: Any, **kwargs: Any) -> list[str]: def get_triton_code(fn: Callable[..., Any], *args: Any, **kwargs: Any) -> str: source_codes = get_code(fn, *args, **kwargs) # Can have two outputs if backwards was eagerly compiled - assert ( - 1 <= len(source_codes) <= 2 - ), f"expected one or two code outputs got {len(source_codes)}" + assert 1 <= len(source_codes) <= 2, ( + f"expected one or two code outputs got {len(source_codes)}" + ) return source_codes[0] def run_and_get_triton_code(fn: Callable[..., Any], *args: Any, **kwargs: Any) -> str: _, source_codes = run_and_get_code(fn, *args, **kwargs) # Can have two outputs if backwards was eagerly compiled - assert ( - 1 <= len(source_codes) <= 2 - ), f"expected one or two code outputs got {len(source_codes)}" + assert 1 <= len(source_codes) <= 2, ( + f"expected one or two code outputs got {len(source_codes)}" + ) return source_codes[0] @@ -1760,9 +1761,9 @@ def is_cpu_device(inputs: Sequence[torch.Tensor]) -> bool: def get_sympy_Expr_dtype(val: sympy.Expr) -> torch.dtype: - assert isinstance( - val, sympy.Expr - ), "only support sympy.Expr as input to get_sympy_Expr_dtype" + assert isinstance(val, sympy.Expr), ( + "only support sympy.Expr as input to get_sympy_Expr_dtype" + ) if val.is_integer: # type: ignore[attr-defined] return torch.int64 else: @@ -1932,7 +1933,7 @@ def is_multi_outputs_template(input_buf: Optional[Union[Buffer, Operation]]) -> def is_output_of_multi_outputs_template( - input_buf: Optional[Union[Buffer, Operation]] + input_buf: Optional[Union[Buffer, Operation]], ) -> bool: """ Check if input buffer is a output of multi-outputs template buffer @@ -2633,7 +2634,8 @@ def set_kernel_post_grad_provenance_tracing( if node not in (EnableReduction, DisableReduction): if node.node is not None: V.debug._inductor_triton_kernel_to_post_grad_node_info[kernel_name] = [ - origin.name for origin in node.node.origins # type: ignore[attr-defined] + origin.name + for origin in node.node.origins # type: ignore[attr-defined] ] diff --git a/torch/_inductor/virtualized.py b/torch/_inductor/virtualized.py index 1ee1ef5a744..02876f09940 100644 --- a/torch/_inductor/virtualized.py +++ b/torch/_inductor/virtualized.py @@ -314,9 +314,9 @@ class _V: KernelFormatterHandler = KernelFormatterHandler WrapperHandler = WrapperHandler - set_ops_handler: Callable[ - [OpsHandler[Any]], AbstractContextManager[None] - ] = _ops._set_handler + set_ops_handler: Callable[[OpsHandler[Any]], AbstractContextManager[None]] = ( + _ops._set_handler + ) get_ops_handler: Callable[[], OpsHandler[Any]] = _ops._get_handler set_graph_handler: Callable[[GraphLowering], Any] = _graph._set_handler set_real_inputs: Callable[[Any], Any] = _real_inputs._set_handler diff --git a/torch/_inductor/wrapper_benchmark.py b/torch/_inductor/wrapper_benchmark.py index 736326f4940..ac7d10e8a0e 100644 --- a/torch/_inductor/wrapper_benchmark.py +++ b/torch/_inductor/wrapper_benchmark.py @@ -14,8 +14,7 @@ from .runtime.runtime_utils import create_bandwidth_info_str, get_num_bytes class BenchmarkCallableType(Protocol): - def __call__(self, times: int, repeat: int) -> float: - ... + def __call__(self, times: int, repeat: int) -> float: ... _kernel_category_choices = [ @@ -138,9 +137,9 @@ def benchmark_all_kernels( ) else: ms = benchmarker.benchmark_gpu(lambda: kernel_mod.call(args), rep=40) - assert ( - len(triton_kernel.launchers) == 1 - ), "Autotuner should have selected the best config" + assert len(triton_kernel.launchers) == 1, ( + "Autotuner should have selected the best config" + ) launcher = triton_kernel.launchers[0] print( get_info_str( @@ -256,9 +255,9 @@ def parse_profile_event_list( "triton_unknown", "unknown", ] - assert OrderedSet(all_events.keys()).issubset( - OrderedSet(category_list) - ), f"{list(all_events.keys())}" + assert OrderedSet(all_events.keys()).issubset(OrderedSet(category_list)), ( + f"{list(all_events.keys())}" + ) per_category_wall_time = {} total_device_ms = 0.0