mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[cuDNN][SDPA] Support attn_bias in cuDNN (#130482)
CC @drisspg Pull Request resolved: https://github.com/pytorch/pytorch/pull/130482 Approved by: https://github.com/drisspg
This commit is contained in:
parent
4f40a7078e
commit
de177b50f8
|
|
@ -22,6 +22,7 @@ void run_cudnn_SDP_fprop(
|
|||
const Tensor& q,
|
||||
const Tensor& k,
|
||||
const Tensor& v,
|
||||
const std::optional<Tensor>& attn_bias,
|
||||
Tensor& softmaxstats,
|
||||
Tensor& o,
|
||||
Tensor& dropoutseed,
|
||||
|
|
@ -43,6 +44,7 @@ void run_cudnn_SDP_bprop(
|
|||
const Tensor& q,
|
||||
const Tensor& k,
|
||||
const Tensor& v,
|
||||
const std::optional<Tensor>& attn_bias,
|
||||
const Tensor& o,
|
||||
const Tensor& dO,
|
||||
const Tensor& softmaxstats,
|
||||
|
|
@ -86,9 +88,9 @@ using graph_and_tensors = std::tuple<
|
|||
std::shared_ptr<fe::graph::Tensor_attributes>, // Q,
|
||||
std::shared_ptr<fe::graph::Tensor_attributes>, // K,
|
||||
std::shared_ptr<fe::graph::Tensor_attributes>, // V,
|
||||
std::optional<std::shared_ptr<fe::graph::Tensor_attributes>>, // Bias
|
||||
std::shared_ptr<fe::graph::Tensor_attributes>, // Attn_scale,
|
||||
// TODO(eqy): additional options
|
||||
// std::shared_ptr<fe::graph::Tensor_attributes>, // Bias,
|
||||
// std::shared_ptr<fe::graph::Tensor_attributes>, // SEQ_LEN_Q,
|
||||
// std::shared_ptr<fe::graph::Tensor_attributes>, // SEQ_LEN_KV,
|
||||
std::shared_ptr<fe::graph::Tensor_attributes>, // Seed,
|
||||
|
|
@ -104,7 +106,8 @@ using graph_and_tensors_backward = std::tuple<
|
|||
std::shared_ptr<fe::graph::Tensor_attributes>, // Q,
|
||||
std::shared_ptr<fe::graph::Tensor_attributes>, // K,
|
||||
std::shared_ptr<fe::graph::Tensor_attributes>, // V,
|
||||
std::shared_ptr<fe::graph::Tensor_attributes>, // Attn_scale
|
||||
std::optional<std::shared_ptr<fe::graph::Tensor_attributes>>, // Bias,
|
||||
std::shared_ptr<fe::graph::Tensor_attributes>, // Attn_scale,
|
||||
std::shared_ptr<fe::graph::Tensor_attributes>, // Seed,
|
||||
std::shared_ptr<fe::graph::Tensor_attributes>, // Offset,
|
||||
std::shared_ptr<fe::graph::Tensor_attributes>, // O,
|
||||
|
|
@ -126,6 +129,8 @@ struct MHAParams {
|
|||
std::array<int, MAX_MHA_DIM> q_stride;
|
||||
std::array<int, MAX_MHA_DIM> k_stride;
|
||||
std::array<int, MAX_MHA_DIM> v_stride;
|
||||
std::array<int, MAX_MHA_DIM> bias_dim;
|
||||
std::array<int, MAX_MHA_DIM> bias_stride;
|
||||
int64_t b;
|
||||
int64_t h;
|
||||
int64_t s_q;
|
||||
|
|
@ -135,6 +140,9 @@ struct MHAParams {
|
|||
double dropout_probability;
|
||||
bool is_causal;
|
||||
bool return_softmaxstats;
|
||||
// might be redundant if we take 0 dim/stride
|
||||
// as signaling no-bias
|
||||
bool has_attn_bias;
|
||||
};
|
||||
|
||||
void setMHAParams(
|
||||
|
|
@ -148,6 +156,7 @@ void setMHAParams(
|
|||
const Tensor& q,
|
||||
const Tensor& k,
|
||||
const Tensor& v,
|
||||
const std::optional<Tensor>& attn_bias,
|
||||
double dropout_probability,
|
||||
bool is_causal,
|
||||
bool return_softmaxstats) {
|
||||
|
|
@ -166,6 +175,7 @@ void setMHAParams(
|
|||
params.dropout_probability = dropout_probability;
|
||||
params.is_causal = is_causal;
|
||||
params.return_softmaxstats = return_softmaxstats;
|
||||
params.has_attn_bias = attn_bias.has_value();
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
q.sizes().size() == MAX_MHA_DIM,
|
||||
"Q tensor has unexpected number of dims, please report a bug to PyTorch.");
|
||||
|
|
@ -190,6 +200,17 @@ void setMHAParams(
|
|||
std::copy(k.strides().begin(), k.strides().end(), params.k_stride.begin());
|
||||
std::copy(v.sizes().begin(), v.sizes().end(), params.v_dim.begin());
|
||||
std::copy(v.strides().begin(), v.strides().end(), params.v_stride.begin());
|
||||
// uninit is OK as the struct is memset 0'd
|
||||
if (params.has_attn_bias) {
|
||||
std::copy(
|
||||
attn_bias.value().sizes().begin(),
|
||||
attn_bias.value().sizes().end(),
|
||||
params.bias_dim.begin());
|
||||
std::copy(
|
||||
attn_bias.value().strides().begin(),
|
||||
attn_bias.value().strides().end(),
|
||||
params.bias_stride.begin());
|
||||
}
|
||||
}
|
||||
|
||||
struct MHACacheKeyWrapper : ParamsWrapper<MHAParams> {
|
||||
|
|
@ -203,6 +224,7 @@ struct MHACacheKeyWrapper : ParamsWrapper<MHAParams> {
|
|||
const Tensor& q,
|
||||
const Tensor& k,
|
||||
const Tensor& v,
|
||||
const std::optional<Tensor>& attn_bias,
|
||||
double dropout_probability,
|
||||
bool is_causal,
|
||||
bool return_softmaxstats) {
|
||||
|
|
@ -217,6 +239,7 @@ struct MHACacheKeyWrapper : ParamsWrapper<MHAParams> {
|
|||
q,
|
||||
k,
|
||||
v,
|
||||
attn_bias,
|
||||
dropout_probability,
|
||||
is_causal,
|
||||
return_softmaxstats);
|
||||
|
|
@ -285,6 +308,7 @@ auto build_graph_and_tensors(
|
|||
const Tensor& q,
|
||||
const Tensor& k,
|
||||
const Tensor& v,
|
||||
const std::optional<Tensor>& attn_bias,
|
||||
Tensor& softmaxstats,
|
||||
Tensor& o,
|
||||
Tensor& dropoutseed,
|
||||
|
|
@ -301,36 +325,6 @@ auto build_graph_and_tensors(
|
|||
mha_graph->set_io_data_type(dtype)
|
||||
.set_intermediate_data_type(fe::DataType_t::FLOAT)
|
||||
.set_compute_data_type(fe::DataType_t::FLOAT);
|
||||
auto Q = mha_graph->tensor(
|
||||
fe::graph::Tensor_attributes()
|
||||
.set_name("Q")
|
||||
.set_dim(std::vector<int64_t>(
|
||||
q.sizes().data(), q.sizes().data() + q.sizes().size()))
|
||||
.set_stride(fixSizeOneDimStrideSDPA(
|
||||
q.sizes(),
|
||||
std::vector<int64_t>(
|
||||
q.strides().data(),
|
||||
q.strides().data() + q.strides().size()))));
|
||||
auto K = mha_graph->tensor(
|
||||
fe::graph::Tensor_attributes()
|
||||
.set_name("K")
|
||||
.set_dim(std::vector<int64_t>(
|
||||
k.sizes().data(), k.sizes().data() + k.sizes().size()))
|
||||
.set_stride(fixSizeOneDimStrideSDPA(
|
||||
k.sizes(),
|
||||
std::vector<int64_t>(
|
||||
k.strides().data(),
|
||||
k.strides().data() + k.strides().size()))));
|
||||
auto V = mha_graph->tensor(
|
||||
fe::graph::Tensor_attributes()
|
||||
.set_name("V")
|
||||
.set_dim(std::vector<int64_t>(
|
||||
v.sizes().data(), v.sizes().data() + v.sizes().size()))
|
||||
.set_stride(fixSizeOneDimStrideSDPA(
|
||||
v.sizes(),
|
||||
std::vector<int64_t>(
|
||||
v.strides().data(),
|
||||
v.strides().data() + v.strides().size()))));
|
||||
auto attn_scale =
|
||||
mha_graph->tensor(fe::graph::Tensor_attributes()
|
||||
.set_name("Attn_scale")
|
||||
|
|
@ -338,11 +332,6 @@ auto build_graph_and_tensors(
|
|||
.set_stride({1, 1, 1, 1})
|
||||
.set_is_pass_by_value(true)
|
||||
.set_data_type(fe::DataType_t::FLOAT));
|
||||
// TODO(eqy): support bias in the future in a follow-up PR
|
||||
// auto bias = mha_graph->tensor(fe::graph::Tensor_attributes()
|
||||
// .set_name("bias")
|
||||
// .set_dim({b, 1, s_q, s_kv})
|
||||
// .set_stride({s_q * s_kv, s_q * s_kv, s_kv, 1}));
|
||||
auto seed = mha_graph->tensor(fe::graph::Tensor_attributes()
|
||||
.set_name("Seed")
|
||||
.set_dim({1, 1, 1, 1})
|
||||
|
|
@ -360,11 +349,30 @@ auto build_graph_and_tensors(
|
|||
.set_causal_mask(is_causal)
|
||||
.set_attn_scale(attn_scale)
|
||||
.set_dropout(dropout_probability, seed, offset);
|
||||
// Optional bias in flash attention is only supported 8.9.3 onwards
|
||||
if (cudnnGetVersion() >= 8904) {
|
||||
// scaled_dot_product_flash_attention_options.set_alibi_mask(true);
|
||||
auto Q = mha_graph->tensor(
|
||||
fe::graph::Tensor_attributes()
|
||||
.set_name("Q")
|
||||
.set_dim(q.sizes().vec())
|
||||
.set_stride(fixSizeOneDimStrideSDPA(q.sizes(), q.strides().vec())));
|
||||
auto K = mha_graph->tensor(
|
||||
fe::graph::Tensor_attributes()
|
||||
.set_name("K")
|
||||
.set_dim(k.sizes().vec())
|
||||
.set_stride(fixSizeOneDimStrideSDPA(k.sizes(), k.strides().vec())));
|
||||
auto V = mha_graph->tensor(
|
||||
fe::graph::Tensor_attributes()
|
||||
.set_name("V")
|
||||
.set_dim(v.sizes().vec())
|
||||
.set_stride(fixSizeOneDimStrideSDPA(v.sizes(), v.strides().vec())));
|
||||
std::optional<std::shared_ptr<fe::graph::Tensor_attributes>> bias;
|
||||
if (attn_bias.has_value()) {
|
||||
bias =
|
||||
mha_graph->tensor(fe::graph::Tensor_attributes()
|
||||
.set_name("bias")
|
||||
.set_dim(attn_bias.value().sizes().vec())
|
||||
.set_stride(attn_bias.value().strides().vec()));
|
||||
scaled_dot_product_flash_attention_options.set_bias(bias.value());
|
||||
}
|
||||
|
||||
auto seq_q = mha_graph->tensor(fe::graph::Tensor_attributes()
|
||||
.set_name("Seq_q")
|
||||
.set_dim({b, 1, 1, 1})
|
||||
|
|
@ -376,20 +384,9 @@ auto build_graph_and_tensors(
|
|||
.set_stride({1, 1, 1, 1})
|
||||
.set_data_type(fe::DataType_t::INT32));
|
||||
|
||||
// if (cudnnGetVersion() >= 8903) {
|
||||
// scaled_dot_product_flash_attention_options.set_bias(bias)
|
||||
// .set_padding_mask(true)
|
||||
// .set_seq_len_q(seq_q)
|
||||
// .set_seq_len_kv(seq_kv);
|
||||
// }
|
||||
|
||||
auto [O, Stats] =
|
||||
mha_graph->sdpa(Q, K, V, scaled_dot_product_flash_attention_options);
|
||||
O->set_output(true)
|
||||
.set_dim(std::vector<int64_t>(
|
||||
o.sizes().data(), o.sizes().data() + o.sizes().size()))
|
||||
.set_stride(std::vector<int64_t>(
|
||||
o.strides().data(), o.strides().data() + o.strides().size()));
|
||||
O->set_output(true).set_dim(o.sizes().vec()).set_stride(o.strides().vec());
|
||||
|
||||
if (Stats) {
|
||||
Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT);
|
||||
|
|
@ -407,6 +404,7 @@ auto build_graph_and_tensors(
|
|||
std::move(Q),
|
||||
std::move(K),
|
||||
std::move(V),
|
||||
std::move(bias),
|
||||
std::move(attn_scale),
|
||||
std::move(seed),
|
||||
std::move(offset),
|
||||
|
|
@ -427,6 +425,7 @@ auto build_graph_and_tensors_backward(
|
|||
const Tensor& q,
|
||||
const Tensor& k,
|
||||
const Tensor& v,
|
||||
const std::optional<Tensor>& attn_bias,
|
||||
const Tensor& o,
|
||||
const Tensor& dO,
|
||||
const Tensor& softmaxstats,
|
||||
|
|
@ -447,24 +446,6 @@ auto build_graph_and_tensors_backward(
|
|||
mha_graph->set_io_data_type(dtype)
|
||||
.set_intermediate_data_type(fe::DataType_t::FLOAT)
|
||||
.set_compute_data_type(fe::DataType_t::FLOAT);
|
||||
auto Q = mha_graph->tensor(
|
||||
fe::graph::Tensor_attributes()
|
||||
.set_name("Q")
|
||||
.set_dim(std::vector<int64_t>(q.sizes().begin(), q.sizes().end()))
|
||||
.set_stride(
|
||||
std::vector<int64_t>(q.strides().begin(), q.strides().end())));
|
||||
auto K = mha_graph->tensor(
|
||||
fe::graph::Tensor_attributes()
|
||||
.set_name("K")
|
||||
.set_dim(std::vector<int64_t>(k.sizes().begin(), k.sizes().end()))
|
||||
.set_stride(
|
||||
std::vector<int64_t>(k.strides().begin(), k.strides().end())));
|
||||
auto V = mha_graph->tensor(
|
||||
fe::graph::Tensor_attributes()
|
||||
.set_name("V")
|
||||
.set_dim(std::vector<int64_t>(v.sizes().begin(), v.sizes().end()))
|
||||
.set_stride(
|
||||
std::vector<int64_t>(v.strides().begin(), v.strides().end())));
|
||||
auto attn_scale =
|
||||
mha_graph->tensor(fe::graph::Tensor_attributes()
|
||||
.set_name("Attn_scale")
|
||||
|
|
@ -472,6 +453,31 @@ auto build_graph_and_tensors_backward(
|
|||
.set_stride({1, 1, 1, 1})
|
||||
.set_is_pass_by_value(true)
|
||||
.set_data_type(fe::DataType_t::FLOAT));
|
||||
auto sdpa_backward_options = fe::graph::SDPA_backward_attributes()
|
||||
.set_name("CUDNN_SDPA_BACKWARD")
|
||||
.set_causal_mask(is_causal)
|
||||
.set_attn_scale(attn_scale);
|
||||
auto Q = mha_graph->tensor(fe::graph::Tensor_attributes()
|
||||
.set_name("Q")
|
||||
.set_dim(q.sizes().vec())
|
||||
.set_stride(q.strides().vec()));
|
||||
auto K = mha_graph->tensor(fe::graph::Tensor_attributes()
|
||||
.set_name("K")
|
||||
.set_dim(k.sizes().vec())
|
||||
.set_stride(k.strides().vec()));
|
||||
auto V = mha_graph->tensor(fe::graph::Tensor_attributes()
|
||||
.set_name("V")
|
||||
.set_dim(v.sizes().vec())
|
||||
.set_stride(v.strides().vec()));
|
||||
std::optional<std::shared_ptr<fe::graph::Tensor_attributes>> bias;
|
||||
if (attn_bias.has_value()) {
|
||||
bias =
|
||||
mha_graph->tensor(fe::graph::Tensor_attributes()
|
||||
.set_name("bias")
|
||||
.set_dim(attn_bias.value().sizes().vec())
|
||||
.set_stride(attn_bias.value().strides().vec()));
|
||||
sdpa_backward_options.set_bias(bias.value());
|
||||
}
|
||||
auto Seed = mha_graph->tensor(fe::graph::Tensor_attributes()
|
||||
.set_name("Seed")
|
||||
.set_dim({1, 1, 1, 1})
|
||||
|
|
@ -482,47 +488,27 @@ auto build_graph_and_tensors_backward(
|
|||
.set_dim({1, 1, 1, 1})
|
||||
.set_stride({1, 1, 1, 1})
|
||||
.set_data_type(fe::DataType_t::INT32));
|
||||
auto O = mha_graph->tensor(
|
||||
fe::graph::Tensor_attributes()
|
||||
.set_name("O")
|
||||
.set_dim(std::vector<int64_t>(o.sizes().begin(), o.sizes().end()))
|
||||
.set_stride(
|
||||
std::vector<int64_t>(o.strides().begin(), o.strides().end())));
|
||||
auto STATS = mha_graph->tensor(
|
||||
fe::graph::Tensor_attributes()
|
||||
.set_name("Stats")
|
||||
.set_dim(std::vector<int64_t>(
|
||||
softmaxstats.sizes().begin(), softmaxstats.sizes().end()))
|
||||
.set_stride(std::vector<int64_t>(
|
||||
softmaxstats.strides().begin(), softmaxstats.strides().end()))
|
||||
.set_data_type(fe::DataType_t::FLOAT));
|
||||
auto DO = mha_graph->tensor(
|
||||
fe::graph::Tensor_attributes()
|
||||
.set_name("DO")
|
||||
.set_dim(std::vector<int64_t>(dO.sizes().begin(), dO.sizes().end()))
|
||||
.set_stride(
|
||||
std::vector<int64_t>(dO.strides().begin(), dO.strides().end())));
|
||||
auto sdpa_backward_options = fe::graph::SDPA_backward_attributes()
|
||||
.set_name("CUDNN_SDPA_BACKWARD")
|
||||
.set_causal_mask(is_causal)
|
||||
.set_attn_scale(attn_scale);
|
||||
auto O = mha_graph->tensor(fe::graph::Tensor_attributes()
|
||||
.set_name("O")
|
||||
.set_dim(o.sizes().vec())
|
||||
.set_stride(o.strides().vec()));
|
||||
auto STATS = mha_graph->tensor(fe::graph::Tensor_attributes()
|
||||
.set_name("Stats")
|
||||
.set_dim(softmaxstats.sizes().vec())
|
||||
.set_stride(softmaxstats.strides().vec())
|
||||
.set_data_type(fe::DataType_t::FLOAT));
|
||||
auto DO = mha_graph->tensor(fe::graph::Tensor_attributes()
|
||||
.set_name("DO")
|
||||
.set_dim(dO.sizes().vec())
|
||||
.set_stride(dO.strides().vec()));
|
||||
if (dropout_probability != 0.0f) {
|
||||
sdpa_backward_options.set_dropout(dropout_probability, Seed, Offset);
|
||||
}
|
||||
auto [DQ, DK, DV] =
|
||||
mha_graph->sdpa_backward(Q, K, V, O, DO, STATS, sdpa_backward_options);
|
||||
DQ->set_output(true)
|
||||
.set_dim(std::vector<int64_t>(dQ.sizes().begin(), dQ.sizes().end()))
|
||||
.set_stride(
|
||||
std::vector<int64_t>(dQ.strides().begin(), dQ.strides().end()));
|
||||
DK->set_output(true)
|
||||
.set_dim(std::vector<int64_t>(dK.sizes().begin(), dK.sizes().end()))
|
||||
.set_stride(
|
||||
std::vector<int64_t>(dK.strides().begin(), dK.strides().end()));
|
||||
DV->set_output(true)
|
||||
.set_dim(std::vector<int64_t>(dV.sizes().begin(), dV.sizes().end()))
|
||||
.set_stride(
|
||||
std::vector<int64_t>(dV.strides().begin(), dV.strides().end()));
|
||||
DQ->set_output(true).set_dim(dQ.sizes().vec()).set_stride(dQ.strides().vec());
|
||||
DK->set_output(true).set_dim(dK.sizes().vec()).set_stride(dK.strides().vec());
|
||||
DV->set_output(true).set_dim(dV.sizes().vec()).set_stride(dV.strides().vec());
|
||||
AT_CUDNN_FRONTEND_CHECK(mha_graph->validate());
|
||||
AT_CUDNN_FRONTEND_CHECK(mha_graph->build_operation_graph(handle));
|
||||
AT_CUDNN_FRONTEND_CHECK(
|
||||
|
|
@ -534,6 +520,7 @@ auto build_graph_and_tensors_backward(
|
|||
std::move(Q),
|
||||
std::move(K),
|
||||
std::move(V),
|
||||
std::move(bias),
|
||||
std::move(attn_scale),
|
||||
std::move(Seed),
|
||||
std::move(Offset),
|
||||
|
|
@ -559,6 +546,7 @@ void run_cudnn_SDP_fprop(
|
|||
const Tensor& q,
|
||||
const Tensor& k,
|
||||
const Tensor& v,
|
||||
const std::optional<Tensor>& attn_bias,
|
||||
Tensor& softmaxstats,
|
||||
Tensor& o,
|
||||
Tensor& dropoutseed,
|
||||
|
|
@ -583,6 +571,7 @@ void run_cudnn_SDP_fprop(
|
|||
q,
|
||||
k,
|
||||
v,
|
||||
attn_bias,
|
||||
dropout_probability,
|
||||
is_causal,
|
||||
return_softmaxstats);
|
||||
|
|
@ -605,13 +594,14 @@ void run_cudnn_SDP_fprop(
|
|||
q,
|
||||
k,
|
||||
v,
|
||||
attn_bias,
|
||||
softmaxstats,
|
||||
o,
|
||||
dropoutseed,
|
||||
dropoutoffset,
|
||||
handle);
|
||||
}
|
||||
auto [mha_graph, Q, K, V, attn_scale, seed, offset, O, Stats] =
|
||||
auto [mha_graph, Q, K, V, bias, attn_scale, seed, offset, O, Stats] =
|
||||
graph_and_tensors_values;
|
||||
std::unordered_map<std::shared_ptr<fe::graph::Tensor_attributes>, void*>
|
||||
variant_pack = {
|
||||
|
|
@ -619,13 +609,15 @@ void run_cudnn_SDP_fprop(
|
|||
{K, k.data_ptr()},
|
||||
{V, v.data_ptr()},
|
||||
{attn_scale, &scaling_factor},
|
||||
//{bias, bias.data_ptr()},
|
||||
{seed, dropoutseed.data_ptr()},
|
||||
{offset, dropoutoffset.data_ptr()},
|
||||
{O, o.data_ptr()}};
|
||||
if (return_softmaxstats) {
|
||||
variant_pack[Stats] = softmaxstats.data_ptr();
|
||||
}
|
||||
if (attn_bias.has_value()) {
|
||||
variant_pack[bias.value()] = attn_bias.value().data_ptr();
|
||||
}
|
||||
auto workspace_size = mha_graph->get_workspace_size();
|
||||
auto workspace_ptr =
|
||||
c10::cuda::CUDACachingAllocator::get()->allocate(workspace_size);
|
||||
|
|
@ -647,6 +639,7 @@ void run_cudnn_SDP_bprop(
|
|||
const Tensor& q,
|
||||
const Tensor& k,
|
||||
const Tensor& v,
|
||||
const std::optional<Tensor>& attn_bias,
|
||||
const Tensor& o,
|
||||
const Tensor& dO,
|
||||
const Tensor& softmaxstats,
|
||||
|
|
@ -690,6 +683,7 @@ void run_cudnn_SDP_bprop(
|
|||
q,
|
||||
k,
|
||||
v,
|
||||
attn_bias,
|
||||
dropout_probability,
|
||||
is_causal,
|
||||
true);
|
||||
|
|
@ -711,6 +705,7 @@ void run_cudnn_SDP_bprop(
|
|||
q,
|
||||
k,
|
||||
v,
|
||||
attn_bias,
|
||||
o,
|
||||
dO_,
|
||||
softmaxstats,
|
||||
|
|
@ -722,8 +717,20 @@ void run_cudnn_SDP_bprop(
|
|||
handle);
|
||||
}
|
||||
auto
|
||||
[mha_graph, Q, K, V, attn_scale, Seed, Offset, O, Do, Stats, Dq, Dk, Dv] =
|
||||
graph_and_tensors_backward_values;
|
||||
[mha_graph,
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
bias,
|
||||
attn_scale,
|
||||
Seed,
|
||||
Offset,
|
||||
O,
|
||||
Do,
|
||||
Stats,
|
||||
Dq,
|
||||
Dk,
|
||||
Dv] = graph_and_tensors_backward_values;
|
||||
std::unordered_map<std::shared_ptr<fe::graph::Tensor_attributes>, void*>
|
||||
variant_pack = {// inputs
|
||||
{Q, q.data_ptr()},
|
||||
|
|
@ -742,6 +749,9 @@ void run_cudnn_SDP_bprop(
|
|||
variant_pack[Seed] = dropoutseed.data_ptr();
|
||||
variant_pack[Offset] = dropoutoffset.data_ptr();
|
||||
}
|
||||
if (attn_bias.has_value()) {
|
||||
variant_pack[bias.value()] = attn_bias.value().data_ptr();
|
||||
}
|
||||
auto workspace_size = mha_graph->get_workspace_size();
|
||||
auto workspace_ptr =
|
||||
c10::cuda::CUDACachingAllocator::get()->allocate(workspace_size);
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ void run_cudnn_SDP_fprop(
|
|||
const Tensor& q,
|
||||
const Tensor& k,
|
||||
const Tensor& v,
|
||||
const std::optional<Tensor>& attn_bias,
|
||||
Tensor& softmaxstats,
|
||||
Tensor& o,
|
||||
Tensor& dropoutseed,
|
||||
|
|
@ -36,6 +37,7 @@ void run_cudnn_SDP_bprop(
|
|||
const Tensor& q,
|
||||
const Tensor& k,
|
||||
const Tensor& v,
|
||||
const std::optional<Tensor>& attn_bias,
|
||||
const Tensor& o,
|
||||
const Tensor& dO,
|
||||
const Tensor& softmaxstats,
|
||||
|
|
|
|||
|
|
@ -658,7 +658,7 @@ Tensor scaled_dot_product_attention(
|
|||
case sdp::SDPBackend::cudnn_attention: {
|
||||
bool compute_logsumexp = should_compute_logsumexp(query_, key, value);
|
||||
auto out_lse_softmax = at::_scaled_dot_product_cudnn_attention(
|
||||
query_, key, value, attn_mask_, compute_logsumexp, dropout_p, is_causal, false /*return_debug_mask*/, scale);
|
||||
query_, key, value, attn_mask, compute_logsumexp, dropout_p, is_causal, false /*return_debug_mask*/, scale);
|
||||
return std::get<0>(out_lse_softmax);
|
||||
}
|
||||
case sdp::SDPBackend::flash_attention: {
|
||||
|
|
|
|||
|
|
@ -774,6 +774,18 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt, Tensor, Ten
|
|||
TORCH_CHECK(
|
||||
max_seqlen_batch_k == max_seqlen_batch_v,
|
||||
"Key and Value must have the same sequence length");
|
||||
auto attn_bias_ = attn_bias;
|
||||
if (attn_bias_.has_value()) {
|
||||
const auto bias_dim = attn_bias_.value().dim();
|
||||
if (bias_dim == 2) {
|
||||
attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k});
|
||||
} else if (bias_dim == 3) {
|
||||
attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k});
|
||||
} else {
|
||||
attn_bias_ = attn_bias_.value().expand({batch_size, attn_bias_.value().size(1), max_seqlen_batch_q, max_seqlen_batch_k});
|
||||
TORCH_CHECK(bias_dim == 4, "cuDNN SDPA expects either a 2D, 3D, or 4D attn_bias but got ", attn_bias_.value().dim(), "D");
|
||||
}
|
||||
}
|
||||
|
||||
Tensor attention, log_sumexp;
|
||||
|
||||
|
|
@ -818,13 +830,14 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt, Tensor, Ten
|
|||
query/* Tensor q*/,
|
||||
key/* Tensor k*/,
|
||||
value/* Tensor v*/,
|
||||
attn_bias_ /* std::optional<Tensor> */,
|
||||
log_sumexp/*Tensor softmaxstats*/,
|
||||
attention/*Tensor o*/,
|
||||
cudnn_seed/*Tensor dropoutseed*/,
|
||||
cudnn_offset/*Tensor dropoutoffset*/);
|
||||
|
||||
// TODO(eqy): support debug_attn_mask
|
||||
return std::make_tuple(attention, log_sumexp, Tensor(), Tensor(), max_seqlen_batch_q, max_seqlen_batch_k, cudnn_seed, cudnn_offset, Tensor());
|
||||
return std::make_tuple(std::move(attention), std::move(log_sumexp), std::move(Tensor()), std::move(Tensor()), max_seqlen_batch_q, max_seqlen_batch_k, std::move(cudnn_seed), std::move(cudnn_offset), std::move(Tensor()));
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor, Tensor, Tensor> _scaled_dot_product_efficient_attention_cuda(
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@
|
|||
#include <ATen/ops/_efficient_attention_backward.h>
|
||||
#include <ATen/ops/_efficient_attention_backward_native.h>
|
||||
#include <ATen/ops/_scaled_dot_product_flash_attention_backward_native.h>
|
||||
#include <ATen/ops/zeros.h>
|
||||
#endif
|
||||
|
||||
#ifdef USE_FLASH_ATTENTION
|
||||
|
|
@ -195,6 +196,27 @@ std::tuple<Tensor, Tensor, Tensor> _scaled_dot_product_cudnn_attention_backward_
|
|||
const int64_t num_heads = query.size(1);
|
||||
const int64_t head_dim_qk = query.size(3);
|
||||
const int64_t head_dim_v = value.size(3);
|
||||
const int64_t max_seqlen_batch_q = query.size(2);
|
||||
const int64_t max_seqlen_batch_k = key.size(2);
|
||||
|
||||
// This is needed because SaveVariable automatically converts
|
||||
// std::optional to undefined tensor
|
||||
std::optional<Tensor> attn_bias_;
|
||||
if (attn_bias.defined()) {
|
||||
attn_bias_ = attn_bias;
|
||||
}
|
||||
if (attn_bias_.has_value()) {
|
||||
const auto bias_dim = attn_bias_.value().dim();
|
||||
if (bias_dim == 2) {
|
||||
attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k});
|
||||
} else if (bias_dim == 3) {
|
||||
attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k});
|
||||
} else {
|
||||
attn_bias_ = attn_bias_.value().expand({batch_size, attn_bias_.value().size(1), max_seqlen_batch_q, max_seqlen_batch_k});
|
||||
TORCH_CHECK(bias_dim == 4, "cuDNN SDPA expects either a 2D, 3D, or 4D attn_bias but got ", attn_bias_.value().dim(), "D");
|
||||
}
|
||||
}
|
||||
|
||||
const auto softmax_scale = sdp::calculate_scale(query, scale).as_float_unchecked();
|
||||
auto dq = at::empty_like(query);
|
||||
auto dk = at::empty_like(key);
|
||||
|
|
@ -211,6 +233,7 @@ std::tuple<Tensor, Tensor, Tensor> _scaled_dot_product_cudnn_attention_backward_
|
|||
query /*const Tensor& q*/,
|
||||
key /*const Tensor& k*/,
|
||||
value /*const Tensor& v*/,
|
||||
attn_bias_ /*const std::optional<Tensor>& attn_bias*/,
|
||||
out /*const Tensor& o*/,
|
||||
grad_out/*const Tensor& dO*/,
|
||||
logsumexp.unsqueeze(-1)/*const Tensor& softmaxstats*/,
|
||||
|
|
@ -219,7 +242,7 @@ std::tuple<Tensor, Tensor, Tensor> _scaled_dot_product_cudnn_attention_backward_
|
|||
dv/*Tensor& dV*/,
|
||||
philox_seed/*Tensor& dropoutseed*/,
|
||||
philox_offset/*Tensor& dropoutoffset*/);
|
||||
return std::make_tuple(dq, dk, dv);
|
||||
return std::make_tuple(std::move(dq), std::move(dk), std::move(dv));
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
|
||||
|
|
|
|||
|
|
@ -550,7 +550,7 @@ bool can_use_cudnn_attention(const sdp_params& params, bool debug) {
|
|||
check_cudnn_deterministic,
|
||||
// check_is_causal,
|
||||
check_dtypes_low_precision,
|
||||
check_for_attn_mask_cudnn,
|
||||
check_attn_mask_shape,
|
||||
check_cudnn_hardware_support
|
||||
);
|
||||
for (auto& constraint : general_constraints) {
|
||||
|
|
|
|||
|
|
@ -274,17 +274,6 @@ inline bool check_for_attn_mask(sdp_params const& params, bool debug) {
|
|||
return true;
|
||||
}
|
||||
|
||||
// TODO(eqy): remove this once support is added
|
||||
inline bool check_for_attn_mask_cudnn(sdp_params const& params, bool debug) {
|
||||
if (params.attn_mask.has_value()) {
|
||||
if (debug) {
|
||||
TORCH_WARN("cuDNN Attention does not support non-null attn_mask.");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
inline bool check_attn_mask_shape(sdp_params const& params, bool debug) {
|
||||
auto attn_mask = params.attn_mask;
|
||||
if (!attn_mask.has_value()) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user