[cuDNN][SDPA] cuDNN SDPA refactor/cleanup, nested tensor backward, test priority bump for sm90, sm100 (#149282)

cleanup tuple/tensor boilerplate in cuDNN SDPA, preparation for nested/ragged tensor backward

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149282
Approved by: https://github.com/drisspg

Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
This commit is contained in:
Eddie Yan 2025-08-08 22:22:48 +00:00 committed by PyTorch MergeBot
parent 334ecbd4ff
commit 1128f4c2a8
12 changed files with 1028 additions and 449 deletions

File diff suppressed because it is too large Load Diff

View File

@ -70,4 +70,31 @@ void run_cudnn_SDP_bprop(
const Tensor& dropoutseed,
const Tensor& dropoutoffset);
void run_cudnn_SDP_bprop_nestedtensor(
int64_t b,
int64_t h_q,
int64_t h_k,
int64_t h_v,
int64_t s_q,
int64_t s_kv,
int64_t d_qk,
int64_t d_v,
float scaling_factor,
bool is_causal,
float dropout_probability,
const Tensor& cum_seqlen_q,
const Tensor& cum_seqlen_kv,
const Tensor& q,
const Tensor& k,
const Tensor& v,
const std::optional<Tensor>& attn_bias,
const Tensor& o,
const Tensor& dO,
const Tensor& softmaxstats,
Tensor& dQ,
Tensor& dK,
Tensor& dV,
const Tensor& dropoutseed,
const Tensor& dropoutoffset);
} // namespace at::native

View File

@ -15013,6 +15013,7 @@
- func: _scaled_dot_product_cudnn_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor attn_bias, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, *, float? scale=None) -> (Tensor, Tensor, Tensor)
dispatch:
CUDA: _scaled_dot_product_cudnn_attention_backward_cuda
NestedTensorCUDA: _scaled_dot_product_cudnn_attention_nestedtensor_backward_cuda
tags: nondeterministic_seeded
- func: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None, SymInt? window_size_left=None, SymInt? window_size_right=None, Tensor? seqused_k=None, Tensor? alibi_slopes=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor rng_state, Tensor unused, Tensor debug_attn_mask)
@ -15045,6 +15046,11 @@
CUDA: _cudnn_attention_forward
tags: nondeterministic_seeded
- func: _cudnn_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor attn_bias, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, *, float? scale=None) -> (Tensor, Tensor, Tensor)
dispatch:
CUDA: _cudnn_attention_backward
tags: nondeterministic_seeded
- func: _triton_scaled_dot_attention(Tensor q, Tensor k, Tensor v, float dropout_p=0.0) -> Tensor
variants: function
dispatch:

View File

