[Inductor-CPU] Fix broken int8 WoQ GEMM AMX implementation in main (#147895)

#146843 broke int8 WoQ GEMM's (for BF16 activation) AMX ISA implementation in the main branch.
UT: `python test/inductor/test_cpu_select_algorithm.py -v -k woq`

The issue remained undetected because in case of templated kernel compilation failure, the auto-tuning infra marks its runtime as `inf`, and the op against which it was being benchmarked is used, so UTs didn't fail even on machines that support AMX ISA.

`test/inductor/test_cpu_select_algorithm.py` UTs checked the value of the `select_algorithm_autotune` counter, which only counts how many ops were selected for autotuning against their templated codegened counterparts.

@leslie-fang-intel advised using a new counter. I added `counters["inductor"]["cpp_templated_kernel_counter"]`, which is incremented after a codegened kernel's compilation, so it'd help catch breakage scenarios in which a templated kernel could not be codegened due to a compilation failure.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147895
Approved by: https://github.com/jgong5, https://github.com/leslie-fang-intel
This commit is contained in:
sanchitintel 2025-02-28 20:20:41 +00:00 committed by PyTorch MergeBot
parent e0e516c554
commit 5a1954eb93
3 changed files with 59 additions and 50 deletions

View File

@ -176,9 +176,9 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
or counters["inductor"]["decompose_addmm"] > 0
):
# This is a special case where we go directly with vectorized codegen
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 0)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 0)
else:
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
@inductor_config.patch({"freezing": True})
@patches
@ -207,7 +207,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
v = torch.LongTensor([[1, 2, 4, 5], [4, 3, 2, 9]])
with verify(dtype) as (atol, rtol):
self.common(mod, (v,), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
@inductor_config.patch({"freezing": True})
@patches
@ -234,7 +234,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
v = torch.randn(in_features, batch_size).to(dtype=dtype)
self.common(mod, (v.transpose(0, 1),))
# TODO(jgong5): support transposed input
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 0)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 0)
@inductor_config.patch({"freezing": True})
@patches
@ -282,7 +282,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
mod = M(bias=bias, epilogue=epilogue, other=u).to(dtype=dtype).eval()
with verify(dtype) as (atol, rtol):
self.common(mod, (v,), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
if (
(
(
@ -360,7 +360,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
mod = M(bias=bias, epilogue=epilogue, other=other).to(dtype=dtype).eval()
with verify(dtype) as (atol, rtol):
self.common(mod, (v, u), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1)
@inductor_config.patch({"freezing": True})
@ -462,7 +462,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
atol=atol,
rtol=rtol,
)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 2)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 2)
self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 2)
@inductor_config.patch({"freezing": True})
@ -533,8 +533,8 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
mod = M(bias=bias).to(dtype=dtype).eval()
with verify(dtype) as (atol, rtol):
self.common(mod, (v,), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
@inductor_config.patch({"freezing": True})
@patches
@ -661,7 +661,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
atol=atol,
rtol=rtol,
)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 2)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 2)
# TODO: change cpp_epilogue_fusion_counter to 1 once supported
self.assertEqual(
counters["inductor"]["cpp_epilogue_fusion_counter"], 1 if epilogue else 0
@ -708,8 +708,8 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
mod = M(bias=bias, unary=unary, binary=binary, other=u).to(dtype=dtype).eval()
with verify(dtype) as (atol, rtol):
self.common(mod, (v,), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
@inductor_config.patch({"freezing": True})
@patches
@ -743,7 +743,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
mod = M(bias=bias, binary=binary, other=u).to(dtype=dtype).eval()
with verify(dtype) as (atol, rtol):
self.common(mod, (v,), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
@inductor_config.patch({"freezing": True})
@patches
@ -796,9 +796,12 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
mod = M(bias=bias).to(dtype=dtype).eval()
with verify(dtype) as (atol, rtol):
self.common(mod, (v,), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 3)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 3)
self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 2)
@unittest.skipIf(
not torch._C._cpu._is_amx_tile_supported(), "AMX ISA support is required"
)
@inductor_config.patch({"freezing": True})
@patches
@torch.no_grad
@ -821,7 +824,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
mod = M(bias=bias).to(dtype=dtype).eval()
with verify(dtype) as (atol, rtol):
self.common(mod, (v,), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
vec_amx = VecAMX()
# Currently brgemm config is only added for half
if dtype == torch.half:
@ -918,7 +921,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
atol=atol,
rtol=rtol,
)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 2)
@inductor_config.patch({"freezing": True})
@ -948,7 +951,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
mod = M(bias=bias).eval()
with verify(dtype) as (atol, rtol):
self.common(mod, (idx, x), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1)
@inductor_config.patch({"freezing": True})
@ -1005,7 +1008,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
atol=atol,
rtol=rtol,
)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1)
@inductor_config.patch({"freezing": True})
@ -1295,7 +1298,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
atol=atol,
rtol=rtol,
)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 2)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 2)
self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 2)
@inductor_config.patch({"freezing": True})
@ -1362,7 +1365,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
equal_nan=True,
exact_dtype=True,
)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 2)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 2)
self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 0)
@inductor_config.patch({"freezing": True})
@ -1412,10 +1415,13 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
w_int8pack, w_scales = _convert_weight_to_int8pack(w)
mod = M(w_int8pack).eval()
self.common(mod, (x, w_scales))
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
vec_amx = VecAMX()
self._check_amx_counter(vec_amx)
@unittest.skipIf(
not torch._C._cpu._is_amx_tile_supported(), "AMX ISA support is required"
)
@inductor_config.patch({"freezing": True})
@patches
@torch.no_grad
@ -1498,6 +1504,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
vec_amx = VecAMX()
self._check_amx_counter(vec_amx)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
@inductor_config.patch({"freezing": True})
@patches
@ -1582,15 +1589,15 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
equal_nan=True,
exact_dtype=True,
)
self.assertEqual(
counters["inductor"]["select_algorithm_autotune"],
2,
)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 2)
self.assertEqual(
counters["inductor"]["cpp_epilogue_fusion_counter"],
0,
)
@unittest.skipIf(
not torch._C._cpu._is_amx_tile_supported(), "AMX ISA support is required"
)
@inductor_config.patch({"freezing": True})
@patches
@torch.no_grad
@ -1617,7 +1624,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
atol, rtol = 1e-2, 1e-2
with patch.object(select_algorithm, "VERIFY", dict(atol=atol, rtol=rtol)):
self.common(ref_quantized_mod, (v,), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
vec_amx = VecAMX()
self._check_amx_counter(vec_amx)
@ -1656,7 +1663,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
mod = M(bias=bias, epilogue=epilogue, other=u).to(dtype=dtype).eval()
with verify(dtype) as (atol, rtol):
self.common(mod, (v,), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
@inductor_config.patch({"freezing": True})
@inductor_config.patch({"cpp.gemm_cache_blocking": "2,2,2"})
@ -1685,7 +1692,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
mod = M(bias=bias).to(dtype=dtype).eval()
with verify(dtype) as (atol, rtol):
self.common(mod, (v,), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
@inductor_config.patch({"freezing": True})
@inductor_config.patch({"cpp.gemm_thread_factors": "4,2,7"})
@ -1714,7 +1721,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
mod = M(bias=bias).to(dtype=dtype).eval()
with verify(dtype) as (atol, rtol):
self.common(mod, (v,), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
@inductor_config.patch({"freezing": False})
@patches
@ -1764,7 +1771,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
(v,),
)
self.assertEqual(actual, expected, atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
@inductor_config.patch({"freezing": True})
@inductor_config.patch({"cpp.enable_grouped_gemm_template": True})
@ -1811,7 +1818,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
self.common(mod, (v,), atol=atol, rtol=rtol)
# gemm_num independent template instead of grouped gemm template
self.assertEqual(
counters["inductor"]["select_algorithm_autotune"], gemm_num
counters["inductor"]["cpp_templated_kernel_counter"], gemm_num
)
self.assertEqual(counters["inductor"]["cpp_grouped_gemm_template"], 0)
@ -2000,7 +2007,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
(v,),
)
self.assertEqual(actual, expected, atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
@inductor_config.patch({"freezing": True})
@inductor_config.patch({"coordinate_descent_tuning": True})
@ -2023,7 +2030,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
counters.clear()
with verify(torch.bfloat16) as (atol, rtol), torch.autocast(device_type="cpu"):
self.common(mod, (v,), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
@inductor_config.patch({"freezing": True})
@patches
@ -2049,7 +2056,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
v = torch.randn(*B, in_features).to(dtype=dtype)
with verify(dtype) as (atol, rtol):
self.common(mod, (v,), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
@inductor_config.patch({"freezing": True})
@patches
@ -2071,7 +2078,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
counters.clear()
with verify(torch.bfloat16) as (atol, rtol):
self.common(mod, (v,), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
self.assertEqual(counters["inductor"]["select_algorithm_weight_prune"], 1)
@patches
@ -2096,7 +2103,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
mod = M().to(dtype=dtype).eval()
with verify(dtype) as (atol, rtol):
self.common(mod, (u, v), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
@patches
@torch.no_grad
@ -2120,7 +2127,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
mod = M().to(dtype=dtype).eval()
with verify(dtype) as (atol, rtol):
self.common(mod, (u, v), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
vec_amx = VecAMX()
# Currently brgemm config is only added for half
if dtype == torch.half:
@ -2150,7 +2157,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
mod = M().to(dtype=dtype).eval()
with verify(dtype) as (atol, rtol), torch.amp.autocast("cpu"):
self.common(mod, (u, v), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
@inductor_config.patch({"freezing": True})
@patches
@ -2176,7 +2183,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
mod = M(v).to(dtype=dtype).eval()
with verify(dtype) as (atol, rtol):
self.common(mod, (u,), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
@patches
@torch.no_grad
@ -2225,7 +2232,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
with verify(dtype) as (atol, rtol):
self.common(mod, (u, v), atol=atol, rtol=rtol)
self.assertEqual(
counters["inductor"]["select_algorithm_autotune"],
counters["inductor"]["cpp_templated_kernel_counter"],
1 if order[0] == (0, 1, 2) else 0,
)
@ -2249,7 +2256,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
mod = M().to(dtype=dtype).eval()
with verify(dtype) as (atol, rtol):
self.common(mod, (u,), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
@patches
@torch.no_grad
@ -2270,7 +2277,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
mod = M().to(dtype=dtype).eval()
with verify(dtype) as (atol, rtol):
self.common(mod, (u,), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
@patches
@torch.no_grad
@ -2306,7 +2313,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
mod = M(epilogue, other).to(dtype=dtype).eval()
with verify(dtype) as (atol, rtol):
self.common(mod, (x, w), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1)
@patches
@ -2339,7 +2346,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
mod = M().to(dtype=dtype).eval()
with verify(dtype) as (atol, rtol):
self.common(mod, (x, w), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1)
@patches
@ -2375,7 +2382,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
(x, w),
)
self.assertEqual(actual, expected, atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 2)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 2)
@patches
@torch.no_grad
@ -2432,7 +2439,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
(x,),
)
self.assertEqual(actual, expected, atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 2)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 2)
@dynamo_config.patch({"dynamic_shapes": True, "assume_static_by_default": False})
@ -2504,7 +2511,7 @@ class TestSelectAlgorithmDynamicShapes(_DynamicShapesTestBase):
mod = M().to(dtype=dtype).eval()
with verify(dtype) as (atol, rtol):
self.common(mod, (u, v), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1)
@patches
@ -2540,7 +2547,7 @@ class TestSelectAlgorithmDynamicShapes(_DynamicShapesTestBase):
mod = M().to(dtype=dtype).eval()
with verify(dtype) as (atol, rtol):
self.common(mod, (u, v, noise), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1)
@patches
@ -2600,7 +2607,7 @@ class TestSelectAlgorithmDynamicShapes(_DynamicShapesTestBase):
mod = M().to(dtype=dtype).eval()
with verify(dtype) as (atol, rtol):
self.common(mod, (u, v, arg5), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1)
@patches
@ -2630,7 +2637,7 @@ class TestSelectAlgorithmDynamicShapes(_DynamicShapesTestBase):
mod = M().to(dtype=dtype).eval()
with verify(dtype) as (atol, rtol):
self.common(mod, (u, v), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
instantiate_device_type_tests(TestSelectAlgorithm, globals(), only_for="cpu")

View File

@ -4967,7 +4967,9 @@ class CppScheduling(BaseScheduling):
for epilogue_node in epilogue_nodes
if isinstance(epilogue_node, (SchedulerNode, FusedSchedulerNode))
]
# The counter cpp_templated_kernel_counter is used for verifying if a
# a templated kernel was successfully compiled in a UT
counters["inductor"]["cpp_templated_kernel_counter"] += 1
counters["inductor"]["cpp_epilogue_fusion_counter"] += len(epilogue_nodes)
assert self.is_cpp_template(template_node), (
"Template node passed to CppScheduler.codegen_template must be a SchedulerNode that wraps a CppTemplateBuffer"

View File

@ -584,7 +584,7 @@ class CppMicroGemmAMX(CppMicroGemm):
// Except maybe for the tail-case, an AMX tile of B has 16x32 BF16 elements.
// we cache K * {{block_n}} elements of dequantized B
{{template.codegen_allocate_weight_buffer("dequantized_B_buf", input_t, "K", block_n)}}
const auto buf_size = K * {{block_n}};
auto load_dequantized_B = [&](int base_idx) {
// Load a tile of B & cache it in L1D.
{{input2_t}}* base_addr = const_cast<{{input2_t}}*>(B) + base_idx;