mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 00:19:58 +01:00
[xla:cpu] Do not register legacy runtime symbols with XLA:CPU custom calls
PiperOrigin-RevId: 826208548
This commit is contained in:
parent
31bb7c01ff
commit
d9024af6d4
|
|
@ -24,8 +24,6 @@ limitations under the License.
|
|||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/functional/any_invocable.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/ExecutionEngine/JITSymbol.h"
|
||||
#include "llvm/ExecutionEngine/Orc/AbsoluteSymbols.h"
|
||||
|
|
@ -35,29 +33,10 @@ limitations under the License.
|
|||
#include "llvm/ExecutionEngine/Orc/Shared/ExecutorSymbolDef.h"
|
||||
#include "llvm/IR/DataLayout.h"
|
||||
#include "llvm/Support/Error.h"
|
||||
#include "mlir/ExecutionEngine/CRunnerUtils.h"
|
||||
#include "xla/service/cpu/cpu_runtime.h"
|
||||
#include "xla/service/cpu/runtime_conv2d.h"
|
||||
#include "xla/service/cpu/runtime_conv2d_acl.h"
|
||||
#include "xla/service/cpu/runtime_conv3d.h"
|
||||
#include "xla/service/cpu/runtime_custom_call_status.h"
|
||||
#include "xla/service/cpu/runtime_fp16.h"
|
||||
#include "xla/service/cpu/runtime_key_value_sort.h"
|
||||
#include "xla/service/cpu/runtime_matmul.h"
|
||||
#include "xla/service/cpu/runtime_matmul_acl.h"
|
||||
#include "xla/service/cpu/runtime_pow.h"
|
||||
#include "xla/service/cpu/runtime_single_threaded_conv2d.h"
|
||||
#include "xla/service/cpu/runtime_single_threaded_conv3d.h"
|
||||
#include "xla/service/cpu/runtime_single_threaded_matmul.h"
|
||||
#include "xla/service/cpu/runtime_topk.h"
|
||||
#include "xla/service/cpu/windows_compatibility.h"
|
||||
#include "xla/service/cpu/windows_compatibility.h" // IWYU pragma: keep
|
||||
#include "xla/service/custom_call_target_registry.h"
|
||||
#include "tsl/platform/logging.h"
|
||||
|
||||
#ifdef XLA_ONEDNN
|
||||
#include "xla/service/cpu/onednn_convolution.h"
|
||||
#include "xla/service/cpu/onednn_matmul.h"
|
||||
#endif // XLA_ONEDNN
|
||||
|
||||
namespace xla::cpu {
|
||||
|
||||
|
|
@ -127,16 +106,6 @@ float __extendhfsf2(uint16_t a);
|
|||
|
||||
} // extern "C"
|
||||
|
||||
#define REGISTER_CPU_RUNTIME_SYMBOL(base_name) \
|
||||
do { \
|
||||
auto* function_address = \
|
||||
reinterpret_cast<void*>(__xla_cpu_runtime_##base_name); \
|
||||
registry->Register(xla::cpu::runtime::k##base_name##SymbolName, \
|
||||
function_address, "Host"); \
|
||||
CHECK_EQ(absl::string_view(xla::cpu::runtime::k##base_name##SymbolName), \
|
||||
"__xla_cpu_runtime_" #base_name); \
|
||||
} while (false)
|
||||
|
||||
// Register both the f32 (float) and f64 (double) versions of a libm symbol.
|
||||
// Unfortunately the double versions are overloaded on some systems, e.g.
|
||||
// Mac so we need an explicit cast. This requires passing the function signature
|
||||
|
|
@ -155,41 +124,6 @@ static bool RegisterKnownJITSymbols() {
|
|||
registry->Register("printf", reinterpret_cast<void*>(&printf), "Host");
|
||||
registry->Register("puts", reinterpret_cast<void*>(&puts), "Host");
|
||||
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(EigenConv2DF16);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(EigenConv2DF32);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(EigenConv3DF16);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(EigenConv3DF32);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF16);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF32);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF64);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulC64);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulC128);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulS32);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(EigenBatchMatMulF32);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(ACLMatMulF32);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(ACLBatchMatMulF32);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(ACLConv2DF32);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConv2DF16);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConv2DF32);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConv3DF16);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConv3DF32);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF8E4M3FN);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF8E5M2);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF16);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF32);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF64);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulC64);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulC128);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulS32);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulU8);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(StatusIsSuccess);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSort);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(TopKF32);
|
||||
#ifdef XLA_ONEDNN
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(OneDnnMatMul);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(OneDnnMatMulReorder);
|
||||
#endif // XLA_ONEDNN
|
||||
|
||||
registry->Register("__gnu_f2h_ieee", reinterpret_cast<void*>(__gnu_f2h_ieee),
|
||||
"Host");
|
||||
registry->Register("__gnu_h2f_ieee", reinterpret_cast<void*>(__gnu_h2f_ieee),
|
||||
|
|
@ -286,10 +220,6 @@ static bool RegisterKnownJITSymbols() {
|
|||
registry->Register("malloc", reinterpret_cast<void*>(malloc), "Host");
|
||||
registry->Register("calloc", reinterpret_cast<void*>(calloc), "Host");
|
||||
registry->Register("free", reinterpret_cast<void*>(free), "Host");
|
||||
#ifndef _WIN32
|
||||
// TODO(b/246980307): fails to link on windows because it's marked dllimport.
|
||||
registry->Register("memrefCopy", reinterpret_cast<void*>(memrefCopy), "Host");
|
||||
#endif
|
||||
|
||||
#ifdef __APPLE__
|
||||
registry->Register("__bzero", reinterpret_cast<void*>(bzero), "Host");
|
||||
|
|
@ -310,7 +240,6 @@ static bool RegisterKnownJITSymbols() {
|
|||
return true;
|
||||
}
|
||||
|
||||
#undef REGISTER_CPU_RUNTIME_SYMBOL
|
||||
#undef REGISTER_LIBM_SYMBOL
|
||||
|
||||
static bool unused = RegisterKnownJITSymbols();
|
||||
|
|
|
|||
|
|
@ -47,6 +47,7 @@ static bool InstructionIsUnavailable(const HloInstruction* instr) {
|
|||
case HloOpcode::kInfeed:
|
||||
case HloOpcode::kOutfeed:
|
||||
case HloOpcode::kScatter:
|
||||
case HloOpcode::kSort:
|
||||
case HloOpcode::kFft:
|
||||
return true;
|
||||
default:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user