mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[cuDNN][SDPA] cuDNN SDPA refactor/cleanup, nested tensor backward, test priority bump for sm90, sm100 (#149282)
cleanup tuple/tensor boilerplate in cuDNN SDPA, preparation for nested/ragged tensor backward Pull Request resolved: https://github.com/pytorch/pytorch/pull/149282 Approved by: https://github.com/drisspg Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
This commit is contained in:
parent
334ecbd4ff
commit
1128f4c2a8
File diff suppressed because it is too large
Load Diff
|
|
@ -70,4 +70,31 @@ void run_cudnn_SDP_bprop(
|
|||
const Tensor& dropoutseed,
|
||||
const Tensor& dropoutoffset);
|
||||
|
||||
void run_cudnn_SDP_bprop_nestedtensor(
|
||||
int64_t b,
|
||||
int64_t h_q,
|
||||
int64_t h_k,
|
||||
int64_t h_v,
|
||||
int64_t s_q,
|
||||
int64_t s_kv,
|
||||
int64_t d_qk,
|
||||
int64_t d_v,
|
||||
float scaling_factor,
|
||||
bool is_causal,
|
||||
float dropout_probability,
|
||||
const Tensor& cum_seqlen_q,
|
||||
const Tensor& cum_seqlen_kv,
|
||||
const Tensor& q,
|
||||
const Tensor& k,
|
||||
const Tensor& v,
|
||||
const std::optional<Tensor>& attn_bias,
|
||||
const Tensor& o,
|
||||
const Tensor& dO,
|
||||
const Tensor& softmaxstats,
|
||||
Tensor& dQ,
|
||||
Tensor& dK,
|
||||
Tensor& dV,
|
||||
const Tensor& dropoutseed,
|
||||
const Tensor& dropoutoffset);
|
||||
|
||||
} // namespace at::native
|
||||
|
|
|
|||
|
|
@ -15013,6 +15013,7 @@
|
|||
- func: _scaled_dot_product_cudnn_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor attn_bias, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, *, float? scale=None) -> (Tensor, Tensor, Tensor)
|
||||
dispatch:
|
||||
CUDA: _scaled_dot_product_cudnn_attention_backward_cuda
|
||||
NestedTensorCUDA: _scaled_dot_product_cudnn_attention_nestedtensor_backward_cuda
|
||||
tags: nondeterministic_seeded
|
||||
|
||||
- func: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None, SymInt? window_size_left=None, SymInt? window_size_right=None, Tensor? seqused_k=None, Tensor? alibi_slopes=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor rng_state, Tensor unused, Tensor debug_attn_mask)
|
||||
|
|
@ -15045,6 +15046,11 @@
|
|||
CUDA: _cudnn_attention_forward
|
||||
tags: nondeterministic_seeded
|
||||
|
||||
- func: _cudnn_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor attn_bias, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, *, float? scale=None) -> (Tensor, Tensor, Tensor)
|
||||
dispatch:
|
||||
CUDA: _cudnn_attention_backward
|
||||
tags: nondeterministic_seeded
|
||||
|
||||
- func: _triton_scaled_dot_attention(Tensor q, Tensor k, Tensor v, float dropout_p=0.0) -> Tensor
|
||||
variants: function
|
||||
dispatch:
|
||||
|
|
|
|||
|
|
@ -349,6 +349,63 @@ _scaled_dot_product_cudnn_attention_nestedtensor_cuda(
|
|||
return std::make_tuple(std::move(attention), std::move(log_sumexp), cumulative_sequence_length_q, cumulative_sequence_length_kv, max_seqlen_batch_q, max_seqlen_batch_kv, std::move(cudnn_seed), std::move(cudnn_offset), Tensor());
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor, Tensor> _scaled_dot_product_cudnn_attention_nestedtensor_backward_cuda(
|
||||
const Tensor& grad_out,
|
||||
const Tensor& query,
|
||||
const Tensor& key,
|
||||
const Tensor& value,
|
||||
const Tensor& out,
|
||||
const Tensor& logsumexp,
|
||||
const Tensor& philox_seed,
|
||||
const Tensor& philox_offset,
|
||||
const Tensor& attn_bias,
|
||||
const Tensor& cum_seq_q,
|
||||
const Tensor& cum_seq_k,
|
||||
const int64_t max_q,
|
||||
const int64_t max_k,
|
||||
double dropout_p,
|
||||
bool is_causal,
|
||||
std::optional<double> scale) {
|
||||
if (!grad_out.defined()) {
|
||||
return std::make_tuple(Tensor{}, Tensor{}, Tensor{});
|
||||
}
|
||||
auto [
|
||||
grad_out_buffer_reshaped,
|
||||
query_buffer_reshaped,
|
||||
key_buffer_reshaped,
|
||||
value_buffer_reshaped,
|
||||
output_buffer_reshaped] =
|
||||
preprocessing::sdpa_nested_preprocessing_backward(
|
||||
grad_out,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
out,
|
||||
cum_seq_q,
|
||||
cum_seq_k,
|
||||
max_q,
|
||||
max_k);
|
||||
|
||||
auto [dq, dk, dv] = at::_cudnn_attention_backward(grad_out_buffer_reshaped,
|
||||
query_buffer_reshaped,
|
||||
key_buffer_reshaped,
|
||||
value_buffer_reshaped,
|
||||
output_buffer_reshaped,
|
||||
logsumexp,
|
||||
philox_seed,
|
||||
philox_offset,
|
||||
attn_bias,
|
||||
cum_seq_q,
|
||||
cum_seq_k,
|
||||
max_q,
|
||||
max_k,
|
||||
dropout_p,
|
||||
is_causal,
|
||||
scale);
|
||||
return std::make_tuple(std::move(dq), std::move(dk), std::move(dv));
|
||||
}
|
||||
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor> _scaled_dot_product_flash_attention_backward_nested(
|
||||
const at::Tensor& grad_out_,
|
||||
const at::Tensor& query,
|
||||
|
|
|
|||
|
|
@ -849,16 +849,6 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt, Tensor, Ten
|
|||
// TODO(eqy): support debug_attn_mask
|
||||
return std::make_tuple(std::move(attention), std::move(log_sumexp), Tensor(), Tensor(), max_seqlen_batch_q, max_seqlen_batch_kv, std::move(cudnn_seed), std::move(cudnn_offset), Tensor());
|
||||
} else {
|
||||
//auto [
|
||||
// query_buffer_reshaped,
|
||||
// key_buffer_reshaped,
|
||||
// value_buffer_reshaped,
|
||||
// cumulative_sequence_length_q,
|
||||
// cumulative_sequence_length_kv,
|
||||
// max_seqlen_batch_q,
|
||||
// max_seqlen_batch_kv,
|
||||
// output_shape] = preprocessing::sdpa_nested_preprocessing(query, key, value);
|
||||
// C10_LOG_API_USAGE_ONCE("torch.sdpa.flash_attention_cudnn");
|
||||
// TODO(eqy): debug mask support
|
||||
// BHSD ...
|
||||
const int64_t batch_size = cumulative_sequence_length_q.value().size(0) - 1;
|
||||
|
|
|
|||
|
|
@ -26,6 +26,8 @@
|
|||
#else
|
||||
#include <ATen/ops/zeros_like.h>
|
||||
#include <ATen/ops/empty_strided.h>
|
||||
#include <ATen/ops/_cudnn_attention_backward.h>
|
||||
#include <ATen/ops/_cudnn_attention_backward_native.h>
|
||||
#include <ATen/ops/_flash_attention_backward.h>
|
||||
#include <ATen/ops/_flash_attention_backward_native.h>
|
||||
#include <ATen/ops/_efficient_attention_backward.h>
|
||||
|
|
@ -184,7 +186,7 @@ std::tuple<Tensor, Tensor, Tensor> _flash_attention_backward(
|
|||
return std::make_tuple(Tensor(), Tensor(), Tensor());
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor, Tensor> _scaled_dot_product_cudnn_attention_backward_cuda(
|
||||
std::tuple<Tensor, Tensor, Tensor> _cudnn_attention_backward(
|
||||
const Tensor& grad_out,
|
||||
const Tensor& query,
|
||||
const Tensor& key,
|
||||
|
|
@ -211,57 +213,117 @@ std::tuple<Tensor, Tensor, Tensor> _scaled_dot_product_cudnn_attention_backward_
|
|||
}
|
||||
}
|
||||
|
||||
const int64_t batch_size = query.size(0);
|
||||
const int64_t num_heads = query.size(1);
|
||||
const int64_t head_dim_qk = query.size(3);
|
||||
const int64_t head_dim_v = value.size(3);
|
||||
const bool is_nested = cum_seq_q.defined();
|
||||
const int64_t max_seqlen_batch_q = query.size(2);
|
||||
const int64_t max_seqlen_batch_k = key.size(2);
|
||||
|
||||
// This is needed because SaveVariable automatically converts
|
||||
// std::optional to undefined tensor
|
||||
std::optional<Tensor> attn_bias_;
|
||||
if (attn_bias.defined()) {
|
||||
attn_bias_ = attn_bias;
|
||||
}
|
||||
if (attn_bias_.has_value()) {
|
||||
const auto bias_dim = attn_bias_.value().dim();
|
||||
if (bias_dim == 2) {
|
||||
attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k});
|
||||
} else if (bias_dim == 3) {
|
||||
attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k});
|
||||
} else {
|
||||
TORCH_CHECK(bias_dim == 4, "cuDNN SDPA expects either a 2D, 3D, or 4D attn_bias but got ", attn_bias_.value().dim(), "D");
|
||||
attn_bias_ = attn_bias_.value().expand({batch_size, attn_bias_.value().size(1), max_seqlen_batch_q, max_seqlen_batch_k});
|
||||
}
|
||||
}
|
||||
if (!is_nested) {
|
||||
const int64_t batch_size = query.size(0);
|
||||
const int64_t num_heads = query.size(1);
|
||||
const int64_t head_dim_qk = query.size(3);
|
||||
const int64_t head_dim_v = value.size(3);
|
||||
|
||||
const auto softmax_scale = sdp::calculate_scale(query, scale).expect_float();
|
||||
auto dq = at::empty_like(query);
|
||||
auto dk = at::empty_like(key);
|
||||
auto dv = at::empty_like(value);
|
||||
run_cudnn_SDP_bprop(batch_size /*int64_t b*/,
|
||||
num_heads /*int64_t h*/,
|
||||
max_q/*int64_t s_q*/,
|
||||
max_k/*int64_t s_kv*/,
|
||||
head_dim_qk /*int64_t d_qk*/,
|
||||
head_dim_v /*int64_t d_v*/,
|
||||
softmax_scale /*float scaling_factor*/,
|
||||
is_causal /*bool is_causal*/,
|
||||
dropout_p /*float dropout_probability*/,
|
||||
query /*const Tensor& q*/,
|
||||
key /*const Tensor& k*/,
|
||||
value /*const Tensor& v*/,
|
||||
attn_bias_ /*const std::optional<Tensor>& attn_bias*/,
|
||||
out /*const Tensor& o*/,
|
||||
grad_out/*const Tensor& dO*/,
|
||||
logsumexp.unsqueeze(-1)/*const Tensor& softmaxstats*/,
|
||||
dq/*Tensor& dQ*/,
|
||||
dk/*Tensor& dK*/,
|
||||
dv/*Tensor& dV*/,
|
||||
philox_seed/*Tensor& dropoutseed*/,
|
||||
philox_offset/*Tensor& dropoutoffset*/);
|
||||
return std::make_tuple(std::move(dq), std::move(dk), std::move(dv));
|
||||
// This is needed because SaveVariable automatically converts
|
||||
// std::optional to undefined tensor
|
||||
std::optional<Tensor> attn_bias_;
|
||||
if (attn_bias.defined()) {
|
||||
attn_bias_ = attn_bias;
|
||||
}
|
||||
if (attn_bias_.has_value()) {
|
||||
const auto bias_dim = attn_bias_.value().dim();
|
||||
if (bias_dim == 2) {
|
||||
attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k});
|
||||
} else if (bias_dim == 3) {
|
||||
attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k});
|
||||
} else {
|
||||
TORCH_CHECK(bias_dim == 4, "cuDNN SDPA expects either a 2D, 3D, or 4D attn_bias but got ", attn_bias_.value().dim(), "D");
|
||||
attn_bias_ = attn_bias_.value().expand({batch_size, attn_bias_.value().size(1), max_seqlen_batch_q, max_seqlen_batch_k});
|
||||
}
|
||||
}
|
||||
|
||||
const auto softmax_scale = sdp::calculate_scale(query, scale).expect_float();
|
||||
auto dq = at::empty_like(query);
|
||||
auto dk = at::empty_like(key);
|
||||
auto dv = at::empty_like(value);
|
||||
run_cudnn_SDP_bprop(batch_size /*int64_t b*/,
|
||||
num_heads /*int64_t h*/,
|
||||
max_q/*int64_t s_q*/,
|
||||
max_k/*int64_t s_kv*/,
|
||||
head_dim_qk /*int64_t d_qk*/,
|
||||
head_dim_v /*int64_t d_v*/,
|
||||
softmax_scale /*float scaling_factor*/,
|
||||
is_causal /*bool is_causal*/,
|
||||
dropout_p /*float dropout_probability*/,
|
||||
query /*const Tensor& q*/,
|
||||
key /*const Tensor& k*/,
|
||||
value /*const Tensor& v*/,
|
||||
attn_bias_ /*const std::optional<Tensor>& attn_bias*/,
|
||||
out /*const Tensor& o*/,
|
||||
grad_out/*const Tensor& dO*/,
|
||||
logsumexp.unsqueeze(-1)/*const Tensor& softmaxstats*/,
|
||||
dq/*Tensor& dQ*/,
|
||||
dk/*Tensor& dK*/,
|
||||
dv/*Tensor& dV*/,
|
||||
philox_seed/*Tensor& dropoutseed*/,
|
||||
philox_offset/*Tensor& dropoutoffset*/);
|
||||
return std::make_tuple(std::move(dq), std::move(dk), std::move(dv));
|
||||
} else {
|
||||
// BHSD ...
|
||||
const int64_t batch_size = cum_seq_q.size(0) - 1;
|
||||
const int64_t num_heads_q = query.size(-2);
|
||||
const int64_t num_heads_k = key.size(-2);
|
||||
const int64_t num_heads_v = value.size(-2);
|
||||
const int64_t head_dim_qk = query.size(-1);
|
||||
const int64_t head_dim_v = value.size(-1);
|
||||
std::optional<Tensor> attn_bias_;
|
||||
if (attn_bias.defined()) {
|
||||
attn_bias_ = attn_bias;
|
||||
}
|
||||
if (attn_bias_.has_value()) {
|
||||
const auto bias_dim = attn_bias_.value().dim();
|
||||
if (bias_dim == 2) {
|
||||
attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k});
|
||||
} else if (bias_dim == 3) {
|
||||
attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k});
|
||||
} else {
|
||||
attn_bias_ = attn_bias_.value().expand({batch_size, attn_bias_.value().size(1), max_seqlen_batch_q, max_seqlen_batch_k});
|
||||
TORCH_CHECK(bias_dim == 4, "cuDNN SDPA expects either a 2D, 3D, or 4D attn_bias but got ", attn_bias_.value().dim(), "D");
|
||||
}
|
||||
}
|
||||
|
||||
auto dq = at::empty_like(query);
|
||||
auto dk = at::empty_like(key);
|
||||
auto dv = at::empty_like(value);
|
||||
|
||||
const auto softmax_scale = sdp::calculate_scale(query, scale).as_float_unchecked();
|
||||
run_cudnn_SDP_bprop_nestedtensor(
|
||||
batch_size,
|
||||
num_heads_q,
|
||||
num_heads_k,
|
||||
num_heads_v,
|
||||
max_seqlen_batch_q,
|
||||
max_seqlen_batch_k,
|
||||
head_dim_qk,
|
||||
head_dim_v,
|
||||
softmax_scale,
|
||||
is_causal,
|
||||
dropout_p,
|
||||
cum_seq_q,
|
||||
cum_seq_k,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_bias_,
|
||||
out,
|
||||
grad_out,
|
||||
logsumexp,
|
||||
dq,
|
||||
dk,
|
||||
dv,
|
||||
philox_seed,
|
||||
philox_offset);
|
||||
return std::make_tuple(std::move(dq), std::move(dk), std::move(dv));
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
|
||||
|
|
@ -1063,4 +1125,40 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> _scaled_dot_product_e
|
|||
}
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor, Tensor> _scaled_dot_product_cudnn_attention_backward_cuda(
|
||||
const Tensor& grad_out,
|
||||
const Tensor& query,
|
||||
const Tensor& key,
|
||||
const Tensor& value,
|
||||
const Tensor& out,
|
||||
const Tensor& logsumexp,
|
||||
const Tensor& philox_seed,
|
||||
const Tensor& philox_offset,
|
||||
const Tensor& attn_bias,
|
||||
const Tensor& cum_seq_q,
|
||||
const Tensor& cum_seq_k,
|
||||
const int64_t max_q,
|
||||
const int64_t max_k,
|
||||
double dropout_p,
|
||||
bool is_causal,
|
||||
std::optional<double> scale) {
|
||||
return at::_cudnn_attention_backward(
|
||||
grad_out,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
out,
|
||||
logsumexp,
|
||||
philox_seed,
|
||||
philox_offset,
|
||||
attn_bias,
|
||||
cum_seq_q,
|
||||
cum_seq_k,
|
||||
max_q,
|
||||
max_k,
|
||||
dropout_p,
|
||||
is_causal,
|
||||
scale);
|
||||
}
|
||||
|
||||
} // namespace at::native
|
||||
|
|
|
|||
|
|
@ -57,21 +57,28 @@
|
|||
namespace sdp {
|
||||
namespace {
|
||||
|
||||
// tracks whether we've set the default priority order once, to avoid setting
|
||||
// it redundantly or overwriting a user-specified priority order
|
||||
// when the priority order context manager is used before the default priority
|
||||
// order is initialized the following happens:
|
||||
// (1) the current priority order is queried
|
||||
// (2) priority_order() is called, which initializes it to the default as init_ is false
|
||||
// (3) the user-specified priority order is set
|
||||
// (3.1) we are in the priority context...
|
||||
// (3.2) we exit the priority context...
|
||||
// (4) the previous priority order (default) is restored
|
||||
bool priority_order_init_ = false;
|
||||
|
||||
// TODO(eqy): more benchmarking to determine whether this should include sm86/89
|
||||
// Needs to be kept in-sync with test_fused_chocie in test_transformers.py
|
||||
bool check_prefer_cudnn_attention() {
|
||||
// TODO(eqy): Re-enable by default after upgrading to a release later than 9.5.0
|
||||
// see context: https://github.com/pytorch/pytorch/issues/138340
|
||||
// return false;
|
||||
#if defined(CUDNN_VERSION)
|
||||
|
||||
#if CUDNN_VERSION > 90000
|
||||
static const bool prefer_cudnn = c10::utils::check_env("TORCH_CUDNN_SDPA_PREFERRED") == true;
|
||||
if (!prefer_cudnn) {
|
||||
return false;
|
||||
}
|
||||
#if (defined(CUDNN_VERSION) && (CUDNN_VERSION > 90000))
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
return dprops->major >= 9;
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
|
||||
return dprops->major >= 9 && !dprops->minor;
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
|
|
@ -79,6 +86,16 @@ bool check_prefer_cudnn_attention() {
|
|||
|
||||
// flash_attention V2 is universally faster than efficient_attention and Math
|
||||
std::array<SDPBackend, num_backends> priority_order(sdp_params const& params) {
|
||||
if (!priority_order_init_) {
|
||||
priority_order_init_ = true;
|
||||
if (check_prefer_cudnn_attention()) {
|
||||
const std::vector<int64_t> cudnn_order = {static_cast<int64_t>(at::SDPBackend::cudnn_attention),
|
||||
static_cast<int64_t>(at::SDPBackend::flash_attention),
|
||||
static_cast<int64_t>(at::SDPBackend::efficient_attention),
|
||||
static_cast<int64_t>(at::SDPBackend::math)};
|
||||
at::globalContext().setSDPPriorityOrder(cudnn_order);
|
||||
}
|
||||
}
|
||||
return at::globalContext().sDPPriorityOrder();
|
||||
}
|
||||
|
||||
|
|
@ -414,12 +431,7 @@ bool check_cudnn_tensor_shapes(sdp_params const& params, bool debug) {
|
|||
return false;
|
||||
}
|
||||
auto head_dim_limit = 128;
|
||||
if (cudnn_version >= 90501) {
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
if (dprops->major == 9 && !dprops->minor) {
|
||||
head_dim_limit = 256;
|
||||
}
|
||||
}
|
||||
// TODO(eqy): add head dim >= 256 cases once support is finalized
|
||||
if (d_qk > head_dim_limit || d_v > head_dim_limit) {
|
||||
if (debug) {
|
||||
TORCH_WARN("head_dim should be no more than ", head_dim_limit);
|
||||
|
|
@ -453,9 +465,15 @@ bool check_cudnn_tensor_shapes(sdp_params const& params, bool debug) {
|
|||
return false;
|
||||
}
|
||||
}
|
||||
if (s_q == 1 || s_k == 1) {
|
||||
if (s_k == 1) {
|
||||
if (debug) {
|
||||
TORCH_WARN_ONCE("cudnn SDPA does not support sequence length 1.");
|
||||
TORCH_WARN_ONCE("cudnn SDPA does not support key/value sequence length 1.");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
if (s_q == 1 && params.dropout != 0.0) {
|
||||
if (debug) {
|
||||
TORCH_WARN_ONCE("cudnn SDPA does not support query sequence length 1 with dropout.");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
|
@ -563,9 +581,9 @@ bool check_for_nested_inputs(sdp_params const& params, bool debug) {
|
|||
|
||||
const auto dprop = at::cuda::getCurrentDeviceProperties();
|
||||
// Check that the input is nested
|
||||
if (dprop->major != 9 && has_for_nested_inputs(params)) {
|
||||
if ((dprop->major == 9 || dprop->major == 10) && has_for_nested_inputs(params)) {
|
||||
if (debug) {
|
||||
TORCH_WARN("CuDNN SDPA supports nested tensors on SM 9.0.");
|
||||
TORCH_WARN("cuDNN SDPA supports nested tensors on SM 9.0, SM 10.0.");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
|
@ -589,7 +607,7 @@ bool check_runtime_disabled_cudnn(sdp_params const& params, bool debug) {
|
|||
// sdp kernels
|
||||
if (!at::globalContext().userEnabledCuDNNSDP()) {
|
||||
if (debug) {
|
||||
TORCH_WARN("CuDNN attention has been runtime disabled.");
|
||||
TORCH_WARN("cuDNN attention has been runtime disabled.");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
|
@ -620,7 +638,7 @@ bool can_use_cudnn_attention(const sdp_params& params, bool debug) {
|
|||
#endif
|
||||
#if defined(CUDNN_VERSION) && CUDNN_VERSION < 90000
|
||||
if (debug) {
|
||||
TORCH_WARN(CUDNN_VERSION, " cuDNN version too old to use CuDNN Attention (< v9.0.0)");
|
||||
TORCH_WARN(CUDNN_VERSION, " cuDNN version too old to use cuDNN Attention (< v9.0.0)");
|
||||
}
|
||||
return false;
|
||||
#endif
|
||||
|
|
@ -630,10 +648,8 @@ bool can_use_cudnn_attention(const sdp_params& params, bool debug) {
|
|||
c10::array_of<bool (*)(sdp_params const&, bool)>(
|
||||
check_runtime_disabled_cudnn,
|
||||
check_for_nested_inputs,
|
||||
check_nonzero_sequence_lengths_dense,
|
||||
check_all_tensors_on_device,
|
||||
check_tensor_shapes,
|
||||
check_cudnn_tensor_shapes,
|
||||
check_cudnn_deterministic,
|
||||
check_dtypes_low_precision,
|
||||
check_attn_mask_shape,
|
||||
|
|
@ -646,8 +662,10 @@ bool can_use_cudnn_attention(const sdp_params& params, bool debug) {
|
|||
}
|
||||
constexpr auto dense_constraints =
|
||||
c10::array_of<bool (*)(sdp_params const&, bool)>(
|
||||
check_nonzero_sequence_lengths_dense,
|
||||
check_last_dim_stride_equals_1_dense<true /*ignore_singleton_dim=*/>,
|
||||
check_batch_size_and_num_heads_dense<true /*enable_gqa*/, false /*requires_same_num_heads*/>
|
||||
check_batch_size_and_num_heads_dense<true /*enable_gqa*/, false /*requires_same_num_heads*/>,
|
||||
check_cudnn_tensor_shapes
|
||||
);
|
||||
|
||||
if (has_only_dense_inputs(params)) {
|
||||
|
|
@ -864,7 +882,7 @@ SDPBackend select_sdp_backend(sdp_params const& kernel_params) {
|
|||
sdp::can_use_mem_efficient_attention(kernel_params, print_debug);
|
||||
TORCH_WARN("Flash attention kernel not used because:");
|
||||
sdp::can_use_flash_attention(kernel_params, print_debug);
|
||||
TORCH_WARN("CuDNN attention kernel not used because:");
|
||||
TORCH_WARN("cuDNN attention kernel not used because:");
|
||||
sdp::can_use_cudnn_attention(kernel_params, print_debug);
|
||||
TORCH_CHECK(!print_debug, "No available kernel. Aborting execution.")
|
||||
return SDPBackend::error;
|
||||
|
|
|
|||
|
|
@ -75,6 +75,7 @@ aten::_ctc_loss.out
|
|||
aten::_ctc_loss_backward
|
||||
aten::_ctc_loss_backward.Tensor
|
||||
aten::_ctc_loss_backward.out
|
||||
aten::_cudnn_attention_backward
|
||||
aten::_cudnn_attention_forward
|
||||
aten::_cudnn_ctc_loss
|
||||
aten::_cudnn_ctc_loss.Tensor
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ from torch._inductor.utils import (
|
|||
run_fw_bw_and_get_code,
|
||||
)
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.nn.attention import sdpa_kernel, SDPBackend
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_cuda import (
|
||||
PLATFORM_SUPPORTS_FLASH_ATTENTION,
|
||||
|
|
@ -177,9 +178,10 @@ class CudaReproTests(TestCase):
|
|||
inputs = [q, k, v, mask]
|
||||
|
||||
def f(q, k, v, mask):
|
||||
return F.scaled_dot_product_attention(
|
||||
q, k, v, attn_mask=mask, dropout_p=0.0
|
||||
)
|
||||
with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
|
||||
return F.scaled_dot_product_attention(
|
||||
q, k, v, attn_mask=mask, dropout_p=0.0
|
||||
)
|
||||
|
||||
f_compiled = torch.compile(f)
|
||||
|
||||
|
|
|
|||
|
|
@ -6760,11 +6760,10 @@ torch.cuda.synchronize()
|
|||
and check_cudnn
|
||||
and (dtype == torch.float16 or dtype == torch.bfloat16)
|
||||
):
|
||||
with self.assertRaisesRegex(RuntimeError, "cuDNN SDPA Nested Tensor"):
|
||||
with torch.nn.attention.sdpa_kernel(
|
||||
torch.nn.attention.SDPBackend.CUDNN_ATTENTION
|
||||
):
|
||||
check_forward_backward()
|
||||
with torch.nn.attention.sdpa_kernel(
|
||||
torch.nn.attention.SDPBackend.CUDNN_ATTENTION
|
||||
):
|
||||
check_forward_backward()
|
||||
|
||||
@skipIfTorchDynamo("SDPA test compiles internally")
|
||||
@unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile")
|
||||
|
|
|
|||
|
|
@ -49,7 +49,6 @@ from torch.testing._internal.common_cuda import (
|
|||
PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
|
||||
PLATFORM_SUPPORTS_FUSED_ATTENTION,
|
||||
PLATFORM_SUPPORTS_CUDNN_ATTENTION,
|
||||
SM90OrLater,
|
||||
tf32_on_and_off,
|
||||
tf32_enabled,
|
||||
)
|
||||
|
|
@ -2657,6 +2656,7 @@ class TestSDPACudaOnly(NNTestCase):
|
|||
|
||||
@skipIfRocm # No cuDNN Attention
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system")
|
||||
@unittest.expectedFailure # cuDNN currently doesn't support this on SM100+/fails graph validation
|
||||
def test_cudnn_attention_d256_heuristic(self, device):
|
||||
dtype = torch.bfloat16
|
||||
make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=True)
|
||||
|
|
@ -2667,7 +2667,7 @@ class TestSDPACudaOnly(NNTestCase):
|
|||
v_shape = SdpaShape(batch, num_heads, seq_len, head_dim_v)
|
||||
query, key, value = make_tensor(q_shape), make_tensor(k_shape), make_tensor(v_shape)
|
||||
|
||||
with sdpa_kernel(backends=[SDPBackend.CUDNN_ATTENTION, SDPBackend.MATH], set_priority=True):
|
||||
with sdpa_kernel(backends=[SDPBackend.CUDNN_ATTENTION], set_priority=True):
|
||||
actual = torch.nn.functional.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False)
|
||||
actual.backward(torch.randn_like(actual))
|
||||
|
|
@ -2705,7 +2705,7 @@ class TestSDPACudaOnly(NNTestCase):
|
|||
|
||||
|
||||
@skipIfRocm # No cuDNN Attention
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system")
|
||||
@unittest.skipIf(True, "broken as of cuDNN 9.10")
|
||||
def test_cudnn_attention_fail_d128(self, device):
|
||||
# Test that cuDNN attention dispatching correctly bails out on d > 128
|
||||
b, h = 1, 2
|
||||
|
|
@ -2720,7 +2720,6 @@ class TestSDPACudaOnly(NNTestCase):
|
|||
ISSM90 = device_cap == (9, 0)
|
||||
ISSM100 = device_cap == (10, 0)
|
||||
with sdpa_kernel(backends=[SDPBackend.CUDNN_ATTENTION]):
|
||||
# SM90/100 support d <= 256 as of cuDNN 9.5.1+
|
||||
if (ISSM90 or ISSM100) and torch.backends.cudnn.version() >= 90501:
|
||||
torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
||||
else:
|
||||
|
|
@ -3156,15 +3155,19 @@ class TestSDPACudaOnly(NNTestCase):
|
|||
value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
|
||||
|
||||
device_capability = None
|
||||
if "cuda" in str(device):
|
||||
device_capability = torch.cuda.get_device_capability()
|
||||
prefer_cudnn = "TORCH_CUDNN_SDPA_PREFERRED" in os.environ
|
||||
prefer_cudnn = prefer_cudnn and device_capability and (device_capability == (9, 0) or device_capability == (10, 0))
|
||||
|
||||
# TODO we are currently disabling this by default, lets assert that this returns
|
||||
# FlashAttention, we need to change when we make remove opt-in for cudnn
|
||||
if type != "nested" and PLATFORM_SUPPORTS_CUDNN_ATTENTION and SM90OrLater:
|
||||
self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.FLASH_ATTENTION.value)
|
||||
with sdpa_kernel(backends=[SDPBackend.CUDNN_ATTENTION]):
|
||||
self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.CUDNN_ATTENTION.value)
|
||||
if type != "nested" and PLATFORM_SUPPORTS_CUDNN_ATTENTION and prefer_cudnn:
|
||||
self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.CUDNN_ATTENTION.value)
|
||||
elif PLATFORM_SUPPORTS_FLASH_ATTENTION:
|
||||
self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.FLASH_ATTENTION.value)
|
||||
elif type != "nested" and PLATFORM_SUPPORTS_CUDNN_ATTENTION: # e.g., we're on Windows
|
||||
elif type != "nested" and PLATFORM_SUPPORTS_CUDNN_ATTENTION and not prefer_cudnn: # e.g., we're on Windows
|
||||
self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.EFFICIENT_ATTENTION.value)
|
||||
with sdpa_kernel(backends=[SDPBackend.CUDNN_ATTENTION]):
|
||||
self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.CUDNN_ATTENTION.value)
|
||||
|
|
|
|||
|
|
@ -2904,6 +2904,10 @@
|
|||
output_differentiability: [True, False, False, False, False, False]
|
||||
query, key, value, bias: _efficient_attention_backward_symint(grad, query, key, value, bias, output, cu_seqlens_q, cu_seqlens_k, max_seqlen_batch_q, max_seqlen_batch_k, logsumexp, dropout_p, philox_seed, philox_offset, custom_mask_type, bias.requires_grad(), scale)
|
||||
|
||||
- name: _cudnn_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
|
||||
output_differentiability: [True, False, False, False, False, False, False, False, False]
|
||||
query, key, value: _cudnn_attention_backward_symint(grad, query, key, value, output, logsumexp, philox_seed, philox_offset, attn_bias, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, scale)
|
||||
|
||||
- name: _scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
|
||||
output_differentiability: [True, False, False, False, False, False, False, False, False]
|
||||
query, key, value: _scaled_dot_product_cudnn_attention_backward_symint(grad, query, key, value, output, logsumexp, philox_seed, philox_offset, attn_bias, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, scale)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user