From 5a1954eb93801cf2261ea65b0f34bb81eb566064 Mon Sep 17 00:00:00 2001 From: sanchitintel Date: Fri, 28 Feb 2025 20:20:41 +0000 Subject: [PATCH] [Inductor-CPU] Fix broken int8 WoQ GEMM AMX implementation in main (#147895) #146843 broke int8 WoQ GEMM's (for BF16 activation) AMX ISA implementation in the main branch. UT: `python test/inductor/test_cpu_select_algorithm.py -v -k woq` The issue remained undetected because in case of templated kernel compilation failure, the auto-tuning infra marks its runtime as `inf`, and the op against which it was being benchmarked is used, so UTs didn't fail even on machines that support AMX ISA. `test/inductor/test_cpu_select_algorithm.py` UTs checked the value of the `select_algorithm_autotune` counter, which only counts how many ops were selected for autotuning against their templated codegened counterparts. @leslie-fang-intel advised using a new counter. I added `counters["inductor"]["cpp_templated_kernel_counter"]`, which is incremented after a codegened kernel's compilation, so it'd help catch breakage scenarios in which a templated kernel could not be codegened due to a compilation failure. Pull Request resolved: https://github.com/pytorch/pytorch/pull/147895 Approved by: https://github.com/jgong5, https://github.com/leslie-fang-intel --- test/inductor/test_cpu_select_algorithm.py | 103 +++++++++++---------- torch/_inductor/codegen/cpp.py | 4 +- torch/_inductor/codegen/cpp_micro_gemm.py | 2 +- 3 files changed, 59 insertions(+), 50 deletions(-) diff --git a/test/inductor/test_cpu_select_algorithm.py b/test/inductor/test_cpu_select_algorithm.py index 4fb05d29a2c..dfea48cf8c6 100644 --- a/test/inductor/test_cpu_select_algorithm.py +++ b/test/inductor/test_cpu_select_algorithm.py @@ -176,9 +176,9 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm): or counters["inductor"]["decompose_addmm"] > 0 ): # This is a special case where we go directly with vectorized codegen - self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 0) + self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 0) else: - self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) + self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1) @inductor_config.patch({"freezing": True}) @patches @@ -207,7 +207,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm): v = torch.LongTensor([[1, 2, 4, 5], [4, 3, 2, 9]]) with verify(dtype) as (atol, rtol): self.common(mod, (v,), atol=atol, rtol=rtol) - self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) + self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1) @inductor_config.patch({"freezing": True}) @patches @@ -234,7 +234,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm): v = torch.randn(in_features, batch_size).to(dtype=dtype) self.common(mod, (v.transpose(0, 1),)) # TODO(jgong5): support transposed input - self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 0) + self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 0) @inductor_config.patch({"freezing": True}) @patches @@ -282,7 +282,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm): mod = M(bias=bias, epilogue=epilogue, other=u).to(dtype=dtype).eval() with verify(dtype) as (atol, rtol): self.common(mod, (v,), atol=atol, rtol=rtol) - self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) + self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1) if ( ( ( @@ -360,7 +360,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm): mod = M(bias=bias, epilogue=epilogue, other=other).to(dtype=dtype).eval() with verify(dtype) as (atol, rtol): self.common(mod, (v, u), atol=atol, rtol=rtol) - self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) + self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1) self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1) @inductor_config.patch({"freezing": True}) @@ -462,7 +462,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm): atol=atol, rtol=rtol, ) - self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 2) + self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 2) self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 2) @inductor_config.patch({"freezing": True}) @@ -533,8 +533,8 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm): mod = M(bias=bias).to(dtype=dtype).eval() with verify(dtype) as (atol, rtol): self.common(mod, (v,), atol=atol, rtol=rtol) - self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1) + self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1) @inductor_config.patch({"freezing": True}) @patches @@ -661,7 +661,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm): atol=atol, rtol=rtol, ) - self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 2) + self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 2) # TODO: change cpp_epilogue_fusion_counter to 1 once supported self.assertEqual( counters["inductor"]["cpp_epilogue_fusion_counter"], 1 if epilogue else 0 @@ -708,8 +708,8 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm): mod = M(bias=bias, unary=unary, binary=binary, other=u).to(dtype=dtype).eval() with verify(dtype) as (atol, rtol): self.common(mod, (v,), atol=atol, rtol=rtol) - self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1) + self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1) @inductor_config.patch({"freezing": True}) @patches @@ -743,7 +743,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm): mod = M(bias=bias, binary=binary, other=u).to(dtype=dtype).eval() with verify(dtype) as (atol, rtol): self.common(mod, (v,), atol=atol, rtol=rtol) - self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) + self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1) @inductor_config.patch({"freezing": True}) @patches @@ -796,9 +796,12 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm): mod = M(bias=bias).to(dtype=dtype).eval() with verify(dtype) as (atol, rtol): self.common(mod, (v,), atol=atol, rtol=rtol) - self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 3) + self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 3) self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 2) + @unittest.skipIf( + not torch._C._cpu._is_amx_tile_supported(), "AMX ISA support is required" + ) @inductor_config.patch({"freezing": True}) @patches @torch.no_grad @@ -821,7 +824,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm): mod = M(bias=bias).to(dtype=dtype).eval() with verify(dtype) as (atol, rtol): self.common(mod, (v,), atol=atol, rtol=rtol) - self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) + self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1) vec_amx = VecAMX() # Currently brgemm config is only added for half if dtype == torch.half: @@ -918,7 +921,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm): atol=atol, rtol=rtol, ) - self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) + self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1) self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 2) @inductor_config.patch({"freezing": True}) @@ -948,7 +951,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm): mod = M(bias=bias).eval() with verify(dtype) as (atol, rtol): self.common(mod, (idx, x), atol=atol, rtol=rtol) - self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) + self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1) self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1) @inductor_config.patch({"freezing": True}) @@ -1005,7 +1008,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm): atol=atol, rtol=rtol, ) - self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) + self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1) self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1) @inductor_config.patch({"freezing": True}) @@ -1295,7 +1298,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm): atol=atol, rtol=rtol, ) - self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 2) + self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 2) self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 2) @inductor_config.patch({"freezing": True}) @@ -1362,7 +1365,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm): equal_nan=True, exact_dtype=True, ) - self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 2) + self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 2) self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 0) @inductor_config.patch({"freezing": True}) @@ -1412,10 +1415,13 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm): w_int8pack, w_scales = _convert_weight_to_int8pack(w) mod = M(w_int8pack).eval() self.common(mod, (x, w_scales)) - self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) + self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1) vec_amx = VecAMX() self._check_amx_counter(vec_amx) + @unittest.skipIf( + not torch._C._cpu._is_amx_tile_supported(), "AMX ISA support is required" + ) @inductor_config.patch({"freezing": True}) @patches @torch.no_grad @@ -1498,6 +1504,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm): vec_amx = VecAMX() self._check_amx_counter(vec_amx) + self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1) @inductor_config.patch({"freezing": True}) @patches @@ -1582,15 +1589,15 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm): equal_nan=True, exact_dtype=True, ) - self.assertEqual( - counters["inductor"]["select_algorithm_autotune"], - 2, - ) + self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 2) self.assertEqual( counters["inductor"]["cpp_epilogue_fusion_counter"], 0, ) + @unittest.skipIf( + not torch._C._cpu._is_amx_tile_supported(), "AMX ISA support is required" + ) @inductor_config.patch({"freezing": True}) @patches @torch.no_grad @@ -1617,7 +1624,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm): atol, rtol = 1e-2, 1e-2 with patch.object(select_algorithm, "VERIFY", dict(atol=atol, rtol=rtol)): self.common(ref_quantized_mod, (v,), atol=atol, rtol=rtol) - self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) + self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1) vec_amx = VecAMX() self._check_amx_counter(vec_amx) @@ -1656,7 +1663,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm): mod = M(bias=bias, epilogue=epilogue, other=u).to(dtype=dtype).eval() with verify(dtype) as (atol, rtol): self.common(mod, (v,), atol=atol, rtol=rtol) - self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) + self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1) @inductor_config.patch({"freezing": True}) @inductor_config.patch({"cpp.gemm_cache_blocking": "2,2,2"}) @@ -1685,7 +1692,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm): mod = M(bias=bias).to(dtype=dtype).eval() with verify(dtype) as (atol, rtol): self.common(mod, (v,), atol=atol, rtol=rtol) - self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) + self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1) @inductor_config.patch({"freezing": True}) @inductor_config.patch({"cpp.gemm_thread_factors": "4,2,7"}) @@ -1714,7 +1721,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm): mod = M(bias=bias).to(dtype=dtype).eval() with verify(dtype) as (atol, rtol): self.common(mod, (v,), atol=atol, rtol=rtol) - self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) + self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1) @inductor_config.patch({"freezing": False}) @patches @@ -1764,7 +1771,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm): (v,), ) self.assertEqual(actual, expected, atol=atol, rtol=rtol) - self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) + self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1) @inductor_config.patch({"freezing": True}) @inductor_config.patch({"cpp.enable_grouped_gemm_template": True}) @@ -1811,7 +1818,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm): self.common(mod, (v,), atol=atol, rtol=rtol) # gemm_num independent template instead of grouped gemm template self.assertEqual( - counters["inductor"]["select_algorithm_autotune"], gemm_num + counters["inductor"]["cpp_templated_kernel_counter"], gemm_num ) self.assertEqual(counters["inductor"]["cpp_grouped_gemm_template"], 0) @@ -2000,7 +2007,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm): (v,), ) self.assertEqual(actual, expected, atol=atol, rtol=rtol) - self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) + self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1) @inductor_config.patch({"freezing": True}) @inductor_config.patch({"coordinate_descent_tuning": True}) @@ -2023,7 +2030,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm): counters.clear() with verify(torch.bfloat16) as (atol, rtol), torch.autocast(device_type="cpu"): self.common(mod, (v,), atol=atol, rtol=rtol) - self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) + self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1) @inductor_config.patch({"freezing": True}) @patches @@ -2049,7 +2056,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm): v = torch.randn(*B, in_features).to(dtype=dtype) with verify(dtype) as (atol, rtol): self.common(mod, (v,), atol=atol, rtol=rtol) - self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) + self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1) @inductor_config.patch({"freezing": True}) @patches @@ -2071,7 +2078,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm): counters.clear() with verify(torch.bfloat16) as (atol, rtol): self.common(mod, (v,), atol=atol, rtol=rtol) - self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) + self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1) self.assertEqual(counters["inductor"]["select_algorithm_weight_prune"], 1) @patches @@ -2096,7 +2103,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm): mod = M().to(dtype=dtype).eval() with verify(dtype) as (atol, rtol): self.common(mod, (u, v), atol=atol, rtol=rtol) - self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) + self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1) @patches @torch.no_grad @@ -2120,7 +2127,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm): mod = M().to(dtype=dtype).eval() with verify(dtype) as (atol, rtol): self.common(mod, (u, v), atol=atol, rtol=rtol) - self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) + self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1) vec_amx = VecAMX() # Currently brgemm config is only added for half if dtype == torch.half: @@ -2150,7 +2157,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm): mod = M().to(dtype=dtype).eval() with verify(dtype) as (atol, rtol), torch.amp.autocast("cpu"): self.common(mod, (u, v), atol=atol, rtol=rtol) - self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) + self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1) @inductor_config.patch({"freezing": True}) @patches @@ -2176,7 +2183,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm): mod = M(v).to(dtype=dtype).eval() with verify(dtype) as (atol, rtol): self.common(mod, (u,), atol=atol, rtol=rtol) - self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) + self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1) @patches @torch.no_grad @@ -2225,7 +2232,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm): with verify(dtype) as (atol, rtol): self.common(mod, (u, v), atol=atol, rtol=rtol) self.assertEqual( - counters["inductor"]["select_algorithm_autotune"], + counters["inductor"]["cpp_templated_kernel_counter"], 1 if order[0] == (0, 1, 2) else 0, ) @@ -2249,7 +2256,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm): mod = M().to(dtype=dtype).eval() with verify(dtype) as (atol, rtol): self.common(mod, (u,), atol=atol, rtol=rtol) - self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) + self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1) @patches @torch.no_grad @@ -2270,7 +2277,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm): mod = M().to(dtype=dtype).eval() with verify(dtype) as (atol, rtol): self.common(mod, (u,), atol=atol, rtol=rtol) - self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) + self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1) @patches @torch.no_grad @@ -2306,7 +2313,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm): mod = M(epilogue, other).to(dtype=dtype).eval() with verify(dtype) as (atol, rtol): self.common(mod, (x, w), atol=atol, rtol=rtol) - self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) + self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1) self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1) @patches @@ -2339,7 +2346,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm): mod = M().to(dtype=dtype).eval() with verify(dtype) as (atol, rtol): self.common(mod, (x, w), atol=atol, rtol=rtol) - self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) + self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1) self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1) @patches @@ -2375,7 +2382,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm): (x, w), ) self.assertEqual(actual, expected, atol=atol, rtol=rtol) - self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 2) + self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 2) @patches @torch.no_grad @@ -2432,7 +2439,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm): (x,), ) self.assertEqual(actual, expected, atol=atol, rtol=rtol) - self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 2) + self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 2) @dynamo_config.patch({"dynamic_shapes": True, "assume_static_by_default": False}) @@ -2504,7 +2511,7 @@ class TestSelectAlgorithmDynamicShapes(_DynamicShapesTestBase): mod = M().to(dtype=dtype).eval() with verify(dtype) as (atol, rtol): self.common(mod, (u, v), atol=atol, rtol=rtol) - self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) + self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1) self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1) @patches @@ -2540,7 +2547,7 @@ class TestSelectAlgorithmDynamicShapes(_DynamicShapesTestBase): mod = M().to(dtype=dtype).eval() with verify(dtype) as (atol, rtol): self.common(mod, (u, v, noise), atol=atol, rtol=rtol) - self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) + self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1) self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1) @patches @@ -2600,7 +2607,7 @@ class TestSelectAlgorithmDynamicShapes(_DynamicShapesTestBase): mod = M().to(dtype=dtype).eval() with verify(dtype) as (atol, rtol): self.common(mod, (u, v, arg5), atol=atol, rtol=rtol) - self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) + self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1) self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1) @patches @@ -2630,7 +2637,7 @@ class TestSelectAlgorithmDynamicShapes(_DynamicShapesTestBase): mod = M().to(dtype=dtype).eval() with verify(dtype) as (atol, rtol): self.common(mod, (u, v), atol=atol, rtol=rtol) - self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) + self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1) instantiate_device_type_tests(TestSelectAlgorithm, globals(), only_for="cpu") diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index ccf6a7d0804..90ba256459d 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -4967,7 +4967,9 @@ class CppScheduling(BaseScheduling): for epilogue_node in epilogue_nodes if isinstance(epilogue_node, (SchedulerNode, FusedSchedulerNode)) ] - + # The counter cpp_templated_kernel_counter is used for verifying if a + # a templated kernel was successfully compiled in a UT + counters["inductor"]["cpp_templated_kernel_counter"] += 1 counters["inductor"]["cpp_epilogue_fusion_counter"] += len(epilogue_nodes) assert self.is_cpp_template(template_node), ( "Template node passed to CppScheduler.codegen_template must be a SchedulerNode that wraps a CppTemplateBuffer" diff --git a/torch/_inductor/codegen/cpp_micro_gemm.py b/torch/_inductor/codegen/cpp_micro_gemm.py index bbfc54f6a66..0825f677d44 100644 --- a/torch/_inductor/codegen/cpp_micro_gemm.py +++ b/torch/_inductor/codegen/cpp_micro_gemm.py @@ -584,7 +584,7 @@ class CppMicroGemmAMX(CppMicroGemm): // Except maybe for the tail-case, an AMX tile of B has 16x32 BF16 elements. // we cache K * {{block_n}} elements of dequantized B {{template.codegen_allocate_weight_buffer("dequantized_B_buf", input_t, "K", block_n)}} - + const auto buf_size = K * {{block_n}}; auto load_dequantized_B = [&](int base_idx) { // Load a tile of B & cache it in L1D. {{input2_t}}* base_addr = const_cast<{{input2_t}}*>(B) + base_idx;