mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Inductor][CPP] Avoid transpose with cpp micro-gemm for FlexAttention (#147069)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147069 Approved by: https://github.com/leslie-fang-intel, https://github.com/jgong5, https://github.com/drisspg ghstack dependencies: #147068
This commit is contained in:
parent
6c089f5da3
commit
d57f617844
|
|
@ -135,7 +135,7 @@ else:
|
|||
)
|
||||
|
||||
test_dtypes = (
|
||||
[torch.float32, torch.bfloat16]
|
||||
[torch.float32, torch.bfloat16, torch.float16]
|
||||
if torch.backends.mkldnn.is_available()
|
||||
and torch.ops.mkldnn._is_mkldnn_bf16_supported()
|
||||
else [torch.float32]
|
||||
|
|
@ -3677,23 +3677,6 @@ class GraphModule(torch.nn.Module):
|
|||
):
|
||||
attention(query, key, value, return_lse=True)
|
||||
|
||||
@unittest.skipIf(TEST_ON_CUDA, "Testing CPU error message")
|
||||
def test_validate_cpu_dtype_error_message(self):
|
||||
make_tensor = functools.partial(
|
||||
torch.randn,
|
||||
(2, 2, 128, 16),
|
||||
device="cpu",
|
||||
dtype=torch.half,
|
||||
requires_grad=False,
|
||||
)
|
||||
query, key, value = make_tensor(), make_tensor(), make_tensor()
|
||||
attention = torch.compile(flex_attention)
|
||||
with self.assertRaisesRegex(
|
||||
torch._inductor.exc.InductorError,
|
||||
r"`torch.float` and `torch.bfloat16` are supported in FlexAttention for CPU device. Found input tensors are `torch.float16`.",
|
||||
):
|
||||
attention(query, key, value)
|
||||
|
||||
@unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
|
||||
def test_device_cuda_1(self):
|
||||
class TestModule(torch.nn.Module):
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ from ..select_algorithm import DataProcessorTemplateWrapper
|
|||
from ..utils import parallel_num_threads
|
||||
from ..virtualized import V
|
||||
from .cpp_template import CppTemplate
|
||||
from .cpp_utils import GemmBlocking
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
|
@ -195,6 +196,10 @@ inline void {{kernel_name}}_copy_value_with_pad(
|
|||
}
|
||||
"""
|
||||
|
||||
MICRO_GEMM_TEMPLATE = r"""
|
||||
GEMM_DEFINE
|
||||
"""
|
||||
|
||||
ALLOCATE_BUFFER = r"""
|
||||
int64_t {{buffer_name}}_dtype_itemsize = std::is_same_v<{{buffer_dtype}}, at::BFloat16> ? 2 : 4;
|
||||
auto& {{buffer_name}}_allocator = *at::getCPUAllocator();
|
||||
|
|
@ -208,6 +213,7 @@ FLEX_ATTENTION_TEMPLATE = r"""
|
|||
#include <ATen/native/cpu/utils.h>
|
||||
#include <ATen/native/CPUBlas.h>
|
||||
#include <ATen/Context.h>
|
||||
{{template.codegen_micro_gemm(kernel.kernel_name)}}
|
||||
{{template.codegen_softmax_fusion(kernel.kernel_name)}}
|
||||
{{template.codegen_brgemm_pack_function(kernel.kernel_name)}}
|
||||
{%- set kernel_args = {"query": query, "key": key, "value": value,
|
||||
|
|
@ -329,7 +335,6 @@ extern "C"
|
|||
need_pack = gemm_size_per_thread / pack_size >= 4;
|
||||
}
|
||||
}
|
||||
|
||||
// Pad is needed for packing when K is not even
|
||||
bool headSize_even = headSize % 2 == 0;
|
||||
int64_t eheadSize = need_pack && !headSize_even ? headSize + 1: headSize;
|
||||
|
|
@ -358,37 +363,37 @@ extern "C"
|
|||
{{template.codegen_allocate_buffer("transpose_buffer_ptr", "scalar_t", "num_thread*kvSplitSize*headSize")}}
|
||||
{{template.codegen_allocate_buffer("query_padding_ptr", "scalar_t", "num_thread*qSplitSize*eheadSize")}}
|
||||
|
||||
// Reorder K, V and transpose K
|
||||
at::parallel_for(0, batchSize * num_head * kvSlice, 1, [&](int64_t begin, int64_t end) {
|
||||
int ompIdx = at::get_thread_num();
|
||||
int64_t i = 0, j = 0, l = 0, n = 0;
|
||||
scalar_t* transpose_ptr = need_pack? transpose_buffer_ptr + ompIdx * kvSplitSize * headSize : nullptr;
|
||||
at::native::data_index_init(begin, i, batchSize, j, num_head, l, kvSlice);
|
||||
for ([[maybe_unused]] auto z : c10::irange(begin, end)) {
|
||||
n = l * kvSplitSize;
|
||||
int64_t cur_kvSplitSize = std::min(kvSplitSize, kvSize - n);
|
||||
auto i_kv = is_broadcast_bs_kv ? i/bs_shards : i;
|
||||
auto j_kv = is_broadcast_head_kv ? j/gqa_shards : j;
|
||||
auto kv_block_num = n / cur_kvSplitSize;
|
||||
auto kv_block_offset = n - kv_block_num * cur_kvSplitSize;
|
||||
// getting kv indices by [BS, Head, 1, kv_block_num]
|
||||
auto i_kvi = is_broadcast_bs_kvi ? i/bs_shards_kvi : i;
|
||||
auto j_kvi = is_broadcast_head_kvi ? j/gqa_shards_kvi : j;
|
||||
auto kv_logical_data = kv_indices_data + i_kvi * kviStrideB +
|
||||
j_kvi * kviStrideH + kv_block_num;
|
||||
auto k_addr =
|
||||
k_data + i_kv * kStrideB + j_kv * kStrideH + n * kStrideN;
|
||||
auto v_addr =
|
||||
v_data + i_kv * vStrideB + j_kv * vStrideH + n * vStrideN;
|
||||
if (use_kv_indice) {
|
||||
k_addr =
|
||||
k_data + i_kv * kStrideB + j_kv * kStrideH +
|
||||
(*kv_logical_data * cur_kvSplitSize + kv_block_offset) * kStrideN;
|
||||
v_addr =
|
||||
v_data + i_kv * vStrideB + j_kv * vStrideH +
|
||||
(*kv_logical_data * cur_kvSplitSize + kv_block_offset) * vStrideN;
|
||||
}
|
||||
if (need_pack) {
|
||||
if (need_pack) {
|
||||
// Pack K, V
|
||||
at::parallel_for(0, batchSize * num_head * kvSlice, 1, [&](int64_t begin, int64_t end) {
|
||||
int ompIdx = at::get_thread_num();
|
||||
int64_t i = 0, j = 0, l = 0, n = 0;
|
||||
scalar_t* transpose_ptr = transpose_buffer_ptr + ompIdx * kvSplitSize * headSize;
|
||||
at::native::data_index_init(begin, i, batchSize, j, num_head, l, kvSlice);
|
||||
for ([[maybe_unused]] auto z : c10::irange(begin, end)) {
|
||||
n = l * kvSplitSize;
|
||||
int64_t cur_kvSplitSize = std::min(kvSplitSize, kvSize - n);
|
||||
auto i_kv = is_broadcast_bs_kv ? i/bs_shards : i;
|
||||
auto j_kv = is_broadcast_head_kv ? j/gqa_shards : j;
|
||||
auto kv_block_num = n / cur_kvSplitSize;
|
||||
auto kv_block_offset = n - kv_block_num * cur_kvSplitSize;
|
||||
// getting kv indices by [BS, Head, 1, kv_block_num]
|
||||
auto i_kvi = is_broadcast_bs_kvi ? i/bs_shards_kvi : i;
|
||||
auto j_kvi = is_broadcast_head_kvi ? j/gqa_shards_kvi : j;
|
||||
auto kv_logical_data = kv_indices_data + i_kvi * kviStrideB +
|
||||
j_kvi * kviStrideH + kv_block_num;
|
||||
auto k_addr =
|
||||
k_data + i_kv * kStrideB + j_kv * kStrideH + n * kStrideN;
|
||||
auto v_addr =
|
||||
v_data + i_kv * vStrideB + j_kv * vStrideH + n * vStrideN;
|
||||
if (use_kv_indice) {
|
||||
k_addr =
|
||||
k_data + i_kv * kStrideB + j_kv * kStrideH +
|
||||
(*kv_logical_data * cur_kvSplitSize + kv_block_offset) * kStrideN;
|
||||
v_addr =
|
||||
v_data + i_kv * vStrideB + j_kv * vStrideH +
|
||||
(*kv_logical_data * cur_kvSplitSize + kv_block_offset) * vStrideN;
|
||||
}
|
||||
// transpose [cur_kvSplitSize, headSize] -> [headSize, cur_kvSplitSize]
|
||||
at::native::utils::transpose<uint16_t>(
|
||||
cur_kvSplitSize,
|
||||
|
|
@ -417,23 +422,11 @@ extern "C"
|
|||
/* ld_src */ vStrideN,
|
||||
/* K */ cur_kvSplitSize,
|
||||
/* N */ headSize_v);
|
||||
} else {
|
||||
using trans_t = std::conditional_t<std::is_same_v<scalar_t, at::BFloat16>, uint16_t, float>;
|
||||
at::native::utils::transpose<trans_t>(
|
||||
cur_kvSplitSize,
|
||||
headSize,
|
||||
/* src_ptr */
|
||||
reinterpret_cast<const trans_t*>(k_addr),
|
||||
/* ld_src */ kStrideN,
|
||||
/* dst */ reinterpret_cast<trans_t*>(key_reorder_ptr + i * num_head * eheadSize * kvSize +
|
||||
j * eheadSize * kvSize + n * eheadSize),
|
||||
/* ld_dst */ cur_kvSplitSize);
|
||||
// Move to the next query
|
||||
at::native::data_index_step(i, batchSize, j, num_head, l, kvSlice);
|
||||
}
|
||||
// Move to the next query
|
||||
at::native::data_index_step(i, batchSize, j, num_head, l, kvSlice);
|
||||
}
|
||||
});
|
||||
|
||||
});
|
||||
}
|
||||
// Attention loop below
|
||||
at::parallel_for(0, batchSize * num_head * qSlice, 1, [&](int64_t begin, int64_t end) {
|
||||
int64_t i = 0, j = 0, k = 0;
|
||||
|
|
@ -488,22 +481,26 @@ extern "C"
|
|||
auto kv_logical_data = kv_indices_data + i_kvi * kviStrideB +
|
||||
j_kvi * kviStrideH + kv_block_num;
|
||||
if (!need_pack) {
|
||||
auto k_addr_t = key_reorder_ptr + i * num_head * eheadSize * kvSize +
|
||||
j * eheadSize * kvSize + n * eheadSize;
|
||||
// TODO: use the micro-gemm template instead of brgemm API
|
||||
at::native::cpublas::brgemm(
|
||||
cur_qSplitSize,
|
||||
cur_kvSplitSize,
|
||||
eheadSize,
|
||||
qStrideM,
|
||||
cur_kvSplitSize,
|
||||
cur_kvSplitSize,
|
||||
false,
|
||||
auto k_addr =
|
||||
k_data + i_kv * kStrideB + j_kv * kStrideH + n * kStrideN;
|
||||
if (use_kv_indice) {
|
||||
k_addr =
|
||||
k_data + i_kv * kStrideB + j_kv * kStrideH +
|
||||
(*kv_logical_data * kvBlockSize + kv_block_offset) * kStrideN;
|
||||
}
|
||||
|
||||
{{kernel.kernel_name}}_kernel_micro_gemm<static_cast<bool>(false)>(
|
||||
q_data + i * qStrideB + j * qStrideH +
|
||||
m * qStrideM,
|
||||
k_addr_t,
|
||||
k_addr,
|
||||
qk_data,
|
||||
need_pack);
|
||||
cur_qSplitSize,
|
||||
cur_kvSplitSize,
|
||||
headSize,
|
||||
qStrideM,
|
||||
kStrideN,
|
||||
cur_kvSplitSize);
|
||||
|
||||
} else {
|
||||
at::native::cpublas::brgemm(
|
||||
cur_qSplitSize,
|
||||
|
|
@ -690,7 +687,7 @@ class CppFlexAttentionTemplate(CppTemplate):
|
|||
kernel_input_name_to_buffer,
|
||||
block_vars,
|
||||
) -> None:
|
||||
assert layout.dtype in [torch.float, torch.bfloat16]
|
||||
assert layout.dtype in [torch.float, torch.bfloat16, torch.float16]
|
||||
super().__init__("flex_attention", input_nodes, layout, parallel_num_threads())
|
||||
self.scale = scale
|
||||
self.score_mod = score_mod
|
||||
|
|
@ -958,6 +955,8 @@ class CppFlexAttentionTemplate(CppTemplate):
|
|||
query = kernel.permute(self.input_nodes[0], [0, 2, 1, 3])
|
||||
key = kernel.permute(self.input_nodes[1], [0, 2, 1, 3])
|
||||
value = kernel.permute(self.input_nodes[2], [0, 2, 1, 3])
|
||||
self.accumulate_dtype = torch.float
|
||||
self.input_dtype = query.layout.dtype
|
||||
|
||||
num_threads = parallel_num_threads()
|
||||
buf_out = TensorBox.create(self.output_node)
|
||||
|
|
@ -975,8 +974,8 @@ class CppFlexAttentionTemplate(CppTemplate):
|
|||
score_mod_other_buffers=self.score_mod_other_buffers,
|
||||
mask_mod_other_buffers=self.mask_mod_other_buffers,
|
||||
scale=self.scale,
|
||||
accumulate_dtype=torch.float,
|
||||
query_dtype=query.layout.dtype,
|
||||
accumulate_dtype=self.accumulate_dtype,
|
||||
query_dtype=self.input_dtype,
|
||||
kvBlockSize=self.kv_block_size,
|
||||
template=self,
|
||||
output=buf_out,
|
||||
|
|
@ -1016,3 +1015,33 @@ class CppFlexAttentionTemplate(CppTemplate):
|
|||
buffer_size=buffer_size,
|
||||
)
|
||||
)
|
||||
|
||||
def micro_gemm_define(self, kernel_name: str):
|
||||
from torch._inductor.codegen.cpp_gemm_template import (
|
||||
CppTemplateKernel,
|
||||
parallel_num_threads,
|
||||
)
|
||||
from torch._inductor.codegen.cpp_micro_gemm import CppMicroGemmFP32Vec
|
||||
from torch._inductor.virtualized import V
|
||||
|
||||
micro_gemm = CppMicroGemmFP32Vec(
|
||||
kernel_name + "_kernel_micro_gemm",
|
||||
self.input_dtype,
|
||||
self.input_dtype,
|
||||
self.accumulate_dtype,
|
||||
self.accumulate_dtype,
|
||||
GemmBlocking(1, 16, 1),
|
||||
1,
|
||||
True,
|
||||
True,
|
||||
)
|
||||
|
||||
with V.set_graph_handler(V.graph):
|
||||
kernel = CppTemplateKernel("cpp_micro_gemm", parallel_num_threads())
|
||||
code = micro_gemm.codegen_define(kernel)
|
||||
return code
|
||||
|
||||
def codegen_micro_gemm(self, kernel_name: str):
|
||||
micro_gemm = self.micro_gemm_define(kernel_name)
|
||||
GEMM_SOURCE_CODE = MICRO_GEMM_TEMPLATE.replace("GEMM_DEFINE", micro_gemm)
|
||||
return self._template_from_string(GEMM_SOURCE_CODE).render()
|
||||
|
|
|
|||
|
|
@ -877,7 +877,7 @@ inline void {{kernel_name}}_transpose_b_kernel(
|
|||
if self.trans_b:
|
||||
# TODO supports tuning of sub_block_m/sub_block_n
|
||||
# to get better performance for specific shapes
|
||||
sub_block_m = min(4, self.register_blocking.block_m)
|
||||
sub_block_m = min(1, self.register_blocking.block_m)
|
||||
sub_block_n = min(4, self.register_blocking.block_n)
|
||||
# update options to generate kernel with trans_b and sub-block size
|
||||
options.update(
|
||||
|
|
|
|||
|
|
@ -1099,9 +1099,9 @@ def lower_cpu(
|
|||
raise NotImplementedError(
|
||||
"Unsupported for now if query, key, value are the same buffer."
|
||||
)
|
||||
if query.get_dtype() not in [torch.float, torch.bfloat16]:
|
||||
if query.get_dtype() not in [torch.float, torch.bfloat16, torch.float16]:
|
||||
raise NotImplementedError(
|
||||
"`torch.float` and `torch.bfloat16` are supported in FlexAttention for CPU device. "
|
||||
"`torch.float` , `torch.float16` and `torch.bfloat16` are supported in FlexAttention for CPU device. "
|
||||
f"Found input tensors are `{query.get_dtype()}`."
|
||||
)
|
||||
score_mod_other_buffers = maybe_realize(score_mod_other_buffers)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user