From fa8e3a28a7bab47742d7f791a24eeb87b04469ca Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Sat, 22 Feb 2025 17:28:12 +0000 Subject: [PATCH] Revert "[cuDNN][SDPA][Nested Tensor] Experimental cuDNN Nested Tensor SDPA Support (forward only) (#141178)" This reverts commit 533b884870acd951e684e0bf551eb76904dec047. Reverted https://github.com/pytorch/pytorch/pull/141178 on behalf of https://github.com/jeanschmidt due to Broke internal arvr signals, see D69971019. @jbschlosser please help the author get this PR merged ([comment](https://github.com/pytorch/pytorch/pull/141178#issuecomment-2676317470)) --- aten/src/ATen/cuda/detail/UnpackRaw.cuh | 6 - aten/src/ATen/native/cudnn/MHA.cpp | 345 +----------------- aten/src/ATen/native/cudnn/MHA.h | 24 -- aten/src/ATen/native/native_functions.yaml | 6 - .../nested/NestedTensorTransformerUtils.h | 3 +- .../cuda/NestedTensorTransformerFunctions.cpp | 29 -- .../native/transformers/cuda/attention.cu | 341 ++++------------- .../native/transformers/cuda/sdp_utils.cpp | 30 +- ...asDecompTest.test_has_decomposition.expect | 1 - test/test_nestedtensor.py | 80 +--- test/test_transformers.py | 34 -- torch/nested/_internal/sdpa.py | 78 +--- 12 files changed, 114 insertions(+), 863 deletions(-) diff --git a/aten/src/ATen/cuda/detail/UnpackRaw.cuh b/aten/src/ATen/cuda/detail/UnpackRaw.cuh index 3a458c756da..70cd222a484 100644 --- a/aten/src/ATen/cuda/detail/UnpackRaw.cuh +++ b/aten/src/ATen/cuda/detail/UnpackRaw.cuh @@ -25,10 +25,4 @@ unpack(at::PhiloxCudaState arg) { } } -// Adapted from TE -// extract seed and offset from PhiloxCudaState -__global__ void unpack_cudnn(at::PhiloxCudaState arg, int64_t* seed_ptr, int64_t* offset_ptr); - -void unpack_cudnn_wrapper(at::PhiloxCudaState arg, int64_t* seed_ptr, int64_t* offset_ptr, cudaStream_t stream); - } // namespace at::cuda::philox diff --git a/aten/src/ATen/native/cudnn/MHA.cpp b/aten/src/ATen/native/cudnn/MHA.cpp index c38d4a095c0..26f5c4931ba 100644 --- a/aten/src/ATen/native/cudnn/MHA.cpp +++ b/aten/src/ATen/native/cudnn/MHA.cpp @@ -31,33 +31,6 @@ void run_cudnn_SDP_fprop( false, "PyTorch was not compiled with cuDNN Flash Attention enabled!"); } -void run_cudnn_SDP_fprop_nestedtensor( - int64_t b, - int64_t h_q, - int64_t h_k, - int64_t h_v, - int64_t s_q, - int64_t s_kv, - int64_t d_qk, - int64_t d_v, - float scaling_factor, - bool return_softmaxstats, - bool is_causal, - double dropout_probability, - const Tensor& cum_seqlen_q, - const Tensor& cum_seqlen_kv, - const Tensor& q, - const Tensor& k, - const Tensor& v, - const std::optional& attn_bias, - Tensor& softmaxstats, - Tensor& o, - Tensor& dropoutseed, - Tensor& dropoutoffset) { - TORCH_CHECK( - false, "PyTorch was not compiled with cuDNN Flash Attention enabled!"); -} - void run_cudnn_SDP_bprop( int64_t b, int64_t h, @@ -488,6 +461,16 @@ auto build_graph_and_tensors( .set_stride(attn_bias.value().strides().vec())); scaled_dot_product_flash_attention_options.set_bias(bias.value()); } + auto seq_q = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Seq_q") + .set_dim({b, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + auto seq_kv = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Seq_kv") + .set_dim({b, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); auto [O, Stats] = mha_graph->sdpa(Q, K, V, scaled_dot_product_flash_attention_options); @@ -517,201 +500,6 @@ auto build_graph_and_tensors( std::move(Stats)); } -auto build_graph_and_tensors_nestedtensor( - int64_t b, - int64_t h_q, - int64_t h_k, - int64_t h_v, - int64_t s_q, - int64_t s_kv, - int64_t d_qk, - int64_t d_v, - float scaling_factor, - bool return_softmaxstats, - bool is_causal, - double dropout_probability, - const Tensor& cum_seqlen_q, - const Tensor& cum_seqlen_kv, - const Tensor& q, - const Tensor& k, - const Tensor& v, - const std::optional& attn_bias, - Tensor& softmaxstats, - Tensor& o, - Tensor& dropoutseed, - Tensor& dropoutoffset, - cudnnHandle_t& handle) { - auto dtype = fe::DataType_t::HALF; - if (q.scalar_type() == kBFloat16) { - dtype = fe::DataType_t::BFLOAT16; - } - auto mha_graph = std::make_shared(); - // We're baking in float accumulation and scale types - // in theory the graph may support other types, but they - // have not been tested - mha_graph->set_io_data_type(dtype) - .set_intermediate_data_type(fe::DataType_t::FLOAT) - .set_compute_data_type(fe::DataType_t::FLOAT); - auto attn_scale = - mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Attn_scale") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_is_pass_by_value(true) - .set_data_type(fe::DataType_t::FLOAT)); - auto seed = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Seed") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); - auto offset = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Offset") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); - auto SEQ_LEN_Q = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Seq_q") - .set_dim({b, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); - auto SEQ_LEN_KV = - mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Seq_kv") - .set_dim({b, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); - - auto scaled_dot_product_flash_attention_options = - fe::graph::SDPA_attributes() - .set_name("CUDNN_SDPA_NESTEDTENSOR") - .set_is_inference(return_softmaxstats == false) - .set_causal_mask(is_causal) - .set_attn_scale(attn_scale) - .set_dropout(dropout_probability, seed, offset) - .set_seq_len_q(SEQ_LEN_Q) - .set_seq_len_kv(SEQ_LEN_KV) - .set_padding_mask(true); - // We hardcode BSHD to cuDNN even though the underlying layout is THD - auto q_strides = q.strides(); - auto k_strides = k.strides(); - auto v_strides = v.strides(); - constexpr int strideidx0 = 1; - constexpr int strideidx1 = 0; - constexpr int strideidx2 = 2; - auto Q = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Q") - .set_dim({b, h_q, s_q, d_qk}) - .set_stride( - {INT_MAX, - q_strides[strideidx0], - q_strides[strideidx1], - q_strides[strideidx2]})); - auto K = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("K") - .set_dim({b, h_k, s_kv, d_qk}) - .set_stride( - {INT_MAX, - k_strides[strideidx0], - k_strides[strideidx1], - k_strides[strideidx2]})); - auto V = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("V") - .set_dim({b, h_v, s_kv, d_v}) - .set_stride( - {INT_MAX, - v_strides[strideidx0], - v_strides[strideidx1], - v_strides[strideidx2]})); - std::optional> bias; - if (attn_bias.has_value()) { - TORCH_CHECK( - false, - "attn_bias not yet supportd with cuDNN Attention and NestedTensor"); - bias = - mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("bias") - .set_dim(attn_bias.value().sizes().vec()) - .set_stride(attn_bias.value().strides().vec())); - scaled_dot_product_flash_attention_options.set_bias(bias.value()); - } - auto RAG_Q_OFF = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("cum_seq_q") - .set_dim({b + 1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); - auto RAG_K_OFF = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("cum_seq_k") - .set_dim({b + 1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); - auto RAG_V_OFF = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("cum_seq_v") - .set_dim({b + 1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); - auto RAG_O_OFF = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("cum_seq_o") - .set_dim({b + 1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); - // auto RAG_STATS_OFF = mha_graph->tensor(fe::graph::Tensor_attributes() - // .set_name("cum_seq_stats") - // .set_dim({b + 1, 1, 1, 1}) - // .set_stride({1, 1, 1, 1}) - // .set_data_type(fe::DataType_t::INT32)); - auto RAG_STATS_OFF = nullptr; - Q->set_ragged_offset(RAG_Q_OFF); - K->set_ragged_offset(RAG_K_OFF); - V->set_ragged_offset(RAG_V_OFF); - auto [O, Stats] = - mha_graph->sdpa(Q, K, V, scaled_dot_product_flash_attention_options); - auto o_strides = o.strides(); - O->set_output(true) - .set_dim({b, h_q, s_q, d_v}) - .set_stride( - {INT_MAX, - o_strides[strideidx0], - o_strides[strideidx1], - o_strides[strideidx2]}); - - O->set_ragged_offset(RAG_O_OFF); - if (Stats) { - TORCH_CHECK( - false, - "cuDNN SDPA Nested Tensor does not yet handle backwards/logsumexp computation"); - // TODO(eqy): fix when stats (backward) support is added - Stats->set_output(true) - .set_data_type(fe::DataType_t::FLOAT) - .set_dim({b, h_q, s_q, 1}) - .set_stride({h_q * s_q * d_v, d_v, s_q * d_v, 1}); - Stats->set_ragged_offset(RAG_STATS_OFF); - } - AT_CUDNN_FRONTEND_CHECK(mha_graph->validate()); - AT_CUDNN_FRONTEND_CHECK(mha_graph->build_operation_graph(handle)); - AT_CUDNN_FRONTEND_CHECK( - mha_graph->create_execution_plans({fe::HeurMode_t::A})); - AT_CUDNN_FRONTEND_CHECK(mha_graph->check_support(handle)); - AT_CUDNN_FRONTEND_CHECK(mha_graph->build_plans(handle)); - return std::make_tuple( - std::move(mha_graph), - std::move(Q), - std::move(K), - std::move(V), - std::move(bias), - std::move(attn_scale), - std::move(seed), - std::move(offset), - std::move(O), - std::move(Stats), - std::move(RAG_Q_OFF), - std::move(RAG_K_OFF), - std::move(RAG_V_OFF), - std::move(RAG_O_OFF), - std::move(RAG_STATS_OFF), - std::move(SEQ_LEN_Q), - std::move(SEQ_LEN_KV)); -} - auto build_graph_and_tensors_backward( int64_t b, int64_t h, @@ -949,119 +737,6 @@ void run_cudnn_SDP_fprop( mhagraphcache.update(key, graph_and_tensors_values); } -void run_cudnn_SDP_fprop_nestedtensor( - int64_t b, - int64_t h_q, - int64_t h_k, - int64_t h_v, - int64_t s_q, - int64_t s_kv, - int64_t d_qk, - int64_t d_v, - float scaling_factor, - bool return_softmaxstats, - bool is_causal, - double dropout_probability, - const Tensor& cum_seqlen_q, - const Tensor& cum_seqlen_kv, - const Tensor& q, - const Tensor& k, - const Tensor& v, - const std::optional& attn_bias, - Tensor& softmaxstats, - Tensor& o, - Tensor& dropoutseed, - Tensor& dropoutoffset) { - cudnnHandle_t handle = getCudnnHandle(); - // do nothing if we got 0-element tensors - if (!q.numel() || !k.numel() || !v.numel()) { - return; - } - - if (!o.defined()) { - o = at::empty({q.size(0), h_q, d_v}, q.options()); - } - - if (return_softmaxstats && !softmaxstats.defined()) { - softmaxstats = at::empty({q.size(0), h_q, 1}, q.options().dtype(kFloat)); - } - auto - [mha_graph, - Q, - K, - V, - bias, - attn_scale, - seed, - offset, - O, - Stats, - RAG_Q_OFF, - RAG_K_OFF, - RAG_V_OFF, - RAG_O_OFF, - RAG_STATS_OFF, - SEQ_LEN_Q, - SEQ_LEN_KV] = - build_graph_and_tensors_nestedtensor( - b, - h_q, - h_k, - h_v, - s_q, - s_kv, - d_qk, - d_v, - scaling_factor, - return_softmaxstats, - is_causal, - dropout_probability, - cum_seqlen_q, - cum_seqlen_kv, - q, - k, - v, - attn_bias, - softmaxstats, - o, - dropoutseed, - dropoutoffset, - handle); - auto seqlen_q = at::diff(cum_seqlen_q, 1, 0); - auto seqlen_kv = at::diff(cum_seqlen_kv, 1, 0); - auto rag_q_off = cum_seqlen_q.mul(h_q * d_qk); - auto rag_k_off = cum_seqlen_kv.mul(h_k * d_qk); - auto rag_v_off = cum_seqlen_kv.mul(h_v * d_v); - auto rag_stats_off = cum_seqlen_q.mul(h_q); - std::unordered_map, void*> - variant_pack = { - {Q, q.data_ptr()}, - {K, k.data_ptr()}, - {V, v.data_ptr()}, - {attn_scale, &scaling_factor}, - {seed, dropoutseed.data_ptr()}, - {offset, dropoutoffset.data_ptr()}, - {O, o.data_ptr()}, - {RAG_Q_OFF, rag_q_off.data_ptr()}, - {RAG_O_OFF, rag_q_off.data_ptr()}, - {RAG_K_OFF, rag_k_off.data_ptr()}, - {RAG_V_OFF, rag_v_off.data_ptr()}, - {SEQ_LEN_Q, seqlen_q.data_ptr()}, - {SEQ_LEN_KV, seqlen_kv.data_ptr()}}; - if (return_softmaxstats) { - variant_pack[Stats] = softmaxstats.data_ptr(); - variant_pack[RAG_STATS_OFF] = cum_seqlen_q.data_ptr(); - } - if (attn_bias.has_value()) { - TORCH_CHECK("bias not supported with nestedtensor"); - } - auto workspace_size = mha_graph->get_workspace_size(); - auto workspace_ptr = - c10::cuda::CUDACachingAllocator::get()->allocate(workspace_size); - TORCH_CHECK( - mha_graph->execute(handle, variant_pack, workspace_ptr.get()).is_good()); -} - void run_cudnn_SDP_bprop( int64_t b, int64_t h, diff --git a/aten/src/ATen/native/cudnn/MHA.h b/aten/src/ATen/native/cudnn/MHA.h index 045e8cf6dee..1bd3deb2be3 100644 --- a/aten/src/ATen/native/cudnn/MHA.h +++ b/aten/src/ATen/native/cudnn/MHA.h @@ -23,30 +23,6 @@ void run_cudnn_SDP_fprop( Tensor& dropoutseed, Tensor& dropoutoffset); -void run_cudnn_SDP_fprop_nestedtensor( - int64_t b, - int64_t h_q, - int64_t h_k, - int64_t h_v, - int64_t max_s_q, - int64_t max_s_kv, - int64_t d_k, - int64_t d_v, - float scaling_factor, - bool isTraining, - bool is_causal, - double dropout_probability, - const Tensor& cum_seqlen_q, - const Tensor& cum_seqlen_kv, - const Tensor& q, - const Tensor& k, - const Tensor& v, - const std::optional& attn_bias, - Tensor& softmaxstats, - Tensor& o, - Tensor& dropoutseed, - Tensor& dropoutoffset); - void run_cudnn_SDP_bprop( int64_t b, int64_t h, diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 429457633e6..52f9547d470 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -14902,7 +14902,6 @@ - func: _scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) dispatch: CUDA: _scaled_dot_product_cudnn_attention_cuda - NestedTensorCUDA: _scaled_dot_product_cudnn_attention_nestedtensor_cuda tags: nondeterministic_seeded - func: _scaled_dot_product_cudnn_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor attn_bias, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, *, float? scale=None) -> (Tensor, Tensor, Tensor) @@ -14935,11 +14934,6 @@ dispatch: CUDA: _efficient_attention_backward -- func: _cudnn_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) - dispatch: - CUDA: _cudnn_attention_forward - tags: nondeterministic_seeded - - func: _triton_scaled_dot_attention(Tensor q, Tensor k, Tensor v, float dropout_p=0.0) -> Tensor variants: function dispatch: diff --git a/aten/src/ATen/native/nested/NestedTensorTransformerUtils.h b/aten/src/ATen/native/nested/NestedTensorTransformerUtils.h index a9082a7dfa4..d3acf229a23 100644 --- a/aten/src/ATen/native/nested/NestedTensorTransformerUtils.h +++ b/aten/src/ATen/native/nested/NestedTensorTransformerUtils.h @@ -1,5 +1,4 @@ -#pragma once -#include +#include namespace at::native::preprocessing { diff --git a/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp b/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp index 5b747645340..5aa34bd10f6 100644 --- a/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp +++ b/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp @@ -18,8 +18,6 @@ #include #include -#include -#include namespace at::native { namespace { @@ -322,33 +320,6 @@ _scaled_dot_product_efficient_attention_nestedtensor_cuda( return std::make_tuple(std::move(attention), std::move(log_sumexp), std::move(seed), std::move(offset)); } -std::tuple -_scaled_dot_product_cudnn_attention_nestedtensor_cuda( - const Tensor& query, - const Tensor& key, - const Tensor& value, - const std::optional& attn_bias, - bool compute_logsumexp, - double dropout_p, - bool is_causal, - bool return_debug_mask, - std::optional scale) { - - auto [ - query_buffer_reshaped, - key_buffer_reshaped, - value_buffer_reshaped, - cumulative_sequence_length_q, - cumulative_sequence_length_kv, - max_seqlen_batch_q, - max_seqlen_batch_kv, - output_shape] = preprocessing::sdpa_nested_preprocessing(query, key, value); - auto [attention, log_sumexp, ignore1, ignore2, ignore3, ignore4, cudnn_seed, cudnn_offset, ignore5] = at::_cudnn_attention_forward(query_buffer_reshaped, key_buffer_reshaped, value_buffer_reshaped, attn_bias, cumulative_sequence_length_q, cumulative_sequence_length_kv, max_seqlen_batch_q, max_seqlen_batch_kv, compute_logsumexp, dropout_p, is_causal, return_debug_mask, scale); - - attention = wrap_buffer(attention.view(-1), output_shape).transpose(1, 2); - return std::make_tuple(std::move(attention), std::move(log_sumexp), cumulative_sequence_length_q, cumulative_sequence_length_kv, max_seqlen_batch_q, max_seqlen_batch_kv, std::move(cudnn_seed), std::move(cudnn_offset), Tensor()); -} - std::tuple _scaled_dot_product_flash_attention_backward_nested( const at::Tensor& grad_out_, const at::Tensor& query, diff --git a/aten/src/ATen/native/transformers/cuda/attention.cu b/aten/src/ATen/native/transformers/cuda/attention.cu index 4275b2370fc..7fe7ee7a1ba 100644 --- a/aten/src/ATen/native/transformers/cuda/attention.cu +++ b/aten/src/ATen/native/transformers/cuda/attention.cu @@ -26,8 +26,6 @@ #include #include #else -#include -#include #include #include #include @@ -65,7 +63,6 @@ #include #include -#include #include #include #include @@ -90,25 +87,6 @@ namespace at { -namespace cuda::philox { - -__global__ void unpack_cudnn(at::PhiloxCudaState arg, int64_t* seed_ptr, int64_t* offset_ptr) { - if (arg.captured_) { - *seed_ptr = static_cast(*arg.seed_.ptr); - *offset_ptr = static_cast( - *(arg.offset_.ptr) + static_cast(arg.offset_intragraph_)); - } else { - *seed_ptr = static_cast(arg.seed_.val); - *offset_ptr = static_cast(arg.offset_.val); - } -} - -void unpack_cudnn_wrapper(at::PhiloxCudaState arg, int64_t* seed_ptr, int64_t* offset_ptr, cudaStream_t stream) { -at::cuda::philox::unpack_cudnn<<<1, 1, 0, stream>>>(arg, seed_ptr, offset_ptr); -} - -} // namespace cuda::philox - namespace native { namespace { @@ -754,177 +732,16 @@ std::tuple _cudnn_attention_forward( - const Tensor& query, - const Tensor& key, - const Tensor& value, - const std::optional& attn_bias, - const std::optional& cumulative_sequence_length_q, - const std::optional& cumulative_sequence_length_kv, - long max_seqlen_batch_q, - long max_seqlen_batch_kv, - bool compute_logsumexp, - double dropout_p, - bool is_causal, - bool return_debug_mask, - std::optional scale) { - // TODO(eqy): debug mask support - // Query (Batch x Num_heads x Q_seq_len x Dim_per_head) - // Key (Batch x Num_heads x KV_seq_len x Dim_per_head) - // Value (Batch x Num_heads x KV_seq_len x Dim_per_head) - const bool is_nested = cumulative_sequence_length_q.has_value(); - if (!is_nested) { - const int64_t batch_size = query.size(0); - const int64_t num_heads = query.size(1); - const int64_t head_dim_qk = query.size(3); - const int64_t head_dim_v = value.size(3); - auto attn_bias_ = attn_bias; - if (attn_bias_.has_value()) { - const auto bias_dim = attn_bias_.value().dim(); - if (bias_dim == 2) { - attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_kv}); - } else if (bias_dim == 3) { - attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_kv}); - } else { - TORCH_CHECK(bias_dim == 4, "cuDNN SDPA expects either a 2D, 3D, or 4D attn_bias but got ", attn_bias_.value().dim(), "D"); - attn_bias_ = attn_bias_.value().expand({batch_size, attn_bias_.value().size(1), max_seqlen_batch_q, max_seqlen_batch_kv}); - } - } - - Tensor attention, log_sumexp; - at::Tensor cudnn_seed, cudnn_offset; - cudnn_seed = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); - cudnn_offset = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); - - const bool use_dropout = std::fpclassify(dropout_p) != FP_ZERO; - - // See Note [Seed and Offset Device] in _efficient_attention_forward - at::PhiloxCudaState philox_state; - const bool in_capture_stream = - at::cuda::currentStreamCaptureStatus() != at::cuda::CaptureStatus::None; - if (use_dropout) { - // Device - auto gen = at::get_generator_or_default( - std::nullopt, at::cuda::detail::getDefaultCUDAGenerator()); - - // See Note [Acquire lock when using random generators] - std::lock_guard lock(gen->mutex_); - // if using dropout, we produce 1 random number for each element of the - // attention tensor - // TODO(eqy): should state be advanced per thread (local) amount or per call/launch (global) amount - philox_state = gen->philox_cuda_state(batch_size * num_heads * max_seqlen_batch_q * max_seqlen_batch_kv); - at::cuda::philox::unpack_cudnn_wrapper( - philox_state, static_cast(cudnn_seed.data_ptr()), static_cast(cudnn_offset.data_ptr()), at::cuda::getCurrentCUDAStream()); - } - - const auto softmax_scale = sdp::calculate_scale(query, scale).expect_float(); - Tensor debugmask; - - run_cudnn_SDP_fprop(batch_size/*int64_t b*/, - num_heads/*int64_t h*/, - max_seqlen_batch_q/*int64_t s_q*/, - max_seqlen_batch_kv/*int64_t s_kv*/, - head_dim_qk/*int64_t d_qk*/, - head_dim_v/*int64_t d_v*/, - softmax_scale/*float scaling_factor*/, - compute_logsumexp/* bool */, - is_causal/* bool */, - dropout_p/*double dropout_probability*/, - query/* Tensor q*/, - key/* Tensor k*/, - value/* Tensor v*/, - attn_bias_ /* std::optional */, - log_sumexp/*Tensor softmaxstats*/, - attention/*Tensor o*/, - cudnn_seed/*Tensor dropoutseed*/, - cudnn_offset/*Tensor dropoutoffset*/); - - // TODO(eqy): support debug_attn_mask - return std::make_tuple(std::move(attention), std::move(log_sumexp), Tensor(), Tensor(), max_seqlen_batch_q, max_seqlen_batch_kv, std::move(cudnn_seed), std::move(cudnn_offset), Tensor()); +// Adapted from TE +// extract seed and offset from PhiloxCudaState +__global__ void unpack_cudnn(at::PhiloxCudaState arg, int64_t* seed_ptr, int64_t* offset_ptr) { + if (arg.captured_) { + *seed_ptr = static_cast(*arg.seed_.ptr); + *offset_ptr = static_cast( + *(arg.offset_.ptr) + static_cast(arg.offset_intragraph_)); } else { - //auto [ - // query_buffer_reshaped, - // key_buffer_reshaped, - // value_buffer_reshaped, - // cumulative_sequence_length_q, - // cumulative_sequence_length_kv, - // max_seqlen_batch_q, - // max_seqlen_batch_kv, - // output_shape] = preprocessing::sdpa_nested_preprocessing(query, key, value); - // C10_LOG_API_USAGE_ONCE("torch.sdpa.flash_attention_cudnn"); - // TODO(eqy): debug mask support - // BHSD ... - const int64_t batch_size = cumulative_sequence_length_q.value().size(0) - 1; - const int64_t num_heads_q = query.size(-2); - const int64_t num_heads_k = key.size(-2); - const int64_t num_heads_v = value.size(-2); - const int64_t head_dim_qk = query.size(-1); - const int64_t head_dim_v = value.size(-1); - auto attn_bias_ = attn_bias; - if (attn_bias_.has_value()) { - const auto bias_dim = attn_bias_.value().dim(); - if (bias_dim == 2) { - attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_kv}); - } else if (bias_dim == 3) { - attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_kv}); - } else { - attn_bias_ = attn_bias_.value().expand({batch_size, attn_bias_.value().size(1), max_seqlen_batch_q, max_seqlen_batch_kv}); - TORCH_CHECK(bias_dim == 4, "cuDNN SDPA expects either a 2D, 3D, or 4D attn_bias but got ", attn_bias_.value().dim(), "D"); - } - } - - Tensor attention, log_sumexp; - - at::Tensor cudnn_seed, cudnn_offset; - cudnn_seed = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); - cudnn_offset = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); - - const bool use_dropout = std::fpclassify(dropout_p) != FP_ZERO; - - // See Note [Seed and Offset Device] in _efficient_attention_forward - at::PhiloxCudaState philox_state; - const bool in_capture_stream = - at::cuda::currentStreamCaptureStatus() != at::cuda::CaptureStatus::None; - if (use_dropout) { - // Device - auto gen = at::get_generator_or_default( - std::nullopt, at::cuda::detail::getDefaultCUDAGenerator()); - - // See Note [Acquire lock when using random generators] - std::lock_guard lock(gen->mutex_); - // if using dropout, we produce 1 random number for each element of the - // attention tensor - // TODO(eqy): should state be advanced per thread (local) amount or per call/launch (global) amount - philox_state = gen->philox_cuda_state(batch_size * num_heads_q * max_seqlen_batch_q * max_seqlen_batch_kv); - at::cuda::philox::unpack_cudnn_wrapper(philox_state, static_cast(cudnn_seed.data_ptr()), static_cast(cudnn_offset.data_ptr()), at::cuda::getCurrentCUDAStream()); - } - - const auto softmax_scale = sdp::calculate_scale(query, scale).as_float_unchecked(); - - run_cudnn_SDP_fprop_nestedtensor(batch_size/*int64_t b*/, - num_heads_q/*int64_t h*/, - num_heads_k, - num_heads_v, - max_seqlen_batch_q/*int64_t s_q*/, - max_seqlen_batch_kv/*int64_t s_kv*/, - head_dim_qk/*int64_t d_qk*/, - head_dim_v/*int64_t d_v*/, - softmax_scale/*float scaling_factor*/, - compute_logsumexp/* bool */, - is_causal/* bool */, - dropout_p/*double dropout_probability*/, - cumulative_sequence_length_q.value(), - cumulative_sequence_length_kv.value(), - query/* Tensor q*/, - key/* Tensor k*/, - value/* Tensor v*/, - attn_bias_ /* std::optional */, - log_sumexp/*Tensor softmaxstats*/, - attention/*Tensor o*/, - cudnn_seed/*Tensor dropoutseed*/, - cudnn_offset/*Tensor dropoutoffset*/); - //attention = wrap_buffer(attention.view(-1), output_shape).transpose(1, 2); - return std::make_tuple(std::move(attention), std::move(log_sumexp), cumulative_sequence_length_q.value(), cumulative_sequence_length_kv.value(), max_seqlen_batch_q, max_seqlen_batch_kv, std::move(cudnn_seed), std::move(cudnn_offset), Tensor()); + *seed_ptr = static_cast(arg.seed_.val); + *offset_ptr = static_cast(arg.offset_.val); } } @@ -940,88 +757,84 @@ std::tuple scale) { // Used for tracking usage statistics C10_LOG_API_USAGE_ONCE("torch.sdpa.flash_attention_cudnn"); + // TODO(eqy): debug mask support + // Query (Batch x Num_heads x Q_seq_len x Dim_per_head) + // Key (Batch x Num_heads x KV_seq_len x Dim_per_head) + // Value (Batch x Num_heads x KV_seq_len x Dim_per_head) + const int64_t batch_size = query.size(0); + const int64_t num_heads = query.size(1); const int64_t max_seqlen_batch_q = query.size(2); + const int64_t head_dim_qk = query.size(3); + const int64_t head_dim_v = value.size(3); const int64_t max_seqlen_batch_k = key.size(2); + const int64_t max_seqlen_batch_v = value.size(2); + TORCH_CHECK( + max_seqlen_batch_k == max_seqlen_batch_v, + "Key and Value must have the same sequence length"); + auto attn_bias_ = attn_bias; + if (attn_bias_.has_value()) { + const auto bias_dim = attn_bias_.value().dim(); + if (bias_dim == 2) { + attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k}); + } else if (bias_dim == 3) { + attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k}); + } else { + TORCH_CHECK(bias_dim == 4, "cuDNN SDPA expects either a 2D, 3D, or 4D attn_bias but got ", attn_bias_.value().dim(), "D"); + attn_bias_ = attn_bias_.value().expand({batch_size, attn_bias_.value().size(1), max_seqlen_batch_q, max_seqlen_batch_k}); + } + } - return at::_cudnn_attention_forward(query, key, value, attn_bias, std::nullopt, std::nullopt, max_seqlen_batch_q, max_seqlen_batch_k, compute_logsumexp, dropout_p, is_causal, return_debug_mask, scale); - //// TODO(eqy): debug mask support - //// Query (Batch x Num_heads x Q_seq_len x Dim_per_head) - //// Key (Batch x Num_heads x KV_seq_len x Dim_per_head) - //// Value (Batch x Num_heads x KV_seq_len x Dim_per_head) - //const int64_t batch_size = query.size(0); - //const int64_t num_heads = query.size(1); - //const int64_t max_seqlen_batch_q = query.size(2); - //const int64_t head_dim_qk = query.size(3); - //const int64_t head_dim_v = value.size(3); - //const int64_t max_seqlen_batch_k = key.size(2); - //const int64_t max_seqlen_batch_v = value.size(2); - //TORCH_CHECK( - // max_seqlen_batch_k == max_seqlen_batch_v, - // "Key and Value must have the same sequence length"); - //auto attn_bias_ = attn_bias; - //if (attn_bias_.has_value()) { - // const auto bias_dim = attn_bias_.value().dim(); - // if (bias_dim == 2) { - // attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k}); - // } else if (bias_dim == 3) { - // attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k}); - // } else { - // TORCH_CHECK(bias_dim == 4, "cuDNN SDPA expects either a 2D, 3D, or 4D attn_bias but got ", attn_bias_.value().dim(), "D"); - // attn_bias_ = attn_bias_.value().expand({batch_size, attn_bias_.value().size(1), max_seqlen_batch_q, max_seqlen_batch_k}); - // } - //} + Tensor attention, log_sumexp; - //Tensor attention, log_sumexp; + at::Tensor cudnn_seed, cudnn_offset; + cudnn_seed = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); + cudnn_offset = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); - //at::Tensor cudnn_seed, cudnn_offset; - //cudnn_seed = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); - //cudnn_offset = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); + const bool use_dropout = std::fpclassify(dropout_p) != FP_ZERO; - //const bool use_dropout = std::fpclassify(dropout_p) != FP_ZERO; + // See Note [Seed and Offset Device] in _efficient_attention_forward + at::PhiloxCudaState philox_state; + const bool in_capture_stream = + at::cuda::currentStreamCaptureStatus() != at::cuda::CaptureStatus::None; + if (use_dropout) { + // Device + auto gen = at::get_generator_or_default( + std::nullopt, at::cuda::detail::getDefaultCUDAGenerator()); - //// See Note [Seed and Offset Device] in _efficient_attention_forward - //at::PhiloxCudaState philox_state; - //const bool in_capture_stream = - // at::cuda::currentStreamCaptureStatus() != at::cuda::CaptureStatus::None; - //if (use_dropout) { - // // Device - // auto gen = at::get_generator_or_default( - // std::nullopt, at::cuda::detail::getDefaultCUDAGenerator()); + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen->mutex_); + // if using dropout, we produce 1 random number for each element of the + // attention tensor + // TODO(eqy): should state be advanced per thread (local) amount or per call/launch (global) amount + philox_state = gen->philox_cuda_state(batch_size * num_heads * max_seqlen_batch_q * max_seqlen_batch_k); + unpack_cudnn<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>( + philox_state, static_cast(cudnn_seed.data_ptr()), static_cast(cudnn_offset.data_ptr())); + } - // // See Note [Acquire lock when using random generators] - // std::lock_guard lock(gen->mutex_); - // // if using dropout, we produce 1 random number for each element of the - // // attention tensor - // // TODO(eqy): should state be advanced per thread (local) amount or per call/launch (global) amount - // philox_state = gen->philox_cuda_state(batch_size * num_heads * max_seqlen_batch_q * max_seqlen_batch_k); - // at::cuda::philox::unpack_cudnn_wrapper( - // philox_state, static_cast(cudnn_seed.data_ptr()), static_cast(cudnn_offset.data_ptr()), at::cuda::getCurrentCUDAStream()); - //} + const auto softmax_scale = sdp::calculate_scale(query, scale).expect_float(); + Tensor debugmask; - //const auto softmax_scale = sdp::calculate_scale(query, scale).expect_float(); - //Tensor debugmask; + run_cudnn_SDP_fprop(batch_size/*int64_t b*/, + num_heads/*int64_t h*/, + max_seqlen_batch_q/*int64_t s_q*/, + max_seqlen_batch_k/*int64_t s_kv*/, + head_dim_qk/*int64_t d_qk*/, + head_dim_v/*int64_t d_v*/, + softmax_scale/*float scaling_factor*/, + compute_logsumexp/* bool */, + is_causal/* bool */, + dropout_p/*double dropout_probability*/, + query/* Tensor q*/, + key/* Tensor k*/, + value/* Tensor v*/, + attn_bias_ /* std::optional */, + log_sumexp/*Tensor softmaxstats*/, + attention/*Tensor o*/, + cudnn_seed/*Tensor dropoutseed*/, + cudnn_offset/*Tensor dropoutoffset*/); - //run_cudnn_SDP_fprop(batch_size/*int64_t b*/, - // num_heads/*int64_t h*/, - // max_seqlen_batch_q/*int64_t s_q*/, - // max_seqlen_batch_k/*int64_t s_kv*/, - // head_dim_qk/*int64_t d_qk*/, - // head_dim_v/*int64_t d_v*/, - // softmax_scale/*float scaling_factor*/, - // compute_logsumexp/* bool */, - // is_causal/* bool */, - // dropout_p/*double dropout_probability*/, - // query/* Tensor q*/, - // key/* Tensor k*/, - // value/* Tensor v*/, - // attn_bias_ /* std::optional */, - // log_sumexp/*Tensor softmaxstats*/, - // attention/*Tensor o*/, - // cudnn_seed/*Tensor dropoutseed*/, - // cudnn_offset/*Tensor dropoutoffset*/); - - //// TODO(eqy): support debug_attn_mask - //return std::make_tuple(std::move(attention), std::move(log_sumexp), Tensor(), Tensor(), max_seqlen_batch_q, max_seqlen_batch_k, std::move(cudnn_seed), std::move(cudnn_offset), Tensor()); + // TODO(eqy): support debug_attn_mask + return std::make_tuple(std::move(attention), std::move(log_sumexp), Tensor(), Tensor(), max_seqlen_batch_q, max_seqlen_batch_k, std::move(cudnn_seed), std::move(cudnn_offset), Tensor()); } std::tuple _scaled_dot_product_efficient_attention_cuda( diff --git a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp index a3be356b8aa..76fdae7b0e6 100644 --- a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp +++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp @@ -505,23 +505,10 @@ bool check_cudnn_hardware_support(sdp_params const& params, bool debug) { } bool check_for_nested_inputs(sdp_params const& params, bool debug) { - static const bool enable_cudnn_nested = c10::utils::check_env("TORCH_CUDNN_SDPA_NESTED_TENSOR_ENABLED") == true; - if (has_for_nested_inputs(params) && !enable_cudnn_nested) { - if (debug) { - TORCH_WARN("Experimental cuDNN SDPA nested tensor support is not enabled."); - } - return false; - } else if (params.query.requires_grad() || params.key.requires_grad() || params.value.requires_grad()) { - if (debug) { - TORCH_WARN("Experimental cuDNN SDPA nested tensor support does not support backward."); - } - } - - const auto dprop = at::cuda::getCurrentDeviceProperties(); // Check that the input is nested - if (dprop->major != 9 && has_for_nested_inputs(params)) { + if (has_for_nested_inputs(params)) { if (debug) { - TORCH_WARN("CuDNN SDPA supports nested tensors on SM 9.0."); + TORCH_WARN("CuDNN currently does not support nested inputs."); } return false; } @@ -587,6 +574,7 @@ bool can_use_cudnn_attention(const sdp_params& params, bool debug) { check_runtime_disabled_cudnn, check_for_nested_inputs, check_nonzero_sequence_lengths_dense, + check_last_dim_stride_equals_1_dense*/>, check_all_tensors_on_device, check_tensor_shapes, check_cudnn_tensor_shapes, @@ -600,18 +588,6 @@ bool can_use_cudnn_attention(const sdp_params& params, bool debug) { return false; } } - constexpr auto dense_constraints = - c10::array_of( - check_last_dim_stride_equals_1_dense - ); - - if (has_only_dense_inputs(params)) { - for (auto& constraint : dense_constraints) { - if (!constraint(params, debug)) { - return false; - } - } - } return true; } diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect index 561494e6a44..fee0c3174ad 100644 --- a/test/expect/HasDecompTest.test_has_decomposition.expect +++ b/test/expect/HasDecompTest.test_has_decomposition.expect @@ -75,7 +75,6 @@ aten::_ctc_loss.out aten::_ctc_loss_backward aten::_ctc_loss_backward.Tensor aten::_ctc_loss_backward.out -aten::_cudnn_attention_forward aten::_cudnn_ctc_loss aten::_cudnn_ctc_loss.Tensor aten::_cudnn_ctc_loss.out diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index 19c519d876c..beeb941af30 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -4,7 +4,6 @@ import ast import io import itertools import math -import os import random import sys import tempfile @@ -31,7 +30,6 @@ from torch.testing._internal.common_cuda import ( PLATFORM_SUPPORTS_FUSED_ATTENTION, SM70OrLater, SM80OrLater, - tf32_on_and_off, ) from torch.testing._internal.common_device_type import ( dtypes, @@ -59,7 +57,6 @@ from torch.testing._internal.common_utils import ( NestedTensorTestCase, parametrize, run_tests, - serialTest, skipIfRocm, skipIfSlowGradcheckEnv, skipIfTorchDynamo, @@ -1997,7 +1994,6 @@ class TestNestedTensorDeviceType(NestedTensorTestCase): @onlyCUDA @dtypes(torch.float, torch.double, torch.float16, torch.bfloat16) - @tf32_on_and_off(0.005) def test_bmm_cuda(self, device, dtype): self._test_bmm(device, dtype) @@ -2022,7 +2018,6 @@ class TestNestedTensorDeviceType(NestedTensorTestCase): ) @dtypes(torch.float, torch.double) - @tf32_on_and_off(0.005) def test_matmul_with_bmm_path(self, device, dtype): def unbind_rebind_matmul(nt1, nt2): t1s = nt1.unbind() @@ -2647,7 +2642,6 @@ class TestNestedTensorDeviceType(NestedTensorTestCase): nt_noncont.narrow(dim=0, start=0, length=1) @parametrize("input_dim", [3, 4]) - @tf32_on_and_off(0.005) def test_scaled_dot_product_attention(self, device, input_dim): def rand_tensor(*shape): return torch.randn(shape, device=device) @@ -3055,7 +3049,6 @@ class TestNestedTensorAutograd(NestedTensorTestCase): data = (a, b, c, d) assert torch.autograd.gradcheck(grad_test_func, inputs=data) - @tf32_on_and_off(0.008) def test_nested_tensor_bmm_backward(self, device): nt0 = torch.nested.nested_tensor( [torch.randn((2, 6)), torch.randn((3, 6))], @@ -3801,13 +3794,11 @@ class TestNestedTensorSubclass(NestedTensorTestCase): @onlyCUDA @dtypes(torch.float32) - @serialTest() def test_linear_backward_memory_usage(self, device, dtype): # Verify that linear_backward() doesn't use more memory than it should # for higher dim input sizes. # See https://github.com/pytorch/pytorch/issues/141112 B, D, max_seq_len = 64, 512, 100 - torch._C._cuda_clearCublasWorkspaces() m = torch.nn.Linear(D, D, device=device) nt = torch.nested.as_nested_tensor( [ @@ -6469,7 +6460,6 @@ torch.cuda.synchronize() TEST_WITH_ROCM, "ROCm doesn't support flash attention or mem_efficient attention for NT", ) - @tf32_on_and_off(0.005) @dtypes( *( [torch.float16, torch.bfloat16, torch.float32] @@ -6654,24 +6644,10 @@ torch.cuda.synchronize() ) self.assertEqual(attn_out.shape, q_nt_3.shape) - @parametrize("skip_backward", [True, False]) - def check_forward_backward(skip_backward=False): - if not skip_backward: - attn_nt = torch.nn.functional.scaled_dot_product_attention( - q_nt_t, k_nt_t, v_nt_t - ).transpose(1, 2) - else: - x_nt.requires_grad = False - q_nt.requires_grad = False - k_nt.requires_grad = False - v_nt.requires_grad = False - tq = q_nt_t.detach() - tk = k_nt_t.detach() - tv = v_nt_t.detach() - with torch.no_grad(): - attn_nt = torch.nn.functional.scaled_dot_product_attention( - tq, tk, tv - ).transpose(1, 2) + def check_forward_backward(): + attn_nt = torch.nn.functional.scaled_dot_product_attention( + q_nt_t, k_nt_t, v_nt_t + ).transpose(1, 2) attn_nts = attn_nt.unbind() self.assertEqual( @@ -6687,26 +6663,23 @@ torch.cuda.synchronize() rtol=output_ref_rtol, ) - if not skip_backward: - nt_grads = torch.autograd.grad( - attn_nt.values().sum(), (q_nt, k_nt, v_nt) + nt_grads = torch.autograd.grad(attn_nt.values().sum(), (q_nt, k_nt, v_nt)) + for nt_grad, d1_grad, d2_grad, grad_atol, grad_rtol in zip( + nt_grads, d1_grads, d2_grads, grad_atols, grad_rtols + ): + unbound_nt_grads = nt_grad.unbind() + self.assertEqual( + d1_grad, + unbound_nt_grads[0].unsqueeze(0), + atol=grad_atol, + rtol=grad_rtol, + ) + self.assertEqual( + d2_grad, + unbound_nt_grads[1].unsqueeze(0), + atol=grad_atol, + rtol=grad_rtol, ) - for nt_grad, d1_grad, d2_grad, grad_atol, grad_rtol in zip( - nt_grads, d1_grads, d2_grads, grad_atols, grad_rtols - ): - unbound_nt_grads = nt_grad.unbind() - self.assertEqual( - d1_grad, - unbound_nt_grads[0].unsqueeze(0), - atol=grad_atol, - rtol=grad_rtol, - ) - self.assertEqual( - d2_grad, - unbound_nt_grads[1].unsqueeze(0), - atol=grad_atol, - rtol=grad_rtol, - ) # Default check_forward_backward() @@ -6725,17 +6698,6 @@ torch.cuda.synchronize() # "group_gemm_dispatch" not implemented for 'BFloat16' if not (str(device).startswith("cuda") and dtype == torch.bfloat16): check_forward_backward() - check_cudnn = os.getenv("TORCH_CUDNN_SDPA_NESTED_TENSOR_ENABLED", "0") == "1" - if ( - "cuda" in str(device) - and check_cudnn - and (dtype == torch.float16 or dtype == torch.bfloat16) - ): - with self.assertRaisesRegex(RuntimeError, "cuDNN SDPA Nested Tensor"): - with torch.nn.attention.sdpa_kernel( - torch.nn.attention.SDPBackend.CUDNN_ATTENTION - ): - check_forward_backward() @skipIfTorchDynamo("SDPA test compiles internally") @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") @@ -8627,7 +8589,6 @@ class TestNestedTensorOpInfo(NestedTensorTestCase): [op for op in njt_op_db if op.supports_njt], allowed_dtypes=(torch.float32,), ) - @tf32_on_and_off(0.005) @sample_skips_and_xfails(FORWARD_SKIPS_AND_XFAILS) def test_forward(self, device, dtype, op): for sample, subtest_ctx, skip_xfail_ctx in op.sample_inputs( @@ -8656,7 +8617,6 @@ class TestNestedTensorOpInfo(NestedTensorTestCase): [op for op in njt_op_db if op.supports_njt and op.supports_autograd], allowed_dtypes=(torch.float32,), ) - @tf32_on_and_off(0.005) @sample_skips_and_xfails(BACKWARD_SKIPS_AND_XFAILS) def test_backward(self, device, dtype, op): for sample, subtest_ctx, skip_xfail_ctx in op.sample_inputs( diff --git a/test/test_transformers.py b/test/test_transformers.py index 1f302c53a64..af711a6fb67 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -2725,40 +2725,6 @@ class TestSDPACudaOnly(NNTestCase): self.assertEqual(actual.contiguous(), math_ref.contiguous(), atol=2e-3, rtol=1e-2) - @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "Fused SDPA was not built for this system") - @unittest.skipIf("TORCH_CUDNN_SDPA_NESTED_TENSOR_ENABLED" not in os.environ, "cuDNN Nested Tensor support not enabled") - @parametrize("type", ["nested"]) - @parametrize("is_contiguous", [True]) - def test_scaled_dot_product_attention_cudnn_nested(self, device, type: str, is_contiguous: bool): - if TEST_WITH_ROCM and type == 'nested': - self.skipTest("ROCM does not support efficient attention on nested tensors, for now") - make_tensor = partial(rand_sdpa_tensor, type=type, device=device, dtype=torch.float16, packed=True) - - batch_size, seq_len, num_heads, head_dim = 8, 64, 16, 64 - shape = SdpaShape(batch_size, num_heads, seq_len, head_dim) - - # Test Packed - qkv = make_tensor(shape) - query, key, value = qkv.chunk(3, dim=-1) - - query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) - - if is_contiguous: - query = query.contiguous() - key = key.contiguous() - value = value.contiguous() - - with sdpa_kernel(backends=[SDPBackend.CUDNN_ATTENTION]): - actual = torch.nn.functional.scaled_dot_product_attention( - query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False) - with sdpa_kernel(backends=[SDPBackend.MATH]): - math_ref = torch.nn.functional.scaled_dot_product_attention( - query.contiguous(), key.contiguous(), value.contiguous(), - attn_mask=None, dropout_p=0.0, is_causal=False) - self.assertEqual(actual.contiguous(), math_ref.contiguous(), atol=2e-3, rtol=1e-2) - @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system") @parametrize("type", ["dense", "nested"]) @parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] if diff --git a/torch/nested/_internal/sdpa.py b/torch/nested/_internal/sdpa.py index 8ac4cc86a58..cad4a94c26e 100644 --- a/torch/nested/_internal/sdpa.py +++ b/torch/nested/_internal/sdpa.py @@ -6,10 +6,8 @@ import torch import torch.nn import torch.nn.functional as F from torch.backends.cuda import ( - can_use_cudnn_attention, can_use_efficient_attention, can_use_flash_attention, - cudnn_sdp_enabled, flash_sdp_enabled, math_sdp_enabled, mem_efficient_sdp_enabled, @@ -113,32 +111,6 @@ def _check_head_dim_size_flash_nested(params: SDPAParams, debug=False) -> bool: return True -def _check_head_dim_size_cudnn_nested(params: SDPAParams, debug=False) -> bool: - max_size = 128 - query_size_last = params.query.size(-1) - key_size_last = params.key.size(-1) - value_size_last = params.value.size(-1) - same_head_dim_size = ( - query_size_last == key_size_last and query_size_last == value_size_last - ) - if not ( - same_head_dim_size - and (query_size_last % 8 == 0) - and (query_size_last <= max_size) - ): - if debug: - log.warning( - "For NestedTensor inputs, cuDNN attention requires q,k,v to have the same " - "last dimension and to be a multiple of 8 and less than or equal to 128. " - "Got Query.size(-1): %d, Key.size(-1): %d, Value.size(-1): %d instead.", - query_size_last, - key_size_last, - value_size_last, - ) - return False - return True - - def _check_for_seq_len_0_and_consistent_head_dim_nested_helper( param: torch.Tensor, param_name: str, debug=False ) -> bool: @@ -295,7 +267,6 @@ def _select_sdp_backend(query, key, value, attn_mask, dropout, is_causal, enable not flash_sdp_enabled() and not mem_efficient_sdp_enabled() and not math_sdp_enabled() - and not cudnn_sdp_enabled() ): return SDPBackend.ERROR @@ -303,15 +274,11 @@ def _select_sdp_backend(query, key, value, attn_mask, dropout, is_causal, enable SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH, - SDPBackend.CUDNN_ATTENTION, ) params = SDPAParams(query, key, value, attn_mask, dropout, is_causal, enable_gqa) for backend in ordering: - if backend == SDPBackend.CUDNN_ATTENTION: - if can_use_cudnn_attention(params): - return SDPBackend.CUDNN_ATTENTION if backend == SDPBackend.FLASH_ATTENTION: if can_use_flash_attention(params) and _can_use_flash_sdpa_jagged(params): return SDPBackend.FLASH_ATTENTION @@ -332,8 +299,6 @@ def _select_sdp_backend(query, key, value, attn_mask, dropout, is_causal, enable _can_use_flash_sdpa_jagged(params, debug=True) log.warning("Math attention kernel not used because:") _can_use_math_sdpa_jagged(params, debug=True) - log.warning("cuDNN attention kernel not used because:") - can_use_cudnn_attention(params, debug=True) return SDPBackend.ERROR @@ -785,6 +750,7 @@ def jagged_scaled_dot_product_attention( max_seqlen_batch_kv, output_nt_info, ) = _sdpa_nested_preprocessing(query_padded, key_padded, value_padded) + ( attention, _logsumexp, @@ -804,6 +770,7 @@ def jagged_scaled_dot_product_attention( False, scale=og_scale, ) + # Reshape output to convert nnz to batch_size and seq_len attention = nested_view_from_values_offsets_lengths( attention, # output from flash_attn is [total_q, num_heads, head_size_og] @@ -842,51 +809,12 @@ def jagged_scaled_dot_product_attention( compute_logsumexp, scale=scale, ) + # Reshape output to convert nnz to batch_size and seq_len return nested_view_from_values_offsets_lengths( attention.squeeze(0), **output_nt_info, ).transpose(1, 2) - elif backend_choice == SDPBackend.CUDNN_ATTENTION: - ( - query_reshaped, - key_reshaped, - value_reshaped, - cumulative_sequence_length_q, - cumulative_sequence_length_kv, - max_seqlen_batch_q, - max_seqlen_batch_kv, - output_nt_info, - ) = _sdpa_nested_preprocessing(query, key, value) - ( - attention, - logsumexp, - cum_seqlen_q, - cum_seqlen_kv, - max_seqlen_q, - max_seqlen_kv, - seed, - offset, - _, - ) = torch.ops.aten._cudnn_attention_forward( - query_reshaped, - key_reshaped, - value_reshaped, - attn_mask, - cumulative_sequence_length_q, - cumulative_sequence_length_kv, - max_seqlen_batch_q, - max_seqlen_batch_kv, - compute_logsumexp, - dropout_p, - is_causal, - False, - scale=scale, - ) - return nested_view_from_values_offsets_lengths( - attention, - **output_nt_info, - ).transpose(1, 2) elif backend_choice == SDPBackend.MATH: # save the offsets and shape of the inputs, so we can reshape the final output # query @ key = attn: [B, D1, j0, D'] @ [B, D1, D' j1] = [B, D1, j0, j1]