[Autotuner]Add support for sharded autotuning in the pass.

PiperOrigin-RevId: 826417614
This commit is contained in:
A. Unique TensorFlower 2025-10-31 03:42:35 -07:00 committed by TensorFlower Gardener
parent e32f20dd91
commit e32304ddc5
3 changed files with 28 additions and 7 deletions

View File

@ -724,6 +724,7 @@ cc_library(
"//xla/backends/gpu/autotuner:legacy_cache",
"//xla/hlo/ir:hlo",
"//xla/hlo/pass:hlo_pass",
"//xla/pjrt/distributed:key_value_store_interface",
"//xla/service:compiler",
"//xla/stream_executor:device_description",
"//xla/stream_executor:device_memory_allocator",

View File

@ -34,6 +34,7 @@ limitations under the License.
#include "xla/backends/gpu/autotuner/legacy_cache.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/pjrt/distributed/key_value_store_interface.h"
#include "xla/service/compiler.h"
#include "xla/stream_executor/device_description.h"
#include "xla/stream_executor/device_memory_allocator.h"
@ -91,7 +92,8 @@ absl::StatusOr<std::unique_ptr<AutotunerPass>> AutotunerPass::Create(
stream_executor::StreamExecutor* stream_executor,
tsl::thread::ThreadPool* thread_pool, InstructionFilterFn should_autotune,
const Compiler::TargetConfig* target_config,
se::DeviceMemoryAllocator* allocator, bool optimize_scratch_bytes) {
se::DeviceMemoryAllocator* allocator, bool optimize_scratch_bytes,
MultiProcessKeyValueStore key_value_store) {
std::unique_ptr<Profiler> profiler = nullptr;
bool is_deviceless = stream_executor == nullptr;
AutotuneConfig autotune_config =
@ -112,8 +114,9 @@ absl::StatusOr<std::unique_ptr<AutotunerPass>> AutotunerPass::Create(
std::unique_ptr<Autotuner> autotuner,
Autotuner::Create(std::move(backends), std::move(profiler),
autotune_config, std::move(cache), thread_pool));
return absl::WrapUnique(
new AutotunerPass(std::move(autotuner), should_autotune));
return absl::WrapUnique(new AutotunerPass(
std::move(autotuner), should_autotune, std::move(key_value_store),
debug_options.xla_gpu_shard_autotuning()));
}
absl::StatusOr<bool> AutotunerPass::Run(
@ -121,7 +124,14 @@ absl::StatusOr<bool> AutotunerPass::Run(
const absl::flat_hash_set<absl::string_view>& execution_threads) {
VLOG(1) << "Running Autotuner Pass";
TF_RETURN_IF_ERROR(autotuner_->Autotune(module, should_autotune_));
bool shard_autotuning =
enable_sharding_ && key_value_store_.process_count > 1;
if (shard_autotuning) {
TF_RETURN_IF_ERROR(
autotuner_->Autotune(module, should_autotune_, key_value_store_));
} else {
TF_RETURN_IF_ERROR(autotuner_->Autotune(module, should_autotune_));
}
return true;
}

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "xla/backends/autotuner/codegen_backend.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/pass/hlo_pass_interface.h"
#include "xla/pjrt/distributed/key_value_store_interface.h"
#include "xla/service/compiler.h"
#include "xla/stream_executor/device_memory_allocator.h"
#include "xla/stream_executor/stream_executor.h"
@ -36,6 +37,7 @@ limitations under the License.
namespace xla {
namespace gpu {
// HloModulePass that runs the autotuner.
class AutotunerPass : public HloModulePass {
public:
// Note: the target_config must outlive the pass.
@ -45,7 +47,8 @@ class AutotunerPass : public HloModulePass {
tsl::thread::ThreadPool* thread_pool, InstructionFilterFn should_autotune,
const Compiler::TargetConfig* target_config,
se::DeviceMemoryAllocator* allocator = nullptr,
bool optimize_scratch_bytes = true);
bool optimize_scratch_bytes = true,
MultiProcessKeyValueStore key_value_store = MultiProcessKeyValueStore());
absl::string_view name() const override { return "autotuner"; }
@ -56,11 +59,18 @@ class AutotunerPass : public HloModulePass {
private:
explicit AutotunerPass(std::unique_ptr<Autotuner> autotuner,
InstructionFilterFn should_autotune)
: autotuner_(std::move(autotuner)), should_autotune_(should_autotune) {}
InstructionFilterFn should_autotune,
MultiProcessKeyValueStore key_value_store,
bool enable_sharding)
: autotuner_(std::move(autotuner)),
should_autotune_(should_autotune),
key_value_store_(std::move(key_value_store)),
enable_sharding_(enable_sharding) {}
std::unique_ptr<Autotuner> autotuner_;
InstructionFilterFn should_autotune_;
MultiProcessKeyValueStore key_value_store_;
bool enable_sharding_ = false;
};
} // namespace gpu