diff --git a/test/inductor/test_cpu_select_algorithm.py b/test/inductor/test_cpu_select_algorithm.py index ad27dd3190f..4e1c48496eb 100644 --- a/test/inductor/test_cpu_select_algorithm.py +++ b/test/inductor/test_cpu_select_algorithm.py @@ -2936,10 +2936,20 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm): x = torch.randn(batch_size, in_features).to(dtype=dtype) mod = M().to(dtype=dtype).eval() - self.common(mod, (x)) - _, code = run_and_get_cpp_code(mod, x) - # Check that only 2 kernels are in the generated code - assert code.count("AMXState amx_state") == 2 + with verify(dtype) as (atol, rtol): + ref_res = mod(x) + m = torch.compile(mod) + res, code = run_and_get_cpp_code(m, x) + self.assertEqual( + res, + ref_res, + atol=atol, + rtol=rtol, + equal_nan=True, + exact_dtype=True, + ) + # Check that only 2 kernels are in the generated code + assert code.count("AMXState amx_state") == 2 @dynamo_config.patch({"dynamic_shapes": True, "assume_static_by_default": False})