@ -349,6 +349,63 @@ _scaled_dot_product_cudnn_attention_nestedtensor_cuda(
return std::make_tuple(std::move(attention), std::move(log_sumexp), cumulative_sequence_length_q, cumulative_sequence_length_kv, max_seqlen_batch_q, max_seqlen_batch_kv, std::move(cudnn_seed), std::move(cudnn_offset), Tensor());
}
std::tuple<Tensor, Tensor, Tensor> _scaled_dot_product_cudnn_attention_nestedtensor_backward_cuda(
const Tensor& grad_out,
const Tensor& query,
const Tensor& key,
const Tensor& value,
const Tensor& out,
const Tensor& logsumexp,
const Tensor& philox_seed,
const Tensor& philox_offset,
const Tensor& attn_bias,
const Tensor& cum_seq_q,
const Tensor& cum_seq_k,
const int64_t max_q,
const int64_t max_k,
double dropout_p,
bool is_causal,
std::optional<double> scale) {
if (!grad_out.defined()) {
return std::make_tuple(Tensor{}, Tensor{}, Tensor{});
}
auto [
grad_out_buffer_reshaped,
query_buffer_reshaped,
key_buffer_reshaped,
value_buffer_reshaped,
output_buffer_reshaped] =
preprocessing::sdpa_nested_preprocessing_backward(
grad_out,
query,
key,
value,
out,
cum_seq_q,
cum_seq_k,
max_q,
max_k);
auto [dq, dk, dv] = at::_cudnn_attention_backward(grad_out_buffer_reshaped,
query_buffer_reshaped,
key_buffer_reshaped,
value_buffer_reshaped,
output_buffer_reshaped,
logsumexp,
philox_seed,
philox_offset,
attn_bias,
cum_seq_q,
cum_seq_k,
max_q,
max_k,
dropout_p,
is_causal,
scale);
return std::make_tuple(std::move(dq), std::move(dk), std::move(dv));
}
std::tuple<at::Tensor, at::Tensor, at::Tensor> _scaled_dot_product_flash_attention_backward_nested(
const at::Tensor& grad_out_,
const at::Tensor& query,

View File

@ -849,16 +849,6 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt, Tensor, Ten
// TODO(eqy): support debug_attn_mask
return std::make_tuple(std::move(attention), std::move(log_sumexp), Tensor(), Tensor(), max_seqlen_batch_q, max_seqlen_batch_kv, std::move(cudnn_seed), std::move(cudnn_offset), Tensor());
} else {
//auto [
// query_buffer_reshaped,
// key_buffer_reshaped,
// value_buffer_reshaped,
// cumulative_sequence_length_q,
// cumulative_sequence_length_kv,
// max_seqlen_batch_q,
// max_seqlen_batch_kv,
// output_shape] = preprocessing::sdpa_nested_preprocessing(query, key, value);
// C10_LOG_API_USAGE_ONCE("torch.sdpa.flash_attention_cudnn");
// TODO(eqy): debug mask support
// BHSD ...
const int64_t batch_size = cumulative_sequence_length_q.value().size(0) - 1;

View File

@ -26,6 +26,8 @@
#else
#include <ATen/ops/zeros_like.h>
#include <ATen/ops/empty_strided.h>
#include <ATen/ops/_cudnn_attention_backward.h>
#include <ATen/ops/_cudnn_attention_backward_native.h>
#include <ATen/ops/_flash_attention_backward.h>
#include <ATen/ops/_flash_attention_backward_native.h>
#include <ATen/ops/_efficient_attention_backward.h>
@ -184,7 +186,7 @@ std::tuple<Tensor, Tensor, Tensor> _flash_attention_backward(
return std::make_tuple(Tensor(), Tensor(), Tensor());
}
std::tuple<Tensor, Tensor, Tensor> _scaled_dot_product_cudnn_attention_backward_cuda(
std::tuple<Tensor, Tensor, Tensor> _cudnn_attention_backward(
const Tensor& grad_out,
const Tensor& query,
const Tensor& key,
@ -211,57 +213,117 @@ std::tuple<Tensor, Tensor, Tensor> _scaled_dot_product_cudnn_attention_backward_
}
}
const int64_t batch_size = query.size(0);
const int64_t num_heads = query.size(1);
const int64_t head_dim_qk = query.size(3);
const int64_t head_dim_v = value.size(3);
const bool is_nested = cum_seq_q.defined();
const int64_t max_seqlen_batch_q = query.size(2);
const int64_t max_seqlen_batch_k = key.size(2);
// This is needed because SaveVariable automatically converts
// std::optional to undefined tensor
std::optional<Tensor> attn_bias_;
if (attn_bias.defined()) {
attn_bias_ = attn_bias;
}
if (attn_bias_.has_value()) {
const auto bias_dim = attn_bias_.value().dim();
if (bias_dim == 2) {
attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k});
} else if (bias_dim == 3) {
attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k});
} else {
TORCH_CHECK(bias_dim == 4, "cuDNN SDPA expects either a 2D, 3D, or 4D attn_bias but got ", attn_bias_.value().dim(), "D");
attn_bias_ = attn_bias_.value().expand({batch_size, attn_bias_.value().size(1), max_seqlen_batch_q, max_seqlen_batch_k});
}
}
if (!is_nested) {
const int64_t batch_size = query.size(0);
const int64_t num_heads = query.size(1);
const int64_t head_dim_qk = query.size(3);
const int64_t head_dim_v = value.size(3);
const auto softmax_scale = sdp::calculate_scale(query, scale).expect_float();
auto dq = at::empty_like(query);
auto dk = at::empty_like(key);
auto dv = at::empty_like(value);
run_cudnn_SDP_bprop(batch_size /*int64_t b*/,
num_heads /*int64_t h*/,
max_q/*int64_t s_q*/,
max_k/*int64_t s_kv*/,
head_dim_qk /*int64_t d_qk*/,
head_dim_v /*int64_t d_v*/,
softmax_scale /*float scaling_factor*/,
is_causal /*bool is_causal*/,
dropout_p /*float dropout_probability*/,
query /*const Tensor& q*/,
key /*const Tensor& k*/,
value /*const Tensor& v*/,
attn_bias_ /*const std::optional<Tensor>& attn_bias*/,
out /*const Tensor& o*/,
grad_out/*const Tensor& dO*/,
logsumexp.unsqueeze(-1)/*const Tensor& softmaxstats*/,
dq/*Tensor& dQ*/,
dk/*Tensor& dK*/,
dv/*Tensor& dV*/,
philox_seed/*Tensor& dropoutseed*/,
philox_offset/*Tensor& dropoutoffset*/);
return std::make_tuple(std::move(dq), std::move(dk), std::move(dv));
// This is needed because SaveVariable automatically converts
// std::optional to undefined tensor
std::optional<Tensor> attn_bias_;
if (attn_bias.defined()) {
attn_bias_ = attn_bias;
}
if (attn_bias_.has_value()) {
const auto bias_dim = attn_bias_.value().dim();
if (bias_dim == 2) {
attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k});
} else if (bias_dim == 3) {
attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k});
} else {
TORCH_CHECK(bias_dim == 4, "cuDNN SDPA expects either a 2D, 3D, or 4D attn_bias but got ", attn_bias_.value().dim(), "D");
attn_bias_ = attn_bias_.value().expand({batch_size, attn_bias_.value().size(1), max_seqlen_batch_q, max_seqlen_batch_k});
}
}
const auto softmax_scale = sdp::calculate_scale(query, scale).expect_float();
auto dq = at::empty_like(query);
auto dk = at::empty_like(key);
auto dv = at::empty_like(value);
run_cudnn_SDP_bprop(batch_size /*int64_t b*/,
num_heads /*int64_t h*/,
max_q/*int64_t s_q*/,
max_k/*int64_t s_kv*/,
head_dim_qk /*int64_t d_qk*/,
head_dim_v /*int64_t d_v*/,
softmax_scale /*float scaling_factor*/,
is_causal /*bool is_causal*/,
dropout_p /*float dropout_probability*/,
query /*const Tensor& q*/,
key /*const Tensor& k*/,
value /*const Tensor& v*/,
attn_bias_ /*const std::optional<Tensor>& attn_bias*/,
out /*const Tensor& o*/,
grad_out/*const Tensor& dO*/,
logsumexp.unsqueeze(-1)/*const Tensor& softmaxstats*/,
dq/*Tensor& dQ*/,
dk/*Tensor& dK*/,
dv/*Tensor& dV*/,
philox_seed/*Tensor& dropoutseed*/,
philox_offset/*Tensor& dropoutoffset*/);
return std::make_tuple(std::move(dq), std::move(dk), std::move(dv));
} else {
// BHSD ...
const int64_t batch_size = cum_seq_q.size(0) - 1;
const int64_t num_heads_q = query.size(-2);
const int64_t num_heads_k = key.size(-2);
const int64_t num_heads_v = value.size(-2);
const int64_t head_dim_qk = query.size(-1);
const int64_t head_dim_v = value.size(-1);
std::optional<Tensor> attn_bias_;
if (attn_bias.defined()) {
attn_bias_ = attn_bias;
}
if (attn_bias_.has_value()) {
const auto bias_dim = attn_bias_.value().dim();
if (bias_dim == 2) {
attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k});
} else if (bias_dim == 3) {
attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k});
} else {
attn_bias_ = attn_bias_.value().expand({batch_size, attn_bias_.value().size(1), max_seqlen_batch_q, max_seqlen_batch_k});
TORCH_CHECK(bias_dim == 4, "cuDNN SDPA expects either a 2D, 3D, or 4D attn_bias but got ", attn_bias_.value().dim(), "D");
}
}
auto dq = at::empty_like(query);
auto dk = at::empty_like(key);
auto dv = at::empty_like(value);
const auto softmax_scale = sdp::calculate_scale(query, scale).as_float_unchecked();
run_cudnn_SDP_bprop_nestedtensor(
batch_size,
num_heads_q,
num_heads_k,
num_heads_v,
max_seqlen_batch_q,
max_seqlen_batch_k,
head_dim_qk,
head_dim_v,
softmax_scale,
is_causal,
dropout_p,
cum_seq_q,
cum_seq_k,
query,
key,
value,
attn_bias_,
out,
grad_out,
logsumexp,
dq,
dk,
dv,
philox_seed,
philox_offset);
return std::make_tuple(std::move(dq), std::move(dk), std::move(dv));
}
}
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
@ -1063,4 +1125,40 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> _scaled_dot_product_e
}
}
std::tuple<Tensor, Tensor, Tensor> _scaled_dot_product_cudnn_attention_backward_cuda(
const Tensor& grad_out,
const Tensor& query,
const Tensor& key,
const Tensor& value,
const Tensor& out,
const Tensor& logsumexp,
const Tensor& philox_seed,
const Tensor& philox_offset,
const Tensor& attn_bias,
const Tensor& cum_seq_q,
const Tensor& cum_seq_k,
const int64_t max_q,
const int64_t max_k,
double dropout_p,
bool is_causal,
std::optional<double> scale) {
return at::_cudnn_attention_backward(
grad_out,
query,
key,
value,
out,
logsumexp,
philox_seed,
philox_offset,
attn_bias,
cum_seq_q,
cum_seq_k,
max_q,
max_k,
dropout_p,
is_causal,
scale);
}
} // namespace at::native

