Add pack support and use micro gemm for Half flex attention on CPU (#151530)

Add pack support and use micro gemm for the second gemm to improve the performance for Half flex attention on CPU.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151530
Approved by: https://github.com/leslie-fang-intel, https://github.com/jansel
This commit is contained in:
CaoE 2025-04-29 07:23:57 +00:00 committed by PyTorch MergeBot
parent 41bd0c900a
commit dcd9a444b3

View File

@ -201,7 +201,7 @@ GEMM_DEFINE
"""
ALLOCATE_BUFFER = r"""
int64_t {{buffer_name}}_dtype_itemsize = std::is_same_v<{{buffer_dtype}}, at::BFloat16> ? 2 : 4;
int64_t {{buffer_name}}_dtype_itemsize = c10::is_reduced_floating_point_v<{{buffer_dtype}}> ? 2 : 4;
auto& {{buffer_name}}_allocator = *at::getCPUAllocator();
auto {{buffer_name}}_work_data = {{buffer_name}}_allocator.allocate({{buffer_size}}*{{buffer_name}}_dtype_itemsize);
void* {{buffer_name}}_data_ptr = {{buffer_name}}_work_data.get();
@ -318,15 +318,16 @@ extern "C"
int64_t kvTail = (kvSize - 1) % kvSplitSize + 1;
bool need_pack = false;
// Whether pack is needed for BFloat16
if (std::is_same_v<scalar_t, at::BFloat16>) {
// Whether pack is needed for BFloat16/Half
if (is_reduced_type) {
// check platform ability
need_pack = at::native::cpublas::could_pack(at::kBFloat16);
need_pack = std::is_same_v<scalar_t, at::BFloat16> ? at::native::cpublas::could_pack(at::kBFloat16)
: at::native::cpublas::could_pack(at::kHalf);
}
if (need_pack) {
// When the number of gemm is greater than the number of pack,
// the pack overhead can be overlaped.
int64_t thresh_size = 64 ;
int64_t thresh_size = 64;
need_pack = kvSize >= thresh_size && qSize >= thresh_size;
if (need_pack) {
double pack_size = batchSize * num_head * kvSize * headSize;
@ -489,7 +490,7 @@ extern "C"
(*kv_logical_data * kvBlockSize + kv_block_offset) * kStrideN;
}
{{kernel.kernel_name}}_kernel_micro_gemm<static_cast<bool>(false)>(
{{kernel.kernel_name}}_kernel_micro_gemm_transpose_b<static_cast<bool>(false)>(
q_data + i * qStrideB + j * qStrideH +
m * qStrideM,
k_addr,
@ -614,7 +615,9 @@ extern "C"
v_data + i_kv * vStrideB + j_kv * vStrideH +
(*kv_logical_data * kvBlockSize + kv_block_offset) * vStrideN;
}
at::native::cpublas::brgemm(
// Fallback Half brgemm is slower than micro gemm
if (!std::is_same_v<scalar_t, at::Half>) {
at::native::cpublas::brgemm(
cur_qSplitSize,
headSize_v,
cur_ekvSplitSize,
@ -626,6 +629,31 @@ extern "C"
v_addr,
dst_data,
need_pack);
} else {
if (n > 0) {
{{kernel.kernel_name}}_kernel_micro_gemm<static_cast<bool>(true)>(
{{kernel.kernel_name}}_conditional_data_ptr(qk_data, qk_reduced_data),
v_addr,
dst_data,
cur_qSplitSize,
headSize_v,
cur_ekvSplitSize,
cur_ekvSplitSize,
vStrideN,
headSize_v);
} else {
{{kernel.kernel_name}}_kernel_micro_gemm<static_cast<bool>(false)>(
{{kernel.kernel_name}}_conditional_data_ptr(qk_data, qk_reduced_data),
v_addr,
dst_data,
cur_qSplitSize,
headSize_v,
cur_ekvSplitSize,
cur_ekvSplitSize,
vStrideN,
headSize_v);
}
}
} else {
int64_t psize = n / kvSplitSize * ekvSplitSize;
at::native::cpublas::brgemm(
@ -1024,8 +1052,8 @@ class CppFlexAttentionTemplate(CppTemplate):
from torch._inductor.codegen.cpp_micro_gemm import CppMicroGemmFP32Vec
from torch._inductor.virtualized import V
micro_gemm = CppMicroGemmFP32Vec(
kernel_name + "_kernel_micro_gemm",
micro_gemm_trans = CppMicroGemmFP32Vec(
kernel_name + "_kernel_micro_gemm_transpose_b",
self.input_dtype,
self.input_dtype,
self.accumulate_dtype,
@ -1036,10 +1064,23 @@ class CppFlexAttentionTemplate(CppTemplate):
True,
)
micro_gemm = CppMicroGemmFP32Vec(
kernel_name + "_kernel_micro_gemm",
self.input_dtype,
self.input_dtype,
self.accumulate_dtype,
self.accumulate_dtype,
GemmBlocking(1, 16, 1),
1,
True,
False,
)
with V.set_graph_handler(V.graph):
kernel = CppTemplateKernel("cpp_micro_gemm", parallel_num_threads())
code_trans = micro_gemm_trans.codegen_define(kernel)
code = micro_gemm.codegen_define(kernel)
return code
return code + code_trans
def codegen_micro_gemm(self, kernel_name: str):
micro_gemm = self.micro_gemm_define(kernel_name)