Added _fused_sdp_choice_stub dispatcher support for HPU device (#149512)

Currently for HPU device we don't have any support for _fused_sdp_choice_stub dispatcher function, so for `scaled_dot_product_attention` function by default selecting the `MATH Backend` using `_fused_sdp_choice_stub` for HPU device. With this PR we have enabled support for `_fused_sdp_choice_stub` dispatcher function, so that we can invoke any backend (for example math, flash_attention, efficient_attention, cudnn_attention, overrideable) according to user choice for HPU device.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149512
Approved by: https://github.com/drisspg
This commit is contained in:
pralay 2025-04-09 15:48:09 +00:00 committed by PyTorch MergeBot
parent d0e3482266
commit 97a5e5c6b3
3 changed files with 37 additions and 0 deletions

View File

@ -147,6 +147,7 @@ DispatchResult DispatchStubImpl::try_get_call_ptr(
c10::DeviceType::MPS,
c10::DeviceType::MTIA,
c10::DeviceType::XPU,
c10::DeviceType::HPU,
c10::DeviceType::PrivateUse1
);
// Check if the device type is supported.
@ -203,6 +204,9 @@ DispatchResult DispatchStubImpl::try_get_call_ptr(
return xpu_dispatch_ptr != nullptr ? DispatchResult(xpu_dispatch_ptr) : ErrorType::MissingDeviceKernel;
#endif
case DeviceType::HPU:
return hpu_dispatch_ptr != nullptr ? DispatchResult(hpu_dispatch_ptr) : ErrorType::MissingDeviceKernel;
case DeviceType::PrivateUse1:
return privateuse1_dispatch_ptr != nullptr ? DispatchResult(privateuse1_dispatch_ptr) : ErrorType::MissingDeviceKernel;

View File

@ -44,6 +44,7 @@
// - MPS: Apple Silicon GPUs (Metal Performance Shaders)
// - MTIA: Meta Training and Inference Devices
// - XPU: Intel GPUs
// - HPU: Reserved for HPU (Intel Gaudi) device types
// - PrivateUse1: Reserved for private/custom device types
//
// If you want to update the list of supported devices, add a new dispatch_ptr
@ -196,6 +197,7 @@ struct TORCH_API DispatchStubImpl {
#if defined(USE_XPU)
void* xpu_dispatch_ptr;
#endif
void* hpu_dispatch_ptr;
void* privateuse1_dispatch_ptr;
#else
std::atomic<void*> cpu_dispatch_ptr{nullptr};
@ -206,6 +208,7 @@ struct TORCH_API DispatchStubImpl {
#if defined(USE_XPU)
void* xpu_dispatch_ptr = nullptr;
#endif
void* hpu_dispatch_ptr = nullptr;
void* privateuse1_dispatch_ptr = nullptr;
#endif
};
@ -259,6 +262,10 @@ public:
}
#endif
void set_hpu_dispatch_ptr(FnPtr fn_ptr) {
impl.hpu_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
}
void set_hip_dispatch_ptr(FnPtr fn_ptr) {
impl.hip_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
}
@ -337,6 +344,13 @@ struct RegisterXPUDispatch {
}
};
template <typename DispatchStub>
struct RegisterHPUDispatch {
RegisterHPUDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value){
stub.set_hpu_dispatch_ptr(value);
}
};
template <typename DispatchStub>
struct RegisterMPSDispatch {
RegisterMPSDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) {
@ -437,6 +451,9 @@ struct RegisterPRIVATEUSE1Dispatch {
#define REGISTER_XPU_DISPATCH(name, fn) \
static RegisterXPUDispatch<struct name##_DECLARE_DISPATCH_type> name ## __register(name, fn);
#define REGISTER_HPU_DISPATCH(name, fn) \
static RegisterHPUDispatch<struct name##_DECLARE_DISPATCH_type> name ## __register(name, fn);
#define REGISTER_HIP_DISPATCH(name, fn) \
static RegisterHIPDispatch<struct name##_DECLARE_DISPATCH_type> name ## __register(name, fn);

View File

@ -28,6 +28,7 @@
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_fused_sdp_choice_native.h>
#include <ATen/ops/_fused_sdp_choice_ops.h>
#include <ATen/ops/_masked_softmax.h>
#include <ATen/ops/_native_multi_head_attention_native.h>
#include <ATen/ops/_nested_from_padded.h>
@ -448,6 +449,7 @@ REGISTER_AVX512_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp)
REGISTER_VSX_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp)
REGISTER_ZVECTOR_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp)
REGISTER_SVE256_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp)
REGISTER_HPU_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_meta);
int64_t _fused_sdp_choice_meta(
const Tensor& query_,
@ -459,6 +461,20 @@ int64_t _fused_sdp_choice_meta(
std::optional<double> scale,
bool enable_gqa) {
auto query_key_set = query_.key_set();
bool has_hpu = query_key_set.has(c10::DispatchKey::HPU);
if (has_hpu) {
auto choice_int = at::_ops::_fused_sdp_choice::redispatch(
c10::DispatchKeySet(DispatchKey::HPU),
query_,
key,
value,
attn_mask_,
dropout_p,
is_causal,
scale,
enable_gqa);
return choice_int;
}
#if defined(USE_ROCM)
bool has_rocm = query_key_set.has(c10::DispatchKey::HIP);
if (has_rocm) {