Grouped Query Attention (#132689)

### Approach: Using the current function declaration

**Constraint:** Q_Heads % KV_Heads == 0

**Major change:**
- Added a new argument enable_gqa: bool to sdpa function call
- It adds a meaning to the last third dimension.

Sample use cases this would enable:
LLama3

```
# LLama3 8b call to SDPA
query = torch.rand(batch, 32, seq_len_q, D)
key = torch.rand(batch, 8, seq_len_kv, D)
value = torch.rand(batch, 8, seq_len_kv, D)

output = scaled_dot_product_attention(query, key, value, is_causal=True, enable_gqa=True)

# Output Shape
(batch, 32, seq_len_q, D)
```

### Design Choice:

- Check if Query.size(-3) == Key.size(-3) == Value.size(-3) or, Query.size(-3) % Key.size(-3) == 0
- The function adjusts the key and value tensors to match the query tensor's head dimension by using repeat_interleave if their number of heads are not equal, facilitating correct and efficient computation in attention mechanisms.
- By default the enable_gqa flag is set to False, which ensures that regular sdpa functionality remains unchanged.

### Benchmarks:

- **sdpa.py: #130634**
For different batch sizes enable_gqa=True shows a substansial improvement in the run_time of sdpa

 | batch_size | q_num_heads | kv_num_heads | q_seq_len | kv_seq_len | embed_dim | forward_time when enable_gqa=True   |   forward_time when enable_gqa=False    |
| ------------ | ------------- | -------------- | ----------- | ------------ | ----------- | ----------- | ---------------- |
|     1      |     32      |      8       |   2048    |    2048    |   2048    |   100.71  |  119.70  |
|     8      |     32      |      8       |   2048    |    2048    |   2048    |   539.78  |  628.83  |
|     16     |     32      |      8       |   2048    |    2048    |   2048    |   1056.81  |  1225.48  |
|     32      |     32      |      8       |   2048    |    2048    |   2048    |   2099.54  |  2440.45  |

![Screenshot 2024-07-25 at 9 07 40 PM](https://github.com/user-attachments/assets/a3e5f716-c39f-4096-9e6c-82a735e57b7b)

- **TorchTitan: https://github.com/pytorch/torchtitan/pull/458**

Differential Revision: D60772086

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132689
Approved by: https://github.com/drisspg
This commit is contained in:
Apurva Jain 2024-08-07 05:35:36 +00:00 committed by PyTorch MergeBot
parent 527f104a69
commit 8bc5ef563e
19 changed files with 372 additions and 170 deletions

View File

@ -14711,21 +14711,21 @@
CUDA, NestedTensorCUDA: native_multi_head_attention_cuda CUDA, NestedTensorCUDA: native_multi_head_attention_cuda
autogen: _native_multi_head_attention.out autogen: _native_multi_head_attention.out
- func: scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> Tensor - func: scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None, bool enable_gqa=False) -> Tensor
python_module: nn python_module: nn
variants: function variants: function
autogen: scaled_dot_product_attention.out autogen: scaled_dot_product_attention.out
tags: nondeterministic_seeded tags: nondeterministic_seeded
# This aten function is kept so that we can test the choice function from Python # This aten function is kept so that we can test the choice function from Python
- func: _fused_sdp_choice(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> int - func: _fused_sdp_choice(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None, bool enable_gqa=False) -> int
dispatch: dispatch:
Meta: _fused_sdp_choice_meta Meta: _fused_sdp_choice_meta
CPU, NestedTensorCPU: _fused_sdp_choice_cpp CPU, NestedTensorCPU: _fused_sdp_choice_cpp
CUDA, NestedTensorCUDA: _fused_sdp_choice_cuda CUDA, NestedTensorCUDA: _fused_sdp_choice_cuda
tags: nondeterministic_seeded tags: nondeterministic_seeded
- func: _scaled_dot_product_attention_math(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, Tensor? dropout_mask=None, *, float? scale=None) -> (Tensor, Tensor) - func: _scaled_dot_product_attention_math(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, Tensor? dropout_mask=None, *, float? scale=None, bool enable_gqa=False) -> (Tensor, Tensor)
variants: function variants: function
tags: nondeterministic_seeded tags: nondeterministic_seeded

View File

@ -430,8 +430,8 @@ std::tuple<Tensor, Tensor> native_multi_head_attention_cpu(
} }
int64_t _fused_sdp_choice_cpp(const Tensor& query_, const Tensor& key, const Tensor& value, int64_t _fused_sdp_choice_cpp(const Tensor& query_, const Tensor& key, const Tensor& value,
const std::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal, std::optional<double> scale){ const std::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal, std::optional<double> scale, bool enable_gqa){
sdp::sdp_params kernel_params{query_, key, value, attn_mask_, dropout_p, is_causal}; sdp::sdp_params kernel_params{query_, key, value, attn_mask_, dropout_p, is_causal, enable_gqa};
auto backend = sdp::select_sdp_backend_cpp(kernel_params); auto backend = sdp::select_sdp_backend_cpp(kernel_params);
if (backend == sdp::SDPBackend::error) { if (backend == sdp::SDPBackend::error) {
TORCH_CHECK( TORCH_CHECK(
@ -455,12 +455,13 @@ int64_t _fused_sdp_choice_meta(
const std::optional<Tensor>& attn_mask_, const std::optional<Tensor>& attn_mask_,
double dropout_p, double dropout_p,
bool is_causal, bool is_causal,
std::optional<double> scale) { std::optional<double> scale,
bool enable_gqa) {
auto query_key_set = query_.key_set(); auto query_key_set = query_.key_set();
#if defined(USE_ROCM) #if defined(USE_ROCM)
bool has_rocm = query_key_set.has(c10::DispatchKey::HIP); bool has_rocm = query_key_set.has(c10::DispatchKey::HIP);
if (has_rocm) { if (has_rocm) {
auto choice_int = _fused_sdp_choice_stub(at::kHIP, query_, key, value, attn_mask_, dropout_p, is_causal, scale); auto choice_int = _fused_sdp_choice_stub(at::kHIP, query_, key, value, attn_mask_, dropout_p, is_causal, scale, enable_gqa);
return choice_int; return choice_int;
} }
#else #else
@ -474,7 +475,8 @@ int64_t _fused_sdp_choice_meta(
attn_mask_, attn_mask_,
dropout_p, dropout_p,
is_causal, is_causal,
scale); scale,
enable_gqa);
return choice_int; return choice_int;
} }
#endif #endif
@ -607,6 +609,36 @@ bool should_compute_logsumexp(const Tensor& query, const Tensor& key, const Tens
return any_inputs_require_grad && gradmode_enabled; return any_inputs_require_grad && gradmode_enabled;
} }
std::tuple<at::Tensor, at::Tensor> pre_process_group_query_attention_input(
const at::Tensor& query,
const at::Tensor& key,
const at::Tensor& value,
const bool enable_gqa) {
if (!enable_gqa) {
return std::make_tuple(key, value);
}
const auto q_num_heads = query.sym_size(-3);
const auto k_num_heads = key.sym_size(-3);
const auto v_num_heads = value.sym_size(-3);
bool all_equal = q_num_heads == k_num_heads && k_num_heads == v_num_heads;
bool key_divisible = q_num_heads % k_num_heads == 0;
bool value_divisible = q_num_heads % v_num_heads == 0;
TORCH_CHECK(all_equal || (key_divisible && value_divisible),
"Number of heads in key and value must divide the number of heads in ");
if (all_equal){
return std::make_tuple(key, value);
}
auto repeat_key_shape = query.sym_size(-3) / key.sym_size(-3);
auto repeat_value_shape = query.sym_size(-3) / value.sym_size(-3);
at::Tensor key_repeated = key.repeat_interleave_symint(repeat_key_shape, -3);
at::Tensor value_repeated = value.repeat_interleave_symint(repeat_value_shape, -3);
return std::make_tuple(std::move(key_repeated), std::move(value_repeated));
}
} // namespace } // namespace
// Computes scaled dot product attention on query, key and value tensors, using // Computes scaled dot product attention on query, key and value tensors, using
@ -645,12 +677,13 @@ Tensor scaled_dot_product_attention(
const std::optional<Tensor>& attn_mask_, const std::optional<Tensor>& attn_mask_,
double dropout_p, double dropout_p,
bool is_causal, bool is_causal,
std::optional<double> scale) { std::optional<double> scale,
bool enable_gqa) {
validate_sdpa_input(query_, key, value, attn_mask_, dropout_p, is_causal, scale); validate_sdpa_input(query_, key, value, attn_mask_, dropout_p, is_causal, scale);
int64_t choice_int = static_cast<int64_t>(sdp::SDPBackend::math); int64_t choice_int = static_cast<int64_t>(sdp::SDPBackend::math);
if (_fused_sdp_choice_stub.is_device_supported(query_.device().type())) { if (_fused_sdp_choice_stub.is_device_supported(query_.device().type())) {
choice_int = _fused_sdp_choice_stub(query_.device().type(), choice_int = _fused_sdp_choice_stub(query_.device().type(),
query_, key, value, attn_mask_, dropout_p, is_causal, scale); query_, key, value, attn_mask_, dropout_p, is_causal, scale, enable_gqa);
} }
sdp::SDPBackend backend = static_cast<sdp::SDPBackend>(choice_int); sdp::SDPBackend backend = static_cast<sdp::SDPBackend>(choice_int);
std::optional<Tensor> attn_mask = convert_boolean_attn_mask(attn_mask_, query_.dtype()); std::optional<Tensor> attn_mask = convert_boolean_attn_mask(attn_mask_, query_.dtype());
@ -712,8 +745,9 @@ Tensor scaled_dot_product_attention(
attn_mask, attn_mask,
dropout_p, dropout_p,
is_causal, is_causal,
std::nullopt, /*dropout_mask*/ c10::nullopt, /*dropout_mask*/
scale)); scale,
enable_gqa));
default: default:
TORCH_CHECK( TORCH_CHECK(
false, false,
@ -725,7 +759,7 @@ Tensor scaled_dot_product_attention(
std::tuple<Tensor, Tensor> _scaled_dot_product_attention_math( std::tuple<Tensor, Tensor> _scaled_dot_product_attention_math(
const Tensor& query_, const Tensor& key, const Tensor& value, const Tensor& query_, const Tensor& key, const Tensor& value,
const std::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal, const std::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal,
const std::optional<Tensor>& dropout_mask, std::optional<double> scale) { const std::optional<Tensor>& dropout_mask, std::optional<double> scale, bool enable_gqa) {
C10_LOG_API_USAGE_ONCE("torch.sdpa.math_fallback"); C10_LOG_API_USAGE_ONCE("torch.sdpa.math_fallback");
if (query_.is_nested() || key.is_nested() || value.is_nested()) { if (query_.is_nested() || key.is_nested() || value.is_nested()) {
TORCH_CHECK( TORCH_CHECK(
@ -781,7 +815,11 @@ std::tuple<Tensor, Tensor> _scaled_dot_product_attention_math(
at::ones_symint({L, S}, query.options().dtype(at::kBool)).tril(); at::ones_symint({L, S}, query.options().dtype(at::kBool)).tril();
attn_mask = convert_boolean_attn_mask(attn_mask, query.dtype()); attn_mask = convert_boolean_attn_mask(attn_mask, query.dtype());
} }
auto attn = at::matmul(query, key_acc.transpose(-2, -1) * scaling_factor);
// MQA/GQA handling
auto [key_expanded, value_expanded] = pre_process_group_query_attention_input(query, key_acc, value_acc, enable_gqa);
auto attn = at::matmul(query, key_expanded.transpose(-2, -1) * scaling_factor);
if (attn_mask.has_value()) { if (attn_mask.has_value()) {
if (at::areAnyTensorSubclassLike({attn, *attn_mask})) { if (at::areAnyTensorSubclassLike({attn, *attn_mask})) {
attn = attn.add(*attn_mask); attn = attn.add(*attn_mask);
@ -797,13 +835,13 @@ std::tuple<Tensor, Tensor> _scaled_dot_product_attention_math(
TORCH_WARN_ONCE("Dropout mask should only be used for testing purposes."); TORCH_WARN_ONCE("Dropout mask should only be used for testing purposes.");
attn = attn.masked_fill(dropout_mask->logical_not(), 0.0); attn = attn.masked_fill(dropout_mask->logical_not(), 0.0);
auto dropout_scaling = 1.0 / (1 - dropout_p); auto dropout_scaling = 1.0 / (1 - dropout_p);
return std::make_tuple(at::matmul(attn, value_acc * dropout_scaling).to(origin_dtype), attn.to(origin_dtype)); return std::make_tuple(at::matmul(attn, value_expanded * dropout_scaling).to(origin_dtype), attn.to(origin_dtype));
} else { } else {
attn = at::dropout(attn, dropout_p, true); attn = at::dropout(attn, dropout_p, true);
} }
} }
return std::make_tuple(at::matmul(attn, value_acc).to(origin_dtype), attn.to(origin_dtype)); return std::make_tuple(at::matmul(attn, value_expanded).to(origin_dtype), attn.to(origin_dtype));
} }
std::tuple<at::Tensor, at::Tensor> std::tuple<at::Tensor, at::Tensor>

View File

@ -9,7 +9,7 @@ namespace at {
namespace native { namespace native {
using fused_sdp_choice_fn = int64_t (*)(const Tensor& query_, const Tensor& key, const Tensor& value, using fused_sdp_choice_fn = int64_t (*)(const Tensor& query_, const Tensor& key, const Tensor& value,
const std::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal, std::optional<double> scale); const std::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal, std::optional<double> scale, bool enable_gqa);
DECLARE_DISPATCH(fused_sdp_choice_fn, _fused_sdp_choice_stub); DECLARE_DISPATCH(fused_sdp_choice_fn, _fused_sdp_choice_stub);

View File

@ -560,7 +560,7 @@ std::tuple<Tensor, Tensor> native_multi_head_attention_cuda(
auto k = key.view({key.size(0), -1, num_head, dim_per_head}).transpose(1, 2); auto k = key.view({key.size(0), -1, num_head, dim_per_head}).transpose(1, 2);
auto v = value.view({value.size(0), -1, num_head, dim_per_head}).transpose(1, 2); auto v = value.view({value.size(0), -1, num_head, dim_per_head}).transpose(1, 2);
sdp::sdp_params kernel_params{q, k, v, mask, 0.0, false}; sdp::sdp_params kernel_params{q, k, v, mask, 0.0, false, false};
auto backend = select_sdp_backend(kernel_params); auto backend = select_sdp_backend(kernel_params);
// strides from packed projection for nested tensors when seq_len is 1 will be // strides from packed projection for nested tensors when seq_len is 1 will be
// and will trigger a contiguous call in the kernel, so we prevent this // and will trigger a contiguous call in the kernel, so we prevent this
@ -868,8 +868,8 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> _scaled_dot_product_efficient_attenti
} }
int64_t _fused_sdp_choice_cuda(const Tensor& query_, const Tensor& key, const Tensor& value, int64_t _fused_sdp_choice_cuda(const Tensor& query_, const Tensor& key, const Tensor& value,
const std::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal, std::optional<double> scale){ const std::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal, std::optional<double> scale, bool enable_gqa){
sdp::sdp_params kernel_params{query_, key, value, attn_mask_, dropout_p, is_causal}; sdp::sdp_params kernel_params{query_, key, value, attn_mask_, dropout_p, is_causal, enable_gqa};
auto backend = select_sdp_backend(kernel_params); auto backend = select_sdp_backend(kernel_params);
if (backend == sdp::SDPBackend::error) { if (backend == sdp::SDPBackend::error) {
TORCH_CHECK( TORCH_CHECK(

View File

@ -607,7 +607,7 @@ bool can_use_flash_attention(sdp_params const& params, bool debug) {
} }
if (has_only_dense_inputs(params)) { if (has_only_dense_inputs(params)) {
constexpr auto dense_constraints = array_of<bool (*)(sdp_params const&, bool)>( constexpr auto dense_constraints = array_of<bool (*)(sdp_params const&, bool)>(
check_batch_size_and_num_heads_dense, check_batch_size_and_num_heads_dense<true /*supports_grouped_query_attention=*/>,
check_nonzero_sequence_lengths_dense, check_nonzero_sequence_lengths_dense,
check_last_dim_stride_equals_1_dense<true /*ignore_singleton_dim=*/>); check_last_dim_stride_equals_1_dense<true /*ignore_singleton_dim=*/>);
for (auto& constraint : dense_constraints) { for (auto& constraint : dense_constraints) {
@ -665,9 +665,9 @@ bool can_use_mem_efficient_attention(sdp_params const& params, bool debug) {
} }
if (has_only_dense_inputs(params)) { if (has_only_dense_inputs(params)) {
constexpr auto dense_constraints = array_of<bool (*)(sdp_params const&, bool)>( constexpr auto dense_constraints = array_of<bool (*)(sdp_params const&, bool)>(
check_batch_size_and_num_heads_dense,
check_nonzero_sequence_lengths_dense, check_nonzero_sequence_lengths_dense,
check_last_dim_stride_equals_1_dense<false /*ignore_singleton_dim=*/>); check_last_dim_stride_equals_1_dense<false /*ignore_singleton_dim=*/>,
check_batch_size_and_num_heads_dense<false /*supports_grouped_query_attention=*/>);
for (auto& constraint : dense_constraints) { for (auto& constraint : dense_constraints) {
if (!constraint(params, debug)) { if (!constraint(params, debug)) {
return false; return false;

View File

@ -42,7 +42,7 @@ bool use_flash_attention_cpp(sdp_params const& params, bool debug) {
check_nested_tensor, check_nested_tensor,
check_for_dropout, check_for_dropout,
check_tensor_shapes, check_tensor_shapes,
check_batch_size_and_num_heads_dense, check_batch_size_and_num_heads_dense<false /*supports_grouped_query_attention*/>,
check_attn_mask_shape, check_attn_mask_shape,
check_head_dim_size_cpp, check_head_dim_size_cpp,
check_nonzero_sequence_lengths_dense, check_nonzero_sequence_lengths_dense,

View File

@ -48,6 +48,7 @@ struct sdp_params {
std::optional<at::Tensor> attn_mask; std::optional<at::Tensor> attn_mask;
double dropout; double dropout;
bool is_causal; bool is_causal;
bool enable_gqa;
}; };
SDPBackend select_sdp_backend_cpp(sdp_params const& kernel_params); SDPBackend select_sdp_backend_cpp(sdp_params const& kernel_params);
@ -353,6 +354,46 @@ inline bool check_safe_kv_broadcast(at::Tensor const& param, bool debug) {
return true; return true;
} }
inline bool check_grouped_query_attention(sdp_params const& params, bool debug) {
const auto q_num_heads = params.query.sym_size(-3);
const auto k_num_heads = params.key.sym_size(-3);
const auto v_num_heads = params.value.sym_size(-3);
const bool same_kv_heads = k_num_heads == v_num_heads;
if (!(same_kv_heads)){
if (debug) {
TORCH_WARN(
"Both fused kernels require key and value to have the same num_heads and batch_size but got: ",
"Key sizes: ",
params.key.sizes(),
", Value sizes: ",
params.value.sizes(),
", Query sizes: ",
params.query.sizes(),
" instead.");
}
return false;
}
// Check if grouped query attention is supported and validate the number of
// heads
if (q_num_heads % k_num_heads != 0) {
if (debug) {
TORCH_WARN(
"FlashAttentionV2 only supports grouped query attention, where the number of heads in key/value must divide number of heads in query.",
"Got input Key sizes(): ",
params.key.sym_size(-3),
", Value sizes(): ",
params.value.sym_size(-3),
", Query sizes(): ",
params.query.sym_size(-3),
" instead.");
}
return false;
}
return true;
}
template <bool supports_gqa>
inline bool check_batch_size_and_num_heads_dense(sdp_params const& params, bool debug) { inline bool check_batch_size_and_num_heads_dense(sdp_params const& params, bool debug) {
// This is expected to be called after check_tensor_shapes ensuring that the // This is expected to be called after check_tensor_shapes ensuring that the
// size() calls won't error since the inputs are all 4 dimensional // size() calls won't error since the inputs are all 4 dimensional
@ -364,16 +405,36 @@ inline bool check_batch_size_and_num_heads_dense(sdp_params const& params, bool
bool same_batch_size = bool same_batch_size =
q_batch_size == k_batch_size && q_batch_size == v_batch_size; q_batch_size == k_batch_size && q_batch_size == v_batch_size;
auto q_num_heads = params.query.sym_size(1); auto q_num_heads = params.query.sym_size(-3);
auto k_num_heads = params.key.sym_size(1); auto k_num_heads = params.key.sym_size(-3);
auto v_num_heads = params.value.sym_size(1); auto v_num_heads = params.value.sym_size(-3);
bool same_num_heads = bool same_num_heads =
q_num_heads == k_num_heads && q_num_heads == v_num_heads; q_num_heads == k_num_heads && q_num_heads == v_num_heads;
if (!(same_batch_size && same_num_heads)) { if (!same_batch_size){
if(debug) {
TORCH_WARN(
"For dense inputs, both fused kernels require query, key and value to have the same batch_size. ",
"Query.sizes(): ",
params.query.sizes(),
", Key.sizes(): ",
params.key.sizes(),
", Value.sizes(): ",
params.value.sizes(),
" instead. To broadcast dense inputs, try using unsqueeze and expand_to before passing them into the kernel.");
}
return false;
}
if(params.enable_gqa && supports_gqa){
return check_grouped_query_attention(params, debug);
}
if (!same_num_heads){
if (debug) { if (debug) {
TORCH_WARN( TORCH_WARN(
"For dense inputs, both fused kernels require query, key and value to have the same batch_size and num_heads. ", "For dense input, both fused kernels require query, key and value to have the same num_heads. ",
"Query.sizes(): ", "Query.sizes(): ",
params.query.sizes(), params.query.sizes(),
", Key sizes(): ", ", Key sizes(): ",
@ -384,6 +445,7 @@ inline bool check_batch_size_and_num_heads_dense(sdp_params const& params, bool
} }
return false; return false;
} }
// If all checks pass, return true
return true; return true;
} }

View File

@ -128,7 +128,7 @@ void quantize_tensor_per_tensor_affine_privateuse1(
} }
int64_t _fused_sdp_choice_privateuse1(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, int64_t _fused_sdp_choice_privateuse1(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value,
const c10::optional<at::Tensor> & attn_mask, double dropout_p, bool is_causal, c10::optional<double> scale){ const c10::optional<at::Tensor> & attn_mask, double dropout_p, bool is_causal, c10::optional<double> scale, bool enable_gqa){
auto backend = sdp::SDPBackend::overrideable; auto backend = sdp::SDPBackend::overrideable;
return static_cast<int64_t>(backend); return static_cast<int64_t>(backend);
} }

View File

@ -302,8 +302,9 @@ class DistMatrixOpsTest(DTensorTestBase):
# TODO: Add test cases where is_causal=False and an attention mask is provided. # TODO: Add test cases where is_causal=False and an attention mask is provided.
# Gaps include missing op support for aten.masked_fill_.Scalar. # Gaps include missing op support for aten.masked_fill_.Scalar.
is_causal = True is_causal = True
enable_gqa = False
params = torch.backends.cuda.SDPAParams( params = torch.backends.cuda.SDPAParams(
query, key, value, None, dropout_p, is_causal query, key, value, None, dropout_p, is_causal, enable_gqa
) )
if torch.backends.cuda.can_use_flash_attention(params, debug=False): if torch.backends.cuda.can_use_flash_attention(params, debug=False):
available_backends.append(SDPBackend.FLASH_ATTENTION) available_backends.append(SDPBackend.FLASH_ATTENTION)

View File

@ -31,7 +31,7 @@ class TestSDPA(torch._dynamo.test_case.TestCase):
@torch.compile(fullgraph=True, backend=counter) @torch.compile(fullgraph=True, backend=counter)
def fn(q, k, v, m): def fn(q, k, v, m):
return SDPAParams(q, k, v, m, 0.1, True) return SDPAParams(q, k, v, m, 0.1, True, False)
q = torch.randn(10) q = torch.randn(10)
k = torch.randn(10) k = torch.randn(10)
@ -39,7 +39,7 @@ class TestSDPA(torch._dynamo.test_case.TestCase):
m = torch.randn(10) m = torch.randn(10)
o = fn(q, k, v, m) o = fn(q, k, v, m)
self.assertTrue(isinstance(o, SDPAParams)) self.assertTrue(isinstance(o, SDPAParams))
self.assert_ref_equals_params(o, SDPAParams(q, k, v, m, 0.1, True)) self.assert_ref_equals_params(o, SDPAParams(q, k, v, m, 0.1, True, False))
self.assertEqual(counter.frame_count, 1) self.assertEqual(counter.frame_count, 1)
def test_graph_break_SDPAParams(self): def test_graph_break_SDPAParams(self):
@ -48,7 +48,7 @@ class TestSDPA(torch._dynamo.test_case.TestCase):
@torch.compile(backend=counter) @torch.compile(backend=counter)
def fn(q, k, v, m): def fn(q, k, v, m):
z = SDPAParams(q, k, v, m, 0.1, True) z = SDPAParams(q, k, v, m, 0.1, True, False)
torch._dynamo.graph_break() torch._dynamo.graph_break()
return z, q + 1 return z, q + 1
@ -58,7 +58,7 @@ class TestSDPA(torch._dynamo.test_case.TestCase):
m = torch.randn(10) m = torch.randn(10)
o, _ = fn(q, k, v, m) o, _ = fn(q, k, v, m)
self.assertTrue(isinstance(o, SDPAParams)) self.assertTrue(isinstance(o, SDPAParams))
self.assert_ref_equals_params(o, SDPAParams(q, k, v, m, 0.1, True)) self.assert_ref_equals_params(o, SDPAParams(q, k, v, m, 0.1, True, False))
self.assertEqual(counter.frame_count, 2) self.assertEqual(counter.frame_count, 2)
def test_input_SDPAParams(self): def test_input_SDPAParams(self):
@ -74,7 +74,7 @@ class TestSDPA(torch._dynamo.test_case.TestCase):
k = torch.randn(10) k = torch.randn(10)
v = torch.randn(10) v = torch.randn(10)
m = torch.randn(10) m = torch.randn(10)
s = SDPAParams(q, k, v, m, 0.1, True) s = SDPAParams(q, k, v, m, 0.1, True, False)
o, _ = fn(s, q) o, _ = fn(s, q)
self.assertIs(o, s) self.assertIs(o, s)
self.assertEqual(counter.frame_count, 1) self.assertEqual(counter.frame_count, 1)
@ -86,7 +86,7 @@ class TestSDPA(torch._dynamo.test_case.TestCase):
@torch.compile(fullgraph=True, backend=counter) @torch.compile(fullgraph=True, backend=counter)
def fn(q, k, v, m): def fn(q, k, v, m):
q += 1 q += 1
z = SDPAParams(q, k, v, m, 0.1, True) z = SDPAParams(q, k, v, m, 0.1, True, False)
a = z.query a = z.query
return a + 1, z, q return a + 1, z, q
@ -95,7 +95,7 @@ class TestSDPA(torch._dynamo.test_case.TestCase):
v = torch.randn(10) v = torch.randn(10)
m = torch.randn(10) m = torch.randn(10)
_, o, _ = fn(q, k, v, m) _, o, _ = fn(q, k, v, m)
expected = SDPAParams(q, k, v, m, 0.1, True) expected = SDPAParams(q, k, v, m, 0.1, True, False)
self.assert_ref_equals_params(o, expected) self.assert_ref_equals_params(o, expected)
self.assertEqual(counter.frame_count, 1) self.assertEqual(counter.frame_count, 1)

View File

@ -1561,6 +1561,36 @@ class TestSDPAFailureModes(NNTestCase):
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention( self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
q, k, v, None, 0.0, False)) q, k, v, None, 0.0, False))
@onlyCUDA
@skipIfRocm # Nested Tensor
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support SDPA or pre-SM80 hardware")
@parametrize("fused_kernel", [SDPBackend.EFFICIENT_ATTENTION])
def test_invalid_sdpa_kernel_grouped_query_attention_cuda(self, device, fused_kernel):
rand_query = torch.rand(8, 8, 64, 64, device=device, dtype=torch.float16, requires_grad=True)
rand_key = torch.rand(8, 4, 64, 64, device=device, dtype=torch.float16, requires_grad=True)
rand_value = torch.rand(8, 4, 64, 64, device=device, dtype=torch.float16, requires_grad=True)
with sdpa_kernel(fused_kernel):
with self.assertRaisesRegex(RuntimeError, "No available kernel"):
with self.assertWarnsRegex(UserWarning, "For dense inputs, both fused kernels require query, "
"key and value to have"):
F.scaled_dot_product_attention(rand_query, rand_key, rand_value, dropout_p=0.0,
is_causal=False, enable_gqa=True)
@onlyCPU
@skipIfRocm # Nested Tensor
def test_invalid_sdpa_kernel_grouped_query_attention_cpu(self, device):
rand_query = torch.rand(8, 8, 64, 64, device=device, dtype=torch.float16, requires_grad=True)
rand_key = torch.rand(8, 4, 64, 64, device=device, dtype=torch.float16, requires_grad=True)
rand_value = torch.rand(8, 4, 64, 64, device=device, dtype=torch.float16, requires_grad=True)
with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
with self.assertRaisesRegex(RuntimeError, "No available kernel"):
with self.assertWarnsRegex(UserWarning, "For dense inputs, both fused kernels require query, "
"key and value to have"):
F.scaled_dot_product_attention(rand_query, rand_key, rand_value, dropout_p=0.0,
is_causal=False, enable_gqa=True)
@onlyCUDA @onlyCUDA
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not flash_attention fused scaled dot product attention") @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not flash_attention fused scaled dot product attention")
@parametrize("kernel", PLATFORM_SPECIFIC_SDPA) @parametrize("kernel", PLATFORM_SPECIFIC_SDPA)
@ -1712,7 +1742,8 @@ class TestSDPAFailureModes(NNTestCase):
seq_len_list = [2, 4, 5, 6, 7] seq_len_list = [2, 4, 5, 6, 7]
shape = SdpaShape(5, 8, seq_len_list, 57) shape = SdpaShape(5, 8, seq_len_list, 57)
make_tensor = partial(rand_sdpa_tensor, shape=shape, type="nested", device=device, dtype=dtype) make_tensor = partial(rand_sdpa_tensor, shape=shape, type="nested", device=device, dtype=dtype)
q, k, v = make_tensor(), make_tensor(), make_tensor() q, k, v = make_tensor().transpose(1, 2), make_tensor().transpose(1, 2), make_tensor().transpose(1, 2)
with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]): with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
with self.assertWarnsRegex(UserWarning, "For NestedTensor inputs, Flash attention requires"): with self.assertWarnsRegex(UserWarning, "For NestedTensor inputs, Flash attention requires"):
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention( self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
@ -1792,7 +1823,7 @@ class TestSDPAFailureModes(NNTestCase):
with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]): with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
with self.assertWarnsRegex(UserWarning, "Both fused kernels do not support training with broadcasted NT inputs"): with self.assertWarnsRegex(UserWarning, "Both fused kernels do not support training with broadcasted NT inputs"):
with self.assertRaisesRegex(RuntimeError, "No available kernel"): with self.assertRaisesRegex(RuntimeError, "No available kernel"):
out = torch.nn.functional.scaled_dot_product_attention( torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False) query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False)
@onlyCUDA @onlyCUDA
@ -2949,23 +2980,32 @@ class TestSDPACudaOnly(NNTestCase):
@parametrize("dropout_p", [0.0, 0.22, 0.48]) @parametrize("dropout_p", [0.0, 0.22, 0.48])
@parametrize("dtype", [torch.float16, torch.bfloat16]) @parametrize("dtype", [torch.float16, torch.bfloat16])
@parametrize("scale", [None, "l1"]) @parametrize("scale", [None, "l1"])
@parametrize("enable_gqa", [True, False])
@parametrize("n_heads", [[16, 8], [10, 2]])
def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int, def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int,
head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype, head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype,
scale: str): scale: str, enable_gqa: bool, n_heads: List[int]):
if isSM8XDevice and head_dim in range(193, 256 + 1): if isSM8XDevice and head_dim in range(193, 256 + 1):
self.skipTest("Flash attention on sm86, sm87, and sm89 for headdim > 192 currently disabled") self.skipTest("Flash attention on sm86, sm87, and sm89 for headdim > 192 currently disabled")
if is_causal and seq_len_q != seq_len_k: if is_causal and seq_len_q != seq_len_k:
self.skipTest("Flash V2 does not accept is_casual when seq_len_q != seq_len_k") self.skipTest("Flash V2 does not accept is_casual when seq_len_q != seq_len_k")
if TEST_WITH_ROCM and seq_len_q >= 1024 and seq_len_k >= 1024 and batch_size > 1: if TEST_WITH_ROCM and seq_len_q >= 1024 and seq_len_k >= 1024 and batch_size > 1:
torch.cuda.empty_cache() # Prevent memory fragmentation torch.cuda.empty_cache() # Prevent memory fragmentation
if max(seq_len_q, seq_len_k) >= 2048 and torch.cuda.get_device_properties('cuda').total_memory < 40 * 2**30:
unittest.skip("Reference implementation OOM")
return
scale = scale if scale is None else (1 / head_dim) scale = scale if scale is None else (1 / head_dim)
n_heads = 4 num_heads_q = num_heads_kv = 4
query = torch.rand(batch_size, n_heads, seq_len_q, head_dim, if enable_gqa:
num_heads_q = n_heads[0]
num_heads_kv = n_heads[1]
query = torch.rand(batch_size, num_heads_q, seq_len_q, head_dim,
device=device, dtype=dtype, requires_grad=True) device=device, dtype=dtype, requires_grad=True)
key = torch.rand(batch_size, n_heads, seq_len_k, head_dim, device=device, key = torch.rand(batch_size, num_heads_kv, seq_len_k, head_dim, device=device,
dtype=dtype, requires_grad=True) dtype=dtype, requires_grad=True)
value = torch.rand(batch_size, n_heads, seq_len_k, head_dim, value = torch.rand(batch_size, num_heads_kv, seq_len_k, head_dim,
device=device, dtype=dtype, requires_grad=True) device=device, dtype=dtype, requires_grad=True)
higher_precision_dtype = torch.float64 if dtype == torch.float32 else torch.float32 higher_precision_dtype = torch.float64 if dtype == torch.float32 else torch.float32
@ -2975,14 +3015,15 @@ class TestSDPACudaOnly(NNTestCase):
if not is_dropout: if not is_dropout:
with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]): with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
out = F.scaled_dot_product_attention(query, key, value, dropout_p=dropout_p, is_causal=is_causal, scale=scale) out = F.scaled_dot_product_attention(
query, key, value, dropout_p=dropout_p, is_causal=is_causal, scale=scale, enable_gqa=enable_gqa)
with sdpa_kernel(backends=[SDPBackend.MATH]): with sdpa_kernel(backends=[SDPBackend.MATH]):
# High Precision Math Reference # High Precision Math Reference
out_ref = F.scaled_dot_product_attention( out_ref = F.scaled_dot_product_attention(
query_ref, key_ref, value_ref, is_causal=is_causal, scale=scale) query_ref, key_ref, value_ref, is_causal=is_causal, scale=scale, enable_gqa=enable_gqa)
# Low Precision Math Reference # Low Precision Math Reference
out_lp_ref = F.scaled_dot_product_attention( out_lp_ref = F.scaled_dot_product_attention(
query, key, value, is_causal=is_causal, scale=scale) query, key, value, is_causal=is_causal, scale=scale, enable_gqa=enable_gqa)
else: else:
# Problem: We pad sizes in the composite region of the top level SDPA. But we need the # Problem: We pad sizes in the composite region of the top level SDPA. But we need the
# Debug mask when have dropout. So I am going to manualy pad up here when testing dropout # Debug mask when have dropout. So I am going to manualy pad up here when testing dropout
@ -3009,11 +3050,12 @@ class TestSDPACudaOnly(NNTestCase):
dropout_mask = softmax_mask >= 0 dropout_mask = softmax_mask >= 0
# High Precision Math Reference # High Precision Math Reference
out_ref = torch.ops.aten._scaled_dot_product_attention_math( out_ref = torch.ops.aten._scaled_dot_product_attention_math(
query_ref, key_ref, value_ref, dropout_p=dropout_p, is_causal=is_causal, scale=scale, dropout_mask=dropout_mask)[0] query_ref, key_ref, value_ref, dropout_p=dropout_p, is_causal=is_causal,
scale=scale, dropout_mask=dropout_mask, enable_gqa=enable_gqa)[0]
# Low Precision Math Reference # Low Precision Math Reference
out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math( out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math(
query, key, value, dropout_p=dropout_p, is_causal=is_causal, scale=scale, query, key, value, dropout_p=dropout_p, is_causal=is_causal, scale=scale,
dropout_mask=dropout_mask)[0] dropout_mask=dropout_mask, enable_gqa=enable_gqa)[0]
upstream_grad = torch.rand_like(out, requires_grad=False) upstream_grad = torch.rand_like(out, requires_grad=False)
@ -3185,6 +3227,7 @@ class TestSDPACudaOnly(NNTestCase):
} }
) )
@skipIfRocm # Nested Tensor @skipIfRocm # Nested Tensor
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system") @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system")
@parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] if @parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] if

View File

@ -466,6 +466,7 @@ def gen_nn_functional(fm: FileManager) -> None:
"dropout_p: float = 0.0", "dropout_p: float = 0.0",
"is_causal: bool = False", "is_causal: bool = False",
"scale: Optional[float] = None", "scale: Optional[float] = None",
"enable_gqa: bool = False",
] ]
) )
) )

View File

@ -1958,6 +1958,7 @@ class _SDPAParams:
attn_mask: Optional[Tensor] attn_mask: Optional[Tensor]
dropout: _float dropout: _float
is_causal: _bool is_causal: _bool
enable_gqa: _bool
def __init__( def __init__(
self, self,
query: Tensor, query: Tensor,
@ -1965,7 +1966,8 @@ class _SDPAParams:
value: Tensor, value: Tensor,
attn_mask: Optional[Tensor], attn_mask: Optional[Tensor],
dropout: _float, dropout: _float,
is_causal: _bool) -> None: ... is_causal: _bool,
enable_gqa: _bool) -> None: ...
class _SDPBackend(Enum): class _SDPBackend(Enum):
ERROR = -1 ERROR = -1

View File

@ -34,6 +34,9 @@ class SDPAParamsVariable(VariableTracker):
is_causal_var = VariableBuilder(tx, AttrSource(source, "is_causal"))( is_causal_var = VariableBuilder(tx, AttrSource(source, "is_causal"))(
value.is_causal value.is_causal
) )
enable_gqa_var = VariableBuilder(tx, AttrSource(source, "enable_gqa"))(
value.enable_gqa
)
param_vars = [ param_vars = [
query_var, query_var,
key_var, key_var,
@ -41,6 +44,7 @@ class SDPAParamsVariable(VariableTracker):
attn_mask_var, attn_mask_var,
dropout_var, dropout_var,
is_causal_var, is_causal_var,
enable_gqa_var,
] ]
return TorchInGraphFunctionVariable(SDPAParams).call_function( return TorchInGraphFunctionVariable(SDPAParams).call_function(
tx, param_vars, {} tx, param_vars, {}

View File

@ -1955,16 +1955,24 @@ Call this whenever a new thread is created in order to propagate values from
at::Tensor const& value, at::Tensor const& value,
std::optional<at::Tensor> attn_mask, std::optional<at::Tensor> attn_mask,
double dropout, double dropout,
bool is_causal) { bool is_causal,
bool enable_gqa) {
return sdp::sdp_params{ return sdp::sdp_params{
query, key, value, std::move(attn_mask), dropout, is_causal}; query,
key,
value,
std::move(attn_mask),
dropout,
is_causal,
enable_gqa};
})) }))
.def_readonly("query", &sdp::sdp_params::query) .def_readonly("query", &sdp::sdp_params::query)
.def_readonly("key", &sdp::sdp_params::key) .def_readonly("key", &sdp::sdp_params::key)
.def_readonly("value", &sdp::sdp_params::value) .def_readonly("value", &sdp::sdp_params::value)
.def_readonly("attn_mask", &sdp::sdp_params::attn_mask) .def_readonly("attn_mask", &sdp::sdp_params::attn_mask)
.def_readonly("dropout", &sdp::sdp_params::dropout) .def_readonly("dropout", &sdp::sdp_params::dropout)
.def_readonly("is_causal", &sdp::sdp_params::is_causal); .def_readonly("is_causal", &sdp::sdp_params::is_causal)
.def_readonly("enable_gqa", &sdp::sdp_params::enable_gqa);
py::enum_<sdp::SDPBackend>( py::enum_<sdp::SDPBackend>(
py_module, py_module,

View File

@ -262,7 +262,7 @@ def _can_use_math_sdpa_jagged(params: SDPAParams, debug=False) -> bool:
return True return True
def _select_sdp_backend(query, key, value, attn_mask, dropout, is_causal): def _select_sdp_backend(query, key, value, attn_mask, dropout, is_causal, enable_gqa):
if ( if (
not flash_sdp_enabled() not flash_sdp_enabled()
and not mem_efficient_sdp_enabled() and not mem_efficient_sdp_enabled()
@ -276,7 +276,7 @@ def _select_sdp_backend(query, key, value, attn_mask, dropout, is_causal):
SDPBackend.MATH, SDPBackend.MATH,
) )
params = SDPAParams(query, key, value, attn_mask, dropout, is_causal) params = SDPAParams(query, key, value, attn_mask, dropout, is_causal, enable_gqa)
for backend in ordering: for backend in ordering:
if backend == SDPBackend.FLASH_ATTENTION: if backend == SDPBackend.FLASH_ATTENTION:
@ -623,6 +623,7 @@ def jagged_scaled_dot_product_attention(
dropout_p=0.0, dropout_p=0.0,
is_causal=False, is_causal=False,
scale=None, scale=None,
enable_gqa=False,
): ):
_validate_sdpa_input(query, key, value, attn_mask, dropout_p, is_causal, scale) _validate_sdpa_input(query, key, value, attn_mask, dropout_p, is_causal, scale)
# for mypy, ugh # for mypy, ugh
@ -653,7 +654,7 @@ def jagged_scaled_dot_product_attention(
compute_logsumexp = query.requires_grad or key.requires_grad or value.requires_grad compute_logsumexp = query.requires_grad or key.requires_grad or value.requires_grad
backend_choice = _select_sdp_backend( backend_choice = _select_sdp_backend(
query, key, value, attn_mask, dropout_p, is_causal query, key, value, attn_mask, dropout_p, is_causal, enable_gqa
) )
if backend_choice == SDPBackend.FLASH_ATTENTION: if backend_choice == SDPBackend.FLASH_ATTENTION:

View File

@ -175,6 +175,7 @@ class CausalBias(torch.Tensor):
dropout_p: float = 0.0, dropout_p: float = 0.0,
is_causal: bool = False, is_causal: bool = False,
scale: Optional[float] = None, scale: Optional[float] = None,
enable_gqa: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
r""" r"""
Handles the logic for computing attention with the specified causal bias. Handles the logic for computing attention with the specified causal bias.
@ -191,6 +192,7 @@ class CausalBias(torch.Tensor):
are set. are set.
scale (optional float): Scaling factor applied prior to softmax. If None, the default value is set scale (optional float): Scaling factor applied prior to softmax. If None, the default value is set
to :math:`\frac{1}{\sqrt{E}}`. to :math:`\frac{1}{\sqrt{E}}`.
enable_gqa (optional bool): If set to True, Grouped Query Attention (GQA) is enabled, by default it is set to False.
Returns: Returns:
output (Tensor): Attention output; shape :math:`(N, ..., L, Ev)`. output (Tensor): Attention output; shape :math:`(N, ..., L, Ev)`.
@ -214,10 +216,13 @@ class CausalBias(torch.Tensor):
dropout_p=dropout_p, dropout_p=dropout_p,
is_causal=True, is_causal=True,
scale=scale, scale=scale,
enable_gqa=enable_gqa,
) )
elif attn_mask.variant == CausalVariant.LOWER_RIGHT: elif attn_mask.variant == CausalVariant.LOWER_RIGHT:
_validate_sdpa_input(query, key, value, None, dropout_p, is_causal, scale) _validate_sdpa_input(query, key, value, None, dropout_p, is_causal, scale)
sdpa_params = SDPAParams(query, key, value, None, dropout_p, is_causal) sdpa_params = SDPAParams(
query, key, value, None, dropout_p, is_causal, enable_gqa
)
if can_use_flash_attention(sdpa_params): if can_use_flash_attention(sdpa_params):
needs_padding = query.size(-1) % 8 != 0 needs_padding = query.size(-1) % 8 != 0
og_head_size = query.size(-1) og_head_size = query.size(-1)
@ -266,6 +271,7 @@ class CausalBias(torch.Tensor):
dropout_p=dropout_p, dropout_p=dropout_p,
is_causal=False, is_causal=False,
scale=scale, scale=scale,
enable_gqa=enable_gqa,
) )
else: else:
raise ValueError( raise ValueError(

View File

@ -5606,20 +5606,21 @@ def _in_projection(
scaled_dot_product_attention = _add_docstr( scaled_dot_product_attention = _add_docstr(
torch._C._nn.scaled_dot_product_attention, torch._C._nn.scaled_dot_product_attention,
r""" r"""scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0,
scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> Tensor: is_causal=False, scale=None, enable_gqa=False) -> Tensor:
Computes scaled dot product attention on query, key and value tensors, using Computes scaled dot product attention on query, key and value tensors, using an optional attention mask if passed,
an optional attention mask if passed, and applying dropout if a probability and applying dropout if a probability greater than 0.0 is specified. The optional scale argument can only be
greater than 0.0 is specified. The optional scale argument can only be specified as a keyword argument. specified as a keyword argument.
.. code-block:: python .. code-block:: python
# Efficient implementation equivalent to the following: # Efficient implementation equivalent to the following:
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor: def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0,
is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor:
L, S = query.size(-2), key.size(-2) L, S = query.size(-2), key.size(-2)
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device) attn_bias = torch.zeros(L, S, dtype=query.dtype)
if is_causal: if is_causal:
assert attn_mask is None assert attn_mask is None
temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
@ -5630,17 +5631,22 @@ greater than 0.0 is specified. The optional scale argument can only be specified
if attn_mask.dtype == torch.bool: if attn_mask.dtype == torch.bool:
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
else: else:
attn_bias = attn_mask + attn_bias attn_bias += attn_mask
if enable_gqa:
key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)
attn_weight = query @ key.transpose(-2, -1) * scale_factor attn_weight = query @ key.transpose(-2, -1) * scale_factor
attn_weight += attn_bias attn_weight += attn_bias
attn_weight = torch.softmax(attn_weight, dim=-1) attn_weight = torch.softmax(attn_weight, dim=-1)
attn_weight = torch.dropout(attn_weight, dropout_p, train=True) attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
return attn_weight @ value return attn_weight @ value
.. warning:: This function is beta and subject to change. .. warning::
This function is beta and subject to change.
.. warning::
.. warning::
This function always applies dropout according to the specified ``dropout_p`` argument. This function always applies dropout according to the specified ``dropout_p`` argument.
To disable dropout during evaluation, be sure to pass a value of ``0.0`` when the module To disable dropout during evaluation, be sure to pass a value of ``0.0`` when the module
that makes the function call is not in training mode. that makes the function call is not in training mode.
@ -5655,9 +5661,10 @@ greater than 0.0 is specified. The optional scale argument can only be specified
self.p = p self.p = p
def forward(self, ...): def forward(self, ...):
return F.scaled_dot_product_attention(..., dropout_p=(self.p if self.training else 0.0)) return F.scaled_dot_product_attention(...,
dropout_p=(self.p if self.training else 0.0))
Note: Note:
There are currently three supported implementations of scaled dot product attention: There are currently three supported implementations of scaled dot product attention:
@ -5689,16 +5696,24 @@ Note:
For math backend, all intermediates are kept in torch.float if inputs are in torch.half or torch.bfloat16. For math backend, all intermediates are kept in torch.float if inputs are in torch.half or torch.bfloat16.
For more information please see :doc:`/notes/numerical_accuracy` For more information please see :doc:`/notes/numerical_accuracy`
Note: Grouped Query Attention (GQA) is an experimental feature. It currently works only for Flash_attention
and math kernel on CUDA tensor, and does not support Nested tensor.
Constraints for GQA:
- number_of_heads_query % number_of_heads_key_value == 0 and,
- number_of_heads_key == number_of_heads_value
Note:
{cudnn_reproducibility_note} {cudnn_reproducibility_note}
""".format( """.format(
**reproducibility_notes **reproducibility_notes
) )
+ r""" + r"""
Args: Args:
query (Tensor): Query tensor; shape :math:`(N, ..., L, E)`. query (Tensor): Query tensor; shape :math:`(N, ..., Hq, L, E)`.
key (Tensor): Key tensor; shape :math:`(N, ..., S, E)`. key (Tensor): Key tensor; shape :math:`(N, ..., H, S, E)`.
value (Tensor): Value tensor; shape :math:`(N, ..., S, Ev)`. value (Tensor): Value tensor; shape :math:`(N, ..., H, S, Ev)`.
attn_mask (optional Tensor): Attention mask; shape must be broadcastable to the shape of attention weights, attn_mask (optional Tensor): Attention mask; shape must be broadcastable to the shape of attention weights,
which is :math:`(N,..., L, S)`. Two types of masks are supported. which is :math:`(N,..., L, S)`. Two types of masks are supported.
A boolean mask where a value of True indicates that the element *should* take part in attention. A boolean mask where a value of True indicates that the element *should* take part in attention.
@ -5710,19 +5725,21 @@ Args:
An error is thrown if both attn_mask and is_causal are set. An error is thrown if both attn_mask and is_causal are set.
scale (optional float, keyword-only): Scaling factor applied prior to softmax. If None, the default value is set scale (optional float, keyword-only): Scaling factor applied prior to softmax. If None, the default value is set
to :math:`\frac{1}{\sqrt{E}}`. to :math:`\frac{1}{\sqrt{E}}`.
enable_gqa (bool): If set to True, Grouped Query Attention (GQA) is enabled, by default it is set to False.
Returns:
output (Tensor): Attention output; shape :math:`(N, ..., Hq, L, Ev)`.
Returns: Shape legend:
output (Tensor): Attention output; shape :math:`(N, ..., L, Ev)`.
Shape legend:
- :math:`N: \text{Batch size} ... : \text{Any number of other batch dimensions (optional)}` - :math:`N: \text{Batch size} ... : \text{Any number of other batch dimensions (optional)}`
- :math:`S: \text{Source sequence length}` - :math:`S: \text{Source sequence length}`
- :math:`L: \text{Target sequence length}` - :math:`L: \text{Target sequence length}`
- :math:`E: \text{Embedding dimension of the query and key}` - :math:`E: \text{Embedding dimension of the query and key}`
- :math:`Ev: \text{Embedding dimension of the value}` - :math:`Ev: \text{Embedding dimension of the value}`
- :math:`Hq: \text{Number of heads of query}`
- :math:`H: \text{Number of heads of key and value}`
Examples: Examples:
>>> # Optionally use the context manager to ensure one of the fused kernels is run >>> # Optionally use the context manager to ensure one of the fused kernels is run
>>> query = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda") >>> query = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
@ -5732,12 +5749,21 @@ Examples:
>>> F.scaled_dot_product_attention(query,key,value) >>> F.scaled_dot_product_attention(query,key,value)
.. _FlashAttention-2\: Faster Attention with Better Parallelism and Work Partitioning: >>> # Sample for GQA for llama3
https://arxiv.org/abs/2307.08691 >>> query = torch.rand(32, 32, 128, 64, dtype=torch.float16, device="cuda")
.. _Memory-Efficient Attention: >>> key = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
https://github.com/facebookresearch/xformers >>> value = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
>>> with sdpa_kernel(backends=[SDPBackend.MATH]):
>>> F.scaled_dot_product_attention(query,key,value,enable_gqa=True)
""",
.. _FlashAttention-2\: Faster Attention with Better Parallelism and Work Partitioning:
https://arxiv.org/abs/2307.08691
.. _Memory-Efficient Attention:
https://github.com/facebookresearch/xformers
.. _Grouped-Query Attention:
https://arxiv.org/pdf/2305.13245
""",
) )

