Revert "[cuDNN][SDPA][Nested Tensor] Experimental cuDNN Nested Tensor SDPA Support (forward only) (#141178)"

This reverts commit 533b884870.

Reverted https://github.com/pytorch/pytorch/pull/141178 on behalf of https://github.com/jeanschmidt due to Broke internal arvr signals, see D69971019. @jbschlosser please help the author get this PR merged ([comment](https://github.com/pytorch/pytorch/pull/141178#issuecomment-2676317470))
This commit is contained in:
PyTorch MergeBot 2025-02-22 17:28:12 +00:00
parent bea72180ed
commit fa8e3a28a7
12 changed files with 114 additions and 863 deletions

View File

@ -25,10 +25,4 @@ unpack(at::PhiloxCudaState arg) {
}
}
// Adapted from TE
// extract seed and offset from PhiloxCudaState
__global__ void unpack_cudnn(at::PhiloxCudaState arg, int64_t* seed_ptr, int64_t* offset_ptr);
void unpack_cudnn_wrapper(at::PhiloxCudaState arg, int64_t* seed_ptr, int64_t* offset_ptr, cudaStream_t stream);
} // namespace at::cuda::philox

View File

@ -31,33 +31,6 @@ void run_cudnn_SDP_fprop(
false, "PyTorch was not compiled with cuDNN Flash Attention enabled!");
}
void run_cudnn_SDP_fprop_nestedtensor(
int64_t b,
int64_t h_q,
int64_t h_k,
int64_t h_v,
int64_t s_q,
int64_t s_kv,
int64_t d_qk,
int64_t d_v,
float scaling_factor,
bool return_softmaxstats,
bool is_causal,
double dropout_probability,
const Tensor& cum_seqlen_q,
const Tensor& cum_seqlen_kv,
const Tensor& q,
const Tensor& k,
const Tensor& v,
const std::optional<Tensor>& attn_bias,
Tensor& softmaxstats,
Tensor& o,
Tensor& dropoutseed,
Tensor& dropoutoffset) {
TORCH_CHECK(
false, "PyTorch was not compiled with cuDNN Flash Attention enabled!");
}
void run_cudnn_SDP_bprop(
int64_t b,
int64_t h,
@ -488,6 +461,16 @@ auto build_graph_and_tensors(
.set_stride(attn_bias.value().strides().vec()));
scaled_dot_product_flash_attention_options.set_bias(bias.value());
}
auto seq_q = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Seq_q")
.set_dim({b, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::INT32));
auto seq_kv = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Seq_kv")
.set_dim({b, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::INT32));
auto [O, Stats] =
mha_graph->sdpa(Q, K, V, scaled_dot_product_flash_attention_options);
@ -517,201 +500,6 @@ auto build_graph_and_tensors(
std::move(Stats));
}
auto build_graph_and_tensors_nestedtensor(
int64_t b,
int64_t h_q,
int64_t h_k,
int64_t h_v,
int64_t s_q,
int64_t s_kv,
int64_t d_qk,
int64_t d_v,
float scaling_factor,
bool return_softmaxstats,
bool is_causal,
double dropout_probability,
const Tensor& cum_seqlen_q,
const Tensor& cum_seqlen_kv,
const Tensor& q,
const Tensor& k,
const Tensor& v,
const std::optional<Tensor>& attn_bias,
Tensor& softmaxstats,
Tensor& o,
Tensor& dropoutseed,
Tensor& dropoutoffset,
cudnnHandle_t& handle) {
auto dtype = fe::DataType_t::HALF;
if (q.scalar_type() == kBFloat16) {
dtype = fe::DataType_t::BFLOAT16;
}
auto mha_graph = std::make_shared<fe::graph::Graph>();
// We're baking in float accumulation and scale types
// in theory the graph may support other types, but they
// have not been tested
mha_graph->set_io_data_type(dtype)
.set_intermediate_data_type(fe::DataType_t::FLOAT)
.set_compute_data_type(fe::DataType_t::FLOAT);
auto attn_scale =
mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Attn_scale")
.set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_is_pass_by_value(true)
.set_data_type(fe::DataType_t::FLOAT));
auto seed = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Seed")
.set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::INT32));
auto offset = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Offset")
.set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::INT32));
auto SEQ_LEN_Q = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Seq_q")
.set_dim({b, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::INT32));
auto SEQ_LEN_KV =
mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Seq_kv")
.set_dim({b, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::INT32));
auto scaled_dot_product_flash_attention_options =
fe::graph::SDPA_attributes()
.set_name("CUDNN_SDPA_NESTEDTENSOR")
.set_is_inference(return_softmaxstats == false)
.set_causal_mask(is_causal)
.set_attn_scale(attn_scale)
.set_dropout(dropout_probability, seed, offset)
.set_seq_len_q(SEQ_LEN_Q)
.set_seq_len_kv(SEQ_LEN_KV)
.set_padding_mask(true);
// We hardcode BSHD to cuDNN even though the underlying layout is THD
auto q_strides = q.strides();
auto k_strides = k.strides();
auto v_strides = v.strides();
constexpr int strideidx0 = 1;
constexpr int strideidx1 = 0;
constexpr int strideidx2 = 2;
auto Q = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Q")
.set_dim({b, h_q, s_q, d_qk})
.set_stride(
{INT_MAX,
q_strides[strideidx0],
q_strides[strideidx1],
q_strides[strideidx2]}));
auto K = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("K")
.set_dim({b, h_k, s_kv, d_qk})
.set_stride(
{INT_MAX,
k_strides[strideidx0],
k_strides[strideidx1],
k_strides[strideidx2]}));
auto V = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("V")
.set_dim({b, h_v, s_kv, d_v})
.set_stride(
{INT_MAX,
v_strides[strideidx0],
v_strides[strideidx1],
v_strides[strideidx2]}));
std::optional<std::shared_ptr<fe::graph::Tensor_attributes>> bias;
if (attn_bias.has_value()) {
TORCH_CHECK(
false,
"attn_bias not yet supportd with cuDNN Attention and NestedTensor");
bias =
mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("bias")
.set_dim(attn_bias.value().sizes().vec())
.set_stride(attn_bias.value().strides().vec()));
scaled_dot_product_flash_attention_options.set_bias(bias.value());
}
auto RAG_Q_OFF = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("cum_seq_q")
.set_dim({b + 1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::INT32));
auto RAG_K_OFF = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("cum_seq_k")
.set_dim({b + 1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::INT32));
auto RAG_V_OFF = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("cum_seq_v")
.set_dim({b + 1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::INT32));
auto RAG_O_OFF = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("cum_seq_o")
.set_dim({b + 1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::INT32));
// auto RAG_STATS_OFF = mha_graph->tensor(fe::graph::Tensor_attributes()
// .set_name("cum_seq_stats")
// .set_dim({b + 1, 1, 1, 1})
// .set_stride({1, 1, 1, 1})
// .set_data_type(fe::DataType_t::INT32));
auto RAG_STATS_OFF = nullptr;
Q->set_ragged_offset(RAG_Q_OFF);
K->set_ragged_offset(RAG_K_OFF);
V->set_ragged_offset(RAG_V_OFF);
auto [O, Stats] =
mha_graph->sdpa(Q, K, V, scaled_dot_product_flash_attention_options);
auto o_strides = o.strides();
O->set_output(true)
.set_dim({b, h_q, s_q, d_v})
.set_stride(
{INT_MAX,
o_strides[strideidx0],
o_strides[strideidx1],
o_strides[strideidx2]});
O->set_ragged_offset(RAG_O_OFF);
if (Stats) {
TORCH_CHECK(
false,
"cuDNN SDPA Nested Tensor does not yet handle backwards/logsumexp computation");
// TODO(eqy): fix when stats (backward) support is added
Stats->set_output(true)
.set_data_type(fe::DataType_t::FLOAT)
.set_dim({b, h_q, s_q, 1})
.set_stride({h_q * s_q * d_v, d_v, s_q * d_v, 1});
Stats->set_ragged_offset(RAG_STATS_OFF);
}
AT_CUDNN_FRONTEND_CHECK(mha_graph->validate());
AT_CUDNN_FRONTEND_CHECK(mha_graph->build_operation_graph(handle));
AT_CUDNN_FRONTEND_CHECK(
mha_graph->create_execution_plans({fe::HeurMode_t::A}));
AT_CUDNN_FRONTEND_CHECK(mha_graph->check_support(handle));
AT_CUDNN_FRONTEND_CHECK(mha_graph->build_plans(handle));
return std::make_tuple(
std::move(mha_graph),
std::move(Q),
std::move(K),
std::move(V),
std::move(bias),
std::move(attn_scale),
std::move(seed),
std::move(offset),
std::move(O),
std::move(Stats),
std::move(RAG_Q_OFF),
std::move(RAG_K_OFF),
std::move(RAG_V_OFF),
std::move(RAG_O_OFF),
std::move(RAG_STATS_OFF),
std::move(SEQ_LEN_Q),
std::move(SEQ_LEN_KV));
}
auto build_graph_and_tensors_backward(
int64_t b,
int64_t h,
@ -949,119 +737,6 @@ void run_cudnn_SDP_fprop(
mhagraphcache.update(key, graph_and_tensors_values);
}
void run_cudnn_SDP_fprop_nestedtensor(
int64_t b,
int64_t h_q,
int64_t h_k,
int64_t h_v,
int64_t s_q,
int64_t s_kv,
int64_t d_qk,
int64_t d_v,
float scaling_factor,
bool return_softmaxstats,
bool is_causal,
double dropout_probability,
const Tensor& cum_seqlen_q,
const Tensor& cum_seqlen_kv,
const Tensor& q,
const Tensor& k,
const Tensor& v,
const std::optional<Tensor>& attn_bias,
Tensor& softmaxstats,
Tensor& o,
Tensor& dropoutseed,
Tensor& dropoutoffset) {
cudnnHandle_t handle = getCudnnHandle();
// do nothing if we got 0-element tensors
if (!q.numel() || !k.numel() || !v.numel()) {
return;
}
if (!o.defined()) {
o = at::empty({q.size(0), h_q, d_v}, q.options());
}
if (return_softmaxstats && !softmaxstats.defined()) {
softmaxstats = at::empty({q.size(0), h_q, 1}, q.options().dtype(kFloat));
}
auto
[mha_graph,
Q,
K,
V,
bias,
attn_scale,
seed,
offset,
O,
Stats,
RAG_Q_OFF,
RAG_K_OFF,
RAG_V_OFF,
RAG_O_OFF,
RAG_STATS_OFF,
SEQ_LEN_Q,
SEQ_LEN_KV] =
build_graph_and_tensors_nestedtensor(
b,
h_q,
h_k,
h_v,
s_q,
s_kv,
d_qk,
d_v,
scaling_factor,
return_softmaxstats,
is_causal,
dropout_probability,
cum_seqlen_q,
cum_seqlen_kv,
q,
k,
v,
attn_bias,
softmaxstats,
o,
dropoutseed,
dropoutoffset,
handle);
auto seqlen_q = at::diff(cum_seqlen_q, 1, 0);
auto seqlen_kv = at::diff(cum_seqlen_kv, 1, 0);
auto rag_q_off = cum_seqlen_q.mul(h_q * d_qk);
auto rag_k_off = cum_seqlen_kv.mul(h_k * d_qk);
auto rag_v_off = cum_seqlen_kv.mul(h_v * d_v);
auto rag_stats_off = cum_seqlen_q.mul(h_q);
std::unordered_map<std::shared_ptr<fe::graph::Tensor_attributes>, void*>
variant_pack = {
{Q, q.data_ptr()},
{K, k.data_ptr()},
{V, v.data_ptr()},
{attn_scale, &scaling_factor},
{seed, dropoutseed.data_ptr()},
{offset, dropoutoffset.data_ptr()},
{O, o.data_ptr()},
{RAG_Q_OFF, rag_q_off.data_ptr()},
{RAG_O_OFF, rag_q_off.data_ptr()},
{RAG_K_OFF, rag_k_off.data_ptr()},
{RAG_V_OFF, rag_v_off.data_ptr()},
{SEQ_LEN_Q, seqlen_q.data_ptr()},
{SEQ_LEN_KV, seqlen_kv.data_ptr()}};
if (return_softmaxstats) {
variant_pack[Stats] = softmaxstats.data_ptr();
variant_pack[RAG_STATS_OFF] = cum_seqlen_q.data_ptr();
}
if (attn_bias.has_value()) {
TORCH_CHECK("bias not supported with nestedtensor");
}
auto workspace_size = mha_graph->get_workspace_size();
auto workspace_ptr =
c10::cuda::CUDACachingAllocator::get()->allocate(workspace_size);
TORCH_CHECK(
mha_graph->execute(handle, variant_pack, workspace_ptr.get()).is_good());
}
void run_cudnn_SDP_bprop(
int64_t b,
int64_t h,

View File

@ -23,30 +23,6 @@ void run_cudnn_SDP_fprop(
Tensor& dropoutseed,
Tensor& dropoutoffset);
void run_cudnn_SDP_fprop_nestedtensor(
int64_t b,
int64_t h_q,
int64_t h_k,
int64_t h_v,
int64_t max_s_q,
int64_t max_s_kv,
int64_t d_k,
int64_t d_v,
float scaling_factor,
bool isTraining,
bool is_causal,
double dropout_probability,
const Tensor& cum_seqlen_q,
const Tensor& cum_seqlen_kv,
const Tensor& q,
const Tensor& k,
const Tensor& v,
const std::optional<Tensor>& attn_bias,
Tensor& softmaxstats,
Tensor& o,
Tensor& dropoutseed,
Tensor& dropoutoffset);
void run_cudnn_SDP_bprop(
int64_t b,
int64_t h,

View File

@ -14902,7 +14902,6 @@
- func: _scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
dispatch:
CUDA: _scaled_dot_product_cudnn_attention_cuda
NestedTensorCUDA: _scaled_dot_product_cudnn_attention_nestedtensor_cuda
tags: nondeterministic_seeded
- func: _scaled_dot_product_cudnn_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor attn_bias, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, *, float? scale=None) -> (Tensor, Tensor, Tensor)
@ -14935,11 +14934,6 @@
dispatch:
CUDA: _efficient_attention_backward
- func: _cudnn_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
dispatch:
CUDA: _cudnn_attention_forward
tags: nondeterministic_seeded
- func: _triton_scaled_dot_attention(Tensor q, Tensor k, Tensor v, float dropout_p=0.0) -> Tensor
variants: function
dispatch:

View File

@ -1,5 +1,4 @@
#pragma once
#include <ATen/core/Tensor.h>
#include <ATen/ATen.h>
namespace at::native::preprocessing {

View File

@ -18,8 +18,6 @@
#include <ATen/native/transformers/cuda/sdp_utils.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAGeneratorImpl.h>
#include <ATen/cuda/CUDAGraphsUtils.cuh>
namespace at::native {
namespace {
@ -322,33 +320,6 @@ _scaled_dot_product_efficient_attention_nestedtensor_cuda(
return std::make_tuple(std::move(attention), std::move(log_sumexp), std::move(seed), std::move(offset));
}
std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt, Tensor, Tensor, Tensor>
_scaled_dot_product_cudnn_attention_nestedtensor_cuda(
const Tensor& query,
const Tensor& key,
const Tensor& value,
const std::optional<Tensor>& attn_bias,
bool compute_logsumexp,
double dropout_p,
bool is_causal,
bool return_debug_mask,
std::optional<double> scale) {
auto [
query_buffer_reshaped,
key_buffer_reshaped,
value_buffer_reshaped,
cumulative_sequence_length_q,
cumulative_sequence_length_kv,
max_seqlen_batch_q,
max_seqlen_batch_kv,
output_shape] = preprocessing::sdpa_nested_preprocessing(query, key, value);
auto [attention, log_sumexp, ignore1, ignore2, ignore3, ignore4, cudnn_seed, cudnn_offset, ignore5] = at::_cudnn_attention_forward(query_buffer_reshaped, key_buffer_reshaped, value_buffer_reshaped, attn_bias, cumulative_sequence_length_q, cumulative_sequence_length_kv, max_seqlen_batch_q, max_seqlen_batch_kv, compute_logsumexp, dropout_p, is_causal, return_debug_mask, scale);
attention = wrap_buffer(attention.view(-1), output_shape).transpose(1, 2);
return std::make_tuple(std::move(attention), std::move(log_sumexp), cumulative_sequence_length_q, cumulative_sequence_length_kv, max_seqlen_batch_q, max_seqlen_batch_kv, std::move(cudnn_seed), std::move(cudnn_offset), Tensor());
}
std::tuple<at::Tensor, at::Tensor, at::Tensor> _scaled_dot_product_flash_attention_backward_nested(
const at::Tensor& grad_out_,
const at::Tensor& query,

View File

@ -26,8 +26,6 @@
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_cudnn_attention_forward.h>
#include <ATen/ops/_cudnn_attention_forward_native.h>
#include <ATen/ops/_efficient_attention_forward.h>
#include <ATen/ops/_efficient_attention_forward_native.h>
#include <ATen/ops/_fill_mem_eff_dropout_mask_native.h>
@ -65,7 +63,6 @@
#include <ATen/native/transformers/attention.h>
#include <ATen/native/nested/NestedTensorUtils.h>
#include <ATen/native/nested/NestedTensorTransformerUtils.h>
#include <ATen/native/nested/NestedTensorTransformerFunctions.h>
#include <ATen/native/transformers/cuda/sdp_utils.h>
#include <ATen/native/transformers/sdp_utils_cpp.h>
@ -90,25 +87,6 @@
namespace at {
namespace cuda::philox {
__global__ void unpack_cudnn(at::PhiloxCudaState arg, int64_t* seed_ptr, int64_t* offset_ptr) {
if (arg.captured_) {
*seed_ptr = static_cast<int64_t>(*arg.seed_.ptr);
*offset_ptr = static_cast<int64_t>(
*(arg.offset_.ptr) + static_cast<int64_t>(arg.offset_intragraph_));
} else {
*seed_ptr = static_cast<int64_t>(arg.seed_.val);
*offset_ptr = static_cast<int64_t>(arg.offset_.val);
}
}
void unpack_cudnn_wrapper(at::PhiloxCudaState arg, int64_t* seed_ptr, int64_t* offset_ptr, cudaStream_t stream) {
at::cuda::philox::unpack_cudnn<<<1, 1, 0, stream>>>(arg, seed_ptr, offset_ptr);
}
} // namespace cuda::philox
namespace native {
namespace {
@ -754,177 +732,16 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt, Tensor, Ten
return std::make_tuple(attention, logsumexp, Tensor(), Tensor(), max_seqlen_batch_q, max_seqlen_batch_k, philox_seed, philox_offset, debug_attn_mask);
}
std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt, Tensor, Tensor, Tensor> _cudnn_attention_forward(
const Tensor& query,
const Tensor& key,
const Tensor& value,
const std::optional<Tensor>& attn_bias,
const std::optional<Tensor>& cumulative_sequence_length_q,
const std::optional<Tensor>& cumulative_sequence_length_kv,
long max_seqlen_batch_q,
long max_seqlen_batch_kv,
bool compute_logsumexp,
double dropout_p,
bool is_causal,
bool return_debug_mask,
std::optional<double> scale) {
// TODO(eqy): debug mask support
// Query (Batch x Num_heads x Q_seq_len x Dim_per_head)
// Key (Batch x Num_heads x KV_seq_len x Dim_per_head)
// Value (Batch x Num_heads x KV_seq_len x Dim_per_head)
const bool is_nested = cumulative_sequence_length_q.has_value();
if (!is_nested) {
const int64_t batch_size = query.size(0);
const int64_t num_heads = query.size(1);
const int64_t head_dim_qk = query.size(3);
const int64_t head_dim_v = value.size(3);
auto attn_bias_ = attn_bias;
if (attn_bias_.has_value()) {
const auto bias_dim = attn_bias_.value().dim();
if (bias_dim == 2) {
attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_kv});
} else if (bias_dim == 3) {
attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_kv});
} else {
TORCH_CHECK(bias_dim == 4, "cuDNN SDPA expects either a 2D, 3D, or 4D attn_bias but got ", attn_bias_.value().dim(), "D");
attn_bias_ = attn_bias_.value().expand({batch_size, attn_bias_.value().size(1), max_seqlen_batch_q, max_seqlen_batch_kv});
}
}
Tensor attention, log_sumexp;
at::Tensor cudnn_seed, cudnn_offset;
cudnn_seed = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
cudnn_offset = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
const bool use_dropout = std::fpclassify(dropout_p) != FP_ZERO;
// See Note [Seed and Offset Device] in _efficient_attention_forward
at::PhiloxCudaState philox_state;
const bool in_capture_stream =
at::cuda::currentStreamCaptureStatus() != at::cuda::CaptureStatus::None;
if (use_dropout) {
// Device
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
std::nullopt, at::cuda::detail::getDefaultCUDAGenerator());
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
// if using dropout, we produce 1 random number for each element of the
// attention tensor
// TODO(eqy): should state be advanced per thread (local) amount or per call/launch (global) amount
philox_state = gen->philox_cuda_state(batch_size * num_heads * max_seqlen_batch_q * max_seqlen_batch_kv);
at::cuda::philox::unpack_cudnn_wrapper(
philox_state, static_cast<int64_t*>(cudnn_seed.data_ptr()), static_cast<int64_t*>(cudnn_offset.data_ptr()), at::cuda::getCurrentCUDAStream());
}
const auto softmax_scale = sdp::calculate_scale(query, scale).expect_float();
Tensor debugmask;
run_cudnn_SDP_fprop(batch_size/*int64_t b*/,
num_heads/*int64_t h*/,
max_seqlen_batch_q/*int64_t s_q*/,
max_seqlen_batch_kv/*int64_t s_kv*/,
head_dim_qk/*int64_t d_qk*/,
head_dim_v/*int64_t d_v*/,
softmax_scale/*float scaling_factor*/,
compute_logsumexp/* bool */,
is_causal/* bool */,
dropout_p/*double dropout_probability*/,
query/* Tensor q*/,
key/* Tensor k*/,
value/* Tensor v*/,
attn_bias_ /* std::optional<Tensor> */,
log_sumexp/*Tensor softmaxstats*/,
attention/*Tensor o*/,
cudnn_seed/*Tensor dropoutseed*/,
cudnn_offset/*Tensor dropoutoffset*/);
// TODO(eqy): support debug_attn_mask
return std::make_tuple(std::move(attention), std::move(log_sumexp), Tensor(), Tensor(), max_seqlen_batch_q, max_seqlen_batch_kv, std::move(cudnn_seed), std::move(cudnn_offset), Tensor());
// Adapted from TE
// extract seed and offset from PhiloxCudaState
__global__ void unpack_cudnn(at::PhiloxCudaState arg, int64_t* seed_ptr, int64_t* offset_ptr) {
if (arg.captured_) {
*seed_ptr = static_cast<int64_t>(*arg.seed_.ptr);
*offset_ptr = static_cast<int64_t>(
*(arg.offset_.ptr) + static_cast<int64_t>(arg.offset_intragraph_));
} else {
//auto [
// query_buffer_reshaped,
// key_buffer_reshaped,
// value_buffer_reshaped,
// cumulative_sequence_length_q,
// cumulative_sequence_length_kv,
// max_seqlen_batch_q,
// max_seqlen_batch_kv,
// output_shape] = preprocessing::sdpa_nested_preprocessing(query, key, value);
// C10_LOG_API_USAGE_ONCE("torch.sdpa.flash_attention_cudnn");
// TODO(eqy): debug mask support
// BHSD ...
const int64_t batch_size = cumulative_sequence_length_q.value().size(0) - 1;
const int64_t num_heads_q = query.size(-2);
const int64_t num_heads_k = key.size(-2);
const int64_t num_heads_v = value.size(-2);
const int64_t head_dim_qk = query.size(-1);
const int64_t head_dim_v = value.size(-1);
auto attn_bias_ = attn_bias;
if (attn_bias_.has_value()) {
const auto bias_dim = attn_bias_.value().dim();
if (bias_dim == 2) {
attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_kv});
} else if (bias_dim == 3) {
attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_kv});
} else {
attn_bias_ = attn_bias_.value().expand({batch_size, attn_bias_.value().size(1), max_seqlen_batch_q, max_seqlen_batch_kv});
TORCH_CHECK(bias_dim == 4, "cuDNN SDPA expects either a 2D, 3D, or 4D attn_bias but got ", attn_bias_.value().dim(), "D");
}
}
Tensor attention, log_sumexp;
at::Tensor cudnn_seed, cudnn_offset;
cudnn_seed = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
cudnn_offset = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
const bool use_dropout = std::fpclassify(dropout_p) != FP_ZERO;
// See Note [Seed and Offset Device] in _efficient_attention_forward
at::PhiloxCudaState philox_state;
const bool in_capture_stream =
at::cuda::currentStreamCaptureStatus() != at::cuda::CaptureStatus::None;
if (use_dropout) {
// Device
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
std::nullopt, at::cuda::detail::getDefaultCUDAGenerator());
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
// if using dropout, we produce 1 random number for each element of the
// attention tensor
// TODO(eqy): should state be advanced per thread (local) amount or per call/launch (global) amount
philox_state = gen->philox_cuda_state(batch_size * num_heads_q * max_seqlen_batch_q * max_seqlen_batch_kv);
at::cuda::philox::unpack_cudnn_wrapper(philox_state, static_cast<int64_t*>(cudnn_seed.data_ptr()), static_cast<int64_t*>(cudnn_offset.data_ptr()), at::cuda::getCurrentCUDAStream());
}
const auto softmax_scale = sdp::calculate_scale(query, scale).as_float_unchecked();
run_cudnn_SDP_fprop_nestedtensor(batch_size/*int64_t b*/,
num_heads_q/*int64_t h*/,
num_heads_k,
num_heads_v,
max_seqlen_batch_q/*int64_t s_q*/,
max_seqlen_batch_kv/*int64_t s_kv*/,
head_dim_qk/*int64_t d_qk*/,
head_dim_v/*int64_t d_v*/,
softmax_scale/*float scaling_factor*/,
compute_logsumexp/* bool */,
is_causal/* bool */,
dropout_p/*double dropout_probability*/,
cumulative_sequence_length_q.value(),
cumulative_sequence_length_kv.value(),
query/* Tensor q*/,
key/* Tensor k*/,
value/* Tensor v*/,
attn_bias_ /* std::optional<Tensor> */,
log_sumexp/*Tensor softmaxstats*/,
attention/*Tensor o*/,
cudnn_seed/*Tensor dropoutseed*/,
cudnn_offset/*Tensor dropoutoffset*/);
//attention = wrap_buffer(attention.view(-1), output_shape).transpose(1, 2);
return std::make_tuple(std::move(attention), std::move(log_sumexp), cumulative_sequence_length_q.value(), cumulative_sequence_length_kv.value(), max_seqlen_batch_q, max_seqlen_batch_kv, std::move(cudnn_seed), std::move(cudnn_offset), Tensor());
*seed_ptr = static_cast<int64_t>(arg.seed_.val);
*offset_ptr = static_cast<int64_t>(arg.offset_.val);
}
}
@ -940,88 +757,84 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt, Tensor, Ten
std::optional<double> scale) {
// Used for tracking usage statistics
C10_LOG_API_USAGE_ONCE("torch.sdpa.flash_attention_cudnn");
// TODO(eqy): debug mask support
// Query (Batch x Num_heads x Q_seq_len x Dim_per_head)
// Key (Batch x Num_heads x KV_seq_len x Dim_per_head)
// Value (Batch x Num_heads x KV_seq_len x Dim_per_head)
const int64_t batch_size = query.size(0);
const int64_t num_heads = query.size(1);
const int64_t max_seqlen_batch_q = query.size(2);
const int64_t head_dim_qk = query.size(3);
const int64_t head_dim_v = value.size(3);
const int64_t max_seqlen_batch_k = key.size(2);
const int64_t max_seqlen_batch_v = value.size(2);
TORCH_CHECK(
max_seqlen_batch_k == max_seqlen_batch_v,
"Key and Value must have the same sequence length");
auto attn_bias_ = attn_bias;
if (attn_bias_.has_value()) {
const auto bias_dim = attn_bias_.value().dim();
if (bias_dim == 2) {
attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k});
} else if (bias_dim == 3) {
attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k});
} else {
TORCH_CHECK(bias_dim == 4, "cuDNN SDPA expects either a 2D, 3D, or 4D attn_bias but got ", attn_bias_.value().dim(), "D");
attn_bias_ = attn_bias_.value().expand({batch_size, attn_bias_.value().size(1), max_seqlen_batch_q, max_seqlen_batch_k});
}
}
return at::_cudnn_attention_forward(query, key, value, attn_bias, std::nullopt, std::nullopt, max_seqlen_batch_q, max_seqlen_batch_k, compute_logsumexp, dropout_p, is_causal, return_debug_mask, scale);
//// TODO(eqy): debug mask support
//// Query (Batch x Num_heads x Q_seq_len x Dim_per_head)
//// Key (Batch x Num_heads x KV_seq_len x Dim_per_head)
//// Value (Batch x Num_heads x KV_seq_len x Dim_per_head)
//const int64_t batch_size = query.size(0);
//const int64_t num_heads = query.size(1);
//const int64_t max_seqlen_batch_q = query.size(2);
//const int64_t head_dim_qk = query.size(3);
//const int64_t head_dim_v = value.size(3);
//const int64_t max_seqlen_batch_k = key.size(2);
//const int64_t max_seqlen_batch_v = value.size(2);
//TORCH_CHECK(
// max_seqlen_batch_k == max_seqlen_batch_v,
// "Key and Value must have the same sequence length");
//auto attn_bias_ = attn_bias;
//if (attn_bias_.has_value()) {
// const auto bias_dim = attn_bias_.value().dim();
// if (bias_dim == 2) {
// attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k});
// } else if (bias_dim == 3) {
// attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k});
// } else {
// TORCH_CHECK(bias_dim == 4, "cuDNN SDPA expects either a 2D, 3D, or 4D attn_bias but got ", attn_bias_.value().dim(), "D");
// attn_bias_ = attn_bias_.value().expand({batch_size, attn_bias_.value().size(1), max_seqlen_batch_q, max_seqlen_batch_k});
// }
//}
Tensor attention, log_sumexp;
//Tensor attention, log_sumexp;
at::Tensor cudnn_seed, cudnn_offset;
cudnn_seed = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
cudnn_offset = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
//at::Tensor cudnn_seed, cudnn_offset;
//cudnn_seed = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
//cudnn_offset = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
const bool use_dropout = std::fpclassify(dropout_p) != FP_ZERO;
//const bool use_dropout = std::fpclassify(dropout_p) != FP_ZERO;
// See Note [Seed and Offset Device] in _efficient_attention_forward
at::PhiloxCudaState philox_state;
const bool in_capture_stream =
at::cuda::currentStreamCaptureStatus() != at::cuda::CaptureStatus::None;
if (use_dropout) {
// Device
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
std::nullopt, at::cuda::detail::getDefaultCUDAGenerator());
//// See Note [Seed and Offset Device] in _efficient_attention_forward
//at::PhiloxCudaState philox_state;
//const bool in_capture_stream =
// at::cuda::currentStreamCaptureStatus() != at::cuda::CaptureStatus::None;
//if (use_dropout) {
// // Device
// auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
// std::nullopt, at::cuda::detail::getDefaultCUDAGenerator());
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
// if using dropout, we produce 1 random number for each element of the
// attention tensor
// TODO(eqy): should state be advanced per thread (local) amount or per call/launch (global) amount
philox_state = gen->philox_cuda_state(batch_size * num_heads * max_seqlen_batch_q * max_seqlen_batch_k);
unpack_cudnn<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>(
philox_state, static_cast<int64_t*>(cudnn_seed.data_ptr()), static_cast<int64_t*>(cudnn_offset.data_ptr()));
}
// // See Note [Acquire lock when using random generators]
// std::lock_guard<std::mutex> lock(gen->mutex_);
// // if using dropout, we produce 1 random number for each element of the
// // attention tensor
// // TODO(eqy): should state be advanced per thread (local) amount or per call/launch (global) amount
// philox_state = gen->philox_cuda_state(batch_size * num_heads * max_seqlen_batch_q * max_seqlen_batch_k);
// at::cuda::philox::unpack_cudnn_wrapper(
// philox_state, static_cast<int64_t*>(cudnn_seed.data_ptr()), static_cast<int64_t*>(cudnn_offset.data_ptr()), at::cuda::getCurrentCUDAStream());
//}
const auto softmax_scale = sdp::calculate_scale(query, scale).expect_float();
Tensor debugmask;
//const auto softmax_scale = sdp::calculate_scale(query, scale).expect_float();
//Tensor debugmask;
run_cudnn_SDP_fprop(batch_size/*int64_t b*/,
num_heads/*int64_t h*/,
max_seqlen_batch_q/*int64_t s_q*/,
max_seqlen_batch_k/*int64_t s_kv*/,
head_dim_qk/*int64_t d_qk*/,
head_dim_v/*int64_t d_v*/,
softmax_scale/*float scaling_factor*/,
compute_logsumexp/* bool */,
is_causal/* bool */,
dropout_p/*double dropout_probability*/,
query/* Tensor q*/,
key/* Tensor k*/,
value/* Tensor v*/,
attn_bias_ /* std::optional<Tensor> */,
log_sumexp/*Tensor softmaxstats*/,
attention/*Tensor o*/,
cudnn_seed/*Tensor dropoutseed*/,
cudnn_offset/*Tensor dropoutoffset*/);
//run_cudnn_SDP_fprop(batch_size/*int64_t b*/,
// num_heads/*int64_t h*/,
// max_seqlen_batch_q/*int64_t s_q*/,
// max_seqlen_batch_k/*int64_t s_kv*/,
// head_dim_qk/*int64_t d_qk*/,
// head_dim_v/*int64_t d_v*/,
// softmax_scale/*float scaling_factor*/,
// compute_logsumexp/* bool */,
// is_causal/* bool */,
// dropout_p/*double dropout_probability*/,
// query/* Tensor q*/,
// key/* Tensor k*/,
// value/* Tensor v*/,
// attn_bias_ /* std::optional<Tensor> */,
// log_sumexp/*Tensor softmaxstats*/,
// attention/*Tensor o*/,
// cudnn_seed/*Tensor dropoutseed*/,
// cudnn_offset/*Tensor dropoutoffset*/);
//// TODO(eqy): support debug_attn_mask
//return std::make_tuple(std::move(attention), std::move(log_sumexp), Tensor(), Tensor(), max_seqlen_batch_q, max_seqlen_batch_k, std::move(cudnn_seed), std::move(cudnn_offset), Tensor());
// TODO(eqy): support debug_attn_mask
return std::make_tuple(std::move(attention), std::move(log_sumexp), Tensor(), Tensor(), max_seqlen_batch_q, max_seqlen_batch_k, std::move(cudnn_seed), std::move(cudnn_offset), Tensor());
}
std::tuple<Tensor, Tensor, Tensor, Tensor> _scaled_dot_product_efficient_attention_cuda(

View File

@ -505,23 +505,10 @@ bool check_cudnn_hardware_support(sdp_params const& params, bool debug) {
}
bool check_for_nested_inputs(sdp_params const& params, bool debug) {
static const bool enable_cudnn_nested = c10::utils::check_env("TORCH_CUDNN_SDPA_NESTED_TENSOR_ENABLED") == true;
if (has_for_nested_inputs(params) && !enable_cudnn_nested) {
if (debug) {
TORCH_WARN("Experimental cuDNN SDPA nested tensor support is not enabled.");
}
return false;
} else if (params.query.requires_grad() || params.key.requires_grad() || params.value.requires_grad()) {
if (debug) {
TORCH_WARN("Experimental cuDNN SDPA nested tensor support does not support backward.");
}
}
const auto dprop = at::cuda::getCurrentDeviceProperties();
// Check that the input is nested
if (dprop->major != 9 && has_for_nested_inputs(params)) {
if (has_for_nested_inputs(params)) {
if (debug) {
TORCH_WARN("CuDNN SDPA supports nested tensors on SM 9.0.");
TORCH_WARN("CuDNN currently does not support nested inputs.");
}
return false;
}
@ -587,6 +574,7 @@ bool can_use_cudnn_attention(const sdp_params& params, bool debug) {
check_runtime_disabled_cudnn,
check_for_nested_inputs,
check_nonzero_sequence_lengths_dense,
check_last_dim_stride_equals_1_dense<true /*ignore_singleton_dim>*/>,
check_all_tensors_on_device,
check_tensor_shapes,
check_cudnn_tensor_shapes,
@ -600,18 +588,6 @@ bool can_use_cudnn_attention(const sdp_params& params, bool debug) {
return false;
}
}
constexpr auto dense_constraints =
c10::array_of<bool (*)(sdp_params const&, bool)>(
check_last_dim_stride_equals_1_dense<true /*ignore_singleton_dim=*/>
);
if (has_only_dense_inputs(params)) {
for (auto& constraint : dense_constraints) {
if (!constraint(params, debug)) {
return false;
}
}
}
return true;
}

View File

@ -75,7 +75,6 @@ aten::_ctc_loss.out
aten::_ctc_loss_backward
aten::_ctc_loss_backward.Tensor
aten::_ctc_loss_backward.out
aten::_cudnn_attention_forward
aten::_cudnn_ctc_loss
aten::_cudnn_ctc_loss.Tensor
aten::_cudnn_ctc_loss.out

View File

@ -4,7 +4,6 @@ import ast
import io
import itertools
import math
import os
import random
import sys
import tempfile
@ -31,7 +30,6 @@ from torch.testing._internal.common_cuda import (
PLATFORM_SUPPORTS_FUSED_ATTENTION,
SM70OrLater,
SM80OrLater,
tf32_on_and_off,
)
from torch.testing._internal.common_device_type import (
dtypes,
@ -59,7 +57,6 @@ from torch.testing._internal.common_utils import (
NestedTensorTestCase,
parametrize,
run_tests,
serialTest,
skipIfRocm,
skipIfSlowGradcheckEnv,
skipIfTorchDynamo,
@ -1997,7 +1994,6 @@ class TestNestedTensorDeviceType(NestedTensorTestCase):
@onlyCUDA
@dtypes(torch.float, torch.double, torch.float16, torch.bfloat16)
@tf32_on_and_off(0.005)
def test_bmm_cuda(self, device, dtype):
self._test_bmm(device, dtype)
@ -2022,7 +2018,6 @@ class TestNestedTensorDeviceType(NestedTensorTestCase):
)
@dtypes(torch.float, torch.double)
@tf32_on_and_off(0.005)
def test_matmul_with_bmm_path(self, device, dtype):
def unbind_rebind_matmul(nt1, nt2):
t1s = nt1.unbind()
@ -2647,7 +2642,6 @@ class TestNestedTensorDeviceType(NestedTensorTestCase):
nt_noncont.narrow(dim=0, start=0, length=1)
@parametrize("input_dim", [3, 4])
@tf32_on_and_off(0.005)
def test_scaled_dot_product_attention(self, device, input_dim):
def rand_tensor(*shape):
return torch.randn(shape, device=device)
@ -3055,7 +3049,6 @@ class TestNestedTensorAutograd(NestedTensorTestCase):
data = (a, b, c, d)
assert torch.autograd.gradcheck(grad_test_func, inputs=data)
@tf32_on_and_off(0.008)
def test_nested_tensor_bmm_backward(self, device):
nt0 = torch.nested.nested_tensor(
[torch.randn((2, 6)), torch.randn((3, 6))],
@ -3801,13 +3794,11 @@ class TestNestedTensorSubclass(NestedTensorTestCase):
@onlyCUDA
@dtypes(torch.float32)
@serialTest()
def test_linear_backward_memory_usage(self, device, dtype):
# Verify that linear_backward() doesn't use more memory than it should
# for higher dim input sizes.
# See https://github.com/pytorch/pytorch/issues/141112
B, D, max_seq_len = 64, 512, 100
torch._C._cuda_clearCublasWorkspaces()
m = torch.nn.Linear(D, D, device=device)
nt = torch.nested.as_nested_tensor(
[
@ -6469,7 +6460,6 @@ torch.cuda.synchronize()
TEST_WITH_ROCM,
"ROCm doesn't support flash attention or mem_efficient attention for NT",
)
@tf32_on_and_off(0.005)
@dtypes(
*(
[torch.float16, torch.bfloat16, torch.float32]
@ -6654,24 +6644,10 @@ torch.cuda.synchronize()
)
self.assertEqual(attn_out.shape, q_nt_3.shape)
@parametrize("skip_backward", [True, False])
def check_forward_backward(skip_backward=False):
if not skip_backward:
attn_nt = torch.nn.functional.scaled_dot_product_attention(
q_nt_t, k_nt_t, v_nt_t
).transpose(1, 2)
else:
x_nt.requires_grad = False
q_nt.requires_grad = False
k_nt.requires_grad = False
v_nt.requires_grad = False
tq = q_nt_t.detach()
tk = k_nt_t.detach()
tv = v_nt_t.detach()
with torch.no_grad():
attn_nt = torch.nn.functional.scaled_dot_product_attention(
tq, tk, tv
).transpose(1, 2)
def check_forward_backward():
attn_nt = torch.nn.functional.scaled_dot_product_attention(
q_nt_t, k_nt_t, v_nt_t
).transpose(1, 2)
attn_nts = attn_nt.unbind()
self.assertEqual(
@ -6687,26 +6663,23 @@ torch.cuda.synchronize()
rtol=output_ref_rtol,
)
if not skip_backward:
nt_grads = torch.autograd.grad(
attn_nt.values().sum(), (q_nt, k_nt, v_nt)
nt_grads = torch.autograd.grad(attn_nt.values().sum(), (q_nt, k_nt, v_nt))
for nt_grad, d1_grad, d2_grad, grad_atol, grad_rtol in zip(
nt_grads, d1_grads, d2_grads, grad_atols, grad_rtols
):
unbound_nt_grads = nt_grad.unbind()
self.assertEqual(
d1_grad,
unbound_nt_grads[0].unsqueeze(0),
atol=grad_atol,
rtol=grad_rtol,
)
self.assertEqual(
d2_grad,
unbound_nt_grads[1].unsqueeze(0),
atol=grad_atol,
rtol=grad_rtol,
)
for nt_grad, d1_grad, d2_grad, grad_atol, grad_rtol in zip(
nt_grads, d1_grads, d2_grads, grad_atols, grad_rtols
):
unbound_nt_grads = nt_grad.unbind()
self.assertEqual(
d1_grad,
unbound_nt_grads[0].unsqueeze(0),
atol=grad_atol,
rtol=grad_rtol,
)
self.assertEqual(
d2_grad,
unbound_nt_grads[1].unsqueeze(0),
atol=grad_atol,
rtol=grad_rtol,
)
# Default
check_forward_backward()
@ -6725,17 +6698,6 @@ torch.cuda.synchronize()
# "group_gemm_dispatch" not implemented for 'BFloat16'
if not (str(device).startswith("cuda") and dtype == torch.bfloat16):
check_forward_backward()
check_cudnn = os.getenv("TORCH_CUDNN_SDPA_NESTED_TENSOR_ENABLED", "0") == "1"
if (
"cuda" in str(device)
and check_cudnn
and (dtype == torch.float16 or dtype == torch.bfloat16)
):
with self.assertRaisesRegex(RuntimeError, "cuDNN SDPA Nested Tensor"):
with torch.nn.attention.sdpa_kernel(
torch.nn.attention.SDPBackend.CUDNN_ATTENTION
):
check_forward_backward()
@skipIfTorchDynamo("SDPA test compiles internally")
@unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile")
@ -8627,7 +8589,6 @@ class TestNestedTensorOpInfo(NestedTensorTestCase):
[op for op in njt_op_db if op.supports_njt],
allowed_dtypes=(torch.float32,),
)
@tf32_on_and_off(0.005)
@sample_skips_and_xfails(FORWARD_SKIPS_AND_XFAILS)
def test_forward(self, device, dtype, op):
for sample, subtest_ctx, skip_xfail_ctx in op.sample_inputs(
@ -8656,7 +8617,6 @@ class TestNestedTensorOpInfo(NestedTensorTestCase):
[op for op in njt_op_db if op.supports_njt and op.supports_autograd],
allowed_dtypes=(torch.float32,),
)
@tf32_on_and_off(0.005)
@sample_skips_and_xfails(BACKWARD_SKIPS_AND_XFAILS)
def test_backward(self, device, dtype, op):
for sample, subtest_ctx, skip_xfail_ctx in op.sample_inputs(

View File

@ -2725,40 +2725,6 @@ class TestSDPACudaOnly(NNTestCase):
self.assertEqual(actual.contiguous(), math_ref.contiguous(), atol=2e-3, rtol=1e-2)
@unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "Fused SDPA was not built for this system")
@unittest.skipIf("TORCH_CUDNN_SDPA_NESTED_TENSOR_ENABLED" not in os.environ, "cuDNN Nested Tensor support not enabled")
@parametrize("type", ["nested"])
@parametrize("is_contiguous", [True])
def test_scaled_dot_product_attention_cudnn_nested(self, device, type: str, is_contiguous: bool):
if TEST_WITH_ROCM and type == 'nested':
self.skipTest("ROCM does not support efficient attention on nested tensors, for now")
make_tensor = partial(rand_sdpa_tensor, type=type, device=device, dtype=torch.float16, packed=True)
batch_size, seq_len, num_heads, head_dim = 8, 64, 16, 64
shape = SdpaShape(batch_size, num_heads, seq_len, head_dim)
# Test Packed
qkv = make_tensor(shape)
query, key, value = qkv.chunk(3, dim=-1)
query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
if is_contiguous:
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
with sdpa_kernel(backends=[SDPBackend.CUDNN_ATTENTION]):
actual = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False)
with sdpa_kernel(backends=[SDPBackend.MATH]):
math_ref = torch.nn.functional.scaled_dot_product_attention(
query.contiguous(), key.contiguous(), value.contiguous(),
attn_mask=None, dropout_p=0.0, is_causal=False)
self.assertEqual(actual.contiguous(), math_ref.contiguous(), atol=2e-3, rtol=1e-2)
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system")
@parametrize("type", ["dense", "nested"])
@parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] if

View File

@ -6,10 +6,8 @@ import torch
import torch.nn
import torch.nn.functional as F
from torch.backends.cuda import (
can_use_cudnn_attention,
can_use_efficient_attention,
can_use_flash_attention,
cudnn_sdp_enabled,
flash_sdp_enabled,
math_sdp_enabled,
mem_efficient_sdp_enabled,
@ -113,32 +111,6 @@ def _check_head_dim_size_flash_nested(params: SDPAParams, debug=False) -> bool:
return True
def _check_head_dim_size_cudnn_nested(params: SDPAParams, debug=False) -> bool:
max_size = 128
query_size_last = params.query.size(-1)
key_size_last = params.key.size(-1)
value_size_last = params.value.size(-1)
same_head_dim_size = (
query_size_last == key_size_last and query_size_last == value_size_last
)
if not (
same_head_dim_size
and (query_size_last % 8 == 0)
and (query_size_last <= max_size)
):
if debug:
log.warning(
"For NestedTensor inputs, cuDNN attention requires q,k,v to have the same "
"last dimension and to be a multiple of 8 and less than or equal to 128. "
"Got Query.size(-1): %d, Key.size(-1): %d, Value.size(-1): %d instead.",
query_size_last,
key_size_last,
value_size_last,
)
return False
return True
def _check_for_seq_len_0_and_consistent_head_dim_nested_helper(
param: torch.Tensor, param_name: str, debug=False
) -> bool:
@ -295,7 +267,6 @@ def _select_sdp_backend(query, key, value, attn_mask, dropout, is_causal, enable
not flash_sdp_enabled()
and not mem_efficient_sdp_enabled()
and not math_sdp_enabled()
and not cudnn_sdp_enabled()
):
return SDPBackend.ERROR
@ -303,15 +274,11 @@ def _select_sdp_backend(query, key, value, attn_mask, dropout, is_causal, enable
SDPBackend.FLASH_ATTENTION,
SDPBackend.EFFICIENT_ATTENTION,
SDPBackend.MATH,
SDPBackend.CUDNN_ATTENTION,
)
params = SDPAParams(query, key, value, attn_mask, dropout, is_causal, enable_gqa)
for backend in ordering:
if backend == SDPBackend.CUDNN_ATTENTION:
if can_use_cudnn_attention(params):
return SDPBackend.CUDNN_ATTENTION
if backend == SDPBackend.FLASH_ATTENTION:
if can_use_flash_attention(params) and _can_use_flash_sdpa_jagged(params):
return SDPBackend.FLASH_ATTENTION
@ -332,8 +299,6 @@ def _select_sdp_backend(query, key, value, attn_mask, dropout, is_causal, enable
_can_use_flash_sdpa_jagged(params, debug=True)
log.warning("Math attention kernel not used because:")
_can_use_math_sdpa_jagged(params, debug=True)
log.warning("cuDNN attention kernel not used because:")
can_use_cudnn_attention(params, debug=True)
return SDPBackend.ERROR
@ -785,6 +750,7 @@ def jagged_scaled_dot_product_attention(
max_seqlen_batch_kv,
output_nt_info,
) = _sdpa_nested_preprocessing(query_padded, key_padded, value_padded)
(
attention,
_logsumexp,
@ -804,6 +770,7 @@ def jagged_scaled_dot_product_attention(
False,
scale=og_scale,
)
# Reshape output to convert nnz to batch_size and seq_len
attention = nested_view_from_values_offsets_lengths(
attention, # output from flash_attn is [total_q, num_heads, head_size_og]
@ -842,51 +809,12 @@ def jagged_scaled_dot_product_attention(
compute_logsumexp,
scale=scale,
)
# Reshape output to convert nnz to batch_size and seq_len
return nested_view_from_values_offsets_lengths(
attention.squeeze(0),
**output_nt_info,
).transpose(1, 2)
elif backend_choice == SDPBackend.CUDNN_ATTENTION:
(
query_reshaped,
key_reshaped,
value_reshaped,
cumulative_sequence_length_q,
cumulative_sequence_length_kv,
max_seqlen_batch_q,
max_seqlen_batch_kv,
output_nt_info,
) = _sdpa_nested_preprocessing(query, key, value)
(
attention,
logsumexp,
cum_seqlen_q,
cum_seqlen_kv,
max_seqlen_q,
max_seqlen_kv,
seed,
offset,
_,
) = torch.ops.aten._cudnn_attention_forward(
query_reshaped,
key_reshaped,
value_reshaped,
attn_mask,
cumulative_sequence_length_q,
cumulative_sequence_length_kv,
max_seqlen_batch_q,
max_seqlen_batch_kv,
compute_logsumexp,
dropout_p,
is_causal,
False,
scale=scale,
)
return nested_view_from_values_offsets_lengths(
attention,
**output_nt_info,
).transpose(1, 2)
elif backend_choice == SDPBackend.MATH:
# save the offsets and shape of the inputs, so we can reshape the final output
# query @ key = attn: [B, D1, j0, D'] @ [B, D1, D' j1] = [B, D1, j0, j1]