[xla:cpu] Do not register legacy runtime symbols with XLA:CPU custom calls

PiperOrigin-RevId: 826208548
This commit is contained in:
Eugene Zhulenev 2025-10-30 15:24:26 -07:00 committed by TensorFlower Gardener
parent 31bb7c01ff
commit d9024af6d4
2 changed files with 2 additions and 72 deletions

View File

@ -24,8 +24,6 @@ limitations under the License.
#include <string>
#include <utility>
#include "absl/functional/any_invocable.h"
#include "absl/strings/string_view.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ExecutionEngine/JITSymbol.h"
#include "llvm/ExecutionEngine/Orc/AbsoluteSymbols.h"
@ -35,29 +33,10 @@ limitations under the License.
#include "llvm/ExecutionEngine/Orc/Shared/ExecutorSymbolDef.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/Support/Error.h"
#include "mlir/ExecutionEngine/CRunnerUtils.h"
#include "xla/service/cpu/cpu_runtime.h"
#include "xla/service/cpu/runtime_conv2d.h"
#include "xla/service/cpu/runtime_conv2d_acl.h"
#include "xla/service/cpu/runtime_conv3d.h"
#include "xla/service/cpu/runtime_custom_call_status.h"
#include "xla/service/cpu/runtime_fp16.h"
#include "xla/service/cpu/runtime_key_value_sort.h"
#include "xla/service/cpu/runtime_matmul.h"
#include "xla/service/cpu/runtime_matmul_acl.h"
#include "xla/service/cpu/runtime_pow.h"
#include "xla/service/cpu/runtime_single_threaded_conv2d.h"
#include "xla/service/cpu/runtime_single_threaded_conv3d.h"
#include "xla/service/cpu/runtime_single_threaded_matmul.h"
#include "xla/service/cpu/runtime_topk.h"
#include "xla/service/cpu/windows_compatibility.h"
#include "xla/service/cpu/windows_compatibility.h" // IWYU pragma: keep
#include "xla/service/custom_call_target_registry.h"
#include "tsl/platform/logging.h"
#ifdef XLA_ONEDNN
#include "xla/service/cpu/onednn_convolution.h"
#include "xla/service/cpu/onednn_matmul.h"
#endif // XLA_ONEDNN
namespace xla::cpu {
@ -127,16 +106,6 @@ float __extendhfsf2(uint16_t a);
} // extern "C"
#define REGISTER_CPU_RUNTIME_SYMBOL(base_name) \
do { \
auto* function_address = \
reinterpret_cast<void*>(__xla_cpu_runtime_##base_name); \
registry->Register(xla::cpu::runtime::k##base_name##SymbolName, \
function_address, "Host"); \
CHECK_EQ(absl::string_view(xla::cpu::runtime::k##base_name##SymbolName), \
"__xla_cpu_runtime_" #base_name); \
} while (false)
// Register both the f32 (float) and f64 (double) versions of a libm symbol.
// Unfortunately the double versions are overloaded on some systems, e.g.
// Mac so we need an explicit cast. This requires passing the function signature
@ -155,41 +124,6 @@ static bool RegisterKnownJITSymbols() {
registry->Register("printf", reinterpret_cast<void*>(&printf), "Host");
registry->Register("puts", reinterpret_cast<void*>(&puts), "Host");
REGISTER_CPU_RUNTIME_SYMBOL(EigenConv2DF16);
REGISTER_CPU_RUNTIME_SYMBOL(EigenConv2DF32);
REGISTER_CPU_RUNTIME_SYMBOL(EigenConv3DF16);
REGISTER_CPU_RUNTIME_SYMBOL(EigenConv3DF32);
REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF16);
REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF32);
REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF64);
REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulC64);
REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulC128);
REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulS32);
REGISTER_CPU_RUNTIME_SYMBOL(EigenBatchMatMulF32);
REGISTER_CPU_RUNTIME_SYMBOL(ACLMatMulF32);
REGISTER_CPU_RUNTIME_SYMBOL(ACLBatchMatMulF32);
REGISTER_CPU_RUNTIME_SYMBOL(ACLConv2DF32);
REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConv2DF16);
REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConv2DF32);
REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConv3DF16);
REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConv3DF32);
REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF8E4M3FN);
REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF8E5M2);
REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF16);
REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF32);
REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF64);
REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulC64);
REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulC128);
REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulS32);
REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulU8);
REGISTER_CPU_RUNTIME_SYMBOL(StatusIsSuccess);
REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSort);
REGISTER_CPU_RUNTIME_SYMBOL(TopKF32);
#ifdef XLA_ONEDNN
REGISTER_CPU_RUNTIME_SYMBOL(OneDnnMatMul);
REGISTER_CPU_RUNTIME_SYMBOL(OneDnnMatMulReorder);
#endif // XLA_ONEDNN
registry->Register("__gnu_f2h_ieee", reinterpret_cast<void*>(__gnu_f2h_ieee),
"Host");
registry->Register("__gnu_h2f_ieee", reinterpret_cast<void*>(__gnu_h2f_ieee),
@ -286,10 +220,6 @@ static bool RegisterKnownJITSymbols() {
registry->Register("malloc", reinterpret_cast<void*>(malloc), "Host");
registry->Register("calloc", reinterpret_cast<void*>(calloc), "Host");
registry->Register("free", reinterpret_cast<void*>(free), "Host");
#ifndef _WIN32
// TODO(b/246980307): fails to link on windows because it's marked dllimport.
registry->Register("memrefCopy", reinterpret_cast<void*>(memrefCopy), "Host");
#endif
#ifdef __APPLE__
registry->Register("__bzero", reinterpret_cast<void*>(bzero), "Host");
@ -310,7 +240,6 @@ static bool RegisterKnownJITSymbols() {
return true;
}
#undef REGISTER_CPU_RUNTIME_SYMBOL
#undef REGISTER_LIBM_SYMBOL
static bool unused = RegisterKnownJITSymbols();

View File

@ -47,6 +47,7 @@ static bool InstructionIsUnavailable(const HloInstruction* instr) {
case HloOpcode::kInfeed:
case HloOpcode::kOutfeed:
case HloOpcode::kScatter:
case HloOpcode::kSort:
case HloOpcode::kFft:
return true;
default: