[XLA] Add xla_disable_hlo_passes to DebugOptions

Also add a SetDebugOptions method to ClientLibraryTestBas; this lets us set
debug options in tests by calling it.

As an example, this CL removes the current way of passing
xla_disable_hlo_passes programmatically in tests - it used to employ a special
constructor parameter which is no longer required.

PiperOrigin-RevId: 158169006
This commit is contained in:
Eli Bendersky 2017-06-06 11:48:39 -07:00 committed by TensorFlower Gardener
parent 2b3535c649
commit cabc5c35c2
8 changed files with 49 additions and 32 deletions

View File

@ -46,12 +46,21 @@ StatusOr<bool> HloPassPipeline::Run(HloModule* module) {
legacy_flags::HloPassPipelineFlags* flags = legacy_flags::HloPassPipelineFlags* flags =
legacy_flags::GetHloPassPipelineFlags(); legacy_flags::GetHloPassPipelineFlags();
std::vector<string> tmp = std::unique_ptr<tensorflow::gtl::FlatSet<string>> disabled_passes;
tensorflow::str_util::Split(flags->xla_disable_hlo_passes, ','); if (!flags->xla_disable_hlo_passes.empty()) {
tensorflow::gtl::FlatSet<string> disabled_passes(tmp.begin(), tmp.end()); std::vector<string> passes_vec =
if (!disabled_passes.empty()) { tensorflow::str_util::Split(flags->xla_disable_hlo_passes, ',');
disabled_passes = MakeUnique<tensorflow::gtl::FlatSet<string>>(
passes_vec.begin(), passes_vec.end());
} else {
auto repeated_field =
module->config().debug_options().xla_disable_hlo_passes();
disabled_passes = MakeUnique<tensorflow::gtl::FlatSet<string>>(
repeated_field.begin(), repeated_field.end());
}
if (!disabled_passes->empty()) {
VLOG(1) << "Passes disabled by --xla_disable_hlo_passes: " VLOG(1) << "Passes disabled by --xla_disable_hlo_passes: "
<< tensorflow::str_util::Join(disabled_passes, ", "); << tensorflow::str_util::Join(*disabled_passes, ", ");
} }
auto run_invariant_checkers = [this, module]() -> Status { auto run_invariant_checkers = [this, module]() -> Status {
@ -66,8 +75,8 @@ StatusOr<bool> HloPassPipeline::Run(HloModule* module) {
bool changed = false; bool changed = false;
string message; string message;
for (auto& pass : passes_) { for (auto& pass : passes_) {
if (!disabled_passes.empty() && if (!disabled_passes->empty() &&
disabled_passes.count(pass->name().ToString()) > 0) { disabled_passes->count(pass->name().ToString()) > 0) {
VLOG(1) << " Skipping HLO pass " << pass->name() VLOG(1) << " Skipping HLO pass " << pass->name()
<< ", disabled by --xla_disable_hlo_passes"; << ", disabled by --xla_disable_hlo_passes";
continue; continue;

View File

@ -44,15 +44,8 @@ Client* GetOrCreateLocalClientOrDie(se::Platform* platform) {
} }
} // namespace } // namespace
ClientLibraryTestBase::ClientLibraryTestBase( ClientLibraryTestBase::ClientLibraryTestBase(se::Platform* platform)
se::Platform* platform, : client_(GetOrCreateLocalClientOrDie(platform)) {}
tensorflow::gtl::ArraySlice<string> disabled_pass_names)
: client_(GetOrCreateLocalClientOrDie(platform)) {
legacy_flags::HloPassPipelineFlags* flags =
legacy_flags::GetHloPassPipelineFlags();
flags->xla_disable_hlo_passes =
tensorflow::str_util::Join(disabled_pass_names, ",");
}
string ClientLibraryTestBase::TestName() const { string ClientLibraryTestBase::TestName() const {
return ::testing::UnitTest::GetInstance()->current_test_info()->name(); return ::testing::UnitTest::GetInstance()->current_test_info()->name();

View File

@ -46,8 +46,7 @@ namespace xla {
class ClientLibraryTestBase : public ::testing::Test { class ClientLibraryTestBase : public ::testing::Test {
protected: protected:
explicit ClientLibraryTestBase( explicit ClientLibraryTestBase(
perftools::gputools::Platform* platform = nullptr, perftools::gputools::Platform* platform = nullptr);
tensorflow::gtl::ArraySlice<string> disabled_pass_names = {});
// Returns the name of the test currently being run. // Returns the name of the test currently being run.
string TestName() const; string TestName() const;
@ -58,6 +57,12 @@ class ClientLibraryTestBase : public ::testing::Test {
void SetSeed(uint64 seed) { execution_options_.set_seed(seed); } void SetSeed(uint64 seed) { execution_options_.set_seed(seed); }
void SetDebugOptions(const DebugOptions& debug_options) {
*(execution_options_.mutable_debug_options()) = debug_options;
}
// TODO(b/25566808): Add helper that populates a literal from a testdata file.
// Convenience methods for building and running a computation from a builder. // Convenience methods for building and running a computation from a builder.
StatusOr<std::unique_ptr<GlobalData>> Execute( StatusOr<std::unique_ptr<GlobalData>> Execute(
ComputationBuilder* builder, ComputationBuilder* builder,

View File

@ -46,14 +46,8 @@ ClientType client_types[] = {ClientType::kLocal, ClientType::kCompileOnly};
class ComputeConstantTest : public ::testing::Test { class ComputeConstantTest : public ::testing::Test {
public: public:
explicit ComputeConstantTest( explicit ComputeConstantTest(
perftools::gputools::Platform* platform = nullptr, perftools::gputools::Platform* platform = nullptr)
tensorflow::gtl::ArraySlice<string> disabled_pass_names = {}) : platform_(platform) {}
: platform_(platform) {
legacy_flags::HloPassPipelineFlags* flags =
legacy_flags::GetHloPassPipelineFlags();
flags->xla_disable_hlo_passes =
tensorflow::str_util::Join(disabled_pass_names, ",");
}
string TestName() const { string TestName() const {
return ::testing::UnitTest::GetInstance()->current_test_info()->name(); return ::testing::UnitTest::GetInstance()->current_test_info()->name();

View File

@ -36,8 +36,12 @@ namespace {
class ConvertTest : public ClientLibraryTestBase { class ConvertTest : public ClientLibraryTestBase {
public: public:
explicit ConvertTest(perftools::gputools::Platform* platform = nullptr) explicit ConvertTest(perftools::gputools::Platform* platform = nullptr)
: ClientLibraryTestBase(platform, : ClientLibraryTestBase(platform) {
/*disabled_pass_names=*/{"algsimp", "inline"}) {} DebugOptions debug_options;
debug_options.add_xla_disable_hlo_passes("algsimp");
debug_options.add_xla_disable_hlo_passes("inline");
SetDebugOptions(debug_options);
}
}; };
TEST_F(ConvertTest, ConvertR1S32ToR1S32) { TEST_F(ConvertTest, ConvertR1S32ToR1S32) {

View File

@ -41,8 +41,12 @@ namespace {
class MapTest : public ClientLibraryTestBase { class MapTest : public ClientLibraryTestBase {
public: public:
explicit MapTest(perftools::gputools::Platform* platform = nullptr) explicit MapTest(perftools::gputools::Platform* platform = nullptr)
: ClientLibraryTestBase(platform, : ClientLibraryTestBase(platform) {
/*disabled_pass_names=*/{"algsimp", "inline"}) {} DebugOptions debug_options;
debug_options.add_xla_disable_hlo_passes("algsimp");
debug_options.add_xla_disable_hlo_passes("inline");
SetDebugOptions(debug_options);
}
// Creates a function that adds its scalar argument with the constant 1.0. // Creates a function that adds its scalar argument with the constant 1.0.
// //

View File

@ -41,8 +41,12 @@ namespace {
class VecOpsSimpleTest : public ClientLibraryTestBase { class VecOpsSimpleTest : public ClientLibraryTestBase {
public: public:
explicit VecOpsSimpleTest(perftools::gputools::Platform* platform = nullptr) explicit VecOpsSimpleTest(perftools::gputools::Platform* platform = nullptr)
: ClientLibraryTestBase(platform, : ClientLibraryTestBase(platform) {
/*disabled_pass_names=*/{"algsimp", "inline"}) {} DebugOptions debug_options;
debug_options.add_xla_disable_hlo_passes("algsimp");
debug_options.add_xla_disable_hlo_passes("inline");
SetDebugOptions(debug_options);
}
ErrorSpec error_spec_{0.0001}; ErrorSpec error_spec_{0.0001};
}; };

View File

@ -27,6 +27,10 @@ message DebugOptions {
// various stages in compilation (file names are LOG(INFO)'d). Set to ".*" to // various stages in compilation (file names are LOG(INFO)'d). Set to ".*" to
// dump *all* HLO modules. // dump *all* HLO modules.
string xla_generate_hlo_graph = 1; string xla_generate_hlo_graph = 1;
// List of HLO passes to disable. These names must exactly match the pass
// names as specified by the HloPassInterface::name() method.
repeated string xla_disable_hlo_passes = 2;
} }
// These settings control how XLA compiles and/or runs code. Not all settings // These settings control how XLA compiles and/or runs code. Not all settings