[XLA] Propagate debug option flags to hlo_test_base.

Specific HLO tests have to replace the generic test_main target with a manual
main() that invokes RUN_ALL_TESTS.

To get access to a module with debug options set up, a new convenience method
is created on HloTestBase.

Initially algebraic_simplifier_test is modified as a canary; in a followup
we'll convert all HLO tests to this approach.

PiperOrigin-RevId: 158309488
This commit is contained in:
Eli Bendersky 2017-06-07 13:31:32 -07:00 committed by TensorFlower Gardener
parent 0770393e95
commit 599727c654
5 changed files with 76 additions and 39 deletions

View File

@ -868,9 +868,10 @@ cc_test(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/core:lib",
"//tensorflow/core:test_main",
"//tensorflow/core:test",
],
)

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <utility>
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@ -32,6 +33,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
namespace op = xla::testing::opcode_matchers;
@ -59,7 +61,7 @@ TEST_F(AlgebraicSimplifierTest, AddZero) {
builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param0, zero));
auto module = MakeUnique<HloModule>(TestName());
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
HloInstruction* root = computation->root_instruction();
EXPECT_EQ(root->opcode(), HloOpcode::kAdd);
@ -82,7 +84,7 @@ TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR0Operand) {
builder.AddInstruction(
HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, bcast, param0));
auto module = MakeUnique<HloModule>(TestName());
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
HloInstruction* root = computation->root_instruction();
EXPECT_EQ(root->opcode(), HloOpcode::kAdd);
@ -105,7 +107,7 @@ TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR1Operand) {
builder.AddInstruction(
HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, bcast, param0));
auto module = MakeUnique<HloModule>(TestName());
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
HloInstruction* root = computation->root_instruction();
EXPECT_EQ(root->opcode(), HloOpcode::kAdd);
@ -127,7 +129,7 @@ TEST_F(AlgebraicSimplifierTest, SubZero) {
builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kSubtract, param0, zero));
auto module = MakeUnique<HloModule>(TestName());
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
HloInstruction* root = computation->root_instruction();
EXPECT_EQ(root->opcode(), HloOpcode::kSubtract);
@ -149,7 +151,7 @@ TEST_F(AlgebraicSimplifierTest, DivOneScalar) {
HloInstruction* div = builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, one));
auto module = MakeUnique<HloModule>(TestName());
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
HloInstruction* root = computation->root_instruction();
EXPECT_EQ(root, div);
@ -171,7 +173,7 @@ TEST_F(AlgebraicSimplifierTest, DivOneArray) {
HloInstruction* div = builder.AddInstruction(
HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, param0, one));
auto module = MakeUnique<HloModule>(TestName());
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
HloInstruction* root = computation->root_instruction();
EXPECT_EQ(root, div);
@ -199,7 +201,7 @@ TEST_F(AlgebraicSimplifierTest, SelectMakeTuple) {
HloInstruction* add = builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, get, param2));
auto module = MakeUnique<HloModule>(TestName());
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
HloInstruction* root = computation->root_instruction();
EXPECT_EQ(root, add);
@ -225,7 +227,7 @@ TEST_F(AlgebraicSimplifierTest, ExpDiv) {
builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, exp0, exp1));
auto module = MakeUnique<HloModule>(TestName());
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
@ -250,7 +252,7 @@ TEST_F(AlgebraicSimplifierTest, LnExp) {
builder.AddInstruction(
HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, exp0));
auto module = MakeUnique<HloModule>(TestName());
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(), op::Log(op::Exp(param0)));
@ -279,7 +281,7 @@ TEST_F(AlgebraicSimplifierTest, LnExpDiv) {
builder.AddInstruction(
HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, div));
auto module = MakeUnique<HloModule>(TestName());
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
@ -304,7 +306,7 @@ TEST_F(AlgebraicSimplifierTest, Pow0Scalar) {
builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, zero));
auto module = MakeUnique<HloModule>(TestName());
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(), op::Power(param0, zero));
@ -329,7 +331,7 @@ TEST_F(AlgebraicSimplifierTest, Pow0Vector) {
builder.AddInstruction(
HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, param0, zero));
auto module = MakeUnique<HloModule>(TestName());
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(), op::Power(param0, zero));
@ -359,7 +361,7 @@ TEST_F(AlgebraicSimplifierTest, Pow1) {
builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, one));
auto module = MakeUnique<HloModule>(TestName());
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(), op::Power(param0, one));
@ -382,7 +384,7 @@ TEST_F(AlgebraicSimplifierTest, Pow2) {
builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, two));
auto module = MakeUnique<HloModule>(TestName());
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(), op::Power(param0, two));
@ -405,7 +407,7 @@ TEST_F(AlgebraicSimplifierTest, PowNegative1) {
builder.AddInstruction(HloInstruction::CreateBinary(r0f32, HloOpcode::kPower,
param0, negative_one));
auto module = MakeUnique<HloModule>(TestName());
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(), op::Power(param0, negative_one));
@ -434,7 +436,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapeBroadcast) {
ShapeUtil::MakeShape(F32, {3, 2}), broadcast));
auto computation = builder.Build();
auto module = MakeUnique<HloModule>(TestName());
auto module = CreateNewModule();
module->AddEntryComputation(std::move(computation));
EXPECT_THAT(module->entry_computation()->root_instruction(),
@ -455,7 +457,7 @@ TEST_F(AlgebraicSimplifierTest, ConvertBetweenSameType) {
builder.AddInstruction(
HloInstruction::CreateConvert(ShapeUtil::MakeShape(F32, {}), input));
auto module = MakeUnique<HloModule>(TestName());
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(), op::Convert(input));
@ -476,7 +478,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveCopy) {
builder.AddInstruction(
HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0));
auto module = MakeUnique<HloModule>(TestName());
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(), op::Copy(param0));
@ -497,7 +499,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveUnaryConcatenate) {
builder.AddInstruction(
HloInstruction::CreateConcatenate(param0->shape(), {param0}, 0));
auto module = MakeUnique<HloModule>(TestName());
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(), op::Concatenate(param0));
@ -527,7 +529,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveEmptyConcatenateOperands) {
builder.AddInstruction(HloInstruction::CreateConcatenate(
result_shape, {empty_literal, param0, param0, empty_slice, param1}, 0));
auto module = MakeUnique<HloModule>(TestName());
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(
@ -558,7 +560,7 @@ TEST_F(AlgebraicSimplifierTest, OnlyEmptyConcatenateOperands) {
builder.AddInstruction(HloInstruction::CreateConcatenate(
result_shape, {empty_literal, empty_slice}, 0));
auto module = MakeUnique<HloModule>(TestName());
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
@ -581,7 +583,7 @@ TEST_F(AlgebraicSimplifierTest, CopyWithDifferentLayout) {
HloInstruction* copy = builder.AddInstruction(
HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0));
auto module = MakeUnique<HloModule>(TestName());
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
// Set to different layouts.
@ -608,7 +610,7 @@ TEST_F(AlgebraicSimplifierTest, CopyWithSameLayout) {
HloInstruction* copy = builder.AddInstruction(
HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0));
auto module = MakeUnique<HloModule>(TestName());
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
// Set to same layouts.
@ -640,7 +642,7 @@ TEST_F(AlgebraicSimplifierTest, NoBitcastAdded) {
*reshape->mutable_shape()->mutable_layout() =
LayoutUtil::MakeLayout({0, 1, 2, 3, 4, 5});
auto module = MakeUnique<HloModule>(TestName());
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(), op::Reshape(param0));
@ -686,7 +688,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapeReplacedWithBitcast) {
builder.AddInstruction(HloInstruction::CreateTuple(
{transformable_reshape, dimensions_wrong_reshape, layout_wrong_reshape}));
auto module = MakeUnique<HloModule>(TestName());
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
@ -716,7 +718,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapeAfterEffectiveUnary) {
builder.AddInstruction(
HloInstruction::CreateBinary(ShapeUtil::MakeShape(F32, {1, 2, 3, 4, 5}),
HloOpcode::kMaximum, movable_reshape, zero));
auto module = MakeUnique<HloModule>(TestName());
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
@ -744,7 +746,7 @@ TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast1) {
*transpose->mutable_shape()->mutable_layout() =
LayoutUtil::MakeLayout({0, 1, 2, 3});
auto module = MakeUnique<HloModule>(TestName());
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(), op::Transpose(param));
@ -771,7 +773,7 @@ TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast2) {
*transpose->mutable_shape()->mutable_layout() =
LayoutUtil::MakeLayout({3, 1, 2, 0});
auto module = MakeUnique<HloModule>(TestName());
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(), op::Transpose(param));
@ -797,7 +799,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapesMerged) {
builder.AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(F32, {1, 2, 1, 1, 2, 1}), reshape1));
auto module = MakeUnique<HloModule>(TestName());
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
@ -825,7 +827,7 @@ TEST_F(AlgebraicSimplifierTest, CopiesMerged) {
ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {0, 2, 1}),
HloOpcode::kCopy, copy1));
auto module = MakeUnique<HloModule>(TestName());
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(), op::Copy(op::Copy(param0)));
@ -850,7 +852,7 @@ TEST_F(AlgebraicSimplifierTest, TransposesMerged) {
builder.AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::MakeShape(F32, {4, 3, 2}), transpose1, {1, 0, 2}));
auto module = MakeUnique<HloModule>(TestName());
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(), op::Transpose(transpose1));
@ -874,7 +876,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapeAndBroadcastMerged) {
builder.AddInstruction(HloInstruction::CreateBroadcast(
ShapeUtil::MakeShape(F32, {1, 2, 3, 5, 1}), reshape1, {0, 2, 3}));
auto module = MakeUnique<HloModule>(TestName());
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
@ -897,7 +899,7 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshapeMerged) {
builder.AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(F32, {2, 3, 7, 2, 1, 3, 2}), broadcast1));
auto module = MakeUnique<HloModule>(TestName());
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
@ -919,7 +921,7 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x1_3) {
builder.AddInstruction(
HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {3}), broadcast));
auto module = MakeUnique<HloModule>(TestName());
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
@ -942,7 +944,7 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4_6x1x1x4) {
builder.AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(F32, {6, 1, 1, 4}), broadcast));
auto module = MakeUnique<HloModule>(TestName());
auto module = CreateNewModule();
HloComputation* computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
@ -966,7 +968,7 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x2x1_6x1x1x1) {
builder.AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(F32, {6, 1, 1, 1}), broadcast));
auto module = MakeUnique<HloModule>(TestName());
auto module = CreateNewModule();
HloComputation* computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
@ -992,7 +994,7 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4x2_6x8) {
builder.AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(F32, {6, 8}), broadcast));
auto module = MakeUnique<HloModule>(TestName());
auto module = CreateNewModule();
HloComputation* computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
@ -1697,7 +1699,7 @@ TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) {
builder.AddInstruction(
HloInstruction::CreateCall(r1f32, {zero, one}, dot_computation.get()));
auto module = MakeUnique<HloModule>(TestName());
auto module = CreateNewModule();
module->AddEmbeddedComputation(std::move(dot_computation));
module->AddEntryComputation(call_builder.Build());
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
@ -1707,3 +1709,20 @@ TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) {
} // namespace
} // namespace xla
int main(int argc, char** argv) {
std::vector<tensorflow::Flag> flag_list;
xla::legacy_flags::AppendDebugOptionsFlags(&flag_list);
xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
if (!parse_result) {
LOG(ERROR) << "\n" << usage;
return 2;
}
testing::InitGoogleTest(&argc, argv);
if (argc > 1) {
LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
return 2;
}
return RUN_ALL_TESTS();
}

