mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Added _fused_sdp_choice_stub dispatcher support for HPU device (#149512)
Currently for HPU device we don't have any support for _fused_sdp_choice_stub dispatcher function, so for `scaled_dot_product_attention` function by default selecting the `MATH Backend` using `_fused_sdp_choice_stub` for HPU device. With this PR we have enabled support for `_fused_sdp_choice_stub` dispatcher function, so that we can invoke any backend (for example math, flash_attention, efficient_attention, cudnn_attention, overrideable) according to user choice for HPU device. Pull Request resolved: https://github.com/pytorch/pytorch/pull/149512 Approved by: https://github.com/drisspg
This commit is contained in:
parent
d0e3482266
commit
97a5e5c6b3
|
|
@ -147,6 +147,7 @@ DispatchResult DispatchStubImpl::try_get_call_ptr(
|
|||
c10::DeviceType::MPS,
|
||||
c10::DeviceType::MTIA,
|
||||
c10::DeviceType::XPU,
|
||||
c10::DeviceType::HPU,
|
||||
c10::DeviceType::PrivateUse1
|
||||
);
|
||||
// Check if the device type is supported.
|
||||
|
|
@ -203,6 +204,9 @@ DispatchResult DispatchStubImpl::try_get_call_ptr(
|
|||
return xpu_dispatch_ptr != nullptr ? DispatchResult(xpu_dispatch_ptr) : ErrorType::MissingDeviceKernel;
|
||||
#endif
|
||||
|
||||
case DeviceType::HPU:
|
||||
return hpu_dispatch_ptr != nullptr ? DispatchResult(hpu_dispatch_ptr) : ErrorType::MissingDeviceKernel;
|
||||
|
||||
case DeviceType::PrivateUse1:
|
||||
return privateuse1_dispatch_ptr != nullptr ? DispatchResult(privateuse1_dispatch_ptr) : ErrorType::MissingDeviceKernel;
|
||||
|
||||
|
|
|
|||
|
|
@ -44,6 +44,7 @@
|
|||
// - MPS: Apple Silicon GPUs (Metal Performance Shaders)
|
||||
// - MTIA: Meta Training and Inference Devices
|
||||
// - XPU: Intel GPUs
|
||||
// - HPU: Reserved for HPU (Intel Gaudi) device types
|
||||
// - PrivateUse1: Reserved for private/custom device types
|
||||
//
|
||||
// If you want to update the list of supported devices, add a new dispatch_ptr
|
||||
|
|
@ -196,6 +197,7 @@ struct TORCH_API DispatchStubImpl {
|
|||
#if defined(USE_XPU)
|
||||
void* xpu_dispatch_ptr;
|
||||
#endif
|
||||
void* hpu_dispatch_ptr;
|
||||
void* privateuse1_dispatch_ptr;
|
||||
#else
|
||||
std::atomic<void*> cpu_dispatch_ptr{nullptr};
|
||||
|
|
@ -206,6 +208,7 @@ struct TORCH_API DispatchStubImpl {
|
|||
#if defined(USE_XPU)
|
||||
void* xpu_dispatch_ptr = nullptr;
|
||||
#endif
|
||||
void* hpu_dispatch_ptr = nullptr;
|
||||
void* privateuse1_dispatch_ptr = nullptr;
|
||||
#endif
|
||||
};
|
||||
|
|
@ -259,6 +262,10 @@ public:
|
|||
}
|
||||
#endif
|
||||
|
||||
void set_hpu_dispatch_ptr(FnPtr fn_ptr) {
|
||||
impl.hpu_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
|
||||
}
|
||||
|
||||
void set_hip_dispatch_ptr(FnPtr fn_ptr) {
|
||||
impl.hip_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
|
||||
}
|
||||
|
|
@ -337,6 +344,13 @@ struct RegisterXPUDispatch {
|
|||
}
|
||||
};
|
||||
|
||||
template <typename DispatchStub>
|
||||
struct RegisterHPUDispatch {
|
||||
RegisterHPUDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value){
|
||||
stub.set_hpu_dispatch_ptr(value);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename DispatchStub>
|
||||
struct RegisterMPSDispatch {
|
||||
RegisterMPSDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) {
|
||||
|
|
@ -437,6 +451,9 @@ struct RegisterPRIVATEUSE1Dispatch {
|
|||
#define REGISTER_XPU_DISPATCH(name, fn) \
|
||||
static RegisterXPUDispatch<struct name##_DECLARE_DISPATCH_type> name ## __register(name, fn);
|
||||
|
||||
#define REGISTER_HPU_DISPATCH(name, fn) \
|
||||
static RegisterHPUDispatch<struct name##_DECLARE_DISPATCH_type> name ## __register(name, fn);
|
||||
|
||||
#define REGISTER_HIP_DISPATCH(name, fn) \
|
||||
static RegisterHIPDispatch<struct name##_DECLARE_DISPATCH_type> name ## __register(name, fn);
|
||||
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@
|
|||
#include <ATen/NativeFunctions.h>
|
||||
#else
|
||||
#include <ATen/ops/_fused_sdp_choice_native.h>
|
||||
#include <ATen/ops/_fused_sdp_choice_ops.h>
|
||||
#include <ATen/ops/_masked_softmax.h>
|
||||
#include <ATen/ops/_native_multi_head_attention_native.h>
|
||||
#include <ATen/ops/_nested_from_padded.h>
|
||||
|
|
@ -448,6 +449,7 @@ REGISTER_AVX512_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp)
|
|||
REGISTER_VSX_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp)
|
||||
REGISTER_ZVECTOR_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp)
|
||||
REGISTER_SVE256_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp)
|
||||
REGISTER_HPU_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_meta);
|
||||
|
||||
int64_t _fused_sdp_choice_meta(
|
||||
const Tensor& query_,
|
||||
|
|
@ -459,6 +461,20 @@ int64_t _fused_sdp_choice_meta(
|
|||
std::optional<double> scale,
|
||||
bool enable_gqa) {
|
||||
auto query_key_set = query_.key_set();
|
||||
bool has_hpu = query_key_set.has(c10::DispatchKey::HPU);
|
||||
if (has_hpu) {
|
||||
auto choice_int = at::_ops::_fused_sdp_choice::redispatch(
|
||||
c10::DispatchKeySet(DispatchKey::HPU),
|
||||
query_,
|
||||
key,
|
||||
value,
|
||||
attn_mask_,
|
||||
dropout_p,
|
||||
is_causal,
|
||||
scale,
|
||||
enable_gqa);
|
||||
return choice_int;
|
||||
}
|
||||
#if defined(USE_ROCM)
|
||||
bool has_rocm = query_key_set.has(c10::DispatchKey::HIP);
|
||||
if (has_rocm) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user