mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
[Autotuner] Add sharding support using KeyValueStore Interface.
- The logic is ported from gemm_fusion_autotuner. I have changed the key of the Key Value store to be just module-fingerprint, earlier it was module-fingerprint + autotunable-fusion-set-from-the-module-fingerprint. The module fingerprint should already represent the fusion-sets contained in it. - We can improve or just remove this functionality when we design storage for offline autotuning. PiperOrigin-RevId: 826103885
This commit is contained in:
parent
4ffcba9004
commit
dd3a14ace4
6
third_party/xla/xla/backends/autotuner/BUILD
vendored
6
third_party/xla/xla/backends/autotuner/BUILD
vendored
|
|
@ -39,6 +39,7 @@ cc_library(
|
|||
"//xla:autotune_results_proto_cc",
|
||||
"//xla:autotuning_proto_cc",
|
||||
"//xla/hlo/ir:hlo",
|
||||
"//xla/pjrt/distributed:key_value_store_interface",
|
||||
"//xla/service:executable",
|
||||
"//xla/service:shaped_buffer",
|
||||
"//xla/tsl/platform:env",
|
||||
|
|
@ -60,7 +61,6 @@ cc_library(
|
|||
"@local_tsl//tsl/platform:blocking_counter",
|
||||
"@local_tsl//tsl/platform:fingerprint",
|
||||
"@local_tsl//tsl/platform:path",
|
||||
"@local_tsl//tsl/platform:protobuf",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -78,6 +78,7 @@ xla_cc_test(
|
|||
"//xla:shape_util",
|
||||
"//xla/hlo/ir:hlo",
|
||||
"//xla/hlo/testlib:hlo_hardware_independent_test_base",
|
||||
"//xla/pjrt/distributed:key_value_store_interface",
|
||||
"//xla/service:executable",
|
||||
"//xla/service:shaped_buffer",
|
||||
"//xla/service/gpu:backend_configs_cc",
|
||||
|
|
@ -92,13 +93,12 @@ xla_cc_test(
|
|||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings:string_view",
|
||||
"@com_google_absl//absl/time",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
"@com_google_protobuf//:any_cc_proto",
|
||||
"@com_google_protobuf//:protobuf",
|
||||
"@local_tsl//tsl/platform:path",
|
||||
"@local_tsl//tsl/platform:protobuf",
|
||||
"@local_tsl//tsl/platform:test",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
|||
147
third_party/xla/xla/backends/autotuner/autotuner.cc
vendored
147
third_party/xla/xla/backends/autotuner/autotuner.cc
vendored
|
|
@ -16,7 +16,10 @@ limitations under the License.
|
|||
#include "xla/backends/autotuner/autotuner.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <iterator>
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
|
|
@ -25,6 +28,7 @@ limitations under the License.
|
|||
#include <vector>
|
||||
|
||||
#include "google/protobuf/any.pb.h"
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/log/check.h"
|
||||
#include "absl/log/log.h"
|
||||
#include "absl/memory/memory.h"
|
||||
|
|
@ -41,6 +45,7 @@ limitations under the License.
|
|||
#include "xla/backends/autotuner/profiler.h"
|
||||
#include "xla/hlo/ir/hlo_instruction.h"
|
||||
#include "xla/hlo/ir/hlo_print_options.h"
|
||||
#include "xla/pjrt/distributed/key_value_store_interface.h"
|
||||
#include "xla/service/executable.h"
|
||||
#include "xla/service/shaped_buffer.h"
|
||||
#include "xla/tsl/platform/env.h"
|
||||
|
|
@ -50,7 +55,6 @@ limitations under the License.
|
|||
#include "xla/tsl/util/proto/proto_utils.h"
|
||||
#include "tsl/platform/blocking_counter.h"
|
||||
#include "tsl/platform/fingerprint.h"
|
||||
#include "tsl/platform/protobuf.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
|
|
@ -74,6 +78,38 @@ std::string UnpackedAnyShortDebugString(const google::protobuf::Any& any) {
|
|||
return s;
|
||||
}
|
||||
|
||||
// It is important to fingerprint the entire module not just the autotuning
|
||||
// candidates, to avoid collisions in the key-value store when several
|
||||
// distinct modules have the same fusions, and are compiled at different
|
||||
// times by the same PjRt client.
|
||||
//
|
||||
// TODO(b/394763704): Eliminate the sharding feature when we have offline
|
||||
// autotuning. See below for an explanation of some issues.
|
||||
//
|
||||
// Theoretically, we also want to include the hash of the module config
|
||||
// to ensure that a module compiled twice with different configs is
|
||||
// autotuned twice.
|
||||
//
|
||||
// This is important since the config could e.g. affect codegen, or the
|
||||
// space of possible parameters for autotuning. As a result, the autotuning
|
||||
// results could look very different for the same module.
|
||||
//
|
||||
// Why is it not done here? Well, proto serialization is non-deterministic
|
||||
// and may change across different builds. Which means that users who run
|
||||
// on several hosts with different CPUs may end up generating different
|
||||
// fingerprints for the same module config. They would then fail to
|
||||
// exchange results through the key value store, which would lead to
|
||||
// deadlocks. Therefore, we don't hash the module config here.
|
||||
//
|
||||
// The flip side is this: if we compile the same module twice in the same
|
||||
// client, but with a different module config each time, we may hit the
|
||||
// cache the second time and recover potentially inferior, or incomplete
|
||||
// autotuning results.
|
||||
std::string GetKvStoreKey(const HloModule* module, int shard_index) {
|
||||
return absl::StrCat("autotune_results_", module->GetFingerprint128(), "_",
|
||||
shard_index);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
absl::StatusOr<Autotuner::Config> Autotuner::GetDefaultConfig(
|
||||
|
|
@ -110,15 +146,15 @@ absl::StatusOr<std::unique_ptr<Autotuner>> Autotuner::Create(
|
|||
|
||||
absl::Status Autotuner::Autotune(HloModule* module,
|
||||
const InstructionFilterFn& should_autotune) {
|
||||
InstructionsByFingerprint instrunctions_by_fingerprint =
|
||||
InstructionsByFingerprint instructions_by_fingerprint =
|
||||
GetAutotuningCandidates(module, should_autotune);
|
||||
if (instrunctions_by_fingerprint.empty()) {
|
||||
if (instructions_by_fingerprint.empty()) {
|
||||
VLOG(1) << "No instructions to autotune.";
|
||||
return absl::OkStatus();
|
||||
}
|
||||
VLOG(1) << "Finding configs for " << instrunctions_by_fingerprint.size()
|
||||
VLOG(1) << "Finding configs for " << instructions_by_fingerprint.size()
|
||||
<< " unique instructions.";
|
||||
for (auto& [_, instructions] : instrunctions_by_fingerprint) {
|
||||
for (auto& [_, instructions] : instructions_by_fingerprint) {
|
||||
CHECK(!instructions.empty());
|
||||
TF_ASSIGN_OR_RETURN(Config config, GetConfig(instructions[0]));
|
||||
CodegenBackend* codegen_backend = config.codegen_backend;
|
||||
|
|
@ -130,6 +166,101 @@ absl::Status Autotuner::Autotune(HloModule* module,
|
|||
return DumpLogsToFile();
|
||||
}
|
||||
|
||||
absl::Status Autotuner::Autotune(HloModule* module,
|
||||
const InstructionFilterFn& should_autotune,
|
||||
MultiProcessKeyValueStore& sharding_kv_store) {
|
||||
CHECK(cache_ != nullptr) << "Sharding autotuning requires a cache.";
|
||||
int total_shards = sharding_kv_store.process_count;
|
||||
int my_shard_index = sharding_kv_store.process_index;
|
||||
|
||||
// 1. Get all the instructions that could be autotuned.
|
||||
InstructionsByFingerprint all_instructions_by_fingerprint =
|
||||
GetAutotuningCandidates(module, should_autotune);
|
||||
if (all_instructions_by_fingerprint.empty()) {
|
||||
VLOG(1) << "No instructions to autotune.";
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
// 2. Shard and get instructions to autotune for current shard.
|
||||
const size_t bucket_size =
|
||||
std::ceil(static_cast<double>(all_instructions_by_fingerprint.size()) /
|
||||
static_cast<double>(total_shards));
|
||||
const size_t start = bucket_size * my_shard_index;
|
||||
const size_t end =
|
||||
std::min(start + bucket_size, all_instructions_by_fingerprint.size());
|
||||
InstructionsByFingerprint instructions_by_fingerprint(
|
||||
std::next(all_instructions_by_fingerprint.begin(), start),
|
||||
std::next(all_instructions_by_fingerprint.begin(), end));
|
||||
|
||||
// 3. Autotune instructions for this shard. Use cached configs if available,
|
||||
// otherwise autotune and cache the best config.
|
||||
VLOG(1) << "Shard " << my_shard_index << "/" << total_shards
|
||||
<< ": finding configs for " << instructions_by_fingerprint.size()
|
||||
<< "/" << all_instructions_by_fingerprint.size()
|
||||
<< " unique instructions ";
|
||||
std::vector<const HloInstruction*> autotuned_instructions;
|
||||
for (auto& [_, instructions] : instructions_by_fingerprint) {
|
||||
CHECK(!instructions.empty());
|
||||
TF_ASSIGN_OR_RETURN(Config config, GetConfig(instructions[0]));
|
||||
autotuned_instructions.push_back(instructions[0]);
|
||||
}
|
||||
TF_RETURN_IF_ERROR(DumpLogsToFile());
|
||||
|
||||
// 4. Store the results for this shard as a serialized string to the KV store.
|
||||
KeyValueStoreInterface& kv_store = *sharding_kv_store.key_value_store;
|
||||
const std::string local_key = GetKvStoreKey(module, my_shard_index);
|
||||
std::string local_results;
|
||||
if (!autotuned_instructions.empty()) {
|
||||
TF_ASSIGN_OR_RETURN(local_results,
|
||||
cache_->Serialize(autotuned_instructions));
|
||||
}
|
||||
absl::StatusOr<std::string> stored_result = kv_store.TryGet(local_key);
|
||||
if (stored_result.status().code() == absl::StatusCode::kNotFound) {
|
||||
VLOG(2) << "Storing results for " << local_key;
|
||||
TF_RETURN_IF_ERROR(kv_store.Set(local_key, local_results));
|
||||
VLOG(2) << "Shard " << my_shard_index << " stored results at " << local_key;
|
||||
} else if (!stored_result.ok()) {
|
||||
return stored_result.status();
|
||||
} else {
|
||||
VLOG(2) << "Results already exist for " << local_key << ", skipping store.";
|
||||
}
|
||||
|
||||
// 5. Load the autotune results of other shards from the KV store and update
|
||||
// the current shard's cache by deserializing the results.
|
||||
for (int i = 0; i < total_shards; ++i) {
|
||||
if (i == my_shard_index) {
|
||||
continue;
|
||||
}
|
||||
const std::string remote_key = GetKvStoreKey(module, i);
|
||||
VLOG(2) << "Shard " << my_shard_index << ": waiting for results from shard "
|
||||
<< i << " / " << total_shards << " at " << remote_key;
|
||||
// TODO(b/361009609): reset to infinite duration once issue with MPI is
|
||||
// fixed. https://github.com/google/jax/issues/22995.
|
||||
TF_ASSIGN_OR_RETURN(std::string remote_results,
|
||||
kv_store.Get(remote_key, absl::Hours(24)));
|
||||
if (!remote_results.empty()) {
|
||||
TF_RETURN_IF_ERROR(cache_->Deserialize(remote_results));
|
||||
}
|
||||
}
|
||||
|
||||
// 6. Apply the results to all candidate instructions, must be already in
|
||||
// cache_ due to step 3 and 5 above.
|
||||
for (auto& [_, instructions] : all_instructions_by_fingerprint) {
|
||||
CHECK(!instructions.empty());
|
||||
std::optional<Config> cached_config = LookUp(instructions[0]);
|
||||
CHECK(cached_config.has_value())
|
||||
<< "Sharding autotuning failed: no config found for HLO: " +
|
||||
instructions[0]->ToString();
|
||||
CodegenBackend* codegen_backend = cached_config->codegen_backend;
|
||||
for (auto* instr : instructions) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
codegen_backend->ApplyConfig(*instr, *cached_config->backend_config));
|
||||
}
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Autotuner::Autotune(HloInstruction* instr) {
|
||||
TF_ASSIGN_OR_RETURN(Config config, GetConfig(instr));
|
||||
CodegenBackend* codegen_backend = config.codegen_backend;
|
||||
|
|
@ -221,15 +352,15 @@ absl::StatusOr<Autotuner::Config> Autotuner::TuneBestConfig(
|
|||
|
||||
Autotuner::InstructionsByFingerprint Autotuner::GetAutotuningCandidates(
|
||||
const HloModule* module, const InstructionFilterFn& should_autotune) {
|
||||
InstructionsByFingerprint instrunctions_by_fingerprint;
|
||||
InstructionsByFingerprint instructions_by_fingerprint;
|
||||
for (HloComputation* computation : module->MakeNonfusionComputations()) {
|
||||
for (HloInstruction* instr : computation->MakeInstructionPostOrder()) {
|
||||
if (should_autotune(*instr)) {
|
||||
instrunctions_by_fingerprint[GetFingerprint(instr)].push_back(instr);
|
||||
instructions_by_fingerprint[GetFingerprint(instr)].push_back(instr);
|
||||
}
|
||||
}
|
||||
}
|
||||
return instrunctions_by_fingerprint;
|
||||
return instructions_by_fingerprint;
|
||||
}
|
||||
|
||||
std::optional<Autotuner::Config> Autotuner::LookUp(
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ limitations under the License.
|
|||
#include "xla/backends/autotuner/codegen_backend.h"
|
||||
#include "xla/backends/autotuner/profiler.h"
|
||||
#include "xla/hlo/ir/hlo_instruction.h"
|
||||
#include "xla/pjrt/distributed/key_value_store_interface.h"
|
||||
#include "xla/service/executable.h"
|
||||
#include "xla/service/shaped_buffer.h"
|
||||
#include "xla/tsl/platform/threadpool.h"
|
||||
|
|
@ -103,6 +104,13 @@ class Autotuner {
|
|||
absl::Status Autotune(HloModule* module,
|
||||
const InstructionFilterFn& should_autotune);
|
||||
|
||||
// Same as above, but also takes a sharding KV store which helps to shard
|
||||
// the autotuning work across multiple processes.
|
||||
// This is used for distributed autotuning.
|
||||
absl::Status Autotune(HloModule* module,
|
||||
const InstructionFilterFn& should_autotune,
|
||||
MultiProcessKeyValueStore& sharding_kv_store);
|
||||
|
||||
private:
|
||||
using InstructionsByFingerprint =
|
||||
absl::flat_hash_map<tsl::Fprint128, std::vector<HloInstruction*>,
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ limitations under the License.
|
|||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/time/time.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "google/protobuf/text_format.h"
|
||||
#include "xla/autotune_results.pb.h"
|
||||
#include "xla/autotuning.pb.h"
|
||||
|
|
@ -39,6 +40,7 @@ limitations under the License.
|
|||
#include "xla/hlo/ir/hlo_opcode.h"
|
||||
#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h"
|
||||
#include "xla/literal_util.h"
|
||||
#include "xla/pjrt/distributed/key_value_store_interface.h"
|
||||
#include "xla/service/executable.h"
|
||||
#include "xla/service/gpu/backend_configs.pb.h"
|
||||
#include "xla/service/shaped_buffer.h"
|
||||
|
|
@ -66,6 +68,7 @@ MATCHER_P(ConfigMatcher, name, "") {
|
|||
}
|
||||
|
||||
MATCHER_P(InstructionMatcher, opcode, "") { return arg.opcode() == opcode; }
|
||||
MATCHER_P(InstrPtrMatcher, opcode, "") { return arg->opcode() == opcode; }
|
||||
|
||||
std::unique_ptr<google::protobuf::Any> GetTestConfig(std::string name) {
|
||||
TestConfig config;
|
||||
|
|
@ -125,6 +128,11 @@ class MockAutotunerCache : public AutotunerCacheInterface {
|
|||
(const HloInstruction* instr,
|
||||
const AutotunerCacheInterface::Config& best_config),
|
||||
(override));
|
||||
MOCK_METHOD(absl::StatusOr<std::string>, Serialize,
|
||||
(absl::Span<const HloInstruction* const> instructions),
|
||||
(override));
|
||||
MOCK_METHOD(absl::Status, Deserialize, (absl::string_view serialized_cache),
|
||||
(override));
|
||||
};
|
||||
|
||||
using absl_testing::IsOk;
|
||||
|
|
@ -143,29 +151,27 @@ se::DeviceDescription CreateDummyDeviceDescription() {
|
|||
|
||||
absl::StatusOr<std::unique_ptr<Autotuner>> SetupAutotunerWithExpectations(
|
||||
HloOpcode instr_to_autotune,
|
||||
std::pair<HloOpcode, int> instr_to_apply_config_and_count) {
|
||||
auto cache_manager = std::make_unique<MockAutotunerCache>();
|
||||
EXPECT_CALL(*cache_manager, Lookup(_)).WillRepeatedly(Return(std::nullopt));
|
||||
EXPECT_CALL(*cache_manager, Insert(_, _))
|
||||
.WillRepeatedly(Return(absl::OkStatus()));
|
||||
|
||||
std::vector<std::pair<HloOpcode, int>> instrs_to_apply_config_and_count,
|
||||
std::unique_ptr<MockAutotunerCache> cache = nullptr) {
|
||||
std::vector<std::unique_ptr<BackendConfig>> configs;
|
||||
configs.push_back(GetTestConfig("test_config_1"));
|
||||
configs.push_back(GetTestConfig("test_config_2"));
|
||||
configs.push_back(GetTestConfig("another_config"));
|
||||
configs.push_back(GetTestConfig("best_config"));
|
||||
|
||||
auto backend = std::make_unique<MockCodegenBackend>();
|
||||
EXPECT_CALL(*backend, name()).WillRepeatedly(Return("mock_backend"));
|
||||
EXPECT_CALL(*backend,
|
||||
GetSupportedConfigs(InstructionMatcher(instr_to_autotune)))
|
||||
.WillOnce(Return(std::move(configs)));
|
||||
EXPECT_CALL(*backend, Compile(_, _))
|
||||
.WillOnce(Return(std::unique_ptr<Executable>()))
|
||||
.WillOnce(Return(std::unique_ptr<Executable>()));
|
||||
HloOpcode instr_to_apply_config = instr_to_apply_config_and_count.first;
|
||||
int count = instr_to_apply_config_and_count.second;
|
||||
EXPECT_CALL(*backend,
|
||||
ApplyConfig(InstructionMatcher(instr_to_apply_config), _))
|
||||
.Times(count)
|
||||
.WillRepeatedly(Return(absl::OkStatus()));
|
||||
for (const auto& [instr_to_apply_config, count] :
|
||||
instrs_to_apply_config_and_count) {
|
||||
EXPECT_CALL(*backend,
|
||||
ApplyConfig(InstructionMatcher(instr_to_apply_config), _))
|
||||
.Times(count)
|
||||
.WillRepeatedly(Return(absl::OkStatus()));
|
||||
}
|
||||
|
||||
auto profiler = std::make_unique<MockProfiler>();
|
||||
auto device_description = CreateDummyDeviceDescription();
|
||||
|
|
@ -178,7 +184,7 @@ absl::StatusOr<std::unique_ptr<Autotuner>> SetupAutotunerWithExpectations(
|
|||
std::vector<std::unique_ptr<CodegenBackend>> backends;
|
||||
backends.push_back(std::move(backend));
|
||||
return Autotuner::Create(std::move(backends), std::move(profiler),
|
||||
GetTestAutotuneConfig(), std::move(cache_manager));
|
||||
GetTestAutotuneConfig(), std::move(cache));
|
||||
}
|
||||
|
||||
constexpr absl::string_view kHlo = R"(
|
||||
|
|
@ -371,7 +377,7 @@ TEST_F(AutotunerTest, AutotuneModuleFollowsFilter) {
|
|||
std::unique_ptr<Autotuner> autotuner,
|
||||
SetupAutotunerWithExpectations(
|
||||
/*instr_to_autotune=*/HloOpcode::kCopy,
|
||||
/*instr_to_apply_config_and_count=*/{HloOpcode::kCopy, 1}));
|
||||
/*instrs_to_apply_config_and_count=*/{{HloOpcode::kCopy, 1}}));
|
||||
|
||||
EXPECT_THAT(autotuner->Autotune(module.get(), should_autotune),
|
||||
absl_testing::IsOk());
|
||||
|
|
@ -388,7 +394,7 @@ TEST_F(AutotunerTest, AutotuneModuleWithDuplicateInstructions) {
|
|||
std::unique_ptr<Autotuner> autotuner,
|
||||
SetupAutotunerWithExpectations(
|
||||
/*instr_to_autotune=*/HloOpcode::kAdd,
|
||||
/*instr_to_apply_config_and_count=*/{HloOpcode::kAdd, 2}));
|
||||
/*instrs_to_apply_config_and_count=*/{{HloOpcode::kAdd, 2}}));
|
||||
|
||||
EXPECT_THAT(autotuner->Autotune(module.get(), should_autotune), IsOk());
|
||||
}
|
||||
|
|
@ -745,5 +751,71 @@ TEST_F(AutotunerTest, UseDefaultConfigUnimplemented) {
|
|||
"GetDefaultConfig is not implemented for mock_backend");
|
||||
}
|
||||
|
||||
class MockKeyValueStore : public KeyValueStoreInterface {
|
||||
public:
|
||||
MOCK_METHOD(absl::Status, Set,
|
||||
(absl::string_view key, absl::string_view value), (override));
|
||||
MOCK_METHOD(absl::StatusOr<std::string>, Get,
|
||||
(absl::string_view key, absl::Duration timeout), (override));
|
||||
MOCK_METHOD(absl::StatusOr<std::string>, TryGet, (absl::string_view key),
|
||||
(override));
|
||||
};
|
||||
|
||||
AutotunerCacheInterface::Config GetCacheConfig(absl::string_view name) {
|
||||
AutotunerCacheInterface::Config config;
|
||||
config.codegen_backend_name = "mock_backend";
|
||||
config.backend_config = *GetTestConfig(std::string(name));
|
||||
return config;
|
||||
};
|
||||
|
||||
TEST_F(AutotunerTest, ShardedAutotuning) {
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(kHlo));
|
||||
constexpr int kShardCount = 2;
|
||||
auto should_autotune = [](const HloInstruction& instruction) {
|
||||
return instruction.opcode() == HloOpcode::kAdd ||
|
||||
instruction.opcode() == HloOpcode::kCopy;
|
||||
};
|
||||
auto kv_store = std::make_shared<MockKeyValueStore>();
|
||||
auto cache = std::make_unique<MockAutotunerCache>();
|
||||
|
||||
// Shard 0 autotunes kCopy instructions, updates the cache and serializes the
|
||||
// result to a string "kCopy_autotune_result".
|
||||
EXPECT_CALL(*cache, Lookup(InstrPtrMatcher(HloOpcode::kCopy)))
|
||||
.WillOnce(Return(std::nullopt)) // During autotuning.
|
||||
.WillOnce(Return(GetCacheConfig("best_config"))); // Config application.
|
||||
EXPECT_CALL(*cache, Insert(InstrPtrMatcher(HloOpcode::kCopy), _))
|
||||
.WillOnce(Return(absl::OkStatus()));
|
||||
EXPECT_CALL(*cache, Serialize(_)).WillOnce(Return("kCopy_autotune_result"));
|
||||
// Stores the serialized results to the KV store if it does not exist.
|
||||
EXPECT_CALL(*kv_store, TryGet(testing::HasSubstr("_0")))
|
||||
.WillOnce(Return(absl::NotFoundError("not found")));
|
||||
EXPECT_CALL(*kv_store, Set(testing::HasSubstr("_0"), "kCopy_autotune_result"))
|
||||
.WillOnce(Return(absl::OkStatus()));
|
||||
|
||||
// Shard 0 reads the KV store entry for shard 1 and updates the current cache.
|
||||
EXPECT_CALL(*kv_store, Get(testing::HasSubstr("_1"), _))
|
||||
.WillOnce(Return("kAdd_autotune_result"));
|
||||
EXPECT_CALL(*cache, Deserialize("kAdd_autotune_result"))
|
||||
.WillOnce(Return(absl::OkStatus()));
|
||||
EXPECT_CALL(*cache, Lookup(InstrPtrMatcher(HloOpcode::kAdd)))
|
||||
.WillOnce(Return(GetCacheConfig("best_config")));
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
std::unique_ptr<Autotuner> autotuner,
|
||||
SetupAutotunerWithExpectations(
|
||||
/*instr_to_autotune=*/HloOpcode::kCopy,
|
||||
/*instrs_to_apply_config_and_count=*/
|
||||
{{HloOpcode::kCopy, 1}, {HloOpcode::kAdd, 2}}, std::move(cache)));
|
||||
|
||||
MultiProcessKeyValueStore sharding_kv_store;
|
||||
sharding_kv_store.key_value_store = kv_store;
|
||||
sharding_kv_store.process_count = kShardCount;
|
||||
sharding_kv_store.process_index = 0;
|
||||
EXPECT_THAT(
|
||||
autotuner->Autotune(module.get(), should_autotune, sharding_kv_store),
|
||||
IsOk());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user