mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
[XLA:CPU] Compare host cpu features when loading AOT result to the compilation machine features
PiperOrigin-RevId: 826043058
This commit is contained in:
parent
fd85062199
commit
4461afa7ef
4
third_party/xla/xla/service/cpu/BUILD
vendored
4
third_party/xla/xla/service/cpu/BUILD
vendored
|
|
@ -538,7 +538,6 @@ xla_test(
|
||||||
"//xla:literal",
|
"//xla:literal",
|
||||||
"//xla:literal_util",
|
"//xla:literal_util",
|
||||||
"//xla/hlo/ir:hlo",
|
"//xla/hlo/ir:hlo",
|
||||||
"//xla/hlo/ir:hlo_module_group",
|
|
||||||
"//xla/service:compiler",
|
"//xla/service:compiler",
|
||||||
"//xla/service:executable",
|
"//xla/service:executable",
|
||||||
"//xla/service:hlo_runner",
|
"//xla/service:hlo_runner",
|
||||||
|
|
@ -551,6 +550,7 @@ xla_test(
|
||||||
"//xla/tsl/platform:statusor",
|
"//xla/tsl/platform:statusor",
|
||||||
"//xla/tsl/platform:test",
|
"//xla/tsl/platform:test",
|
||||||
"@com_google_absl//absl/strings:string_view",
|
"@com_google_absl//absl/strings:string_view",
|
||||||
|
"@llvm-project//llvm:Support",
|
||||||
"@llvm-project//llvm:TargetParser",
|
"@llvm-project//llvm:TargetParser",
|
||||||
"@local_tsl//tsl/platform:casts",
|
"@local_tsl//tsl/platform:casts",
|
||||||
],
|
],
|
||||||
|
|
@ -2176,9 +2176,11 @@ cc_library(
|
||||||
"@com_google_absl//absl/log",
|
"@com_google_absl//absl/log",
|
||||||
"@com_google_absl//absl/status",
|
"@com_google_absl//absl/status",
|
||||||
"@com_google_absl//absl/status:statusor",
|
"@com_google_absl//absl/status:statusor",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
"@com_google_absl//absl/types:span",
|
"@com_google_absl//absl/types:span",
|
||||||
"@llvm-project//llvm:Support",
|
"@llvm-project//llvm:Support",
|
||||||
"@llvm-project//llvm:Target",
|
"@llvm-project//llvm:Target",
|
||||||
|
"@llvm-project//llvm:TargetParser",
|
||||||
"@llvm-project//llvm:ir_headers",
|
"@llvm-project//llvm:ir_headers",
|
||||||
] + if_llvm_aarch64_available([
|
] + if_llvm_aarch64_available([
|
||||||
"@llvm-project//llvm:AArch64CodeGen", # fixdeps: keep
|
"@llvm-project//llvm:AArch64CodeGen", # fixdeps: keep
|
||||||
|
|
|
||||||
|
|
@ -18,9 +18,10 @@ limitations under the License.
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "absl/strings/string_view.h"
|
#include "absl/strings/string_view.h"
|
||||||
|
#include "llvm/ADT/StringMap.h" // IWYU pragma: keep
|
||||||
|
#include "llvm/TargetParser/Host.h"
|
||||||
#include "llvm/TargetParser/Triple.h"
|
#include "llvm/TargetParser/Triple.h"
|
||||||
#include "xla/hlo/ir/hlo_module.h"
|
#include "xla/hlo/ir/hlo_module.h"
|
||||||
#include "xla/hlo/ir/hlo_module_group.h"
|
|
||||||
#include "xla/literal.h"
|
#include "xla/literal.h"
|
||||||
#include "xla/literal_util.h"
|
#include "xla/literal_util.h"
|
||||||
#include "xla/service/compiler.h"
|
#include "xla/service/compiler.h"
|
||||||
|
|
@ -169,6 +170,17 @@ ENTRY main {
|
||||||
llvm::Triple(cpu_aot_result->proto().target_machine_options().triple())
|
llvm::Triple(cpu_aot_result->proto().target_machine_options().triple())
|
||||||
.getArchName(),
|
.getArchName(),
|
||||||
llvm::Triple(kTargetTripleForHost).getArchName());
|
llvm::Triple(kTargetTripleForHost).getArchName());
|
||||||
|
|
||||||
|
auto host_machine_features = llvm::sys::getHostCPUFeatures();
|
||||||
|
std::vector<absl::string_view> enabled_features;
|
||||||
|
for (const auto& feature : host_machine_features) {
|
||||||
|
if (feature.getValue()) {
|
||||||
|
enabled_features.push_back(feature.getKey());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPECT_EQ(cpu_aot_result->proto().target_machine_options().features(),
|
||||||
|
absl::StrJoin(enabled_features, ","));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
|
||||||
|
|
@ -22,13 +22,14 @@ limitations under the License.
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "absl/log/log.h"
|
#include "absl/log/log.h"
|
||||||
#include "absl/status/status.h"
|
|
||||||
#include "absl/status/statusor.h"
|
#include "absl/status/statusor.h"
|
||||||
|
#include "absl/strings/str_split.h"
|
||||||
#include "absl/types/span.h"
|
#include "absl/types/span.h"
|
||||||
#include "llvm/ADT/StringRef.h"
|
#include "llvm/ADT/StringRef.h"
|
||||||
#include "llvm/IR/DataLayout.h"
|
#include "llvm/IR/DataLayout.h"
|
||||||
#include "llvm/Target/TargetMachine.h"
|
#include "llvm/Target/TargetMachine.h"
|
||||||
#include "llvm/Target/TargetOptions.h"
|
#include "llvm/Target/TargetOptions.h"
|
||||||
|
#include "llvm/TargetParser/Host.h"
|
||||||
#include "xla/backends/cpu/codegen/cpu_features.h"
|
#include "xla/backends/cpu/codegen/cpu_features.h"
|
||||||
#include "xla/backends/cpu/codegen/execution_engine.h"
|
#include "xla/backends/cpu/codegen/execution_engine.h"
|
||||||
#include "xla/backends/cpu/codegen/ir_compiler.h"
|
#include "xla/backends/cpu/codegen/ir_compiler.h"
|
||||||
|
|
@ -184,6 +185,21 @@ CpuAotLoader::LoadAotCompilationResult(
|
||||||
expected_triple.getArchName(), triple.getArchName());
|
expected_triple.getArchName(), triple.getArchName());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto compile_machine_features =
|
||||||
|
absl::StrSplit(aot_result_proto.target_machine_options().features(), ',');
|
||||||
|
|
||||||
|
auto host_machine_features = llvm::sys::getHostCPUFeatures();
|
||||||
|
|
||||||
|
for (const auto& feature : compile_machine_features) {
|
||||||
|
if (!host_machine_features.contains(feature) ||
|
||||||
|
!host_machine_features[feature]) {
|
||||||
|
return Internal(
|
||||||
|
"Cannot load AOT result. Target machine feature %s is not supported "
|
||||||
|
"on the host machine.",
|
||||||
|
feature);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<SymbolProto> compiled_symbols_proto;
|
std::vector<SymbolProto> compiled_symbols_proto;
|
||||||
for (const auto& symbol_proto : aot_result_proto.compiled_symbols()) {
|
for (const auto& symbol_proto : aot_result_proto.compiled_symbols()) {
|
||||||
compiled_symbols_proto.push_back(symbol_proto);
|
compiled_symbols_proto.push_back(symbol_proto);
|
||||||
|
|
|
||||||
14
third_party/xla/xla/service/cpu/cpu_compiler.cc
vendored
14
third_party/xla/xla/service/cpu/cpu_compiler.cc
vendored
|
|
@ -1988,8 +1988,20 @@ CpuCompiler::CompileCpuExecutable(
|
||||||
target_machine_options_proto.set_triple(
|
target_machine_options_proto.set_triple(
|
||||||
target_machine->getTargetTriple().getTriple());
|
target_machine->getTargetTriple().getTriple());
|
||||||
target_machine_options_proto.set_cpu(target_machine->getTargetCPU());
|
target_machine_options_proto.set_cpu(target_machine->getTargetCPU());
|
||||||
|
|
||||||
|
// TODO(basioli): Target machine features are returning an empty string at the
|
||||||
|
// moment so for now we are using the host CPU features. This should be
|
||||||
|
// updated to use the target machine features of the target we are actually
|
||||||
|
// compiling for as we might want to support cross-compilation.
|
||||||
|
auto host_machine_features = llvm::sys::getHostCPUFeatures();
|
||||||
|
std::vector<absl::string_view> enabled_features;
|
||||||
|
for (const auto& feature : host_machine_features) {
|
||||||
|
if (feature.getValue()) {
|
||||||
|
enabled_features.push_back(feature.getKey());
|
||||||
|
}
|
||||||
|
}
|
||||||
target_machine_options_proto.set_features(
|
target_machine_options_proto.set_features(
|
||||||
target_machine->getTargetFeatureString());
|
absl::StrJoin(enabled_features, ","));
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
auto cpu_executable,
|
auto cpu_executable,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user