[SDPA] Adds basic correctness checks (#94274)

# Summary
Add more checks around shape constraints as well as update the sdp_utils to properly catch different head_dims between qk and v for flash_attention which is not supported.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94274
Approved by: https://github.com/cpuhrsch
This commit is contained in:
Driss Guessous 2023-02-09 08:05:22 +00:00 committed by PyTorch MergeBot
parent 92f569fe11
commit 81bbee7d7e
4 changed files with 222 additions and 58 deletions

View File

@ -710,6 +710,33 @@ std::tuple<Tensor, Tensor> _scaled_dot_product_attention(
query_, key, value, attn_mask_, dropout_p, is_causal);
}
inline void validate_sdpa_input(
const Tensor& query_,
const Tensor& key,
const Tensor& value,
const c10::optional<Tensor>& attn_mask_,
double dropout_p,
bool is_causal) {
TORCH_CHECK(
query_.dtype() == key.dtype() && query_.dtype() == value.dtype(),
"Expected query, key, and value to have the same dtype, but got query.dtype: ",
query_.dtype(), " key.dtype: ", key.dtype(), " and value.dtype: ", value.dtype(), " instead.");
TORCH_CHECK(
query_.device() == key.device() && query_.device() == value.device(),
"Expected query, key, and value to have the same device type, but got query.device: ",
query_.device(), " key.device: ", key.device(), " and value.device: ", value.device(), " instead.");
TORCH_CHECK(
query_.dim() >= 2 && key.dim() >= 2 && value.dim() >= 2,
"Expected query, key, and value to all be at least 2 dimensional, but got query.dim: ",
query_.dim(), " key.dim: ", key.dim(), " and value.dim: ", value.dim(), " instead.");
if (attn_mask_.has_value()){
auto mask_dtype = attn_mask_->dtype();
TORCH_CHECK(mask_dtype == at::kBool || mask_dtype == query_.dtype(),
"Expected attn_mask dtype to be bool or to match query dtype, but got attn_mask.dtype: ",
mask_dtype, " and query.dtype: ", query_.dtype(), " instead.");
}
return;
}
// Computes scaled dot product attention on query, key and value tensors, using
// an optional attention mask if passed, and applying dropout if a probability
// greater than 0.0 is specified.
@ -745,6 +772,7 @@ Tensor scaled_dot_product_attention(
const c10::optional<Tensor>& attn_mask_,
double dropout_p,
bool is_causal) {
validate_sdpa_input(query_, key, value, attn_mask_, dropout_p, is_causal);
int64_t choice_int = static_cast<int64_t>(sdp::SDPBackend::math);
if (query_.device().type() == DeviceType::CUDA){
choice_int = _fused_sdp_choice_stub(query_.device().type(),

View File

@ -214,22 +214,51 @@ inline bool check_tensor_shapes(sdp_params params, bool debug) {
return true;
}
inline bool check_equal_batch_size_and_num_heads(sdp_params params, bool debug) {
// This is expected to be called after check_tensor_shapes ensuring that the size()
// calls won't error since the inputs are all 4 dimensional
bool same_batch_size = params.query.size(0) == params.key.size(0) &&
params.query.size(0) == params.value.size(0);
// We pass through for NestedTensors since this is checked in a later filter
bool same_num_heads = params.query.is_nested()
? true
: params.query.size(1) == params.key.size(1) &&
params.query.size(1) == params.value.size(1);
if (!(same_batch_size && same_num_heads)) {
if (debug) {
TORCH_WARN(
"Both fused kernels requires query, key and value to have the same batch_size and num_heads. Query.sizes(): ",
params.query.sizes(),
", Key sizes(): ",
params.key.sizes(),
", Value sizes(): ",
params.value.sizes(),
" instead.");
}
return false;
}
return true;
}
inline bool check_head_dim_size(sdp_params params, bool debug) {
const int64_t query_size_last = params.query.size(-1);
const int64_t key_size_last = params.key.size(-1);
const int64_t value_size_last = params.value.size(-1);
if (!(query_size_last == params.key.size(-1) && query_size_last % 8 == 0 &&
if (!(query_size_last == key_size_last &&
query_size_last == value_size_last && query_size_last % 8 == 0 &&
query_size_last <= 128 && value_size_last % 8 == 0 &&
value_size_last <= 128)) {
if (debug) {
TORCH_WARN(
"Flash attention requires last dimension of inputs to be a multiple of 8 and less than or equal to 128.",
"Got Query.size(-1): ",
query_size_last,
", Key.size(-1): ",
params.key.size(-1),
", Value.size(-1): ",
params.value.size(-1),
" instead.");
"Flash attention requires q,k,v to have the same last dimension and to be a multiple of 8 and less than or equal to 128.",
" Got Query.size(-1): ",
query_size_last,
", Key.size(-1): ",
params.key.size(-1),
", Value.size(-1): ",
params.value.size(-1),
" instead.");
}
return false;
}
@ -393,9 +422,10 @@ inline bool use_flash_attention(sdp_params params, bool debug) {
return false;
#endif
// Define gate functions that determine if a flash kernel can be ran
constexpr std::array<bool(*)(sdp_params, bool), 7> constraints {{
constexpr std::array<bool(*)(sdp_params, bool), 8> constraints {{
check_runtime_disabled_flash,
check_tensor_shapes,
check_equal_batch_size_and_num_heads,
check_for_attn_mask,
check_head_dim_size,
check_gpu_sm75_or_greater,
@ -427,11 +457,12 @@ inline bool use_mem_efficient_attention(sdp_params params, bool debug) {
at::kHalf, at::kFloat, at::kBFloat16};
// Define gate functions that determine if a flash kernel can be ran
constexpr std::array<bool(*)(sdp_params, bool), 10> constraints{{
constexpr std::array<bool(*)(sdp_params, bool), 11> constraints{{
check_gpu_sm50_or_greater,
check_runtime_disabled_mem_efficient,
check_requires_grad_and_nested,
check_tensor_shapes,
check_equal_batch_size_and_num_heads,
check_for_attn_mask,
check_head_dim_size_mem_efficient,
check_gpu_sm86_head_dim_128,

View File

@ -1076,6 +1076,13 @@ class TestSDPA(NNTestCase):
_do_cuda_memory_leak_check = True
_do_cuda_non_default_stream = True
backend_map = {
SDPBackend.MATH: {"enable_math": True, "enable_flash": False, "enable_mem_efficient": False},
SDPBackend.FLASH_ATTENTION: {"enable_math": False, "enable_flash": True, "enable_mem_efficient": False},
SDPBackend.EFFICIENT_ATTENTION: {
"enable_math": False, "enable_flash": False, "enable_mem_efficient": True}
}
def rand_tensor(self, shape: Tuple[int], device: str, dtype: torch.dtype,
type: str, requires_grad: bool = False, packed: bool = False) -> torch.Tensor:
"""Creates rand dense or nested tensor with given shape and type.
@ -1480,22 +1487,22 @@ class TestSDPA(NNTestCase):
assert torch._fused_sdp_choice(query, key, value) == (
SDPBackend.EFFICIENT_ATTENTION if warn_only else SDPBackend.MATH)
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA, "CUDA unavailable")
def test_sdp_runtime_dispatch(self):
# We will test all the constraints that we know will cause a failure
# The problem is that any code path that goes down flash_attention
# will fail on CI/CD becuase it is not compiled with the right flags
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA or not isSM86Device, "CUDA unavailable")
def test_memory_efficeint_sm86_failure(self):
device = 'cuda'
dtype = torch.float16
make_tensor = partial(self.rand_tensor, type="dense", device=device, dtype=dtype)
if isSM86Device:
# See check_gpu_sm86_head_dim_128 in pytorch/aten/src/ATen/native/transformers/cuda/sdp_utils.h
size = (2, 2, 4, 128)
q, k, v = make_tensor(size), make_tensor(size), make_tensor(size)
with sdp_kernel(enable_mem_efficient=True, enable_flash=False, enable_math=False):
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
q, k, v, None, 0.0, False))
# See check_gpu_sm86_head_dim_128 in pytorch/aten/src/ATen/native/transformers/cuda/sdp_utils.h
size = (2, 2, 4, 128)
q, k, v = make_tensor(size), make_tensor(size), make_tensor(size)
with sdp_kernel(enable_mem_efficient=True, enable_flash=False, enable_math=False):
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
q, k, v, None, 0.0, False))
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA, "Does not support fused scaled dot product attention")
def test_dispatch_fails_no_backend(self):
dtype = torch.float16
device = "cuda"
with sdp_kernel(enable_flash=False, enable_math=False, enable_mem_efficient=False):
size = (2, 3, 4)
q = torch.randn(size, device=device, dtype=dtype)
@ -1506,42 +1513,92 @@ class TestSDPA(NNTestCase):
self.assertRaisesRegex(RuntimeError, "No viable backend for scaled_dot_product_attention was found.",
lambda: torch.nn.functional.scaled_dot_product_attention(q, k, v))
if SM80OrLater:
with sdp_kernel(enable_flash=True, enable_mem_efficient=False, enable_math=False):
# Failures for invalid input
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA, "Does not support fused scaled dot product attention")
@parametrize(
"kernel",
[SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION]
if SM80OrLater
else [SDPBackend.EFFICIENT_ATTENTION],
)
def test_invalid_fused_inputs_dim_3(self, kernel: SDPBackend):
with sdp_kernel(**self.backend_map[kernel]):
# Dim is not 4
device = "cuda"
size = (2, 3, 8)
dtype = torch.float16
q = torch.randn(size, device=device, dtype=dtype)
k = torch.randn(size, device=device, dtype=dtype)
v = torch.randn(size, device=device, dtype=dtype)
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
q, k, v, None, 0.0, False))
# Dim is not 4
q = torch.randn(size, device=device, dtype=dtype)
k = torch.randn(size, device=device, dtype=dtype)
v = torch.randn(size, device=device, dtype=dtype)
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
q, k, v, None, 0.0, False))
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA, "Does not support fused scaled dot product attention")
@parametrize(
"kernel",
[SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION]
if SM80OrLater
else [SDPBackend.EFFICIENT_ATTENTION],
)
def test_invalid_fused_inputs_broadcast(self, kernel: SDPBackend):
with sdp_kernel(**self.backend_map[kernel]):
# Fused Kernels don't support broadcasting
device = "cuda"
dtype = torch.float16
size = (2, 4, 3, 8)
size_broadcast = (1, 4, 3, 8)
q = torch.randn(size_broadcast, device=device, dtype=dtype)
k = torch.randn(size, device=device, dtype=dtype)
v = torch.randn(size, device=device, dtype=dtype)
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
q, k, v, None, 0.0, False))
# The embed dim per head is not divisible by 8 for flash attention
size = (2, 2, 3, 4)
q, k, v = make_tensor(size), make_tensor(size), make_tensor(size)
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
q, k, v, None, 0.0, False))
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA or not SM80OrLater, "Does not support fused scaled dot product attention")
@parametrize("kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION])
def test_invalid_fused_inputs_head_dim(self, kernel: SDPBackend):
with sdp_kernel(**self.backend_map[kernel]):
# The embed dim per head is not divisible by 8 for flash attention
device = "cuda"
dtype = torch.float16
make_tensor = partial(self.rand_tensor, type="dense", device=device, dtype=dtype)
size = (2, 2, 3, 9)
q, k, v = make_tensor(size), make_tensor(size), make_tensor(size)
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
q, k, v, None, 0.0, False))
# Invalid dtype for both Flash Attention and Mem Efficient Attention
size = (2, 2, 3, 16)
make_tensor = partial(self.rand_tensor, type="dense", device=device, dtype=torch.float64)
q, k, v = make_tensor(size), make_tensor(size), make_tensor(size)
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
q, k, v, None, 0.0, False))
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA, "Does not support fused scaled dot product attention")
@parametrize(
"kernel",
[SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION]
if SM80OrLater
else [SDPBackend.EFFICIENT_ATTENTION],
)
def test_invalid_fused_inputs_invalid_dtype(self, kernel: SDPBackend):
with sdp_kernel(**self.backend_map[kernel]):
# Invalid dtype for both Flash Attention and Mem Efficient Attention
device = "cuda"
size = (2, 2, 3, 16)
make_tensor = partial(self.rand_tensor, type="dense", device=device, dtype=torch.float64)
q, k, v = make_tensor(size), make_tensor(size), make_tensor(size)
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
q, k, v, None, 0.0, False))
# Invalid dtype for Flash Attention
make_tensor = partial(self.rand_tensor, type="dense", device=device, dtype=torch.float32)
q, k, v = make_tensor(size), make_tensor(size), make_tensor(size)
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
q, k, v, None, 0.0, False))
# Failures for unsupported SDP args
q, k, v = make_tensor(size), make_tensor(size), make_tensor(size)
# Non-None attention mask
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
q, k, v, torch.ones_like(q), 0.0, False))
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA, "Does not support fused scaled dot product attention")
@parametrize(
"kernel",
[SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION]
if SM80OrLater
else [SDPBackend.EFFICIENT_ATTENTION],
)
def test_invalid_fused_inputs_attn_mask_present(self, kernel: SDPBackend):
with sdp_kernel(**self.backend_map[kernel]):
# Failures for unsupported SDP args
device = "cuda"
size = (2, 2, 3, 16)
make_tensor = partial(self.rand_tensor, type="dense", device=device, dtype=torch.float16)
q, k, v = make_tensor(size), make_tensor(size), make_tensor(size)
# Non-None attention mask
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
q, k, v, torch.ones_like(q), 0.0, False))
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA or not SM80OrLater, "CUDA unavailable")
def test_unaligned_tensors(self):
@ -1784,6 +1841,39 @@ class TestSDPA(NNTestCase):
self.assertEqual(value.grad, value_ref.grad.to(value.grad.dtype),
atol=grad_v_ref_atol, rtol=grad_v_ref_rtol)
@parametrize("kernel", [SDPBackend.MATH, SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION])
@parametrize("device", ["cpu", "cuda"] if TEST_CUDA else ["cpu"])
def test_invalid_inputs_different_datatypes(self, kernel: SDPBackend, device: str):
with sdp_kernel(**self.backend_map[kernel]):
# Different datatypes
shape = (1, 4, 8, 16)
query = torch.randn(shape, dtype=torch.float32, device=device)
key = torch.randn(shape, dtype=torch.float16, device=device)
value = torch.randn(shape, dtype=torch.float16, device=device)
self.assertRaises(RuntimeError, lambda: F.scaled_dot_product_attention(query, key, value))
@parametrize("kernel", [SDPBackend.MATH, SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION])
@parametrize("device", ["cpu", "cuda"] if TEST_CUDA else ["cpu"])
def test_invalid_inputs_different_devices(self, kernel: SDPBackend, device: str):
# Different devices
shape = (1, 4, 8, 16)
if device == "cuda":
query = torch.randn(shape, dtype=torch.float32, device=device)
key = torch.randn(shape, dtype=torch.float16, device='cpu')
value = torch.randn(shape, dtype=torch.float16, device='cpu')
self.assertRaises(RuntimeError, lambda: F.scaled_dot_product_attention(query, key, value))
@parametrize("kernel", [SDPBackend.MATH, SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION])
@parametrize("device", ["cpu", "cuda"] if TEST_CUDA else ["cpu"])
def test_invalid_inputs_1_dimensional_inputs(self, kernel: SDPBackend, device: str):
with sdp_kernel(**self.backend_map[kernel]):
# 1 dimensional input
shape = (1, 4)
query = torch.randn(4, dtype=torch.float16, device=device)
key = torch.randn(shape, dtype=torch.float16, device=device)
value = torch.randn(shape, dtype=torch.float16, device=device)
self.assertRaises(RuntimeError, lambda: F.scaled_dot_product_attention(query, key, value))
# TODO: Replace this with instantiate_device_type_tests() to take advantage of test framework support for
# cross device / dtype testing.
instantiate_parametrized_tests(TestTransformers)

View File

@ -7670,17 +7670,32 @@ def sample_inputs_scaled_dot_product_attention(op_info, device, dtype, requires_
dim_4_q_shape = (batch, num_heads, seq_q, head_dim)
dim_4_kv_shape = (batch, num_heads, seq_kv, head_dim)
qkv_shapes = [(dim_3_q_shape, dim_3_kv_shape), (dim_4_q_shape, dim_4_kv_shape)]
broadcast_tuple = ((num_heads, seq_q, head_dim), (batch, num_heads, seq_kv, head_dim))
qkv_shapes = [(dim_3_q_shape, dim_3_kv_shape), (dim_4_q_shape, dim_4_kv_shape), broadcast_tuple]
samples = []
for qkv_shapes, is_causal, dropout_p in product(
qkv_shapes, [True, False], [0.0, 0.5]):
shape_q, shape_kv = qkv_shapes
yield SampleInput(
samples.append(SampleInput(
make(shape_q),
make(shape_kv),
make(shape_kv),
is_causal=is_causal,
dropout_p=dropout_p
)
))
# Add non standard shapes
diff_v_head_dim = SampleInput(
make((batch, num_heads, seq_q, head_dim)),
make((batch, num_heads, seq_kv, head_dim)),
make((batch, num_heads, seq_kv, head_dim + 8)),
is_causal=is_causal,
dropout_p=dropout_p
)
samples.append(diff_v_head_dim)
yield from samples
def sample_inputs_pairwise_distance(op_info, device, dtype, requires_grad, **kwargs):
make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)