Migrate cpu_gpu_fusion_test to use PjRt.

PiperOrigin-RevId: 824627421
This commit is contained in:
Niklas Vangerow 2025-10-27 12:22:17 -07:00 committed by TensorFlower Gardener
parent a1219cfa94
commit 56a660da4b
2 changed files with 12 additions and 5 deletions

View File

@ -3121,9 +3121,13 @@ xla_test(
"gpu", "gpu",
"interpreter", "interpreter",
], ],
tags = [
"test_migrated_to_hlo_runner_pjrt",
],
deps = [ deps = [
":client_library_test_runner_mixin", ":client_library_test_runner_mixin",
":hlo_test_base", ":hlo_pjrt_interpreter_reference_mixin",
":hlo_pjrt_test_base",
":literal_test_util", ":literal_test_util",
":xla_internal_test_main", # fixdeps: keep ":xla_internal_test_main", # fixdeps: keep
"//xla:array2d", "//xla:array2d",

View File

@ -52,7 +52,8 @@ limitations under the License.
#include "xla/stream_executor/platform.h" #include "xla/stream_executor/platform.h"
#include "xla/stream_executor/stream_executor_memory_allocator.h" #include "xla/stream_executor/stream_executor_memory_allocator.h"
#include "xla/tests/client_library_test_runner_mixin.h" #include "xla/tests/client_library_test_runner_mixin.h"
#include "xla/tests/hlo_test_base.h" #include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h"
#include "xla/tests/hlo_pjrt_test_base.h"
#include "xla/tests/literal_test_util.h" #include "xla/tests/literal_test_util.h"
#include "xla/tsl/platform/env.h" #include "xla/tsl/platform/env.h"
#include "xla/tsl/platform/statusor.h" #include "xla/tsl/platform/statusor.h"
@ -74,7 +75,8 @@ const float test_float_vals[3][test_width][test_height] = {
// Test whether fusion operations are emitted with no errors and compute // Test whether fusion operations are emitted with no errors and compute
// accurate outputs. // accurate outputs.
class CpuGpuFusionTest : public HloTestBase { class CpuGpuFusionTest
: public HloPjRtInterpreterReferenceMixin<HloPjRtTestBase> {
protected: protected:
template <typename T, int Arity> template <typename T, int Arity>
void TestElementwise2D( void TestElementwise2D(
@ -155,7 +157,7 @@ class CpuGpuFusionTest : public HloTestBase {
bool ComputeElementwiseAnswerCompare(ComparisonDirection direction, bool ComputeElementwiseAnswerCompare(ComparisonDirection direction,
absl::Span<const float> xs); absl::Span<const float> xs);
DebugOptions GetDebugOptionsForTest() const override { DebugOptions GetDebugOptionsForTest() const override {
DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); DebugOptions debug_options = HloPjRtTestBase::GetDebugOptionsForTest();
debug_options.add_xla_disable_hlo_passes("layout-assignment"); debug_options.add_xla_disable_hlo_passes("layout-assignment");
return debug_options; return debug_options;
} }
@ -884,7 +886,8 @@ TEST_F(CpuGpuFusionTest, Clamp2D) {
} }
class FusionClientLibraryTest class FusionClientLibraryTest
: public ClientLibraryTestRunnerMixin<HloTestBase> {}; : public ClientLibraryTestRunnerMixin<
HloPjRtInterpreterReferenceMixin<HloPjRtTestBase>> {};
TEST_F(FusionClientLibraryTest, ManyLayoutTransformations) { TEST_F(FusionClientLibraryTest, ManyLayoutTransformations) {
// On the GPU backend, it's possible to have too many transposes within one // On the GPU backend, it's possible to have too many transposes within one