diff --git a/aten/src/ATen/native/DispatchStub.cpp b/aten/src/ATen/native/DispatchStub.cpp index 1be4ec37dfe..e1d329fbf30 100644 --- a/aten/src/ATen/native/DispatchStub.cpp +++ b/aten/src/ATen/native/DispatchStub.cpp @@ -147,6 +147,7 @@ DispatchResult DispatchStubImpl::try_get_call_ptr( c10::DeviceType::MPS, c10::DeviceType::MTIA, c10::DeviceType::XPU, + c10::DeviceType::HPU, c10::DeviceType::PrivateUse1 ); // Check if the device type is supported. @@ -203,6 +204,9 @@ DispatchResult DispatchStubImpl::try_get_call_ptr( return xpu_dispatch_ptr != nullptr ? DispatchResult(xpu_dispatch_ptr) : ErrorType::MissingDeviceKernel; #endif + case DeviceType::HPU: + return hpu_dispatch_ptr != nullptr ? DispatchResult(hpu_dispatch_ptr) : ErrorType::MissingDeviceKernel; + case DeviceType::PrivateUse1: return privateuse1_dispatch_ptr != nullptr ? DispatchResult(privateuse1_dispatch_ptr) : ErrorType::MissingDeviceKernel; diff --git a/aten/src/ATen/native/DispatchStub.h b/aten/src/ATen/native/DispatchStub.h index 725d0d08bae..cbe4b23c671 100644 --- a/aten/src/ATen/native/DispatchStub.h +++ b/aten/src/ATen/native/DispatchStub.h @@ -44,6 +44,7 @@ // - MPS: Apple Silicon GPUs (Metal Performance Shaders) // - MTIA: Meta Training and Inference Devices // - XPU: Intel GPUs +// - HPU: Reserved for HPU (Intel Gaudi) device types // - PrivateUse1: Reserved for private/custom device types // // If you want to update the list of supported devices, add a new dispatch_ptr @@ -196,6 +197,7 @@ struct TORCH_API DispatchStubImpl { #if defined(USE_XPU) void* xpu_dispatch_ptr; #endif + void* hpu_dispatch_ptr; void* privateuse1_dispatch_ptr; #else std::atomic cpu_dispatch_ptr{nullptr}; @@ -206,6 +208,7 @@ struct TORCH_API DispatchStubImpl { #if defined(USE_XPU) void* xpu_dispatch_ptr = nullptr; #endif + void* hpu_dispatch_ptr = nullptr; void* privateuse1_dispatch_ptr = nullptr; #endif }; @@ -259,6 +262,10 @@ public: } #endif + void set_hpu_dispatch_ptr(FnPtr fn_ptr) { + impl.hpu_dispatch_ptr = reinterpret_cast(fn_ptr); + } + void set_hip_dispatch_ptr(FnPtr fn_ptr) { impl.hip_dispatch_ptr = reinterpret_cast(fn_ptr); } @@ -337,6 +344,13 @@ struct RegisterXPUDispatch { } }; +template +struct RegisterHPUDispatch { + RegisterHPUDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value){ + stub.set_hpu_dispatch_ptr(value); + } +}; + template struct RegisterMPSDispatch { RegisterMPSDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) { @@ -437,6 +451,9 @@ struct RegisterPRIVATEUSE1Dispatch { #define REGISTER_XPU_DISPATCH(name, fn) \ static RegisterXPUDispatch name ## __register(name, fn); +#define REGISTER_HPU_DISPATCH(name, fn) \ + static RegisterHPUDispatch name ## __register(name, fn); + #define REGISTER_HIP_DISPATCH(name, fn) \ static RegisterHIPDispatch name ## __register(name, fn); diff --git a/aten/src/ATen/native/transformers/attention.cpp b/aten/src/ATen/native/transformers/attention.cpp index 27397bf7889..66bdaa0baa8 100644 --- a/aten/src/ATen/native/transformers/attention.cpp +++ b/aten/src/ATen/native/transformers/attention.cpp @@ -28,6 +28,7 @@ #include #else #include +#include #include #include #include @@ -448,6 +449,7 @@ REGISTER_AVX512_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp) REGISTER_VSX_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp) REGISTER_ZVECTOR_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp) REGISTER_SVE256_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp) +REGISTER_HPU_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_meta); int64_t _fused_sdp_choice_meta( const Tensor& query_, @@ -459,6 +461,20 @@ int64_t _fused_sdp_choice_meta( std::optional scale, bool enable_gqa) { auto query_key_set = query_.key_set(); + bool has_hpu = query_key_set.has(c10::DispatchKey::HPU); + if (has_hpu) { + auto choice_int = at::_ops::_fused_sdp_choice::redispatch( + c10::DispatchKeySet(DispatchKey::HPU), + query_, + key, + value, + attn_mask_, + dropout_p, + is_causal, + scale, + enable_gqa); + return choice_int; + } #if defined(USE_ROCM) bool has_rocm = query_key_set.has(c10::DispatchKey::HIP); if (has_rocm) {