[Inductor][CPP] Avoid transpose with cpp micro-gemm for FlexAttention (#147069)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147069
Approved by: https://github.com/leslie-fang-intel, https://github.com/jgong5, https://github.com/drisspg
ghstack dependencies: #147068
This commit is contained in:
jianan-gu 2025-03-03 01:22:26 -08:00 committed by PyTorch MergeBot
parent 6c089f5da3
commit d57f617844
4 changed files with 97 additions and 85 deletions

View File

@ -135,7 +135,7 @@ else:
)
test_dtypes = (
[torch.float32, torch.bfloat16]
[torch.float32, torch.bfloat16, torch.float16]
if torch.backends.mkldnn.is_available()
and torch.ops.mkldnn._is_mkldnn_bf16_supported()
else [torch.float32]
@ -3677,23 +3677,6 @@ class GraphModule(torch.nn.Module):
):
attention(query, key, value, return_lse=True)
@unittest.skipIf(TEST_ON_CUDA, "Testing CPU error message")
def test_validate_cpu_dtype_error_message(self):
make_tensor = functools.partial(
torch.randn,
(2, 2, 128, 16),
device="cpu",
dtype=torch.half,
requires_grad=False,
)
query, key, value = make_tensor(), make_tensor(), make_tensor()
attention = torch.compile(flex_attention)
with self.assertRaisesRegex(
torch._inductor.exc.InductorError,
r"`torch.float` and `torch.bfloat16` are supported in FlexAttention for CPU device. Found input tensors are `torch.float16`.",
):
attention(query, key, value)
@unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
def test_device_cuda_1(self):
class TestModule(torch.nn.Module):

View File

