mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 00:19:58 +01:00
Migrate reshape_test to use PjRt.
PiperOrigin-RevId: 826087067
This commit is contained in:
parent
6dd75c4e8b
commit
0c87bef802
6
third_party/xla/xla/tests/BUILD
vendored
6
third_party/xla/xla/tests/BUILD
vendored
|
|
@ -2491,9 +2491,13 @@ xla_test(
|
|||
name = "reshape_test",
|
||||
srcs = ["reshape_test.cc"],
|
||||
shard_count = 30,
|
||||
tags = [
|
||||
"test_migrated_to_hlo_runner_pjrt",
|
||||
],
|
||||
deps = [
|
||||
":client_library_test_runner_mixin",
|
||||
":hlo_test_base",
|
||||
":hlo_pjrt_interpreter_reference_mixin",
|
||||
":hlo_pjrt_test_base",
|
||||
":literal_test_util",
|
||||
":xla_internal_test_main", # fixdeps: keep
|
||||
"//xla:array2d",
|
||||
|
|
|
|||
8
third_party/xla/xla/tests/reshape_test.cc
vendored
8
third_party/xla/xla/tests/reshape_test.cc
vendored
|
|
@ -35,7 +35,8 @@ limitations under the License.
|
|||
#include "xla/shape.h"
|
||||
#include "xla/shape_util.h"
|
||||
#include "xla/tests/client_library_test_runner_mixin.h"
|
||||
#include "xla/tests/hlo_test_base.h"
|
||||
#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h"
|
||||
#include "xla/tests/hlo_pjrt_test_base.h"
|
||||
#include "xla/tests/literal_test_util.h"
|
||||
#include "xla/tsl/platform/statusor.h"
|
||||
#include "xla/tsl/platform/test.h"
|
||||
|
|
@ -47,7 +48,8 @@ namespace xla {
|
|||
namespace {
|
||||
|
||||
class ReshapeTest : public ::testing::WithParamInterface<PrimitiveType>,
|
||||
public ClientLibraryTestRunnerMixin<HloTestBase> {
|
||||
public ClientLibraryTestRunnerMixin<
|
||||
HloPjRtInterpreterReferenceMixin<HloPjRtTestBase>> {
|
||||
public:
|
||||
ReshapeTest() { set_float_type(GetParam()); }
|
||||
|
||||
|
|
@ -957,7 +959,7 @@ TEST_P(ReshapeTest, R4TwoMinorTransposeTrivialR2) {
|
|||
INSTANTIATE_TEST_CASE_P(ReshapeTestInstance, ReshapeTest,
|
||||
::testing::ValuesIn({F32, BF16, F8E5M2, F8E4M3FN}));
|
||||
|
||||
using ReshapeHloTest = HloTestBase;
|
||||
using ReshapeHloTest = HloPjRtInterpreterReferenceMixin<HloPjRtTestBase>;
|
||||
|
||||
TEST_F(ReshapeHloTest, NoHloPasses) {
|
||||
const std::string hlo_string = R"(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user