[XLA:CPU] Fix flaky test.

PiperOrigin-RevId: 821835738
This commit is contained in:
Will Froom 2025-10-20 15:42:55 -07:00 committed by TensorFlower Gardener
parent 67e5eafb24
commit dd4822d61c

View File

@ -14,6 +14,7 @@
# ==============================================================================
from collections.abc import Callable, Iterable
from typing import Optional
from absl.testing import absltest
import numpy as np
@ -33,7 +34,7 @@ def compare_kernel(
output_shape: tuple[int, ...],
dtype,
expected_output: Callable[[np.ndarray, ...], np.ndarray],
exact: bool = True,
maxulp: Optional[int] = None,
) -> None:
mlir_emitter = cpu_testlib.MlirTestKernelEmitter(
ir, kernel_name, (num_workgroups, 1, 1)
@ -47,16 +48,20 @@ def compare_kernel(
inputs = [np.random.rand(*shape).astype(dtype) for shape in input_shapes]
input_tensors = [create_literal(input) for input in inputs]
output_tensor = create_literal(np.zeros(output_shape, dtype=dtype))
output_tensor = create_literal(
np.zeros(shape=output_shape, dtype=dtype)
if output_shape
else np.array(0, dtype=dtype)
)
runner.call(input_tensors + [output_tensor])
if exact:
np.testing.assert_array_equal(
np.asarray(output_tensor), expected_output(*inputs)
)
output_np = np.asarray(output_tensor)
expected_output_np = expected_output(*inputs)
if maxulp is None:
np.testing.assert_array_equal(output_np, expected_output_np)
else:
np.testing.assert_array_almost_equal_nulp(
np.asarray(output_tensor), expected_output(*inputs), nulp=3
np.testing.assert_array_max_ulp(
output_np, expected_output_np, maxulp=maxulp
)
@ -171,7 +176,7 @@ class XtileLoweringTest(absltest.TestCase):
(8, 8),
np.float32,
lambda lhs, rhs: lhs @ rhs,
False,
maxulp=5,
)
def test_dot_scalar_output(self):
@ -197,10 +202,10 @@ class XtileLoweringTest(absltest.TestCase):
"test_dot_scalar_output",
1,
[(8, 16), (16, 8)],
(1,),
(),
np.float32,
lambda lhs, rhs: np.tensordot(lhs, rhs, axes=([1, 0], [0, 1])),
False,
lambda lhs, rhs: np.tensordot(lhs, rhs, axes=[[1, 0], [0, 1]]),
maxulp=8,
)
def test_dot_fusion_single_tile(self):
@ -233,7 +238,7 @@ class XtileLoweringTest(absltest.TestCase):
(8, 1),
np.float32,
lambda lhs_0, lhs_1, rhs: np.tanh((lhs_0 + lhs_1) @ rhs),
False,
maxulp=5,
)