mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 00:19:58 +01:00
[XLA:CPU] Fix flaky test.
PiperOrigin-RevId: 821835738
This commit is contained in:
parent
67e5eafb24
commit
dd4822d61c
|
|
@ -14,6 +14,7 @@
|
|||
# ==============================================================================
|
||||
|
||||
from collections.abc import Callable, Iterable
|
||||
from typing import Optional
|
||||
|
||||
from absl.testing import absltest
|
||||
import numpy as np
|
||||
|
|
@ -33,7 +34,7 @@ def compare_kernel(
|
|||
output_shape: tuple[int, ...],
|
||||
dtype,
|
||||
expected_output: Callable[[np.ndarray, ...], np.ndarray],
|
||||
exact: bool = True,
|
||||
maxulp: Optional[int] = None,
|
||||
) -> None:
|
||||
mlir_emitter = cpu_testlib.MlirTestKernelEmitter(
|
||||
ir, kernel_name, (num_workgroups, 1, 1)
|
||||
|
|
@ -47,16 +48,20 @@ def compare_kernel(
|
|||
inputs = [np.random.rand(*shape).astype(dtype) for shape in input_shapes]
|
||||
|
||||
input_tensors = [create_literal(input) for input in inputs]
|
||||
output_tensor = create_literal(np.zeros(output_shape, dtype=dtype))
|
||||
output_tensor = create_literal(
|
||||
np.zeros(shape=output_shape, dtype=dtype)
|
||||
if output_shape
|
||||
else np.array(0, dtype=dtype)
|
||||
)
|
||||
runner.call(input_tensors + [output_tensor])
|
||||
|
||||
if exact:
|
||||
np.testing.assert_array_equal(
|
||||
np.asarray(output_tensor), expected_output(*inputs)
|
||||
)
|
||||
output_np = np.asarray(output_tensor)
|
||||
expected_output_np = expected_output(*inputs)
|
||||
if maxulp is None:
|
||||
np.testing.assert_array_equal(output_np, expected_output_np)
|
||||
else:
|
||||
np.testing.assert_array_almost_equal_nulp(
|
||||
np.asarray(output_tensor), expected_output(*inputs), nulp=3
|
||||
np.testing.assert_array_max_ulp(
|
||||
output_np, expected_output_np, maxulp=maxulp
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -171,7 +176,7 @@ class XtileLoweringTest(absltest.TestCase):
|
|||
(8, 8),
|
||||
np.float32,
|
||||
lambda lhs, rhs: lhs @ rhs,
|
||||
False,
|
||||
maxulp=5,
|
||||
)
|
||||
|
||||
def test_dot_scalar_output(self):
|
||||
|
|
@ -197,10 +202,10 @@ class XtileLoweringTest(absltest.TestCase):
|
|||
"test_dot_scalar_output",
|
||||
1,
|
||||
[(8, 16), (16, 8)],
|
||||
(1,),
|
||||
(),
|
||||
np.float32,
|
||||
lambda lhs, rhs: np.tensordot(lhs, rhs, axes=([1, 0], [0, 1])),
|
||||
False,
|
||||
lambda lhs, rhs: np.tensordot(lhs, rhs, axes=[[1, 0], [0, 1]]),
|
||||
maxulp=8,
|
||||
)
|
||||
|
||||
def test_dot_fusion_single_tile(self):
|
||||
|
|
@ -233,7 +238,7 @@ class XtileLoweringTest(absltest.TestCase):
|
|||
(8, 1),
|
||||
np.float32,
|
||||
lambda lhs_0, lhs_1, rhs: np.tanh((lhs_0 + lhs_1) @ rhs),
|
||||
False,
|
||||
maxulp=5,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user