mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 00:20:20 +01:00
[XLA] Add xla_disable_hlo_passes to DebugOptions
Also add a SetDebugOptions method to ClientLibraryTestBas; this lets us set debug options in tests by calling it. As an example, this CL removes the current way of passing xla_disable_hlo_passes programmatically in tests - it used to employ a special constructor parameter which is no longer required. PiperOrigin-RevId: 158169006
This commit is contained in:
parent
2b3535c649
commit
cabc5c35c2
|
|
@ -46,12 +46,21 @@ StatusOr<bool> HloPassPipeline::Run(HloModule* module) {
|
||||||
|
|
||||||
legacy_flags::HloPassPipelineFlags* flags =
|
legacy_flags::HloPassPipelineFlags* flags =
|
||||||
legacy_flags::GetHloPassPipelineFlags();
|
legacy_flags::GetHloPassPipelineFlags();
|
||||||
std::vector<string> tmp =
|
std::unique_ptr<tensorflow::gtl::FlatSet<string>> disabled_passes;
|
||||||
tensorflow::str_util::Split(flags->xla_disable_hlo_passes, ',');
|
if (!flags->xla_disable_hlo_passes.empty()) {
|
||||||
tensorflow::gtl::FlatSet<string> disabled_passes(tmp.begin(), tmp.end());
|
std::vector<string> passes_vec =
|
||||||
if (!disabled_passes.empty()) {
|
tensorflow::str_util::Split(flags->xla_disable_hlo_passes, ',');
|
||||||
|
disabled_passes = MakeUnique<tensorflow::gtl::FlatSet<string>>(
|
||||||
|
passes_vec.begin(), passes_vec.end());
|
||||||
|
} else {
|
||||||
|
auto repeated_field =
|
||||||
|
module->config().debug_options().xla_disable_hlo_passes();
|
||||||
|
disabled_passes = MakeUnique<tensorflow::gtl::FlatSet<string>>(
|
||||||
|
repeated_field.begin(), repeated_field.end());
|
||||||
|
}
|
||||||
|
if (!disabled_passes->empty()) {
|
||||||
VLOG(1) << "Passes disabled by --xla_disable_hlo_passes: "
|
VLOG(1) << "Passes disabled by --xla_disable_hlo_passes: "
|
||||||
<< tensorflow::str_util::Join(disabled_passes, ", ");
|
<< tensorflow::str_util::Join(*disabled_passes, ", ");
|
||||||
}
|
}
|
||||||
|
|
||||||
auto run_invariant_checkers = [this, module]() -> Status {
|
auto run_invariant_checkers = [this, module]() -> Status {
|
||||||
|
|
@ -66,8 +75,8 @@ StatusOr<bool> HloPassPipeline::Run(HloModule* module) {
|
||||||
bool changed = false;
|
bool changed = false;
|
||||||
string message;
|
string message;
|
||||||
for (auto& pass : passes_) {
|
for (auto& pass : passes_) {
|
||||||
if (!disabled_passes.empty() &&
|
if (!disabled_passes->empty() &&
|
||||||
disabled_passes.count(pass->name().ToString()) > 0) {
|
disabled_passes->count(pass->name().ToString()) > 0) {
|
||||||
VLOG(1) << " Skipping HLO pass " << pass->name()
|
VLOG(1) << " Skipping HLO pass " << pass->name()
|
||||||
<< ", disabled by --xla_disable_hlo_passes";
|
<< ", disabled by --xla_disable_hlo_passes";
|
||||||
continue;
|
continue;
|
||||||
|
|
|
||||||
|
|
@ -44,15 +44,8 @@ Client* GetOrCreateLocalClientOrDie(se::Platform* platform) {
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
ClientLibraryTestBase::ClientLibraryTestBase(
|
ClientLibraryTestBase::ClientLibraryTestBase(se::Platform* platform)
|
||||||
se::Platform* platform,
|
: client_(GetOrCreateLocalClientOrDie(platform)) {}
|
||||||
tensorflow::gtl::ArraySlice<string> disabled_pass_names)
|
|
||||||
: client_(GetOrCreateLocalClientOrDie(platform)) {
|
|
||||||
legacy_flags::HloPassPipelineFlags* flags =
|
|
||||||
legacy_flags::GetHloPassPipelineFlags();
|
|
||||||
flags->xla_disable_hlo_passes =
|
|
||||||
tensorflow::str_util::Join(disabled_pass_names, ",");
|
|
||||||
}
|
|
||||||
|
|
||||||
string ClientLibraryTestBase::TestName() const {
|
string ClientLibraryTestBase::TestName() const {
|
||||||
return ::testing::UnitTest::GetInstance()->current_test_info()->name();
|
return ::testing::UnitTest::GetInstance()->current_test_info()->name();
|
||||||
|
|
|
||||||
|
|
@ -46,8 +46,7 @@ namespace xla {
|
||||||
class ClientLibraryTestBase : public ::testing::Test {
|
class ClientLibraryTestBase : public ::testing::Test {
|
||||||
protected:
|
protected:
|
||||||
explicit ClientLibraryTestBase(
|
explicit ClientLibraryTestBase(
|
||||||
perftools::gputools::Platform* platform = nullptr,
|
perftools::gputools::Platform* platform = nullptr);
|
||||||
tensorflow::gtl::ArraySlice<string> disabled_pass_names = {});
|
|
||||||
|
|
||||||
// Returns the name of the test currently being run.
|
// Returns the name of the test currently being run.
|
||||||
string TestName() const;
|
string TestName() const;
|
||||||
|
|
@ -58,6 +57,12 @@ class ClientLibraryTestBase : public ::testing::Test {
|
||||||
|
|
||||||
void SetSeed(uint64 seed) { execution_options_.set_seed(seed); }
|
void SetSeed(uint64 seed) { execution_options_.set_seed(seed); }
|
||||||
|
|
||||||
|
void SetDebugOptions(const DebugOptions& debug_options) {
|
||||||
|
*(execution_options_.mutable_debug_options()) = debug_options;
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(b/25566808): Add helper that populates a literal from a testdata file.
|
||||||
|
|
||||||
// Convenience methods for building and running a computation from a builder.
|
// Convenience methods for building and running a computation from a builder.
|
||||||
StatusOr<std::unique_ptr<GlobalData>> Execute(
|
StatusOr<std::unique_ptr<GlobalData>> Execute(
|
||||||
ComputationBuilder* builder,
|
ComputationBuilder* builder,
|
||||||
|
|
|
||||||
|
|
@ -46,14 +46,8 @@ ClientType client_types[] = {ClientType::kLocal, ClientType::kCompileOnly};
|
||||||
class ComputeConstantTest : public ::testing::Test {
|
class ComputeConstantTest : public ::testing::Test {
|
||||||
public:
|
public:
|
||||||
explicit ComputeConstantTest(
|
explicit ComputeConstantTest(
|
||||||
perftools::gputools::Platform* platform = nullptr,
|
perftools::gputools::Platform* platform = nullptr)
|
||||||
tensorflow::gtl::ArraySlice<string> disabled_pass_names = {})
|
: platform_(platform) {}
|
||||||
: platform_(platform) {
|
|
||||||
legacy_flags::HloPassPipelineFlags* flags =
|
|
||||||
legacy_flags::GetHloPassPipelineFlags();
|
|
||||||
flags->xla_disable_hlo_passes =
|
|
||||||
tensorflow::str_util::Join(disabled_pass_names, ",");
|
|
||||||
}
|
|
||||||
|
|
||||||
string TestName() const {
|
string TestName() const {
|
||||||
return ::testing::UnitTest::GetInstance()->current_test_info()->name();
|
return ::testing::UnitTest::GetInstance()->current_test_info()->name();
|
||||||
|
|
|
||||||
|
|
@ -36,8 +36,12 @@ namespace {
|
||||||
class ConvertTest : public ClientLibraryTestBase {
|
class ConvertTest : public ClientLibraryTestBase {
|
||||||
public:
|
public:
|
||||||
explicit ConvertTest(perftools::gputools::Platform* platform = nullptr)
|
explicit ConvertTest(perftools::gputools::Platform* platform = nullptr)
|
||||||
: ClientLibraryTestBase(platform,
|
: ClientLibraryTestBase(platform) {
|
||||||
/*disabled_pass_names=*/{"algsimp", "inline"}) {}
|
DebugOptions debug_options;
|
||||||
|
debug_options.add_xla_disable_hlo_passes("algsimp");
|
||||||
|
debug_options.add_xla_disable_hlo_passes("inline");
|
||||||
|
SetDebugOptions(debug_options);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(ConvertTest, ConvertR1S32ToR1S32) {
|
TEST_F(ConvertTest, ConvertR1S32ToR1S32) {
|
||||||
|
|
|
||||||
|
|
@ -41,8 +41,12 @@ namespace {
|
||||||
class MapTest : public ClientLibraryTestBase {
|
class MapTest : public ClientLibraryTestBase {
|
||||||
public:
|
public:
|
||||||
explicit MapTest(perftools::gputools::Platform* platform = nullptr)
|
explicit MapTest(perftools::gputools::Platform* platform = nullptr)
|
||||||
: ClientLibraryTestBase(platform,
|
: ClientLibraryTestBase(platform) {
|
||||||
/*disabled_pass_names=*/{"algsimp", "inline"}) {}
|
DebugOptions debug_options;
|
||||||
|
debug_options.add_xla_disable_hlo_passes("algsimp");
|
||||||
|
debug_options.add_xla_disable_hlo_passes("inline");
|
||||||
|
SetDebugOptions(debug_options);
|
||||||
|
}
|
||||||
|
|
||||||
// Creates a function that adds its scalar argument with the constant 1.0.
|
// Creates a function that adds its scalar argument with the constant 1.0.
|
||||||
//
|
//
|
||||||
|
|
|
||||||
|
|
@ -41,8 +41,12 @@ namespace {
|
||||||
class VecOpsSimpleTest : public ClientLibraryTestBase {
|
class VecOpsSimpleTest : public ClientLibraryTestBase {
|
||||||
public:
|
public:
|
||||||
explicit VecOpsSimpleTest(perftools::gputools::Platform* platform = nullptr)
|
explicit VecOpsSimpleTest(perftools::gputools::Platform* platform = nullptr)
|
||||||
: ClientLibraryTestBase(platform,
|
: ClientLibraryTestBase(platform) {
|
||||||
/*disabled_pass_names=*/{"algsimp", "inline"}) {}
|
DebugOptions debug_options;
|
||||||
|
debug_options.add_xla_disable_hlo_passes("algsimp");
|
||||||
|
debug_options.add_xla_disable_hlo_passes("inline");
|
||||||
|
SetDebugOptions(debug_options);
|
||||||
|
}
|
||||||
|
|
||||||
ErrorSpec error_spec_{0.0001};
|
ErrorSpec error_spec_{0.0001};
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -27,6 +27,10 @@ message DebugOptions {
|
||||||
// various stages in compilation (file names are LOG(INFO)'d). Set to ".*" to
|
// various stages in compilation (file names are LOG(INFO)'d). Set to ".*" to
|
||||||
// dump *all* HLO modules.
|
// dump *all* HLO modules.
|
||||||
string xla_generate_hlo_graph = 1;
|
string xla_generate_hlo_graph = 1;
|
||||||
|
|
||||||
|
// List of HLO passes to disable. These names must exactly match the pass
|
||||||
|
// names as specified by the HloPassInterface::name() method.
|
||||||
|
repeated string xla_disable_hlo_passes = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
// These settings control how XLA compiles and/or runs code. Not all settings
|
// These settings control how XLA compiles and/or runs code. Not all settings
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user