View File

@ -57,21 +57,28 @@
namespace sdp {
namespace {
// tracks whether we've set the default priority order once, to avoid setting
// it redundantly or overwriting a user-specified priority order
// when the priority order context manager is used before the default priority
// order is initialized the following happens:
// (1) the current priority order is queried
// (2) priority_order() is called, which initializes it to the default as init_ is false
// (3) the user-specified priority order is set
// (3.1) we are in the priority context...
// (3.2) we exit the priority context...
// (4) the previous priority order (default) is restored
bool priority_order_init_ = false;
// TODO(eqy): more benchmarking to determine whether this should include sm86/89
// Needs to be kept in-sync with test_fused_chocie in test_transformers.py
bool check_prefer_cudnn_attention() {
// TODO(eqy): Re-enable by default after upgrading to a release later than 9.5.0
// see context: https://github.com/pytorch/pytorch/issues/138340
// return false;
#if defined(CUDNN_VERSION)
#if CUDNN_VERSION > 90000
static const bool prefer_cudnn = c10::utils::check_env("TORCH_CUDNN_SDPA_PREFERRED") == true;
if (!prefer_cudnn) {
return false;
}
#if (defined(CUDNN_VERSION) && (CUDNN_VERSION > 90000))
auto dprops = at::cuda::getCurrentDeviceProperties();
return dprops->major >= 9;
#else
return false;
#endif
return dprops->major >= 9 && !dprops->minor;
#else
return false;
#endif
@ -79,6 +86,16 @@ bool check_prefer_cudnn_attention() {
// flash_attention V2 is universally faster than efficient_attention and Math
std::array<SDPBackend, num_backends> priority_order(sdp_params const& params) {
if (!priority_order_init_) {
priority_order_init_ = true;
if (check_prefer_cudnn_attention()) {
const std::vector<int64_t> cudnn_order = {static_cast<int64_t>(at::SDPBackend::cudnn_attention),
static_cast<int64_t>(at::SDPBackend::flash_attention),
static_cast<int64_t>(at::SDPBackend::efficient_attention),
static_cast<int64_t>(at::SDPBackend::math)};
at::globalContext().setSDPPriorityOrder(cudnn_order);
}
}
return at::globalContext().sDPPriorityOrder();
}
@ -414,12 +431,7 @@ bool check_cudnn_tensor_shapes(sdp_params const& params, bool debug) {
return false;
}
auto head_dim_limit = 128;
if (cudnn_version >= 90501) {
auto dprops = at::cuda::getCurrentDeviceProperties();
if (dprops->major == 9 && !dprops->minor) {
head_dim_limit = 256;
}
}
// TODO(eqy): add head dim >= 256 cases once support is finalized
if (d_qk > head_dim_limit || d_v > head_dim_limit) {
if (debug) {
TORCH_WARN("head_dim should be no more than ", head_dim_limit);
@ -453,9 +465,15 @@ bool check_cudnn_tensor_shapes(sdp_params const& params, bool debug) {
return false;
}
}
if (s_q == 1 || s_k == 1) {
if (s_k == 1) {
if (debug) {
TORCH_WARN_ONCE("cudnn SDPA does not support sequence length 1.");
TORCH_WARN_ONCE("cudnn SDPA does not support key/value sequence length 1.");
}
return false;
}
if (s_q == 1 && params.dropout != 0.0) {
if (debug) {
TORCH_WARN_ONCE("cudnn SDPA does not support query sequence length 1 with dropout.");
}
return false;
}
@ -563,9 +581,9 @@ bool check_for_nested_inputs(sdp_params const& params, bool debug) {
const auto dprop = at::cuda::getCurrentDeviceProperties();
// Check that the input is nested
if (dprop->major != 9 && has_for_nested_inputs(params)) {
if ((dprop->major == 9 || dprop->major == 10) && has_for_nested_inputs(params)) {
if (debug) {
TORCH_WARN("CuDNN SDPA supports nested tensors on SM 9.0.");
TORCH_WARN("cuDNN SDPA supports nested tensors on SM 9.0, SM 10.0.");
}
return false;
}
@ -589,7 +607,7 @@ bool check_runtime_disabled_cudnn(sdp_params const& params, bool debug) {
// sdp kernels
if (!at::globalContext().userEnabledCuDNNSDP()) {
if (debug) {
TORCH_WARN("CuDNN attention has been runtime disabled.");
TORCH_WARN("cuDNN attention has been runtime disabled.");
}
return false;
}
@ -620,7 +638,7 @@ bool can_use_cudnn_attention(const sdp_params& params, bool debug) {
#endif
#if defined(CUDNN_VERSION) && CUDNN_VERSION < 90000
if (debug) {
TORCH_WARN(CUDNN_VERSION, " cuDNN version too old to use CuDNN Attention (< v9.0.0)");
TORCH_WARN(CUDNN_VERSION, " cuDNN version too old to use cuDNN Attention (< v9.0.0)");
}
return false;
#endif
@ -630,10 +648,8 @@ bool can_use_cudnn_attention(const sdp_params& params, bool debug) {
c10::array_of<bool (*)(sdp_params const&, bool)>(
check_runtime_disabled_cudnn,
check_for_nested_inputs,
check_nonzero_sequence_lengths_dense,
check_all_tensors_on_device,
check_tensor_shapes,
check_cudnn_tensor_shapes,
check_cudnn_deterministic,
check_dtypes_low_precision,
check_attn_mask_shape,
@ -646,8 +662,10 @@ bool can_use_cudnn_attention(const sdp_params& params, bool debug) {
}
constexpr auto dense_constraints =
c10::array_of<bool (*)(sdp_params const&, bool)>(
check_nonzero_sequence_lengths_dense,
check_last_dim_stride_equals_1_dense<true /*ignore_singleton_dim=*/>,
check_batch_size_and_num_heads_dense<true /*enable_gqa*/, false /*requires_same_num_heads*/>
check_batch_size_and_num_heads_dense<true /*enable_gqa*/, false /*requires_same_num_heads*/>,
check_cudnn_tensor_shapes
);
if (has_only_dense_inputs(params)) {
@ -864,7 +882,7 @@ SDPBackend select_sdp_backend(sdp_params const& kernel_params) {
sdp::can_use_mem_efficient_attention(kernel_params, print_debug);
TORCH_WARN("Flash attention kernel not used because:");
sdp::can_use_flash_attention(kernel_params, print_debug);
TORCH_WARN("CuDNN attention kernel not used because:");
TORCH_WARN("cuDNN attention kernel not used because:");
sdp::can_use_cudnn_attention(kernel_params, print_debug);
TORCH_CHECK(!print_debug, "No available kernel. Aborting execution.")
return SDPBackend::error;

View File

@ -75,6 +75,7 @@ aten::_ctc_loss.out
aten::_ctc_loss_backward
aten::_ctc_loss_backward.Tensor
aten::_ctc_loss_backward.out
aten::_cudnn_attention_backward
aten::_cudnn_attention_forward
aten::_cudnn_ctc_loss
aten::_cudnn_ctc_loss.Tensor

View File

@ -26,6 +26,7 @@ from torch._inductor.utils import (
run_fw_bw_and_get_code,
)
from torch.fx.experimental.proxy_tensor import make_fx
from torch.nn.attention import sdpa_kernel, SDPBackend
from torch.testing import FileCheck
from torch.testing._internal.common_cuda import (
PLATFORM_SUPPORTS_FLASH_ATTENTION,
@ -177,9 +178,10 @@ class CudaReproTests(TestCase):
inputs = [q, k, v, mask]
def f(q, k, v, mask):
return F.scaled_dot_product_attention(
q, k, v, attn_mask=mask, dropout_p=0.0
)
with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
return F.scaled_dot_product_attention(
q, k, v, attn_mask=mask, dropout_p=0.0
)
f_compiled = torch.compile(f)

View File

@ -6760,11 +6760,10 @@ torch.cuda.synchronize()
and check_cudnn
and (dtype == torch.float16 or dtype == torch.bfloat16)
):
with self.assertRaisesRegex(RuntimeError, "cuDNN SDPA Nested Tensor"):
with torch.nn.attention.sdpa_kernel(
torch.nn.attention.SDPBackend.CUDNN_ATTENTION
):
check_forward_backward()
with torch.nn.attention.sdpa_kernel(
torch.nn.attention.SDPBackend.CUDNN_ATTENTION
):
check_forward_backward()
@skipIfTorchDynamo("SDPA test compiles internally")
@unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile")

View File

@ -49,7 +49,6 @@ from torch.testing._internal.common_cuda import (
PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
PLATFORM_SUPPORTS_FUSED_ATTENTION,
PLATFORM_SUPPORTS_CUDNN_ATTENTION,
SM90OrLater,
tf32_on_and_off,
tf32_enabled,
)
@ -2657,6 +2656,7 @@ class TestSDPACudaOnly(NNTestCase):
@skipIfRocm # No cuDNN Attention
@unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system")
@unittest.expectedFailure # cuDNN currently doesn't support this on SM100+/fails graph validation
def test_cudnn_attention_d256_heuristic(self, device):
dtype = torch.bfloat16
make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=True)
@ -2667,7 +2667,7 @@ class TestSDPACudaOnly(NNTestCase):
v_shape = SdpaShape(batch, num_heads, seq_len, head_dim_v)
query, key, value = make_tensor(q_shape), make_tensor(k_shape), make_tensor(v_shape)
with sdpa_kernel(backends=[SDPBackend.CUDNN_ATTENTION, SDPBackend.MATH], set_priority=True):
with sdpa_kernel(backends=[SDPBackend.CUDNN_ATTENTION], set_priority=True):
actual = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False)
actual.backward(torch.randn_like(actual))
@ -2705,7 +2705,7 @@ class TestSDPACudaOnly(NNTestCase):
@skipIfRocm # No cuDNN Attention
@unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system")
@unittest.skipIf(True, "broken as of cuDNN 9.10")
def test_cudnn_attention_fail_d128(self, device):
# Test that cuDNN attention dispatching correctly bails out on d > 128
b, h = 1, 2
@ -2720,7 +2720,6 @@ class TestSDPACudaOnly(NNTestCase):
ISSM90 = device_cap == (9, 0)
ISSM100 = device_cap == (10, 0)
with sdpa_kernel(backends=[SDPBackend.CUDNN_ATTENTION]):
# SM90/100 support d <= 256 as of cuDNN 9.5.1+
if (ISSM90 or ISSM100) and torch.backends.cudnn.version() >= 90501:
torch.nn.functional.scaled_dot_product_attention(q, k, v)
else:
@ -3156,15 +3155,19 @@ class TestSDPACudaOnly(NNTestCase):
value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
device_capability = None
if "cuda" in str(device):
device_capability = torch.cuda.get_device_capability()
prefer_cudnn = "TORCH_CUDNN_SDPA_PREFERRED" in os.environ
prefer_cudnn = prefer_cudnn and device_capability and (device_capability == (9, 0) or device_capability == (10, 0))
# TODO we are currently disabling this by default, lets assert that this returns
# FlashAttention, we need to change when we make remove opt-in for cudnn
if type != "nested" and PLATFORM_SUPPORTS_CUDNN_ATTENTION and SM90OrLater:
self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.FLASH_ATTENTION.value)
with sdpa_kernel(backends=[SDPBackend.CUDNN_ATTENTION]):
self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.CUDNN_ATTENTION.value)
if type != "nested" and PLATFORM_SUPPORTS_CUDNN_ATTENTION and prefer_cudnn:
self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.CUDNN_ATTENTION.value)
elif PLATFORM_SUPPORTS_FLASH_ATTENTION:
self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.FLASH_ATTENTION.value)
elif type != "nested" and PLATFORM_SUPPORTS_CUDNN_ATTENTION: # e.g., we're on Windows
elif type != "nested" and PLATFORM_SUPPORTS_CUDNN_ATTENTION and not prefer_cudnn: # e.g., we're on Windows
self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.EFFICIENT_ATTENTION.value)
with sdpa_kernel(backends=[SDPBackend.CUDNN_ATTENTION]):
self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.CUDNN_ATTENTION.value)

View File

@ -2904,6 +2904,10 @@
output_differentiability: [True, False, False, False, False, False]
query, key, value, bias: _efficient_attention_backward_symint(grad, query, key, value, bias, output, cu_seqlens_q, cu_seqlens_k, max_seqlen_batch_q, max_seqlen_batch_k, logsumexp, dropout_p, philox_seed, philox_offset, custom_mask_type, bias.requires_grad(), scale)
- name: _cudnn_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
output_differentiability: [True, False, False, False, False, False, False, False, False]
query, key, value: _cudnn_attention_backward_symint(grad, query, key, value, output, logsumexp, philox_seed, philox_offset, attn_bias, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, scale)
- name: _scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
output_differentiability: [True, False, False, False, False, False, False, False, False]
query, key, value: _scaled_dot_product_cudnn_attention_backward_symint(grad, query, key, value, output, logsumexp, philox_seed, philox_offset, attn_bias, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, scale)