[Autotuner] Add sharding support using KeyValueStore Interface.

- The logic is ported from gemm_fusion_autotuner. I have changed the key of the Key Value store to be just module-fingerprint, earlier it was module-fingerprint + autotunable-fusion-set-from-the-module-fingerprint. The module fingerprint should already represent the fusion-sets contained in it.
- We can improve or just remove this functionality when we design storage for offline autotuning.

PiperOrigin-RevId: 826103885
This commit is contained in:
A. Unique TensorFlower 2025-10-30 10:57:45 -07:00 committed by TensorFlower Gardener
parent 4ffcba9004
commit dd3a14ace4
4 changed files with 239 additions and 28 deletions

View File

@ -39,6 +39,7 @@ cc_library(
"//xla:autotune_results_proto_cc",
"//xla:autotuning_proto_cc",
"//xla/hlo/ir:hlo",
"//xla/pjrt/distributed:key_value_store_interface",
"//xla/service:executable",
"//xla/service:shaped_buffer",
"//xla/tsl/platform:env",
@ -60,7 +61,6 @@ cc_library(
"@local_tsl//tsl/platform:blocking_counter",
"@local_tsl//tsl/platform:fingerprint",
"@local_tsl//tsl/platform:path",
"@local_tsl//tsl/platform:protobuf",
],
)
@ -78,6 +78,7 @@ xla_cc_test(
"//xla:shape_util",
"//xla/hlo/ir:hlo",
"//xla/hlo/testlib:hlo_hardware_independent_test_base",
"//xla/pjrt/distributed:key_value_store_interface",
"//xla/service:executable",
"//xla/service:shaped_buffer",
"//xla/service/gpu:backend_configs_cc",
@ -92,13 +93,12 @@ xla_cc_test(
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/time",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",
"@com_google_googletest//:gtest_main",
"@com_google_protobuf//:any_cc_proto",
"@com_google_protobuf//:protobuf",
"@local_tsl//tsl/platform:path",
"@local_tsl//tsl/platform:protobuf",
"@local_tsl//tsl/platform:test",
],
)

View File

