mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
[XLA] Propagate debug option flags to hlo_test_base.
Specific HLO tests have to replace the generic test_main target with a manual main() that invokes RUN_ALL_TESTS. To get access to a module with debug options set up, a new convenience method is created on HloTestBase. Initially algebraic_simplifier_test is modified as a canary; in a followup we'll convert all HLO tests to this approach. PiperOrigin-RevId: 158309488
This commit is contained in:
parent
0770393e95
commit
599727c654
|
|
@ -868,9 +868,10 @@ cc_test(
|
|||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:test",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ limitations under the License.
|
|||
#include <utility>
|
||||
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/ptr_util.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
|
|
@ -32,6 +33,7 @@ limitations under the License.
|
|||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace op = xla::testing::opcode_matchers;
|
||||
|
||||
|
|
@ -59,7 +61,7 @@ TEST_F(AlgebraicSimplifierTest, AddZero) {
|
|||
builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param0, zero));
|
||||
|
||||
auto module = MakeUnique<HloModule>(TestName());
|
||||
auto module = CreateNewModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
HloInstruction* root = computation->root_instruction();
|
||||
EXPECT_EQ(root->opcode(), HloOpcode::kAdd);
|
||||
|
|
@ -82,7 +84,7 @@ TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR0Operand) {
|
|||
builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, bcast, param0));
|
||||
|
||||
auto module = MakeUnique<HloModule>(TestName());
|
||||
auto module = CreateNewModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
HloInstruction* root = computation->root_instruction();
|
||||
EXPECT_EQ(root->opcode(), HloOpcode::kAdd);
|
||||
|
|
@ -105,7 +107,7 @@ TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR1Operand) {
|
|||
builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, bcast, param0));
|
||||
|
||||
auto module = MakeUnique<HloModule>(TestName());
|
||||
auto module = CreateNewModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
HloInstruction* root = computation->root_instruction();
|
||||
EXPECT_EQ(root->opcode(), HloOpcode::kAdd);
|
||||
|
|
@ -127,7 +129,7 @@ TEST_F(AlgebraicSimplifierTest, SubZero) {
|
|||
builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(r0f32, HloOpcode::kSubtract, param0, zero));
|
||||
|
||||
auto module = MakeUnique<HloModule>(TestName());
|
||||
auto module = CreateNewModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
HloInstruction* root = computation->root_instruction();
|
||||
EXPECT_EQ(root->opcode(), HloOpcode::kSubtract);
|
||||
|
|
@ -149,7 +151,7 @@ TEST_F(AlgebraicSimplifierTest, DivOneScalar) {
|
|||
HloInstruction* div = builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, one));
|
||||
|
||||
auto module = MakeUnique<HloModule>(TestName());
|
||||
auto module = CreateNewModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
HloInstruction* root = computation->root_instruction();
|
||||
EXPECT_EQ(root, div);
|
||||
|
|
@ -171,7 +173,7 @@ TEST_F(AlgebraicSimplifierTest, DivOneArray) {
|
|||
HloInstruction* div = builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, param0, one));
|
||||
|
||||
auto module = MakeUnique<HloModule>(TestName());
|
||||
auto module = CreateNewModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
HloInstruction* root = computation->root_instruction();
|
||||
EXPECT_EQ(root, div);
|
||||
|
|
@ -199,7 +201,7 @@ TEST_F(AlgebraicSimplifierTest, SelectMakeTuple) {
|
|||
HloInstruction* add = builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, get, param2));
|
||||
|
||||
auto module = MakeUnique<HloModule>(TestName());
|
||||
auto module = CreateNewModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
HloInstruction* root = computation->root_instruction();
|
||||
EXPECT_EQ(root, add);
|
||||
|
|
@ -225,7 +227,7 @@ TEST_F(AlgebraicSimplifierTest, ExpDiv) {
|
|||
builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, exp0, exp1));
|
||||
|
||||
auto module = MakeUnique<HloModule>(TestName());
|
||||
auto module = CreateNewModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
EXPECT_THAT(computation->root_instruction(),
|
||||
|
|
@ -250,7 +252,7 @@ TEST_F(AlgebraicSimplifierTest, LnExp) {
|
|||
builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, exp0));
|
||||
|
||||
auto module = MakeUnique<HloModule>(TestName());
|
||||
auto module = CreateNewModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
EXPECT_THAT(computation->root_instruction(), op::Log(op::Exp(param0)));
|
||||
|
|
@ -279,7 +281,7 @@ TEST_F(AlgebraicSimplifierTest, LnExpDiv) {
|
|||
builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, div));
|
||||
|
||||
auto module = MakeUnique<HloModule>(TestName());
|
||||
auto module = CreateNewModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
EXPECT_THAT(computation->root_instruction(),
|
||||
|
|
@ -304,7 +306,7 @@ TEST_F(AlgebraicSimplifierTest, Pow0Scalar) {
|
|||
builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, zero));
|
||||
|
||||
auto module = MakeUnique<HloModule>(TestName());
|
||||
auto module = CreateNewModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
EXPECT_THAT(computation->root_instruction(), op::Power(param0, zero));
|
||||
|
|
@ -329,7 +331,7 @@ TEST_F(AlgebraicSimplifierTest, Pow0Vector) {
|
|||
builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, param0, zero));
|
||||
|
||||
auto module = MakeUnique<HloModule>(TestName());
|
||||
auto module = CreateNewModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
EXPECT_THAT(computation->root_instruction(), op::Power(param0, zero));
|
||||
|
|
@ -359,7 +361,7 @@ TEST_F(AlgebraicSimplifierTest, Pow1) {
|
|||
builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, one));
|
||||
|
||||
auto module = MakeUnique<HloModule>(TestName());
|
||||
auto module = CreateNewModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
EXPECT_THAT(computation->root_instruction(), op::Power(param0, one));
|
||||
|
|
@ -382,7 +384,7 @@ TEST_F(AlgebraicSimplifierTest, Pow2) {
|
|||
builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, two));
|
||||
|
||||
auto module = MakeUnique<HloModule>(TestName());
|
||||
auto module = CreateNewModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
EXPECT_THAT(computation->root_instruction(), op::Power(param0, two));
|
||||
|
|
@ -405,7 +407,7 @@ TEST_F(AlgebraicSimplifierTest, PowNegative1) {
|
|||
builder.AddInstruction(HloInstruction::CreateBinary(r0f32, HloOpcode::kPower,
|
||||
param0, negative_one));
|
||||
|
||||
auto module = MakeUnique<HloModule>(TestName());
|
||||
auto module = CreateNewModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
EXPECT_THAT(computation->root_instruction(), op::Power(param0, negative_one));
|
||||
|
|
@ -434,7 +436,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapeBroadcast) {
|
|||
ShapeUtil::MakeShape(F32, {3, 2}), broadcast));
|
||||
|
||||
auto computation = builder.Build();
|
||||
auto module = MakeUnique<HloModule>(TestName());
|
||||
auto module = CreateNewModule();
|
||||
module->AddEntryComputation(std::move(computation));
|
||||
|
||||
EXPECT_THAT(module->entry_computation()->root_instruction(),
|
||||
|
|
@ -455,7 +457,7 @@ TEST_F(AlgebraicSimplifierTest, ConvertBetweenSameType) {
|
|||
builder.AddInstruction(
|
||||
HloInstruction::CreateConvert(ShapeUtil::MakeShape(F32, {}), input));
|
||||
|
||||
auto module = MakeUnique<HloModule>(TestName());
|
||||
auto module = CreateNewModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
EXPECT_THAT(computation->root_instruction(), op::Convert(input));
|
||||
|
|
@ -476,7 +478,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveCopy) {
|
|||
builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0));
|
||||
|
||||
auto module = MakeUnique<HloModule>(TestName());
|
||||
auto module = CreateNewModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
EXPECT_THAT(computation->root_instruction(), op::Copy(param0));
|
||||
|
|
@ -497,7 +499,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveUnaryConcatenate) {
|
|||
builder.AddInstruction(
|
||||
HloInstruction::CreateConcatenate(param0->shape(), {param0}, 0));
|
||||
|
||||
auto module = MakeUnique<HloModule>(TestName());
|
||||
auto module = CreateNewModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
EXPECT_THAT(computation->root_instruction(), op::Concatenate(param0));
|
||||
|
|
@ -527,7 +529,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveEmptyConcatenateOperands) {
|
|||
builder.AddInstruction(HloInstruction::CreateConcatenate(
|
||||
result_shape, {empty_literal, param0, param0, empty_slice, param1}, 0));
|
||||
|
||||
auto module = MakeUnique<HloModule>(TestName());
|
||||
auto module = CreateNewModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
EXPECT_THAT(
|
||||
|
|
@ -558,7 +560,7 @@ TEST_F(AlgebraicSimplifierTest, OnlyEmptyConcatenateOperands) {
|
|||
builder.AddInstruction(HloInstruction::CreateConcatenate(
|
||||
result_shape, {empty_literal, empty_slice}, 0));
|
||||
|
||||
auto module = MakeUnique<HloModule>(TestName());
|
||||
auto module = CreateNewModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
EXPECT_THAT(computation->root_instruction(),
|
||||
|
|
@ -581,7 +583,7 @@ TEST_F(AlgebraicSimplifierTest, CopyWithDifferentLayout) {
|
|||
HloInstruction* copy = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0));
|
||||
|
||||
auto module = MakeUnique<HloModule>(TestName());
|
||||
auto module = CreateNewModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
// Set to different layouts.
|
||||
|
|
@ -608,7 +610,7 @@ TEST_F(AlgebraicSimplifierTest, CopyWithSameLayout) {
|
|||
HloInstruction* copy = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0));
|
||||
|
||||
auto module = MakeUnique<HloModule>(TestName());
|
||||
auto module = CreateNewModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
// Set to same layouts.
|
||||
|
|
@ -640,7 +642,7 @@ TEST_F(AlgebraicSimplifierTest, NoBitcastAdded) {
|
|||
*reshape->mutable_shape()->mutable_layout() =
|
||||
LayoutUtil::MakeLayout({0, 1, 2, 3, 4, 5});
|
||||
|
||||
auto module = MakeUnique<HloModule>(TestName());
|
||||
auto module = CreateNewModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
EXPECT_THAT(computation->root_instruction(), op::Reshape(param0));
|
||||
|
|
@ -686,7 +688,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapeReplacedWithBitcast) {
|
|||
builder.AddInstruction(HloInstruction::CreateTuple(
|
||||
{transformable_reshape, dimensions_wrong_reshape, layout_wrong_reshape}));
|
||||
|
||||
auto module = MakeUnique<HloModule>(TestName());
|
||||
auto module = CreateNewModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
EXPECT_THAT(computation->root_instruction(),
|
||||
|
|
@ -716,7 +718,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapeAfterEffectiveUnary) {
|
|||
builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(ShapeUtil::MakeShape(F32, {1, 2, 3, 4, 5}),
|
||||
HloOpcode::kMaximum, movable_reshape, zero));
|
||||
auto module = MakeUnique<HloModule>(TestName());
|
||||
auto module = CreateNewModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
EXPECT_THAT(computation->root_instruction(),
|
||||
|
|
@ -744,7 +746,7 @@ TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast1) {
|
|||
*transpose->mutable_shape()->mutable_layout() =
|
||||
LayoutUtil::MakeLayout({0, 1, 2, 3});
|
||||
|
||||
auto module = MakeUnique<HloModule>(TestName());
|
||||
auto module = CreateNewModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
EXPECT_THAT(computation->root_instruction(), op::Transpose(param));
|
||||
|
|
@ -771,7 +773,7 @@ TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast2) {
|
|||
*transpose->mutable_shape()->mutable_layout() =
|
||||
LayoutUtil::MakeLayout({3, 1, 2, 0});
|
||||
|
||||
auto module = MakeUnique<HloModule>(TestName());
|
||||
auto module = CreateNewModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
EXPECT_THAT(computation->root_instruction(), op::Transpose(param));
|
||||
|
|
@ -797,7 +799,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapesMerged) {
|
|||
builder.AddInstruction(HloInstruction::CreateReshape(
|
||||
ShapeUtil::MakeShape(F32, {1, 2, 1, 1, 2, 1}), reshape1));
|
||||
|
||||
auto module = MakeUnique<HloModule>(TestName());
|
||||
auto module = CreateNewModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
EXPECT_THAT(computation->root_instruction(),
|
||||
|
|
@ -825,7 +827,7 @@ TEST_F(AlgebraicSimplifierTest, CopiesMerged) {
|
|||
ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {0, 2, 1}),
|
||||
HloOpcode::kCopy, copy1));
|
||||
|
||||
auto module = MakeUnique<HloModule>(TestName());
|
||||
auto module = CreateNewModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
EXPECT_THAT(computation->root_instruction(), op::Copy(op::Copy(param0)));
|
||||
|
|
@ -850,7 +852,7 @@ TEST_F(AlgebraicSimplifierTest, TransposesMerged) {
|
|||
builder.AddInstruction(HloInstruction::CreateTranspose(
|
||||
ShapeUtil::MakeShape(F32, {4, 3, 2}), transpose1, {1, 0, 2}));
|
||||
|
||||
auto module = MakeUnique<HloModule>(TestName());
|
||||
auto module = CreateNewModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
EXPECT_THAT(computation->root_instruction(), op::Transpose(transpose1));
|
||||
|
|
@ -874,7 +876,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapeAndBroadcastMerged) {
|
|||
builder.AddInstruction(HloInstruction::CreateBroadcast(
|
||||
ShapeUtil::MakeShape(F32, {1, 2, 3, 5, 1}), reshape1, {0, 2, 3}));
|
||||
|
||||
auto module = MakeUnique<HloModule>(TestName());
|
||||
auto module = CreateNewModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
EXPECT_THAT(computation->root_instruction(),
|
||||
|
|
@ -897,7 +899,7 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshapeMerged) {
|
|||
builder.AddInstruction(HloInstruction::CreateReshape(
|
||||
ShapeUtil::MakeShape(F32, {2, 3, 7, 2, 1, 3, 2}), broadcast1));
|
||||
|
||||
auto module = MakeUnique<HloModule>(TestName());
|
||||
auto module = CreateNewModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
EXPECT_THAT(computation->root_instruction(),
|
||||
|
|
@ -919,7 +921,7 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x1_3) {
|
|||
builder.AddInstruction(
|
||||
HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {3}), broadcast));
|
||||
|
||||
auto module = MakeUnique<HloModule>(TestName());
|
||||
auto module = CreateNewModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
EXPECT_THAT(computation->root_instruction(),
|
||||
|
|
@ -942,7 +944,7 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4_6x1x1x4) {
|
|||
builder.AddInstruction(HloInstruction::CreateReshape(
|
||||
ShapeUtil::MakeShape(F32, {6, 1, 1, 4}), broadcast));
|
||||
|
||||
auto module = MakeUnique<HloModule>(TestName());
|
||||
auto module = CreateNewModule();
|
||||
HloComputation* computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
EXPECT_THAT(computation->root_instruction(),
|
||||
|
|
@ -966,7 +968,7 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x2x1_6x1x1x1) {
|
|||
builder.AddInstruction(HloInstruction::CreateReshape(
|
||||
ShapeUtil::MakeShape(F32, {6, 1, 1, 1}), broadcast));
|
||||
|
||||
auto module = MakeUnique<HloModule>(TestName());
|
||||
auto module = CreateNewModule();
|
||||
HloComputation* computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
EXPECT_THAT(computation->root_instruction(),
|
||||
|
|
@ -992,7 +994,7 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4x2_6x8) {
|
|||
builder.AddInstruction(HloInstruction::CreateReshape(
|
||||
ShapeUtil::MakeShape(F32, {6, 8}), broadcast));
|
||||
|
||||
auto module = MakeUnique<HloModule>(TestName());
|
||||
auto module = CreateNewModule();
|
||||
HloComputation* computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
EXPECT_THAT(computation->root_instruction(),
|
||||
|
|
@ -1697,7 +1699,7 @@ TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) {
|
|||
builder.AddInstruction(
|
||||
HloInstruction::CreateCall(r1f32, {zero, one}, dot_computation.get()));
|
||||
|
||||
auto module = MakeUnique<HloModule>(TestName());
|
||||
auto module = CreateNewModule();
|
||||
module->AddEmbeddedComputation(std::move(dot_computation));
|
||||
module->AddEntryComputation(call_builder.Build());
|
||||
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
|
||||
|
|
@ -1707,3 +1709,20 @@ TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) {
|
|||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
std::vector<tensorflow::Flag> flag_list;
|
||||
xla::legacy_flags::AppendDebugOptionsFlags(&flag_list);
|
||||
xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
|
||||
const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
|
||||
if (!parse_result) {
|
||||
LOG(ERROR) << "\n" << usage;
|
||||
return 2;
|
||||
}
|
||||
testing::InitGoogleTest(&argc, argv);
|
||||
if (argc > 1) {
|
||||
LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
|
||||
return 2;
|
||||
}
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -93,6 +93,7 @@ cc_library(
|
|||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
|
||||
"//tensorflow/compiler/xla/legacy_flags:hlo_test_base_flags",
|
||||
"//tensorflow/compiler/xla/service",
|
||||
"//tensorflow/compiler/xla/service:backend",
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ limitations under the License.
|
|||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
|
||||
#include "tensorflow/compiler/xla/legacy_flags/hlo_test_base_flags.h"
|
||||
#include "tensorflow/compiler/xla/ptr_util.h"
|
||||
#include "tensorflow/compiler/xla/service/backend.h"
|
||||
|
|
@ -54,6 +55,8 @@ struct HloTestBase::EigenThreadPoolWrapper {
|
|||
|
||||
HloTestBase::HloTestBase()
|
||||
: backend_(Backend::CreateDefaultBackend().ConsumeValueOrDie()) {
|
||||
// TODO(b/62411181): get rid of this flag entirely when the usual debug flags
|
||||
// are piped to all HLO tests.
|
||||
test_hlo_dumper_ = [](const HloModule& module, const string& label) {
|
||||
legacy_flags::HloTestBaseFlags* flags = legacy_flags::GetHloTestBaseFlags();
|
||||
if (flags->xla_hlo_test_generate_hlo_graph) {
|
||||
|
|
@ -73,6 +76,13 @@ HloTestBase::~HloTestBase() {
|
|||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<HloModule> HloTestBase::CreateNewModule() {
|
||||
HloModuleConfig config;
|
||||
config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags());
|
||||
return MakeUnique<HloModule>(TestName(), VersionedComputationHandle(),
|
||||
config);
|
||||
}
|
||||
|
||||
StatusOr<perftools::gputools::DeviceMemoryBase> HloTestBase::Execute(
|
||||
std::unique_ptr<HloModule> module,
|
||||
tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
|
||||
|
|
|
|||
|
|
@ -44,6 +44,12 @@ class HloTestBase : public ::testing::Test {
|
|||
|
||||
~HloTestBase() override;
|
||||
|
||||
// Creates a new HLO module for a test. The module created will have
|
||||
// TestName() for its name; it will also automatically populate its debug
|
||||
// options from command-line flags. It's recommended to use this method to
|
||||
// create all HloModules for tests.
|
||||
std::unique_ptr<HloModule> CreateNewModule();
|
||||
|
||||
// Executes the given module and returns a global data handle.
|
||||
StatusOr<perftools::gputools::DeviceMemoryBase> Execute(
|
||||
std::unique_ptr<HloModule> module,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user