mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Inductor-CPU] Fix broken int8 WoQ GEMM AMX implementation in main (#147895)
#146843 broke int8 WoQ GEMM's (for BF16 activation) AMX ISA implementation in the main branch. UT: `python test/inductor/test_cpu_select_algorithm.py -v -k woq` The issue remained undetected because in case of templated kernel compilation failure, the auto-tuning infra marks its runtime as `inf`, and the op against which it was being benchmarked is used, so UTs didn't fail even on machines that support AMX ISA. `test/inductor/test_cpu_select_algorithm.py` UTs checked the value of the `select_algorithm_autotune` counter, which only counts how many ops were selected for autotuning against their templated codegened counterparts. @leslie-fang-intel advised using a new counter. I added `counters["inductor"]["cpp_templated_kernel_counter"]`, which is incremented after a codegened kernel's compilation, so it'd help catch breakage scenarios in which a templated kernel could not be codegened due to a compilation failure. Pull Request resolved: https://github.com/pytorch/pytorch/pull/147895 Approved by: https://github.com/jgong5, https://github.com/leslie-fang-intel
This commit is contained in:
parent
e0e516c554
commit
5a1954eb93
|
|
@ -176,9 +176,9 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
|
|||
or counters["inductor"]["decompose_addmm"] > 0
|
||||
):
|
||||
# This is a special case where we go directly with vectorized codegen
|
||||
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 0)
|
||||
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 0)
|
||||
else:
|
||||
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
|
||||
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
|
||||
|
||||
@inductor_config.patch({"freezing": True})
|
||||
@patches
|
||||
|
|
@ -207,7 +207,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
|
|||
v = torch.LongTensor([[1, 2, 4, 5], [4, 3, 2, 9]])
|
||||
with verify(dtype) as (atol, rtol):
|
||||
self.common(mod, (v,), atol=atol, rtol=rtol)
|
||||
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
|
||||
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
|
||||
|
||||
@inductor_config.patch({"freezing": True})
|
||||
@patches
|
||||
|
|
@ -234,7 +234,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
|
|||
v = torch.randn(in_features, batch_size).to(dtype=dtype)
|
||||
self.common(mod, (v.transpose(0, 1),))
|
||||
# TODO(jgong5): support transposed input
|
||||
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 0)
|
||||
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 0)
|
||||
|
||||
@inductor_config.patch({"freezing": True})
|
||||
@patches
|
||||
|
|
@ -282,7 +282,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
|
|||
mod = M(bias=bias, epilogue=epilogue, other=u).to(dtype=dtype).eval()
|
||||
with verify(dtype) as (atol, rtol):
|
||||
self.common(mod, (v,), atol=atol, rtol=rtol)
|
||||
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
|
||||
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
|
||||
if (
|
||||
(
|
||||
(
|
||||
|
|
@ -360,7 +360,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
|
|||
mod = M(bias=bias, epilogue=epilogue, other=other).to(dtype=dtype).eval()
|
||||
with verify(dtype) as (atol, rtol):
|
||||
self.common(mod, (v, u), atol=atol, rtol=rtol)
|
||||
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
|
||||
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
|
||||
self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1)
|
||||
|
||||
@inductor_config.patch({"freezing": True})
|
||||
|
|
@ -462,7 +462,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
|
|||
atol=atol,
|
||||
rtol=rtol,
|
||||
)
|
||||
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 2)
|
||||
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 2)
|
||||
self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 2)
|
||||
|
||||
@inductor_config.patch({"freezing": True})
|
||||
|
|
@ -533,8 +533,8 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
|
|||
mod = M(bias=bias).to(dtype=dtype).eval()
|
||||
with verify(dtype) as (atol, rtol):
|
||||
self.common(mod, (v,), atol=atol, rtol=rtol)
|
||||
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
|
||||
self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1)
|
||||
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
|
||||
|
||||
@inductor_config.patch({"freezing": True})
|
||||
@patches
|
||||
|
|
@ -661,7 +661,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
|
|||
atol=atol,
|
||||
rtol=rtol,
|
||||
)
|
||||
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 2)
|
||||
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 2)
|
||||
# TODO: change cpp_epilogue_fusion_counter to 1 once supported
|
||||
self.assertEqual(
|
||||
counters["inductor"]["cpp_epilogue_fusion_counter"], 1 if epilogue else 0
|
||||
|
|
@ -708,8 +708,8 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
|
|||
mod = M(bias=bias, unary=unary, binary=binary, other=u).to(dtype=dtype).eval()
|
||||
with verify(dtype) as (atol, rtol):
|
||||
self.common(mod, (v,), atol=atol, rtol=rtol)
|
||||
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
|
||||
self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1)
|
||||
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
|
||||
|
||||
@inductor_config.patch({"freezing": True})
|
||||
@patches
|
||||
|
|
@ -743,7 +743,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
|
|||
mod = M(bias=bias, binary=binary, other=u).to(dtype=dtype).eval()
|
||||
with verify(dtype) as (atol, rtol):
|
||||
self.common(mod, (v,), atol=atol, rtol=rtol)
|
||||
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
|
||||
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
|
||||
|
||||
@inductor_config.patch({"freezing": True})
|
||||
@patches
|
||||
|
|
@ -796,9 +796,12 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
|
|||
mod = M(bias=bias).to(dtype=dtype).eval()
|
||||
with verify(dtype) as (atol, rtol):
|
||||
self.common(mod, (v,), atol=atol, rtol=rtol)
|
||||
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 3)
|
||||
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 3)
|
||||
self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 2)
|
||||
|
||||
@unittest.skipIf(
|
||||
not torch._C._cpu._is_amx_tile_supported(), "AMX ISA support is required"
|
||||
)
|
||||
@inductor_config.patch({"freezing": True})
|
||||
@patches
|
||||
@torch.no_grad
|
||||
|
|
@ -821,7 +824,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
|
|||
mod = M(bias=bias).to(dtype=dtype).eval()
|
||||
with verify(dtype) as (atol, rtol):
|
||||
self.common(mod, (v,), atol=atol, rtol=rtol)
|
||||
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
|
||||
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
|
||||
vec_amx = VecAMX()
|
||||
# Currently brgemm config is only added for half
|
||||
if dtype == torch.half:
|
||||
|
|
@ -918,7 +921,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
|
|||
atol=atol,
|
||||
rtol=rtol,
|
||||
)
|
||||
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
|
||||
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
|
||||
self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 2)
|
||||
|
||||
@inductor_config.patch({"freezing": True})
|
||||
|
|
@ -948,7 +951,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
|
|||
mod = M(bias=bias).eval()
|
||||
with verify(dtype) as (atol, rtol):
|
||||
self.common(mod, (idx, x), atol=atol, rtol=rtol)
|
||||
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
|
||||
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
|
||||
self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1)
|
||||
|
||||
@inductor_config.patch({"freezing": True})
|
||||
|
|
@ -1005,7 +1008,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
|
|||
atol=atol,
|
||||
rtol=rtol,
|
||||
)
|
||||
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
|
||||
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
|
||||
self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1)
|
||||
|
||||
@inductor_config.patch({"freezing": True})
|
||||
|
|
@ -1295,7 +1298,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
|
|||
atol=atol,
|
||||
rtol=rtol,
|
||||
)
|
||||
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 2)
|
||||
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 2)
|
||||
self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 2)
|
||||
|
||||
@inductor_config.patch({"freezing": True})
|
||||
|
|
@ -1362,7 +1365,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
|
|||
equal_nan=True,
|
||||
exact_dtype=True,
|
||||
)
|
||||
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 2)
|
||||
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 2)
|
||||
self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 0)
|
||||
|
||||
@inductor_config.patch({"freezing": True})
|
||||
|
|
@ -1412,10 +1415,13 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
|
|||
w_int8pack, w_scales = _convert_weight_to_int8pack(w)
|
||||
mod = M(w_int8pack).eval()
|
||||
self.common(mod, (x, w_scales))
|
||||
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
|
||||
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
|
||||
vec_amx = VecAMX()
|
||||
self._check_amx_counter(vec_amx)
|
||||
|
||||
@unittest.skipIf(
|
||||
not torch._C._cpu._is_amx_tile_supported(), "AMX ISA support is required"
|
||||
)
|
||||
@inductor_config.patch({"freezing": True})
|
||||
@patches
|
||||
@torch.no_grad
|
||||
|
|
@ -1498,6 +1504,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
|
|||
|
||||
vec_amx = VecAMX()
|
||||
self._check_amx_counter(vec_amx)
|
||||
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
|
||||
|
||||
@inductor_config.patch({"freezing": True})
|
||||
@patches
|
||||
|
|
@ -1582,15 +1589,15 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
|
|||
equal_nan=True,
|
||||
exact_dtype=True,
|
||||
)
|
||||
self.assertEqual(
|
||||
counters["inductor"]["select_algorithm_autotune"],
|
||||
2,
|
||||
)
|
||||
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 2)
|
||||
self.assertEqual(
|
||||
counters["inductor"]["cpp_epilogue_fusion_counter"],
|
||||
0,
|
||||
)
|
||||
|
||||
@unittest.skipIf(
|
||||
not torch._C._cpu._is_amx_tile_supported(), "AMX ISA support is required"
|
||||
)
|
||||
@inductor_config.patch({"freezing": True})
|
||||
@patches
|
||||
@torch.no_grad
|
||||
|
|
@ -1617,7 +1624,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
|
|||
atol, rtol = 1e-2, 1e-2
|
||||
with patch.object(select_algorithm, "VERIFY", dict(atol=atol, rtol=rtol)):
|
||||
self.common(ref_quantized_mod, (v,), atol=atol, rtol=rtol)
|
||||
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
|
||||
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
|
||||
vec_amx = VecAMX()
|
||||
self._check_amx_counter(vec_amx)
|
||||
|
||||
|
|
@ -1656,7 +1663,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
|
|||
mod = M(bias=bias, epilogue=epilogue, other=u).to(dtype=dtype).eval()
|
||||
with verify(dtype) as (atol, rtol):
|
||||
self.common(mod, (v,), atol=atol, rtol=rtol)
|
||||
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
|
||||
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
|
||||
|
||||
@inductor_config.patch({"freezing": True})
|
||||
@inductor_config.patch({"cpp.gemm_cache_blocking": "2,2,2"})
|
||||
|
|
@ -1685,7 +1692,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
|
|||
mod = M(bias=bias).to(dtype=dtype).eval()
|
||||
with verify(dtype) as (atol, rtol):
|
||||
self.common(mod, (v,), atol=atol, rtol=rtol)
|
||||
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
|
||||
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
|
||||
|
||||
@inductor_config.patch({"freezing": True})
|
||||
@inductor_config.patch({"cpp.gemm_thread_factors": "4,2,7"})
|
||||
|
|
@ -1714,7 +1721,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
|
|||
mod = M(bias=bias).to(dtype=dtype).eval()
|
||||
with verify(dtype) as (atol, rtol):
|
||||
self.common(mod, (v,), atol=atol, rtol=rtol)
|
||||
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
|
||||
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
|
||||
|
||||
@inductor_config.patch({"freezing": False})
|
||||
@patches
|
||||
|
|
@ -1764,7 +1771,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
|
|||
(v,),
|
||||
)
|
||||
self.assertEqual(actual, expected, atol=atol, rtol=rtol)
|
||||
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
|
||||
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
|
||||
|
||||
@inductor_config.patch({"freezing": True})
|
||||
@inductor_config.patch({"cpp.enable_grouped_gemm_template": True})
|
||||
|
|
@ -1811,7 +1818,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
|
|||
self.common(mod, (v,), atol=atol, rtol=rtol)
|
||||
# gemm_num independent template instead of grouped gemm template
|
||||
self.assertEqual(
|
||||
counters["inductor"]["select_algorithm_autotune"], gemm_num
|
||||
counters["inductor"]["cpp_templated_kernel_counter"], gemm_num
|
||||
)
|
||||
self.assertEqual(counters["inductor"]["cpp_grouped_gemm_template"], 0)
|
||||
|
||||
|
|
@ -2000,7 +2007,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
|
|||
(v,),
|
||||
)
|
||||
self.assertEqual(actual, expected, atol=atol, rtol=rtol)
|
||||
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
|
||||
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
|
||||
|
||||
@inductor_config.patch({"freezing": True})
|
||||
@inductor_config.patch({"coordinate_descent_tuning": True})
|
||||
|
|
@ -2023,7 +2030,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
|
|||
counters.clear()
|
||||
with verify(torch.bfloat16) as (atol, rtol), torch.autocast(device_type="cpu"):
|
||||
self.common(mod, (v,), atol=atol, rtol=rtol)
|
||||
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
|
||||
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
|
||||
|
||||
@inductor_config.patch({"freezing": True})
|
||||
@patches
|
||||
|
|
@ -2049,7 +2056,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
|
|||
v = torch.randn(*B, in_features).to(dtype=dtype)
|
||||
with verify(dtype) as (atol, rtol):
|
||||
self.common(mod, (v,), atol=atol, rtol=rtol)
|
||||
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
|
||||
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
|
||||
|
||||
@inductor_config.patch({"freezing": True})
|
||||
@patches
|
||||
|
|
@ -2071,7 +2078,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
|
|||
counters.clear()
|
||||
with verify(torch.bfloat16) as (atol, rtol):
|
||||
self.common(mod, (v,), atol=atol, rtol=rtol)
|
||||
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
|
||||
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
|
||||
self.assertEqual(counters["inductor"]["select_algorithm_weight_prune"], 1)
|
||||
|
||||
@patches
|
||||
|
|
@ -2096,7 +2103,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
|
|||
mod = M().to(dtype=dtype).eval()
|
||||
with verify(dtype) as (atol, rtol):
|
||||
self.common(mod, (u, v), atol=atol, rtol=rtol)
|
||||
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
|
||||
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
|
||||
|
||||
@patches
|
||||
@torch.no_grad
|
||||
|
|
@ -2120,7 +2127,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
|
|||
mod = M().to(dtype=dtype).eval()
|
||||
with verify(dtype) as (atol, rtol):
|
||||
self.common(mod, (u, v), atol=atol, rtol=rtol)
|
||||
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
|
||||
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
|
||||
vec_amx = VecAMX()
|
||||
# Currently brgemm config is only added for half
|
||||
if dtype == torch.half:
|
||||
|
|
@ -2150,7 +2157,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
|
|||
mod = M().to(dtype=dtype).eval()
|
||||
with verify(dtype) as (atol, rtol), torch.amp.autocast("cpu"):
|
||||
self.common(mod, (u, v), atol=atol, rtol=rtol)
|
||||
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
|
||||
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
|
||||
|
||||
@inductor_config.patch({"freezing": True})
|
||||
@patches
|
||||
|
|
@ -2176,7 +2183,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
|
|||
mod = M(v).to(dtype=dtype).eval()
|
||||
with verify(dtype) as (atol, rtol):
|
||||
self.common(mod, (u,), atol=atol, rtol=rtol)
|
||||
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
|
||||
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
|
||||
|
||||
@patches
|
||||
@torch.no_grad
|
||||
|
|
@ -2225,7 +2232,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
|
|||
with verify(dtype) as (atol, rtol):
|
||||
self.common(mod, (u, v), atol=atol, rtol=rtol)
|
||||
self.assertEqual(
|
||||
counters["inductor"]["select_algorithm_autotune"],
|
||||
counters["inductor"]["cpp_templated_kernel_counter"],
|
||||
1 if order[0] == (0, 1, 2) else 0,
|
||||
)
|
||||
|
||||
|
|
@ -2249,7 +2256,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
|
|||
mod = M().to(dtype=dtype).eval()
|
||||
with verify(dtype) as (atol, rtol):
|
||||
self.common(mod, (u,), atol=atol, rtol=rtol)
|
||||
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
|
||||
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
|
||||
|
||||
@patches
|
||||
@torch.no_grad
|
||||
|
|
@ -2270,7 +2277,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
|
|||
mod = M().to(dtype=dtype).eval()
|
||||
with verify(dtype) as (atol, rtol):
|
||||
self.common(mod, (u,), atol=atol, rtol=rtol)
|
||||
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
|
||||
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
|
||||
|
||||
@patches
|
||||
@torch.no_grad
|
||||
|
|
@ -2306,7 +2313,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
|
|||
mod = M(epilogue, other).to(dtype=dtype).eval()
|
||||
with verify(dtype) as (atol, rtol):
|
||||
self.common(mod, (x, w), atol=atol, rtol=rtol)
|
||||
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
|
||||
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
|
||||
self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1)
|
||||
|
||||
@patches
|
||||
|
|
@ -2339,7 +2346,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
|
|||
mod = M().to(dtype=dtype).eval()
|
||||
with verify(dtype) as (atol, rtol):
|
||||
self.common(mod, (x, w), atol=atol, rtol=rtol)
|
||||
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
|
||||
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
|
||||
self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1)
|
||||
|
||||
@patches
|
||||
|
|
@ -2375,7 +2382,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
|
|||
(x, w),
|
||||
)
|
||||
self.assertEqual(actual, expected, atol=atol, rtol=rtol)
|
||||
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 2)
|
||||
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 2)
|
||||
|
||||
@patches
|
||||
@torch.no_grad
|
||||
|
|
@ -2432,7 +2439,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
|
|||
(x,),
|
||||
)
|
||||
self.assertEqual(actual, expected, atol=atol, rtol=rtol)
|
||||
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 2)
|
||||
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 2)
|
||||
|
||||
|
||||
@dynamo_config.patch({"dynamic_shapes": True, "assume_static_by_default": False})
|
||||
|
|
@ -2504,7 +2511,7 @@ class TestSelectAlgorithmDynamicShapes(_DynamicShapesTestBase):
|
|||
mod = M().to(dtype=dtype).eval()
|
||||
with verify(dtype) as (atol, rtol):
|
||||
self.common(mod, (u, v), atol=atol, rtol=rtol)
|
||||
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
|
||||
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
|
||||
self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1)
|
||||
|
||||
@patches
|
||||
|
|
@ -2540,7 +2547,7 @@ class TestSelectAlgorithmDynamicShapes(_DynamicShapesTestBase):
|
|||
mod = M().to(dtype=dtype).eval()
|
||||
with verify(dtype) as (atol, rtol):
|
||||
self.common(mod, (u, v, noise), atol=atol, rtol=rtol)
|
||||
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
|
||||
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
|
||||
self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1)
|
||||
|
||||
@patches
|
||||
|
|
@ -2600,7 +2607,7 @@ class TestSelectAlgorithmDynamicShapes(_DynamicShapesTestBase):
|
|||
mod = M().to(dtype=dtype).eval()
|
||||
with verify(dtype) as (atol, rtol):
|
||||
self.common(mod, (u, v, arg5), atol=atol, rtol=rtol)
|
||||
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
|
||||
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
|
||||
self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1)
|
||||
|
||||
@patches
|
||||
|
|
@ -2630,7 +2637,7 @@ class TestSelectAlgorithmDynamicShapes(_DynamicShapesTestBase):
|
|||
mod = M().to(dtype=dtype).eval()
|
||||
with verify(dtype) as (atol, rtol):
|
||||
self.common(mod, (u, v), atol=atol, rtol=rtol)
|
||||
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
|
||||
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestSelectAlgorithm, globals(), only_for="cpu")
|
||||
|
|
|
|||
|
|
@ -4967,7 +4967,9 @@ class CppScheduling(BaseScheduling):
|
|||
for epilogue_node in epilogue_nodes
|
||||
if isinstance(epilogue_node, (SchedulerNode, FusedSchedulerNode))
|
||||
]
|
||||
|
||||
# The counter cpp_templated_kernel_counter is used for verifying if a
|
||||
# a templated kernel was successfully compiled in a UT
|
||||
counters["inductor"]["cpp_templated_kernel_counter"] += 1
|
||||
counters["inductor"]["cpp_epilogue_fusion_counter"] += len(epilogue_nodes)
|
||||
assert self.is_cpp_template(template_node), (
|
||||
"Template node passed to CppScheduler.codegen_template must be a SchedulerNode that wraps a CppTemplateBuffer"
|
||||
|
|
|
|||
|
|
@ -584,7 +584,7 @@ class CppMicroGemmAMX(CppMicroGemm):
|
|||
// Except maybe for the tail-case, an AMX tile of B has 16x32 BF16 elements.
|
||||
// we cache K * {{block_n}} elements of dequantized B
|
||||
{{template.codegen_allocate_weight_buffer("dequantized_B_buf", input_t, "K", block_n)}}
|
||||
|
||||
const auto buf_size = K * {{block_n}};
|
||||
auto load_dequantized_B = [&](int base_idx) {
|
||||
// Load a tile of B & cache it in L1D.
|
||||
{{input2_t}}* base_addr = const_cast<{{input2_t}}*>(B) + base_idx;
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user