mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
[Autotuner]Add support for sharded autotuning in the pass.
PiperOrigin-RevId: 826417614
This commit is contained in:
parent
e32f20dd91
commit
e32304ddc5
|
|
@ -724,6 +724,7 @@ cc_library(
|
|||
"//xla/backends/gpu/autotuner:legacy_cache",
|
||||
"//xla/hlo/ir:hlo",
|
||||
"//xla/hlo/pass:hlo_pass",
|
||||
"//xla/pjrt/distributed:key_value_store_interface",
|
||||
"//xla/service:compiler",
|
||||
"//xla/stream_executor:device_description",
|
||||
"//xla/stream_executor:device_memory_allocator",
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@ limitations under the License.
|
|||
#include "xla/backends/gpu/autotuner/legacy_cache.h"
|
||||
#include "xla/hlo/ir/hlo_instruction.h"
|
||||
#include "xla/hlo/ir/hlo_module.h"
|
||||
#include "xla/pjrt/distributed/key_value_store_interface.h"
|
||||
#include "xla/service/compiler.h"
|
||||
#include "xla/stream_executor/device_description.h"
|
||||
#include "xla/stream_executor/device_memory_allocator.h"
|
||||
|
|
@ -91,7 +92,8 @@ absl::StatusOr<std::unique_ptr<AutotunerPass>> AutotunerPass::Create(
|
|||
stream_executor::StreamExecutor* stream_executor,
|
||||
tsl::thread::ThreadPool* thread_pool, InstructionFilterFn should_autotune,
|
||||
const Compiler::TargetConfig* target_config,
|
||||
se::DeviceMemoryAllocator* allocator, bool optimize_scratch_bytes) {
|
||||
se::DeviceMemoryAllocator* allocator, bool optimize_scratch_bytes,
|
||||
MultiProcessKeyValueStore key_value_store) {
|
||||
std::unique_ptr<Profiler> profiler = nullptr;
|
||||
bool is_deviceless = stream_executor == nullptr;
|
||||
AutotuneConfig autotune_config =
|
||||
|
|
@ -112,8 +114,9 @@ absl::StatusOr<std::unique_ptr<AutotunerPass>> AutotunerPass::Create(
|
|||
std::unique_ptr<Autotuner> autotuner,
|
||||
Autotuner::Create(std::move(backends), std::move(profiler),
|
||||
autotune_config, std::move(cache), thread_pool));
|
||||
return absl::WrapUnique(
|
||||
new AutotunerPass(std::move(autotuner), should_autotune));
|
||||
return absl::WrapUnique(new AutotunerPass(
|
||||
std::move(autotuner), should_autotune, std::move(key_value_store),
|
||||
debug_options.xla_gpu_shard_autotuning()));
|
||||
}
|
||||
|
||||
absl::StatusOr<bool> AutotunerPass::Run(
|
||||
|
|
@ -121,7 +124,14 @@ absl::StatusOr<bool> AutotunerPass::Run(
|
|||
const absl::flat_hash_set<absl::string_view>& execution_threads) {
|
||||
VLOG(1) << "Running Autotuner Pass";
|
||||
|
||||
bool shard_autotuning =
|
||||
enable_sharding_ && key_value_store_.process_count > 1;
|
||||
if (shard_autotuning) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
autotuner_->Autotune(module, should_autotune_, key_value_store_));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(autotuner_->Autotune(module, should_autotune_));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ limitations under the License.
|
|||
#include "xla/backends/autotuner/codegen_backend.h"
|
||||
#include "xla/hlo/ir/hlo_module.h"
|
||||
#include "xla/hlo/pass/hlo_pass_interface.h"
|
||||
#include "xla/pjrt/distributed/key_value_store_interface.h"
|
||||
#include "xla/service/compiler.h"
|
||||
#include "xla/stream_executor/device_memory_allocator.h"
|
||||
#include "xla/stream_executor/stream_executor.h"
|
||||
|
|
@ -36,6 +37,7 @@ limitations under the License.
|
|||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
// HloModulePass that runs the autotuner.
|
||||
class AutotunerPass : public HloModulePass {
|
||||
public:
|
||||
// Note: the target_config must outlive the pass.
|
||||
|
|
@ -45,7 +47,8 @@ class AutotunerPass : public HloModulePass {
|
|||
tsl::thread::ThreadPool* thread_pool, InstructionFilterFn should_autotune,
|
||||
const Compiler::TargetConfig* target_config,
|
||||
se::DeviceMemoryAllocator* allocator = nullptr,
|
||||
bool optimize_scratch_bytes = true);
|
||||
bool optimize_scratch_bytes = true,
|
||||
MultiProcessKeyValueStore key_value_store = MultiProcessKeyValueStore());
|
||||
|
||||
absl::string_view name() const override { return "autotuner"; }
|
||||
|
||||
|
|
@ -56,11 +59,18 @@ class AutotunerPass : public HloModulePass {
|
|||
|
||||
private:
|
||||
explicit AutotunerPass(std::unique_ptr<Autotuner> autotuner,
|
||||
InstructionFilterFn should_autotune)
|
||||
: autotuner_(std::move(autotuner)), should_autotune_(should_autotune) {}
|
||||
InstructionFilterFn should_autotune,
|
||||
MultiProcessKeyValueStore key_value_store,
|
||||
bool enable_sharding)
|
||||
: autotuner_(std::move(autotuner)),
|
||||
should_autotune_(should_autotune),
|
||||
key_value_store_(std::move(key_value_store)),
|
||||
enable_sharding_(enable_sharding) {}
|
||||
|
||||
std::unique_ptr<Autotuner> autotuner_;
|
||||
InstructionFilterFn should_autotune_;
|
||||
MultiProcessKeyValueStore key_value_store_;
|
||||
bool enable_sharding_ = false;
|
||||
};
|
||||
|
||||
} // namespace gpu
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user