[Intel GPU] Add SDPA implementation on XPU with OneDNN (#147612)

Add XPU implementation of OneDNN based SDPA operator. Will be integrated and enabled later.

Depends on BUILD_GRAPH switch in #147608

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147612
Approved by: https://github.com/EikanWang
This commit is contained in:
Ding, Yi1 2025-02-24 16:12:01 +00:00 committed by PyTorch MergeBot
parent 576ed1e400
commit cde12207a0
6 changed files with 717 additions and 9 deletions

View File

@ -0,0 +1,230 @@
#include <ATen/native/mkldnn/xpu/detail/oneDNN.h>
#include <ATen/native/transformers/attention.h>
#include <ATen/native/transformers/sdp_utils_cpp.h>
#include <c10/util/Array.h>
#include <torch/library.h>
namespace {
bool check_head_dim_size_xpu(sdp::sdp_params const& params, bool debug) {
const auto query_size_last = params.query.sym_size(-1);
const auto key_size_last = params.key.sym_size(-1);
const auto value_size_last = params.value.sym_size(-1);
if ((query_size_last != key_size_last) ||
(query_size_last != value_size_last)) {
if (debug) {
TORCH_WARN(
"OneDNN attention requires q,k,v to have the same last dimension.",
" Got Query.size(-1): ",
query_size_last,
", Key.size(-1): ",
key_size_last,
", Value.size(-1): ",
value_size_last,
" instead.");
}
return false;
}
if (query_size_last > 256) {
if (debug) {
TORCH_WARN(
"OneDNN attention requires q,k,v to have head dimension less than 256.",
" Got ",
query_size_last,
" instead.");
}
return false;
}
return true;
}
bool check_no_grad(sdp::sdp_params const& params, bool debug) {
const bool any_inputs_require_grad = params.query.requires_grad() ||
params.key.requires_grad() || params.value.requires_grad();
const bool gradmode_enabled = at::GradMode::is_enabled();
if (debug && any_inputs_require_grad && gradmode_enabled) {
TORCH_WARN("Backward or grad to be supported.");
}
return !any_inputs_require_grad || !gradmode_enabled;
}
bool use_overrideable_xpu(sdp::sdp_params const& params, bool debug) {
constexpr auto supported_dtypes = c10::array_of<at::ScalarType>(
at::kFloat, at::kBFloat16, at::kHalf); // double is not supported
// Define gate functions that determine if a flash kernel can be run
constexpr auto constraints = c10::array_of<bool (*)(
sdp::sdp_params const&, bool)>(
sdp::check_nested_tensor,
sdp::check_for_dropout,
sdp::check_tensor_shapes,
sdp::check_batch_size_and_num_heads_dense<true /*supports GQA*/>,
sdp::check_attn_mask_shape,
sdp::check_nonzero_sequence_lengths_dense,
sdp::check_last_dim_stride_equals_1_dense<false /*ignore_singleton_dim*/>,
check_head_dim_size_xpu,
check_no_grad);
for (auto& constraint : constraints) {
if (!constraint(params, debug)) {
return false;
}
}
return sdp::check_tensor_dtype(params, supported_dtypes, debug);
}
sdp::SDPBackend select_sdp_backend_xpu(sdp::sdp_params const& kernel_params) {
// This function defines the priority order of the different sdp backends
// 1. Flash Attention
// 2. Math fallback
auto& ctx = at::globalContext();
// use overrideable linked to onednn as overrideable implementation
if (!ctx.userEnabledMathSDP() && !ctx.userEnabledOverrideableSDP()) {
return sdp::SDPBackend::error;
}
// Get ideal kernel ordering
const std::array<sdp::SDPBackend, 2> priority_order{
sdp::SDPBackend::overrideable,
sdp::SDPBackend::math,
};
// Because TORCHCHECK checks if condition is true we negate debug so that
// The statements will be printed when debug is true
bool print_debug = false;
for (auto& backend : priority_order) {
switch (backend) {
case sdp::SDPBackend::overrideable:
if (ctx.userEnabledOverrideableSDP() &&
use_overrideable_xpu(kernel_params, print_debug)) {
return sdp::SDPBackend::overrideable;
}
break;
case sdp::SDPBackend::math:
if (ctx.userEnabledMathSDP()) {
return sdp::SDPBackend::math;
}
break;
default:
TORCH_CHECK(false, "Invalid backend");
}
}
// If we have gotten to this point then two things have happened:
// 1. use_overrideable_xpu did not satisfy the constraints to be ran
// 2. The user has explicitly disabled the math kernel
// We then re-run the kernel checks with debug enabled to print out the
// reason why the kernel was not selected
print_debug = true;
TORCH_WARN("OneDNN kernel not used because:");
use_overrideable_xpu(kernel_params, print_debug);
TORCH_CHECK(!print_debug, "No available kernel. Aborting execution.")
return sdp::SDPBackend::error;
}
} // namespace
namespace at::native {
int64_t _fused_sdp_choice_xpu(
const at::Tensor& query_,
const at::Tensor& key,
const at::Tensor& value,
const std::optional<at::Tensor>& attn_mask_,
double dropout_p,
bool is_causal,
std::optional<double> scale,
bool enable_gqa) {
sdp::sdp_params kernel_params{
query_, key, value, attn_mask_, dropout_p, is_causal, enable_gqa};
auto backend = select_sdp_backend_xpu(kernel_params);
if (backend == sdp::SDPBackend::error) {
TORCH_CHECK(
false,
"No viable backend for scaled_dot_product_attention was found. ",
"This is likely due to turning off both the math kernel and the fused kernels.");
}
return static_cast<int64_t>(backend);
}
std::tuple<
at::Tensor,
at::Tensor,
at::Tensor,
at::Tensor,
c10::SymInt,
c10::SymInt,
at::Tensor,
at::Tensor,
at::Tensor>
_scaled_dot_product_fused_attention_overrideable_xpu(
const at::Tensor& query,
const at::Tensor& key,
const at::Tensor& value,
const std::optional<at::Tensor>& attn_bias,
double dropout_p,
bool is_causal,
bool return_debug_mask,
std::optional<double> scale) {
TORCH_INTERNAL_ASSERT(
query.dim() == 4 && key.dim() == 4 && value.dim() == 4,
"scaled_dot_product_fused_attention_overrideable_xpu: Accept only 4 dims inputs shape of {(B), H, T, K}");
TORCH_INTERNAL_ASSERT(
(key.size(0) == value.size(0)) && (key.size(1) == value.size(1)) &&
(key.size(2) == value.size(2)),
"scaled_dot_product_fused_attention_overrideable_xpu: K/V should have the same batch / seq / num_head");
TORCH_INTERNAL_ASSERT(
query.size(3) == key.size(3),
"scaled_dot_product_fused_attention_overrideable_xpu: Q/K should have the same head_dim");
TORCH_INTERNAL_ASSERT(
dropout_p == 0.0,
"scaled_dot_product_fused_attention_overrideable_xpu: Currently do not support dropout > 0");
TORCH_INTERNAL_ASSERT(
!(attn_bias.has_value() && is_causal),
"scaled_dot_product_fused_attention_overrideable_xpu: attn_bias cannot present with is_causal");
const int64_t batch_size = query.size(0);
const int64_t num_head = query.size(1);
const int64_t num_head_kv = key.size(1);
const int64_t head_dim = query.size(3);
const int64_t head_dim_v = value.size(3);
const int64_t seq_len_q = query.size(2);
const int64_t seq_len_kv = key.size(2);
auto opts = query.options();
auto output = at::empty({batch_size, num_head, seq_len_q, head_dim}, opts);
// auto logsumexp =
// at::empty({batch_size, num_head, seq_len_q}, opts.dtype(at::kFloat));
auto logsumexp = at::empty({}, opts.dtype(at::kFloat));
at::native::onednn::gpu_float_sdpa(
batch_size,
seq_len_q,
seq_len_kv,
num_head,
num_head_kv,
head_dim,
head_dim_v,
query,
key,
value,
attn_bias,
is_causal,
scale.has_value() ? scale.value() : (1.0 / std::sqrt(head_dim)),
output);
// rng and debug mask not used
auto philox_seed = at::empty({}, at::dtype(at::kLong));
auto philox_offset = at::empty({}, at::dtype(at::kLong));
auto debug_attn_mask = at::empty(
{batch_size, num_head, seq_len_q, seq_len_kv}, at::dtype(at::kFloat));
return std::make_tuple(
output,
logsumexp,
/* cum_seq_q */ at::Tensor(),
/* cum_seq_k */ at::Tensor(),
seq_len_q,
seq_len_kv,
philox_seed,
philox_offset,
debug_attn_mask);
}
} // namespace at::native

View File

@ -0,0 +1,407 @@
#include <ATen/native/mkldnn/xpu/detail/Attr.h>
#include <ATen/native/mkldnn/xpu/detail/Utils.h>
#include <ATen/native/mkldnn/xpu/detail/oneDNN.h>
#include <oneapi/dnnl/dnnl.hpp>
using namespace at::native::onednn;
using logical_tensor = dnnl::graph::logical_tensor;
using data_type = logical_tensor::data_type;
using dims = logical_tensor::dims;
using op = dnnl::graph::op;
using partition = dnnl::graph::partition;
namespace {
struct SDPALogicalParams {
enum class TensorID {
query,
key,
scale,
neg_inf,
attn_mask,
value,
output,
end,
};
logical_tensor query{};
logical_tensor key{};
logical_tensor scale{};
std::optional<logical_tensor> neg_inf;
std::optional<logical_tensor> attn_mask;
logical_tensor value{};
logical_tensor output{};
SDPALogicalParams(
const at::Tensor& query_,
const at::Tensor& key_,
const at::Tensor& value_,
const std::optional<at::Tensor>& attn_mask_,
const at::Tensor& output_,
bool is_causal) {
const data_type dtype = // to logical_tensor data type
query_.scalar_type() == c10::ScalarType::Float ? data_type::f32
: query_.scalar_type() == c10::ScalarType::Half ? data_type::f16
: query_.scalar_type() == c10::ScalarType::BFloat16 ? data_type::bf16
: data_type::undef;
TORCH_INTERNAL_ASSERT(
(dtype != data_type::undef),
"Only FP16/BF16/FP32 datatypes are currently supported");
const dims scalar_shape = {1};
std::vector<logical_tensor> inputLogicalTensors;
query = {
static_cast<size_t>(TensorID::query),
dtype,
query_.sizes().vec(),
query_.strides().vec()};
key = {
static_cast<size_t>(TensorID::key),
dtype,
key_.sizes().vec(),
key_.strides().vec()};
scale = {
static_cast<size_t>(TensorID::scale),
dtype,
scalar_shape,
logical_tensor::layout_type::strided,
logical_tensor::property_type::constant};
if (is_causal) {
neg_inf = {
static_cast<size_t>(TensorID::neg_inf),
dtype,
scalar_shape,
logical_tensor::layout_type::strided,
logical_tensor::property_type::constant};
}
if (attn_mask_.has_value()) {
attn_mask = {
static_cast<size_t>(TensorID::attn_mask),
dtype,
attn_mask_->sizes().vec(),
attn_mask_->strides().vec()};
}
value = {
static_cast<size_t>(TensorID::value),
dtype,
value_.sizes().vec(),
value_.strides().vec()};
output = {
static_cast<size_t>(TensorID::output),
dtype,
output_.sizes().vec(),
output_.strides().vec()};
}
std::vector<logical_tensor> get_input() const {
std::vector<logical_tensor> input = {query, key, scale};
if (neg_inf.has_value()) {
input.push_back(neg_inf.value());
}
if (attn_mask.has_value()) {
input.push_back(attn_mask.value());
}
input.push_back(value);
return input;
}
std::vector<logical_tensor> get_output() const {
return {output};
}
};
partition create_sdpa_graph_partition(
int batch_size,
int seq_len_q,
int seq_len_k,
int num_head,
int head_dim,
bool is_causal,
data_type dtype,
const SDPALogicalParams& params) {
// graph building and partitioning
// currently, we assume that Q and K have same sequence length
dims qk_output_shape = {batch_size, num_head, seq_len_q, seq_len_k};
dims scale_shape = {1};
size_t lt_id = static_cast<size_t>(SDPALogicalParams::TensorID::end);
size_t op_id = 0;
logical_tensor matmul_qk_out{lt_id++, dtype};
op matmul_qk{
op_id++,
op::kind::MatMul,
{params.query, params.key},
{matmul_qk_out},
"matmul_qk"};
matmul_qk.set_attr<bool>(op::attr::transpose_b, true);
logical_tensor scaled_qk_out{lt_id++, dtype};
op scale_mul{
op_id++,
op::kind::Multiply,
{matmul_qk_out, params.scale},
{scaled_qk_out},
"scale_mul"};
std::optional<logical_tensor> masked_qk_out;
// For optional additive mask
std::optional<op> mask_add;
// For optional implicite causal mask
std::optional<op> mask_gen_idx_row;
std::optional<logical_tensor> mask_row_idx;
std::optional<op> mask_gen_idx_col;
std::optional<logical_tensor> mask_col_idx;
std::optional<op> mask_gt;
std::optional<logical_tensor> mask_gt_out;
std::optional<op> mask_select;
if (params.attn_mask.has_value()) {
TORCH_INTERNAL_ASSERT(
!is_causal, "Additive mask cannot use with is_causal.");
masked_qk_out = {lt_id++, dtype};
mask_add = {
op_id++,
op::kind::Add,
{scaled_qk_out, params.attn_mask.value()},
{masked_qk_out.value()},
"mask_add"};
} else if (is_causal) {
#if (DNNL_VERSION_MAJOR >= 3 && DNNL_VERSION_MINOR >= 7)
mask_row_idx = {lt_id++, data_type::s32};
mask_gen_idx_row = {
op_id++,
op::kind::GenIndex,
{scaled_qk_out},
{mask_row_idx.value()},
"mask_gen_idx_row"};
mask_gen_idx_row->set_attr<int64_t>(op::attr::axis, -2);
mask_col_idx = {lt_id++, data_type::s32};
mask_gen_idx_col = {
op_id++,
op::kind::GenIndex,
{scaled_qk_out},
{mask_col_idx.value()},
"mask_gen_idx_col"};
mask_gen_idx_col->set_attr<int64_t>(op::attr::axis, -1);
mask_gt_out = {lt_id++, data_type::boolean};
mask_gt = {
op_id++,
op::kind::GreaterEqual,
{mask_row_idx.value(), mask_col_idx.value()},
{mask_gt_out.value()},
"mask_gt"};
masked_qk_out = {lt_id++, dtype};
mask_select = {
op_id++,
op::kind::Select,
{mask_gt_out.value(), scaled_qk_out, params.neg_inf.value()},
{masked_qk_out.value()},
"mask_select"};
#else
TORCH_CHECK(
false,
"OneDNN v3.7 or later is required for implicit causal mask support.");
#endif
}
op softmax{op_id++, op::kind::SoftMax, "softmax"};
softmax.set_attr<int64_t>(op::attr::axis, -1);
logical_tensor softmax_out{lt_id++, dtype};
softmax.add_input(masked_qk_out.value_or(scaled_qk_out));
softmax.add_output(softmax_out);
op matmul_v{
op_id++,
op::kind::MatMul,
{softmax_out, params.value},
{params.output},
"matmul_v"};
constexpr auto ekind = dnnl::engine::kind::gpu;
dnnl::graph::graph g(ekind);
g.add_op(matmul_qk);
g.add_op(scale_mul);
if (mask_add.has_value()) {
g.add_op(mask_add.value());
}
if (is_causal) {
g.add_op(mask_gen_idx_row.value());
g.add_op(mask_gen_idx_col.value());
g.add_op(mask_gt.value());
g.add_op(mask_select.value());
}
g.add_op(softmax);
g.add_op(matmul_v);
g.finalize();
auto partitions = g.get_partitions();
TORCH_INTERNAL_ASSERT(
(partitions.size() == 1) && partitions[0].is_supported(),
"oneDNN doesn't support this fusion pattern. If you'd like its support, please submit a issue.");
return partitions[0];
}
partition& find_or_create_graph_partition(
int batch_size,
int seq_len_q,
int seq_len_k,
int num_head,
int head_dim,
bool is_causal,
const SDPALogicalParams& params) {
thread_local static PartitionCache cache;
const data_type dtype = params.query.get_data_type();
// cache key creation
// patternID is determined on the basis of the arguments provided
std::bitset<32> patternID;
if (dtype == data_type::f32) {
// bit 3 corresponds to float32 dtype
patternID.set(3, 1);
}
if (dtype == data_type::bf16) {
// bit 2 corresponds to fp16/bf16 dtype
patternID.set(2, 1);
}
// sdp pattern
patternID.set(4, 1);
// Refer to comments in Utils.h. The first 8 bits are reserved
int pos = 8;
// attn_mask
patternID.set(pos++, params.attn_mask.has_value());
patternID.set(pos++, is_causal);
auto partition_ = cache.find_partition(patternID);
if (!partition_.has_value()) {
// partition cache no hit
// graph building and partitioning
partition sdp_partition = create_sdpa_graph_partition(
batch_size,
seq_len_q,
seq_len_k,
num_head,
head_dim,
is_causal,
dtype,
params);
partition_ = cache.insert_partition_cache(patternID, sdp_partition);
}
return *partition_;
}
} // namespace
namespace at::native::onednn {
void gpu_float_sdpa(
int batch_size,
int seq_len_q,
int seq_len_k,
int num_head,
int num_head_kv,
int head_dim,
int head_dim_v,
const Tensor& query,
const Tensor& key,
const Tensor& value,
std::optional<at::Tensor> attn_mask,
bool is_causal,
float softmax_scale,
const Tensor& output) {
auto eng = GpuEngineManager::Instance().get_engine(
{c10::kXPU, c10::xpu::current_device()});
auto strm = GpuStreamManager::Instance().get_stream();
const auto get_tril_mask = [&]() {
auto opts = query.options();
auto bool_tril =
at::ones_symint(
{query.sym_size(-2), key.sym_size(-2)}, opts.dtype(at::kBool))
.tril();
return at::where(
bool_tril,
0.f,
at::scalar_tensor(-std::numeric_limits<float>::infinity(), opts));
};
static bool driver_support_implict_causal = true;
if (attn_mask.has_value()) {
TORCH_INTERNAL_ASSERT(
!is_causal,
"scaled_dot_product_fused_attention_overrideable_xpu: "
"attn_mask cannot present with is_causal");
} else {
// Currenetly implict mask only supports square fp16 cases
const bool support_implict_causal = driver_support_implict_causal &&
(query.dtype() == at::kHalf || query.dtype() == at::kBFloat16) &&
seq_len_q == seq_len_k;
if (is_causal && !support_implict_causal) {
attn_mask = get_tril_mask();
is_causal = false;
}
}
std::vector<logical_tensor> l_inputs, l_outputs;
std::optional<dnnl::graph::compiled_partition> compiled_partition;
auto get_compiled_partition = [&]() {
const SDPALogicalParams logical_params(
query, key, value, attn_mask, output, is_causal);
auto& partition_ = find_or_create_graph_partition(
batch_size,
seq_len_q,
seq_len_k,
num_head,
head_dim,
is_causal,
logical_params);
auto i = logical_params.get_input();
auto o = logical_params.get_output();
auto compiled_partition = partition_.compile(i, o, eng);
l_inputs = std::move(i);
l_outputs = std::move(o);
return compiled_partition;
};
// maybe retry without causal mask
try {
compiled_partition = get_compiled_partition();
} catch (std::exception& e) {
if (is_causal) {
attn_mask = get_tril_mask();
is_causal = false;
compiled_partition = get_compiled_partition();
driver_support_implict_causal = false;
} else {
throw e;
}
}
Tensor softmax_scale1 = at::full({}, softmax_scale, query.options());
std::optional<at::Tensor> neg_inf;
if (is_causal) {
neg_inf = at::full({}, -INFINITY, query.options());
}
std::vector<dnnl::graph::tensor> outputs = {
{l_outputs[0], eng, output.data_ptr()},
};
size_t i = 0;
std::vector<dnnl::graph::tensor> inputs;
inputs.reserve(l_inputs.size());
inputs.emplace_back(l_inputs[i++], eng, query.data_ptr());
inputs.emplace_back(l_inputs[i++], eng, key.data_ptr());
inputs.emplace_back(l_inputs[i++], eng, softmax_scale1.data_ptr());
if (neg_inf.has_value()) {
inputs.emplace_back(l_inputs[i++], eng, neg_inf->data_ptr());
}
if (attn_mask.has_value()) {
inputs.emplace_back(l_inputs[i++], eng, attn_mask->data_ptr());
}
inputs.emplace_back(l_inputs[i++], eng, value.data_ptr());
compiled_partition->execute(strm, inputs, outputs);
}
} // namespace at::native::onednn

View File

@ -7,6 +7,8 @@
#include <ATen/core/grad_mode.h>
#include <c10/core/MemoryFormat.h>
#include <oneapi/dnnl/dnnl.hpp>
#include <oneapi/dnnl/dnnl_graph.hpp>
#include <oneapi/dnnl/dnnl_graph_sycl.hpp>
#include <oneapi/dnnl/dnnl_sycl.hpp>
#include <oneapi/dnnl/dnnl_version.h>
@ -99,4 +101,33 @@ dnnl::memory dnnl_memory_from_host_scalar(
return mem;
}
struct PartitionCache {
std::unordered_map<std::bitset<32>, dnnl::graph::partition> partition_map_{};
// The first 8 bits are reserved
// bit 0: is int8
// bit 1: is uint8
// bit 2: fp16(0) / bf16(1)
// bit 3: is fp32
// bit 4: is sdp pattern
// bit 5-7: N/A
// The rest of the bits depend upon the arguments provided
// However, down the line, we might have different bitsets for different
// patterns
dnnl::graph::partition& insert_partition_cache(
std::bitset<32>& patternID,
dnnl::graph::partition& p) {
partition_map_[patternID] = std::move(p);
return partition_map_[patternID];
}
std::optional<std::reference_wrapper<dnnl::graph::partition>> find_partition(
std::bitset<32>& patternID) {
auto iter = partition_map_.find(patternID);
if (iter != partition_map_.end()) {
return iter->second;
}
return std::nullopt;
}
};
} // namespace at::native::onednn

