[XLA:CPU] Compare host cpu features when loading AOT result to the compilation machine features

PiperOrigin-RevId: 826043058
This commit is contained in:
Karlo Basioli 2025-10-30 08:28:39 -07:00 committed by TensorFlower Gardener
parent fd85062199
commit 4461afa7ef
4 changed files with 46 additions and 4 deletions

View File

@ -538,7 +538,6 @@ xla_test(
"//xla:literal",
"//xla:literal_util",
"//xla/hlo/ir:hlo",
"//xla/hlo/ir:hlo_module_group",
"//xla/service:compiler",
"//xla/service:executable",
"//xla/service:hlo_runner",
@ -551,6 +550,7 @@ xla_test(
"//xla/tsl/platform:statusor",
"//xla/tsl/platform:test",
"@com_google_absl//absl/strings:string_view",
"@llvm-project//llvm:Support",
"@llvm-project//llvm:TargetParser",
"@local_tsl//tsl/platform:casts",
],
@ -2176,9 +2176,11 @@ cc_library(
"@com_google_absl//absl/log",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:Support",
"@llvm-project//llvm:Target",
"@llvm-project//llvm:TargetParser",
"@llvm-project//llvm:ir_headers",
] + if_llvm_aarch64_available([
"@llvm-project//llvm:AArch64CodeGen", # fixdeps: keep

View File

@ -18,9 +18,10 @@ limitations under the License.
#include <vector>
#include "absl/strings/string_view.h"
#include "llvm/ADT/StringMap.h" // IWYU pragma: keep
#include "llvm/TargetParser/Host.h"
#include "llvm/TargetParser/Triple.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/ir/hlo_module_group.h"
#include "xla/literal.h"
#include "xla/literal_util.h"
#include "xla/service/compiler.h"
@ -169,6 +170,17 @@ ENTRY main {
llvm::Triple(cpu_aot_result->proto().target_machine_options().triple())
.getArchName(),
llvm::Triple(kTargetTripleForHost).getArchName());
auto host_machine_features = llvm::sys::getHostCPUFeatures();
std::vector<absl::string_view> enabled_features;
for (const auto& feature : host_machine_features) {
if (feature.getValue()) {
enabled_features.push_back(feature.getKey());
}
}
EXPECT_EQ(cpu_aot_result->proto().target_machine_options().features(),
absl::StrJoin(enabled_features, ","));
}
} // namespace

View File

@ -22,13 +22,14 @@ limitations under the License.
#include <vector>
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_split.h"
#include "absl/types/span.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/Target/TargetOptions.h"
#include "llvm/TargetParser/Host.h"
#include "xla/backends/cpu/codegen/cpu_features.h"
#include "xla/backends/cpu/codegen/execution_engine.h"
#include "xla/backends/cpu/codegen/ir_compiler.h"
@ -184,6 +185,21 @@ CpuAotLoader::LoadAotCompilationResult(
expected_triple.getArchName(), triple.getArchName());
}
auto compile_machine_features =
absl::StrSplit(aot_result_proto.target_machine_options().features(), ',');
auto host_machine_features = llvm::sys::getHostCPUFeatures();
for (const auto& feature : compile_machine_features) {
if (!host_machine_features.contains(feature) ||
!host_machine_features[feature]) {
return Internal(
"Cannot load AOT result. Target machine feature %s is not supported "
"on the host machine.",
feature);
}
}
std::vector<SymbolProto> compiled_symbols_proto;
for (const auto& symbol_proto : aot_result_proto.compiled_symbols()) {
compiled_symbols_proto.push_back(symbol_proto);

View File

@ -1988,8 +1988,20 @@ CpuCompiler::CompileCpuExecutable(
target_machine_options_proto.set_triple(
target_machine->getTargetTriple().getTriple());
target_machine_options_proto.set_cpu(target_machine->getTargetCPU());
// TODO(basioli): Target machine features are returning an empty string at the
// moment so for now we are using the host CPU features. This should be
// updated to use the target machine features of the target we are actually
// compiling for as we might want to support cross-compilation.
auto host_machine_features = llvm::sys::getHostCPUFeatures();
std::vector<absl::string_view> enabled_features;
for (const auto& feature : host_machine_features) {
if (feature.getValue()) {
enabled_features.push_back(feature.getKey());
}
}
target_machine_options_proto.set_features(
target_machine->getTargetFeatureString());
absl::StrJoin(enabled_features, ","));
TF_ASSIGN_OR_RETURN(
auto cpu_executable,