mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Migrate broadcast_simple_test to use PjRt.
PiperOrigin-RevId: 825775803
This commit is contained in:
parent
152b2338d9
commit
fe2a783077
6
third_party/xla/xla/tests/BUILD
vendored
6
third_party/xla/xla/tests/BUILD
vendored
|
|
@ -2319,9 +2319,13 @@ xla_test(
|
|||
xla_test(
|
||||
name = "broadcast_simple_test",
|
||||
srcs = ["broadcast_simple_test.cc"],
|
||||
tags = [
|
||||
"test_migrated_to_hlo_runner_pjrt",
|
||||
],
|
||||
deps = [
|
||||
":client_library_test_runner_mixin",
|
||||
":hlo_test_base",
|
||||
":hlo_pjrt_interpreter_reference_mixin",
|
||||
":hlo_pjrt_test_base",
|
||||
":xla_internal_test_main", # fixdeps: keep
|
||||
"//xla:array2d",
|
||||
"//xla:array3d",
|
||||
|
|
|
|||
|
|
@ -32,7 +32,8 @@ limitations under the License.
|
|||
#include "xla/shape.h"
|
||||
#include "xla/shape_util.h"
|
||||
#include "xla/tests/client_library_test_runner_mixin.h"
|
||||
#include "xla/tests/hlo_test_base.h"
|
||||
#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h"
|
||||
#include "xla/tests/hlo_pjrt_test_base.h"
|
||||
#include "xla/tsl/platform/test.h"
|
||||
|
||||
namespace xla {
|
||||
|
|
@ -101,7 +102,8 @@ float ApplyOpToFloats(HloOpcode op, float lhs, float rhs) {
|
|||
}
|
||||
}
|
||||
|
||||
using BroadcastSimpleTest = ClientLibraryTestRunnerMixin<HloTestBase>;
|
||||
using BroadcastSimpleTest = ClientLibraryTestRunnerMixin<
|
||||
HloPjRtInterpreterReferenceMixin<HloPjRtTestBase>>;
|
||||
|
||||
TEST_F(BroadcastSimpleTest, ScalarNoOpBroadcast) {
|
||||
XlaBuilder b(TestName());
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user