Migrate map_test to use PjRt.

PiperOrigin-RevId: 826107887
This commit is contained in:
Niklas Vangerow 2025-10-30 11:06:56 -07:00 committed by TensorFlower Gardener
parent dd3a14ace4
commit 061041963e
2 changed files with 12 additions and 5 deletions

View File

@ -678,9 +678,13 @@ xla_test(
xla_test(
name = "map_test",
srcs = ["map_test.cc"],
tags = [
"test_migrated_to_hlo_runner_pjrt",
],
deps = [
":client_library_test_runner_mixin",
":hlo_test_base",
":hlo_pjrt_interpreter_reference_mixin",
":hlo_pjrt_test_base",
":xla_internal_test_main", # fixdeps: keep
"//xla:array2d",
"//xla:array3d",

View File

@ -34,7 +34,8 @@ limitations under the License.
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/tests/client_library_test_runner_mixin.h"
#include "xla/tests/hlo_test_base.h"
#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h"
#include "xla/tests/hlo_pjrt_test_base.h"
#include "xla/tsl/platform/status.h"
#include "xla/tsl/platform/test.h"
#include "xla/xla_data.pb.h"
@ -42,7 +43,8 @@ limitations under the License.
namespace xla {
namespace {
class MapTest : public ClientLibraryTestRunnerMixin<HloTestBase> {
class MapTest : public ClientLibraryTestRunnerMixin<
HloPjRtInterpreterReferenceMixin<HloPjRtTestBase>> {
public:
MapTest() {
mutable_debug_options()->add_xla_disable_hlo_passes("algsimp");
@ -454,7 +456,7 @@ TEST_F(MapTest, MapOperationWithBuildError) {
"different element types: f32[] and u16[]"));
}
class MapHloTest : public HloTestBase {};
using MapHloTest = HloPjRtInterpreterReferenceMixin<HloPjRtTestBase>;
// TODO(b/230123847): Enable this on GPU once mhlo allows mixed-type map.
TEST_F(MapHloTest, MapWithMixedInputTypes) {
@ -484,7 +486,8 @@ TEST_F(MapHloTest, MapWithMixedInputTypes) {
// MapTest disables inline and algsimp. MapTestWithFullOpt runs all
// optimizations.
using MapTestWithFullOpt = ClientLibraryTestRunnerMixin<HloTestBase>;
using MapTestWithFullOpt = ClientLibraryTestRunnerMixin<
HloPjRtInterpreterReferenceMixin<HloPjRtTestBase>>;
// Regression test for b/31466798. The inliner simplifies map(param0, param1,
// power) to power(param0, param1) without deleting the old subcomputation which