diff --git a/aten/src/ATen/native/cudnn/MHA.cpp b/aten/src/ATen/native/cudnn/MHA.cpp index 54320cd4656..c2f7ce2ac2d 100644 --- a/aten/src/ATen/native/cudnn/MHA.cpp +++ b/aten/src/ATen/native/cudnn/MHA.cpp @@ -482,7 +482,7 @@ auto build_graph( auto scaled_dot_product_flash_attention_options = fe::graph::SDPA_attributes() .set_name("CUDNN_SDPA") - .set_is_inference(return_softmaxstats == false) + .set_generate_stats(return_softmaxstats) .set_causal_mask(is_causal) .set_attn_scale(attn_scale); if (use_ragged_in_dense(q, k, v, o, attn_bias.has_value())) { @@ -702,7 +702,7 @@ auto build_graph_nestedtensor( auto scaled_dot_product_flash_attention_options = fe::graph::SDPA_attributes() .set_name("CUDNN_SDPA_NESTEDTENSOR") - .set_is_inference(return_softmaxstats == false) + .set_generate_stats(return_softmaxstats) .set_causal_mask(is_causal) .set_attn_scale(attn_scale) .set_seq_len_q(SEQ_LEN_Q_) diff --git a/third_party/cudnn_frontend b/third_party/cudnn_frontend index f937055efc6..1a7b4b78db4 160000 --- a/third_party/cudnn_frontend +++ b/third_party/cudnn_frontend @@ -1 +1 @@ -Subproject commit f937055efc6d414d11f4c6577e3977fe74f35fb6 +Subproject commit 1a7b4b78db44712fb9707d21cd2e3179f1fd88b8