diff --git a/aten/src/ATen/native/cudnn/MHA.cpp b/aten/src/ATen/native/cudnn/MHA.cpp index 5d146edb90b..177f3bd3c12 100644 --- a/aten/src/ATen/native/cudnn/MHA.cpp +++ b/aten/src/ATen/native/cudnn/MHA.cpp @@ -84,6 +84,37 @@ void run_cudnn_SDP_bprop( false, "PyTorch was not compiled with cuDNN Flash Attention enabled!"); } +void run_cudnn_SDP_bprop_nestedtensor( + int64_t b, + int64_t h_q, + int64_t h_k, + int64_t h_v, + int64_t s_q, + int64_t s_kv, + int64_t d_qk, + int64_t d_v, + + float scaling_factor, + bool is_causal, + float dropout_probability, + const Tensor& cum_seqlen_q, + const Tensor& cum_seqlen_kv, + const Tensor& q, + const Tensor& k, + const Tensor& v, + const std::optional& attn_bias, + const Tensor& o, + const Tensor& dO, + const Tensor& softmaxstats, + Tensor& dQ, + Tensor& dK, + Tensor& dV, + const Tensor& dropoutseed, + const Tensor& dropoutoffset) { + TORCH_CHECK( + false, "PyTorch was not compiled with cuDNN Flash Attention enabled!"); +} + } // namespace native } // namespace at @@ -110,40 +141,6 @@ namespace native { #include namespace fe = cudnn_frontend; -using graph_and_tensors = std::tuple< - std::shared_ptr, - std::shared_ptr, // Q, - std::shared_ptr, // K, - std::shared_ptr, // V, - std::optional>, // Bias - std::shared_ptr, // Attn_scale, - // TODO(eqy): additional options - // std::shared_ptr, // SEQ_LEN_Q, - // std::shared_ptr, // SEQ_LEN_KV, - std::shared_ptr, // Seed, - std::shared_ptr, // Offset, - // std::shared_ptr, // Dropout_mask, - // std::shared_ptr, // Dropout_scale - std::shared_ptr, // O - std::shared_ptr // Stats - >; - -using graph_and_tensors_backward = std::tuple< - std::shared_ptr, - std::shared_ptr, // Q, - std::shared_ptr, // K, - std::shared_ptr, // V, - std::optional>, // Bias, - std::shared_ptr, // Attn_scale, - std::shared_ptr, // Seed, - std::shared_ptr, // Offset, - std::shared_ptr, // O, - std::shared_ptr, // dO, - std::shared_ptr, // stats, - std::shared_ptr, // dQ, - std::shared_ptr, // dK,, - std::shared_ptr // dV, - >; #define MAX_MHA_DIM 4 @@ -297,11 +294,40 @@ struct MHAGraphCache { // @eqy: use thread local caches as cuDNN Execution Plans are not guaranteed to // be thread safe across all engines see Limitations in // https://docs.nvidia.com/deeplearning/cudnn/backend/latest/release-notes.html -thread_local MHAGraphCache mhagraphcache; -thread_local MHAGraphCache +thread_local MHAGraphCache< + std::shared_ptr, + MHACacheKeyWrapper> + mhagraphcache; +thread_local MHAGraphCache< + std::shared_ptr, + MHACacheKeyWrapper> mhagraphbackwardcache; namespace { + +enum UIDS { + Q, + K, + V, + O, + BIAS, + SCALE, + SEED, + OFFSET, + LSE, + DO, + DQ, + DK, + DV, + SEQ_LEN_Q, + SEQ_LEN_KV, + RAG_Q_OFF, + RAG_K_OFF, + RAG_V_OFF, + RAG_O_OFF, + RAG_LSE_OFF +}; + // analogous to the same function in Descriptors.h for cuDNN Convolutions... auto fixSizeOneDimStrideSDPA( const IntArrayRef sizes, @@ -403,7 +429,7 @@ bool same_strides(const Tensor& t1, const Tensor& t2) { } } // namespace -auto build_graph_and_tensors( +auto build_graph( int64_t b, int64_t h, int64_t s_q, @@ -436,46 +462,55 @@ auto build_graph_and_tensors( .set_compute_data_type(fe::DataType_t::FLOAT); auto attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(SCALE) .set_name("Attn_scale") .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_is_pass_by_value(true) .set_data_type(fe::DataType_t::FLOAT)); - auto seed = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Seed") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type( - dropoutseed.dtype() == kInt - ? fe::DataType_t::INT32 - : fe::DataType_t::INT64)); - auto offset = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Offset") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type( - dropoutoffset.dtype() == kInt - ? fe::DataType_t::INT32 - : fe::DataType_t::INT64)); auto scaled_dot_product_flash_attention_options = fe::graph::SDPA_attributes() .set_name("CUDNN_SDPA") .set_is_inference(return_softmaxstats == false) .set_causal_mask(is_causal) - .set_attn_scale(attn_scale) - .set_dropout(dropout_probability, seed, offset); - auto Q = mha_graph->tensor( + .set_attn_scale(attn_scale); + if (dropout_probability != 0.0f) { + auto seed = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(SEED) + .set_name("Seed") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type( + dropoutseed.dtype() == kInt + ? fe::DataType_t::INT32 + : fe::DataType_t::INT64)); + auto offset = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(OFFSET) + .set_name("Offset") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type( + dropoutoffset.dtype() == kInt + ? fe::DataType_t::INT32 + : fe::DataType_t::INT64)); + scaled_dot_product_flash_attention_options.set_dropout( + dropout_probability, seed, offset); + } + auto Q_ = mha_graph->tensor( fe::graph::Tensor_attributes() + .set_uid(Q) .set_name("Q") .set_dim(q.sizes().vec()) .set_stride(fixSizeOneDimStrideSDPA(q.sizes(), q.strides().vec()))); - auto K = mha_graph->tensor( + auto K_ = mha_graph->tensor( fe::graph::Tensor_attributes() + .set_uid(K) .set_name("K") .set_dim(k.sizes().vec()) .set_stride(fixSizeOneDimStrideSDPA(k.sizes(), k.strides().vec()))); - auto V = mha_graph->tensor( + auto V_ = mha_graph->tensor( fe::graph::Tensor_attributes() + .set_uid(V) .set_name("V") .set_dim(v.sizes().vec()) .set_stride(fixSizeOneDimStrideSDPA(v.sizes(), v.strides().vec()))); @@ -483,17 +518,20 @@ auto build_graph_and_tensors( if (attn_bias.has_value()) { bias = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(BIAS) .set_name("bias") .set_dim(attn_bias.value().sizes().vec()) .set_stride(attn_bias.value().strides().vec())); scaled_dot_product_flash_attention_options.set_bias(bias.value()); } - auto [O, Stats] = - mha_graph->sdpa(Q, K, V, scaled_dot_product_flash_attention_options); - O->set_output(true).set_dim(o.sizes().vec()).set_stride(o.strides().vec()); + auto [O_, Stats] = + mha_graph->sdpa(Q_, K_, V_, scaled_dot_product_flash_attention_options); + O_->set_uid(O); + O_->set_output(true).set_dim(o.sizes().vec()).set_stride(o.strides().vec()); if (Stats) { + Stats->set_uid(LSE); Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT); } @@ -504,20 +542,10 @@ auto build_graph_and_tensors( AT_CUDNN_FRONTEND_CHECK(mha_graph->check_support(handle)); AT_CUDNN_FRONTEND_CHECK(mha_graph->build_plans(handle)); - return std::make_tuple( - std::move(mha_graph), - std::move(Q), - std::move(K), - std::move(V), - std::move(bias), - std::move(attn_scale), - std::move(seed), - std::move(offset), - std::move(O), - std::move(Stats)); + return mha_graph; } -auto build_graph_and_tensors_nestedtensor( +auto build_graph_nestedtensor( int64_t b, int64_t h_q, int64_t h_k, @@ -554,28 +582,22 @@ auto build_graph_and_tensors_nestedtensor( .set_compute_data_type(fe::DataType_t::FLOAT); auto attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(SCALE) .set_name("Attn_scale") .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_is_pass_by_value(true) .set_data_type(fe::DataType_t::FLOAT)); - auto seed = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Seed") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); - auto offset = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Offset") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); - auto SEQ_LEN_Q = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Seq_q") - .set_dim({b, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); - auto SEQ_LEN_KV = + auto SEQ_LEN_Q_ = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(SEQ_LEN_Q) + .set_name("Seq_q") + .set_dim({b, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + auto SEQ_LEN_KV_ = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(SEQ_LEN_KV) .set_name("Seq_kv") .set_dim({b, 1, 1, 1}) .set_stride({1, 1, 1, 1}) @@ -587,41 +609,66 @@ auto build_graph_and_tensors_nestedtensor( .set_is_inference(return_softmaxstats == false) .set_causal_mask(is_causal) .set_attn_scale(attn_scale) - .set_dropout(dropout_probability, seed, offset) - .set_seq_len_q(SEQ_LEN_Q) - .set_seq_len_kv(SEQ_LEN_KV) + .set_seq_len_q(SEQ_LEN_Q_) + .set_seq_len_kv(SEQ_LEN_KV_) .set_padding_mask(true); + if (dropout_probability != 0.0f) { + auto seed = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(SEED) + .set_name("Seed") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type( + dropoutseed.dtype() == kInt + ? fe::DataType_t::INT32 + : fe::DataType_t::INT64)); + auto offset = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(OFFSET) + .set_name("Offset") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type( + dropoutoffset.dtype() == kInt + ? fe::DataType_t::INT32 + : fe::DataType_t::INT64)); + scaled_dot_product_flash_attention_options.set_dropout( + dropout_probability, seed, offset); + } // We hardcode BSHD to cuDNN even though the underlying layout is THD auto q_strides = q.strides(); auto k_strides = k.strides(); auto v_strides = v.strides(); + // NB: cuDNN API shape is transposed constexpr int strideidx0 = 1; constexpr int strideidx1 = 0; constexpr int strideidx2 = 2; - auto Q = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Q") - .set_dim({b, h_q, s_q, d_qk}) - .set_stride( - {INT_MAX, - q_strides[strideidx0], - q_strides[strideidx1], - q_strides[strideidx2]})); - auto K = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("K") - .set_dim({b, h_k, s_kv, d_qk}) - .set_stride( - {INT_MAX, - k_strides[strideidx0], - k_strides[strideidx1], - k_strides[strideidx2]})); - auto V = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("V") - .set_dim({b, h_v, s_kv, d_v}) - .set_stride( - {INT_MAX, - v_strides[strideidx0], - v_strides[strideidx1], - v_strides[strideidx2]})); + auto Q_ = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(Q) + .set_name("Q") + .set_dim({b, h_q, s_q, d_qk}) + .set_stride( + {INT_MAX, + q_strides[strideidx0], + q_strides[strideidx1], + q_strides[strideidx2]})); + auto K_ = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(K) + .set_name("K") + .set_dim({b, h_k, s_kv, d_qk}) + .set_stride( + {INT_MAX, + k_strides[strideidx0], + k_strides[strideidx1], + k_strides[strideidx2]})); + auto V_ = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(V) + .set_name("V") + .set_dim({b, h_v, s_kv, d_v}) + .set_stride( + {INT_MAX, + v_strides[strideidx0], + v_strides[strideidx1], + v_strides[strideidx2]})); std::optional> bias; if (attn_bias.has_value()) { TORCH_CHECK( @@ -629,44 +676,48 @@ auto build_graph_and_tensors_nestedtensor( "attn_bias not yet supportd with cuDNN Attention and NestedTensor"); bias = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(BIAS) .set_name("bias") .set_dim(attn_bias.value().sizes().vec()) .set_stride(attn_bias.value().strides().vec())); scaled_dot_product_flash_attention_options.set_bias(bias.value()); } - auto RAG_Q_OFF = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("cum_seq_q") - .set_dim({b + 1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); - auto RAG_K_OFF = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("cum_seq_k") - .set_dim({b + 1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); - auto RAG_V_OFF = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("cum_seq_v") - .set_dim({b + 1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); - auto RAG_O_OFF = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("cum_seq_o") - .set_dim({b + 1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); - // auto RAG_STATS_OFF = mha_graph->tensor(fe::graph::Tensor_attributes() - // .set_name("cum_seq_stats") - // .set_dim({b + 1, 1, 1, 1}) - // .set_stride({1, 1, 1, 1}) - // .set_data_type(fe::DataType_t::INT32)); - auto RAG_STATS_OFF = nullptr; - Q->set_ragged_offset(RAG_Q_OFF); - K->set_ragged_offset(RAG_K_OFF); - V->set_ragged_offset(RAG_V_OFF); - auto [O, Stats] = - mha_graph->sdpa(Q, K, V, scaled_dot_product_flash_attention_options); + auto RAG_Q_OFF_ = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(RAG_Q_OFF) + .set_name("cum_seq_q") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + auto RAG_K_OFF_ = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(RAG_K_OFF) + .set_name("cum_seq_k") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + auto RAG_V_OFF_ = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(RAG_V_OFF) + .set_name("cum_seq_v") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + auto RAG_O_OFF_ = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(RAG_O_OFF) + .set_name("cum_seq_o") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + Q_->set_ragged_offset(RAG_Q_OFF_); + K_->set_ragged_offset(RAG_K_OFF_); + V_->set_ragged_offset(RAG_V_OFF_); + auto [O_, Stats] = + mha_graph->sdpa(Q_, K_, V_, scaled_dot_product_flash_attention_options); auto o_strides = o.strides(); - O->set_output(true) + O_->set_output(true) + .set_uid(O) .set_dim({b, h_q, s_q, d_v}) .set_stride( {INT_MAX, @@ -674,16 +725,20 @@ auto build_graph_and_tensors_nestedtensor( o_strides[strideidx1], o_strides[strideidx2]}); - O->set_ragged_offset(RAG_O_OFF); + O_->set_ragged_offset(RAG_O_OFF_); if (Stats) { - TORCH_CHECK( - false, - "cuDNN SDPA Nested Tensor does not yet handle backwards/logsumexp computation"); - // TODO(eqy): fix when stats (backward) support is added + auto RAG_STATS_OFF = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(RAG_LSE_OFF) + .set_name("cum_seq_stats") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); Stats->set_output(true) + .set_uid(LSE) .set_data_type(fe::DataType_t::FLOAT) .set_dim({b, h_q, s_q, 1}) - .set_stride({h_q * s_q * d_v, d_v, s_q * d_v, 1}); + .set_stride({h_q * s_q, 1, h_q, 1}); Stats->set_ragged_offset(RAG_STATS_OFF); } AT_CUDNN_FRONTEND_CHECK(mha_graph->validate()); @@ -692,27 +747,10 @@ auto build_graph_and_tensors_nestedtensor( mha_graph->create_execution_plans({fe::HeurMode_t::A})); AT_CUDNN_FRONTEND_CHECK(mha_graph->check_support(handle)); AT_CUDNN_FRONTEND_CHECK(mha_graph->build_plans(handle)); - return std::make_tuple( - std::move(mha_graph), - std::move(Q), - std::move(K), - std::move(V), - std::move(bias), - std::move(attn_scale), - std::move(seed), - std::move(offset), - std::move(O), - std::move(Stats), - std::move(RAG_Q_OFF), - std::move(RAG_K_OFF), - std::move(RAG_V_OFF), - std::move(RAG_O_OFF), - std::move(RAG_STATS_OFF), - std::move(SEQ_LEN_Q), - std::move(SEQ_LEN_KV)); + return mha_graph; } -auto build_graph_and_tensors_backward( +auto build_graph_backward( int64_t b, int64_t h, int64_t s_q, @@ -748,6 +786,7 @@ auto build_graph_and_tensors_backward( .set_compute_data_type(fe::DataType_t::FLOAT); auto attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(SCALE) .set_name("Attn_scale") .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) @@ -757,87 +796,327 @@ auto build_graph_and_tensors_backward( .set_name("CUDNN_SDPA_BACKWARD") .set_causal_mask(is_causal) .set_attn_scale(attn_scale); - auto Q = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Q") - .set_dim(q.sizes().vec()) - .set_stride(q.strides().vec())); - auto K = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("K") - .set_dim(k.sizes().vec()) - .set_stride(k.strides().vec())); - auto V = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("V") - .set_dim(v.sizes().vec()) - .set_stride(v.strides().vec())); + auto Q_ = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(Q) + .set_name("Q") + .set_dim(q.sizes().vec()) + .set_stride(q.strides().vec())); + auto K_ = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(K) + .set_name("K") + .set_dim(k.sizes().vec()) + .set_stride(k.strides().vec())); + auto V_ = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(V) + .set_name("V") + .set_dim(v.sizes().vec()) + .set_stride(v.strides().vec())); std::optional> bias; if (attn_bias.has_value()) { bias = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(BIAS) .set_name("bias") .set_dim(attn_bias.value().sizes().vec()) .set_stride(attn_bias.value().strides().vec())); sdpa_backward_options.set_bias(bias.value()); } - auto Seed = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Seed") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type( - dropoutseed.dtype() == kInt - ? fe::DataType_t::INT32 - : fe::DataType_t::INT64)); - - auto Offset = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Offset") + if (dropout_probability != 0.0f) { + auto seed = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(SEED) + .set_name("Seed") .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type( - dropoutoffset.dtype() == kInt + dropoutseed.dtype() == kInt ? fe::DataType_t::INT32 : fe::DataType_t::INT64)); + auto offset = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(OFFSET) + .set_name("Offset") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type( + dropoutoffset.dtype() == kInt + ? fe::DataType_t::INT32 + : fe::DataType_t::INT64)); + sdpa_backward_options.set_dropout(dropout_probability, seed, offset); + } - auto O = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("O") - .set_dim(o.sizes().vec()) - .set_stride(o.strides().vec())); - auto STATS = mha_graph->tensor(fe::graph::Tensor_attributes() + auto O_ = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(O) + .set_name("O") + .set_dim(o.sizes().vec()) + .set_stride(o.strides().vec())); + auto Stats = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(LSE) .set_name("Stats") .set_dim(softmaxstats.sizes().vec()) .set_stride(softmaxstats.strides().vec()) .set_data_type(fe::DataType_t::FLOAT)); - auto DO = mha_graph->tensor(fe::graph::Tensor_attributes() + auto Do = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(DO) .set_name("DO") .set_dim(dO.sizes().vec()) .set_stride(dO.strides().vec())); - if (dropout_probability != 0.0f) { - sdpa_backward_options.set_dropout(dropout_probability, Seed, Offset); - } - auto [DQ, DK, DV] = - mha_graph->sdpa_backward(Q, K, V, O, DO, STATS, sdpa_backward_options); - DQ->set_output(true).set_dim(dQ.sizes().vec()).set_stride(dQ.strides().vec()); - DK->set_output(true).set_dim(dK.sizes().vec()).set_stride(dK.strides().vec()); - DV->set_output(true).set_dim(dV.sizes().vec()).set_stride(dV.strides().vec()); + auto [Dq, Dk, Dv] = mha_graph->sdpa_backward( + Q_, K_, V_, O_, Do, Stats, sdpa_backward_options); + Dq->set_uid(DQ); + Dq->set_output(true).set_dim(dQ.sizes().vec()).set_stride(dQ.strides().vec()); + Dk->set_uid(DK); + Dk->set_output(true).set_dim(dK.sizes().vec()).set_stride(dK.strides().vec()); + Dv->set_uid(DV); + Dv->set_output(true).set_dim(dV.sizes().vec()).set_stride(dV.strides().vec()); AT_CUDNN_FRONTEND_CHECK(mha_graph->validate()); AT_CUDNN_FRONTEND_CHECK(mha_graph->build_operation_graph(handle)); AT_CUDNN_FRONTEND_CHECK( mha_graph->create_execution_plans({fe::HeurMode_t::A})); AT_CUDNN_FRONTEND_CHECK(mha_graph->check_support(handle)); AT_CUDNN_FRONTEND_CHECK(mha_graph->build_plans(handle)); - return std::make_tuple( - std::move(mha_graph), - std::move(Q), - std::move(K), - std::move(V), - std::move(bias), - std::move(attn_scale), - std::move(Seed), - std::move(Offset), - std::move(O), - std::move(DO), - std::move(STATS), - std::move(DQ), - std::move(DK), - std::move(DV)); + return mha_graph; +} + +auto build_graph_backward_nestedtensor( + int64_t b, + int64_t h_q, + int64_t h_k, + int64_t h_v, + int64_t s_q, + int64_t s_kv, + int64_t d_qk, + int64_t d_v, + float scaling_factor, + bool is_causal, + float dropout_probability, + const Tensor& cum_seqlen_q, + const Tensor& cum_seqlen_kv, + const Tensor& q, + const Tensor& k, + const Tensor& v, + const std::optional& attn_bias, + const Tensor& o, + const Tensor& dO, + const Tensor& softmaxstats, + Tensor& dQ, + Tensor& dK, + Tensor& dV, + const Tensor& dropoutseed, + const Tensor& dropoutoffset, + cudnnHandle_t& handle) { + auto dtype = fe::DataType_t::HALF; + if (q.scalar_type() == kBFloat16) { + dtype = fe::DataType_t::BFLOAT16; + } + auto mha_graph = std::make_shared(); + // We're baking in float accumulation and scale types + // in theory the graph may support other types, but they + // have not been tested + mha_graph->set_io_data_type(dtype) + .set_intermediate_data_type(fe::DataType_t::FLOAT) + .set_compute_data_type(fe::DataType_t::FLOAT); + auto attn_scale = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(SCALE) + .set_name("Attn_scale") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_is_pass_by_value(true) + .set_data_type(fe::DataType_t::FLOAT)); + + auto SEQ_LEN_Q_ = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(SEQ_LEN_Q) + .set_name("Seq_q") + .set_dim({b, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + auto SEQ_LEN_KV_ = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(SEQ_LEN_KV) + .set_name("Seq_kv") + .set_dim({b, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + auto sdpa_backward_options = fe::graph::SDPA_backward_attributes() + .set_name("CUDNN_SDPA_NESTEDTENSOR_BACKWARD") + .set_causal_mask(is_causal) + .set_attn_scale(attn_scale) + .set_seq_len_q(SEQ_LEN_Q_) + .set_seq_len_kv(SEQ_LEN_KV_) + .set_padding_mask(true); + if (dropout_probability != 0.0f) { + auto seed = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(SEED) + .set_name("Seed") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type( + dropoutseed.dtype() == kInt + ? fe::DataType_t::INT32 + : fe::DataType_t::INT64)); + auto offset = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(OFFSET) + .set_name("Offset") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type( + dropoutoffset.dtype() == kInt + ? fe::DataType_t::INT32 + : fe::DataType_t::INT64)); + sdpa_backward_options.set_dropout(dropout_probability, seed, offset); + } + auto q_strides = q.strides(); + auto k_strides = k.strides(); + auto v_strides = v.strides(); + // NB: cuDNN API shape is transposed + constexpr int strideidx0 = 1; + constexpr int strideidx1 = 0; + constexpr int strideidx2 = 2; + auto Q_ = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(Q) + .set_name("Q") + .set_dim({b, h_q, s_q, d_qk}) + .set_stride( + {INT_MAX, + q_strides[strideidx0], + q_strides[strideidx1], + q_strides[strideidx2]})); + auto K_ = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(K) + .set_name("K") + .set_dim({b, h_k, s_kv, d_qk}) + .set_stride( + {INT_MAX, + k_strides[strideidx0], + k_strides[strideidx1], + k_strides[strideidx2]})); + auto V_ = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(V) + .set_name("V") + .set_dim({b, h_v, s_kv, d_v}) + .set_stride( + {INT_MAX, + v_strides[strideidx0], + v_strides[strideidx1], + v_strides[strideidx2]})); + auto o_strides = o.strides(); + auto O_ = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(O) + .set_name("O") + .set_dim({b, h_q, s_q, d_v}) + .set_stride( + {INT_MAX, + o_strides[strideidx0], + o_strides[strideidx1], + o_strides[strideidx2]})); + + std::optional> bias; + if (attn_bias.has_value()) { + TORCH_CHECK( + false, + "attn_bias not yet supportd with cuDNN Attention and NestedTensor"); + bias = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(BIAS) + .set_name("bias") + .set_dim(attn_bias.value().sizes().vec()) + .set_stride(attn_bias.value().strides().vec())); + sdpa_backward_options.set_bias(bias.value()); + } + auto RAG_Q_OFF_ = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(RAG_Q_OFF) + .set_name("cum_seq_q") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + auto RAG_K_OFF_ = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(RAG_K_OFF) + .set_name("cum_seq_k") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + auto RAG_V_OFF_ = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(RAG_V_OFF) + .set_name("cum_seq_v") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + auto RAG_O_OFF_ = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(RAG_O_OFF) + .set_name("cum_seq_o") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + auto RAG_STATS_OFF_ = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(RAG_LSE_OFF) + .set_name("cum_seq_stats") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + O_->set_ragged_offset(RAG_O_OFF_); + Q_->set_ragged_offset(RAG_Q_OFF_); + K_->set_ragged_offset(RAG_K_OFF_); + V_->set_ragged_offset(RAG_V_OFF_); + auto STATS = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(LSE) + .set_name("stats") + .set_dim({b, h_q, s_q, 1}) + .set_stride({s_q * h_q, 1, h_q, 1}) + .set_data_type(fe::DataType_t::FLOAT)); + STATS->set_ragged_offset(RAG_STATS_OFF_); + auto do_strides = dO.strides(); + auto DO_ = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_ragged_offset(RAG_O_OFF_) + .set_uid(DO) + .set_name("DO") + .set_dim({b, h_q, s_q, d_v}) + .set_stride( + {INT_MAX, + do_strides[strideidx0], + do_strides[strideidx1], + do_strides[strideidx2]})); + auto [Dq, Dk, Dv] = mha_graph->sdpa_backward( + Q_, K_, V_, O_, DO_, STATS, sdpa_backward_options); + Dq->set_output(true) + .set_uid(DQ) + .set_ragged_offset(RAG_Q_OFF_) + .set_dim({b, h_q, s_q, d_qk}) + .set_stride( + {INT_MAX, + q_strides[strideidx0], + q_strides[strideidx1], + q_strides[strideidx2]}); + Dk->set_output(true) + .set_uid(DK) + .set_ragged_offset(RAG_K_OFF_) + .set_dim({b, h_k, s_kv, d_qk}) + .set_stride( + {INT_MAX, + k_strides[strideidx0], + k_strides[strideidx1], + k_strides[strideidx2]}); + Dv->set_output(true) + .set_uid(DV) + .set_ragged_offset(RAG_V_OFF_) + .set_dim({b, h_v, s_kv, d_v}) + .set_stride( + {INT_MAX, + v_strides[strideidx0], + v_strides[strideidx1], + v_strides[strideidx2]}); + + AT_CUDNN_FRONTEND_CHECK(mha_graph->validate()); + AT_CUDNN_FRONTEND_CHECK(mha_graph->build_operation_graph(handle)); + AT_CUDNN_FRONTEND_CHECK( + mha_graph->create_execution_plans({fe::HeurMode_t::A})); + AT_CUDNN_FRONTEND_CHECK(mha_graph->check_support(handle)); + AT_CUDNN_FRONTEND_CHECK(mha_graph->build_plans(handle)); + return mha_graph; } void run_cudnn_SDP_fprop( @@ -898,12 +1177,12 @@ void run_cudnn_SDP_fprop( dropout_probability, is_causal, return_softmaxstats); - auto graph_and_tensors_ptr = mhagraphcache.find(key); - graph_and_tensors graph_and_tensors_values; - if (graph_and_tensors_ptr) { - graph_and_tensors_values = *graph_and_tensors_ptr; + auto graph_ptr = mhagraphcache.find(key); + std::shared_ptr mha_graph; + if (graph_ptr) { + mha_graph = *graph_ptr; } else { - graph_and_tensors_values = build_graph_and_tensors( + mha_graph = build_graph( b, h, s_q, @@ -924,29 +1203,28 @@ void run_cudnn_SDP_fprop( _dropoutoffset, handle); } - auto [mha_graph, Q, K, V, bias, attn_scale, seed, offset, O, Stats] = - graph_and_tensors_values; - std::unordered_map, void*> - variant_pack = { - {Q, q.data_ptr()}, - {K, k.data_ptr()}, - {V, v.data_ptr()}, - {attn_scale, &scaling_factor}, - {seed, _dropoutseed.data_ptr()}, - {offset, _dropoutoffset.data_ptr()}, - {O, o.data_ptr()}}; + std::unordered_map variant_pack = { + {Q, q.data_ptr()}, + {K, k.data_ptr()}, + {V, v.data_ptr()}, + {SCALE, &scaling_factor}, + {O, o.data_ptr()}}; if (return_softmaxstats) { - variant_pack[Stats] = softmaxstats.data_ptr(); + variant_pack[LSE] = softmaxstats.data_ptr(); } if (attn_bias.has_value()) { - variant_pack[bias.value()] = attn_bias.value().data_ptr(); + variant_pack[BIAS] = attn_bias.value().data_ptr(); + } + if (dropout_probability != 0.0f) { + variant_pack[SEED] = _dropoutseed.data_ptr(); + variant_pack[OFFSET] = _dropoutoffset.data_ptr(); } auto workspace_size = mha_graph->get_workspace_size(); auto workspace_ptr = c10::cuda::CUDACachingAllocator::get()->allocate(workspace_size); TORCH_CHECK( mha_graph->execute(handle, variant_pack, workspace_ptr.get()).is_good()); - mhagraphcache.update(key, graph_and_tensors_values); + mhagraphcache.update(key, mha_graph); } void run_cudnn_SDP_fprop_nestedtensor( @@ -985,72 +1263,55 @@ void run_cudnn_SDP_fprop_nestedtensor( if (return_softmaxstats && !softmaxstats.defined()) { softmaxstats = at::empty({q.size(0), h_q, 1}, q.options().dtype(kFloat)); } - auto - [mha_graph, - Q, - K, - V, - bias, - attn_scale, - seed, - offset, - O, - Stats, - RAG_Q_OFF, - RAG_K_OFF, - RAG_V_OFF, - RAG_O_OFF, - RAG_STATS_OFF, - SEQ_LEN_Q, - SEQ_LEN_KV] = - build_graph_and_tensors_nestedtensor( - b, - h_q, - h_k, - h_v, - s_q, - s_kv, - d_qk, - d_v, - scaling_factor, - return_softmaxstats, - is_causal, - dropout_probability, - cum_seqlen_q, - cum_seqlen_kv, - q, - k, - v, - attn_bias, - softmaxstats, - o, - dropoutseed, - dropoutoffset, - handle); + auto mha_graph = build_graph_nestedtensor( + b, + h_q, + h_k, + h_v, + s_q, + s_kv, + d_qk, + d_v, + scaling_factor, + return_softmaxstats, + is_causal, + dropout_probability, + cum_seqlen_q, + cum_seqlen_kv, + q, + k, + v, + attn_bias, + softmaxstats, + o, + dropoutseed, + dropoutoffset, + handle); auto seqlen_q = at::diff(cum_seqlen_q, 1, 0); auto seqlen_kv = at::diff(cum_seqlen_kv, 1, 0); auto rag_q_off = cum_seqlen_q.mul(h_q * d_qk); - auto rag_k_off = cum_seqlen_kv.mul(h_k * d_qk); + auto rag_k_off = cum_seqlen_kv.mul(h_k * d_v); auto rag_v_off = cum_seqlen_kv.mul(h_v * d_v); auto rag_stats_off = cum_seqlen_q.mul(h_q); - std::unordered_map, void*> - variant_pack = { - {Q, q.data_ptr()}, - {K, k.data_ptr()}, - {V, v.data_ptr()}, - {attn_scale, &scaling_factor}, - {seed, dropoutseed.data_ptr()}, - {offset, dropoutoffset.data_ptr()}, - {O, o.data_ptr()}, - {RAG_Q_OFF, rag_q_off.data_ptr()}, - {RAG_O_OFF, rag_q_off.data_ptr()}, - {RAG_K_OFF, rag_k_off.data_ptr()}, - {RAG_V_OFF, rag_v_off.data_ptr()}, - {SEQ_LEN_Q, seqlen_q.data_ptr()}, - {SEQ_LEN_KV, seqlen_kv.data_ptr()}}; + std::unordered_map variant_pack = { + {Q, q.data_ptr()}, + {K, k.data_ptr()}, + {V, v.data_ptr()}, + {SCALE, &scaling_factor}, + {O, o.data_ptr()}, + {RAG_Q_OFF, rag_q_off.data_ptr()}, + {RAG_O_OFF, rag_q_off.data_ptr()}, + {RAG_K_OFF, rag_k_off.data_ptr()}, + {RAG_V_OFF, rag_v_off.data_ptr()}, + {SEQ_LEN_Q, seqlen_q.data_ptr()}, + {SEQ_LEN_KV, seqlen_kv.data_ptr()}}; if (return_softmaxstats) { - variant_pack[Stats] = softmaxstats.data_ptr(); - variant_pack[RAG_STATS_OFF] = cum_seqlen_q.data_ptr(); + variant_pack[LSE] = softmaxstats.data_ptr(); + variant_pack[RAG_LSE_OFF] = rag_stats_off.data_ptr(); + } + if (dropout_probability != 0.0f) { + variant_pack[SEED] = dropoutseed.data_ptr(); + variant_pack[OFFSET] = dropoutoffset.data_ptr(); } if (attn_bias.has_value()) { TORCH_CHECK("bias not supported with nestedtensor"); @@ -1134,12 +1395,12 @@ void run_cudnn_SDP_bprop( dropout_probability, is_causal, true); - auto graph_and_tensors_backward_ptr = mhagraphbackwardcache.find(key); - graph_and_tensors_backward graph_and_tensors_backward_values; - if (graph_and_tensors_backward_ptr) { - graph_and_tensors_backward_values = *graph_and_tensors_backward_ptr; + auto graph_backward_ptr = mhagraphbackwardcache.find(key); + std::shared_ptr mha_graph; + if (graph_backward_ptr) { + mha_graph = *graph_backward_ptr; } else { - graph_and_tensors_backward_values = build_graph_and_tensors_backward( + mha_graph = build_graph_backward( b, h, s_q, @@ -1163,41 +1424,25 @@ void run_cudnn_SDP_bprop( _dropoutoffset, handle); } - auto - [mha_graph, - Q, - K, - V, - bias, - attn_scale, - Seed, - Offset, - O, - Do, - Stats, - Dq, - Dk, - Dv] = graph_and_tensors_backward_values; - std::unordered_map, void*> - variant_pack = {// inputs - {Q, q.data_ptr()}, - {K, k.data_ptr()}, - {V, v.data_ptr()}, - {O, o.data_ptr()}, - {Do, dO_.data_ptr()}, - {Stats, softmaxstats.data_ptr()}, - // outputs - {Dq, dQ.data_ptr()}, - {Dk, dK.data_ptr()}, - {Dv, dV.data_ptr()}, - // pass by value - {attn_scale, &scaling_factor}}; + std::unordered_map variant_pack = { + // inputs + {Q, q.data_ptr()}, + {K, k.data_ptr()}, + {V, v.data_ptr()}, + {O, o.data_ptr()}, + {DO, dO_.data_ptr()}, + {LSE, softmaxstats.data_ptr()}, + // outputs + {DQ, dQ.data_ptr()}, + {DK, dK.data_ptr()}, + {DV, dV.data_ptr()}, + {SCALE, &scaling_factor}}; if (dropout_probability != 0.0f) { - variant_pack[Seed] = _dropoutseed.data_ptr(); - variant_pack[Offset] = _dropoutoffset.data_ptr(); + variant_pack[SEED] = _dropoutseed.data_ptr(); + variant_pack[OFFSET] = _dropoutoffset.data_ptr(); } if (attn_bias.has_value()) { - variant_pack[bias.value()] = attn_bias.value().data_ptr(); + variant_pack[BIAS] = attn_bias.value().data_ptr(); } auto workspace_size = mha_graph->get_workspace_size(); auto workspace_ptr = @@ -1205,7 +1450,127 @@ void run_cudnn_SDP_bprop( TORCH_CHECK(!workspace_size || workspace_ptr.get()); TORCH_CHECK( mha_graph->execute(handle, variant_pack, workspace_ptr.get()).is_good()); - mhagraphbackwardcache.update(key, graph_and_tensors_backward_values); + mhagraphbackwardcache.update(key, mha_graph); +} + +void run_cudnn_SDP_bprop_nestedtensor( + int64_t b, + int64_t h_q, + int64_t h_k, + int64_t h_v, + int64_t s_q, + int64_t s_kv, + int64_t d_qk, + int64_t d_v, + float scaling_factor, + bool is_causal, + float dropout_probability, + const Tensor& cum_seqlen_q, + const Tensor& cum_seqlen_kv, + const Tensor& q, + const Tensor& k, + const Tensor& v, + const std::optional& attn_bias, + const Tensor& o, + const Tensor& dO, + const Tensor& softmaxstats, + Tensor& dQ, + Tensor& dK, + Tensor& dV, + const Tensor& dropoutseed, + const Tensor& dropoutoffset) { + // do nothing if we got 0-element tensors + if (!q.numel() || !k.numel() || !v.numel() || !o.numel() || !dO.numel() || + !softmaxstats.numel()) { + return; + } + + Tensor dO_ = dO; + const auto innermost_dO_stride = dO.strides()[dO.strides().size() - 1]; + if (innermost_dO_stride != 1) { + permute_to_matching_layout(o, dO_); + } + + auto seqlen_q = at::diff(cum_seqlen_q, 1, 0); + auto seqlen_kv = at::diff(cum_seqlen_kv, 1, 0); + auto rag_q_off = cum_seqlen_q.mul(h_q * d_qk); + auto rag_k_off = cum_seqlen_kv.mul(h_k * d_v); + auto rag_v_off = cum_seqlen_kv.mul(h_v * d_v); + auto rag_stats_off = cum_seqlen_q.mul(h_q); + + auto dprops = at::cuda::getCurrentDeviceProperties(); + auto _dropoutseed = dropoutseed; + auto _dropoutoffset = dropoutoffset; + // cuDNN dropout bug requires these to be in int64 + if (dprops->major == 10 && dprops->minor == 0) { + _dropoutseed = dropoutseed.to(kLong); + _dropoutoffset = dropoutoffset.to(kLong); + } + + cudnnHandle_t handle = getCudnnHandle(); + + auto mha_graph = build_graph_backward_nestedtensor( + b, + h_q, + h_k, + h_v, + s_q, + s_kv, + d_qk, + d_v, + scaling_factor, + is_causal, + dropout_probability, + cum_seqlen_q, + cum_seqlen_kv, + q, + k, + v, + attn_bias, + o, + dO_, + softmaxstats, + dQ, + dK, + dV, + dropoutseed, + dropoutoffset, + handle); + + std::unordered_map variant_pack = { + // inputs + {Q, q.data_ptr()}, + {K, k.data_ptr()}, + {V, v.data_ptr()}, + {O, o.data_ptr()}, + {DO, dO_.data_ptr()}, + {LSE, softmaxstats.data_ptr()}, + // outputs + {DQ, dQ.data_ptr()}, + {DK, dK.data_ptr()}, + {DV, dV.data_ptr()}, + {SCALE, &scaling_factor}, + {RAG_Q_OFF, rag_q_off.data_ptr()}, + {RAG_O_OFF, rag_q_off.data_ptr()}, + {RAG_K_OFF, rag_k_off.data_ptr()}, + {RAG_V_OFF, rag_v_off.data_ptr()}, + {RAG_LSE_OFF, rag_stats_off.data_ptr()}, + {SEQ_LEN_Q, seqlen_q.data_ptr()}, + {SEQ_LEN_KV, seqlen_kv.data_ptr()}}; + if (dropout_probability != 0.0f) { + variant_pack[SEED] = _dropoutseed.data_ptr(); + variant_pack[OFFSET] = _dropoutoffset.data_ptr(); + } + TORCH_CHECK( + !attn_bias.has_value(), + "attn_bias not yet supportd with cuDNN Attention and NestedTensor"); + + auto workspace_size = mha_graph->get_workspace_size(); + auto workspace_ptr = + c10::cuda::CUDACachingAllocator::get()->allocate(workspace_size); + TORCH_CHECK(!workspace_size || workspace_ptr.get()); + TORCH_CHECK( + mha_graph->execute(handle, variant_pack, workspace_ptr.get()).is_good()); } } // namespace native diff --git a/aten/src/ATen/native/cudnn/MHA.h b/aten/src/ATen/native/cudnn/MHA.h index 045e8cf6dee..620abc1aa0a 100644 --- a/aten/src/ATen/native/cudnn/MHA.h +++ b/aten/src/ATen/native/cudnn/MHA.h @@ -70,4 +70,31 @@ void run_cudnn_SDP_bprop( const Tensor& dropoutseed, const Tensor& dropoutoffset); +void run_cudnn_SDP_bprop_nestedtensor( + int64_t b, + int64_t h_q, + int64_t h_k, + int64_t h_v, + int64_t s_q, + int64_t s_kv, + int64_t d_qk, + int64_t d_v, + float scaling_factor, + bool is_causal, + float dropout_probability, + const Tensor& cum_seqlen_q, + const Tensor& cum_seqlen_kv, + const Tensor& q, + const Tensor& k, + const Tensor& v, + const std::optional& attn_bias, + const Tensor& o, + const Tensor& dO, + const Tensor& softmaxstats, + Tensor& dQ, + Tensor& dK, + Tensor& dV, + const Tensor& dropoutseed, + const Tensor& dropoutoffset); + } // namespace at::native diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index a1f38b64a32..8757daebf9b 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -14958,6 +14958,7 @@ - func: _scaled_dot_product_cudnn_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor attn_bias, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, *, float? scale=None) -> (Tensor, Tensor, Tensor) dispatch: CUDA: _scaled_dot_product_cudnn_attention_backward_cuda + NestedTensorCUDA: _scaled_dot_product_cudnn_attention_nestedtensor_backward_cuda tags: nondeterministic_seeded - func: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None, SymInt? window_size_left=None, SymInt? window_size_right=None, Tensor? seqused_k=None, Tensor? alibi_slopes=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor rng_state, Tensor unused, Tensor debug_attn_mask) @@ -14990,6 +14991,11 @@ CUDA: _cudnn_attention_forward tags: nondeterministic_seeded +- func: _cudnn_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor attn_bias, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, *, float? scale=None) -> (Tensor, Tensor, Tensor) + dispatch: + CUDA: _cudnn_attention_backward + tags: nondeterministic_seeded + - func: _triton_scaled_dot_attention(Tensor q, Tensor k, Tensor v, float dropout_p=0.0) -> Tensor variants: function dispatch: diff --git a/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp b/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp index 5b747645340..2dd423f8ab4 100644 --- a/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp +++ b/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp @@ -349,6 +349,63 @@ _scaled_dot_product_cudnn_attention_nestedtensor_cuda( return std::make_tuple(std::move(attention), std::move(log_sumexp), cumulative_sequence_length_q, cumulative_sequence_length_kv, max_seqlen_batch_q, max_seqlen_batch_kv, std::move(cudnn_seed), std::move(cudnn_offset), Tensor()); } +std::tuple _scaled_dot_product_cudnn_attention_nestedtensor_backward_cuda( + const Tensor& grad_out, + const Tensor& query, + const Tensor& key, + const Tensor& value, + const Tensor& out, + const Tensor& logsumexp, + const Tensor& philox_seed, + const Tensor& philox_offset, + const Tensor& attn_bias, + const Tensor& cum_seq_q, + const Tensor& cum_seq_k, + const int64_t max_q, + const int64_t max_k, + double dropout_p, + bool is_causal, + std::optional scale) { + if (!grad_out.defined()) { + return std::make_tuple(Tensor{}, Tensor{}, Tensor{}); + } + auto [ + grad_out_buffer_reshaped, + query_buffer_reshaped, + key_buffer_reshaped, + value_buffer_reshaped, + output_buffer_reshaped] = + preprocessing::sdpa_nested_preprocessing_backward( + grad_out, + query, + key, + value, + out, + cum_seq_q, + cum_seq_k, + max_q, + max_k); + + auto [dq, dk, dv] = at::_cudnn_attention_backward(grad_out_buffer_reshaped, + query_buffer_reshaped, + key_buffer_reshaped, + value_buffer_reshaped, + output_buffer_reshaped, + logsumexp, + philox_seed, + philox_offset, + attn_bias, + cum_seq_q, + cum_seq_k, + max_q, + max_k, + dropout_p, + is_causal, + scale); + return std::make_tuple(dq, dk, dv); +} + + std::tuple _scaled_dot_product_flash_attention_backward_nested( const at::Tensor& grad_out_, const at::Tensor& query, diff --git a/aten/src/ATen/native/transformers/cuda/attention.cu b/aten/src/ATen/native/transformers/cuda/attention.cu index b4aeab10a75..70566739bab 100644 --- a/aten/src/ATen/native/transformers/cuda/attention.cu +++ b/aten/src/ATen/native/transformers/cuda/attention.cu @@ -848,16 +848,6 @@ std::tuple #include #else +#include +#include #include #include #include @@ -170,7 +172,7 @@ std::tuple _flash_attention_backward( return std::make_tuple(Tensor(), Tensor(), Tensor()); } -std::tuple _scaled_dot_product_cudnn_attention_backward_cuda( +std::tuple _cudnn_attention_backward( const Tensor& grad_out, const Tensor& query, const Tensor& key, @@ -197,57 +199,117 @@ std::tuple _scaled_dot_product_cudnn_attention_backward_ } } - const int64_t batch_size = query.size(0); - const int64_t num_heads = query.size(1); - const int64_t head_dim_qk = query.size(3); - const int64_t head_dim_v = value.size(3); + const bool is_nested = cum_seq_q.defined(); const int64_t max_seqlen_batch_q = query.size(2); const int64_t max_seqlen_batch_k = key.size(2); - // This is needed because SaveVariable automatically converts - // std::optional to undefined tensor - std::optional attn_bias_; - if (attn_bias.defined()) { - attn_bias_ = attn_bias; - } - if (attn_bias_.has_value()) { - const auto bias_dim = attn_bias_.value().dim(); - if (bias_dim == 2) { - attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k}); - } else if (bias_dim == 3) { - attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k}); - } else { - TORCH_CHECK(bias_dim == 4, "cuDNN SDPA expects either a 2D, 3D, or 4D attn_bias but got ", attn_bias_.value().dim(), "D"); - attn_bias_ = attn_bias_.value().expand({batch_size, attn_bias_.value().size(1), max_seqlen_batch_q, max_seqlen_batch_k}); - } - } + if (!is_nested) { + const int64_t batch_size = query.size(0); + const int64_t num_heads = query.size(1); + const int64_t head_dim_qk = query.size(3); + const int64_t head_dim_v = value.size(3); - const auto softmax_scale = sdp::calculate_scale(query, scale).expect_float(); - auto dq = at::empty_like(query); - auto dk = at::empty_like(key); - auto dv = at::empty_like(value); - run_cudnn_SDP_bprop(batch_size /*int64_t b*/, - num_heads /*int64_t h*/, - max_q/*int64_t s_q*/, - max_k/*int64_t s_kv*/, - head_dim_qk /*int64_t d_qk*/, - head_dim_v /*int64_t d_v*/, - softmax_scale /*float scaling_factor*/, - is_causal /*bool is_causal*/, - dropout_p /*float dropout_probability*/, - query /*const Tensor& q*/, - key /*const Tensor& k*/, - value /*const Tensor& v*/, - attn_bias_ /*const std::optional& attn_bias*/, - out /*const Tensor& o*/, - grad_out/*const Tensor& dO*/, - logsumexp.unsqueeze(-1)/*const Tensor& softmaxstats*/, - dq/*Tensor& dQ*/, - dk/*Tensor& dK*/, - dv/*Tensor& dV*/, - philox_seed/*Tensor& dropoutseed*/, - philox_offset/*Tensor& dropoutoffset*/); - return std::make_tuple(std::move(dq), std::move(dk), std::move(dv)); + // This is needed because SaveVariable automatically converts + // std::optional to undefined tensor + std::optional attn_bias_; + if (attn_bias.defined()) { + attn_bias_ = attn_bias; + } + if (attn_bias_.has_value()) { + const auto bias_dim = attn_bias_.value().dim(); + if (bias_dim == 2) { + attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k}); + } else if (bias_dim == 3) { + attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k}); + } else { + TORCH_CHECK(bias_dim == 4, "cuDNN SDPA expects either a 2D, 3D, or 4D attn_bias but got ", attn_bias_.value().dim(), "D"); + attn_bias_ = attn_bias_.value().expand({batch_size, attn_bias_.value().size(1), max_seqlen_batch_q, max_seqlen_batch_k}); + } + } + + const auto softmax_scale = sdp::calculate_scale(query, scale).expect_float(); + auto dq = at::empty_like(query); + auto dk = at::empty_like(key); + auto dv = at::empty_like(value); + run_cudnn_SDP_bprop(batch_size /*int64_t b*/, + num_heads /*int64_t h*/, + max_q/*int64_t s_q*/, + max_k/*int64_t s_kv*/, + head_dim_qk /*int64_t d_qk*/, + head_dim_v /*int64_t d_v*/, + softmax_scale /*float scaling_factor*/, + is_causal /*bool is_causal*/, + dropout_p /*float dropout_probability*/, + query /*const Tensor& q*/, + key /*const Tensor& k*/, + value /*const Tensor& v*/, + attn_bias_ /*const std::optional& attn_bias*/, + out /*const Tensor& o*/, + grad_out/*const Tensor& dO*/, + logsumexp.unsqueeze(-1)/*const Tensor& softmaxstats*/, + dq/*Tensor& dQ*/, + dk/*Tensor& dK*/, + dv/*Tensor& dV*/, + philox_seed/*Tensor& dropoutseed*/, + philox_offset/*Tensor& dropoutoffset*/); + return std::make_tuple(std::move(dq), std::move(dk), std::move(dv)); + } else { + // BHSD ... + const int64_t batch_size = cum_seq_q.size(0) - 1; + const int64_t num_heads_q = query.size(-2); + const int64_t num_heads_k = key.size(-2); + const int64_t num_heads_v = value.size(-2); + const int64_t head_dim_qk = query.size(-1); + const int64_t head_dim_v = value.size(-1); + std::optional attn_bias_; + if (attn_bias.defined()) { + attn_bias_ = attn_bias; + } + if (attn_bias_.has_value()) { + const auto bias_dim = attn_bias_.value().dim(); + if (bias_dim == 2) { + attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k}); + } else if (bias_dim == 3) { + attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k}); + } else { + attn_bias_ = attn_bias_.value().expand({batch_size, attn_bias_.value().size(1), max_seqlen_batch_q, max_seqlen_batch_k}); + TORCH_CHECK(bias_dim == 4, "cuDNN SDPA expects either a 2D, 3D, or 4D attn_bias but got ", attn_bias_.value().dim(), "D"); + } + } + + auto dq = at::empty_like(query); + auto dk = at::empty_like(key); + auto dv = at::empty_like(value); + + const auto softmax_scale = sdp::calculate_scale(query, scale).as_float_unchecked(); + run_cudnn_SDP_bprop_nestedtensor( + batch_size, + num_heads_q, + num_heads_k, + num_heads_v, + max_seqlen_batch_q, + max_seqlen_batch_k, + head_dim_qk, + head_dim_v, + softmax_scale, + is_causal, + dropout_p, + cum_seq_q, + cum_seq_k, + query, + key, + value, + attn_bias_, + out, + grad_out, + logsumexp, + dq, + dk, + dv, + philox_seed, + philox_offset); + return std::make_tuple(std::move(dq), std::move(dk), std::move(dv)); + } } std::tuple @@ -950,4 +1012,40 @@ std::tuple _scaled_dot_product_e grad_q.transpose(1, 2), grad_k.transpose(1, 2), grad_v.transpose(1, 2), grad_bias); } +std::tuple _scaled_dot_product_cudnn_attention_backward_cuda( + const Tensor& grad_out, + const Tensor& query, + const Tensor& key, + const Tensor& value, + const Tensor& out, + const Tensor& logsumexp, + const Tensor& philox_seed, + const Tensor& philox_offset, + const Tensor& attn_bias, + const Tensor& cum_seq_q, + const Tensor& cum_seq_k, + const int64_t max_q, + const int64_t max_k, + double dropout_p, + bool is_causal, + std::optional scale) { + return at::_cudnn_attention_backward( + grad_out, + query, + key, + value, + out, + logsumexp, + philox_seed, + philox_offset, + attn_bias, + cum_seq_q, + cum_seq_k, + max_q, + max_k, + dropout_p, + is_causal, + scale); +} + } // namespace at::native diff --git a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp index 45b4cf118c1..6e0ae12433d 100644 --- a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp +++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp @@ -57,12 +57,21 @@ namespace sdp { namespace { +// tracks whether we've set the default priority order once, to avoid setting +// it redundantly or overwriting a user-specified priority order +// when the priority order context manager is used before the default priority +// order is initialized the following happens: +// (1) the current priority order is queried +// (2) priority_order() is called, which initializes it to the default as init_ is false +// (3) the user-specified priority order is set +// (3.1) we are in the priority context... +// (3.2) we exit the priority context... +// (4) the previous priority order (default) is restored +bool priority_order_init_ = false; + // TODO(eqy): more benchmarking to determine whether this should include sm86/89 // Needs to be kept in-sync with test_fused_chocie in test_transformers.py bool check_prefer_cudnn_attention() { - // TODO(eqy): Re-enable by default after upgrading to a release later than 9.5.0 - // see context: https://github.com/pytorch/pytorch/issues/138340 - // return false; #if defined(CUDNN_VERSION) #if CUDNN_VERSION > 90000 @@ -79,6 +88,16 @@ bool check_prefer_cudnn_attention() { // flash_attention V2 is universally faster than efficient_attention and Math std::array priority_order(sdp_params const& params) { + if (!priority_order_init_) { + priority_order_init_ = true; + if (check_prefer_cudnn_attention()) { + const std::vector cudnn_order = {static_cast(at::SDPBackend::cudnn_attention), + static_cast(at::SDPBackend::flash_attention), + static_cast(at::SDPBackend::efficient_attention), + static_cast(at::SDPBackend::math)}; + at::globalContext().setSDPPriorityOrder(cudnn_order); + } + } return at::globalContext().sDPPriorityOrder(); } @@ -453,9 +472,15 @@ bool check_cudnn_tensor_shapes(sdp_params const& params, bool debug) { return false; } } - if (s_q == 1 || s_k == 1) { + if (s_k == 1) { if (debug) { - TORCH_WARN_ONCE("cudnn SDPA does not support sequence length 1."); + TORCH_WARN_ONCE("cudnn SDPA does not support key/value sequence length 1."); + } + return false; + } + if (s_q == 1 && params.dropout != 0.0) { + if (debug) { + TORCH_WARN_ONCE("cudnn SDPA does not support query sequence length 1 with dropout."); } return false; } @@ -563,9 +588,9 @@ bool check_for_nested_inputs(sdp_params const& params, bool debug) { const auto dprop = at::cuda::getCurrentDeviceProperties(); // Check that the input is nested - if (dprop->major != 9 && has_for_nested_inputs(params)) { + if ((dprop->major == 9 || dprop->major == 10) && has_for_nested_inputs(params)) { if (debug) { - TORCH_WARN("CuDNN SDPA supports nested tensors on SM 9.0."); + TORCH_WARN("cuDNN SDPA supports nested tensors on SM 9.0, SM 10.0."); } return false; } @@ -589,7 +614,7 @@ bool check_runtime_disabled_cudnn(sdp_params const& params, bool debug) { // sdp kernels if (!at::globalContext().userEnabledCuDNNSDP()) { if (debug) { - TORCH_WARN("CuDNN attention has been runtime disabled."); + TORCH_WARN("cuDNN attention has been runtime disabled."); } return false; } @@ -620,7 +645,7 @@ bool can_use_cudnn_attention(const sdp_params& params, bool debug) { #endif #if defined(CUDNN_VERSION) && CUDNN_VERSION < 90000 if (debug) { - TORCH_WARN(CUDNN_VERSION, " cuDNN version too old to use CuDNN Attention (< v9.0.0)"); + TORCH_WARN(CUDNN_VERSION, " cuDNN version too old to use cuDNN Attention (< v9.0.0)"); } return false; #endif @@ -630,10 +655,8 @@ bool can_use_cudnn_attention(const sdp_params& params, bool debug) { c10::array_of( check_runtime_disabled_cudnn, check_for_nested_inputs, - check_nonzero_sequence_lengths_dense, check_all_tensors_on_device, check_tensor_shapes, - check_cudnn_tensor_shapes, check_cudnn_deterministic, check_dtypes_low_precision, check_attn_mask_shape, @@ -646,8 +669,10 @@ bool can_use_cudnn_attention(const sdp_params& params, bool debug) { } constexpr auto dense_constraints = c10::array_of( + check_nonzero_sequence_lengths_dense, check_last_dim_stride_equals_1_dense, - check_batch_size_and_num_heads_dense + check_batch_size_and_num_heads_dense, + check_cudnn_tensor_shapes ); if (has_only_dense_inputs(params)) { @@ -859,7 +884,7 @@ SDPBackend select_sdp_backend(sdp_params const& kernel_params) { sdp::can_use_mem_efficient_attention(kernel_params, print_debug); TORCH_WARN("Flash attention kernel not used because:"); sdp::can_use_flash_attention(kernel_params, print_debug); - TORCH_WARN("CuDNN attention kernel not used because:"); + TORCH_WARN("cuDNN attention kernel not used because:"); sdp::can_use_cudnn_attention(kernel_params, print_debug); TORCH_CHECK(!print_debug, "No available kernel. Aborting execution.") return SDPBackend::error; diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect index 575221406cd..4585a73658e 100644 --- a/test/expect/HasDecompTest.test_has_decomposition.expect +++ b/test/expect/HasDecompTest.test_has_decomposition.expect @@ -75,6 +75,7 @@ aten::_ctc_loss.out aten::_ctc_loss_backward aten::_ctc_loss_backward.Tensor aten::_ctc_loss_backward.out +aten::_cudnn_attention_backward aten::_cudnn_attention_forward aten::_cudnn_ctc_loss aten::_cudnn_ctc_loss.Tensor diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index f53268cb24d..951289fe992 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -6746,11 +6746,10 @@ torch.cuda.synchronize() and check_cudnn and (dtype == torch.float16 or dtype == torch.bfloat16) ): - with self.assertRaisesRegex(RuntimeError, "cuDNN SDPA Nested Tensor"): - with torch.nn.attention.sdpa_kernel( - torch.nn.attention.SDPBackend.CUDNN_ATTENTION - ): - check_forward_backward() + with torch.nn.attention.sdpa_kernel( + torch.nn.attention.SDPBackend.CUDNN_ATTENTION + ): + check_forward_backward() @skipIfTorchDynamo("SDPA test compiles internally") @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") diff --git a/test/test_transformers.py b/test/test_transformers.py index a8e63d91b1a..051e5d1d5b8 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -49,7 +49,6 @@ from torch.testing._internal.common_cuda import ( PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, PLATFORM_SUPPORTS_FUSED_ATTENTION, PLATFORM_SUPPORTS_CUDNN_ATTENTION, - SM90OrLater, tf32_on_and_off, tf32_enabled, ) @@ -2994,15 +2993,18 @@ class TestSDPACudaOnly(NNTestCase): value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) + device_capability = None + if "cuda" in str(device): + device_capability = torch.cuda.get_device_capability() + prefer_cudnn = device_capability and (device_capability == (9, 0) or device_capability == (10, 0)) + # TODO we are currently disabling this by default, lets assert that this returns # FlashAttention, we need to change when we make remove opt-in for cudnn - if type != "nested" and PLATFORM_SUPPORTS_CUDNN_ATTENTION and SM90OrLater: - self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.FLASH_ATTENTION.value) - with sdpa_kernel(backends=[SDPBackend.CUDNN_ATTENTION]): - self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.CUDNN_ATTENTION.value) + if type != "nested" and PLATFORM_SUPPORTS_CUDNN_ATTENTION and prefer_cudnn: + self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.CUDNN_ATTENTION.value) elif PLATFORM_SUPPORTS_FLASH_ATTENTION: self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.FLASH_ATTENTION.value) - elif type != "nested" and PLATFORM_SUPPORTS_CUDNN_ATTENTION: # e.g., we're on Windows + elif type != "nested" and PLATFORM_SUPPORTS_CUDNN_ATTENTION and not prefer_cudnn: # e.g., we're on Windows self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.EFFICIENT_ATTENTION.value) with sdpa_kernel(backends=[SDPBackend.CUDNN_ATTENTION]): self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.CUDNN_ATTENTION.value) diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index fe4dd72b247..ba3b7e36054 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -2896,6 +2896,10 @@ output_differentiability: [True, False, False, False, False, False] query, key, value, bias: _efficient_attention_backward_symint(grad, query, key, value, bias, output, cu_seqlens_q, cu_seqlens_k, max_seqlen_batch_q, max_seqlen_batch_k, logsumexp, dropout_p, philox_seed, philox_offset, custom_mask_type, bias.requires_grad(), scale) +- name: _cudnn_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) + output_differentiability: [True, False, False, False, False, False, False, False, False] + query, key, value: _cudnn_attention_backward_symint(grad, query, key, value, output, logsumexp, philox_seed, philox_offset, attn_bias, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, scale) + - name: _scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) output_differentiability: [True, False, False, False, False, False, False, False, False] query, key, value: _scaled_dot_product_cudnn_attention_backward_symint(grad, query, key, value, output, logsumexp, philox_seed, philox_offset, attn_bias, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, scale)