[Inductor-CPU] Faster int8 WoQ GEMM for small M with explicit prefetching and different outer loops (#149373)

### Summary

Fixes #148494

Explicitly prefetch the cache lines of the next `B` block to accelerate int8 WoQ (BF16 activation, int8 statically quantized weights) GEMM for small `M` dimension.

Some of this code (outer loops of the GEMM) is being ported over from Intel Extension for PyTorch. The macro-kernel* and the micro-kernel* are essentially the same, but optionally prefetch a block of B. Templatization is being used to prevent branching causing a slowdown due to unnecessary prefetching.

\* - in [BLIS](https://dl.acm.org/doi/10.1145/2764454) parlance

### Performance data with BS 1

Machine: 32 cores of one socket of a Intel Xeon SP Gen 5 machine

| Model | input tokens | output tokens | next-token latency before this PR | Next-token latency after this change | Speedup |
|-----------|-------------|-----------------|--------------------------------------|------------------------------------------|-----------|
|GPT-J | 128 | 128 | 42 ms | 38 ms | 9.52 % |
| GPT-J | 1024 | 1024 | 48 ms | 45 ms | 6.25 % |
|LLaMA 3.1 8B Instruct | 128 | 128 | 52 ms | 47 ms|  9.61% |
|LLaMA 3.1 8B Instruct | 1024 | 1024 | 57 ms | 53 ms|  7.01% |

While the input shapes of GEMMs corresponding to linear for next-token computation remain the same in case of different number of input & output tokens, the difference in next-token latency is due to attention for those cases

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149373
Approved by: https://github.com/leslie-fang-intel, https://github.com/Xia-Weiwen

Co-authored-by: Xia Weiwen <xia.weiwen@hotmail.com>
This commit is contained in:
sanchitintel 2025-05-15 11:55:54 +00:00 committed by PyTorch MergeBot
parent e5e06d9cab
commit 7482eb217c
4 changed files with 194 additions and 25 deletions

View File

@ -1378,6 +1378,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
@parametrize(
"batch_size",
(
1,
17,
32,
),
@ -1429,6 +1430,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
mod = M(w_int8pack).eval()
self.common(mod, (x, w_scales))
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
if batch_size * mid_dim >= 16:
vec_amx = VecAMX()
self._check_amx_counter(vec_amx)

View File

@ -26,7 +26,9 @@ from .cpp_micro_gemm import (
CppMicroBrgemm,
CppMicroGemm,
CppMicroGemmAMX,
CppMicroGemmFP32Vec,
create_micro_gemm,
is_int8_woq_gemm_small_m_dim_corner_case,
LayoutType,
)
from .cpp_template import CppTemplate
@ -41,7 +43,7 @@ from .cpp_utils import (
log = logging.getLogger(__name__)
GEMM_TEMPLATE_INIT_BLOCKING = r"""
GEMM_TEMPLATE_INIT_BLOCKING_BASIC_BLOCK = r"""
constexpr int64_t num_threads = {{num_threads}};
constexpr int64_t N = {{N}};
constexpr int64_t K = {{K}};
@ -50,10 +52,17 @@ GEMM_TEMPLATE_INIT_BLOCKING = r"""
constexpr int64_t Kr = {{micro_gemm.register_blocking.block_k}};
constexpr int64_t Nr_blocks = (N + Nr - 1) / Nr;
constexpr int64_t Kr_blocks = (K + Kr - 1) / Kr;
{%- if is_dynamic_M %}
const int64_t M = {{kernel.size(GemmOut, 0)}};
const int64_t Mr_blocks = (M + Mr - 1) / Mr;
{%- else %}
constexpr int64_t M = {{kernel.size(GemmOut, 0)}};
constexpr int64_t Mr_blocks = (M + Mr - 1) / Mr;
{%- endif %}
"""
GEMM_TEMPLATE_INIT_BLOCKING_EXTENDED = r"""
{%- if is_dynamic_M %}
{%- if num_threads > 1 %}
int64_t Mt_blocks, Nt_blocks, Kt_blocks;
mm_get_thread_blocking(num_threads, {{config.cpp.gemm_max_k_slices}}, M, N, K, Mr, Nr, Kr, Mt_blocks, Nt_blocks, Kt_blocks);
@ -88,8 +97,6 @@ GEMM_TEMPLATE_INIT_BLOCKING = r"""
const int64_t num_Nt_blocks = (Nr_blocks + Nt_blocks - 1) / Nt_blocks;
const int64_t num_Kt_blocks = (Kr_blocks + Kt_blocks - 1) / Kt_blocks;
{%- else %}
constexpr int64_t M = {{kernel.size(GemmOut, 0)}};
constexpr int64_t Mr_blocks = (M + Mr - 1) / Mr;
constexpr int64_t Mt_blocks = {{template.thread_blocking(num_threads).block_m}};
constexpr int64_t Nt_blocks = {{template.thread_blocking(num_threads).block_n}};
constexpr int64_t Kt_blocks = {{template.thread_blocking(num_threads).block_k}};
@ -318,6 +325,52 @@ GEMM_TEMPLATE = r"""
}
"""
SMALL_M_GEMM_TEMPLATE = r"""
{{ template.codegen_gemm_stub_def() }}
{
{{ kernel.maybe_codegen_profile() }}
{{ template.codegen_blocks(
num_threads, N, K, micro_gemm, is_dynamic_M, kernel, GemmOut, config, L1_cache_size, L2_cache_size, X, W
) }}
# pragma omp parallel
{
#pragma omp for nowait
for (int64_t nr_block_id = 0; nr_block_id < Nr_blocks; nr_block_id++) {
// Handle one output M * Nr block in each thread
int64_t n_start = nr_block_id * Nr;
int64_t n_end = (nr_block_id + 1) * Nr;
{%- if use_local_acc %}
{%- set acc_buf_name = "local_acc_buf" %}
{{ kernel.define_stack_allocated_buffer(acc_buf_name, ["M", "Nr"], acc_buf_dtype) }}
{%- set acc = kernel.local_buffers[acc_buf_name] %}
{%- else %}
{%- set acc = kernel.slice_nd(GemmOut, [(0, "M"), ("n_start", "n_end")]) %}
{%- endif %}
for (int64_t kr_block_id = 0; kr_block_id < Kr_blocks; kr_block_id++) {
// this loop is not parallelized
int64_t k_start = kr_block_id * Kr;
int64_t k_end = std::min((kr_block_id + 1) * Kr, K);
{%- set tile_X = kernel.slice_nd(X, [(0, "M"), ("k_start", "k_end")]) %}
{%- set tile_W_3d = kernel.slice_nd(W, [("nr_block_id", "nr_block_id + 1"), ("k_start", "k_end"), ()]) %}
{%- set tile_W = kernel.view(tile_W_3d, ["k_end - k_start", micro_gemm.register_blocking.block_n]) %}
if C10_UNLIKELY(kr_block_id == 0) {
{{ micro_gemm.codegen_call(kernel, tile_X, tile_W, acc, accum=False, prefetch=True)|indent(20, false) }}
} else if C10_UNLIKELY(k_end == K) {
{{ micro_gemm.codegen_call(kernel, tile_X, tile_W, acc, accum=True, prefetch=False)|indent(20, false) }}
} else {
{{ micro_gemm.codegen_call(kernel, tile_X, tile_W, acc, accum=True, prefetch=True)|indent(20, false) }}
}
}
{%- set tile_Y = kernel.slice_nd(Y_2d, [("0", "M"), ("n_start", "n_end")]) %}
{%- set tile_acc = kernel.slice_nd(acc, [("0", "M"), ("0", "n_end - n_start")]) %}
{{ kernel.store_output(
tile_Y, tile_acc, GemmOut, epilogue_nodes, offsets=("0", "n_start"), reindexers=reindexers
)|indent(20, false) }}
}
}
}
"""
def _is_int8_gemm(inputs):
return (
@ -1485,6 +1538,24 @@ class CppGemmTemplate(CppTemplate):
)
return options
def is_int8_woq_gemm_small_m_dim(
self,
X: ir.ReinterpretView,
W: ir.ReinterpretView,
N,
K,
micro_gemm,
):
"""Use SMALL_M_GEMM_TEMPLATE"""
return (
isinstance(micro_gemm, CppMicroGemmFP32Vec)
and is_int8_woq_gemm_small_m_dim_corner_case(
micro_gemm, X.get_size()[0], N, K
)
and X.get_dtype() is torch.bfloat16
and W.get_dtype() is torch.int8
)
def render( # type: ignore[override, return]
self,
kernel: CppTemplateKernel,
@ -1506,7 +1577,17 @@ class CppGemmTemplate(CppTemplate):
stack.enter_context(
patch.object(V.graph, "get_dtype", self._fake_get_dtype(buf))
)
return self._template_from_string(GEMM_TEMPLATE).render(**options)
if not options["is_dynamic_M"] and self.is_int8_woq_gemm_small_m_dim(
options["X"],
options["W"],
options["N"],
options["K"],
options["micro_gemm"],
):
template_str = SMALL_M_GEMM_TEMPLATE
else:
template_str = GEMM_TEMPLATE
return self._template_from_string(template_str).render(**options)
def codegen_blocks(
self,
@ -1539,7 +1620,13 @@ class CppGemmTemplate(CppTemplate):
W=W,
is_woq_int4=self.is_woq_int4(),
)
return self._template_from_string(GEMM_TEMPLATE_INIT_BLOCKING).render(options)
template_str = GEMM_TEMPLATE_INIT_BLOCKING_BASIC_BLOCK
if not (
not is_dynamic_M
and self.is_int8_woq_gemm_small_m_dim(X, W, N, K, micro_gemm)
):
template_str += GEMM_TEMPLATE_INIT_BLOCKING_EXTENDED
return self._template_from_string(template_str).render(options)
def codegen_microkernel_def(self):
return self._template_from_string(GEMM_TEMPLATE_MICROKERNEL_DEF).render(

View File

@ -4,8 +4,6 @@ import sys
from enum import Enum
from typing import Callable, Optional
import sympy
import torch
from .. import cpp_builder, ir
@ -55,7 +53,7 @@ class CppMicroGemm:
# TODO(jgong5): support constant shapes and lds as template args.
DECLARE_KERNEL = r"""
template <bool accum>
template <bool accum, bool prefetch=false>
inline void {{kernel_name}}(
{%- if kernel_extra_args_declare %}
{{kernel_extra_args_declare}}
@ -138,6 +136,7 @@ inline void {{kernel_name}}(
B: ir.Buffer,
C: ir.Buffer,
accum: bool,
prefetch: bool = False,
**kwargs_for_extra_args,
) -> str:
"""
@ -154,7 +153,9 @@ inline void {{kernel_name}}(
ldb = kernel.stride(B, 0)
ldc = kernel.stride(C, 0)
res = IndentedBuffer()
res.writeline(f"{self.name}<{value_to_cpp(accum, 'bool')}>(")
res.writeline(
f"{self.name}<{value_to_cpp(accum, 'bool')}, {value_to_cpp(prefetch, 'bool')}>("
)
with res.indent():
kwargs_for_extra_args.update({"kernel": kernel})
extra_args = self.get_kernel_extra_args(**kwargs_for_extra_args)
@ -317,6 +318,26 @@ class CppMicroGemmRef(CppMicroGemm):
return KernelTemplate._template_from_string(self.TEMPLATE_ENTRY).render(options)
def is_int8_woq_gemm_small_m_dim_corner_case(config, m, n, k):
return (
k % config.register_blocking.block_k == 0
and n % config.register_blocking.block_n == 0
and m < 16
)
# extra check for small M dimension for int8 WoQ case
def check_int8_woq_small_m_dim(config, m, n, k, alpha, num_threads, **kwargs):
return is_int8_woq_gemm_small_m_dim_corner_case(config, m, n, k) and not kwargs.get(
"dynamic_M", False
)
# For int8 WoQ GEMM with small M, we use different blockings that shouldn't be used otherwise
def do_not_use_with_small_m_for_int8_woq(config, m, n, k, alpha, num_threads, **kwargs):
return not check_int8_woq_small_m_dim(config, m, n, k, alpha, num_threads, **kwargs)
@register_micro_gemm(
*generate_gemm_config(
VecAVX512,
@ -342,6 +363,19 @@ class CppMicroGemmRef(CppMicroGemm):
input2_dtype=torch.int8,
output_dtype=torch.float,
compute_dtype=torch.float,
extra_check=do_not_use_with_small_m_for_int8_woq,
),
*generate_gemm_config(
VecAVX512,
[
(4, 32, 64),
(8, 32, 64),
],
input_dtype=torch.bfloat16,
input2_dtype=torch.int8,
output_dtype=torch.float,
compute_dtype=torch.float,
extra_check=check_int8_woq_small_m_dim,
),
*generate_gemm_config(
VecAVX2,
@ -367,6 +401,19 @@ class CppMicroGemmRef(CppMicroGemm):
input2_dtype=torch.int8,
output_dtype=torch.float,
compute_dtype=torch.float,
extra_check=do_not_use_with_small_m_for_int8_woq,
),
*generate_gemm_config(
VecAVX2,
[
(2, 16, 64),
(4, 16, 64),
],
input_dtype=torch.bfloat16,
input2_dtype=torch.int8,
output_dtype=torch.float,
compute_dtype=torch.float,
extra_check=check_int8_woq_small_m_dim,
),
*generate_gemm_config(
VecNEON,
@ -397,7 +444,7 @@ class CppMicroGemmFP32Vec(CppMicroGemm):
{{declare_kernel}} {
using Vectorized = at::vec::Vectorized<{{compute_t}}>;
constexpr auto VLEN = Vectorized::size();
{{kernel.assert_function}}({{block_n}} % VLEN == 0, "{{block_n}} dimension must be multiple of Vector size");
{{kernel.assert_function}}({{block_n}} % VLEN == 0, "block_n dimension must be multiple of Vector size");
{{kernel.assert_function}}(K % {{block_k}} == 0, "K dimension must be multiple of {{block_k}}");
// TODO(jgong5): loop unroll for M and N
for (int64_t m = 0; m < M; m += {{block_m}}) {
@ -406,9 +453,9 @@ class CppMicroGemmFP32Vec(CppMicroGemm):
int64_t block_n = std::min<int64_t>(N - n, {{block_n}});
if (block_m == {{block_m}} && block_n == {{block_n}}) {
{%- if not trans_b %}
{{kernel_name}}_kernel<{{block_m}}, {{block_n}}, accum>(
{{kernel_name}}_kernel<{{block_m}}, {{block_n}}, accum, prefetch>(
{%- else %}
{{kernel_name}}_transpose_b_kernel<{{block_m}}, {{block_n}}, accum>(
{{kernel_name}}_transpose_b_kernel<{{block_m}}, {{block_n}}, accum, prefetch>(
{%- endif %}
A + m * lda,
{%- if not trans_b %}
@ -431,9 +478,9 @@ class CppMicroGemmFP32Vec(CppMicroGemm):
{%- for b in range(block_m - 1, 0, -1) %}
case {{b}}:
{%- if not trans_b %}
{{kernel_name}}_kernel<{{b}}, {{block_n}}, accum>(
{{kernel_name}}_kernel<{{b}}, {{block_n}}, accum, prefetch>(
{%- else %}
{{kernel_name}}_transpose_b_kernel<{{b}}, {{block_n}}, accum>(
{{kernel_name}}_transpose_b_kernel<{{b}}, {{block_n}}, accum, prefetch>(
{%- endif %}
A + m * lda,
{%- if not trans_b %}
@ -459,9 +506,9 @@ class CppMicroGemmFP32Vec(CppMicroGemm):
{%- for b in range(block_m, 0, -1) %}
case {{b}}:
{%- if not trans_b %}
{{kernel_name}}_ntail_kernel<{{b}}, {{block_n}}, accum>(
{{kernel_name}}_ntail_kernel<{{b}}, {{block_n}}, accum, prefetch>(
{%- else %}
{{kernel_name}}_ntail_transpose_b_kernel<{{b}}, {{block_n}}, accum>(
{{kernel_name}}_ntail_transpose_b_kernel<{{b}}, {{block_n}}, accum, prefetch>(
{%- endif %}
A + m * lda,
{%- if not trans_b %}
@ -492,7 +539,7 @@ class CppMicroGemmFP32Vec(CppMicroGemm):
TEMPLATE_KERNEL = r"""
template <int64_t BLOCK_M, int64_t BLOCK_N, bool accum>
template <int64_t BLOCK_M, int64_t BLOCK_N, bool accum, bool prefetch=false>
{%- if not trans_b %}
{%- if tail_n %}
inline void {{kernel_name}}_ntail_kernel(
@ -592,6 +639,9 @@ inline void {{kernel_name}}_transpose_b_kernel(
{%- elif input2_dtype == torch.int8 %}
// Convert VLEN int8 elements to int32, and then fp32
auto b32 = at::vec::convert_to_int32<int8_t>(B + k * ldb + col * VLEN);
if constexpr (prefetch) {
_mm_prefetch(B + (k + {{block_k}}) * ldb + col * VLEN, _MM_HINT_T0);
}
vb[col] = at::vec::convert<float>(b32);
{%- else %}
vb[col] = Vectorized::loadu(B + k * ldb + col * VLEN);
@ -1065,7 +1115,7 @@ class CppMicroGemmAMX(CppMicroGemm):
TEMPLATE_KERNEL = r"""
template <bool accum>
template <bool accum, bool prefetch=false>
inline void {{kernel_name}}_amx_kernel_{{num_rows}}_{{num_columns}}(
AMXState& amx_state,
const {{input_t}}* {{restrict_keyword}} A,
@ -1855,6 +1905,11 @@ def create_micro_gemm(
use_ref=True,
q_group_size=None,
) -> Optional[CppMicroGemm]:
"""
Based on the provided info, try to find the config of the micro-kernel that would
deliver the best performance in terms of lower latency for this case.
"""
def create_from_config(cls, config: CppMicroGemmConfig):
return cls(
name,
@ -1868,8 +1923,11 @@ def create_micro_gemm(
assert isinstance(n, int) or n.is_number, n
assert isinstance(k, int) or k.is_number, k
m = V.graph.sizevars.size_hint(m, fallback=1) if isinstance(m, sympy.Expr) else m
assert isinstance(m, int), m
from ..utils import has_free_symbols
dynamic_M = has_free_symbols((m,))
m = V.graph.sizevars.size_hint(m, fallback=1) if dynamic_M else m
assert isinstance(m, int) or m.is_number, m
if output_dtype is None:
output_dtype = input_dtype
if compute_dtype is None:
@ -1894,17 +1952,26 @@ def create_micro_gemm(
# subject to change in the future.
):
if config.extra_check is not None and not config.extra_check(
config, m, n, k, alpha, num_threads, q_group_size=q_group_size
config,
m,
n,
k,
alpha,
num_threads,
dynamic_M=dynamic_M,
q_group_size=q_group_size,
):
continue
block_m, block_n, block_k = config.register_blocking
if (
config.vec_isa_cls == VecAMX
and m < block_m
and not dynamic_M
and input_dtype == torch.bfloat16
and input2_dtype in [torch.int8, torch.uint8]
):
# For WoQ GEMM, AMX micro-kernel may not perform well if m < block_m
# For WoQ GEMM, AMX micro-kernel may not perform well if m < block_m.
# Exception: for dynamic shapes, we consider using the AMX micro-kernel.
continue
# Criteria on the ranking of configurations
# 1. ISA: AMX > VEC

View File

@ -208,6 +208,19 @@ class CppTemplateKernel(CppKernel):
numel = f"{cexpr_index(buf.get_numel())}"
return f"auto _{name} = std::make_unique<{ctype}[]>({numel}); auto {name} = _{name}.get();"
def define_stack_allocated_buffer(
self, name, sizes: list[Any], dtype=torch.float
) -> str:
"""Define stack-allocated buffer"""
sizes = parse_expr_with_index_symbols(sizes)
buf = ir.Buffer(
name=name, layout=ir.FixedLayout(torch.device("cpu"), dtype, sizes)
)
self.local_buffers[name] = buf
ctype = f"{DTYPE_TO_CPP[dtype]}"
numel = f"{cexpr_index(buf.get_numel())}"
return f"alignas(64) {ctype} _{name}[{numel}]; {ctype}* {name} = _{name};"
def reinit_buffer_if_null(self, name):
"""Reinit the previously defined local buffer if it is null"""
assert name in self.local_buffers