diff --git a/third_party/xla/xla/service/cpu/BUILD b/third_party/xla/xla/service/cpu/BUILD index 5139c0305c2..af339d467b3 100644 --- a/third_party/xla/xla/service/cpu/BUILD +++ b/third_party/xla/xla/service/cpu/BUILD @@ -538,7 +538,6 @@ xla_test( "//xla:literal", "//xla:literal_util", "//xla/hlo/ir:hlo", - "//xla/hlo/ir:hlo_module_group", "//xla/service:compiler", "//xla/service:executable", "//xla/service:hlo_runner", @@ -551,6 +550,7 @@ xla_test( "//xla/tsl/platform:statusor", "//xla/tsl/platform:test", "@com_google_absl//absl/strings:string_view", + "@llvm-project//llvm:Support", "@llvm-project//llvm:TargetParser", "@local_tsl//tsl/platform:casts", ], @@ -2176,9 +2176,11 @@ cc_library( "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", "@llvm-project//llvm:Target", + "@llvm-project//llvm:TargetParser", "@llvm-project//llvm:ir_headers", ] + if_llvm_aarch64_available([ "@llvm-project//llvm:AArch64CodeGen", # fixdeps: keep diff --git a/third_party/xla/xla/service/cpu/cpu_aot_compiler_test.cc b/third_party/xla/xla/service/cpu/cpu_aot_compiler_test.cc index d0d621b9ad9..e869d323087 100644 --- a/third_party/xla/xla/service/cpu/cpu_aot_compiler_test.cc +++ b/third_party/xla/xla/service/cpu/cpu_aot_compiler_test.cc @@ -18,9 +18,10 @@ limitations under the License. #include #include "absl/strings/string_view.h" +#include "llvm/ADT/StringMap.h" // IWYU pragma: keep +#include "llvm/TargetParser/Host.h" #include "llvm/TargetParser/Triple.h" #include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/ir/hlo_module_group.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/service/compiler.h" @@ -169,6 +170,17 @@ ENTRY main { llvm::Triple(cpu_aot_result->proto().target_machine_options().triple()) .getArchName(), llvm::Triple(kTargetTripleForHost).getArchName()); + + auto host_machine_features = llvm::sys::getHostCPUFeatures(); + std::vector enabled_features; + for (const auto& feature : host_machine_features) { + if (feature.getValue()) { + enabled_features.push_back(feature.getKey()); + } + } + + EXPECT_EQ(cpu_aot_result->proto().target_machine_options().features(), + absl::StrJoin(enabled_features, ",")); } } // namespace diff --git a/third_party/xla/xla/service/cpu/cpu_aot_loader.cc b/third_party/xla/xla/service/cpu/cpu_aot_loader.cc index ea29096a3cb..a4595bec665 100644 --- a/third_party/xla/xla/service/cpu/cpu_aot_loader.cc +++ b/third_party/xla/xla/service/cpu/cpu_aot_loader.cc @@ -22,13 +22,14 @@ limitations under the License. #include #include "absl/log/log.h" -#include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/str_split.h" #include "absl/types/span.h" #include "llvm/ADT/StringRef.h" #include "llvm/IR/DataLayout.h" #include "llvm/Target/TargetMachine.h" #include "llvm/Target/TargetOptions.h" +#include "llvm/TargetParser/Host.h" #include "xla/backends/cpu/codegen/cpu_features.h" #include "xla/backends/cpu/codegen/execution_engine.h" #include "xla/backends/cpu/codegen/ir_compiler.h" @@ -184,6 +185,21 @@ CpuAotLoader::LoadAotCompilationResult( expected_triple.getArchName(), triple.getArchName()); } + auto compile_machine_features = + absl::StrSplit(aot_result_proto.target_machine_options().features(), ','); + + auto host_machine_features = llvm::sys::getHostCPUFeatures(); + + for (const auto& feature : compile_machine_features) { + if (!host_machine_features.contains(feature) || + !host_machine_features[feature]) { + return Internal( + "Cannot load AOT result. Target machine feature %s is not supported " + "on the host machine.", + feature); + } + } + std::vector compiled_symbols_proto; for (const auto& symbol_proto : aot_result_proto.compiled_symbols()) { compiled_symbols_proto.push_back(symbol_proto); diff --git a/third_party/xla/xla/service/cpu/cpu_compiler.cc b/third_party/xla/xla/service/cpu/cpu_compiler.cc index 7005327351d..ceb9275be79 100644 --- a/third_party/xla/xla/service/cpu/cpu_compiler.cc +++ b/third_party/xla/xla/service/cpu/cpu_compiler.cc @@ -1988,8 +1988,20 @@ CpuCompiler::CompileCpuExecutable( target_machine_options_proto.set_triple( target_machine->getTargetTriple().getTriple()); target_machine_options_proto.set_cpu(target_machine->getTargetCPU()); + + // TODO(basioli): Target machine features are returning an empty string at the + // moment so for now we are using the host CPU features. This should be + // updated to use the target machine features of the target we are actually + // compiling for as we might want to support cross-compilation. + auto host_machine_features = llvm::sys::getHostCPUFeatures(); + std::vector enabled_features; + for (const auto& feature : host_machine_features) { + if (feature.getValue()) { + enabled_features.push_back(feature.getKey()); + } + } target_machine_options_proto.set_features( - target_machine->getTargetFeatureString()); + absl::StrJoin(enabled_features, ",")); TF_ASSIGN_OR_RETURN( auto cpu_executable,