@ -16,7 +16,10 @@ limitations under the License.
#include "xla/backends/autotuner/autotuner.h"
#include <algorithm>
#include <cmath>
#include <cstddef>
#include <cstdint>
#include <iterator>
#include <limits>
#include <memory>
#include <optional>
@ -25,6 +28,7 @@ limitations under the License.
#include <vector>
#include "google/protobuf/any.pb.h"
#include "absl/container/flat_hash_map.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/memory/memory.h"
@ -41,6 +45,7 @@ limitations under the License.
#include "xla/backends/autotuner/profiler.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_print_options.h"
#include "xla/pjrt/distributed/key_value_store_interface.h"
#include "xla/service/executable.h"
#include "xla/service/shaped_buffer.h"
#include "xla/tsl/platform/env.h"
@ -50,7 +55,6 @@ limitations under the License.
#include "xla/tsl/util/proto/proto_utils.h"
#include "tsl/platform/blocking_counter.h"
#include "tsl/platform/fingerprint.h"
#include "tsl/platform/protobuf.h"
namespace xla {
@ -74,6 +78,38 @@ std::string UnpackedAnyShortDebugString(const google::protobuf::Any& any) {
return s;
}
// It is important to fingerprint the entire module not just the autotuning
// candidates, to avoid collisions in the key-value store when several
// distinct modules have the same fusions, and are compiled at different
// times by the same PjRt client.
//
// TODO(b/394763704): Eliminate the sharding feature when we have offline
// autotuning. See below for an explanation of some issues.
//
// Theoretically, we also want to include the hash of the module config
// to ensure that a module compiled twice with different configs is
// autotuned twice.
//
// This is important since the config could e.g. affect codegen, or the
// space of possible parameters for autotuning. As a result, the autotuning
// results could look very different for the same module.
//
// Why is it not done here? Well, proto serialization is non-deterministic
// and may change across different builds. Which means that users who run
// on several hosts with different CPUs may end up generating different
// fingerprints for the same module config. They would then fail to
// exchange results through the key value store, which would lead to
// deadlocks. Therefore, we don't hash the module config here.
//
// The flip side is this: if we compile the same module twice in the same
// client, but with a different module config each time, we may hit the
// cache the second time and recover potentially inferior, or incomplete
// autotuning results.
std::string GetKvStoreKey(const HloModule* module, int shard_index) {
return absl::StrCat("autotune_results_", module->GetFingerprint128(), "_",
shard_index);
}
} // namespace
absl::StatusOr<Autotuner::Config> Autotuner::GetDefaultConfig(
@ -110,15 +146,15 @@ absl::StatusOr<std::unique_ptr<Autotuner>> Autotuner::Create(
absl::Status Autotuner::Autotune(HloModule* module,
const InstructionFilterFn& should_autotune) {
InstructionsByFingerprint instrunctions_by_fingerprint =
InstructionsByFingerprint instructions_by_fingerprint =
GetAutotuningCandidates(module, should_autotune);
if (instrunctions_by_fingerprint.empty()) {
if (instructions_by_fingerprint.empty()) {
VLOG(1) << "No instructions to autotune.";
return absl::OkStatus();
}
VLOG(1) << "Finding configs for " << instrunctions_by_fingerprint.size()
VLOG(1) << "Finding configs for " << instructions_by_fingerprint.size()
<< " unique instructions.";
for (auto& [_, instructions] : instrunctions_by_fingerprint) {
for (auto& [_, instructions] : instructions_by_fingerprint) {
CHECK(!instructions.empty());
TF_ASSIGN_OR_RETURN(Config config, GetConfig(instructions[0]));
CodegenBackend* codegen_backend = config.codegen_backend;
@ -130,6 +166,101 @@ absl::Status Autotuner::Autotune(HloModule* module,
return DumpLogsToFile();
}
absl::Status Autotuner::Autotune(HloModule* module,
const InstructionFilterFn& should_autotune,
MultiProcessKeyValueStore& sharding_kv_store) {
CHECK(cache_ != nullptr) << "Sharding autotuning requires a cache.";
int total_shards = sharding_kv_store.process_count;
int my_shard_index = sharding_kv_store.process_index;
// 1. Get all the instructions that could be autotuned.
InstructionsByFingerprint all_instructions_by_fingerprint =
GetAutotuningCandidates(module, should_autotune);
if (all_instructions_by_fingerprint.empty()) {
VLOG(1) << "No instructions to autotune.";
return absl::OkStatus();
}
// 2. Shard and get instructions to autotune for current shard.
const size_t bucket_size =
std::ceil(static_cast<double>(all_instructions_by_fingerprint.size()) /
static_cast<double>(total_shards));
const size_t start = bucket_size * my_shard_index;
const size_t end =
std::min(start + bucket_size, all_instructions_by_fingerprint.size());
InstructionsByFingerprint instructions_by_fingerprint(
std::next(all_instructions_by_fingerprint.begin(), start),
std::next(all_instructions_by_fingerprint.begin(), end));
// 3. Autotune instructions for this shard. Use cached configs if available,
// otherwise autotune and cache the best config.
VLOG(1) << "Shard " << my_shard_index << "/" << total_shards
<< ": finding configs for " << instructions_by_fingerprint.size()
<< "/" << all_instructions_by_fingerprint.size()
<< " unique instructions ";
std::vector<const HloInstruction*> autotuned_instructions;
for (auto& [_, instructions] : instructions_by_fingerprint) {
CHECK(!instructions.empty());
TF_ASSIGN_OR_RETURN(Config config, GetConfig(instructions[0]));
autotuned_instructions.push_back(instructions[0]);
}
TF_RETURN_IF_ERROR(DumpLogsToFile());
// 4. Store the results for this shard as a serialized string to the KV store.
KeyValueStoreInterface& kv_store = *sharding_kv_store.key_value_store;
const std::string local_key = GetKvStoreKey(module, my_shard_index);
std::string local_results;
if (!autotuned_instructions.empty()) {
TF_ASSIGN_OR_RETURN(local_results,
cache_->Serialize(autotuned_instructions));
}
absl::StatusOr<std::string> stored_result = kv_store.TryGet(local_key);
if (stored_result.status().code() == absl::StatusCode::kNotFound) {
VLOG(2) << "Storing results for " << local_key;
TF_RETURN_IF_ERROR(kv_store.Set(local_key, local_results));
VLOG(2) << "Shard " << my_shard_index << " stored results at " << local_key;
} else if (!stored_result.ok()) {
return stored_result.status();
} else {
VLOG(2) << "Results already exist for " << local_key << ", skipping store.";
}
// 5. Load the autotune results of other shards from the KV store and update
// the current shard's cache by deserializing the results.
for (int i = 0; i < total_shards; ++i) {
if (i == my_shard_index) {
continue;
}
const std::string remote_key = GetKvStoreKey(module, i);
VLOG(2) << "Shard " << my_shard_index << ": waiting for results from shard "
<< i << " / " << total_shards << " at " << remote_key;
// TODO(b/361009609): reset to infinite duration once issue with MPI is
// fixed. https://github.com/google/jax/issues/22995.
TF_ASSIGN_OR_RETURN(std::string remote_results,
kv_store.Get(remote_key, absl::Hours(24)));
if (!remote_results.empty()) {
TF_RETURN_IF_ERROR(cache_->Deserialize(remote_results));
}
}
// 6. Apply the results to all candidate instructions, must be already in
// cache_ due to step 3 and 5 above.
for (auto& [_, instructions] : all_instructions_by_fingerprint) {
CHECK(!instructions.empty());
std::optional<Config> cached_config = LookUp(instructions[0]);
CHECK(cached_config.has_value())
<< "Sharding autotuning failed: no config found for HLO: " +
instructions[0]->ToString();
CodegenBackend* codegen_backend = cached_config->codegen_backend;
for (auto* instr : instructions) {
TF_RETURN_IF_ERROR(
codegen_backend->ApplyConfig(*instr, *cached_config->backend_config));
}
}
return absl::OkStatus();
}
absl::Status Autotuner::Autotune(HloInstruction* instr) {
TF_ASSIGN_OR_RETURN(Config config, GetConfig(instr));
CodegenBackend* codegen_backend = config.codegen_backend;
@ -221,15 +352,15 @@ absl::StatusOr<Autotuner::Config> Autotuner::TuneBestConfig(
Autotuner::InstructionsByFingerprint Autotuner::GetAutotuningCandidates(
const HloModule* module, const InstructionFilterFn& should_autotune) {
InstructionsByFingerprint instrunctions_by_fingerprint;
InstructionsByFingerprint instructions_by_fingerprint;
for (HloComputation* computation : module->MakeNonfusionComputations()) {
for (HloInstruction* instr : computation->MakeInstructionPostOrder()) {
if (should_autotune(*instr)) {
instrunctions_by_fingerprint[GetFingerprint(instr)].push_back(instr);
instructions_by_fingerprint[GetFingerprint(instr)].push_back(instr);
}
}
}
return instrunctions_by_fingerprint;
return instructions_by_fingerprint;
}
std::optional<Autotuner::Config> Autotuner::LookUp(

View File

@ -33,6 +33,7 @@ limitations under the License.
#include "xla/backends/autotuner/codegen_backend.h"
#include "xla/backends/autotuner/profiler.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/pjrt/distributed/key_value_store_interface.h"
#include "xla/service/executable.h"
#include "xla/service/shaped_buffer.h"
#include "xla/tsl/platform/threadpool.h"
@ -103,6 +104,13 @@ class Autotuner {
absl::Status Autotune(HloModule* module,
const InstructionFilterFn& should_autotune);
// Same as above, but also takes a sharding KV store which helps to shard
// the autotuning work across multiple processes.
// This is used for distributed autotuning.
absl::Status Autotune(HloModule* module,
const InstructionFilterFn& should_autotune,
MultiProcessKeyValueStore& sharding_kv_store);
private:
using InstructionsByFingerprint =
absl::flat_hash_map<tsl::Fprint128, std::vector<HloInstruction*>,

View File

@ -29,6 +29,7 @@ limitations under the License.
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "absl/time/time.h"
#include "absl/types/span.h"
#include "google/protobuf/text_format.h"
#include "xla/autotune_results.pb.h"
#include "xla/autotuning.pb.h"
@ -39,6 +40,7 @@ limitations under the License.
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h"
#include "xla/literal_util.h"
#include "xla/pjrt/distributed/key_value_store_interface.h"
#include "xla/service/executable.h"
#include "xla/service/gpu/backend_configs.pb.h"
#include "xla/service/shaped_buffer.h"
@ -66,6 +68,7 @@ MATCHER_P(ConfigMatcher, name, "") {
}
MATCHER_P(InstructionMatcher, opcode, "") { return arg.opcode() == opcode; }
MATCHER_P(InstrPtrMatcher, opcode, "") { return arg->opcode() == opcode; }
std::unique_ptr<google::protobuf::Any> GetTestConfig(std::string name) {
TestConfig config;
@ -125,6 +128,11 @@ class MockAutotunerCache : public AutotunerCacheInterface {
(const HloInstruction* instr,
const AutotunerCacheInterface::Config& best_config),
(override));
MOCK_METHOD(absl::StatusOr<std::string>, Serialize,
(absl::Span<const HloInstruction* const> instructions),
(override));
MOCK_METHOD(absl::Status, Deserialize, (absl::string_view serialized_cache),
(override));
};
using absl_testing::IsOk;
@ -143,29 +151,27 @@ se::DeviceDescription CreateDummyDeviceDescription() {
absl::StatusOr<std::unique_ptr<Autotuner>> SetupAutotunerWithExpectations(
HloOpcode instr_to_autotune,
std::pair<HloOpcode, int> instr_to_apply_config_and_count) {
auto cache_manager = std::make_unique<MockAutotunerCache>();
EXPECT_CALL(*cache_manager, Lookup(_)).WillRepeatedly(Return(std::nullopt));
EXPECT_CALL(*cache_manager, Insert(_, _))
.WillRepeatedly(Return(absl::OkStatus()));
std::vector<std::pair<HloOpcode, int>> instrs_to_apply_config_and_count,
std::unique_ptr<MockAutotunerCache> cache = nullptr) {
std::vector<std::unique_ptr<BackendConfig>> configs;
configs.push_back(GetTestConfig("test_config_1"));
configs.push_back(GetTestConfig("test_config_2"));
configs.push_back(GetTestConfig("another_config"));
configs.push_back(GetTestConfig("best_config"));
auto backend = std::make_unique<MockCodegenBackend>();
EXPECT_CALL(*backend, name()).WillRepeatedly(Return("mock_backend"));
EXPECT_CALL(*backend,
GetSupportedConfigs(InstructionMatcher(instr_to_autotune)))
.WillOnce(Return(std::move(configs)));
EXPECT_CALL(*backend, Compile(_, _))
.WillOnce(Return(std::unique_ptr<Executable>()))
.WillOnce(Return(std::unique_ptr<Executable>()));
HloOpcode instr_to_apply_config = instr_to_apply_config_and_count.first;
int count = instr_to_apply_config_and_count.second;
EXPECT_CALL(*backend,
ApplyConfig(InstructionMatcher(instr_to_apply_config), _))
.Times(count)
.WillRepeatedly(Return(absl::OkStatus()));
for (const auto& [instr_to_apply_config, count] :
instrs_to_apply_config_and_count) {
EXPECT_CALL(*backend,
ApplyConfig(InstructionMatcher(instr_to_apply_config), _))
.Times(count)
.WillRepeatedly(Return(absl::OkStatus()));
}
auto profiler = std::make_unique<MockProfiler>();
auto device_description = CreateDummyDeviceDescription();
@ -178,7 +184,7 @@ absl::StatusOr<std::unique_ptr<Autotuner>> SetupAutotunerWithExpectations(
std::vector<std::unique_ptr<CodegenBackend>> backends;
backends.push_back(std::move(backend));
return Autotuner::Create(std::move(backends), std::move(profiler),
GetTestAutotuneConfig(), std::move(cache_manager));
GetTestAutotuneConfig(), std::move(cache));
}
constexpr absl::string_view kHlo = R"(
@ -371,7 +377,7 @@ TEST_F(AutotunerTest, AutotuneModuleFollowsFilter) {
std::unique_ptr<Autotuner> autotuner,
SetupAutotunerWithExpectations(
/*instr_to_autotune=*/HloOpcode::kCopy,
/*instr_to_apply_config_and_count=*/{HloOpcode::kCopy, 1}));
/*instrs_to_apply_config_and_count=*/{{HloOpcode::kCopy, 1}}));
EXPECT_THAT(autotuner->Autotune(module.get(), should_autotune),
absl_testing::IsOk());
@ -388,7 +394,7 @@ TEST_F(AutotunerTest, AutotuneModuleWithDuplicateInstructions) {
std::unique_ptr<Autotuner> autotuner,
SetupAutotunerWithExpectations(
/*instr_to_autotune=*/HloOpcode::kAdd,
/*instr_to_apply_config_and_count=*/{HloOpcode::kAdd, 2}));
/*instrs_to_apply_config_and_count=*/{{HloOpcode::kAdd, 2}}));
EXPECT_THAT(autotuner->Autotune(module.get(), should_autotune), IsOk());
}
@ -745,5 +751,71 @@ TEST_F(AutotunerTest, UseDefaultConfigUnimplemented) {
"GetDefaultConfig is not implemented for mock_backend");
}
class MockKeyValueStore : public KeyValueStoreInterface {
public:
MOCK_METHOD(absl::Status, Set,
(absl::string_view key, absl::string_view value), (override));
MOCK_METHOD(absl::StatusOr<std::string>, Get,
(absl::string_view key, absl::Duration timeout), (override));
MOCK_METHOD(absl::StatusOr<std::string>, TryGet, (absl::string_view key),
(override));
};
AutotunerCacheInterface::Config GetCacheConfig(absl::string_view name) {
AutotunerCacheInterface::Config config;
config.codegen_backend_name = "mock_backend";
config.backend_config = *GetTestConfig(std::string(name));
return config;
};
TEST_F(AutotunerTest, ShardedAutotuning) {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(kHlo));
constexpr int kShardCount = 2;
auto should_autotune = [](const HloInstruction& instruction) {
return instruction.opcode() == HloOpcode::kAdd ||
instruction.opcode() == HloOpcode::kCopy;
};
auto kv_store = std::make_shared<MockKeyValueStore>();
auto cache = std::make_unique<MockAutotunerCache>();
// Shard 0 autotunes kCopy instructions, updates the cache and serializes the
// result to a string "kCopy_autotune_result".
EXPECT_CALL(*cache, Lookup(InstrPtrMatcher(HloOpcode::kCopy)))
.WillOnce(Return(std::nullopt)) // During autotuning.
.WillOnce(Return(GetCacheConfig("best_config"))); // Config application.
EXPECT_CALL(*cache, Insert(InstrPtrMatcher(HloOpcode::kCopy), _))
.WillOnce(Return(absl::OkStatus()));
EXPECT_CALL(*cache, Serialize(_)).WillOnce(Return("kCopy_autotune_result"));
// Stores the serialized results to the KV store if it does not exist.
EXPECT_CALL(*kv_store, TryGet(testing::HasSubstr("_0")))
.WillOnce(Return(absl::NotFoundError("not found")));
EXPECT_CALL(*kv_store, Set(testing::HasSubstr("_0"), "kCopy_autotune_result"))
.WillOnce(Return(absl::OkStatus()));
// Shard 0 reads the KV store entry for shard 1 and updates the current cache.
EXPECT_CALL(*kv_store, Get(testing::HasSubstr("_1"), _))
.WillOnce(Return("kAdd_autotune_result"));
EXPECT_CALL(*cache, Deserialize("kAdd_autotune_result"))
.WillOnce(Return(absl::OkStatus()));
EXPECT_CALL(*cache, Lookup(InstrPtrMatcher(HloOpcode::kAdd)))
.WillOnce(Return(GetCacheConfig("best_config")));
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Autotuner> autotuner,
SetupAutotunerWithExpectations(
/*instr_to_autotune=*/HloOpcode::kCopy,
/*instrs_to_apply_config_and_count=*/
{{HloOpcode::kCopy, 1}, {HloOpcode::kAdd, 2}}, std::move(cache)));
MultiProcessKeyValueStore sharding_kv_store;
sharding_kv_store.key_value_store = kv_store;
sharding_kv_store.process_count = kShardCount;
sharding_kv_store.process_index = 0;
EXPECT_THAT(
autotuner->Autotune(module.get(), should_autotune, sharding_kv_store),
IsOk());
}
} // namespace
} // namespace xla