mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add pack support and use micro gemm for Half flex attention on CPU (#151530)
Add pack support and use micro gemm for the second gemm to improve the performance for Half flex attention on CPU. Pull Request resolved: https://github.com/pytorch/pytorch/pull/151530 Approved by: https://github.com/leslie-fang-intel, https://github.com/jansel
This commit is contained in:
parent
41bd0c900a
commit
dcd9a444b3
|
|
@ -201,7 +201,7 @@ GEMM_DEFINE
|
|||
"""
|
||||
|
||||
ALLOCATE_BUFFER = r"""
|
||||
int64_t {{buffer_name}}_dtype_itemsize = std::is_same_v<{{buffer_dtype}}, at::BFloat16> ? 2 : 4;
|
||||
int64_t {{buffer_name}}_dtype_itemsize = c10::is_reduced_floating_point_v<{{buffer_dtype}}> ? 2 : 4;
|
||||
auto& {{buffer_name}}_allocator = *at::getCPUAllocator();
|
||||
auto {{buffer_name}}_work_data = {{buffer_name}}_allocator.allocate({{buffer_size}}*{{buffer_name}}_dtype_itemsize);
|
||||
void* {{buffer_name}}_data_ptr = {{buffer_name}}_work_data.get();
|
||||
|
|
@ -318,15 +318,16 @@ extern "C"
|
|||
int64_t kvTail = (kvSize - 1) % kvSplitSize + 1;
|
||||
|
||||
bool need_pack = false;
|
||||
// Whether pack is needed for BFloat16
|
||||
if (std::is_same_v<scalar_t, at::BFloat16>) {
|
||||
// Whether pack is needed for BFloat16/Half
|
||||
if (is_reduced_type) {
|
||||
// check platform ability
|
||||
need_pack = at::native::cpublas::could_pack(at::kBFloat16);
|
||||
need_pack = std::is_same_v<scalar_t, at::BFloat16> ? at::native::cpublas::could_pack(at::kBFloat16)
|
||||
: at::native::cpublas::could_pack(at::kHalf);
|
||||
}
|
||||
if (need_pack) {
|
||||
// When the number of gemm is greater than the number of pack,
|
||||
// the pack overhead can be overlaped.
|
||||
int64_t thresh_size = 64 ;
|
||||
int64_t thresh_size = 64;
|
||||
need_pack = kvSize >= thresh_size && qSize >= thresh_size;
|
||||
if (need_pack) {
|
||||
double pack_size = batchSize * num_head * kvSize * headSize;
|
||||
|
|
@ -489,7 +490,7 @@ extern "C"
|
|||
(*kv_logical_data * kvBlockSize + kv_block_offset) * kStrideN;
|
||||
}
|
||||
|
||||
{{kernel.kernel_name}}_kernel_micro_gemm<static_cast<bool>(false)>(
|
||||
{{kernel.kernel_name}}_kernel_micro_gemm_transpose_b<static_cast<bool>(false)>(
|
||||
q_data + i * qStrideB + j * qStrideH +
|
||||
m * qStrideM,
|
||||
k_addr,
|
||||
|
|
@ -614,7 +615,9 @@ extern "C"
|
|||
v_data + i_kv * vStrideB + j_kv * vStrideH +
|
||||
(*kv_logical_data * kvBlockSize + kv_block_offset) * vStrideN;
|
||||
}
|
||||
at::native::cpublas::brgemm(
|
||||
// Fallback Half brgemm is slower than micro gemm
|
||||
if (!std::is_same_v<scalar_t, at::Half>) {
|
||||
at::native::cpublas::brgemm(
|
||||
cur_qSplitSize,
|
||||
headSize_v,
|
||||
cur_ekvSplitSize,
|
||||
|
|
@ -626,6 +629,31 @@ extern "C"
|
|||
v_addr,
|
||||
dst_data,
|
||||
need_pack);
|
||||
} else {
|
||||
if (n > 0) {
|
||||
{{kernel.kernel_name}}_kernel_micro_gemm<static_cast<bool>(true)>(
|
||||
{{kernel.kernel_name}}_conditional_data_ptr(qk_data, qk_reduced_data),
|
||||
v_addr,
|
||||
dst_data,
|
||||
cur_qSplitSize,
|
||||
headSize_v,
|
||||
cur_ekvSplitSize,
|
||||
cur_ekvSplitSize,
|
||||
vStrideN,
|
||||
headSize_v);
|
||||
} else {
|
||||
{{kernel.kernel_name}}_kernel_micro_gemm<static_cast<bool>(false)>(
|
||||
{{kernel.kernel_name}}_conditional_data_ptr(qk_data, qk_reduced_data),
|
||||
v_addr,
|
||||
dst_data,
|
||||
cur_qSplitSize,
|
||||
headSize_v,
|
||||
cur_ekvSplitSize,
|
||||
cur_ekvSplitSize,
|
||||
vStrideN,
|
||||
headSize_v);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
int64_t psize = n / kvSplitSize * ekvSplitSize;
|
||||
at::native::cpublas::brgemm(
|
||||
|
|
@ -1024,8 +1052,8 @@ class CppFlexAttentionTemplate(CppTemplate):
|
|||
from torch._inductor.codegen.cpp_micro_gemm import CppMicroGemmFP32Vec
|
||||
from torch._inductor.virtualized import V
|
||||
|
||||
micro_gemm = CppMicroGemmFP32Vec(
|
||||
kernel_name + "_kernel_micro_gemm",
|
||||
micro_gemm_trans = CppMicroGemmFP32Vec(
|
||||
kernel_name + "_kernel_micro_gemm_transpose_b",
|
||||
self.input_dtype,
|
||||
self.input_dtype,
|
||||
self.accumulate_dtype,
|
||||
|
|
@ -1036,10 +1064,23 @@ class CppFlexAttentionTemplate(CppTemplate):
|
|||
True,
|
||||
)
|
||||
|
||||
micro_gemm = CppMicroGemmFP32Vec(
|
||||
kernel_name + "_kernel_micro_gemm",
|
||||
self.input_dtype,
|
||||
self.input_dtype,
|
||||
self.accumulate_dtype,
|
||||
self.accumulate_dtype,
|
||||
GemmBlocking(1, 16, 1),
|
||||
1,
|
||||
True,
|
||||
False,
|
||||
)
|
||||
|
||||
with V.set_graph_handler(V.graph):
|
||||
kernel = CppTemplateKernel("cpp_micro_gemm", parallel_num_threads())
|
||||
code_trans = micro_gemm_trans.codegen_define(kernel)
|
||||
code = micro_gemm.codegen_define(kernel)
|
||||
return code
|
||||
return code + code_trans
|
||||
|
||||
def codegen_micro_gemm(self, kernel_name: str):
|
||||
micro_gemm = self.micro_gemm_define(kernel_name)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user