Migrate broadcast_simple_test to use PjRt.

PiperOrigin-RevId: 825775803
This commit is contained in:
Niklas Vangerow 2025-10-29 17:29:09 -07:00 committed by TensorFlower Gardener
parent 152b2338d9
commit fe2a783077
2 changed files with 9 additions and 3 deletions

View File

@ -2319,9 +2319,13 @@ xla_test(
xla_test(
name = "broadcast_simple_test",
srcs = ["broadcast_simple_test.cc"],
tags = [
"test_migrated_to_hlo_runner_pjrt",
],
deps = [
":client_library_test_runner_mixin",
":hlo_test_base",
":hlo_pjrt_interpreter_reference_mixin",
":hlo_pjrt_test_base",
":xla_internal_test_main", # fixdeps: keep
"//xla:array2d",
"//xla:array3d",

View File

@ -32,7 +32,8 @@ limitations under the License.
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/tests/client_library_test_runner_mixin.h"
#include "xla/tests/hlo_test_base.h"
#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h"
#include "xla/tests/hlo_pjrt_test_base.h"
#include "xla/tsl/platform/test.h"
namespace xla {
@ -101,7 +102,8 @@ float ApplyOpToFloats(HloOpcode op, float lhs, float rhs) {
}
}
using BroadcastSimpleTest = ClientLibraryTestRunnerMixin<HloTestBase>;
using BroadcastSimpleTest = ClientLibraryTestRunnerMixin<
HloPjRtInterpreterReferenceMixin<HloPjRtTestBase>>;
TEST_F(BroadcastSimpleTest, ScalarNoOpBroadcast) {
XlaBuilder b(TestName());