Migrate multioutput_fusion_test to use PjRt.

PiperOrigin-RevId: 826203532
This commit is contained in:
Niklas Vangerow 2025-10-30 15:09:38 -07:00 committed by TensorFlower Gardener
parent c3d0bf7023
commit 31bb7c01ff
2 changed files with 13 additions and 10 deletions

View File

@ -3188,11 +3188,12 @@ xla_test(
name = "multioutput_fusion_test",
srcs = ["multioutput_fusion_test.cc"],
backends = ["gpu"],
tags = ["test_migrated_to_hlo_runner_pjrt"],
deps = [
":client_library_test_base",
":hlo_test_base",
":hlo_pjrt_interpreter_reference_mixin",
":hlo_pjrt_test_base",
":literal_test_util",
":xla_internal_test_main",
":xla_internal_test_main", # fixdeps: keep
"//xla:error_spec",
"//xla:literal",
"//xla:literal_util",

View File

@ -30,7 +30,8 @@ limitations under the License.
#include "xla/literal_util.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/tests/hlo_test_base.h"
#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h"
#include "xla/tests/hlo_pjrt_test_base.h"
#include "xla/tests/literal_test_util.h"
#include "xla/tsl/platform/status.h"
#include "xla/tsl/platform/statusor.h"
@ -40,15 +41,16 @@ limitations under the License.
namespace xla {
namespace {
class MultiOutputFusionTest : public HloTestBase {
protected:
MultiOutputFusionTest() { error_spec_ = ErrorSpec{0.0001, 1e-2}; }
constexpr ErrorSpec kErrorSpec{0.0001, 1e-2};
class MultiOutputFusionTest
: public HloPjRtInterpreterReferenceMixin<HloPjRtTestBase> {
protected:
// Layout assignment assumes that there are no fusions in the input graph.
// Since the purpose of this test is to send pre-fused graphs to XLA, we have
// to do layout assignment ourselves.
DebugOptions GetDebugOptionsForTest() const override {
auto opts = HloTestBase::GetDebugOptionsForTest();
auto opts = HloPjRtTestBase::GetDebugOptionsForTest();
opts.add_xla_disable_hlo_passes("layout-assignment");
return opts;
}
@ -110,7 +112,7 @@ class MultiOutputFusionTest : public HloTestBase {
Literal literal_r0 = LiteralUtil::CreateR0<float>(-9.0f);
TF_ASSERT_OK_AND_ASSIGN(
Literal actual, Execute(std::move(hlo_module), {&literal_r0, &arg1}));
EXPECT_TRUE(LiteralTestUtil::Near(expect, actual, error_spec_));
EXPECT_TRUE(LiteralTestUtil::Near(expect, actual, kErrorSpec));
}
void RunTest1D(bool manual_fusion, int size) {
@ -174,7 +176,7 @@ class MultiOutputFusionTest : public HloTestBase {
Literal expect = LiteralUtil::CreateR1<float>({size * 1.5f * 3.5f});
TF_ASSERT_OK_AND_ASSIGN(Literal actual,
Execute(std::move(hlo_module), {&input0, &input1}));
EXPECT_TRUE(LiteralTestUtil::Near(expect, actual, error_spec_));
EXPECT_TRUE(LiteralTestUtil::Near(expect, actual, kErrorSpec));
}
};