[cuDNN][SDPA] Support attn_bias in cuDNN (#130482)

CC @drisspg
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130482
Approved by: https://github.com/drisspg
This commit is contained in:
eqy 2024-07-16 23:45:19 +00:00 committed by PyTorch MergeBot
parent 4f40a7078e
commit de177b50f8
7 changed files with 163 additions and 126 deletions

View File

@ -22,6 +22,7 @@ void run_cudnn_SDP_fprop(
const Tensor& q,
const Tensor& k,
const Tensor& v,
const std::optional<Tensor>& attn_bias,
Tensor& softmaxstats,
Tensor& o,
Tensor& dropoutseed,
@ -43,6 +44,7 @@ void run_cudnn_SDP_bprop(
const Tensor& q,
const Tensor& k,
const Tensor& v,
const std::optional<Tensor>& attn_bias,
const Tensor& o,
const Tensor& dO,
const Tensor& softmaxstats,
@ -86,9 +88,9 @@ using graph_and_tensors = std::tuple<
std::shared_ptr<fe::graph::Tensor_attributes>, // Q,
std::shared_ptr<fe::graph::Tensor_attributes>, // K,
std::shared_ptr<fe::graph::Tensor_attributes>, // V,
std::optional<std::shared_ptr<fe::graph::Tensor_attributes>>, // Bias
std::shared_ptr<fe::graph::Tensor_attributes>, // Attn_scale,
// TODO(eqy): additional options
// std::shared_ptr<fe::graph::Tensor_attributes>, // Bias,
// std::shared_ptr<fe::graph::Tensor_attributes>, // SEQ_LEN_Q,
// std::shared_ptr<fe::graph::Tensor_attributes>, // SEQ_LEN_KV,
std::shared_ptr<fe::graph::Tensor_attributes>, // Seed,
@ -104,7 +106,8 @@ using graph_and_tensors_backward = std::tuple<
std::shared_ptr<fe::graph::Tensor_attributes>, // Q,
std::shared_ptr<fe::graph::Tensor_attributes>, // K,
std::shared_ptr<fe::graph::Tensor_attributes>, // V,
std::shared_ptr<fe::graph::Tensor_attributes>, // Attn_scale
std::optional<std::shared_ptr<fe::graph::Tensor_attributes>>, // Bias,
std::shared_ptr<fe::graph::Tensor_attributes>, // Attn_scale,
std::shared_ptr<fe::graph::Tensor_attributes>, // Seed,
std::shared_ptr<fe::graph::Tensor_attributes>, // Offset,
std::shared_ptr<fe::graph::Tensor_attributes>, // O,
@ -126,6 +129,8 @@ struct MHAParams {
std::array<int, MAX_MHA_DIM> q_stride;
std::array<int, MAX_MHA_DIM> k_stride;
std::array<int, MAX_MHA_DIM> v_stride;
std::array<int, MAX_MHA_DIM> bias_dim;
std::array<int, MAX_MHA_DIM> bias_stride;
int64_t b;
int64_t h;
int64_t s_q;
@ -135,6 +140,9 @@ struct MHAParams {
double dropout_probability;
bool is_causal;
bool return_softmaxstats;
// might be redundant if we take 0 dim/stride
// as signaling no-bias
bool has_attn_bias;
};
void setMHAParams(
@ -148,6 +156,7 @@ void setMHAParams(
const Tensor& q,
const Tensor& k,
const Tensor& v,
const std::optional<Tensor>& attn_bias,
double dropout_probability,
bool is_causal,
bool return_softmaxstats) {
@ -166,6 +175,7 @@ void setMHAParams(
params.dropout_probability = dropout_probability;
params.is_causal = is_causal;
params.return_softmaxstats = return_softmaxstats;
params.has_attn_bias = attn_bias.has_value();
TORCH_INTERNAL_ASSERT(
q.sizes().size() == MAX_MHA_DIM,
"Q tensor has unexpected number of dims, please report a bug to PyTorch.");
@ -190,6 +200,17 @@ void setMHAParams(
std::copy(k.strides().begin(), k.strides().end(), params.k_stride.begin());
std::copy(v.sizes().begin(), v.sizes().end(), params.v_dim.begin());
std::copy(v.strides().begin(), v.strides().end(), params.v_stride.begin());
// uninit is OK as the struct is memset 0'd
if (params.has_attn_bias) {
std::copy(
attn_bias.value().sizes().begin(),
attn_bias.value().sizes().end(),
params.bias_dim.begin());
std::copy(
attn_bias.value().strides().begin(),
attn_bias.value().strides().end(),
params.bias_stride.begin());
}
}
struct MHACacheKeyWrapper : ParamsWrapper<MHAParams> {
@ -203,6 +224,7 @@ struct MHACacheKeyWrapper : ParamsWrapper<MHAParams> {
const Tensor& q,
const Tensor& k,
const Tensor& v,
const std::optional<Tensor>& attn_bias,
double dropout_probability,
bool is_causal,
bool return_softmaxstats) {
@ -217,6 +239,7 @@ struct MHACacheKeyWrapper : ParamsWrapper<MHAParams> {
q,
k,
v,
attn_bias,
dropout_probability,
is_causal,
return_softmaxstats);
@ -285,6 +308,7 @@ auto build_graph_and_tensors(
const Tensor& q,
const Tensor& k,
const Tensor& v,
const std::optional<Tensor>& attn_bias,
Tensor& softmaxstats,
Tensor& o,
Tensor& dropoutseed,
@ -301,36 +325,6 @@ auto build_graph_and_tensors(
mha_graph->set_io_data_type(dtype)
.set_intermediate_data_type(fe::DataType_t::FLOAT)
.set_compute_data_type(fe::DataType_t::FLOAT);
auto Q = mha_graph->tensor(
fe::graph::Tensor_attributes()
.set_name("Q")
.set_dim(std::vector<int64_t>(
q.sizes().data(), q.sizes().data() + q.sizes().size()))
.set_stride(fixSizeOneDimStrideSDPA(
q.sizes(),
std::vector<int64_t>(
q.strides().data(),
q.strides().data() + q.strides().size()))));
auto K = mha_graph->tensor(
fe::graph::Tensor_attributes()
.set_name("K")
.set_dim(std::vector<int64_t>(
k.sizes().data(), k.sizes().data() + k.sizes().size()))
.set_stride(fixSizeOneDimStrideSDPA(
k.sizes(),
std::vector<int64_t>(
k.strides().data(),
k.strides().data() + k.strides().size()))));
auto V = mha_graph->tensor(
fe::graph::Tensor_attributes()
.set_name("V")
.set_dim(std::vector<int64_t>(
v.sizes().data(), v.sizes().data() + v.sizes().size()))
.set_stride(fixSizeOneDimStrideSDPA(
v.sizes(),
std::vector<int64_t>(
v.strides().data(),
v.strides().data() + v.strides().size()))));
auto attn_scale =
mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Attn_scale")
@ -338,11 +332,6 @@ auto build_graph_and_tensors(
.set_stride({1, 1, 1, 1})
.set_is_pass_by_value(true)
.set_data_type(fe::DataType_t::FLOAT));
// TODO(eqy): support bias in the future in a follow-up PR
// auto bias = mha_graph->tensor(fe::graph::Tensor_attributes()
// .set_name("bias")
// .set_dim({b, 1, s_q, s_kv})
// .set_stride({s_q * s_kv, s_q * s_kv, s_kv, 1}));
auto seed = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Seed")
.set_dim({1, 1, 1, 1})
@ -360,11 +349,30 @@ auto build_graph_and_tensors(
.set_causal_mask(is_causal)
.set_attn_scale(attn_scale)
.set_dropout(dropout_probability, seed, offset);
// Optional bias in flash attention is only supported 8.9.3 onwards
if (cudnnGetVersion() >= 8904) {
// scaled_dot_product_flash_attention_options.set_alibi_mask(true);
auto Q = mha_graph->tensor(
fe::graph::Tensor_attributes()
.set_name("Q")
.set_dim(q.sizes().vec())
.set_stride(fixSizeOneDimStrideSDPA(q.sizes(), q.strides().vec())));
auto K = mha_graph->tensor(
fe::graph::Tensor_attributes()
.set_name("K")
.set_dim(k.sizes().vec())
.set_stride(fixSizeOneDimStrideSDPA(k.sizes(), k.strides().vec())));
auto V = mha_graph->tensor(
fe::graph::Tensor_attributes()
.set_name("V")
.set_dim(v.sizes().vec())
.set_stride(fixSizeOneDimStrideSDPA(v.sizes(), v.strides().vec())));
std::optional<std::shared_ptr<fe::graph::Tensor_attributes>> bias;
if (attn_bias.has_value()) {
bias =
mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("bias")
.set_dim(attn_bias.value().sizes().vec())
.set_stride(attn_bias.value().strides().vec()));
scaled_dot_product_flash_attention_options.set_bias(bias.value());
}
auto seq_q = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Seq_q")
.set_dim({b, 1, 1, 1})
@ -376,20 +384,9 @@ auto build_graph_and_tensors(
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::INT32));
// if (cudnnGetVersion() >= 8903) {
// scaled_dot_product_flash_attention_options.set_bias(bias)
// .set_padding_mask(true)
// .set_seq_len_q(seq_q)
// .set_seq_len_kv(seq_kv);
// }
auto [O, Stats] =
mha_graph->sdpa(Q, K, V, scaled_dot_product_flash_attention_options);
O->set_output(true)
.set_dim(std::vector<int64_t>(
o.sizes().data(), o.sizes().data() + o.sizes().size()))
.set_stride(std::vector<int64_t>(
o.strides().data(), o.strides().data() + o.strides().size()));
O->set_output(true).set_dim(o.sizes().vec()).set_stride(o.strides().vec());
if (Stats) {
Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT);
@ -407,6 +404,7 @@ auto build_graph_and_tensors(
std::move(Q),
std::move(K),
std::move(V),
std::move(bias),
std::move(attn_scale),
std::move(seed),
std::move(offset),
@ -427,6 +425,7 @@ auto build_graph_and_tensors_backward(
const Tensor& q,
const Tensor& k,
const Tensor& v,
const std::optional<Tensor>& attn_bias,
const Tensor& o,
const Tensor& dO,
const Tensor& softmaxstats,
@ -447,24 +446,6 @@ auto build_graph_and_tensors_backward(
mha_graph->set_io_data_type(dtype)
.set_intermediate_data_type(fe::DataType_t::FLOAT)
.set_compute_data_type(fe::DataType_t::FLOAT);
auto Q = mha_graph->tensor(
fe::graph::Tensor_attributes()
.set_name("Q")
.set_dim(std::vector<int64_t>(q.sizes().begin(), q.sizes().end()))
.set_stride(
std::vector<int64_t>(q.strides().begin(), q.strides().end())));
auto K = mha_graph->tensor(
fe::graph::Tensor_attributes()
.set_name("K")
.set_dim(std::vector<int64_t>(k.sizes().begin(), k.sizes().end()))
.set_stride(
std::vector<int64_t>(k.strides().begin(), k.strides().end())));
auto V = mha_graph->tensor(
fe::graph::Tensor_attributes()
.set_name("V")
.set_dim(std::vector<int64_t>(v.sizes().begin(), v.sizes().end()))
.set_stride(
std::vector<int64_t>(v.strides().begin(), v.strides().end())));
auto attn_scale =
mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Attn_scale")
@ -472,6 +453,31 @@ auto build_graph_and_tensors_backward(
.set_stride({1, 1, 1, 1})
.set_is_pass_by_value(true)
.set_data_type(fe::DataType_t::FLOAT));
auto sdpa_backward_options = fe::graph::SDPA_backward_attributes()
.set_name("CUDNN_SDPA_BACKWARD")
.set_causal_mask(is_causal)
.set_attn_scale(attn_scale);
auto Q = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Q")
.set_dim(q.sizes().vec())
.set_stride(q.strides().vec()));
auto K = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("K")
.set_dim(k.sizes().vec())
.set_stride(k.strides().vec()));
auto V = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("V")
.set_dim(v.sizes().vec())
.set_stride(v.strides().vec()));
std::optional<std::shared_ptr<fe::graph::Tensor_attributes>> bias;
if (attn_bias.has_value()) {
bias =
mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("bias")
.set_dim(attn_bias.value().sizes().vec())
.set_stride(attn_bias.value().strides().vec()));
sdpa_backward_options.set_bias(bias.value());
}
auto Seed = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Seed")
.set_dim({1, 1, 1, 1})
@ -482,47 +488,27 @@ auto build_graph_and_tensors_backward(
.set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::INT32));
auto O = mha_graph->tensor(
fe::graph::Tensor_attributes()
.set_name("O")
.set_dim(std::vector<int64_t>(o.sizes().begin(), o.sizes().end()))
.set_stride(
std::vector<int64_t>(o.strides().begin(), o.strides().end())));
auto STATS = mha_graph->tensor(
fe::graph::Tensor_attributes()
.set_name("Stats")
.set_dim(std::vector<int64_t>(
softmaxstats.sizes().begin(), softmaxstats.sizes().end()))
.set_stride(std::vector<int64_t>(
softmaxstats.strides().begin(), softmaxstats.strides().end()))
.set_data_type(fe::DataType_t::FLOAT));
auto DO = mha_graph->tensor(
fe::graph::Tensor_attributes()
.set_name("DO")
.set_dim(std::vector<int64_t>(dO.sizes().begin(), dO.sizes().end()))
.set_stride(
std::vector<int64_t>(dO.strides().begin(), dO.strides().end())));
auto sdpa_backward_options = fe::graph::SDPA_backward_attributes()
.set_name("CUDNN_SDPA_BACKWARD")
.set_causal_mask(is_causal)
.set_attn_scale(attn_scale);
auto O = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("O")
.set_dim(o.sizes().vec())
.set_stride(o.strides().vec()));
auto STATS = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Stats")
.set_dim(softmaxstats.sizes().vec())
.set_stride(softmaxstats.strides().vec())
.set_data_type(fe::DataType_t::FLOAT));
auto DO = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("DO")
.set_dim(dO.sizes().vec())
.set_stride(dO.strides().vec()));
if (dropout_probability != 0.0f) {
sdpa_backward_options.set_dropout(dropout_probability, Seed, Offset);
}
auto [DQ, DK, DV] =
mha_graph->sdpa_backward(Q, K, V, O, DO, STATS, sdpa_backward_options);
DQ->set_output(true)
.set_dim(std::vector<int64_t>(dQ.sizes().begin(), dQ.sizes().end()))
.set_stride(
std::vector<int64_t>(dQ.strides().begin(), dQ.strides().end()));
DK->set_output(true)
.set_dim(std::vector<int64_t>(dK.sizes().begin(), dK.sizes().end()))
.set_stride(
std::vector<int64_t>(dK.strides().begin(), dK.strides().end()));
DV->set_output(true)
.set_dim(std::vector<int64_t>(dV.sizes().begin(), dV.sizes().end()))
.set_stride(
std::vector<int64_t>(dV.strides().begin(), dV.strides().end()));
DQ->set_output(true).set_dim(dQ.sizes().vec()).set_stride(dQ.strides().vec());
DK->set_output(true).set_dim(dK.sizes().vec()).set_stride(dK.strides().vec());
DV->set_output(true).set_dim(dV.sizes().vec()).set_stride(dV.strides().vec());
AT_CUDNN_FRONTEND_CHECK(mha_graph->validate());
AT_CUDNN_FRONTEND_CHECK(mha_graph->build_operation_graph(handle));
AT_CUDNN_FRONTEND_CHECK(
@ -534,6 +520,7 @@ auto build_graph_and_tensors_backward(
std::move(Q),
std::move(K),
std::move(V),
std::move(bias),
std::move(attn_scale),
std::move(Seed),
std::move(Offset),
@ -559,6 +546,7 @@ void run_cudnn_SDP_fprop(
const Tensor& q,
const Tensor& k,
const Tensor& v,
const std::optional<Tensor>& attn_bias,
Tensor& softmaxstats,
Tensor& o,
Tensor& dropoutseed,
@ -583,6 +571,7 @@ void run_cudnn_SDP_fprop(
q,
k,
v,
attn_bias,
dropout_probability,
is_causal,
return_softmaxstats);
@ -605,13 +594,14 @@ void run_cudnn_SDP_fprop(
q,
k,
v,
attn_bias,
softmaxstats,
o,
dropoutseed,
dropoutoffset,
handle);
}
auto [mha_graph, Q, K, V, attn_scale, seed, offset, O, Stats] =
auto [mha_graph, Q, K, V, bias, attn_scale, seed, offset, O, Stats] =
graph_and_tensors_values;
std::unordered_map<std::shared_ptr<fe::graph::Tensor_attributes>, void*>
variant_pack = {
@ -619,13 +609,15 @@ void run_cudnn_SDP_fprop(
{K, k.data_ptr()},
{V, v.data_ptr()},
{attn_scale, &scaling_factor},
//{bias, bias.data_ptr()},
{seed, dropoutseed.data_ptr()},
{offset, dropoutoffset.data_ptr()},
{O, o.data_ptr()}};
if (return_softmaxstats) {
variant_pack[Stats] = softmaxstats.data_ptr();
}
if (attn_bias.has_value()) {
variant_pack[bias.value()] = attn_bias.value().data_ptr();
}
auto workspace_size = mha_graph->get_workspace_size();
auto workspace_ptr =
c10::cuda::CUDACachingAllocator::get()->allocate(workspace_size);
@ -647,6 +639,7 @@ void run_cudnn_SDP_bprop(
const Tensor& q,
const Tensor& k,
const Tensor& v,
const std::optional<Tensor>& attn_bias,
const Tensor& o,
const Tensor& dO,
const Tensor& softmaxstats,
@ -690,6 +683,7 @@ void run_cudnn_SDP_bprop(
q,
k,
v,
attn_bias,
dropout_probability,
is_causal,
true);
@ -711,6 +705,7 @@ void run_cudnn_SDP_bprop(
q,
k,
v,
attn_bias,
o,
dO_,
softmaxstats,
@ -722,8 +717,20 @@ void run_cudnn_SDP_bprop(
handle);
}
auto
[mha_graph, Q, K, V, attn_scale, Seed, Offset, O, Do, Stats, Dq, Dk, Dv] =
graph_and_tensors_backward_values;
[mha_graph,
Q,
K,
V,
bias,
attn_scale,
Seed,
Offset,
O,
Do,
Stats,
Dq,
Dk,
Dv] = graph_and_tensors_backward_values;
std::unordered_map<std::shared_ptr<fe::graph::Tensor_attributes>, void*>
variant_pack = {// inputs
{Q, q.data_ptr()},
@ -742,6 +749,9 @@ void run_cudnn_SDP_bprop(
variant_pack[Seed] = dropoutseed.data_ptr();
variant_pack[Offset] = dropoutoffset.data_ptr();
}
if (attn_bias.has_value()) {
variant_pack[bias.value()] = attn_bias.value().data_ptr();
}
auto workspace_size = mha_graph->get_workspace_size();
auto workspace_ptr =
c10::cuda::CUDACachingAllocator::get()->allocate(workspace_size);

View File

@ -18,6 +18,7 @@ void run_cudnn_SDP_fprop(
const Tensor& q,
const Tensor& k,
const Tensor& v,
const std::optional<Tensor>& attn_bias,
Tensor& softmaxstats,
Tensor& o,
Tensor& dropoutseed,
@ -36,6 +37,7 @@ void run_cudnn_SDP_bprop(
const Tensor& q,
const Tensor& k,
const Tensor& v,
const std::optional<Tensor>& attn_bias,
const Tensor& o,
const Tensor& dO,
const Tensor& softmaxstats,

View File

@ -658,7 +658,7 @@ Tensor scaled_dot_product_attention(
case sdp::SDPBackend::cudnn_attention: {
bool compute_logsumexp = should_compute_logsumexp(query_, key, value);
auto out_lse_softmax = at::_scaled_dot_product_cudnn_attention(
query_, key, value, attn_mask_, compute_logsumexp, dropout_p, is_causal, false /*return_debug_mask*/, scale);
query_, key, value, attn_mask, compute_logsumexp, dropout_p, is_causal, false /*return_debug_mask*/, scale);
return std::get<0>(out_lse_softmax);
}
case sdp::SDPBackend::flash_attention: {

View File

@ -774,6 +774,18 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt, Tensor, Ten
TORCH_CHECK(
max_seqlen_batch_k == max_seqlen_batch_v,
"Key and Value must have the same sequence length");
auto attn_bias_ = attn_bias;
if (attn_bias_.has_value()) {
const auto bias_dim = attn_bias_.value().dim();
if (bias_dim == 2) {
attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k});
} else if (bias_dim == 3) {
attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k});
} else {
attn_bias_ = attn_bias_.value().expand({batch_size, attn_bias_.value().size(1), max_seqlen_batch_q, max_seqlen_batch_k});
TORCH_CHECK(bias_dim == 4, "cuDNN SDPA expects either a 2D, 3D, or 4D attn_bias but got ", attn_bias_.value().dim(), "D");
}
}
Tensor attention, log_sumexp;
@ -818,13 +830,14 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt, Tensor, Ten
query/* Tensor q*/,
key/* Tensor k*/,
value/* Tensor v*/,
attn_bias_ /* std::optional<Tensor> */,
log_sumexp/*Tensor softmaxstats*/,
attention/*Tensor o*/,
cudnn_seed/*Tensor dropoutseed*/,
cudnn_offset/*Tensor dropoutoffset*/);
// TODO(eqy): support debug_attn_mask
return std::make_tuple(attention, log_sumexp, Tensor(), Tensor(), max_seqlen_batch_q, max_seqlen_batch_k, cudnn_seed, cudnn_offset, Tensor());
return std::make_tuple(std::move(attention), std::move(log_sumexp), std::move(Tensor()), std::move(Tensor()), max_seqlen_batch_q, max_seqlen_batch_k, std::move(cudnn_seed), std::move(cudnn_offset), std::move(Tensor()));
}
std::tuple<Tensor, Tensor, Tensor, Tensor> _scaled_dot_product_efficient_attention_cuda(

View File

@ -29,6 +29,7 @@
#include <ATen/ops/_efficient_attention_backward.h>
#include <ATen/ops/_efficient_attention_backward_native.h>
#include <ATen/ops/_scaled_dot_product_flash_attention_backward_native.h>
#include <ATen/ops/zeros.h>
#endif
#ifdef USE_FLASH_ATTENTION
@ -195,6 +196,27 @@ std::tuple<Tensor, Tensor, Tensor> _scaled_dot_product_cudnn_attention_backward_
const int64_t num_heads = query.size(1);
const int64_t head_dim_qk = query.size(3);
const int64_t head_dim_v = value.size(3);
const int64_t max_seqlen_batch_q = query.size(2);
const int64_t max_seqlen_batch_k = key.size(2);
// This is needed because SaveVariable automatically converts
// std::optional to undefined tensor
std::optional<Tensor> attn_bias_;
if (attn_bias.defined()) {
attn_bias_ = attn_bias;
}
if (attn_bias_.has_value()) {
const auto bias_dim = attn_bias_.value().dim();
if (bias_dim == 2) {
attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k});
} else if (bias_dim == 3) {
attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k});
} else {
attn_bias_ = attn_bias_.value().expand({batch_size, attn_bias_.value().size(1), max_seqlen_batch_q, max_seqlen_batch_k});
TORCH_CHECK(bias_dim == 4, "cuDNN SDPA expects either a 2D, 3D, or 4D attn_bias but got ", attn_bias_.value().dim(), "D");
}
}
const auto softmax_scale = sdp::calculate_scale(query, scale).as_float_unchecked();
auto dq = at::empty_like(query);
auto dk = at::empty_like(key);
@ -211,6 +233,7 @@ std::tuple<Tensor, Tensor, Tensor> _scaled_dot_product_cudnn_attention_backward_
query /*const Tensor& q*/,
key /*const Tensor& k*/,
value /*const Tensor& v*/,
attn_bias_ /*const std::optional<Tensor>& attn_bias*/,
out /*const Tensor& o*/,
grad_out/*const Tensor& dO*/,
logsumexp.unsqueeze(-1)/*const Tensor& softmaxstats*/,
@ -219,7 +242,7 @@ std::tuple<Tensor, Tensor, Tensor> _scaled_dot_product_cudnn_attention_backward_
dv/*Tensor& dV*/,
philox_seed/*Tensor& dropoutseed*/,
philox_offset/*Tensor& dropoutoffset*/);
return std::make_tuple(dq, dk, dv);
return std::make_tuple(std::move(dq), std::move(dk), std::move(dv));
}
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>

View File

@ -550,7 +550,7 @@ bool can_use_cudnn_attention(const sdp_params& params, bool debug) {
check_cudnn_deterministic,
// check_is_causal,
check_dtypes_low_precision,
check_for_attn_mask_cudnn,
check_attn_mask_shape,
check_cudnn_hardware_support
);
for (auto& constraint : general_constraints) {

View File

@ -274,17 +274,6 @@ inline bool check_for_attn_mask(sdp_params const& params, bool debug) {
return true;
}
// TODO(eqy): remove this once support is added
inline bool check_for_attn_mask_cudnn(sdp_params const& params, bool debug) {
if (params.attn_mask.has_value()) {
if (debug) {
TORCH_WARN("cuDNN Attention does not support non-null attn_mask.");
}
return false;
}
return true;
}
inline bool check_attn_mask_shape(sdp_params const& params, bool debug) {
auto attn_mask = params.attn_mask;
if (!attn_mask.has_value()) {