mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Added _fused_sdp_choice_stub dispatcher support for HPU device (#149512)
Currently for HPU device we don't have any support for _fused_sdp_choice_stub dispatcher function, so for `scaled_dot_product_attention` function by default selecting the `MATH Backend` using `_fused_sdp_choice_stub` for HPU device. With this PR we have enabled support for `_fused_sdp_choice_stub` dispatcher function, so that we can invoke any backend (for example math, flash_attention, efficient_attention, cudnn_attention, overrideable) according to user choice for HPU device. Pull Request resolved: https://github.com/pytorch/pytorch/pull/149512 Approved by: https://github.com/drisspg
This commit is contained in:
parent
d0e3482266
commit
97a5e5c6b3
|
|
@ -147,6 +147,7 @@ DispatchResult DispatchStubImpl::try_get_call_ptr(
|
||||||
c10::DeviceType::MPS,
|
c10::DeviceType::MPS,
|
||||||
c10::DeviceType::MTIA,
|
c10::DeviceType::MTIA,
|
||||||
c10::DeviceType::XPU,
|
c10::DeviceType::XPU,
|
||||||
|
c10::DeviceType::HPU,
|
||||||
c10::DeviceType::PrivateUse1
|
c10::DeviceType::PrivateUse1
|
||||||
);
|
);
|
||||||
// Check if the device type is supported.
|
// Check if the device type is supported.
|
||||||
|
|
@ -203,6 +204,9 @@ DispatchResult DispatchStubImpl::try_get_call_ptr(
|
||||||
return xpu_dispatch_ptr != nullptr ? DispatchResult(xpu_dispatch_ptr) : ErrorType::MissingDeviceKernel;
|
return xpu_dispatch_ptr != nullptr ? DispatchResult(xpu_dispatch_ptr) : ErrorType::MissingDeviceKernel;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
case DeviceType::HPU:
|
||||||
|
return hpu_dispatch_ptr != nullptr ? DispatchResult(hpu_dispatch_ptr) : ErrorType::MissingDeviceKernel;
|
||||||
|
|
||||||
case DeviceType::PrivateUse1:
|
case DeviceType::PrivateUse1:
|
||||||
return privateuse1_dispatch_ptr != nullptr ? DispatchResult(privateuse1_dispatch_ptr) : ErrorType::MissingDeviceKernel;
|
return privateuse1_dispatch_ptr != nullptr ? DispatchResult(privateuse1_dispatch_ptr) : ErrorType::MissingDeviceKernel;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -44,6 +44,7 @@
|
||||||
// - MPS: Apple Silicon GPUs (Metal Performance Shaders)
|
// - MPS: Apple Silicon GPUs (Metal Performance Shaders)
|
||||||
// - MTIA: Meta Training and Inference Devices
|
// - MTIA: Meta Training and Inference Devices
|
||||||
// - XPU: Intel GPUs
|
// - XPU: Intel GPUs
|
||||||
|
// - HPU: Reserved for HPU (Intel Gaudi) device types
|
||||||
// - PrivateUse1: Reserved for private/custom device types
|
// - PrivateUse1: Reserved for private/custom device types
|
||||||
//
|
//
|
||||||
// If you want to update the list of supported devices, add a new dispatch_ptr
|
// If you want to update the list of supported devices, add a new dispatch_ptr
|
||||||
|
|
@ -196,6 +197,7 @@ struct TORCH_API DispatchStubImpl {
|
||||||
#if defined(USE_XPU)
|
#if defined(USE_XPU)
|
||||||
void* xpu_dispatch_ptr;
|
void* xpu_dispatch_ptr;
|
||||||
#endif
|
#endif
|
||||||
|
void* hpu_dispatch_ptr;
|
||||||
void* privateuse1_dispatch_ptr;
|
void* privateuse1_dispatch_ptr;
|
||||||
#else
|
#else
|
||||||
std::atomic<void*> cpu_dispatch_ptr{nullptr};
|
std::atomic<void*> cpu_dispatch_ptr{nullptr};
|
||||||
|
|
@ -206,6 +208,7 @@ struct TORCH_API DispatchStubImpl {
|
||||||
#if defined(USE_XPU)
|
#if defined(USE_XPU)
|
||||||
void* xpu_dispatch_ptr = nullptr;
|
void* xpu_dispatch_ptr = nullptr;
|
||||||
#endif
|
#endif
|
||||||
|
void* hpu_dispatch_ptr = nullptr;
|
||||||
void* privateuse1_dispatch_ptr = nullptr;
|
void* privateuse1_dispatch_ptr = nullptr;
|
||||||
#endif
|
#endif
|
||||||
};
|
};
|
||||||
|
|
@ -259,6 +262,10 @@ public:
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
void set_hpu_dispatch_ptr(FnPtr fn_ptr) {
|
||||||
|
impl.hpu_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
|
||||||
|
}
|
||||||
|
|
||||||
void set_hip_dispatch_ptr(FnPtr fn_ptr) {
|
void set_hip_dispatch_ptr(FnPtr fn_ptr) {
|
||||||
impl.hip_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
|
impl.hip_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
|
||||||
}
|
}
|
||||||
|
|
@ -337,6 +344,13 @@ struct RegisterXPUDispatch {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename DispatchStub>
|
||||||
|
struct RegisterHPUDispatch {
|
||||||
|
RegisterHPUDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value){
|
||||||
|
stub.set_hpu_dispatch_ptr(value);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
template <typename DispatchStub>
|
template <typename DispatchStub>
|
||||||
struct RegisterMPSDispatch {
|
struct RegisterMPSDispatch {
|
||||||
RegisterMPSDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) {
|
RegisterMPSDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) {
|
||||||
|
|
@ -437,6 +451,9 @@ struct RegisterPRIVATEUSE1Dispatch {
|
||||||
#define REGISTER_XPU_DISPATCH(name, fn) \
|
#define REGISTER_XPU_DISPATCH(name, fn) \
|
||||||
static RegisterXPUDispatch<struct name##_DECLARE_DISPATCH_type> name ## __register(name, fn);
|
static RegisterXPUDispatch<struct name##_DECLARE_DISPATCH_type> name ## __register(name, fn);
|
||||||
|
|
||||||
|
#define REGISTER_HPU_DISPATCH(name, fn) \
|
||||||
|
static RegisterHPUDispatch<struct name##_DECLARE_DISPATCH_type> name ## __register(name, fn);
|
||||||
|
|
||||||
#define REGISTER_HIP_DISPATCH(name, fn) \
|
#define REGISTER_HIP_DISPATCH(name, fn) \
|
||||||
static RegisterHIPDispatch<struct name##_DECLARE_DISPATCH_type> name ## __register(name, fn);
|
static RegisterHIPDispatch<struct name##_DECLARE_DISPATCH_type> name ## __register(name, fn);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -28,6 +28,7 @@
|
||||||
#include <ATen/NativeFunctions.h>
|
#include <ATen/NativeFunctions.h>
|
||||||
#else
|
#else
|
||||||
#include <ATen/ops/_fused_sdp_choice_native.h>
|
#include <ATen/ops/_fused_sdp_choice_native.h>
|
||||||
|
#include <ATen/ops/_fused_sdp_choice_ops.h>
|
||||||
#include <ATen/ops/_masked_softmax.h>
|
#include <ATen/ops/_masked_softmax.h>
|
||||||
#include <ATen/ops/_native_multi_head_attention_native.h>
|
#include <ATen/ops/_native_multi_head_attention_native.h>
|
||||||
#include <ATen/ops/_nested_from_padded.h>
|
#include <ATen/ops/_nested_from_padded.h>
|
||||||
|
|
@ -448,6 +449,7 @@ REGISTER_AVX512_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp)
|
||||||
REGISTER_VSX_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp)
|
REGISTER_VSX_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp)
|
||||||
REGISTER_ZVECTOR_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp)
|
REGISTER_ZVECTOR_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp)
|
||||||
REGISTER_SVE256_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp)
|
REGISTER_SVE256_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp)
|
||||||
|
REGISTER_HPU_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_meta);
|
||||||
|
|
||||||
int64_t _fused_sdp_choice_meta(
|
int64_t _fused_sdp_choice_meta(
|
||||||
const Tensor& query_,
|
const Tensor& query_,
|
||||||
|
|
@ -459,6 +461,20 @@ int64_t _fused_sdp_choice_meta(
|
||||||
std::optional<double> scale,
|
std::optional<double> scale,
|
||||||
bool enable_gqa) {
|
bool enable_gqa) {
|
||||||
auto query_key_set = query_.key_set();
|
auto query_key_set = query_.key_set();
|
||||||
|
bool has_hpu = query_key_set.has(c10::DispatchKey::HPU);
|
||||||
|
if (has_hpu) {
|
||||||
|
auto choice_int = at::_ops::_fused_sdp_choice::redispatch(
|
||||||
|
c10::DispatchKeySet(DispatchKey::HPU),
|
||||||
|
query_,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
attn_mask_,
|
||||||
|
dropout_p,
|
||||||
|
is_causal,
|
||||||
|
scale,
|
||||||
|
enable_gqa);
|
||||||
|
return choice_int;
|
||||||
|
}
|
||||||
#if defined(USE_ROCM)
|
#if defined(USE_ROCM)
|
||||||
bool has_rocm = query_key_set.has(c10::DispatchKey::HIP);
|
bool has_rocm = query_key_set.has(c10::DispatchKey::HIP);
|
||||||
if (has_rocm) {
|
if (has_rocm) {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user