mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Dynamic kernel:
```cpp
[[max_total_threads_per_threadgroup(1024)]]
kernel void generated_kernel(
device float* out_ptr0,
constant float* in_ptr0,
constant long& r0_numel,
uint2 thread_pos [[thread_position_in_grid]],
uint2 group_pos [[thread_position_in_threadgroup]]
) {
auto xindex = thread_pos.x;
auto r0_index = thread_pos.y;
int x0 = xindex;
threadgroup float tmp_acc_0[32];
float tmp_acc_1 = 0;
for(auto r0_1_cnt = 0; r0_1_cnt < static_cast<int>(metal::floor(static_cast<float>(0.99902343750000000 + 0.00097656250000000000*r0_numel))); ++r0_1_cnt) {
int r0_1 = 1024 * r0_1_cnt + r0_index;
if (r0_1 >= r0_numel) break;
auto tmp0 = in_ptr0[x0 + 5*r0_1];
tmp_acc_1 += tmp0;
}
auto tmp1 = c10:🤘:threadgroup_sum(tmp_acc_0, tmp_acc_1, r0_index * 1, metal::min(static_cast<decltype(1024+r0_numel)>(1024), static_cast<decltype(1024+r0_numel)>(r0_numel)));
if (r0_index == 0) out_ptr0[x0] = static_cast<float>(tmp1);
}
void AOTInductorModel::run_impl(...) {
...
auto arg0_1_size = arg0_1.sizes();
int64_t s77 = arg0_1_size[0];
inputs.clear();
[[maybe_unused]] auto& kernels = static_cast<AOTInductorModelKernels&>(*this->kernels_.get());
static constexpr int64_t int_array_0[] = {5LL, };
static constexpr int64_t int_array_1[] = {1LL, };
AtenTensorHandle buf0_handle;
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(1, int_array_0, int_array_1, cached_torch_dtype_float32, cached_torch_device_type_mps, this->device_idx_, &buf0_handle));
RAIIAtenTensorHandle buf0(buf0_handle);
auto mps_lib_0_func = mps_lib_0.getKernelFunction("generated_kernel");
auto mps_lib_0_func_handle = AOTIMetalKernelFunctionHandle(mps_lib_0_func.get());
mps_lib_0_func->runCommandBlock([&] {
mps_lib_0_func->startEncoding();
aoti_torch_mps_set_arg_tensor(mps_lib_0_func_handle, 0, buf0);
aoti_torch_mps_set_arg_tensor(mps_lib_0_func_handle, 1, arg0_1);
aoti_torch_mps_set_arg_int(mps_lib_0_func_handle, 2, s77);
mps_lib_0_func->dispatch({static_cast<uint64_t>(5LL), static_cast<uint64_t>(std::min(static_cast<int64_t>(1024LL), static_cast<int64_t>(s77)))}, {static_cast<uint64_t>(1), static_cast<uint64_t>(std::min(static_cast<int64_t>(1024LL), static_cast<int64_t>(s77)))});
});
arg0_1.reset();
output_handles[0] = buf0.release();
} // AOTInductorModel::run_impl
```
Static kernel:
```cpp
kernel void generated_kernel(
device float* out_ptr0,
constant float* in_ptr0,
uint xindex [[thread_position_in_grid]]
) {
int x0 = xindex;
auto tmp0 = in_ptr0[x0];
auto tmp1 = in_ptr0[5 + x0];
auto tmp3 = in_ptr0[10 + x0];
auto tmp5 = in_ptr0[15 + x0];
auto tmp2 = tmp0 + tmp1;
auto tmp4 = tmp2 + tmp3;
auto tmp6 = tmp4 + tmp5;
out_ptr0[x0] = static_cast<float>(tmp6);
}
void AOTInductorModel::run_impl(...) {
...
static constexpr int64_t int_array_0[] = {5LL, };
static constexpr int64_t int_array_1[] = {1LL, };
AtenTensorHandle buf0_handle;
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(1, int_array_0, int_array_1, cached_torch_dtype_float32, cached_torch_device_type_mps, this->device_idx_, &buf0_handle));
RAIIAtenTensorHandle buf0(buf0_handle);
auto mps_lib_0_func = mps_lib_0.getKernelFunction("generated_kernel");
auto mps_lib_0_func_handle = AOTIMetalKernelFunctionHandle(mps_lib_0_func.get());
mps_lib_0_func->runCommandBlock([&] {
mps_lib_0_func->startEncoding();
aoti_torch_mps_set_arg_tensor(mps_lib_0_func_handle, 0, buf0);
aoti_torch_mps_set_arg_tensor(mps_lib_0_func_handle, 1, arg0_1);
mps_lib_0_func->dispatch({static_cast<uint64_t>(5LL)});
});
arg0_1.reset();
output_handles[0] = buf0.release();
} // AOTInductorModel::run_impl
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159355
Approved by: https://github.com/malfet
|
||
|---|---|---|
| .. | ||
| aoti_runtime | ||
| cuda | ||
| mtia | ||
| rocm | ||
| xpu | ||
| __init__.py | ||
| aoti_hipify_utils.py | ||
| block_analysis.py | ||
| common.py | ||
| cpp_bmm_template.py | ||
| cpp_flex_attention_template.py | ||
| cpp_gemm_template.py | ||
| cpp_grouped_gemm_template.py | ||
| cpp_micro_gemm.py | ||
| cpp_template_kernel.py | ||
| cpp_template.py | ||
| cpp_utils.py | ||
| cpp_wrapper_cpu_array_ref.py | ||
| cpp_wrapper_cpu.py | ||
| cpp_wrapper_gpu.py | ||
| cpp_wrapper_mps.py | ||
| cpp.py | ||
| cpu_device_op_overrides.py | ||
| cuda_combined_scheduling.py | ||
| debug_utils.py | ||
| halide.py | ||
| memory_planning.py | ||
| mps_device_op_overrides.py | ||
| mps.py | ||
| multi_kernel.py | ||
| python_wrapper_mtia.py | ||
| simd_kernel_features.py | ||
| simd.py | ||
| subgraph.py | ||
| triton_combo_kernel.py | ||
| triton_split_scan.py | ||
| triton_utils.py | ||
| triton.py | ||
| wrapper_fxir.py | ||
| wrapper.py | ||