View File

@ -93,6 +93,7 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/compiler/xla/legacy_flags:hlo_test_base_flags",
"//tensorflow/compiler/xla/service",
"//tensorflow/compiler/xla/service:backend",

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/legacy_flags/hlo_test_base_flags.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/backend.h"
@ -54,6 +55,8 @@ struct HloTestBase::EigenThreadPoolWrapper {
HloTestBase::HloTestBase()
: backend_(Backend::CreateDefaultBackend().ConsumeValueOrDie()) {
// TODO(b/62411181): get rid of this flag entirely when the usual debug flags
// are piped to all HLO tests.
test_hlo_dumper_ = [](const HloModule& module, const string& label) {
legacy_flags::HloTestBaseFlags* flags = legacy_flags::GetHloTestBaseFlags();
if (flags->xla_hlo_test_generate_hlo_graph) {
@ -73,6 +76,13 @@ HloTestBase::~HloTestBase() {
}
}
std::unique_ptr<HloModule> HloTestBase::CreateNewModule() {
HloModuleConfig config;
config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags());
return MakeUnique<HloModule>(TestName(), VersionedComputationHandle(),
config);
}
StatusOr<perftools::gputools::DeviceMemoryBase> HloTestBase::Execute(
std::unique_ptr<HloModule> module,
tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>

View File

@ -44,6 +44,12 @@ class HloTestBase : public ::testing::Test {
~HloTestBase() override;
// Creates a new HLO module for a test. The module created will have
// TestName() for its name; it will also automatically populate its debug
// options from command-line flags. It's recommended to use this method to
// create all HloModules for tests.
std::unique_ptr<HloModule> CreateNewModule();
// Executes the given module and returns a global data handle.
StatusOr<perftools::gputools::DeviceMemoryBase> Execute(
std::unique_ptr<HloModule> module,