mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[Intel GPU] Dispatch Stub support (#130019)
# Motivation Structured codegen is beneficial for easier decoupling tensor meta setting and kernel implementation. At present, XPU operators need to handle tensor metas in hand-written way. We plan to leverage the codegen system for auto generate structured operators. This PR facilitate the `DispatchStub` support for Intel GPUs. Based on that, XPU operators would have possibility to register kernel functor to operator stubs. This is a prerequisite of PR #130082, where we will modify the codegen system to generate XPU needed source files and headers. Pull Request resolved: https://github.com/pytorch/pytorch/pull/130019 Approved by: https://github.com/EikanWang, https://github.com/gujinghui, https://github.com/albanD
This commit is contained in:
parent
5b3b2b9cc7
commit
2a02b5cd22
|
|
@ -1165,6 +1165,10 @@ if(APPLE)
|
|||
append_cxx_flag_if_supported("-Wno-missing-braces" CMAKE_CXX_FLAGS)
|
||||
endif()
|
||||
|
||||
if(USE_XPU)
|
||||
string(APPEND CMAKE_CXX_FLAGS " -DUSE_XPU")
|
||||
endif()
|
||||
|
||||
if(EMSCRIPTEN)
|
||||
string(
|
||||
APPEND
|
||||
|
|
|
|||
|
|
@ -112,6 +112,7 @@ DispatchResult DispatchStubImpl::try_get_call_ptr(
|
|||
c10::DeviceType::CUDA,
|
||||
c10::DeviceType::HIP,
|
||||
c10::DeviceType::MPS,
|
||||
c10::DeviceType::XPU,
|
||||
c10::DeviceType::PrivateUse1
|
||||
);
|
||||
// Check if the device type is supported.
|
||||
|
|
@ -158,6 +159,11 @@ DispatchResult DispatchStubImpl::try_get_call_ptr(
|
|||
return mps_dispatch_ptr != nullptr ? DispatchResult(mps_dispatch_ptr) : ErrorType::MissingDeviceKernel;
|
||||
#endif
|
||||
|
||||
#if defined(USE_XPU)
|
||||
case DeviceType::XPU:
|
||||
return xpu_dispatch_ptr != nullptr ? DispatchResult(xpu_dispatch_ptr) : ErrorType::MissingDeviceKernel;
|
||||
#endif
|
||||
|
||||
case DeviceType::PrivateUse1:
|
||||
return privateuse1_dispatch_ptr != nullptr ? DispatchResult(privateuse1_dispatch_ptr) : ErrorType::MissingDeviceKernel;
|
||||
|
||||
|
|
|
|||
|
|
@ -43,6 +43,7 @@
|
|||
// - CUDA: NVIDIA GPUs
|
||||
// - HIP: AMD GPUs
|
||||
// - MPS: Apple Silicon GPUs (Metal Performance Shaders)
|
||||
// - XPU: Intel GPUs
|
||||
// - PrivateUse1: Reserved for private/custom device types
|
||||
//
|
||||
// If you want to update the list of supported devices, add a new dispatch_ptr
|
||||
|
|
@ -177,12 +178,18 @@ struct TORCH_API DispatchStubImpl {
|
|||
void* cuda_dispatch_ptr;
|
||||
void* hip_dispatch_ptr;
|
||||
void* mps_dispatch_ptr;
|
||||
#if defined(USE_XPU)
|
||||
void* xpu_dispatch_ptr;
|
||||
#endif
|
||||
void* privateuse1_dispatch_ptr;
|
||||
#else
|
||||
std::atomic<void*> cpu_dispatch_ptr{nullptr};
|
||||
void* cuda_dispatch_ptr = nullptr;
|
||||
void* hip_dispatch_ptr = nullptr;
|
||||
void* mps_dispatch_ptr = nullptr;
|
||||
#if defined(USE_XPU)
|
||||
void* xpu_dispatch_ptr = nullptr;
|
||||
#endif
|
||||
void* privateuse1_dispatch_ptr = nullptr;
|
||||
#endif
|
||||
};
|
||||
|
|
@ -227,6 +234,12 @@ public:
|
|||
impl.cuda_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
|
||||
}
|
||||
|
||||
#if defined(USE_XPU)
|
||||
void set_xpu_dispatch_ptr(FnPtr fn_ptr){
|
||||
impl.xpu_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
|
||||
}
|
||||
#endif
|
||||
|
||||
void set_hip_dispatch_ptr(FnPtr fn_ptr) {
|
||||
impl.hip_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
|
||||
}
|
||||
|
|
@ -288,6 +301,13 @@ struct RegisterCUDADispatch {
|
|||
}
|
||||
};
|
||||
|
||||
template <typename DispatchStub>
|
||||
struct RegisterXPUDispatch {
|
||||
RegisterXPUDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value){
|
||||
stub.set_xpu_dispatch_ptr(value);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename DispatchStub>
|
||||
struct RegisterMPSDispatch {
|
||||
RegisterMPSDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) {
|
||||
|
|
@ -368,6 +388,9 @@ struct RegisterPRIVATEUSE1Dispatch {
|
|||
#define REGISTER_CUDA_DISPATCH(name, fn) \
|
||||
static RegisterCUDADispatch<struct name##_DECLARE_DISPATCH_type> name ## __register(name, fn);
|
||||
|
||||
#define REGISTER_XPU_DISPATCH(name, fn) \
|
||||
static RegisterXPUDispatch<struct name##_DECLARE_DISPATCH_type> name ## __register(name, fn);
|
||||
|
||||
#define REGISTER_HIP_DISPATCH(name, fn) \
|
||||
static RegisterHIPDispatch<struct name##_DECLARE_DISPATCH_type> name ## __register(name, fn);
|
||||
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ from torch.testing._internal.common_utils import (
|
|||
skipIfTorchDynamo,
|
||||
TemporaryFileName,
|
||||
TEST_CUDA,
|
||||
TEST_XPU,
|
||||
)
|
||||
from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME
|
||||
|
||||
|
|
@ -68,6 +69,7 @@ def generate_faked_module():
|
|||
|
||||
|
||||
@unittest.skipIf(IS_ARM64, "Does not work on arm")
|
||||
@unittest.skipIf(TEST_XPU, "XPU does not support cppextension currently")
|
||||
@torch.testing._internal.common_utils.markDynamoStrictTest
|
||||
class TestCppExtensionOpenRgistration(common.TestCase):
|
||||
"""Tests Open Device Registration with C++ extensions."""
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ from torch.testing._internal.common_utils import (
|
|||
NOTEST_CPU,
|
||||
IS_WINDOWS,
|
||||
TEST_WITH_TORCHDYNAMO,
|
||||
TEST_XPU,
|
||||
)
|
||||
from torch._dynamo.testing import CompileCounterWithBackend
|
||||
|
||||
|
|
@ -3617,6 +3618,7 @@ class TestAttnBias(NNTestCase):
|
|||
with self.assertRaisesRegex(ValueError, "CausalBias should not be used with causal=True"):
|
||||
scaled_dot_product_attention(query, key, value, attn_mask=attn_bias, is_causal=True, dropout_p=0.0)
|
||||
|
||||
@unittest.skipIf(TEST_XPU, "XPU does not support cppextension currently")
|
||||
@unittest.skipIf(IS_FBCODE, "Ninja is required to load C++ extensions and it's not compatible with Buck ")
|
||||
class TestSDPAPrivateUse1Only(NNTestCase):
|
||||
@classmethod
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user