Migrate reshape_test to use PjRt.

PiperOrigin-RevId: 826087067
This commit is contained in:
Niklas Vangerow 2025-10-30 10:19:16 -07:00 committed by TensorFlower Gardener
parent 6dd75c4e8b
commit 0c87bef802
2 changed files with 10 additions and 4 deletions

View File

@ -2491,9 +2491,13 @@ xla_test(
name = "reshape_test", name = "reshape_test",
srcs = ["reshape_test.cc"], srcs = ["reshape_test.cc"],
shard_count = 30, shard_count = 30,
tags = [
"test_migrated_to_hlo_runner_pjrt",
],
deps = [ deps = [
":client_library_test_runner_mixin", ":client_library_test_runner_mixin",
":hlo_test_base", ":hlo_pjrt_interpreter_reference_mixin",
":hlo_pjrt_test_base",
":literal_test_util", ":literal_test_util",
":xla_internal_test_main", # fixdeps: keep ":xla_internal_test_main", # fixdeps: keep
"//xla:array2d", "//xla:array2d",

View File

@ -35,7 +35,8 @@ limitations under the License.
#include "xla/shape.h" #include "xla/shape.h"
#include "xla/shape_util.h" #include "xla/shape_util.h"
#include "xla/tests/client_library_test_runner_mixin.h" #include "xla/tests/client_library_test_runner_mixin.h"
#include "xla/tests/hlo_test_base.h" #include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h"
#include "xla/tests/hlo_pjrt_test_base.h"
#include "xla/tests/literal_test_util.h" #include "xla/tests/literal_test_util.h"
#include "xla/tsl/platform/statusor.h" #include "xla/tsl/platform/statusor.h"
#include "xla/tsl/platform/test.h" #include "xla/tsl/platform/test.h"
@ -47,7 +48,8 @@ namespace xla {
namespace { namespace {
class ReshapeTest : public ::testing::WithParamInterface<PrimitiveType>, class ReshapeTest : public ::testing::WithParamInterface<PrimitiveType>,
public ClientLibraryTestRunnerMixin<HloTestBase> { public ClientLibraryTestRunnerMixin<
HloPjRtInterpreterReferenceMixin<HloPjRtTestBase>> {
public: public:
ReshapeTest() { set_float_type(GetParam()); } ReshapeTest() { set_float_type(GetParam()); }
@ -957,7 +959,7 @@ TEST_P(ReshapeTest, R4TwoMinorTransposeTrivialR2) {
INSTANTIATE_TEST_CASE_P(ReshapeTestInstance, ReshapeTest, INSTANTIATE_TEST_CASE_P(ReshapeTestInstance, ReshapeTest,
::testing::ValuesIn({F32, BF16, F8E5M2, F8E4M3FN})); ::testing::ValuesIn({F32, BF16, F8E5M2, F8E4M3FN}));
using ReshapeHloTest = HloTestBase; using ReshapeHloTest = HloPjRtInterpreterReferenceMixin<HloPjRtTestBase>;
TEST_F(ReshapeHloTest, NoHloPasses) { TEST_F(ReshapeHloTest, NoHloPasses) {
const std::string hlo_string = R"( const std::string hlo_string = R"(