mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
This PR adds code generation for CK-tile based universal gemm kernels to the CK backend for Inductor, and adds these kernels to autotune choices. Unlike legacy-CK based kernels (which are generated by parsing the CK instances from CK library), we generate the set of instances by manually specifying the tuning parameters. This PR introduces a new template for code generation, and compilation/autotuning is handled by the existing infrastructure. Points of discussion: * For simplicity and reduced coupling with CK, the instance filter checks only data type and layout, and doesn't check the alignment requirement - meaning that more instances will be compiled than necessary - while keeping the code generation independent from internal CK logic which checks the alignment validity at runtime * CK-tile instances are enabled whenever legacy-CK instances are enabled. A config knob could be introduced to differentiate between the instance types if that's needed * Whether gemm problem size K is ever dynamic, since whenever it's not a compile-time constant, we need to perform a runtime dispatch between several kernels ** Testing ** Use the existing tests in `test/inductor/test_ck_backend.py` Pull Request resolved: https://github.com/pytorch/pytorch/pull/152341 Approved by: https://github.com/chenyang78 |
||
|---|---|---|
| .. | ||
| aoti_runtime | ||
| cuda | ||
| rocm | ||
| xpu | ||
| __init__.py | ||
| aoti_hipify_utils.py | ||
| block_analysis.py | ||
| common.py | ||
| cpp_bmm_template.py | ||
| cpp_flex_attention_template.py | ||
| cpp_gemm_template.py | ||
| cpp_grouped_gemm_template.py | ||
| cpp_micro_gemm.py | ||
| cpp_template_kernel.py | ||
| cpp_template.py | ||
| cpp_utils.py | ||
| cpp_wrapper_cpu_array_ref.py | ||
| cpp_wrapper_cpu.py | ||
| cpp_wrapper_gpu.py | ||
| cpp_wrapper_mps.py | ||
| cpp.py | ||
| cpu_device_op_overrides.py | ||
| cuda_combined_scheduling.py | ||
| debug_utils.py | ||
| halide.py | ||
| memory_planning.py | ||
| mps_device_op_overrides.py | ||
| mps.py | ||
| multi_kernel.py | ||
| simd_kernel_features.py | ||
| simd.py | ||
| subgraph.py | ||
| triton_combo_kernel.py | ||
| triton_split_scan.py | ||
| triton_utils.py | ||
| triton.py | ||
| wrapper_fxir.py | ||
| wrapper.py | ||