Add the option to dump before/after autotuned instructions in AutotunerConfig.

- This change is required to still support the functionality of xla_gpu_dump_autotuned_gemm_fusions in the new infra.

PiperOrigin-RevId: 826161466
This commit is contained in:
A. Unique TensorFlower 2025-10-30 13:24:39 -07:00 committed by TensorFlower Gardener
parent 8f60516a86
commit c40bb10b96
4 changed files with 109 additions and 26 deletions

View File

@ -40,8 +40,10 @@ cc_library(
"//xla:autotuning_proto_cc",
"//xla/hlo/ir:hlo",
"//xla/pjrt/distributed:key_value_store_interface",
"//xla/service:dump",
"//xla/service:executable",
"//xla/service:shaped_buffer",
"//xla/tools:hlo_decomposer_lib",
"//xla/tsl/platform:env",
"//xla/tsl/platform:errors",
"//xla/tsl/platform:statusor",
@ -86,6 +88,7 @@ xla_cc_test(
"//xla/tsl/platform:env",
"//xla/tsl/platform:statusor",
"//xla/tsl/platform:test",
"//xla/tsl/testing:temporary_directory",
"//xla/tsl/util/proto:proto_matchers",
"//xla/tsl/util/proto:proto_utils",
"@com_google_absl//absl/status",

View File

@ -46,8 +46,10 @@ limitations under the License.
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_print_options.h"
#include "xla/pjrt/distributed/key_value_store_interface.h"
#include "xla/service/dump.h"
#include "xla/service/executable.h"
#include "xla/service/shaped_buffer.h"
#include "xla/tools/hlo_decomposer.h"
#include "xla/tsl/platform/env.h"
#include "xla/tsl/platform/errors.h"
#include "xla/tsl/platform/statusor.h"
@ -158,6 +160,9 @@ absl::Status Autotuner::Autotune(HloModule* module,
CHECK(!instructions.empty());
TF_ASSIGN_OR_RETURN(Config config, GetConfig(instructions[0]));
CodegenBackend* codegen_backend = config.codegen_backend;
if (autotune_config_.dump_hlos) {
TF_RETURN_IF_ERROR(DumpHlo(instructions[0], config));
}
for (auto* instr : instructions) {
TF_RETURN_IF_ERROR(
codegen_backend->ApplyConfig(*instr, *config.backend_config));
@ -251,6 +256,9 @@ absl::Status Autotuner::Autotune(HloModule* module,
CHECK(cached_config.has_value())
<< "Sharding autotuning failed: no config found for HLO: " +
instructions[0]->ToString();
if (autotune_config_.dump_hlos) {
TF_RETURN_IF_ERROR(DumpHlo(instructions[0], *cached_config));
}
CodegenBackend* codegen_backend = cached_config->codegen_backend;
for (auto* instr : instructions) {
TF_RETURN_IF_ERROR(
@ -264,6 +272,9 @@ absl::Status Autotuner::Autotune(HloModule* module,
absl::Status Autotuner::Autotune(HloInstruction* instr) {
TF_ASSIGN_OR_RETURN(Config config, GetConfig(instr));
CodegenBackend* codegen_backend = config.codegen_backend;
if (autotune_config_.dump_hlos) {
TF_RETURN_IF_ERROR(DumpHlo(instr, config));
}
TF_RETURN_IF_ERROR(
codegen_backend->ApplyConfig(*instr, *config.backend_config));
return DumpLogsToFile();
@ -531,6 +542,22 @@ absl::StatusOr<Autotuner::ConfigResult> Autotuner::PickBestConfig(
return std::move(*best_result);
}
absl::Status Autotuner::DumpHlo(HloInstruction* instr, const Config& config) {
const HloModule* parent_module = instr->GetModule();
std::unique_ptr<HloModule> module = ExtractInstructionIntoNewModule(*instr);
module->set_name(std::string(instr->name()));
std::string id =
absl::StrCat("autotuner_", dump_counter_++, ".", instr->name());
DumpToFileInDirOrStdout(*parent_module, "", absl::StrCat(id, ".before.txt"),
module->ToString());
HloInstruction* root = module->entry_computation()->root_instruction();
TF_RETURN_IF_ERROR(
config.codegen_backend->ApplyConfig(*root, *config.backend_config));
DumpToFileInDirOrStdout(*parent_module, "", absl::StrCat(id, ".after.txt"),
module->ToString());
return absl::OkStatus();
}
absl::StatusOr<ScopedShapedBuffer> Autotuner::GetReferenceOutput(
std::vector<ExecutableCandidate>& candidates, InputBuffers& input_buffers) {
for (auto& candidate : candidates) {

View File

@ -82,6 +82,9 @@ struct AutotuneConfig {
// Note: If cache is provided, the cached config will be used instead of the
// default config.
bool use_default_config = false;
// If true, dump the autotuned instructions to the modules's xla_dump_to or
// to stdout if not set.
bool dump_hlos = false;
};
class Autotuner {
@ -205,6 +208,8 @@ class Autotuner {
void LogConfigResults(const HloInstruction& instr,
const std::vector<ConfigResult>& results);
absl::Status DumpLogsToFile();
// Dumps HLO before and after applying the config.
absl::Status DumpHlo(HloInstruction* instr, const Config& config);
std::vector<std::unique_ptr<CodegenBackend>> codegen_backends_;
std::unique_ptr<Profiler> profiler_;
@ -212,6 +217,7 @@ class Autotuner {
std::unique_ptr<AutotunerCacheInterface> cache_;
tsl::thread::ThreadPool* thread_pool_;
AutotuningLogs logs_;
int dump_counter_ = 0;
};
} // namespace xla

View File

@ -50,6 +50,7 @@ limitations under the License.
#include "xla/tsl/platform/statusor.h"
#include "xla/tsl/platform/test.h"
#include "xla/tsl/platform/threadpool.h"
#include "xla/tsl/testing/temporary_directory.h"
#include "xla/tsl/util/proto/proto_matchers.h"
#include "xla/tsl/util/proto/proto_utils.h"
#include "tsl/platform/path.h"
@ -140,7 +141,9 @@ using absl_testing::StatusIs;
using ::testing::_;
using ::testing::AtMost;
using ::testing::ByMove;
using ::testing::MatchesRegex;
using ::testing::Return;
using ::testing::UnorderedElementsAre;
using tsl::proto_utils::ToDurationProto;
se::DeviceDescription CreateDummyDeviceDescription() {
@ -150,21 +153,32 @@ se::DeviceDescription CreateDummyDeviceDescription() {
}
absl::StatusOr<std::unique_ptr<Autotuner>> SetupAutotunerWithExpectations(
HloOpcode instr_to_autotune,
std::vector<HloOpcode> instrs_to_autotune,
std::vector<std::pair<HloOpcode, int>> instrs_to_apply_config_and_count,
std::unique_ptr<MockAutotunerCache> cache = nullptr) {
std::vector<std::unique_ptr<BackendConfig>> configs;
configs.push_back(GetTestConfig("another_config"));
configs.push_back(GetTestConfig("best_config"));
std::unique_ptr<MockAutotunerCache> cache = nullptr,
bool dump_hlos = false) {
auto backend = std::make_unique<MockCodegenBackend>();
auto profiler = std::make_unique<MockProfiler>();
EXPECT_CALL(*backend, name()).WillRepeatedly(Return("mock_backend"));
EXPECT_CALL(*backend,
GetSupportedConfigs(InstructionMatcher(instr_to_autotune)))
.WillOnce(Return(std::move(configs)));
for (const auto& instr_to_autotune : instrs_to_autotune) {
std::vector<std::unique_ptr<BackendConfig>> configs;
// Best config is just by notion here since profiler time is same for all.
configs.push_back(GetTestConfig("best_config"));
configs.push_back(GetTestConfig("another_config"));
EXPECT_CALL(*backend,
GetSupportedConfigs(InstructionMatcher(instr_to_autotune)))
.WillOnce(Return(std::move(configs)));
}
EXPECT_CALL(*profiler, CreateInputBuffers(_))
.Times(instrs_to_autotune.size())
.WillRepeatedly([] { return std::make_unique<InputBuffers>(); });
EXPECT_CALL(*backend, Compile(_, _))
.WillOnce(Return(std::unique_ptr<Executable>()))
.WillOnce(Return(std::unique_ptr<Executable>()));
.Times(2 * instrs_to_autotune.size())
.WillRepeatedly([] { return std::unique_ptr<Executable>(); });
EXPECT_CALL(*profiler, Profile(_, _))
.Times(2 * instrs_to_autotune.size())
.WillRepeatedly([] { return ProfileResult({absl::Seconds(1)}); });
for (const auto& [instr_to_apply_config, count] :
instrs_to_apply_config_and_count) {
EXPECT_CALL(*backend,
@ -172,19 +186,12 @@ absl::StatusOr<std::unique_ptr<Autotuner>> SetupAutotunerWithExpectations(
.Times(count)
.WillRepeatedly(Return(absl::OkStatus()));
}
auto profiler = std::make_unique<MockProfiler>();
auto device_description = CreateDummyDeviceDescription();
EXPECT_CALL(*profiler, CreateInputBuffers(_))
.WillOnce(Return(std::make_unique<InputBuffers>()));
EXPECT_CALL(*profiler, Profile(_, _))
.WillOnce(Return(ProfileResult({absl::Seconds(2)})))
.WillOnce(Return(ProfileResult({absl::Seconds(1)})));
std::vector<std::unique_ptr<CodegenBackend>> backends;
backends.push_back(std::move(backend));
return Autotuner::Create(std::move(backends), std::move(profiler),
GetTestAutotuneConfig(), std::move(cache));
AutotuneConfig config = GetTestAutotuneConfig();
config.dump_hlos = dump_hlos;
return Autotuner::Create(std::move(backends), std::move(profiler), config,
std::move(cache));
}
constexpr absl::string_view kHlo = R"(
@ -376,7 +383,7 @@ TEST_F(AutotunerTest, AutotuneModuleFollowsFilter) {
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Autotuner> autotuner,
SetupAutotunerWithExpectations(
/*instr_to_autotune=*/HloOpcode::kCopy,
/*instrs_to_autotune=*/{HloOpcode::kCopy},
/*instrs_to_apply_config_and_count=*/{{HloOpcode::kCopy, 1}}));
EXPECT_THAT(autotuner->Autotune(module.get(), should_autotune),
@ -393,7 +400,7 @@ TEST_F(AutotunerTest, AutotuneModuleWithDuplicateInstructions) {
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Autotuner> autotuner,
SetupAutotunerWithExpectations(
/*instr_to_autotune=*/HloOpcode::kAdd,
/*instrs_to_autotune=*/{HloOpcode::kAdd},
/*instrs_to_apply_config_and_count=*/{{HloOpcode::kAdd, 2}}));
EXPECT_THAT(autotuner->Autotune(module.get(), should_autotune), IsOk());
@ -585,7 +592,10 @@ TEST_F(AutotunerTest, ExpectAllInstructionsInCache) {
}
TEST_F(AutotunerTest, DumpLogsToFile) {
config_.dump_logs_to = tsl::io::JoinPath(tsl::testing::TmpDir(), "dump.log");
TF_ASSERT_OK_AND_ASSIGN(
tsl::testing::TemporaryDirectory temp_dir,
tsl::testing::TemporaryDirectory::CreateForCurrentTestcase());
config_.dump_logs_to = tsl::io::JoinPath(temp_dir.path(), "dump.log");
std::vector<std::unique_ptr<BackendConfig>> configs;
configs.push_back(GetTestConfig("test_config_1"));
@ -804,7 +814,7 @@ TEST_F(AutotunerTest, ShardedAutotuning) {
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Autotuner> autotuner,
SetupAutotunerWithExpectations(
/*instr_to_autotune=*/HloOpcode::kCopy,
/*instrs_to_autotune=*/{HloOpcode::kCopy},
/*instrs_to_apply_config_and_count=*/
{{HloOpcode::kCopy, 1}, {HloOpcode::kAdd, 2}}, std::move(cache)));
@ -817,5 +827,42 @@ TEST_F(AutotunerTest, ShardedAutotuning) {
IsOk());
}
TEST_F(AutotunerTest, DumpHlos) {
TF_ASSERT_OK_AND_ASSIGN(
tsl::testing::TemporaryDirectory dump_dir,
tsl::testing::TemporaryDirectory::CreateForCurrentTestcase());
auto module = ParseAndReturnVerifiedModule(kHlo).value();
module->mutable_config().mutable_debug_options().set_xla_dump_to(
dump_dir.path());
auto should_autotune = [](const HloInstruction& instruction) {
return instruction.opcode() == HloOpcode::kCopy ||
instruction.opcode() == HloOpcode::kAdd;
};
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Autotuner> autotuner,
SetupAutotunerWithExpectations(
/*instrs_to_autotune=*/{HloOpcode::kCopy, HloOpcode::kAdd},
// One apply config call per instruction is expected for dumping HLOs.
/*instrs_to_apply_config_and_count=*/
{{HloOpcode::kCopy, 2}, {HloOpcode::kAdd, 3}},
/*cache=*/nullptr,
/*dump_hlos=*/true));
EXPECT_THAT(autotuner->Autotune(module.get(), should_autotune), IsOk());
std::vector<std::string> files;
EXPECT_THAT(tsl::Env::Default()->GetChildren(dump_dir.path(), &files),
IsOk());
EXPECT_THAT(files.size(), 4);
EXPECT_THAT(
files,
UnorderedElementsAre(
MatchesRegex(".*\\.test_module\\.autotuner_0\\.copy\\.before\\.txt"),
MatchesRegex(".*\\.test_module\\.autotuner_0\\.copy\\.after\\.txt"),
MatchesRegex(".*\\.test_module\\.autotuner_1\\.add\\.after\\.txt"),
MatchesRegex(".*\\.test_module\\.autotuner_1\\.add\\.before\\.txt")));
}
} // namespace
} // namespace xla