mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 00:19:58 +01:00
Migrate multioutput_fusion_test to use PjRt.
PiperOrigin-RevId: 826203532
This commit is contained in:
parent
c3d0bf7023
commit
31bb7c01ff
7
third_party/xla/xla/tests/BUILD
vendored
7
third_party/xla/xla/tests/BUILD
vendored
|
|
@ -3188,11 +3188,12 @@ xla_test(
|
|||
name = "multioutput_fusion_test",
|
||||
srcs = ["multioutput_fusion_test.cc"],
|
||||
backends = ["gpu"],
|
||||
tags = ["test_migrated_to_hlo_runner_pjrt"],
|
||||
deps = [
|
||||
":client_library_test_base",
|
||||
":hlo_test_base",
|
||||
":hlo_pjrt_interpreter_reference_mixin",
|
||||
":hlo_pjrt_test_base",
|
||||
":literal_test_util",
|
||||
":xla_internal_test_main",
|
||||
":xla_internal_test_main", # fixdeps: keep
|
||||
"//xla:error_spec",
|
||||
"//xla:literal",
|
||||
"//xla:literal_util",
|
||||
|
|
|
|||
|
|
@ -30,7 +30,8 @@ limitations under the License.
|
|||
#include "xla/literal_util.h"
|
||||
#include "xla/shape.h"
|
||||
#include "xla/shape_util.h"
|
||||
#include "xla/tests/hlo_test_base.h"
|
||||
#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h"
|
||||
#include "xla/tests/hlo_pjrt_test_base.h"
|
||||
#include "xla/tests/literal_test_util.h"
|
||||
#include "xla/tsl/platform/status.h"
|
||||
#include "xla/tsl/platform/statusor.h"
|
||||
|
|
@ -40,15 +41,16 @@ limitations under the License.
|
|||
namespace xla {
|
||||
namespace {
|
||||
|
||||
class MultiOutputFusionTest : public HloTestBase {
|
||||
protected:
|
||||
MultiOutputFusionTest() { error_spec_ = ErrorSpec{0.0001, 1e-2}; }
|
||||
constexpr ErrorSpec kErrorSpec{0.0001, 1e-2};
|
||||
|
||||
class MultiOutputFusionTest
|
||||
: public HloPjRtInterpreterReferenceMixin<HloPjRtTestBase> {
|
||||
protected:
|
||||
// Layout assignment assumes that there are no fusions in the input graph.
|
||||
// Since the purpose of this test is to send pre-fused graphs to XLA, we have
|
||||
// to do layout assignment ourselves.
|
||||
DebugOptions GetDebugOptionsForTest() const override {
|
||||
auto opts = HloTestBase::GetDebugOptionsForTest();
|
||||
auto opts = HloPjRtTestBase::GetDebugOptionsForTest();
|
||||
opts.add_xla_disable_hlo_passes("layout-assignment");
|
||||
return opts;
|
||||
}
|
||||
|
|
@ -110,7 +112,7 @@ class MultiOutputFusionTest : public HloTestBase {
|
|||
Literal literal_r0 = LiteralUtil::CreateR0<float>(-9.0f);
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
Literal actual, Execute(std::move(hlo_module), {&literal_r0, &arg1}));
|
||||
EXPECT_TRUE(LiteralTestUtil::Near(expect, actual, error_spec_));
|
||||
EXPECT_TRUE(LiteralTestUtil::Near(expect, actual, kErrorSpec));
|
||||
}
|
||||
|
||||
void RunTest1D(bool manual_fusion, int size) {
|
||||
|
|
@ -174,7 +176,7 @@ class MultiOutputFusionTest : public HloTestBase {
|
|||
Literal expect = LiteralUtil::CreateR1<float>({size * 1.5f * 3.5f});
|
||||
TF_ASSERT_OK_AND_ASSIGN(Literal actual,
|
||||
Execute(std::move(hlo_module), {&input0, &input1}));
|
||||
EXPECT_TRUE(LiteralTestUtil::Near(expect, actual, error_spec_));
|
||||
EXPECT_TRUE(LiteralTestUtil::Near(expect, actual, kErrorSpec));
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user