mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Inductor-CPU] Faster int8 WoQ GEMM for small M with explicit prefetching and different outer loops (#149373)
### Summary Fixes #148494 Explicitly prefetch the cache lines of the next `B` block to accelerate int8 WoQ (BF16 activation, int8 statically quantized weights) GEMM for small `M` dimension. Some of this code (outer loops of the GEMM) is being ported over from Intel Extension for PyTorch. The macro-kernel* and the micro-kernel* are essentially the same, but optionally prefetch a block of B. Templatization is being used to prevent branching causing a slowdown due to unnecessary prefetching. \* - in [BLIS](https://dl.acm.org/doi/10.1145/2764454) parlance ### Performance data with BS 1 Machine: 32 cores of one socket of a Intel Xeon SP Gen 5 machine | Model | input tokens | output tokens | next-token latency before this PR | Next-token latency after this change | Speedup | |-----------|-------------|-----------------|--------------------------------------|------------------------------------------|-----------| |GPT-J | 128 | 128 | 42 ms | 38 ms | 9.52 % | | GPT-J | 1024 | 1024 | 48 ms | 45 ms | 6.25 % | |LLaMA 3.1 8B Instruct | 128 | 128 | 52 ms | 47 ms| 9.61% | |LLaMA 3.1 8B Instruct | 1024 | 1024 | 57 ms | 53 ms| 7.01% | While the input shapes of GEMMs corresponding to linear for next-token computation remain the same in case of different number of input & output tokens, the difference in next-token latency is due to attention for those cases Pull Request resolved: https://github.com/pytorch/pytorch/pull/149373 Approved by: https://github.com/leslie-fang-intel, https://github.com/Xia-Weiwen Co-authored-by: Xia Weiwen <xia.weiwen@hotmail.com>
This commit is contained in:
parent
e5e06d9cab
commit
7482eb217c
|
|
@ -1378,6 +1378,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
|
|||
@parametrize(
|
||||
"batch_size",
|
||||
(
|
||||
1,
|
||||
17,
|
||||
32,
|
||||
),
|
||||
|
|
@ -1429,6 +1430,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
|
|||
mod = M(w_int8pack).eval()
|
||||
self.common(mod, (x, w_scales))
|
||||
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
|
||||
if batch_size * mid_dim >= 16:
|
||||
vec_amx = VecAMX()
|
||||
self._check_amx_counter(vec_amx)
|
||||
|
||||
|
|
|
|||
|
|
@ -26,7 +26,9 @@ from .cpp_micro_gemm import (
|
|||
CppMicroBrgemm,
|
||||
CppMicroGemm,
|
||||
CppMicroGemmAMX,
|
||||
CppMicroGemmFP32Vec,
|
||||
create_micro_gemm,
|
||||
is_int8_woq_gemm_small_m_dim_corner_case,
|
||||
LayoutType,
|
||||
)
|
||||
from .cpp_template import CppTemplate
|
||||
|
|
@ -41,7 +43,7 @@ from .cpp_utils import (
|
|||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
GEMM_TEMPLATE_INIT_BLOCKING = r"""
|
||||
GEMM_TEMPLATE_INIT_BLOCKING_BASIC_BLOCK = r"""
|
||||
constexpr int64_t num_threads = {{num_threads}};
|
||||
constexpr int64_t N = {{N}};
|
||||
constexpr int64_t K = {{K}};
|
||||
|
|
@ -50,10 +52,17 @@ GEMM_TEMPLATE_INIT_BLOCKING = r"""
|
|||
constexpr int64_t Kr = {{micro_gemm.register_blocking.block_k}};
|
||||
constexpr int64_t Nr_blocks = (N + Nr - 1) / Nr;
|
||||
constexpr int64_t Kr_blocks = (K + Kr - 1) / Kr;
|
||||
|
||||
{%- if is_dynamic_M %}
|
||||
const int64_t M = {{kernel.size(GemmOut, 0)}};
|
||||
const int64_t Mr_blocks = (M + Mr - 1) / Mr;
|
||||
{%- else %}
|
||||
constexpr int64_t M = {{kernel.size(GemmOut, 0)}};
|
||||
constexpr int64_t Mr_blocks = (M + Mr - 1) / Mr;
|
||||
{%- endif %}
|
||||
"""
|
||||
|
||||
GEMM_TEMPLATE_INIT_BLOCKING_EXTENDED = r"""
|
||||
{%- if is_dynamic_M %}
|
||||
{%- if num_threads > 1 %}
|
||||
int64_t Mt_blocks, Nt_blocks, Kt_blocks;
|
||||
mm_get_thread_blocking(num_threads, {{config.cpp.gemm_max_k_slices}}, M, N, K, Mr, Nr, Kr, Mt_blocks, Nt_blocks, Kt_blocks);
|
||||
|
|
@ -88,8 +97,6 @@ GEMM_TEMPLATE_INIT_BLOCKING = r"""
|
|||
const int64_t num_Nt_blocks = (Nr_blocks + Nt_blocks - 1) / Nt_blocks;
|
||||
const int64_t num_Kt_blocks = (Kr_blocks + Kt_blocks - 1) / Kt_blocks;
|
||||
{%- else %}
|
||||
constexpr int64_t M = {{kernel.size(GemmOut, 0)}};
|
||||
constexpr int64_t Mr_blocks = (M + Mr - 1) / Mr;
|
||||
constexpr int64_t Mt_blocks = {{template.thread_blocking(num_threads).block_m}};
|
||||
constexpr int64_t Nt_blocks = {{template.thread_blocking(num_threads).block_n}};
|
||||
constexpr int64_t Kt_blocks = {{template.thread_blocking(num_threads).block_k}};
|
||||
|
|
@ -318,6 +325,52 @@ GEMM_TEMPLATE = r"""
|
|||
}
|
||||
"""
|
||||
|
||||
SMALL_M_GEMM_TEMPLATE = r"""
|
||||
{{ template.codegen_gemm_stub_def() }}
|
||||
{
|
||||
{{ kernel.maybe_codegen_profile() }}
|
||||
{{ template.codegen_blocks(
|
||||
num_threads, N, K, micro_gemm, is_dynamic_M, kernel, GemmOut, config, L1_cache_size, L2_cache_size, X, W
|
||||
) }}
|
||||
# pragma omp parallel
|
||||
{
|
||||
#pragma omp for nowait
|
||||
for (int64_t nr_block_id = 0; nr_block_id < Nr_blocks; nr_block_id++) {
|
||||
// Handle one output M * Nr block in each thread
|
||||
int64_t n_start = nr_block_id * Nr;
|
||||
int64_t n_end = (nr_block_id + 1) * Nr;
|
||||
{%- if use_local_acc %}
|
||||
{%- set acc_buf_name = "local_acc_buf" %}
|
||||
{{ kernel.define_stack_allocated_buffer(acc_buf_name, ["M", "Nr"], acc_buf_dtype) }}
|
||||
{%- set acc = kernel.local_buffers[acc_buf_name] %}
|
||||
{%- else %}
|
||||
{%- set acc = kernel.slice_nd(GemmOut, [(0, "M"), ("n_start", "n_end")]) %}
|
||||
{%- endif %}
|
||||
for (int64_t kr_block_id = 0; kr_block_id < Kr_blocks; kr_block_id++) {
|
||||
// this loop is not parallelized
|
||||
int64_t k_start = kr_block_id * Kr;
|
||||
int64_t k_end = std::min((kr_block_id + 1) * Kr, K);
|
||||
{%- set tile_X = kernel.slice_nd(X, [(0, "M"), ("k_start", "k_end")]) %}
|
||||
{%- set tile_W_3d = kernel.slice_nd(W, [("nr_block_id", "nr_block_id + 1"), ("k_start", "k_end"), ()]) %}
|
||||
{%- set tile_W = kernel.view(tile_W_3d, ["k_end - k_start", micro_gemm.register_blocking.block_n]) %}
|
||||
if C10_UNLIKELY(kr_block_id == 0) {
|
||||
{{ micro_gemm.codegen_call(kernel, tile_X, tile_W, acc, accum=False, prefetch=True)|indent(20, false) }}
|
||||
} else if C10_UNLIKELY(k_end == K) {
|
||||
{{ micro_gemm.codegen_call(kernel, tile_X, tile_W, acc, accum=True, prefetch=False)|indent(20, false) }}
|
||||
} else {
|
||||
{{ micro_gemm.codegen_call(kernel, tile_X, tile_W, acc, accum=True, prefetch=True)|indent(20, false) }}
|
||||
}
|
||||
}
|
||||
{%- set tile_Y = kernel.slice_nd(Y_2d, [("0", "M"), ("n_start", "n_end")]) %}
|
||||
{%- set tile_acc = kernel.slice_nd(acc, [("0", "M"), ("0", "n_end - n_start")]) %}
|
||||
{{ kernel.store_output(
|
||||
tile_Y, tile_acc, GemmOut, epilogue_nodes, offsets=("0", "n_start"), reindexers=reindexers
|
||||
)|indent(20, false) }}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
def _is_int8_gemm(inputs):
|
||||
return (
|
||||
|
|
@ -1485,6 +1538,24 @@ class CppGemmTemplate(CppTemplate):
|
|||
)
|
||||
return options
|
||||
|
||||
def is_int8_woq_gemm_small_m_dim(
|
||||
self,
|
||||
X: ir.ReinterpretView,
|
||||
W: ir.ReinterpretView,
|
||||
N,
|
||||
K,
|
||||
micro_gemm,
|
||||
):
|
||||
"""Use SMALL_M_GEMM_TEMPLATE"""
|
||||
return (
|
||||
isinstance(micro_gemm, CppMicroGemmFP32Vec)
|
||||
and is_int8_woq_gemm_small_m_dim_corner_case(
|
||||
micro_gemm, X.get_size()[0], N, K
|
||||
)
|
||||
and X.get_dtype() is torch.bfloat16
|
||||
and W.get_dtype() is torch.int8
|
||||
)
|
||||
|
||||
def render( # type: ignore[override, return]
|
||||
self,
|
||||
kernel: CppTemplateKernel,
|
||||
|
|
@ -1506,7 +1577,17 @@ class CppGemmTemplate(CppTemplate):
|
|||
stack.enter_context(
|
||||
patch.object(V.graph, "get_dtype", self._fake_get_dtype(buf))
|
||||
)
|
||||
return self._template_from_string(GEMM_TEMPLATE).render(**options)
|
||||
if not options["is_dynamic_M"] and self.is_int8_woq_gemm_small_m_dim(
|
||||
options["X"],
|
||||
options["W"],
|
||||
options["N"],
|
||||
options["K"],
|
||||
options["micro_gemm"],
|
||||
):
|
||||
template_str = SMALL_M_GEMM_TEMPLATE
|
||||
else:
|
||||
template_str = GEMM_TEMPLATE
|
||||
return self._template_from_string(template_str).render(**options)
|
||||
|
||||
def codegen_blocks(
|
||||
self,
|
||||
|
|
@ -1539,7 +1620,13 @@ class CppGemmTemplate(CppTemplate):
|
|||
W=W,
|
||||
is_woq_int4=self.is_woq_int4(),
|
||||
)
|
||||
return self._template_from_string(GEMM_TEMPLATE_INIT_BLOCKING).render(options)
|
||||
template_str = GEMM_TEMPLATE_INIT_BLOCKING_BASIC_BLOCK
|
||||
if not (
|
||||
not is_dynamic_M
|
||||
and self.is_int8_woq_gemm_small_m_dim(X, W, N, K, micro_gemm)
|
||||
):
|
||||
template_str += GEMM_TEMPLATE_INIT_BLOCKING_EXTENDED
|
||||
return self._template_from_string(template_str).render(options)
|
||||
|
||||
def codegen_microkernel_def(self):
|
||||
return self._template_from_string(GEMM_TEMPLATE_MICROKERNEL_DEF).render(
|
||||
|
|
|
|||
|
|
@ -4,8 +4,6 @@ import sys
|
|||
from enum import Enum
|
||||
from typing import Callable, Optional
|
||||
|
||||
import sympy
|
||||
|
||||
import torch
|
||||
|
||||
from .. import cpp_builder, ir
|
||||
|
|
@ -55,7 +53,7 @@ class CppMicroGemm:
|
|||
|
||||
# TODO(jgong5): support constant shapes and lds as template args.
|
||||
DECLARE_KERNEL = r"""
|
||||
template <bool accum>
|
||||
template <bool accum, bool prefetch=false>
|
||||
inline void {{kernel_name}}(
|
||||
{%- if kernel_extra_args_declare %}
|
||||
{{kernel_extra_args_declare}}
|
||||
|
|
@ -138,6 +136,7 @@ inline void {{kernel_name}}(
|
|||
B: ir.Buffer,
|
||||
C: ir.Buffer,
|
||||
accum: bool,
|
||||
prefetch: bool = False,
|
||||
**kwargs_for_extra_args,
|
||||
) -> str:
|
||||
"""
|
||||
|
|
@ -154,7 +153,9 @@ inline void {{kernel_name}}(
|
|||
ldb = kernel.stride(B, 0)
|
||||
ldc = kernel.stride(C, 0)
|
||||
res = IndentedBuffer()
|
||||
res.writeline(f"{self.name}<{value_to_cpp(accum, 'bool')}>(")
|
||||
res.writeline(
|
||||
f"{self.name}<{value_to_cpp(accum, 'bool')}, {value_to_cpp(prefetch, 'bool')}>("
|
||||
)
|
||||
with res.indent():
|
||||
kwargs_for_extra_args.update({"kernel": kernel})
|
||||
extra_args = self.get_kernel_extra_args(**kwargs_for_extra_args)
|
||||
|
|
@ -317,6 +318,26 @@ class CppMicroGemmRef(CppMicroGemm):
|
|||
return KernelTemplate._template_from_string(self.TEMPLATE_ENTRY).render(options)
|
||||
|
||||
|
||||
def is_int8_woq_gemm_small_m_dim_corner_case(config, m, n, k):
|
||||
return (
|
||||
k % config.register_blocking.block_k == 0
|
||||
and n % config.register_blocking.block_n == 0
|
||||
and m < 16
|
||||
)
|
||||
|
||||
|
||||
# extra check for small M dimension for int8 WoQ case
|
||||
def check_int8_woq_small_m_dim(config, m, n, k, alpha, num_threads, **kwargs):
|
||||
return is_int8_woq_gemm_small_m_dim_corner_case(config, m, n, k) and not kwargs.get(
|
||||
"dynamic_M", False
|
||||
)
|
||||
|
||||
|
||||
# For int8 WoQ GEMM with small M, we use different blockings that shouldn't be used otherwise
|
||||
def do_not_use_with_small_m_for_int8_woq(config, m, n, k, alpha, num_threads, **kwargs):
|
||||
return not check_int8_woq_small_m_dim(config, m, n, k, alpha, num_threads, **kwargs)
|
||||
|
||||
|
||||
@register_micro_gemm(
|
||||
*generate_gemm_config(
|
||||
VecAVX512,
|
||||
|
|
@ -342,6 +363,19 @@ class CppMicroGemmRef(CppMicroGemm):
|
|||
input2_dtype=torch.int8,
|
||||
output_dtype=torch.float,
|
||||
compute_dtype=torch.float,
|
||||
extra_check=do_not_use_with_small_m_for_int8_woq,
|
||||
),
|
||||
*generate_gemm_config(
|
||||
VecAVX512,
|
||||
[
|
||||
(4, 32, 64),
|
||||
(8, 32, 64),
|
||||
],
|
||||
input_dtype=torch.bfloat16,
|
||||
input2_dtype=torch.int8,
|
||||
output_dtype=torch.float,
|
||||
compute_dtype=torch.float,
|
||||
extra_check=check_int8_woq_small_m_dim,
|
||||
),
|
||||
*generate_gemm_config(
|
||||
VecAVX2,
|
||||
|
|
@ -367,6 +401,19 @@ class CppMicroGemmRef(CppMicroGemm):
|
|||
input2_dtype=torch.int8,
|
||||
output_dtype=torch.float,
|
||||
compute_dtype=torch.float,
|
||||
extra_check=do_not_use_with_small_m_for_int8_woq,
|
||||
),
|
||||
*generate_gemm_config(
|
||||
VecAVX2,
|
||||
[
|
||||
(2, 16, 64),
|
||||
(4, 16, 64),
|
||||
],
|
||||
input_dtype=torch.bfloat16,
|
||||
input2_dtype=torch.int8,
|
||||
output_dtype=torch.float,
|
||||
compute_dtype=torch.float,
|
||||
extra_check=check_int8_woq_small_m_dim,
|
||||
),
|
||||
*generate_gemm_config(
|
||||
VecNEON,
|
||||
|
|
@ -397,7 +444,7 @@ class CppMicroGemmFP32Vec(CppMicroGemm):
|
|||
{{declare_kernel}} {
|
||||
using Vectorized = at::vec::Vectorized<{{compute_t}}>;
|
||||
constexpr auto VLEN = Vectorized::size();
|
||||
{{kernel.assert_function}}({{block_n}} % VLEN == 0, "{{block_n}} dimension must be multiple of Vector size");
|
||||
{{kernel.assert_function}}({{block_n}} % VLEN == 0, "block_n dimension must be multiple of Vector size");
|
||||
{{kernel.assert_function}}(K % {{block_k}} == 0, "K dimension must be multiple of {{block_k}}");
|
||||
// TODO(jgong5): loop unroll for M and N
|
||||
for (int64_t m = 0; m < M; m += {{block_m}}) {
|
||||
|
|
@ -406,9 +453,9 @@ class CppMicroGemmFP32Vec(CppMicroGemm):
|
|||
int64_t block_n = std::min<int64_t>(N - n, {{block_n}});
|
||||
if (block_m == {{block_m}} && block_n == {{block_n}}) {
|
||||
{%- if not trans_b %}
|
||||
{{kernel_name}}_kernel<{{block_m}}, {{block_n}}, accum>(
|
||||
{{kernel_name}}_kernel<{{block_m}}, {{block_n}}, accum, prefetch>(
|
||||
{%- else %}
|
||||
{{kernel_name}}_transpose_b_kernel<{{block_m}}, {{block_n}}, accum>(
|
||||
{{kernel_name}}_transpose_b_kernel<{{block_m}}, {{block_n}}, accum, prefetch>(
|
||||
{%- endif %}
|
||||
A + m * lda,
|
||||
{%- if not trans_b %}
|
||||
|
|
@ -431,9 +478,9 @@ class CppMicroGemmFP32Vec(CppMicroGemm):
|
|||
{%- for b in range(block_m - 1, 0, -1) %}
|
||||
case {{b}}:
|
||||
{%- if not trans_b %}
|
||||
{{kernel_name}}_kernel<{{b}}, {{block_n}}, accum>(
|
||||
{{kernel_name}}_kernel<{{b}}, {{block_n}}, accum, prefetch>(
|
||||
{%- else %}
|
||||
{{kernel_name}}_transpose_b_kernel<{{b}}, {{block_n}}, accum>(
|
||||
{{kernel_name}}_transpose_b_kernel<{{b}}, {{block_n}}, accum, prefetch>(
|
||||
{%- endif %}
|
||||
A + m * lda,
|
||||
{%- if not trans_b %}
|
||||
|
|
@ -459,9 +506,9 @@ class CppMicroGemmFP32Vec(CppMicroGemm):
|
|||
{%- for b in range(block_m, 0, -1) %}
|
||||
case {{b}}:
|
||||
{%- if not trans_b %}
|
||||
{{kernel_name}}_ntail_kernel<{{b}}, {{block_n}}, accum>(
|
||||
{{kernel_name}}_ntail_kernel<{{b}}, {{block_n}}, accum, prefetch>(
|
||||
{%- else %}
|
||||
{{kernel_name}}_ntail_transpose_b_kernel<{{b}}, {{block_n}}, accum>(
|
||||
{{kernel_name}}_ntail_transpose_b_kernel<{{b}}, {{block_n}}, accum, prefetch>(
|
||||
{%- endif %}
|
||||
A + m * lda,
|
||||
{%- if not trans_b %}
|
||||
|
|
@ -492,7 +539,7 @@ class CppMicroGemmFP32Vec(CppMicroGemm):
|
|||
|
||||
TEMPLATE_KERNEL = r"""
|
||||
|
||||
template <int64_t BLOCK_M, int64_t BLOCK_N, bool accum>
|
||||
template <int64_t BLOCK_M, int64_t BLOCK_N, bool accum, bool prefetch=false>
|
||||
{%- if not trans_b %}
|
||||
{%- if tail_n %}
|
||||
inline void {{kernel_name}}_ntail_kernel(
|
||||
|
|
@ -592,6 +639,9 @@ inline void {{kernel_name}}_transpose_b_kernel(
|
|||
{%- elif input2_dtype == torch.int8 %}
|
||||
// Convert VLEN int8 elements to int32, and then fp32
|
||||
auto b32 = at::vec::convert_to_int32<int8_t>(B + k * ldb + col * VLEN);
|
||||
if constexpr (prefetch) {
|
||||
_mm_prefetch(B + (k + {{block_k}}) * ldb + col * VLEN, _MM_HINT_T0);
|
||||
}
|
||||
vb[col] = at::vec::convert<float>(b32);
|
||||
{%- else %}
|
||||
vb[col] = Vectorized::loadu(B + k * ldb + col * VLEN);
|
||||
|
|
@ -1065,7 +1115,7 @@ class CppMicroGemmAMX(CppMicroGemm):
|
|||
|
||||
TEMPLATE_KERNEL = r"""
|
||||
|
||||
template <bool accum>
|
||||
template <bool accum, bool prefetch=false>
|
||||
inline void {{kernel_name}}_amx_kernel_{{num_rows}}_{{num_columns}}(
|
||||
AMXState& amx_state,
|
||||
const {{input_t}}* {{restrict_keyword}} A,
|
||||
|
|
@ -1855,6 +1905,11 @@ def create_micro_gemm(
|
|||
use_ref=True,
|
||||
q_group_size=None,
|
||||
) -> Optional[CppMicroGemm]:
|
||||
"""
|
||||
Based on the provided info, try to find the config of the micro-kernel that would
|
||||
deliver the best performance in terms of lower latency for this case.
|
||||
"""
|
||||
|
||||
def create_from_config(cls, config: CppMicroGemmConfig):
|
||||
return cls(
|
||||
name,
|
||||
|
|
@ -1868,8 +1923,11 @@ def create_micro_gemm(
|
|||
|
||||
assert isinstance(n, int) or n.is_number, n
|
||||
assert isinstance(k, int) or k.is_number, k
|
||||
m = V.graph.sizevars.size_hint(m, fallback=1) if isinstance(m, sympy.Expr) else m
|
||||
assert isinstance(m, int), m
|
||||
from ..utils import has_free_symbols
|
||||
|
||||
dynamic_M = has_free_symbols((m,))
|
||||
m = V.graph.sizevars.size_hint(m, fallback=1) if dynamic_M else m
|
||||
assert isinstance(m, int) or m.is_number, m
|
||||
if output_dtype is None:
|
||||
output_dtype = input_dtype
|
||||
if compute_dtype is None:
|
||||
|
|
@ -1894,17 +1952,26 @@ def create_micro_gemm(
|
|||
# subject to change in the future.
|
||||
):
|
||||
if config.extra_check is not None and not config.extra_check(
|
||||
config, m, n, k, alpha, num_threads, q_group_size=q_group_size
|
||||
config,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
alpha,
|
||||
num_threads,
|
||||
dynamic_M=dynamic_M,
|
||||
q_group_size=q_group_size,
|
||||
):
|
||||
continue
|
||||
block_m, block_n, block_k = config.register_blocking
|
||||
if (
|
||||
config.vec_isa_cls == VecAMX
|
||||
and m < block_m
|
||||
and not dynamic_M
|
||||
and input_dtype == torch.bfloat16
|
||||
and input2_dtype in [torch.int8, torch.uint8]
|
||||
):
|
||||
# For WoQ GEMM, AMX micro-kernel may not perform well if m < block_m
|
||||
# For WoQ GEMM, AMX micro-kernel may not perform well if m < block_m.
|
||||
# Exception: for dynamic shapes, we consider using the AMX micro-kernel.
|
||||
continue
|
||||
# Criteria on the ranking of configurations
|
||||
# 1. ISA: AMX > VEC
|
||||
|
|
|
|||
|
|
@ -208,6 +208,19 @@ class CppTemplateKernel(CppKernel):
|
|||
numel = f"{cexpr_index(buf.get_numel())}"
|
||||
return f"auto _{name} = std::make_unique<{ctype}[]>({numel}); auto {name} = _{name}.get();"
|
||||
|
||||
def define_stack_allocated_buffer(
|
||||
self, name, sizes: list[Any], dtype=torch.float
|
||||
) -> str:
|
||||
"""Define stack-allocated buffer"""
|
||||
sizes = parse_expr_with_index_symbols(sizes)
|
||||
buf = ir.Buffer(
|
||||
name=name, layout=ir.FixedLayout(torch.device("cpu"), dtype, sizes)
|
||||
)
|
||||
self.local_buffers[name] = buf
|
||||
ctype = f"{DTYPE_TO_CPP[dtype]}"
|
||||
numel = f"{cexpr_index(buf.get_numel())}"
|
||||
return f"alignas(64) {ctype} _{name}[{numel}]; {ctype}* {name} = _{name};"
|
||||
|
||||
def reinit_buffer_if_null(self, name):
|
||||
"""Reinit the previously defined local buffer if it is null"""
|
||||
assert name in self.local_buffers
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user