View File

@ -155,4 +155,19 @@ void quantized_matmul(
c10::string_view unary_post_op_algorithm,
bool m2_trnas);
void gpu_float_sdpa(
int batch_size,
int seq_len_q,
int seq_len_k,
int num_head,
int num_head_kv,
int head_dim,
int head_dim_v,
const Tensor& query,
const Tensor& key,
const Tensor& value,
std::optional<at::Tensor> attn_mask,
bool is_causal,
float softmax_scale,
const Tensor& output);
} // namespace at::native::onednn

View File

@ -1,5 +1,8 @@
#include <ATen/native/mkldnn/xpu/detail/Utils.h>
#include <ATen/native/mkldnn/xpu/detail/oneDNNContext.h>
#include <c10/xpu/XPUCachingAllocator.h>
#include <oneapi/dnnl/dnnl_graph.hpp>
#include <oneapi/dnnl/dnnl_graph_sycl.hpp>
/* *
* Do NOT put any kernels or call any device binaries here!
@ -9,6 +12,36 @@ namespace at::native::onednn {
using namespace dnnl;
static inline void* dnnl_alloc(
size_t size,
size_t /*alignment*/,
const void* /*dev*/,
const void* /*context*/) {
return c10::xpu::XPUCachingAllocator::raw_alloc(size);
}
static inline void dnnl_delete(
void* buf,
const void* /*dev*/,
const void* /*context*/,
void* /*event*/) {
return c10::xpu::XPUCachingAllocator::raw_delete(buf);
}
GpuEngineManager::GpuEngineManager() {
c10::DeviceIndex device_count = c10::xpu::device_count();
TORCH_INTERNAL_ASSERT(device_count > 0);
for (const auto i : c10::irange(device_count)) {
static dnnl::graph::allocator alloc =
dnnl::graph::sycl_interop::make_allocator(dnnl_alloc, dnnl_delete);
engine_pool.push_back(std::make_shared<dnnl::engine>(
dnnl::graph::sycl_interop::make_engine_with_allocator(
c10::xpu::get_raw_device(i),
c10::xpu::get_device_context(),
alloc)));
}
}
GpuEngineManager& GpuEngineManager::Instance() {
static GpuEngineManager myInstance;
return myInstance;

View File

@ -37,15 +37,7 @@ struct TORCH_XPU_API GpuEngineManager {
GpuEngineManager& operator=(GpuEngineManager&&) = default;
protected:
GpuEngineManager() {
c10::DeviceIndex device_count = c10::xpu::device_count();
TORCH_INTERNAL_ASSERT(device_count > 0);
for (const auto i : c10::irange(device_count)) {
engine_pool.push_back(
std::make_shared<dnnl::engine>(dnnl::sycl_interop::make_engine(
c10::xpu::get_raw_device(i), c10::xpu::get_device_context())));
}
}
GpuEngineManager();
~GpuEngineManager() = default;
private: