mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Migrate reshape_test to use PjRt.
PiperOrigin-RevId: 826087067
This commit is contained in:
parent
6dd75c4e8b
commit
0c87bef802
6
third_party/xla/xla/tests/BUILD
vendored
6
third_party/xla/xla/tests/BUILD
vendored
|
|
@ -2491,9 +2491,13 @@ xla_test(
|
||||||
name = "reshape_test",
|
name = "reshape_test",
|
||||||
srcs = ["reshape_test.cc"],
|
srcs = ["reshape_test.cc"],
|
||||||
shard_count = 30,
|
shard_count = 30,
|
||||||
|
tags = [
|
||||||
|
"test_migrated_to_hlo_runner_pjrt",
|
||||||
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":client_library_test_runner_mixin",
|
":client_library_test_runner_mixin",
|
||||||
":hlo_test_base",
|
":hlo_pjrt_interpreter_reference_mixin",
|
||||||
|
":hlo_pjrt_test_base",
|
||||||
":literal_test_util",
|
":literal_test_util",
|
||||||
":xla_internal_test_main", # fixdeps: keep
|
":xla_internal_test_main", # fixdeps: keep
|
||||||
"//xla:array2d",
|
"//xla:array2d",
|
||||||
|
|
|
||||||
8
third_party/xla/xla/tests/reshape_test.cc
vendored
8
third_party/xla/xla/tests/reshape_test.cc
vendored
|
|
@ -35,7 +35,8 @@ limitations under the License.
|
||||||
#include "xla/shape.h"
|
#include "xla/shape.h"
|
||||||
#include "xla/shape_util.h"
|
#include "xla/shape_util.h"
|
||||||
#include "xla/tests/client_library_test_runner_mixin.h"
|
#include "xla/tests/client_library_test_runner_mixin.h"
|
||||||
#include "xla/tests/hlo_test_base.h"
|
#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h"
|
||||||
|
#include "xla/tests/hlo_pjrt_test_base.h"
|
||||||
#include "xla/tests/literal_test_util.h"
|
#include "xla/tests/literal_test_util.h"
|
||||||
#include "xla/tsl/platform/statusor.h"
|
#include "xla/tsl/platform/statusor.h"
|
||||||
#include "xla/tsl/platform/test.h"
|
#include "xla/tsl/platform/test.h"
|
||||||
|
|
@ -47,7 +48,8 @@ namespace xla {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
class ReshapeTest : public ::testing::WithParamInterface<PrimitiveType>,
|
class ReshapeTest : public ::testing::WithParamInterface<PrimitiveType>,
|
||||||
public ClientLibraryTestRunnerMixin<HloTestBase> {
|
public ClientLibraryTestRunnerMixin<
|
||||||
|
HloPjRtInterpreterReferenceMixin<HloPjRtTestBase>> {
|
||||||
public:
|
public:
|
||||||
ReshapeTest() { set_float_type(GetParam()); }
|
ReshapeTest() { set_float_type(GetParam()); }
|
||||||
|
|
||||||
|
|
@ -957,7 +959,7 @@ TEST_P(ReshapeTest, R4TwoMinorTransposeTrivialR2) {
|
||||||
INSTANTIATE_TEST_CASE_P(ReshapeTestInstance, ReshapeTest,
|
INSTANTIATE_TEST_CASE_P(ReshapeTestInstance, ReshapeTest,
|
||||||
::testing::ValuesIn({F32, BF16, F8E5M2, F8E4M3FN}));
|
::testing::ValuesIn({F32, BF16, F8E5M2, F8E4M3FN}));
|
||||||
|
|
||||||
using ReshapeHloTest = HloTestBase;
|
using ReshapeHloTest = HloPjRtInterpreterReferenceMixin<HloPjRtTestBase>;
|
||||||
|
|
||||||
TEST_F(ReshapeHloTest, NoHloPasses) {
|
TEST_F(ReshapeHloTest, NoHloPasses) {
|
||||||
const std::string hlo_string = R"(
|
const std::string hlo_string = R"(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user