mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[sdpa] move seq_len_1 check and replace with seq_len_0 check in sdp_utils (#95486)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95486 Approved by: https://github.com/drisspg, https://github.com/cpuhrsch
This commit is contained in:
parent
72b9d45e76
commit
63796d35ef
|
|
@ -568,7 +568,11 @@ std::tuple<Tensor, Tensor> native_multi_head_attention_cuda(
|
|||
|
||||
sdp::sdp_params kernel_params{q, k, v, mask.has_value(), 0.0, false};
|
||||
auto backend = select_sdp_backend(kernel_params);
|
||||
if (backend == sdp::SDPBackend::flash_attention || backend == sdp::SDPBackend::efficient_attention) {
|
||||
// strides from packed projection for nested tensors when seq_len is 1 will be
|
||||
// and will trigger a contiguous call in the kernel, so we prevent this
|
||||
bool no_seq_len_1_nested = query.is_nested() ? check_for_seq_len_1_nested_tensor(kernel_params, false) : true;
|
||||
if (no_seq_len_1_nested &&
|
||||
(backend == sdp::SDPBackend::flash_attention || backend == sdp::SDPBackend::efficient_attention)) {
|
||||
auto x = at::linear(query, qkv_weight, qkv_bias);
|
||||
auto chunks = x.chunk(3, -1);
|
||||
auto x_size_0 = x.size(0);
|
||||
|
|
|
|||
|
|
@ -124,7 +124,7 @@ inline bool check_for_non_zero_dropout(sdp_params params, bool debug) {
|
|||
return true;
|
||||
}
|
||||
|
||||
inline bool check_for_seq_len_1_nested_tensor(sdp_params params, bool debug) {
|
||||
inline bool check_for_seq_len_0_nested_tensor(sdp_params params, bool debug) {
|
||||
// When this function is called we are assured that the nt is dim==4
|
||||
if (!params.query.is_nested()) {
|
||||
return true;
|
||||
|
|
@ -145,11 +145,36 @@ inline bool check_for_seq_len_1_nested_tensor(sdp_params params, bool debug) {
|
|||
const int64_t n_tensors = params.query.size(0);
|
||||
const int64_t size_tensor_stride = sizes.stride(0);
|
||||
|
||||
// This is being called inside sdp with shape [batch, heads, {seq_len}, dim]
|
||||
for (const auto i : c10::irange(n_tensors)) {
|
||||
if (sizes_ptr[(i * size_tensor_stride) + 1] == 0) {
|
||||
if (debug) {
|
||||
TORCH_WARN("Memory efficient attention does not support sequence_length == 0");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
inline bool check_for_seq_len_1_nested_tensor(sdp_params params, bool debug) {
|
||||
// When this function is called we are assured that the nt is dim==4
|
||||
if (!params.query.is_nested()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
const auto nt_q_tensor_impl = at::native::get_nested_tensor_impl(params.query);
|
||||
const at::Tensor& sizes = nt_q_tensor_impl->get_nested_size_tensor();
|
||||
auto* sizes_ptr = sizes.data_ptr<int64_t>();
|
||||
const int64_t n_tensors = params.query.size(0);
|
||||
const int64_t size_tensor_stride = sizes.stride(0);
|
||||
|
||||
// This is being called inside sdp with shape [batch, heads, {seq_len}, dim]
|
||||
for (const auto i : c10::irange(n_tensors)) {
|
||||
if (sizes_ptr[(i * size_tensor_stride) + 1] <= 1) {
|
||||
if (debug) {
|
||||
TORCH_WARN("Memory efficient attention does not support sequence_length <= 1");
|
||||
TORCH_WARN("Packed projection for fused kernels does not support sequence_length <= 1");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
|
@ -451,7 +476,7 @@ inline bool use_flash_attention(sdp_params params, bool debug) {
|
|||
check_head_dim_size,
|
||||
check_gpu_sm75_or_greater,
|
||||
check_requires_grad_and_head_dim_128_and_sm86,
|
||||
check_for_seq_len_1_nested_tensor);
|
||||
check_for_seq_len_0_nested_tensor);
|
||||
for (auto& constraint : constraints) {
|
||||
if (!constraint(params, debug)) {
|
||||
return false;
|
||||
|
|
@ -487,7 +512,7 @@ inline bool use_mem_efficient_attention(sdp_params params, bool debug) {
|
|||
check_for_attn_mask,
|
||||
check_head_dim_size_mem_efficient,
|
||||
check_gpu_sm86_head_dim_128,
|
||||
check_for_seq_len_1_nested_tensor,
|
||||
check_for_seq_len_0_nested_tensor,
|
||||
check_for_non_zero_dropout,
|
||||
check_use_deterministic_algorithms);
|
||||
for (auto& constraint : constraints) {
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from torch.backends.cuda import sdp_kernel, SDPBackend
|
|||
import torch.optim as optim
|
||||
from torch.testing._internal.common_dtype import floating_types_and_half
|
||||
|
||||
from typing import Tuple
|
||||
from typing import List, Tuple, Union
|
||||
from torch.testing._internal.common_nn import NNTestCase
|
||||
from torch.testing._internal.common_utils import (
|
||||
TEST_FAIRSEQ,
|
||||
|
|
@ -1092,7 +1092,7 @@ class TestSDPA(NNTestCase):
|
|||
"enable_math": False, "enable_flash": False, "enable_mem_efficient": True}
|
||||
}
|
||||
|
||||
def rand_tensor(self, shape: Tuple[int], device: str, dtype: torch.dtype,
|
||||
def rand_tensor(self, shape: Tuple[Union[int, List[int]]], device: str, dtype: torch.dtype,
|
||||
type: str, requires_grad: bool = False, packed: bool = False) -> torch.Tensor:
|
||||
"""Creates rand dense or nested tensor with given shape and type.
|
||||
|
||||
|
|
@ -1109,11 +1109,20 @@ class TestSDPA(NNTestCase):
|
|||
"""
|
||||
batch, seq_len, num_heads, head_dim = shape
|
||||
if type == "nested":
|
||||
size = (seq_len, num_heads, head_dim) if not packed else (seq_len, 3 * num_heads * head_dim)
|
||||
return torch.nested.nested_tensor([
|
||||
torch.randn(size, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
for _ in range(batch)])
|
||||
if isinstance(seq_len, list):
|
||||
def _size(i):
|
||||
return (seq_len[i], num_heads, head_dim) if not packed else (seq_len[i], 3 * num_heads * head_dim)
|
||||
|
||||
return torch.nested.nested_tensor([
|
||||
torch.randn(_size(i), device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
for i in range(batch)])
|
||||
else:
|
||||
size = (seq_len, num_heads, head_dim) if not packed else (seq_len, 3 * num_heads * head_dim)
|
||||
return torch.nested.nested_tensor([
|
||||
torch.randn(size, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
for _ in range(batch)])
|
||||
else:
|
||||
assert(isinstance(seq_len, int))
|
||||
size = (batch, seq_len, num_heads, head_dim) if not packed else (batch, seq_len, 3 * num_heads * head_dim)
|
||||
return torch.randn(size, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
|
||||
|
|
@ -1897,6 +1906,67 @@ class TestSDPA(NNTestCase):
|
|||
value = torch.randn(shape, dtype=torch.float16, device=device)
|
||||
self.assertRaises(RuntimeError, lambda: F.scaled_dot_product_attention(query, key, value))
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA, "Fused SDPA was not built for this system")
|
||||
@parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION])
|
||||
def test_fused_kernels_seq_len_1_inputs(self, fused_kernel):
|
||||
if (not SM80OrLater) and fused_kernel == SDPBackend.FLASH_ATTENTION:
|
||||
return
|
||||
rand_nested_tensor = partial(self.rand_tensor, type="nested", device="cuda", dtype=torch.float16)
|
||||
batch, num_heads, head_dim = 32, 16, 64
|
||||
seq_lens = torch.randint(low=1, high=32, size=(batch,))
|
||||
# make sure some seq_lens are 1
|
||||
num_ones = 10
|
||||
indices = torch.randint(low=0, high=batch, size=(num_ones,))
|
||||
seq_lens.scatter_(0, indices, 1)
|
||||
|
||||
shape = (batch, seq_lens.tolist(), num_heads, head_dim)
|
||||
query = rand_nested_tensor(shape)
|
||||
key = rand_nested_tensor(shape)
|
||||
value = rand_nested_tensor(shape)
|
||||
|
||||
query = query.transpose(1, 2)
|
||||
key = key.transpose(1, 2)
|
||||
value = value.transpose(1, 2)
|
||||
|
||||
with sdp_kernel(**self.backend_map[fused_kernel]):
|
||||
actual = torch.nn.functional.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False)
|
||||
with sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False):
|
||||
math_ref = torch.nn.functional.scaled_dot_product_attention(
|
||||
query.contiguous().to(torch.float32),
|
||||
key.contiguous().to(torch.float32),
|
||||
value.contiguous().to(torch.float32),
|
||||
attn_mask=None, dropout_p=0.0, is_causal=False)
|
||||
|
||||
self.assertEqual(actual.contiguous(), math_ref.contiguous().to(torch.float16), atol=1e-3, rtol=1e-2)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA, "Fused SDPA was not built for this system")
|
||||
@parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION])
|
||||
def test_fused_kernels_seq_len_0_inputs(self, fused_kernel):
|
||||
if (not SM80OrLater) and fused_kernel == SDPBackend.FLASH_ATTENTION:
|
||||
return
|
||||
rand_nested_tensor = partial(self.rand_tensor, type="nested", device="cuda", dtype=torch.float16)
|
||||
batch, num_heads, head_dim = 32, 16, 64
|
||||
seq_lens = torch.randint(low=1, high=32, size=(batch,))
|
||||
# make sure some seq_lens are 0
|
||||
num_zeros = 10
|
||||
indices = torch.randint(low=0, high=batch, size=(num_zeros,))
|
||||
seq_lens.scatter_(0, indices, 0)
|
||||
|
||||
shape = (batch, seq_lens.tolist(), num_heads, head_dim)
|
||||
query = rand_nested_tensor(shape)
|
||||
key = rand_nested_tensor(shape)
|
||||
value = rand_nested_tensor(shape)
|
||||
|
||||
query = query.transpose(1, 2)
|
||||
key = key.transpose(1, 2)
|
||||
value = value.transpose(1, 2)
|
||||
|
||||
with sdp_kernel(**self.backend_map[fused_kernel]):
|
||||
with self.assertRaisesRegex(RuntimeError, "No available kernel"):
|
||||
torch.nn.functional.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False)
|
||||
|
||||
|
||||
# TODO: Replace this with instantiate_device_type_tests() to take advantage of test framework support for
|
||||
# cross device / dtype testing.
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user