mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "[cuDNN][SDPA] cuDNN SDPA refactor/cleanup, nested tensor backward, test priority bump for sm90, sm100 (#149282)"
This reverts commit 9386701b51.
Reverted https://github.com/pytorch/pytorch/pull/149282 on behalf of https://github.com/jeanschmidt due to Breaking internal builds, see [D74729259](https://www.internalfb.com/diff/D74729259). @drisspg may you help out the author have their PR merged? ([comment](https://github.com/pytorch/pytorch/pull/149282#issuecomment-2881546951))
This commit is contained in:
parent
c92ea3bc98
commit
f363a3f51a
File diff suppressed because it is too large
Load Diff
|
|
@ -70,31 +70,4 @@ void run_cudnn_SDP_bprop(
|
|||
const Tensor& dropoutseed,
|
||||
const Tensor& dropoutoffset);
|
||||
|
||||
void run_cudnn_SDP_bprop_nestedtensor(
|
||||
int64_t b,
|
||||
int64_t h_q,
|
||||
int64_t h_k,
|
||||
int64_t h_v,
|
||||
int64_t s_q,
|
||||
int64_t s_kv,
|
||||
int64_t d_qk,
|
||||
int64_t d_v,
|
||||
float scaling_factor,
|
||||
bool is_causal,
|
||||
float dropout_probability,
|
||||
const Tensor& cum_seqlen_q,
|
||||
const Tensor& cum_seqlen_kv,
|
||||
const Tensor& q,
|
||||
const Tensor& k,
|
||||
const Tensor& v,
|
||||
const std::optional<Tensor>& attn_bias,
|
||||
const Tensor& o,
|
||||
const Tensor& dO,
|
||||
const Tensor& softmaxstats,
|
||||
Tensor& dQ,
|
||||
Tensor& dK,
|
||||
Tensor& dV,
|
||||
const Tensor& dropoutseed,
|
||||
const Tensor& dropoutoffset);
|
||||
|
||||
} // namespace at::native
|
||||
|
|
|
|||
|
|
@ -14958,7 +14958,6 @@
|
|||
- func: _scaled_dot_product_cudnn_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor attn_bias, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, *, float? scale=None) -> (Tensor, Tensor, Tensor)
|
||||
dispatch:
|
||||
CUDA: _scaled_dot_product_cudnn_attention_backward_cuda
|
||||
NestedTensorCUDA: _scaled_dot_product_cudnn_attention_nestedtensor_backward_cuda
|
||||
tags: nondeterministic_seeded
|
||||
|
||||
- func: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None, SymInt? window_size_left=None, SymInt? window_size_right=None, Tensor? seqused_k=None, Tensor? alibi_slopes=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor rng_state, Tensor unused, Tensor debug_attn_mask)
|
||||
|
|
@ -14991,11 +14990,6 @@
|
|||
CUDA: _cudnn_attention_forward
|
||||
tags: nondeterministic_seeded
|
||||
|
||||
- func: _cudnn_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor attn_bias, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, *, float? scale=None) -> (Tensor, Tensor, Tensor)
|
||||
dispatch:
|
||||
CUDA: _cudnn_attention_backward
|
||||
tags: nondeterministic_seeded
|
||||
|
||||
- func: _triton_scaled_dot_attention(Tensor q, Tensor k, Tensor v, float dropout_p=0.0) -> Tensor
|
||||
variants: function
|
||||
dispatch:
|
||||
|
|
|
|||
|
|
@ -349,63 +349,6 @@ _scaled_dot_product_cudnn_attention_nestedtensor_cuda(
|
|||
return std::make_tuple(std::move(attention), std::move(log_sumexp), cumulative_sequence_length_q, cumulative_sequence_length_kv, max_seqlen_batch_q, max_seqlen_batch_kv, std::move(cudnn_seed), std::move(cudnn_offset), Tensor());
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor, Tensor> _scaled_dot_product_cudnn_attention_nestedtensor_backward_cuda(
|
||||
const Tensor& grad_out,
|
||||
const Tensor& query,
|
||||
const Tensor& key,
|
||||
const Tensor& value,
|
||||
const Tensor& out,
|
||||
const Tensor& logsumexp,
|
||||
const Tensor& philox_seed,
|
||||
const Tensor& philox_offset,
|
||||
const Tensor& attn_bias,
|
||||
const Tensor& cum_seq_q,
|
||||
const Tensor& cum_seq_k,
|
||||
const int64_t max_q,
|
||||
const int64_t max_k,
|
||||
double dropout_p,
|
||||
bool is_causal,
|
||||
std::optional<double> scale) {
|
||||
if (!grad_out.defined()) {
|
||||
return std::make_tuple(Tensor{}, Tensor{}, Tensor{});
|
||||
}
|
||||
auto [
|
||||
grad_out_buffer_reshaped,
|
||||
query_buffer_reshaped,
|
||||
key_buffer_reshaped,
|
||||
value_buffer_reshaped,
|
||||
output_buffer_reshaped] =
|
||||
preprocessing::sdpa_nested_preprocessing_backward(
|
||||
grad_out,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
out,
|
||||
cum_seq_q,
|
||||
cum_seq_k,
|
||||
max_q,
|
||||
max_k);
|
||||
|
||||
auto [dq, dk, dv] = at::_cudnn_attention_backward(grad_out_buffer_reshaped,
|
||||
query_buffer_reshaped,
|
||||
key_buffer_reshaped,
|
||||
value_buffer_reshaped,
|
||||
output_buffer_reshaped,
|
||||
logsumexp,
|
||||
philox_seed,
|
||||
philox_offset,
|
||||
attn_bias,
|
||||
cum_seq_q,
|
||||
cum_seq_k,
|
||||
max_q,
|
||||
max_k,
|
||||
dropout_p,
|
||||
is_causal,
|
||||
scale);
|
||||
return std::make_tuple(dq, dk, dv);
|
||||
}
|
||||
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor> _scaled_dot_product_flash_attention_backward_nested(
|
||||
const at::Tensor& grad_out_,
|
||||
const at::Tensor& query,
|
||||
|
|
|
|||
|
|
@ -848,6 +848,16 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt, Tensor, Ten
|
|||
// TODO(eqy): support debug_attn_mask
|
||||
return std::make_tuple(std::move(attention), std::move(log_sumexp), Tensor(), Tensor(), max_seqlen_batch_q, max_seqlen_batch_kv, std::move(cudnn_seed), std::move(cudnn_offset), Tensor());
|
||||
} else {
|
||||
//auto [
|
||||
// query_buffer_reshaped,
|
||||
// key_buffer_reshaped,
|
||||
// value_buffer_reshaped,
|
||||
// cumulative_sequence_length_q,
|
||||
// cumulative_sequence_length_kv,
|
||||
// max_seqlen_batch_q,
|
||||
// max_seqlen_batch_kv,
|
||||
// output_shape] = preprocessing::sdpa_nested_preprocessing(query, key, value);
|
||||
// C10_LOG_API_USAGE_ONCE("torch.sdpa.flash_attention_cudnn");
|
||||
// TODO(eqy): debug mask support
|
||||
// BHSD ...
|
||||
const int64_t batch_size = cumulative_sequence_length_q.value().size(0) - 1;
|
||||
|
|
|
|||
|
|
@ -24,8 +24,6 @@
|
|||
#include <ATen/Functions.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#else
|
||||
#include <ATen/ops/_cudnn_attention_backward.h>
|
||||
#include <ATen/ops/_cudnn_attention_backward_native.h>
|
||||
#include <ATen/ops/_flash_attention_backward.h>
|
||||
#include <ATen/ops/_flash_attention_backward_native.h>
|
||||
#include <ATen/ops/_efficient_attention_backward.h>
|
||||
|
|
@ -172,7 +170,7 @@ std::tuple<Tensor, Tensor, Tensor> _flash_attention_backward(
|
|||
return std::make_tuple(Tensor(), Tensor(), Tensor());
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor, Tensor> _cudnn_attention_backward(
|
||||
std::tuple<Tensor, Tensor, Tensor> _scaled_dot_product_cudnn_attention_backward_cuda(
|
||||
const Tensor& grad_out,
|
||||
const Tensor& query,
|
||||
const Tensor& key,
|
||||
|
|
@ -199,15 +197,12 @@ std::tuple<Tensor, Tensor, Tensor> _cudnn_attention_backward(
|
|||
}
|
||||
}
|
||||
|
||||
const bool is_nested = cum_seq_q.defined();
|
||||
const int64_t max_seqlen_batch_q = query.size(2);
|
||||
const int64_t max_seqlen_batch_k = key.size(2);
|
||||
|
||||
if (!is_nested) {
|
||||
const int64_t batch_size = query.size(0);
|
||||
const int64_t num_heads = query.size(1);
|
||||
const int64_t head_dim_qk = query.size(3);
|
||||
const int64_t head_dim_v = value.size(3);
|
||||
const int64_t max_seqlen_batch_q = query.size(2);
|
||||
const int64_t max_seqlen_batch_k = key.size(2);
|
||||
|
||||
// This is needed because SaveVariable automatically converts
|
||||
// std::optional to undefined tensor
|
||||
|
|
@ -253,63 +248,6 @@ std::tuple<Tensor, Tensor, Tensor> _cudnn_attention_backward(
|
|||
philox_seed/*Tensor& dropoutseed*/,
|
||||
philox_offset/*Tensor& dropoutoffset*/);
|
||||
return std::make_tuple(std::move(dq), std::move(dk), std::move(dv));
|
||||
} else {
|
||||
// BHSD ...
|
||||
const int64_t batch_size = cum_seq_q.size(0) - 1;
|
||||
const int64_t num_heads_q = query.size(-2);
|
||||
const int64_t num_heads_k = key.size(-2);
|
||||
const int64_t num_heads_v = value.size(-2);
|
||||
const int64_t head_dim_qk = query.size(-1);
|
||||
const int64_t head_dim_v = value.size(-1);
|
||||
std::optional<Tensor> attn_bias_;
|
||||
if (attn_bias.defined()) {
|
||||
attn_bias_ = attn_bias;
|
||||
}
|
||||
if (attn_bias_.has_value()) {
|
||||
const auto bias_dim = attn_bias_.value().dim();
|
||||
if (bias_dim == 2) {
|
||||
attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k});
|
||||
} else if (bias_dim == 3) {
|
||||
attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k});
|
||||
} else {
|
||||
attn_bias_ = attn_bias_.value().expand({batch_size, attn_bias_.value().size(1), max_seqlen_batch_q, max_seqlen_batch_k});
|
||||
TORCH_CHECK(bias_dim == 4, "cuDNN SDPA expects either a 2D, 3D, or 4D attn_bias but got ", attn_bias_.value().dim(), "D");
|
||||
}
|
||||
}
|
||||
|
||||
auto dq = at::empty_like(query);
|
||||
auto dk = at::empty_like(key);
|
||||
auto dv = at::empty_like(value);
|
||||
|
||||
const auto softmax_scale = sdp::calculate_scale(query, scale).as_float_unchecked();
|
||||
run_cudnn_SDP_bprop_nestedtensor(
|
||||
batch_size,
|
||||
num_heads_q,
|
||||
num_heads_k,
|
||||
num_heads_v,
|
||||
max_seqlen_batch_q,
|
||||
max_seqlen_batch_k,
|
||||
head_dim_qk,
|
||||
head_dim_v,
|
||||
softmax_scale,
|
||||
is_causal,
|
||||
dropout_p,
|
||||
cum_seq_q,
|
||||
cum_seq_k,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_bias_,
|
||||
out,
|
||||
grad_out,
|
||||
logsumexp,
|
||||
dq,
|
||||
dk,
|
||||
dv,
|
||||
philox_seed,
|
||||
philox_offset);
|
||||
return std::make_tuple(std::move(dq), std::move(dk), std::move(dv));
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
|
||||
|
|
@ -1012,40 +950,4 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> _scaled_dot_product_e
|
|||
grad_q.transpose(1, 2), grad_k.transpose(1, 2), grad_v.transpose(1, 2), grad_bias);
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor, Tensor> _scaled_dot_product_cudnn_attention_backward_cuda(
|
||||
const Tensor& grad_out,
|
||||
const Tensor& query,
|
||||
const Tensor& key,
|
||||
const Tensor& value,
|
||||
const Tensor& out,
|
||||
const Tensor& logsumexp,
|
||||
const Tensor& philox_seed,
|
||||
const Tensor& philox_offset,
|
||||
const Tensor& attn_bias,
|
||||
const Tensor& cum_seq_q,
|
||||
const Tensor& cum_seq_k,
|
||||
const int64_t max_q,
|
||||
const int64_t max_k,
|
||||
double dropout_p,
|
||||
bool is_causal,
|
||||
std::optional<double> scale) {
|
||||
return at::_cudnn_attention_backward(
|
||||
grad_out,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
out,
|
||||
logsumexp,
|
||||
philox_seed,
|
||||
philox_offset,
|
||||
attn_bias,
|
||||
cum_seq_q,
|
||||
cum_seq_k,
|
||||
max_q,
|
||||
max_k,
|
||||
dropout_p,
|
||||
is_causal,
|
||||
scale);
|
||||
}
|
||||
|
||||
} // namespace at::native
|
||||
|
|
|
|||
|
|
@ -57,21 +57,12 @@
|
|||
namespace sdp {
|
||||
namespace {
|
||||
|
||||
// tracks whether we've set the default priority order once, to avoid setting
|
||||
// it redundantly or overwriting a user-specified priority order
|
||||
// when the priority order context manager is used before the default priority
|
||||
// order is initialized the following happens:
|
||||
// (1) the current priority order is queried
|
||||
// (2) priority_order() is called, which initializes it to the default as init_ is false
|
||||
// (3) the user-specified priority order is set
|
||||
// (3.1) we are in the priority context...
|
||||
// (3.2) we exit the priority context...
|
||||
// (4) the previous priority order (default) is restored
|
||||
bool priority_order_init_ = false;
|
||||
|
||||
// TODO(eqy): more benchmarking to determine whether this should include sm86/89
|
||||
// Needs to be kept in-sync with test_fused_chocie in test_transformers.py
|
||||
bool check_prefer_cudnn_attention() {
|
||||
// TODO(eqy): Re-enable by default after upgrading to a release later than 9.5.0
|
||||
// see context: https://github.com/pytorch/pytorch/issues/138340
|
||||
// return false;
|
||||
#if defined(CUDNN_VERSION)
|
||||
|
||||
#if CUDNN_VERSION > 90000
|
||||
|
|
@ -88,16 +79,6 @@ bool check_prefer_cudnn_attention() {
|
|||
|
||||
// flash_attention V2 is universally faster than efficient_attention and Math
|
||||
std::array<SDPBackend, num_backends> priority_order(sdp_params const& params) {
|
||||
if (!priority_order_init_) {
|
||||
priority_order_init_ = true;
|
||||
if (check_prefer_cudnn_attention()) {
|
||||
const std::vector<int64_t> cudnn_order = {static_cast<int64_t>(at::SDPBackend::cudnn_attention),
|
||||
static_cast<int64_t>(at::SDPBackend::flash_attention),
|
||||
static_cast<int64_t>(at::SDPBackend::efficient_attention),
|
||||
static_cast<int64_t>(at::SDPBackend::math)};
|
||||
at::globalContext().setSDPPriorityOrder(cudnn_order);
|
||||
}
|
||||
}
|
||||
return at::globalContext().sDPPriorityOrder();
|
||||
}
|
||||
|
||||
|
|
@ -472,15 +453,9 @@ bool check_cudnn_tensor_shapes(sdp_params const& params, bool debug) {
|
|||
return false;
|
||||
}
|
||||
}
|
||||
if (s_k == 1) {
|
||||
if (s_q == 1 || s_k == 1) {
|
||||
if (debug) {
|
||||
TORCH_WARN_ONCE("cudnn SDPA does not support key/value sequence length 1.");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
if (s_q == 1 && params.dropout != 0.0) {
|
||||
if (debug) {
|
||||
TORCH_WARN_ONCE("cudnn SDPA does not support query sequence length 1 with dropout.");
|
||||
TORCH_WARN_ONCE("cudnn SDPA does not support sequence length 1.");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
|
@ -588,9 +563,9 @@ bool check_for_nested_inputs(sdp_params const& params, bool debug) {
|
|||
|
||||
const auto dprop = at::cuda::getCurrentDeviceProperties();
|
||||
// Check that the input is nested
|
||||
if ((dprop->major == 9 || dprop->major == 10) && has_for_nested_inputs(params)) {
|
||||
if (dprop->major != 9 && has_for_nested_inputs(params)) {
|
||||
if (debug) {
|
||||
TORCH_WARN("cuDNN SDPA supports nested tensors on SM 9.0, SM 10.0.");
|
||||
TORCH_WARN("CuDNN SDPA supports nested tensors on SM 9.0.");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
|
@ -614,7 +589,7 @@ bool check_runtime_disabled_cudnn(sdp_params const& params, bool debug) {
|
|||
// sdp kernels
|
||||
if (!at::globalContext().userEnabledCuDNNSDP()) {
|
||||
if (debug) {
|
||||
TORCH_WARN("cuDNN attention has been runtime disabled.");
|
||||
TORCH_WARN("CuDNN attention has been runtime disabled.");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
|
@ -645,7 +620,7 @@ bool can_use_cudnn_attention(const sdp_params& params, bool debug) {
|
|||
#endif
|
||||
#if defined(CUDNN_VERSION) && CUDNN_VERSION < 90000
|
||||
if (debug) {
|
||||
TORCH_WARN(CUDNN_VERSION, " cuDNN version too old to use cuDNN Attention (< v9.0.0)");
|
||||
TORCH_WARN(CUDNN_VERSION, " cuDNN version too old to use CuDNN Attention (< v9.0.0)");
|
||||
}
|
||||
return false;
|
||||
#endif
|
||||
|
|
@ -655,8 +630,10 @@ bool can_use_cudnn_attention(const sdp_params& params, bool debug) {
|
|||
c10::array_of<bool (*)(sdp_params const&, bool)>(
|
||||
check_runtime_disabled_cudnn,
|
||||
check_for_nested_inputs,
|
||||
check_nonzero_sequence_lengths_dense,
|
||||
check_all_tensors_on_device,
|
||||
check_tensor_shapes,
|
||||
check_cudnn_tensor_shapes,
|
||||
check_cudnn_deterministic,
|
||||
check_dtypes_low_precision,
|
||||
check_attn_mask_shape,
|
||||
|
|
@ -669,10 +646,8 @@ bool can_use_cudnn_attention(const sdp_params& params, bool debug) {
|
|||
}
|
||||
constexpr auto dense_constraints =
|
||||
c10::array_of<bool (*)(sdp_params const&, bool)>(
|
||||
check_nonzero_sequence_lengths_dense,
|
||||
check_last_dim_stride_equals_1_dense<true /*ignore_singleton_dim=*/>,
|
||||
check_batch_size_and_num_heads_dense<true /*enable_gqa*/, false /*requires_same_num_heads*/>,
|
||||
check_cudnn_tensor_shapes
|
||||
check_batch_size_and_num_heads_dense<true /*enable_gqa*/, false /*requires_same_num_heads*/>
|
||||
);
|
||||
|
||||
if (has_only_dense_inputs(params)) {
|
||||
|
|
@ -884,7 +859,7 @@ SDPBackend select_sdp_backend(sdp_params const& kernel_params) {
|
|||
sdp::can_use_mem_efficient_attention(kernel_params, print_debug);
|
||||
TORCH_WARN("Flash attention kernel not used because:");
|
||||
sdp::can_use_flash_attention(kernel_params, print_debug);
|
||||
TORCH_WARN("cuDNN attention kernel not used because:");
|
||||
TORCH_WARN("CuDNN attention kernel not used because:");
|
||||
sdp::can_use_cudnn_attention(kernel_params, print_debug);
|
||||
TORCH_CHECK(!print_debug, "No available kernel. Aborting execution.")
|
||||
return SDPBackend::error;
|
||||
|
|
|
|||
|
|
@ -75,7 +75,6 @@ aten::_ctc_loss.out
|
|||
aten::_ctc_loss_backward
|
||||
aten::_ctc_loss_backward.Tensor
|
||||
aten::_ctc_loss_backward.out
|
||||
aten::_cudnn_attention_backward
|
||||
aten::_cudnn_attention_forward
|
||||
aten::_cudnn_ctc_loss
|
||||
aten::_cudnn_ctc_loss.Tensor
|
||||
|
|
|
|||
|
|
@ -6746,6 +6746,7 @@ torch.cuda.synchronize()
|
|||
and check_cudnn
|
||||
and (dtype == torch.float16 or dtype == torch.bfloat16)
|
||||
):
|
||||
with self.assertRaisesRegex(RuntimeError, "cuDNN SDPA Nested Tensor"):
|
||||
with torch.nn.attention.sdpa_kernel(
|
||||
torch.nn.attention.SDPBackend.CUDNN_ATTENTION
|
||||
):
|
||||
|
|
|
|||
|
|
@ -49,6 +49,7 @@ from torch.testing._internal.common_cuda import (
|
|||
PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
|
||||
PLATFORM_SUPPORTS_FUSED_ATTENTION,
|
||||
PLATFORM_SUPPORTS_CUDNN_ATTENTION,
|
||||
SM90OrLater,
|
||||
tf32_on_and_off,
|
||||
tf32_enabled,
|
||||
)
|
||||
|
|
@ -2993,18 +2994,15 @@ class TestSDPACudaOnly(NNTestCase):
|
|||
value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
|
||||
|
||||
device_capability = None
|
||||
if "cuda" in str(device):
|
||||
device_capability = torch.cuda.get_device_capability()
|
||||
prefer_cudnn = device_capability and (device_capability == (9, 0) or device_capability == (10, 0))
|
||||
|
||||
# TODO we are currently disabling this by default, lets assert that this returns
|
||||
# FlashAttention, we need to change when we make remove opt-in for cudnn
|
||||
if type != "nested" and PLATFORM_SUPPORTS_CUDNN_ATTENTION and prefer_cudnn:
|
||||
if type != "nested" and PLATFORM_SUPPORTS_CUDNN_ATTENTION and SM90OrLater:
|
||||
self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.FLASH_ATTENTION.value)
|
||||
with sdpa_kernel(backends=[SDPBackend.CUDNN_ATTENTION]):
|
||||
self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.CUDNN_ATTENTION.value)
|
||||
elif PLATFORM_SUPPORTS_FLASH_ATTENTION:
|
||||
self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.FLASH_ATTENTION.value)
|
||||
elif type != "nested" and PLATFORM_SUPPORTS_CUDNN_ATTENTION and not prefer_cudnn: # e.g., we're on Windows
|
||||
elif type != "nested" and PLATFORM_SUPPORTS_CUDNN_ATTENTION: # e.g., we're on Windows
|
||||
self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.EFFICIENT_ATTENTION.value)
|
||||
with sdpa_kernel(backends=[SDPBackend.CUDNN_ATTENTION]):
|
||||
self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.CUDNN_ATTENTION.value)
|
||||
|
|
|
|||
|
|
@ -2896,10 +2896,6 @@
|
|||
output_differentiability: [True, False, False, False, False, False]
|
||||
query, key, value, bias: _efficient_attention_backward_symint(grad, query, key, value, bias, output, cu_seqlens_q, cu_seqlens_k, max_seqlen_batch_q, max_seqlen_batch_k, logsumexp, dropout_p, philox_seed, philox_offset, custom_mask_type, bias.requires_grad(), scale)
|
||||
|
||||
- name: _cudnn_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
|
||||
output_differentiability: [True, False, False, False, False, False, False, False, False]
|
||||
query, key, value: _cudnn_attention_backward_symint(grad, query, key, value, output, logsumexp, philox_seed, philox_offset, attn_bias, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, scale)
|
||||
|
||||
- name: _scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
|
||||
output_differentiability: [True, False, False, False, False, False, False, False, False]
|
||||
query, key, value: _scaled_dot_product_cudnn_attention_backward_symint(grad, query, key, value, output, logsumexp, philox_seed, philox_offset, attn_bias, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, scale)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user