mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Intel GPU] Add SDPA implementation on XPU with OneDNN (#147612)
Add XPU implementation of OneDNN based SDPA operator. Will be integrated and enabled later. Depends on BUILD_GRAPH switch in #147608 Pull Request resolved: https://github.com/pytorch/pytorch/pull/147612 Approved by: https://github.com/EikanWang
This commit is contained in:
parent
576ed1e400
commit
cde12207a0
230
aten/src/ATen/native/mkldnn/xpu/Attention.cpp
Normal file
230
aten/src/ATen/native/mkldnn/xpu/Attention.cpp
Normal file
|
|
@ -0,0 +1,230 @@
|
|||
#include <ATen/native/mkldnn/xpu/detail/oneDNN.h>
|
||||
#include <ATen/native/transformers/attention.h>
|
||||
#include <ATen/native/transformers/sdp_utils_cpp.h>
|
||||
#include <c10/util/Array.h>
|
||||
#include <torch/library.h>
|
||||
|
||||
namespace {
|
||||
bool check_head_dim_size_xpu(sdp::sdp_params const& params, bool debug) {
|
||||
const auto query_size_last = params.query.sym_size(-1);
|
||||
const auto key_size_last = params.key.sym_size(-1);
|
||||
const auto value_size_last = params.value.sym_size(-1);
|
||||
if ((query_size_last != key_size_last) ||
|
||||
(query_size_last != value_size_last)) {
|
||||
if (debug) {
|
||||
TORCH_WARN(
|
||||
"OneDNN attention requires q,k,v to have the same last dimension.",
|
||||
" Got Query.size(-1): ",
|
||||
query_size_last,
|
||||
", Key.size(-1): ",
|
||||
key_size_last,
|
||||
", Value.size(-1): ",
|
||||
value_size_last,
|
||||
" instead.");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
if (query_size_last > 256) {
|
||||
if (debug) {
|
||||
TORCH_WARN(
|
||||
"OneDNN attention requires q,k,v to have head dimension less than 256.",
|
||||
" Got ",
|
||||
query_size_last,
|
||||
" instead.");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool check_no_grad(sdp::sdp_params const& params, bool debug) {
|
||||
const bool any_inputs_require_grad = params.query.requires_grad() ||
|
||||
params.key.requires_grad() || params.value.requires_grad();
|
||||
const bool gradmode_enabled = at::GradMode::is_enabled();
|
||||
if (debug && any_inputs_require_grad && gradmode_enabled) {
|
||||
TORCH_WARN("Backward or grad to be supported.");
|
||||
}
|
||||
return !any_inputs_require_grad || !gradmode_enabled;
|
||||
}
|
||||
|
||||
bool use_overrideable_xpu(sdp::sdp_params const& params, bool debug) {
|
||||
constexpr auto supported_dtypes = c10::array_of<at::ScalarType>(
|
||||
at::kFloat, at::kBFloat16, at::kHalf); // double is not supported
|
||||
|
||||
// Define gate functions that determine if a flash kernel can be run
|
||||
constexpr auto constraints = c10::array_of<bool (*)(
|
||||
sdp::sdp_params const&, bool)>(
|
||||
sdp::check_nested_tensor,
|
||||
sdp::check_for_dropout,
|
||||
sdp::check_tensor_shapes,
|
||||
sdp::check_batch_size_and_num_heads_dense<true /*supports GQA*/>,
|
||||
sdp::check_attn_mask_shape,
|
||||
sdp::check_nonzero_sequence_lengths_dense,
|
||||
sdp::check_last_dim_stride_equals_1_dense<false /*ignore_singleton_dim*/>,
|
||||
check_head_dim_size_xpu,
|
||||
check_no_grad);
|
||||
for (auto& constraint : constraints) {
|
||||
if (!constraint(params, debug)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return sdp::check_tensor_dtype(params, supported_dtypes, debug);
|
||||
}
|
||||
|
||||
sdp::SDPBackend select_sdp_backend_xpu(sdp::sdp_params const& kernel_params) {
|
||||
// This function defines the priority order of the different sdp backends
|
||||
// 1. Flash Attention
|
||||
// 2. Math fallback
|
||||
auto& ctx = at::globalContext();
|
||||
// use overrideable linked to onednn as overrideable implementation
|
||||
if (!ctx.userEnabledMathSDP() && !ctx.userEnabledOverrideableSDP()) {
|
||||
return sdp::SDPBackend::error;
|
||||
}
|
||||
|
||||
// Get ideal kernel ordering
|
||||
const std::array<sdp::SDPBackend, 2> priority_order{
|
||||
sdp::SDPBackend::overrideable,
|
||||
sdp::SDPBackend::math,
|
||||
};
|
||||
|
||||
// Because TORCHCHECK checks if condition is true we negate debug so that
|
||||
// The statements will be printed when debug is true
|
||||
bool print_debug = false;
|
||||
for (auto& backend : priority_order) {
|
||||
switch (backend) {
|
||||
case sdp::SDPBackend::overrideable:
|
||||
if (ctx.userEnabledOverrideableSDP() &&
|
||||
use_overrideable_xpu(kernel_params, print_debug)) {
|
||||
return sdp::SDPBackend::overrideable;
|
||||
}
|
||||
break;
|
||||
case sdp::SDPBackend::math:
|
||||
if (ctx.userEnabledMathSDP()) {
|
||||
return sdp::SDPBackend::math;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(false, "Invalid backend");
|
||||
}
|
||||
}
|
||||
// If we have gotten to this point then two things have happened:
|
||||
// 1. use_overrideable_xpu did not satisfy the constraints to be ran
|
||||
// 2. The user has explicitly disabled the math kernel
|
||||
// We then re-run the kernel checks with debug enabled to print out the
|
||||
// reason why the kernel was not selected
|
||||
|
||||
print_debug = true;
|
||||
TORCH_WARN("OneDNN kernel not used because:");
|
||||
use_overrideable_xpu(kernel_params, print_debug);
|
||||
TORCH_CHECK(!print_debug, "No available kernel. Aborting execution.")
|
||||
return sdp::SDPBackend::error;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
namespace at::native {
|
||||
int64_t _fused_sdp_choice_xpu(
|
||||
const at::Tensor& query_,
|
||||
const at::Tensor& key,
|
||||
const at::Tensor& value,
|
||||
const std::optional<at::Tensor>& attn_mask_,
|
||||
double dropout_p,
|
||||
bool is_causal,
|
||||
std::optional<double> scale,
|
||||
bool enable_gqa) {
|
||||
sdp::sdp_params kernel_params{
|
||||
query_, key, value, attn_mask_, dropout_p, is_causal, enable_gqa};
|
||||
auto backend = select_sdp_backend_xpu(kernel_params);
|
||||
|
||||
if (backend == sdp::SDPBackend::error) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"No viable backend for scaled_dot_product_attention was found. ",
|
||||
"This is likely due to turning off both the math kernel and the fused kernels.");
|
||||
}
|
||||
return static_cast<int64_t>(backend);
|
||||
}
|
||||
|
||||
std::tuple<
|
||||
at::Tensor,
|
||||
at::Tensor,
|
||||
at::Tensor,
|
||||
at::Tensor,
|
||||
c10::SymInt,
|
||||
c10::SymInt,
|
||||
at::Tensor,
|
||||
at::Tensor,
|
||||
at::Tensor>
|
||||
_scaled_dot_product_fused_attention_overrideable_xpu(
|
||||
const at::Tensor& query,
|
||||
const at::Tensor& key,
|
||||
const at::Tensor& value,
|
||||
const std::optional<at::Tensor>& attn_bias,
|
||||
double dropout_p,
|
||||
bool is_causal,
|
||||
bool return_debug_mask,
|
||||
std::optional<double> scale) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
query.dim() == 4 && key.dim() == 4 && value.dim() == 4,
|
||||
"scaled_dot_product_fused_attention_overrideable_xpu: Accept only 4 dims inputs shape of {(B), H, T, K}");
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
(key.size(0) == value.size(0)) && (key.size(1) == value.size(1)) &&
|
||||
(key.size(2) == value.size(2)),
|
||||
"scaled_dot_product_fused_attention_overrideable_xpu: K/V should have the same batch / seq / num_head");
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
query.size(3) == key.size(3),
|
||||
"scaled_dot_product_fused_attention_overrideable_xpu: Q/K should have the same head_dim");
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
dropout_p == 0.0,
|
||||
"scaled_dot_product_fused_attention_overrideable_xpu: Currently do not support dropout > 0");
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
!(attn_bias.has_value() && is_causal),
|
||||
"scaled_dot_product_fused_attention_overrideable_xpu: attn_bias cannot present with is_causal");
|
||||
|
||||
const int64_t batch_size = query.size(0);
|
||||
const int64_t num_head = query.size(1);
|
||||
const int64_t num_head_kv = key.size(1);
|
||||
const int64_t head_dim = query.size(3);
|
||||
const int64_t head_dim_v = value.size(3);
|
||||
const int64_t seq_len_q = query.size(2);
|
||||
const int64_t seq_len_kv = key.size(2);
|
||||
|
||||
auto opts = query.options();
|
||||
auto output = at::empty({batch_size, num_head, seq_len_q, head_dim}, opts);
|
||||
// auto logsumexp =
|
||||
// at::empty({batch_size, num_head, seq_len_q}, opts.dtype(at::kFloat));
|
||||
auto logsumexp = at::empty({}, opts.dtype(at::kFloat));
|
||||
|
||||
at::native::onednn::gpu_float_sdpa(
|
||||
batch_size,
|
||||
seq_len_q,
|
||||
seq_len_kv,
|
||||
num_head,
|
||||
num_head_kv,
|
||||
head_dim,
|
||||
head_dim_v,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_bias,
|
||||
is_causal,
|
||||
scale.has_value() ? scale.value() : (1.0 / std::sqrt(head_dim)),
|
||||
output);
|
||||
|
||||
// rng and debug mask not used
|
||||
auto philox_seed = at::empty({}, at::dtype(at::kLong));
|
||||
auto philox_offset = at::empty({}, at::dtype(at::kLong));
|
||||
auto debug_attn_mask = at::empty(
|
||||
{batch_size, num_head, seq_len_q, seq_len_kv}, at::dtype(at::kFloat));
|
||||
|
||||
return std::make_tuple(
|
||||
output,
|
||||
logsumexp,
|
||||
/* cum_seq_q */ at::Tensor(),
|
||||
/* cum_seq_k */ at::Tensor(),
|
||||
seq_len_q,
|
||||
seq_len_kv,
|
||||
philox_seed,
|
||||
philox_offset,
|
||||
debug_attn_mask);
|
||||
}
|
||||
} // namespace at::native
|
||||
407
aten/src/ATen/native/mkldnn/xpu/detail/Attention.cpp
Normal file
407
aten/src/ATen/native/mkldnn/xpu/detail/Attention.cpp
Normal file
|
|
@ -0,0 +1,407 @@
|
|||
#include <ATen/native/mkldnn/xpu/detail/Attr.h>
|
||||
#include <ATen/native/mkldnn/xpu/detail/Utils.h>
|
||||
#include <ATen/native/mkldnn/xpu/detail/oneDNN.h>
|
||||
|
||||
#include <oneapi/dnnl/dnnl.hpp>
|
||||
|
||||
using namespace at::native::onednn;
|
||||
using logical_tensor = dnnl::graph::logical_tensor;
|
||||
using data_type = logical_tensor::data_type;
|
||||
using dims = logical_tensor::dims;
|
||||
using op = dnnl::graph::op;
|
||||
using partition = dnnl::graph::partition;
|
||||
|
||||
namespace {
|
||||
struct SDPALogicalParams {
|
||||
enum class TensorID {
|
||||
query,
|
||||
key,
|
||||
scale,
|
||||
neg_inf,
|
||||
attn_mask,
|
||||
value,
|
||||
output,
|
||||
end,
|
||||
};
|
||||
|
||||
logical_tensor query{};
|
||||
logical_tensor key{};
|
||||
logical_tensor scale{};
|
||||
std::optional<logical_tensor> neg_inf;
|
||||
std::optional<logical_tensor> attn_mask;
|
||||
logical_tensor value{};
|
||||
logical_tensor output{};
|
||||
|
||||
SDPALogicalParams(
|
||||
const at::Tensor& query_,
|
||||
const at::Tensor& key_,
|
||||
const at::Tensor& value_,
|
||||
const std::optional<at::Tensor>& attn_mask_,
|
||||
const at::Tensor& output_,
|
||||
bool is_causal) {
|
||||
const data_type dtype = // to logical_tensor data type
|
||||
query_.scalar_type() == c10::ScalarType::Float ? data_type::f32
|
||||
: query_.scalar_type() == c10::ScalarType::Half ? data_type::f16
|
||||
: query_.scalar_type() == c10::ScalarType::BFloat16 ? data_type::bf16
|
||||
: data_type::undef;
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
(dtype != data_type::undef),
|
||||
"Only FP16/BF16/FP32 datatypes are currently supported");
|
||||
const dims scalar_shape = {1};
|
||||
std::vector<logical_tensor> inputLogicalTensors;
|
||||
query = {
|
||||
static_cast<size_t>(TensorID::query),
|
||||
dtype,
|
||||
query_.sizes().vec(),
|
||||
query_.strides().vec()};
|
||||
key = {
|
||||
static_cast<size_t>(TensorID::key),
|
||||
dtype,
|
||||
key_.sizes().vec(),
|
||||
key_.strides().vec()};
|
||||
scale = {
|
||||
static_cast<size_t>(TensorID::scale),
|
||||
dtype,
|
||||
scalar_shape,
|
||||
logical_tensor::layout_type::strided,
|
||||
logical_tensor::property_type::constant};
|
||||
if (is_causal) {
|
||||
neg_inf = {
|
||||
static_cast<size_t>(TensorID::neg_inf),
|
||||
dtype,
|
||||
scalar_shape,
|
||||
logical_tensor::layout_type::strided,
|
||||
logical_tensor::property_type::constant};
|
||||
}
|
||||
if (attn_mask_.has_value()) {
|
||||
attn_mask = {
|
||||
static_cast<size_t>(TensorID::attn_mask),
|
||||
dtype,
|
||||
attn_mask_->sizes().vec(),
|
||||
attn_mask_->strides().vec()};
|
||||
}
|
||||
value = {
|
||||
static_cast<size_t>(TensorID::value),
|
||||
dtype,
|
||||
value_.sizes().vec(),
|
||||
value_.strides().vec()};
|
||||
output = {
|
||||
static_cast<size_t>(TensorID::output),
|
||||
dtype,
|
||||
output_.sizes().vec(),
|
||||
output_.strides().vec()};
|
||||
}
|
||||
std::vector<logical_tensor> get_input() const {
|
||||
std::vector<logical_tensor> input = {query, key, scale};
|
||||
if (neg_inf.has_value()) {
|
||||
input.push_back(neg_inf.value());
|
||||
}
|
||||
if (attn_mask.has_value()) {
|
||||
input.push_back(attn_mask.value());
|
||||
}
|
||||
input.push_back(value);
|
||||
return input;
|
||||
}
|
||||
std::vector<logical_tensor> get_output() const {
|
||||
return {output};
|
||||
}
|
||||
};
|
||||
|
||||
partition create_sdpa_graph_partition(
|
||||
int batch_size,
|
||||
int seq_len_q,
|
||||
int seq_len_k,
|
||||
int num_head,
|
||||
int head_dim,
|
||||
bool is_causal,
|
||||
data_type dtype,
|
||||
const SDPALogicalParams& params) {
|
||||
// graph building and partitioning
|
||||
// currently, we assume that Q and K have same sequence length
|
||||
|
||||
dims qk_output_shape = {batch_size, num_head, seq_len_q, seq_len_k};
|
||||
dims scale_shape = {1};
|
||||
size_t lt_id = static_cast<size_t>(SDPALogicalParams::TensorID::end);
|
||||
size_t op_id = 0;
|
||||
|
||||
logical_tensor matmul_qk_out{lt_id++, dtype};
|
||||
op matmul_qk{
|
||||
op_id++,
|
||||
op::kind::MatMul,
|
||||
{params.query, params.key},
|
||||
{matmul_qk_out},
|
||||
"matmul_qk"};
|
||||
matmul_qk.set_attr<bool>(op::attr::transpose_b, true);
|
||||
|
||||
logical_tensor scaled_qk_out{lt_id++, dtype};
|
||||
op scale_mul{
|
||||
op_id++,
|
||||
op::kind::Multiply,
|
||||
{matmul_qk_out, params.scale},
|
||||
{scaled_qk_out},
|
||||
"scale_mul"};
|
||||
|
||||
std::optional<logical_tensor> masked_qk_out;
|
||||
|
||||
// For optional additive mask
|
||||
std::optional<op> mask_add;
|
||||
|
||||
// For optional implicite causal mask
|
||||
std::optional<op> mask_gen_idx_row;
|
||||
std::optional<logical_tensor> mask_row_idx;
|
||||
std::optional<op> mask_gen_idx_col;
|
||||
std::optional<logical_tensor> mask_col_idx;
|
||||
std::optional<op> mask_gt;
|
||||
std::optional<logical_tensor> mask_gt_out;
|
||||
std::optional<op> mask_select;
|
||||
|
||||
if (params.attn_mask.has_value()) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
!is_causal, "Additive mask cannot use with is_causal.");
|
||||
masked_qk_out = {lt_id++, dtype};
|
||||
mask_add = {
|
||||
op_id++,
|
||||
op::kind::Add,
|
||||
{scaled_qk_out, params.attn_mask.value()},
|
||||
{masked_qk_out.value()},
|
||||
"mask_add"};
|
||||
} else if (is_causal) {
|
||||
#if (DNNL_VERSION_MAJOR >= 3 && DNNL_VERSION_MINOR >= 7)
|
||||
mask_row_idx = {lt_id++, data_type::s32};
|
||||
mask_gen_idx_row = {
|
||||
op_id++,
|
||||
op::kind::GenIndex,
|
||||
{scaled_qk_out},
|
||||
{mask_row_idx.value()},
|
||||
"mask_gen_idx_row"};
|
||||
mask_gen_idx_row->set_attr<int64_t>(op::attr::axis, -2);
|
||||
|
||||
mask_col_idx = {lt_id++, data_type::s32};
|
||||
mask_gen_idx_col = {
|
||||
op_id++,
|
||||
op::kind::GenIndex,
|
||||
{scaled_qk_out},
|
||||
{mask_col_idx.value()},
|
||||
"mask_gen_idx_col"};
|
||||
mask_gen_idx_col->set_attr<int64_t>(op::attr::axis, -1);
|
||||
|
||||
mask_gt_out = {lt_id++, data_type::boolean};
|
||||
mask_gt = {
|
||||
op_id++,
|
||||
op::kind::GreaterEqual,
|
||||
{mask_row_idx.value(), mask_col_idx.value()},
|
||||
{mask_gt_out.value()},
|
||||
"mask_gt"};
|
||||
|
||||
masked_qk_out = {lt_id++, dtype};
|
||||
mask_select = {
|
||||
op_id++,
|
||||
op::kind::Select,
|
||||
{mask_gt_out.value(), scaled_qk_out, params.neg_inf.value()},
|
||||
{masked_qk_out.value()},
|
||||
"mask_select"};
|
||||
#else
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"OneDNN v3.7 or later is required for implicit causal mask support.");
|
||||
#endif
|
||||
}
|
||||
|
||||
op softmax{op_id++, op::kind::SoftMax, "softmax"};
|
||||
softmax.set_attr<int64_t>(op::attr::axis, -1);
|
||||
|
||||
logical_tensor softmax_out{lt_id++, dtype};
|
||||
softmax.add_input(masked_qk_out.value_or(scaled_qk_out));
|
||||
softmax.add_output(softmax_out);
|
||||
|
||||
op matmul_v{
|
||||
op_id++,
|
||||
op::kind::MatMul,
|
||||
{softmax_out, params.value},
|
||||
{params.output},
|
||||
"matmul_v"};
|
||||
|
||||
constexpr auto ekind = dnnl::engine::kind::gpu;
|
||||
dnnl::graph::graph g(ekind);
|
||||
g.add_op(matmul_qk);
|
||||
g.add_op(scale_mul);
|
||||
if (mask_add.has_value()) {
|
||||
g.add_op(mask_add.value());
|
||||
}
|
||||
if (is_causal) {
|
||||
g.add_op(mask_gen_idx_row.value());
|
||||
g.add_op(mask_gen_idx_col.value());
|
||||
g.add_op(mask_gt.value());
|
||||
g.add_op(mask_select.value());
|
||||
}
|
||||
|
||||
g.add_op(softmax);
|
||||
g.add_op(matmul_v);
|
||||
g.finalize();
|
||||
auto partitions = g.get_partitions();
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
(partitions.size() == 1) && partitions[0].is_supported(),
|
||||
"oneDNN doesn't support this fusion pattern. If you'd like its support, please submit a issue.");
|
||||
return partitions[0];
|
||||
}
|
||||
|
||||
partition& find_or_create_graph_partition(
|
||||
int batch_size,
|
||||
int seq_len_q,
|
||||
int seq_len_k,
|
||||
int num_head,
|
||||
int head_dim,
|
||||
bool is_causal,
|
||||
const SDPALogicalParams& params) {
|
||||
thread_local static PartitionCache cache;
|
||||
const data_type dtype = params.query.get_data_type();
|
||||
|
||||
// cache key creation
|
||||
// patternID is determined on the basis of the arguments provided
|
||||
std::bitset<32> patternID;
|
||||
if (dtype == data_type::f32) {
|
||||
// bit 3 corresponds to float32 dtype
|
||||
patternID.set(3, 1);
|
||||
}
|
||||
if (dtype == data_type::bf16) {
|
||||
// bit 2 corresponds to fp16/bf16 dtype
|
||||
patternID.set(2, 1);
|
||||
}
|
||||
// sdp pattern
|
||||
patternID.set(4, 1);
|
||||
|
||||
// Refer to comments in Utils.h. The first 8 bits are reserved
|
||||
int pos = 8;
|
||||
// attn_mask
|
||||
patternID.set(pos++, params.attn_mask.has_value());
|
||||
patternID.set(pos++, is_causal);
|
||||
|
||||
auto partition_ = cache.find_partition(patternID);
|
||||
if (!partition_.has_value()) {
|
||||
// partition cache no hit
|
||||
// graph building and partitioning
|
||||
partition sdp_partition = create_sdpa_graph_partition(
|
||||
batch_size,
|
||||
seq_len_q,
|
||||
seq_len_k,
|
||||
num_head,
|
||||
head_dim,
|
||||
is_causal,
|
||||
dtype,
|
||||
params);
|
||||
partition_ = cache.insert_partition_cache(patternID, sdp_partition);
|
||||
}
|
||||
return *partition_;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
namespace at::native::onednn {
|
||||
void gpu_float_sdpa(
|
||||
int batch_size,
|
||||
int seq_len_q,
|
||||
int seq_len_k,
|
||||
int num_head,
|
||||
int num_head_kv,
|
||||
int head_dim,
|
||||
int head_dim_v,
|
||||
const Tensor& query,
|
||||
const Tensor& key,
|
||||
const Tensor& value,
|
||||
std::optional<at::Tensor> attn_mask,
|
||||
bool is_causal,
|
||||
float softmax_scale,
|
||||
const Tensor& output) {
|
||||
auto eng = GpuEngineManager::Instance().get_engine(
|
||||
{c10::kXPU, c10::xpu::current_device()});
|
||||
auto strm = GpuStreamManager::Instance().get_stream();
|
||||
|
||||
const auto get_tril_mask = [&]() {
|
||||
auto opts = query.options();
|
||||
auto bool_tril =
|
||||
at::ones_symint(
|
||||
{query.sym_size(-2), key.sym_size(-2)}, opts.dtype(at::kBool))
|
||||
.tril();
|
||||
return at::where(
|
||||
bool_tril,
|
||||
0.f,
|
||||
at::scalar_tensor(-std::numeric_limits<float>::infinity(), opts));
|
||||
};
|
||||
|
||||
static bool driver_support_implict_causal = true;
|
||||
if (attn_mask.has_value()) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
!is_causal,
|
||||
"scaled_dot_product_fused_attention_overrideable_xpu: "
|
||||
"attn_mask cannot present with is_causal");
|
||||
} else {
|
||||
// Currenetly implict mask only supports square fp16 cases
|
||||
const bool support_implict_causal = driver_support_implict_causal &&
|
||||
(query.dtype() == at::kHalf || query.dtype() == at::kBFloat16) &&
|
||||
seq_len_q == seq_len_k;
|
||||
if (is_causal && !support_implict_causal) {
|
||||
attn_mask = get_tril_mask();
|
||||
is_causal = false;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<logical_tensor> l_inputs, l_outputs;
|
||||
std::optional<dnnl::graph::compiled_partition> compiled_partition;
|
||||
|
||||
auto get_compiled_partition = [&]() {
|
||||
const SDPALogicalParams logical_params(
|
||||
query, key, value, attn_mask, output, is_causal);
|
||||
auto& partition_ = find_or_create_graph_partition(
|
||||
batch_size,
|
||||
seq_len_q,
|
||||
seq_len_k,
|
||||
num_head,
|
||||
head_dim,
|
||||
is_causal,
|
||||
logical_params);
|
||||
auto i = logical_params.get_input();
|
||||
auto o = logical_params.get_output();
|
||||
auto compiled_partition = partition_.compile(i, o, eng);
|
||||
l_inputs = std::move(i);
|
||||
l_outputs = std::move(o);
|
||||
return compiled_partition;
|
||||
};
|
||||
|
||||
// maybe retry without causal mask
|
||||
try {
|
||||
compiled_partition = get_compiled_partition();
|
||||
} catch (std::exception& e) {
|
||||
if (is_causal) {
|
||||
attn_mask = get_tril_mask();
|
||||
is_causal = false;
|
||||
compiled_partition = get_compiled_partition();
|
||||
driver_support_implict_causal = false;
|
||||
} else {
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
|
||||
Tensor softmax_scale1 = at::full({}, softmax_scale, query.options());
|
||||
std::optional<at::Tensor> neg_inf;
|
||||
if (is_causal) {
|
||||
neg_inf = at::full({}, -INFINITY, query.options());
|
||||
}
|
||||
|
||||
std::vector<dnnl::graph::tensor> outputs = {
|
||||
{l_outputs[0], eng, output.data_ptr()},
|
||||
};
|
||||
size_t i = 0;
|
||||
std::vector<dnnl::graph::tensor> inputs;
|
||||
inputs.reserve(l_inputs.size());
|
||||
inputs.emplace_back(l_inputs[i++], eng, query.data_ptr());
|
||||
inputs.emplace_back(l_inputs[i++], eng, key.data_ptr());
|
||||
inputs.emplace_back(l_inputs[i++], eng, softmax_scale1.data_ptr());
|
||||
if (neg_inf.has_value()) {
|
||||
inputs.emplace_back(l_inputs[i++], eng, neg_inf->data_ptr());
|
||||
}
|
||||
if (attn_mask.has_value()) {
|
||||
inputs.emplace_back(l_inputs[i++], eng, attn_mask->data_ptr());
|
||||
}
|
||||
inputs.emplace_back(l_inputs[i++], eng, value.data_ptr());
|
||||
compiled_partition->execute(strm, inputs, outputs);
|
||||
}
|
||||
} // namespace at::native::onednn
|
||||
|
|
@ -7,6 +7,8 @@
|
|||
#include <ATen/core/grad_mode.h>
|
||||
#include <c10/core/MemoryFormat.h>
|
||||
#include <oneapi/dnnl/dnnl.hpp>
|
||||
#include <oneapi/dnnl/dnnl_graph.hpp>
|
||||
#include <oneapi/dnnl/dnnl_graph_sycl.hpp>
|
||||
#include <oneapi/dnnl/dnnl_sycl.hpp>
|
||||
#include <oneapi/dnnl/dnnl_version.h>
|
||||
|
||||
|
|
@ -99,4 +101,33 @@ dnnl::memory dnnl_memory_from_host_scalar(
|
|||
return mem;
|
||||
}
|
||||
|
||||
struct PartitionCache {
|
||||
std::unordered_map<std::bitset<32>, dnnl::graph::partition> partition_map_{};
|
||||
|
||||
// The first 8 bits are reserved
|
||||
// bit 0: is int8
|
||||
// bit 1: is uint8
|
||||
// bit 2: fp16(0) / bf16(1)
|
||||
// bit 3: is fp32
|
||||
// bit 4: is sdp pattern
|
||||
// bit 5-7: N/A
|
||||
// The rest of the bits depend upon the arguments provided
|
||||
// However, down the line, we might have different bitsets for different
|
||||
// patterns
|
||||
dnnl::graph::partition& insert_partition_cache(
|
||||
std::bitset<32>& patternID,
|
||||
dnnl::graph::partition& p) {
|
||||
partition_map_[patternID] = std::move(p);
|
||||
return partition_map_[patternID];
|
||||
}
|
||||
std::optional<std::reference_wrapper<dnnl::graph::partition>> find_partition(
|
||||
std::bitset<32>& patternID) {
|
||||
auto iter = partition_map_.find(patternID);
|
||||
if (iter != partition_map_.end()) {
|
||||
return iter->second;
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace at::native::onednn
|
||||
|
|
|
|||
|
|
@ -155,4 +155,19 @@ void quantized_matmul(
|
|||
c10::string_view unary_post_op_algorithm,
|
||||
bool m2_trnas);
|
||||
|
||||
void gpu_float_sdpa(
|
||||
int batch_size,
|
||||
int seq_len_q,
|
||||
int seq_len_k,
|
||||
int num_head,
|
||||
int num_head_kv,
|
||||
int head_dim,
|
||||
int head_dim_v,
|
||||
const Tensor& query,
|
||||
const Tensor& key,
|
||||
const Tensor& value,
|
||||
std::optional<at::Tensor> attn_mask,
|
||||
bool is_causal,
|
||||
float softmax_scale,
|
||||
const Tensor& output);
|
||||
} // namespace at::native::onednn
|
||||
|
|
|
|||
|
|
@ -1,5 +1,8 @@
|
|||
#include <ATen/native/mkldnn/xpu/detail/Utils.h>
|
||||
#include <ATen/native/mkldnn/xpu/detail/oneDNNContext.h>
|
||||
#include <c10/xpu/XPUCachingAllocator.h>
|
||||
#include <oneapi/dnnl/dnnl_graph.hpp>
|
||||
#include <oneapi/dnnl/dnnl_graph_sycl.hpp>
|
||||
|
||||
/* *
|
||||
* Do NOT put any kernels or call any device binaries here!
|
||||
|
|
@ -9,6 +12,36 @@ namespace at::native::onednn {
|
|||
|
||||
using namespace dnnl;
|
||||
|
||||
static inline void* dnnl_alloc(
|
||||
size_t size,
|
||||
size_t /*alignment*/,
|
||||
const void* /*dev*/,
|
||||
const void* /*context*/) {
|
||||
return c10::xpu::XPUCachingAllocator::raw_alloc(size);
|
||||
}
|
||||
|
||||
static inline void dnnl_delete(
|
||||
void* buf,
|
||||
const void* /*dev*/,
|
||||
const void* /*context*/,
|
||||
void* /*event*/) {
|
||||
return c10::xpu::XPUCachingAllocator::raw_delete(buf);
|
||||
}
|
||||
|
||||
GpuEngineManager::GpuEngineManager() {
|
||||
c10::DeviceIndex device_count = c10::xpu::device_count();
|
||||
TORCH_INTERNAL_ASSERT(device_count > 0);
|
||||
for (const auto i : c10::irange(device_count)) {
|
||||
static dnnl::graph::allocator alloc =
|
||||
dnnl::graph::sycl_interop::make_allocator(dnnl_alloc, dnnl_delete);
|
||||
engine_pool.push_back(std::make_shared<dnnl::engine>(
|
||||
dnnl::graph::sycl_interop::make_engine_with_allocator(
|
||||
c10::xpu::get_raw_device(i),
|
||||
c10::xpu::get_device_context(),
|
||||
alloc)));
|
||||
}
|
||||
}
|
||||
|
||||
GpuEngineManager& GpuEngineManager::Instance() {
|
||||
static GpuEngineManager myInstance;
|
||||
return myInstance;
|
||||
|
|
|
|||
|
|
@ -37,15 +37,7 @@ struct TORCH_XPU_API GpuEngineManager {
|
|||
GpuEngineManager& operator=(GpuEngineManager&&) = default;
|
||||
|
||||
protected:
|
||||
GpuEngineManager() {
|
||||
c10::DeviceIndex device_count = c10::xpu::device_count();
|
||||
TORCH_INTERNAL_ASSERT(device_count > 0);
|
||||
for (const auto i : c10::irange(device_count)) {
|
||||
engine_pool.push_back(
|
||||
std::make_shared<dnnl::engine>(dnnl::sycl_interop::make_engine(
|
||||
c10::xpu::get_raw_device(i), c10::xpu::get_device_context())));
|
||||
}
|
||||
}
|
||||
GpuEngineManager();
|
||||
~GpuEngineManager() = default;
|
||||
|
||||
private:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user