mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Migrate map_test to use PjRt.
PiperOrigin-RevId: 826107887
This commit is contained in:
parent
dd3a14ace4
commit
061041963e
6
third_party/xla/xla/tests/BUILD
vendored
6
third_party/xla/xla/tests/BUILD
vendored
|
|
@ -678,9 +678,13 @@ xla_test(
|
|||
xla_test(
|
||||
name = "map_test",
|
||||
srcs = ["map_test.cc"],
|
||||
tags = [
|
||||
"test_migrated_to_hlo_runner_pjrt",
|
||||
],
|
||||
deps = [
|
||||
":client_library_test_runner_mixin",
|
||||
":hlo_test_base",
|
||||
":hlo_pjrt_interpreter_reference_mixin",
|
||||
":hlo_pjrt_test_base",
|
||||
":xla_internal_test_main", # fixdeps: keep
|
||||
"//xla:array2d",
|
||||
"//xla:array3d",
|
||||
|
|
|
|||
11
third_party/xla/xla/tests/map_test.cc
vendored
11
third_party/xla/xla/tests/map_test.cc
vendored
|
|
@ -34,7 +34,8 @@ limitations under the License.
|
|||
#include "xla/shape.h"
|
||||
#include "xla/shape_util.h"
|
||||
#include "xla/tests/client_library_test_runner_mixin.h"
|
||||
#include "xla/tests/hlo_test_base.h"
|
||||
#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h"
|
||||
#include "xla/tests/hlo_pjrt_test_base.h"
|
||||
#include "xla/tsl/platform/status.h"
|
||||
#include "xla/tsl/platform/test.h"
|
||||
#include "xla/xla_data.pb.h"
|
||||
|
|
@ -42,7 +43,8 @@ limitations under the License.
|
|||
namespace xla {
|
||||
namespace {
|
||||
|
||||
class MapTest : public ClientLibraryTestRunnerMixin<HloTestBase> {
|
||||
class MapTest : public ClientLibraryTestRunnerMixin<
|
||||
HloPjRtInterpreterReferenceMixin<HloPjRtTestBase>> {
|
||||
public:
|
||||
MapTest() {
|
||||
mutable_debug_options()->add_xla_disable_hlo_passes("algsimp");
|
||||
|
|
@ -454,7 +456,7 @@ TEST_F(MapTest, MapOperationWithBuildError) {
|
|||
"different element types: f32[] and u16[]"));
|
||||
}
|
||||
|
||||
class MapHloTest : public HloTestBase {};
|
||||
using MapHloTest = HloPjRtInterpreterReferenceMixin<HloPjRtTestBase>;
|
||||
|
||||
// TODO(b/230123847): Enable this on GPU once mhlo allows mixed-type map.
|
||||
TEST_F(MapHloTest, MapWithMixedInputTypes) {
|
||||
|
|
@ -484,7 +486,8 @@ TEST_F(MapHloTest, MapWithMixedInputTypes) {
|
|||
|
||||
// MapTest disables inline and algsimp. MapTestWithFullOpt runs all
|
||||
// optimizations.
|
||||
using MapTestWithFullOpt = ClientLibraryTestRunnerMixin<HloTestBase>;
|
||||
using MapTestWithFullOpt = ClientLibraryTestRunnerMixin<
|
||||
HloPjRtInterpreterReferenceMixin<HloPjRtTestBase>>;
|
||||
|
||||
// Regression test for b/31466798. The inliner simplifies map(param0, param1,
|
||||
// power) to power(param0, param1) without deleting the old subcomputation which
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user