mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Migrate map_test to use PjRt.
PiperOrigin-RevId: 826107887
This commit is contained in:
parent
dd3a14ace4
commit
061041963e
6
third_party/xla/xla/tests/BUILD
vendored
6
third_party/xla/xla/tests/BUILD
vendored
|
|
@ -678,9 +678,13 @@ xla_test(
|
||||||
xla_test(
|
xla_test(
|
||||||
name = "map_test",
|
name = "map_test",
|
||||||
srcs = ["map_test.cc"],
|
srcs = ["map_test.cc"],
|
||||||
|
tags = [
|
||||||
|
"test_migrated_to_hlo_runner_pjrt",
|
||||||
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":client_library_test_runner_mixin",
|
":client_library_test_runner_mixin",
|
||||||
":hlo_test_base",
|
":hlo_pjrt_interpreter_reference_mixin",
|
||||||
|
":hlo_pjrt_test_base",
|
||||||
":xla_internal_test_main", # fixdeps: keep
|
":xla_internal_test_main", # fixdeps: keep
|
||||||
"//xla:array2d",
|
"//xla:array2d",
|
||||||
"//xla:array3d",
|
"//xla:array3d",
|
||||||
|
|
|
||||||
11
third_party/xla/xla/tests/map_test.cc
vendored
11
third_party/xla/xla/tests/map_test.cc
vendored
|
|
@ -34,7 +34,8 @@ limitations under the License.
|
||||||
#include "xla/shape.h"
|
#include "xla/shape.h"
|
||||||
#include "xla/shape_util.h"
|
#include "xla/shape_util.h"
|
||||||
#include "xla/tests/client_library_test_runner_mixin.h"
|
#include "xla/tests/client_library_test_runner_mixin.h"
|
||||||
#include "xla/tests/hlo_test_base.h"
|
#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h"
|
||||||
|
#include "xla/tests/hlo_pjrt_test_base.h"
|
||||||
#include "xla/tsl/platform/status.h"
|
#include "xla/tsl/platform/status.h"
|
||||||
#include "xla/tsl/platform/test.h"
|
#include "xla/tsl/platform/test.h"
|
||||||
#include "xla/xla_data.pb.h"
|
#include "xla/xla_data.pb.h"
|
||||||
|
|
@ -42,7 +43,8 @@ limitations under the License.
|
||||||
namespace xla {
|
namespace xla {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
class MapTest : public ClientLibraryTestRunnerMixin<HloTestBase> {
|
class MapTest : public ClientLibraryTestRunnerMixin<
|
||||||
|
HloPjRtInterpreterReferenceMixin<HloPjRtTestBase>> {
|
||||||
public:
|
public:
|
||||||
MapTest() {
|
MapTest() {
|
||||||
mutable_debug_options()->add_xla_disable_hlo_passes("algsimp");
|
mutable_debug_options()->add_xla_disable_hlo_passes("algsimp");
|
||||||
|
|
@ -454,7 +456,7 @@ TEST_F(MapTest, MapOperationWithBuildError) {
|
||||||
"different element types: f32[] and u16[]"));
|
"different element types: f32[] and u16[]"));
|
||||||
}
|
}
|
||||||
|
|
||||||
class MapHloTest : public HloTestBase {};
|
using MapHloTest = HloPjRtInterpreterReferenceMixin<HloPjRtTestBase>;
|
||||||
|
|
||||||
// TODO(b/230123847): Enable this on GPU once mhlo allows mixed-type map.
|
// TODO(b/230123847): Enable this on GPU once mhlo allows mixed-type map.
|
||||||
TEST_F(MapHloTest, MapWithMixedInputTypes) {
|
TEST_F(MapHloTest, MapWithMixedInputTypes) {
|
||||||
|
|
@ -484,7 +486,8 @@ TEST_F(MapHloTest, MapWithMixedInputTypes) {
|
||||||
|
|
||||||
// MapTest disables inline and algsimp. MapTestWithFullOpt runs all
|
// MapTest disables inline and algsimp. MapTestWithFullOpt runs all
|
||||||
// optimizations.
|
// optimizations.
|
||||||
using MapTestWithFullOpt = ClientLibraryTestRunnerMixin<HloTestBase>;
|
using MapTestWithFullOpt = ClientLibraryTestRunnerMixin<
|
||||||
|
HloPjRtInterpreterReferenceMixin<HloPjRtTestBase>>;
|
||||||
|
|
||||||
// Regression test for b/31466798. The inliner simplifies map(param0, param1,
|
// Regression test for b/31466798. The inliner simplifies map(param0, param1,
|
||||||
// power) to power(param0, param1) without deleting the old subcomputation which
|
// power) to power(param0, param1) without deleting the old subcomputation which
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user