[Intel GPU] Dispatch Stub support (#130019)

# Motivation
Structured codegen is beneficial for easier decoupling tensor meta setting and kernel implementation. At present, XPU operators need to handle tensor metas in hand-written way.

We plan to leverage the codegen system for auto generate structured operators. This PR facilitate the `DispatchStub` support for  Intel GPUs. Based on that, XPU operators would have possibility to register kernel functor to operator stubs.

This is a prerequisite of PR #130082, where we will modify the codegen system to generate XPU needed source files and headers.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130019
Approved by: https://github.com/EikanWang, https://github.com/gujinghui, https://github.com/albanD
This commit is contained in:
Yan Zhiwei 2024-07-26 07:05:35 +00:00 committed by PyTorch MergeBot
parent 5b3b2b9cc7
commit 2a02b5cd22
5 changed files with 37 additions and 0 deletions

View File

@ -1165,6 +1165,10 @@ if(APPLE)
append_cxx_flag_if_supported("-Wno-missing-braces" CMAKE_CXX_FLAGS)
endif()
if(USE_XPU)
string(APPEND CMAKE_CXX_FLAGS " -DUSE_XPU")
endif()
if(EMSCRIPTEN)
string(
APPEND

View File

@ -112,6 +112,7 @@ DispatchResult DispatchStubImpl::try_get_call_ptr(
c10::DeviceType::CUDA,
c10::DeviceType::HIP,
c10::DeviceType::MPS,
c10::DeviceType::XPU,
c10::DeviceType::PrivateUse1
);
// Check if the device type is supported.
@ -158,6 +159,11 @@ DispatchResult DispatchStubImpl::try_get_call_ptr(
return mps_dispatch_ptr != nullptr ? DispatchResult(mps_dispatch_ptr) : ErrorType::MissingDeviceKernel;
#endif
#if defined(USE_XPU)
case DeviceType::XPU:
return xpu_dispatch_ptr != nullptr ? DispatchResult(xpu_dispatch_ptr) : ErrorType::MissingDeviceKernel;
#endif
case DeviceType::PrivateUse1:
return privateuse1_dispatch_ptr != nullptr ? DispatchResult(privateuse1_dispatch_ptr) : ErrorType::MissingDeviceKernel;

View File

@ -43,6 +43,7 @@
// - CUDA: NVIDIA GPUs
// - HIP: AMD GPUs
// - MPS: Apple Silicon GPUs (Metal Performance Shaders)
// - XPU: Intel GPUs
// - PrivateUse1: Reserved for private/custom device types
//
// If you want to update the list of supported devices, add a new dispatch_ptr
@ -177,12 +178,18 @@ struct TORCH_API DispatchStubImpl {
void* cuda_dispatch_ptr;
void* hip_dispatch_ptr;
void* mps_dispatch_ptr;
#if defined(USE_XPU)
void* xpu_dispatch_ptr;
#endif
void* privateuse1_dispatch_ptr;
#else
std::atomic<void*> cpu_dispatch_ptr{nullptr};
void* cuda_dispatch_ptr = nullptr;
void* hip_dispatch_ptr = nullptr;
void* mps_dispatch_ptr = nullptr;
#if defined(USE_XPU)
void* xpu_dispatch_ptr = nullptr;
#endif
void* privateuse1_dispatch_ptr = nullptr;
#endif
};
@ -227,6 +234,12 @@ public:
impl.cuda_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
}
#if defined(USE_XPU)
void set_xpu_dispatch_ptr(FnPtr fn_ptr){
impl.xpu_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
}
#endif
void set_hip_dispatch_ptr(FnPtr fn_ptr) {
impl.hip_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
}
@ -288,6 +301,13 @@ struct RegisterCUDADispatch {
}
};
template <typename DispatchStub>
struct RegisterXPUDispatch {
RegisterXPUDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value){
stub.set_xpu_dispatch_ptr(value);
}
};
template <typename DispatchStub>
struct RegisterMPSDispatch {
RegisterMPSDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) {
@ -368,6 +388,9 @@ struct RegisterPRIVATEUSE1Dispatch {
#define REGISTER_CUDA_DISPATCH(name, fn) \
static RegisterCUDADispatch<struct name##_DECLARE_DISPATCH_type> name ## __register(name, fn);
#define REGISTER_XPU_DISPATCH(name, fn) \
static RegisterXPUDispatch<struct name##_DECLARE_DISPATCH_type> name ## __register(name, fn);
#define REGISTER_HIP_DISPATCH(name, fn) \
static RegisterHIPDispatch<struct name##_DECLARE_DISPATCH_type> name ## __register(name, fn);

View File

@ -17,6 +17,7 @@ from torch.testing._internal.common_utils import (
skipIfTorchDynamo,
TemporaryFileName,
TEST_CUDA,
TEST_XPU,
)
from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME
@ -68,6 +69,7 @@ def generate_faked_module():
@unittest.skipIf(IS_ARM64, "Does not work on arm")
@unittest.skipIf(TEST_XPU, "XPU does not support cppextension currently")
@torch.testing._internal.common_utils.markDynamoStrictTest
class TestCppExtensionOpenRgistration(common.TestCase):
"""Tests Open Device Registration with C++ extensions."""

View File

@ -37,6 +37,7 @@ from torch.testing._internal.common_utils import (
NOTEST_CPU,
IS_WINDOWS,
TEST_WITH_TORCHDYNAMO,
TEST_XPU,
)
from torch._dynamo.testing import CompileCounterWithBackend
@ -3617,6 +3618,7 @@ class TestAttnBias(NNTestCase):
with self.assertRaisesRegex(ValueError, "CausalBias should not be used with causal=True"):
scaled_dot_product_attention(query, key, value, attn_mask=attn_bias, is_causal=True, dropout_p=0.0)
@unittest.skipIf(TEST_XPU, "XPU does not support cppextension currently")
@unittest.skipIf(IS_FBCODE, "Ninja is required to load C++ extensions and it's not compatible with Buck ")
class TestSDPAPrivateUse1Only(NNTestCase):
@classmethod