@ -17,6 +17,7 @@ from ..select_algorithm import DataProcessorTemplateWrapper
from ..utils import parallel_num_threads
from ..virtualized import V
from .cpp_template import CppTemplate
from .cpp_utils import GemmBlocking
log = logging.getLogger(__name__)
@ -195,6 +196,10 @@ inline void {{kernel_name}}_copy_value_with_pad(
}
"""
MICRO_GEMM_TEMPLATE = r"""
GEMM_DEFINE
"""
ALLOCATE_BUFFER = r"""
int64_t {{buffer_name}}_dtype_itemsize = std::is_same_v<{{buffer_dtype}}, at::BFloat16> ? 2 : 4;
auto& {{buffer_name}}_allocator = *at::getCPUAllocator();
@ -208,6 +213,7 @@ FLEX_ATTENTION_TEMPLATE = r"""
#include <ATen/native/cpu/utils.h>
#include <ATen/native/CPUBlas.h>
#include <ATen/Context.h>
{{template.codegen_micro_gemm(kernel.kernel_name)}}
{{template.codegen_softmax_fusion(kernel.kernel_name)}}
{{template.codegen_brgemm_pack_function(kernel.kernel_name)}}
{%- set kernel_args = {"query": query, "key": key, "value": value,
@ -329,7 +335,6 @@ extern "C"
need_pack = gemm_size_per_thread / pack_size >= 4;
}
}
// Pad is needed for packing when K is not even
bool headSize_even = headSize % 2 == 0;
int64_t eheadSize = need_pack && !headSize_even ? headSize + 1: headSize;
@ -358,37 +363,37 @@ extern "C"
{{template.codegen_allocate_buffer("transpose_buffer_ptr", "scalar_t", "num_thread*kvSplitSize*headSize")}}
{{template.codegen_allocate_buffer("query_padding_ptr", "scalar_t", "num_thread*qSplitSize*eheadSize")}}
// Reorder K, V and transpose K
at::parallel_for(0, batchSize * num_head * kvSlice, 1, [&](int64_t begin, int64_t end) {
int ompIdx = at::get_thread_num();
int64_t i = 0, j = 0, l = 0, n = 0;
scalar_t* transpose_ptr = need_pack? transpose_buffer_ptr + ompIdx * kvSplitSize * headSize : nullptr;
at::native::data_index_init(begin, i, batchSize, j, num_head, l, kvSlice);
for ([[maybe_unused]] auto z : c10::irange(begin, end)) {
n = l * kvSplitSize;
int64_t cur_kvSplitSize = std::min(kvSplitSize, kvSize - n);
auto i_kv = is_broadcast_bs_kv ? i/bs_shards : i;
auto j_kv = is_broadcast_head_kv ? j/gqa_shards : j;
auto kv_block_num = n / cur_kvSplitSize;
auto kv_block_offset = n - kv_block_num * cur_kvSplitSize;
// getting kv indices by [BS, Head, 1, kv_block_num]
auto i_kvi = is_broadcast_bs_kvi ? i/bs_shards_kvi : i;
auto j_kvi = is_broadcast_head_kvi ? j/gqa_shards_kvi : j;
auto kv_logical_data = kv_indices_data + i_kvi * kviStrideB +
j_kvi * kviStrideH + kv_block_num;
auto k_addr =
k_data + i_kv * kStrideB + j_kv * kStrideH + n * kStrideN;
auto v_addr =
v_data + i_kv * vStrideB + j_kv * vStrideH + n * vStrideN;
if (use_kv_indice) {
k_addr =
k_data + i_kv * kStrideB + j_kv * kStrideH +
(*kv_logical_data * cur_kvSplitSize + kv_block_offset) * kStrideN;
v_addr =
v_data + i_kv * vStrideB + j_kv * vStrideH +
(*kv_logical_data * cur_kvSplitSize + kv_block_offset) * vStrideN;
}
if (need_pack) {
if (need_pack) {
// Pack K, V
at::parallel_for(0, batchSize * num_head * kvSlice, 1, [&](int64_t begin, int64_t end) {
int ompIdx = at::get_thread_num();
int64_t i = 0, j = 0, l = 0, n = 0;
scalar_t* transpose_ptr = transpose_buffer_ptr + ompIdx * kvSplitSize * headSize;
at::native::data_index_init(begin, i, batchSize, j, num_head, l, kvSlice);
for ([[maybe_unused]] auto z : c10::irange(begin, end)) {
n = l * kvSplitSize;
int64_t cur_kvSplitSize = std::min(kvSplitSize, kvSize - n);
auto i_kv = is_broadcast_bs_kv ? i/bs_shards : i;
auto j_kv = is_broadcast_head_kv ? j/gqa_shards : j;
auto kv_block_num = n / cur_kvSplitSize;
auto kv_block_offset = n - kv_block_num * cur_kvSplitSize;
// getting kv indices by [BS, Head, 1, kv_block_num]
auto i_kvi = is_broadcast_bs_kvi ? i/bs_shards_kvi : i;
auto j_kvi = is_broadcast_head_kvi ? j/gqa_shards_kvi : j;
auto kv_logical_data = kv_indices_data + i_kvi * kviStrideB +
j_kvi * kviStrideH + kv_block_num;
auto k_addr =
k_data + i_kv * kStrideB + j_kv * kStrideH + n * kStrideN;
auto v_addr =
v_data + i_kv * vStrideB + j_kv * vStrideH + n * vStrideN;
if (use_kv_indice) {
k_addr =
k_data + i_kv * kStrideB + j_kv * kStrideH +
(*kv_logical_data * cur_kvSplitSize + kv_block_offset) * kStrideN;
v_addr =
v_data + i_kv * vStrideB + j_kv * vStrideH +
(*kv_logical_data * cur_kvSplitSize + kv_block_offset) * vStrideN;
}
// transpose [cur_kvSplitSize, headSize] -> [headSize, cur_kvSplitSize]
at::native::utils::transpose<uint16_t>(
cur_kvSplitSize,
@ -417,23 +422,11 @@ extern "C"
/* ld_src */ vStrideN,
/* K */ cur_kvSplitSize,
/* N */ headSize_v);
} else {
using trans_t = std::conditional_t<std::is_same_v<scalar_t, at::BFloat16>, uint16_t, float>;
at::native::utils::transpose<trans_t>(
cur_kvSplitSize,
headSize,
/* src_ptr */
reinterpret_cast<const trans_t*>(k_addr),
/* ld_src */ kStrideN,
/* dst */ reinterpret_cast<trans_t*>(key_reorder_ptr + i * num_head * eheadSize * kvSize +
j * eheadSize * kvSize + n * eheadSize),
/* ld_dst */ cur_kvSplitSize);
// Move to the next query
at::native::data_index_step(i, batchSize, j, num_head, l, kvSlice);
}
// Move to the next query
at::native::data_index_step(i, batchSize, j, num_head, l, kvSlice);
}
});
});
}
// Attention loop below
at::parallel_for(0, batchSize * num_head * qSlice, 1, [&](int64_t begin, int64_t end) {
int64_t i = 0, j = 0, k = 0;
@ -488,22 +481,26 @@ extern "C"
auto kv_logical_data = kv_indices_data + i_kvi * kviStrideB +
j_kvi * kviStrideH + kv_block_num;
if (!need_pack) {
auto k_addr_t = key_reorder_ptr + i * num_head * eheadSize * kvSize +
j * eheadSize * kvSize + n * eheadSize;
// TODO: use the micro-gemm template instead of brgemm API
at::native::cpublas::brgemm(
cur_qSplitSize,
cur_kvSplitSize,
eheadSize,
qStrideM,
cur_kvSplitSize,
cur_kvSplitSize,
false,
auto k_addr =
k_data + i_kv * kStrideB + j_kv * kStrideH + n * kStrideN;
if (use_kv_indice) {
k_addr =
k_data + i_kv * kStrideB + j_kv * kStrideH +
(*kv_logical_data * kvBlockSize + kv_block_offset) * kStrideN;
}
{{kernel.kernel_name}}_kernel_micro_gemm<static_cast<bool>(false)>(
q_data + i * qStrideB + j * qStrideH +
m * qStrideM,
k_addr_t,
k_addr,
qk_data,
need_pack);
cur_qSplitSize,
cur_kvSplitSize,
headSize,
qStrideM,
kStrideN,
cur_kvSplitSize);
} else {
at::native::cpublas::brgemm(
cur_qSplitSize,
@ -690,7 +687,7 @@ class CppFlexAttentionTemplate(CppTemplate):
kernel_input_name_to_buffer,
block_vars,
) -> None:
assert layout.dtype in [torch.float, torch.bfloat16]
assert layout.dtype in [torch.float, torch.bfloat16, torch.float16]
super().__init__("flex_attention", input_nodes, layout, parallel_num_threads())
self.scale = scale
self.score_mod = score_mod
@ -958,6 +955,8 @@ class CppFlexAttentionTemplate(CppTemplate):
query = kernel.permute(self.input_nodes[0], [0, 2, 1, 3])
key = kernel.permute(self.input_nodes[1], [0, 2, 1, 3])
value = kernel.permute(self.input_nodes[2], [0, 2, 1, 3])
self.accumulate_dtype = torch.float
self.input_dtype = query.layout.dtype
num_threads = parallel_num_threads()
buf_out = TensorBox.create(self.output_node)
@ -975,8 +974,8 @@ class CppFlexAttentionTemplate(CppTemplate):
score_mod_other_buffers=self.score_mod_other_buffers,
mask_mod_other_buffers=self.mask_mod_other_buffers,
scale=self.scale,
accumulate_dtype=torch.float,
query_dtype=query.layout.dtype,
accumulate_dtype=self.accumulate_dtype,
query_dtype=self.input_dtype,
kvBlockSize=self.kv_block_size,
template=self,
output=buf_out,
@ -1016,3 +1015,33 @@ class CppFlexAttentionTemplate(CppTemplate):
buffer_size=buffer_size,
)
)
def micro_gemm_define(self, kernel_name: str):
from torch._inductor.codegen.cpp_gemm_template import (
CppTemplateKernel,
parallel_num_threads,
)
from torch._inductor.codegen.cpp_micro_gemm import CppMicroGemmFP32Vec
from torch._inductor.virtualized import V
micro_gemm = CppMicroGemmFP32Vec(
kernel_name + "_kernel_micro_gemm",
self.input_dtype,
self.input_dtype,
self.accumulate_dtype,
self.accumulate_dtype,
GemmBlocking(1, 16, 1),
1,
True,
True,
)
with V.set_graph_handler(V.graph):
kernel = CppTemplateKernel("cpp_micro_gemm", parallel_num_threads())
code = micro_gemm.codegen_define(kernel)
return code
def codegen_micro_gemm(self, kernel_name: str):
micro_gemm = self.micro_gemm_define(kernel_name)
GEMM_SOURCE_CODE = MICRO_GEMM_TEMPLATE.replace("GEMM_DEFINE", micro_gemm)
return self._template_from_string(GEMM_SOURCE_CODE).render()

View File

@ -877,7 +877,7 @@ inline void {{kernel_name}}_transpose_b_kernel(
if self.trans_b:
# TODO supports tuning of sub_block_m/sub_block_n
# to get better performance for specific shapes
sub_block_m = min(4, self.register_blocking.block_m)
sub_block_m = min(1, self.register_blocking.block_m)
sub_block_n = min(4, self.register_blocking.block_n)
# update options to generate kernel with trans_b and sub-block size
options.update(

View File

@ -1099,9 +1099,9 @@ def lower_cpu(
raise NotImplementedError(
"Unsupported for now if query, key, value are the same buffer."
)
if query.get_dtype() not in [torch.float, torch.bfloat16]:
if query.get_dtype() not in [torch.float, torch.bfloat16, torch.float16]:
raise NotImplementedError(
"`torch.float` and `torch.bfloat16` are supported in FlexAttention for CPU device. "
"`torch.float` , `torch.float16` and `torch.bfloat16` are supported in FlexAttention for CPU device. "
f"Found input tensors are `{query.get_dtype()}`."
)
score_mod_other_buffers = maybe_realize(score_mod_other_buffers)