[XLA:GPU] Don't fail Autotuner::GetSupportedConfigs if one of the backend fails

PiperOrigin-RevId: 820303427
This commit is contained in:
A. Unique TensorFlower 2025-10-16 10:46:19 -07:00 committed by TensorFlower Gardener
parent 3c991bd608
commit 83c407040a
2 changed files with 43 additions and 4 deletions

View File

@ -248,10 +248,12 @@ absl::StatusOr<std::vector<Autotuner::Config>> Autotuner::GetSupportedConfigs(
HloInstruction* instr) {
std::vector<Config> configs;
for (auto& codegen_backend : codegen_backends_) {
std::vector<std::unique_ptr<BackendConfig>> per_backend_configs;
TF_ASSIGN_OR_RETURN(per_backend_configs,
codegen_backend->GetSupportedConfigs(*instr));
for (auto& config : per_backend_configs) {
absl::StatusOr<std::vector<std::unique_ptr<BackendConfig>>>
per_backend_configs = codegen_backend->GetSupportedConfigs(*instr);
if (!per_backend_configs.ok()) {
continue;
}
for (auto& config : *per_backend_configs) {
configs.push_back({codegen_backend.get(), std::move(config)});
}
}

View File

@ -393,6 +393,43 @@ TEST_F(AutotunerTest, AutotuneModuleWithDuplicateInstructions) {
EXPECT_THAT(autotuner->Autotune(module.get(), should_autotune), IsOk());
}
TEST_F(AutotunerTest, AutotuneButOneBackendFails) {
auto cache_manager = std::make_unique<MockAutotunerCache>();
EXPECT_CALL(*cache_manager, Lookup(_)).WillOnce(Return(std::nullopt));
EXPECT_CALL(*cache_manager, Insert(_, _)).WillOnce(Return(absl::OkStatus()));
std::vector<std::unique_ptr<BackendConfig>> configs;
configs.push_back(GetTestConfig("test_config"));
auto good_backend = std::make_unique<MockCodegenBackend>();
EXPECT_CALL(*good_backend, GetSupportedConfigs)
.WillOnce(Return(std::move(configs)));
EXPECT_CALL(*good_backend, Compile(_, _))
.WillOnce(Return(std::unique_ptr<Executable>()));
EXPECT_CALL(*good_backend, ApplyConfig(_, ConfigMatcher("test_config")))
.Times(1)
.WillRepeatedly(Return(absl::OkStatus()));
auto bad_backend = std::make_unique<MockCodegenBackend>();
EXPECT_CALL(*bad_backend, GetSupportedConfigs)
.WillOnce(Return(absl::InternalError("test error")));
auto profiler = std::make_unique<MockProfiler>();
auto device_description = CreateDummyDeviceDescription();
EXPECT_CALL(*profiler, CreateInputBuffers(_))
.WillOnce(Return(std::make_unique<InputBuffers>()));
EXPECT_CALL(*profiler, Profile(_, _))
.WillOnce(Return(ProfileResult({absl::Seconds(1)})));
std::vector<std::unique_ptr<CodegenBackend>> backends;
backends.push_back(std::move(good_backend));
backends.push_back(std::move(bad_backend));
TF_ASSERT_OK_AND_ASSIGN(
auto autotuner,
Autotuner::Create(std::move(backends), std::move(profiler), config_,
std::move(cache_manager)));
auto dummy_instr = HloInstruction::CreateConstant(LiteralUtil::CreateR0(1));
EXPECT_THAT(autotuner->Autotune(dummy_instr.get()), absl_testing::IsOk());
}
TEST_F(AutotunerTest, CacheHit) {
auto cache_manager = std::make_unique<MockAutotunerCache>();
AutotunerCacheInterface::Config config;