mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
[XLA:GPU] Don't fail Autotuner::GetSupportedConfigs if one of the backend fails
PiperOrigin-RevId: 820303427
This commit is contained in:
parent
3c991bd608
commit
83c407040a
|
|
@ -248,10 +248,12 @@ absl::StatusOr<std::vector<Autotuner::Config>> Autotuner::GetSupportedConfigs(
|
|||
HloInstruction* instr) {
|
||||
std::vector<Config> configs;
|
||||
for (auto& codegen_backend : codegen_backends_) {
|
||||
std::vector<std::unique_ptr<BackendConfig>> per_backend_configs;
|
||||
TF_ASSIGN_OR_RETURN(per_backend_configs,
|
||||
codegen_backend->GetSupportedConfigs(*instr));
|
||||
for (auto& config : per_backend_configs) {
|
||||
absl::StatusOr<std::vector<std::unique_ptr<BackendConfig>>>
|
||||
per_backend_configs = codegen_backend->GetSupportedConfigs(*instr);
|
||||
if (!per_backend_configs.ok()) {
|
||||
continue;
|
||||
}
|
||||
for (auto& config : *per_backend_configs) {
|
||||
configs.push_back({codegen_backend.get(), std::move(config)});
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -393,6 +393,43 @@ TEST_F(AutotunerTest, AutotuneModuleWithDuplicateInstructions) {
|
|||
EXPECT_THAT(autotuner->Autotune(module.get(), should_autotune), IsOk());
|
||||
}
|
||||
|
||||
TEST_F(AutotunerTest, AutotuneButOneBackendFails) {
|
||||
auto cache_manager = std::make_unique<MockAutotunerCache>();
|
||||
EXPECT_CALL(*cache_manager, Lookup(_)).WillOnce(Return(std::nullopt));
|
||||
EXPECT_CALL(*cache_manager, Insert(_, _)).WillOnce(Return(absl::OkStatus()));
|
||||
|
||||
std::vector<std::unique_ptr<BackendConfig>> configs;
|
||||
configs.push_back(GetTestConfig("test_config"));
|
||||
|
||||
auto good_backend = std::make_unique<MockCodegenBackend>();
|
||||
EXPECT_CALL(*good_backend, GetSupportedConfigs)
|
||||
.WillOnce(Return(std::move(configs)));
|
||||
EXPECT_CALL(*good_backend, Compile(_, _))
|
||||
.WillOnce(Return(std::unique_ptr<Executable>()));
|
||||
EXPECT_CALL(*good_backend, ApplyConfig(_, ConfigMatcher("test_config")))
|
||||
.Times(1)
|
||||
.WillRepeatedly(Return(absl::OkStatus()));
|
||||
auto bad_backend = std::make_unique<MockCodegenBackend>();
|
||||
EXPECT_CALL(*bad_backend, GetSupportedConfigs)
|
||||
.WillOnce(Return(absl::InternalError("test error")));
|
||||
|
||||
auto profiler = std::make_unique<MockProfiler>();
|
||||
auto device_description = CreateDummyDeviceDescription();
|
||||
EXPECT_CALL(*profiler, CreateInputBuffers(_))
|
||||
.WillOnce(Return(std::make_unique<InputBuffers>()));
|
||||
EXPECT_CALL(*profiler, Profile(_, _))
|
||||
.WillOnce(Return(ProfileResult({absl::Seconds(1)})));
|
||||
std::vector<std::unique_ptr<CodegenBackend>> backends;
|
||||
backends.push_back(std::move(good_backend));
|
||||
backends.push_back(std::move(bad_backend));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto autotuner,
|
||||
Autotuner::Create(std::move(backends), std::move(profiler), config_,
|
||||
std::move(cache_manager)));
|
||||
auto dummy_instr = HloInstruction::CreateConstant(LiteralUtil::CreateR0(1));
|
||||
EXPECT_THAT(autotuner->Autotune(dummy_instr.get()), absl_testing::IsOk());
|
||||
}
|
||||
|
||||
TEST_F(AutotunerTest, CacheHit) {
|
||||
auto cache_manager = std::make_unique<MockAutotunerCache>();
|
||||
AutotunerCacheInterface::Config config;
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user