View File

@ -8689,6 +8689,7 @@ def sample_inputs_scaled_mm(op_info, device, dtype, requires_grad, **kwargs):
def sample_inputs_scaled_dot_product_attention(op_info, device, dtype, requires_grad, **kwargs): def sample_inputs_scaled_dot_product_attention(op_info, device, dtype, requires_grad, **kwargs):
make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
batch, seq_q, seq_kv, num_heads, head_dim = 4, 3, 6, 4, 8 batch, seq_q, seq_kv, num_heads, head_dim = 4, 3, 6, 4, 8
num_heads_q_gqa, num_heads_kv_gqa = 32, 8
dim_3_q_shape = (batch, seq_q, head_dim) dim_3_q_shape = (batch, seq_q, head_dim)
dim_3_kv_shape = (batch, seq_kv, head_dim) dim_3_kv_shape = (batch, seq_kv, head_dim)
@ -8699,8 +8700,8 @@ def sample_inputs_scaled_dot_product_attention(op_info, device, dtype, requires_
qkv_shapes = [(dim_3_q_shape, dim_3_kv_shape), (dim_4_q_shape, dim_4_kv_shape), broadcast_tuple] qkv_shapes = [(dim_3_q_shape, dim_3_kv_shape), (dim_4_q_shape, dim_4_kv_shape), broadcast_tuple]
samples = [] samples = []
for qkv_shape, is_causal, dropout_p in product( for qkv_shape, is_causal, dropout_p, enable_gqa in product(
qkv_shapes, [True, False], [0.0, 0.5]): qkv_shapes, [True, False], [0.0, 0.5], [True, False]):
shape_q, shape_kv = qkv_shape shape_q, shape_kv = qkv_shape
samples.append(SampleInput( samples.append(SampleInput(
make(shape_q), make(shape_q),
@ -8730,6 +8731,15 @@ def sample_inputs_scaled_dot_product_attention(op_info, device, dtype, requires_
dropout_p=0.0) dropout_p=0.0)
) )
samples.append(
SampleInput(
make((batch, num_heads_q_gqa, seq_q, head_dim)),
make((batch, num_heads_kv_gqa, seq_kv, head_dim)),
make((batch, num_heads_kv_gqa, seq_kv, head_dim)),
enable_gqa=True
)
)
yield from samples yield from samples