[sdpa] move seq_len_1 check and replace with seq_len_0 check in sdp_utils (#95486)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95486
Approved by: https://github.com/drisspg, https://github.com/cpuhrsch
This commit is contained in:
Mikayla Gawarecki 2023-03-01 17:54:36 +00:00 committed by PyTorch MergeBot
parent 72b9d45e76
commit 63796d35ef
3 changed files with 110 additions and 11 deletions

View File

@ -568,7 +568,11 @@ std::tuple<Tensor, Tensor> native_multi_head_attention_cuda(
sdp::sdp_params kernel_params{q, k, v, mask.has_value(), 0.0, false};
auto backend = select_sdp_backend(kernel_params);
if (backend == sdp::SDPBackend::flash_attention || backend == sdp::SDPBackend::efficient_attention) {
// strides from packed projection for nested tensors when seq_len is 1 will be
// and will trigger a contiguous call in the kernel, so we prevent this
bool no_seq_len_1_nested = query.is_nested() ? check_for_seq_len_1_nested_tensor(kernel_params, false) : true;
if (no_seq_len_1_nested &&
(backend == sdp::SDPBackend::flash_attention || backend == sdp::SDPBackend::efficient_attention)) {
auto x = at::linear(query, qkv_weight, qkv_bias);
auto chunks = x.chunk(3, -1);
auto x_size_0 = x.size(0);

View File

@ -124,7 +124,7 @@ inline bool check_for_non_zero_dropout(sdp_params params, bool debug) {
return true;
}
inline bool check_for_seq_len_1_nested_tensor(sdp_params params, bool debug) {
inline bool check_for_seq_len_0_nested_tensor(sdp_params params, bool debug) {
// When this function is called we are assured that the nt is dim==4
if (!params.query.is_nested()) {
return true;
@ -145,11 +145,36 @@ inline bool check_for_seq_len_1_nested_tensor(sdp_params params, bool debug) {
const int64_t n_tensors = params.query.size(0);
const int64_t size_tensor_stride = sizes.stride(0);
// This is being called inside sdp with shape [batch, heads, {seq_len}, dim]
for (const auto i : c10::irange(n_tensors)) {
if (sizes_ptr[(i * size_tensor_stride) + 1] == 0) {
if (debug) {
TORCH_WARN("Memory efficient attention does not support sequence_length == 0");
}
return false;
}
}
return true;
}
inline bool check_for_seq_len_1_nested_tensor(sdp_params params, bool debug) {
// When this function is called we are assured that the nt is dim==4
if (!params.query.is_nested()) {
return true;
}
const auto nt_q_tensor_impl = at::native::get_nested_tensor_impl(params.query);
const at::Tensor& sizes = nt_q_tensor_impl->get_nested_size_tensor();
auto* sizes_ptr = sizes.data_ptr<int64_t>();
const int64_t n_tensors = params.query.size(0);
const int64_t size_tensor_stride = sizes.stride(0);
// This is being called inside sdp with shape [batch, heads, {seq_len}, dim]
for (const auto i : c10::irange(n_tensors)) {
if (sizes_ptr[(i * size_tensor_stride) + 1] <= 1) {
if (debug) {
TORCH_WARN("Memory efficient attention does not support sequence_length <= 1");
TORCH_WARN("Packed projection for fused kernels does not support sequence_length <= 1");
}
return false;
}
@ -451,7 +476,7 @@ inline bool use_flash_attention(sdp_params params, bool debug) {
check_head_dim_size,
check_gpu_sm75_or_greater,
check_requires_grad_and_head_dim_128_and_sm86,
check_for_seq_len_1_nested_tensor);
check_for_seq_len_0_nested_tensor);
for (auto& constraint : constraints) {
if (!constraint(params, debug)) {
return false;
@ -487,7 +512,7 @@ inline bool use_mem_efficient_attention(sdp_params params, bool debug) {
check_for_attn_mask,
check_head_dim_size_mem_efficient,
check_gpu_sm86_head_dim_128,
check_for_seq_len_1_nested_tensor,
check_for_seq_len_0_nested_tensor,
check_for_non_zero_dropout,
check_use_deterministic_algorithms);
for (auto& constraint : constraints) {

View File

@ -13,7 +13,7 @@ from torch.backends.cuda import sdp_kernel, SDPBackend
import torch.optim as optim
from torch.testing._internal.common_dtype import floating_types_and_half
from typing import Tuple
from typing import List, Tuple, Union
from torch.testing._internal.common_nn import NNTestCase
from torch.testing._internal.common_utils import (
TEST_FAIRSEQ,
@ -1092,7 +1092,7 @@ class TestSDPA(NNTestCase):
"enable_math": False, "enable_flash": False, "enable_mem_efficient": True}
}
def rand_tensor(self, shape: Tuple[int], device: str, dtype: torch.dtype,
def rand_tensor(self, shape: Tuple[Union[int, List[int]]], device: str, dtype: torch.dtype,
type: str, requires_grad: bool = False, packed: bool = False) -> torch.Tensor:
"""Creates rand dense or nested tensor with given shape and type.
@ -1109,11 +1109,20 @@ class TestSDPA(NNTestCase):
"""
batch, seq_len, num_heads, head_dim = shape
if type == "nested":
size = (seq_len, num_heads, head_dim) if not packed else (seq_len, 3 * num_heads * head_dim)
return torch.nested.nested_tensor([
torch.randn(size, device=device, dtype=dtype, requires_grad=requires_grad)
for _ in range(batch)])
if isinstance(seq_len, list):
def _size(i):
return (seq_len[i], num_heads, head_dim) if not packed else (seq_len[i], 3 * num_heads * head_dim)
return torch.nested.nested_tensor([
torch.randn(_size(i), device=device, dtype=dtype, requires_grad=requires_grad)
for i in range(batch)])
else:
size = (seq_len, num_heads, head_dim) if not packed else (seq_len, 3 * num_heads * head_dim)
return torch.nested.nested_tensor([
torch.randn(size, device=device, dtype=dtype, requires_grad=requires_grad)
for _ in range(batch)])
else:
assert(isinstance(seq_len, int))
size = (batch, seq_len, num_heads, head_dim) if not packed else (batch, seq_len, 3 * num_heads * head_dim)
return torch.randn(size, device=device, dtype=dtype, requires_grad=requires_grad)
@ -1897,6 +1906,67 @@ class TestSDPA(NNTestCase):
value = torch.randn(shape, dtype=torch.float16, device=device)
self.assertRaises(RuntimeError, lambda: F.scaled_dot_product_attention(query, key, value))
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA, "Fused SDPA was not built for this system")
@parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION])
def test_fused_kernels_seq_len_1_inputs(self, fused_kernel):
if (not SM80OrLater) and fused_kernel == SDPBackend.FLASH_ATTENTION:
return
rand_nested_tensor = partial(self.rand_tensor, type="nested", device="cuda", dtype=torch.float16)
batch, num_heads, head_dim = 32, 16, 64
seq_lens = torch.randint(low=1, high=32, size=(batch,))
# make sure some seq_lens are 1
num_ones = 10
indices = torch.randint(low=0, high=batch, size=(num_ones,))
seq_lens.scatter_(0, indices, 1)
shape = (batch, seq_lens.tolist(), num_heads, head_dim)
query = rand_nested_tensor(shape)
key = rand_nested_tensor(shape)
value = rand_nested_tensor(shape)
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
with sdp_kernel(**self.backend_map[fused_kernel]):
actual = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False)
with sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False):
math_ref = torch.nn.functional.scaled_dot_product_attention(
query.contiguous().to(torch.float32),
key.contiguous().to(torch.float32),
value.contiguous().to(torch.float32),
attn_mask=None, dropout_p=0.0, is_causal=False)
self.assertEqual(actual.contiguous(), math_ref.contiguous().to(torch.float16), atol=1e-3, rtol=1e-2)
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA, "Fused SDPA was not built for this system")
@parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION])
def test_fused_kernels_seq_len_0_inputs(self, fused_kernel):
if (not SM80OrLater) and fused_kernel == SDPBackend.FLASH_ATTENTION:
return
rand_nested_tensor = partial(self.rand_tensor, type="nested", device="cuda", dtype=torch.float16)
batch, num_heads, head_dim = 32, 16, 64
seq_lens = torch.randint(low=1, high=32, size=(batch,))
# make sure some seq_lens are 0
num_zeros = 10
indices = torch.randint(low=0, high=batch, size=(num_zeros,))
seq_lens.scatter_(0, indices, 0)
shape = (batch, seq_lens.tolist(), num_heads, head_dim)
query = rand_nested_tensor(shape)
key = rand_nested_tensor(shape)
value = rand_nested_tensor(shape)
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
with sdp_kernel(**self.backend_map[fused_kernel]):
with self.assertRaisesRegex(RuntimeError, "No available kernel"):
torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False)
# TODO: Replace this with instantiate_device_type_tests() to take advantage of test framework support for
# cross device / dtype testing.