From e32304ddc52b41093da1505830007c1d15de69f2 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 31 Oct 2025 03:42:35 -0700 Subject: [PATCH] [Autotuner]Add support for sharded autotuning in the pass. PiperOrigin-RevId: 826417614 --- .../xla/xla/service/gpu/autotuning/BUILD | 1 + .../service/gpu/autotuning/autotuner_pass.cc | 18 ++++++++++++++---- .../service/gpu/autotuning/autotuner_pass.h | 16 +++++++++++++--- 3 files changed, 28 insertions(+), 7 deletions(-) diff --git a/third_party/xla/xla/service/gpu/autotuning/BUILD b/third_party/xla/xla/service/gpu/autotuning/BUILD index 8aca9cf65b2..bdd919aa69b 100644 --- a/third_party/xla/xla/service/gpu/autotuning/BUILD +++ b/third_party/xla/xla/service/gpu/autotuning/BUILD @@ -724,6 +724,7 @@ cc_library( "//xla/backends/gpu/autotuner:legacy_cache", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", + "//xla/pjrt/distributed:key_value_store_interface", "//xla/service:compiler", "//xla/stream_executor:device_description", "//xla/stream_executor:device_memory_allocator", diff --git a/third_party/xla/xla/service/gpu/autotuning/autotuner_pass.cc b/third_party/xla/xla/service/gpu/autotuning/autotuner_pass.cc index a11da2525df..d62cfbee1b6 100644 --- a/third_party/xla/xla/service/gpu/autotuning/autotuner_pass.cc +++ b/third_party/xla/xla/service/gpu/autotuning/autotuner_pass.cc @@ -34,6 +34,7 @@ limitations under the License. #include "xla/backends/gpu/autotuner/legacy_cache.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/pjrt/distributed/key_value_store_interface.h" #include "xla/service/compiler.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/device_memory_allocator.h" @@ -91,7 +92,8 @@ absl::StatusOr> AutotunerPass::Create( stream_executor::StreamExecutor* stream_executor, tsl::thread::ThreadPool* thread_pool, InstructionFilterFn should_autotune, const Compiler::TargetConfig* target_config, - se::DeviceMemoryAllocator* allocator, bool optimize_scratch_bytes) { + se::DeviceMemoryAllocator* allocator, bool optimize_scratch_bytes, + MultiProcessKeyValueStore key_value_store) { std::unique_ptr profiler = nullptr; bool is_deviceless = stream_executor == nullptr; AutotuneConfig autotune_config = @@ -112,8 +114,9 @@ absl::StatusOr> AutotunerPass::Create( std::unique_ptr autotuner, Autotuner::Create(std::move(backends), std::move(profiler), autotune_config, std::move(cache), thread_pool)); - return absl::WrapUnique( - new AutotunerPass(std::move(autotuner), should_autotune)); + return absl::WrapUnique(new AutotunerPass( + std::move(autotuner), should_autotune, std::move(key_value_store), + debug_options.xla_gpu_shard_autotuning())); } absl::StatusOr AutotunerPass::Run( @@ -121,7 +124,14 @@ absl::StatusOr AutotunerPass::Run( const absl::flat_hash_set& execution_threads) { VLOG(1) << "Running Autotuner Pass"; - TF_RETURN_IF_ERROR(autotuner_->Autotune(module, should_autotune_)); + bool shard_autotuning = + enable_sharding_ && key_value_store_.process_count > 1; + if (shard_autotuning) { + TF_RETURN_IF_ERROR( + autotuner_->Autotune(module, should_autotune_, key_value_store_)); + } else { + TF_RETURN_IF_ERROR(autotuner_->Autotune(module, should_autotune_)); + } return true; } diff --git a/third_party/xla/xla/service/gpu/autotuning/autotuner_pass.h b/third_party/xla/xla/service/gpu/autotuning/autotuner_pass.h index 372a5a9b334..5d55ae570d6 100644 --- a/third_party/xla/xla/service/gpu/autotuning/autotuner_pass.h +++ b/third_party/xla/xla/service/gpu/autotuning/autotuner_pass.h @@ -27,6 +27,7 @@ limitations under the License. #include "xla/backends/autotuner/codegen_backend.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/pass/hlo_pass_interface.h" +#include "xla/pjrt/distributed/key_value_store_interface.h" #include "xla/service/compiler.h" #include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/stream_executor.h" @@ -36,6 +37,7 @@ limitations under the License. namespace xla { namespace gpu { +// HloModulePass that runs the autotuner. class AutotunerPass : public HloModulePass { public: // Note: the target_config must outlive the pass. @@ -45,7 +47,8 @@ class AutotunerPass : public HloModulePass { tsl::thread::ThreadPool* thread_pool, InstructionFilterFn should_autotune, const Compiler::TargetConfig* target_config, se::DeviceMemoryAllocator* allocator = nullptr, - bool optimize_scratch_bytes = true); + bool optimize_scratch_bytes = true, + MultiProcessKeyValueStore key_value_store = MultiProcessKeyValueStore()); absl::string_view name() const override { return "autotuner"; } @@ -56,11 +59,18 @@ class AutotunerPass : public HloModulePass { private: explicit AutotunerPass(std::unique_ptr autotuner, - InstructionFilterFn should_autotune) - : autotuner_(std::move(autotuner)), should_autotune_(should_autotune) {} + InstructionFilterFn should_autotune, + MultiProcessKeyValueStore key_value_store, + bool enable_sharding) + : autotuner_(std::move(autotuner)), + should_autotune_(should_autotune), + key_value_store_(std::move(key_value_store)), + enable_sharding_(enable_sharding) {} std::unique_ptr autotuner_; InstructionFilterFn should_autotune_; + MultiProcessKeyValueStore key_value_store_; + bool enable_sharding_ = false; }; } // namespace gpu