pytorch/torch/_inductor/codegen
angelayi 25ef3d315d [aoti][mps] Dynamic reductions (#159355)
Dynamic kernel:
```cpp
[[max_total_threads_per_threadgroup(1024)]]
kernel void generated_kernel(
    device float* out_ptr0,
    constant float* in_ptr0,
    constant long& r0_numel,
    uint2 thread_pos [[thread_position_in_grid]],
    uint2 group_pos [[thread_position_in_threadgroup]]
) {
    auto xindex = thread_pos.x;
    auto r0_index = thread_pos.y;
    int x0 = xindex;
    threadgroup float tmp_acc_0[32];
    float tmp_acc_1 = 0;
    for(auto r0_1_cnt = 0; r0_1_cnt < static_cast<int>(metal::floor(static_cast<float>(0.99902343750000000 + 0.00097656250000000000*r0_numel))); ++r0_1_cnt) {
        int r0_1 = 1024 * r0_1_cnt + r0_index;
        if (r0_1 >= r0_numel) break;
        auto tmp0 = in_ptr0[x0 + 5*r0_1];
        tmp_acc_1 += tmp0;
    }
    auto tmp1 = c10:🤘:threadgroup_sum(tmp_acc_0, tmp_acc_1, r0_index * 1, metal::min(static_cast<decltype(1024+r0_numel)>(1024), static_cast<decltype(1024+r0_numel)>(r0_numel)));
    if (r0_index == 0) out_ptr0[x0] = static_cast<float>(tmp1);
}

void AOTInductorModel::run_impl(...) {
    ...
    auto arg0_1_size = arg0_1.sizes();
    int64_t s77 = arg0_1_size[0];
    inputs.clear();
    [[maybe_unused]] auto& kernels = static_cast<AOTInductorModelKernels&>(*this->kernels_.get());
    static constexpr int64_t int_array_0[] = {5LL, };
    static constexpr int64_t int_array_1[] = {1LL, };
    AtenTensorHandle buf0_handle;
    AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(1, int_array_0, int_array_1, cached_torch_dtype_float32, cached_torch_device_type_mps, this->device_idx_, &buf0_handle));
    RAIIAtenTensorHandle buf0(buf0_handle);
    auto mps_lib_0_func = mps_lib_0.getKernelFunction("generated_kernel");
    auto mps_lib_0_func_handle = AOTIMetalKernelFunctionHandle(mps_lib_0_func.get());
    mps_lib_0_func->runCommandBlock([&] {
        mps_lib_0_func->startEncoding();
        aoti_torch_mps_set_arg_tensor(mps_lib_0_func_handle, 0, buf0);
        aoti_torch_mps_set_arg_tensor(mps_lib_0_func_handle, 1, arg0_1);
        aoti_torch_mps_set_arg_int(mps_lib_0_func_handle, 2, s77);
        mps_lib_0_func->dispatch({static_cast<uint64_t>(5LL), static_cast<uint64_t>(std::min(static_cast<int64_t>(1024LL), static_cast<int64_t>(s77)))}, {static_cast<uint64_t>(1), static_cast<uint64_t>(std::min(static_cast<int64_t>(1024LL), static_cast<int64_t>(s77)))});

    });
    arg0_1.reset();
    output_handles[0] = buf0.release();
} // AOTInductorModel::run_impl
```

Static kernel:
```cpp
kernel void generated_kernel(
    device float* out_ptr0,
    constant float* in_ptr0,
    uint xindex [[thread_position_in_grid]]
) {
    int x0 = xindex;
    auto tmp0 = in_ptr0[x0];
    auto tmp1 = in_ptr0[5 + x0];
    auto tmp3 = in_ptr0[10 + x0];
    auto tmp5 = in_ptr0[15 + x0];
    auto tmp2 = tmp0 + tmp1;
    auto tmp4 = tmp2 + tmp3;
    auto tmp6 = tmp4 + tmp5;
    out_ptr0[x0] = static_cast<float>(tmp6);
}

void AOTInductorModel::run_impl(...) {
    ...
    static constexpr int64_t int_array_0[] = {5LL, };
    static constexpr int64_t int_array_1[] = {1LL, };
    AtenTensorHandle buf0_handle;
    AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(1, int_array_0, int_array_1, cached_torch_dtype_float32, cached_torch_device_type_mps, this->device_idx_, &buf0_handle));
    RAIIAtenTensorHandle buf0(buf0_handle);
    auto mps_lib_0_func = mps_lib_0.getKernelFunction("generated_kernel");
    auto mps_lib_0_func_handle = AOTIMetalKernelFunctionHandle(mps_lib_0_func.get());
    mps_lib_0_func->runCommandBlock([&] {
        mps_lib_0_func->startEncoding();
        aoti_torch_mps_set_arg_tensor(mps_lib_0_func_handle, 0, buf0);
        aoti_torch_mps_set_arg_tensor(mps_lib_0_func_handle, 1, arg0_1);
        mps_lib_0_func->dispatch({static_cast<uint64_t>(5LL)});

    });
    arg0_1.reset();
    output_handles[0] = buf0.release();
} // AOTInductorModel::run_impl
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159355
Approved by: https://github.com/malfet
2025-07-31 23:15:02 +00:00
..
aoti_runtime [AOTI] Save data sizes to constants_info (#154534) 2025-05-29 06:39:13 +00:00
cuda [cutlass] rename EVT args within kernels for code caching (#159243) 2025-07-28 19:01:40 +00:00
mtia [Re-land][Inductor] Support native Inductor as backend for MTIA (#159211) 2025-07-29 17:03:24 +00:00
rocm [ROCm][CK][Inductor] enable gfx950 for max autotune with CK (#159195) 2025-07-27 20:47:13 +00:00
xpu [user triton] AOT inductor support for device-side TMA (#155896) 2025-06-27 04:28:04 +00:00
__init__.py
aoti_hipify_utils.py [BE][3/16] fix typos in torch/ (torch/_inductor/) (#156313) 2025-06-23 02:57:12 +00:00
block_analysis.py [Inductor] Restrict block analysis to only match integer dims and strides (#149615) 2025-06-24 22:43:12 +00:00
common.py [Re-land][Inductor] Support native Inductor as backend for MTIA (#159211) 2025-07-29 17:03:24 +00:00
cpp_bmm_template.py
cpp_flex_attention_template.py [Inductor] Set the default value of min_chunk_size to 512 (#150762) 2025-07-21 12:46:05 +00:00
cpp_gemm_template.py [inductor] Add typing to _inductor/ir.py (#149958) 2025-06-30 15:56:35 +00:00
cpp_grouped_gemm_template.py
cpp_micro_gemm.py [Pyrefly][Refactor] Replace dict() calls with literal dict syntax for improved readability (#157735) 2025-07-08 18:10:33 +00:00
cpp_template_kernel.py [Inductor] Set the default value of min_chunk_size to 512 (#150762) 2025-07-21 12:46:05 +00:00
cpp_template.py codecache: Remove cpp_prefix.h duplication per build, then precompile it (#144293) 2025-05-16 17:41:36 +00:00
cpp_utils.py [aoti] Initial Metal support (#153959) 2025-05-23 05:45:35 +00:00
cpp_wrapper_cpu_array_ref.py [inductor] Add typing to _inductor/ir.py (#149958) 2025-06-30 15:56:35 +00:00
cpp_wrapper_cpu.py [AOTI] Explicitly delete wait_tensor returned tensor (#159502) 2025-07-31 15:33:36 +00:00
cpp_wrapper_gpu.py [user triton] AOT inductor support for device-side TMA (#155896) 2025-06-27 04:28:04 +00:00
cpp_wrapper_mps.py [aoti][mps] Improve tabbing in cpp generation (#158351) 2025-07-23 00:54:53 +00:00
cpp.py Refactor Provenance Tracking (#158399) 2025-07-17 00:23:00 +00:00
cpu_device_op_overrides.py
cuda_combined_scheduling.py multi-kernel matmuls based on varying hint sizes (#156628) 2025-07-12 15:08:21 +00:00
debug_utils.py [Inductor] Refactor wrapper codegen to use Wrapper IR. (#150458) 2025-04-15 17:28:36 +00:00
halide.py [inductor] more size_hint_or_throw usage (#157394) 2025-07-02 20:20:59 +00:00
memory_planning.py
mps_device_op_overrides.py [aoti] Initial Metal support (#153959) 2025-05-23 05:45:35 +00:00
mps.py [aoti][mps] Dynamic reductions (#159355) 2025-07-31 23:15:02 +00:00
multi_kernel.py multi-kernel matmuls based on varying hint sizes (#156628) 2025-07-12 15:08:21 +00:00
python_wrapper_mtia.py [Re-land][Inductor] Support native Inductor as backend for MTIA (#159211) 2025-07-29 17:03:24 +00:00
simd_kernel_features.py Replace runtime type parameterization (#155221) 2025-06-05 21:43:54 +00:00
simd.py [inductor][templates] Finalize all registered hooks (#157270) 2025-07-20 22:07:32 +00:00
subgraph.py [inductor] Add typing to _inductor/ir.py (#149958) 2025-06-30 15:56:35 +00:00
triton_combo_kernel.py [BE][3/16] fix typos in torch/ (torch/_inductor/) (#156313) 2025-06-23 02:57:12 +00:00
triton_split_scan.py
triton_utils.py [Inductor] Fix a user-defined Triton kernel bool param codegen issue (#158845) 2025-07-24 00:19:27 +00:00
triton.py [inductor] Update to(tl.int8).to(tl.uint8) workaround from #94717 to handle entire range of torch.uint8 (#158567) 2025-07-26 19:11:37 +00:00
wrapper_fxir.py FXConverter handling of generic output in inductor fallback kernel (#159002) (#159297) 2025-07-29 18:29:01 +00:00
wrapper.py [aoti][mps] Dynamic reductions (#159355) 2025-07-31 23:15:02 +00:00