mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[SDPA] Adds basic correctness checks (#94274)
# Summary Add more checks around shape constraints as well as update the sdp_utils to properly catch different head_dims between qk and v for flash_attention which is not supported. Pull Request resolved: https://github.com/pytorch/pytorch/pull/94274 Approved by: https://github.com/cpuhrsch
This commit is contained in:
parent
92f569fe11
commit
81bbee7d7e
|
|
@ -710,6 +710,33 @@ std::tuple<Tensor, Tensor> _scaled_dot_product_attention(
|
|||
query_, key, value, attn_mask_, dropout_p, is_causal);
|
||||
}
|
||||
|
||||
inline void validate_sdpa_input(
|
||||
const Tensor& query_,
|
||||
const Tensor& key,
|
||||
const Tensor& value,
|
||||
const c10::optional<Tensor>& attn_mask_,
|
||||
double dropout_p,
|
||||
bool is_causal) {
|
||||
TORCH_CHECK(
|
||||
query_.dtype() == key.dtype() && query_.dtype() == value.dtype(),
|
||||
"Expected query, key, and value to have the same dtype, but got query.dtype: ",
|
||||
query_.dtype(), " key.dtype: ", key.dtype(), " and value.dtype: ", value.dtype(), " instead.");
|
||||
TORCH_CHECK(
|
||||
query_.device() == key.device() && query_.device() == value.device(),
|
||||
"Expected query, key, and value to have the same device type, but got query.device: ",
|
||||
query_.device(), " key.device: ", key.device(), " and value.device: ", value.device(), " instead.");
|
||||
TORCH_CHECK(
|
||||
query_.dim() >= 2 && key.dim() >= 2 && value.dim() >= 2,
|
||||
"Expected query, key, and value to all be at least 2 dimensional, but got query.dim: ",
|
||||
query_.dim(), " key.dim: ", key.dim(), " and value.dim: ", value.dim(), " instead.");
|
||||
if (attn_mask_.has_value()){
|
||||
auto mask_dtype = attn_mask_->dtype();
|
||||
TORCH_CHECK(mask_dtype == at::kBool || mask_dtype == query_.dtype(),
|
||||
"Expected attn_mask dtype to be bool or to match query dtype, but got attn_mask.dtype: ",
|
||||
mask_dtype, " and query.dtype: ", query_.dtype(), " instead.");
|
||||
}
|
||||
return;
|
||||
}
|
||||
// Computes scaled dot product attention on query, key and value tensors, using
|
||||
// an optional attention mask if passed, and applying dropout if a probability
|
||||
// greater than 0.0 is specified.
|
||||
|
|
@ -745,6 +772,7 @@ Tensor scaled_dot_product_attention(
|
|||
const c10::optional<Tensor>& attn_mask_,
|
||||
double dropout_p,
|
||||
bool is_causal) {
|
||||
validate_sdpa_input(query_, key, value, attn_mask_, dropout_p, is_causal);
|
||||
int64_t choice_int = static_cast<int64_t>(sdp::SDPBackend::math);
|
||||
if (query_.device().type() == DeviceType::CUDA){
|
||||
choice_int = _fused_sdp_choice_stub(query_.device().type(),
|
||||
|
|
|
|||
|
|
@ -214,22 +214,51 @@ inline bool check_tensor_shapes(sdp_params params, bool debug) {
|
|||
return true;
|
||||
}
|
||||
|
||||
inline bool check_equal_batch_size_and_num_heads(sdp_params params, bool debug) {
|
||||
// This is expected to be called after check_tensor_shapes ensuring that the size()
|
||||
// calls won't error since the inputs are all 4 dimensional
|
||||
bool same_batch_size = params.query.size(0) == params.key.size(0) &&
|
||||
params.query.size(0) == params.value.size(0);
|
||||
// We pass through for NestedTensors since this is checked in a later filter
|
||||
bool same_num_heads = params.query.is_nested()
|
||||
? true
|
||||
: params.query.size(1) == params.key.size(1) &&
|
||||
params.query.size(1) == params.value.size(1);
|
||||
|
||||
if (!(same_batch_size && same_num_heads)) {
|
||||
if (debug) {
|
||||
TORCH_WARN(
|
||||
"Both fused kernels requires query, key and value to have the same batch_size and num_heads. Query.sizes(): ",
|
||||
params.query.sizes(),
|
||||
", Key sizes(): ",
|
||||
params.key.sizes(),
|
||||
", Value sizes(): ",
|
||||
params.value.sizes(),
|
||||
" instead.");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
inline bool check_head_dim_size(sdp_params params, bool debug) {
|
||||
const int64_t query_size_last = params.query.size(-1);
|
||||
const int64_t key_size_last = params.key.size(-1);
|
||||
const int64_t value_size_last = params.value.size(-1);
|
||||
if (!(query_size_last == params.key.size(-1) && query_size_last % 8 == 0 &&
|
||||
if (!(query_size_last == key_size_last &&
|
||||
query_size_last == value_size_last && query_size_last % 8 == 0 &&
|
||||
query_size_last <= 128 && value_size_last % 8 == 0 &&
|
||||
value_size_last <= 128)) {
|
||||
if (debug) {
|
||||
TORCH_WARN(
|
||||
"Flash attention requires last dimension of inputs to be a multiple of 8 and less than or equal to 128.",
|
||||
"Got Query.size(-1): ",
|
||||
query_size_last,
|
||||
", Key.size(-1): ",
|
||||
params.key.size(-1),
|
||||
", Value.size(-1): ",
|
||||
params.value.size(-1),
|
||||
" instead.");
|
||||
"Flash attention requires q,k,v to have the same last dimension and to be a multiple of 8 and less than or equal to 128.",
|
||||
" Got Query.size(-1): ",
|
||||
query_size_last,
|
||||
", Key.size(-1): ",
|
||||
params.key.size(-1),
|
||||
", Value.size(-1): ",
|
||||
params.value.size(-1),
|
||||
" instead.");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
|
@ -393,9 +422,10 @@ inline bool use_flash_attention(sdp_params params, bool debug) {
|
|||
return false;
|
||||
#endif
|
||||
// Define gate functions that determine if a flash kernel can be ran
|
||||
constexpr std::array<bool(*)(sdp_params, bool), 7> constraints {{
|
||||
constexpr std::array<bool(*)(sdp_params, bool), 8> constraints {{
|
||||
check_runtime_disabled_flash,
|
||||
check_tensor_shapes,
|
||||
check_equal_batch_size_and_num_heads,
|
||||
check_for_attn_mask,
|
||||
check_head_dim_size,
|
||||
check_gpu_sm75_or_greater,
|
||||
|
|
@ -427,11 +457,12 @@ inline bool use_mem_efficient_attention(sdp_params params, bool debug) {
|
|||
at::kHalf, at::kFloat, at::kBFloat16};
|
||||
|
||||
// Define gate functions that determine if a flash kernel can be ran
|
||||
constexpr std::array<bool(*)(sdp_params, bool), 10> constraints{{
|
||||
constexpr std::array<bool(*)(sdp_params, bool), 11> constraints{{
|
||||
check_gpu_sm50_or_greater,
|
||||
check_runtime_disabled_mem_efficient,
|
||||
check_requires_grad_and_nested,
|
||||
check_tensor_shapes,
|
||||
check_equal_batch_size_and_num_heads,
|
||||
check_for_attn_mask,
|
||||
check_head_dim_size_mem_efficient,
|
||||
check_gpu_sm86_head_dim_128,
|
||||
|
|
|
|||
|
|
@ -1076,6 +1076,13 @@ class TestSDPA(NNTestCase):
|
|||
_do_cuda_memory_leak_check = True
|
||||
_do_cuda_non_default_stream = True
|
||||
|
||||
backend_map = {
|
||||
SDPBackend.MATH: {"enable_math": True, "enable_flash": False, "enable_mem_efficient": False},
|
||||
SDPBackend.FLASH_ATTENTION: {"enable_math": False, "enable_flash": True, "enable_mem_efficient": False},
|
||||
SDPBackend.EFFICIENT_ATTENTION: {
|
||||
"enable_math": False, "enable_flash": False, "enable_mem_efficient": True}
|
||||
}
|
||||
|
||||
def rand_tensor(self, shape: Tuple[int], device: str, dtype: torch.dtype,
|
||||
type: str, requires_grad: bool = False, packed: bool = False) -> torch.Tensor:
|
||||
"""Creates rand dense or nested tensor with given shape and type.
|
||||
|
|
@ -1480,22 +1487,22 @@ class TestSDPA(NNTestCase):
|
|||
assert torch._fused_sdp_choice(query, key, value) == (
|
||||
SDPBackend.EFFICIENT_ATTENTION if warn_only else SDPBackend.MATH)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA, "CUDA unavailable")
|
||||
def test_sdp_runtime_dispatch(self):
|
||||
# We will test all the constraints that we know will cause a failure
|
||||
# The problem is that any code path that goes down flash_attention
|
||||
# will fail on CI/CD becuase it is not compiled with the right flags
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA or not isSM86Device, "CUDA unavailable")
|
||||
def test_memory_efficeint_sm86_failure(self):
|
||||
device = 'cuda'
|
||||
dtype = torch.float16
|
||||
make_tensor = partial(self.rand_tensor, type="dense", device=device, dtype=dtype)
|
||||
if isSM86Device:
|
||||
# See check_gpu_sm86_head_dim_128 in pytorch/aten/src/ATen/native/transformers/cuda/sdp_utils.h
|
||||
size = (2, 2, 4, 128)
|
||||
q, k, v = make_tensor(size), make_tensor(size), make_tensor(size)
|
||||
with sdp_kernel(enable_mem_efficient=True, enable_flash=False, enable_math=False):
|
||||
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v, None, 0.0, False))
|
||||
# See check_gpu_sm86_head_dim_128 in pytorch/aten/src/ATen/native/transformers/cuda/sdp_utils.h
|
||||
size = (2, 2, 4, 128)
|
||||
q, k, v = make_tensor(size), make_tensor(size), make_tensor(size)
|
||||
with sdp_kernel(enable_mem_efficient=True, enable_flash=False, enable_math=False):
|
||||
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v, None, 0.0, False))
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA, "Does not support fused scaled dot product attention")
|
||||
def test_dispatch_fails_no_backend(self):
|
||||
dtype = torch.float16
|
||||
device = "cuda"
|
||||
with sdp_kernel(enable_flash=False, enable_math=False, enable_mem_efficient=False):
|
||||
size = (2, 3, 4)
|
||||
q = torch.randn(size, device=device, dtype=dtype)
|
||||
|
|
@ -1506,42 +1513,92 @@ class TestSDPA(NNTestCase):
|
|||
self.assertRaisesRegex(RuntimeError, "No viable backend for scaled_dot_product_attention was found.",
|
||||
lambda: torch.nn.functional.scaled_dot_product_attention(q, k, v))
|
||||
|
||||
if SM80OrLater:
|
||||
with sdp_kernel(enable_flash=True, enable_mem_efficient=False, enable_math=False):
|
||||
# Failures for invalid input
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA, "Does not support fused scaled dot product attention")
|
||||
@parametrize(
|
||||
"kernel",
|
||||
[SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION]
|
||||
if SM80OrLater
|
||||
else [SDPBackend.EFFICIENT_ATTENTION],
|
||||
)
|
||||
def test_invalid_fused_inputs_dim_3(self, kernel: SDPBackend):
|
||||
with sdp_kernel(**self.backend_map[kernel]):
|
||||
# Dim is not 4
|
||||
device = "cuda"
|
||||
size = (2, 3, 8)
|
||||
dtype = torch.float16
|
||||
q = torch.randn(size, device=device, dtype=dtype)
|
||||
k = torch.randn(size, device=device, dtype=dtype)
|
||||
v = torch.randn(size, device=device, dtype=dtype)
|
||||
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v, None, 0.0, False))
|
||||
|
||||
# Dim is not 4
|
||||
q = torch.randn(size, device=device, dtype=dtype)
|
||||
k = torch.randn(size, device=device, dtype=dtype)
|
||||
v = torch.randn(size, device=device, dtype=dtype)
|
||||
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v, None, 0.0, False))
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA, "Does not support fused scaled dot product attention")
|
||||
@parametrize(
|
||||
"kernel",
|
||||
[SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION]
|
||||
if SM80OrLater
|
||||
else [SDPBackend.EFFICIENT_ATTENTION],
|
||||
)
|
||||
def test_invalid_fused_inputs_broadcast(self, kernel: SDPBackend):
|
||||
with sdp_kernel(**self.backend_map[kernel]):
|
||||
# Fused Kernels don't support broadcasting
|
||||
device = "cuda"
|
||||
dtype = torch.float16
|
||||
size = (2, 4, 3, 8)
|
||||
size_broadcast = (1, 4, 3, 8)
|
||||
q = torch.randn(size_broadcast, device=device, dtype=dtype)
|
||||
k = torch.randn(size, device=device, dtype=dtype)
|
||||
v = torch.randn(size, device=device, dtype=dtype)
|
||||
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v, None, 0.0, False))
|
||||
|
||||
# The embed dim per head is not divisible by 8 for flash attention
|
||||
size = (2, 2, 3, 4)
|
||||
q, k, v = make_tensor(size), make_tensor(size), make_tensor(size)
|
||||
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v, None, 0.0, False))
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA or not SM80OrLater, "Does not support fused scaled dot product attention")
|
||||
@parametrize("kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION])
|
||||
def test_invalid_fused_inputs_head_dim(self, kernel: SDPBackend):
|
||||
with sdp_kernel(**self.backend_map[kernel]):
|
||||
# The embed dim per head is not divisible by 8 for flash attention
|
||||
device = "cuda"
|
||||
dtype = torch.float16
|
||||
make_tensor = partial(self.rand_tensor, type="dense", device=device, dtype=dtype)
|
||||
size = (2, 2, 3, 9)
|
||||
q, k, v = make_tensor(size), make_tensor(size), make_tensor(size)
|
||||
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v, None, 0.0, False))
|
||||
|
||||
# Invalid dtype for both Flash Attention and Mem Efficient Attention
|
||||
size = (2, 2, 3, 16)
|
||||
make_tensor = partial(self.rand_tensor, type="dense", device=device, dtype=torch.float64)
|
||||
q, k, v = make_tensor(size), make_tensor(size), make_tensor(size)
|
||||
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v, None, 0.0, False))
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA, "Does not support fused scaled dot product attention")
|
||||
@parametrize(
|
||||
"kernel",
|
||||
[SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION]
|
||||
if SM80OrLater
|
||||
else [SDPBackend.EFFICIENT_ATTENTION],
|
||||
)
|
||||
def test_invalid_fused_inputs_invalid_dtype(self, kernel: SDPBackend):
|
||||
with sdp_kernel(**self.backend_map[kernel]):
|
||||
# Invalid dtype for both Flash Attention and Mem Efficient Attention
|
||||
device = "cuda"
|
||||
size = (2, 2, 3, 16)
|
||||
make_tensor = partial(self.rand_tensor, type="dense", device=device, dtype=torch.float64)
|
||||
q, k, v = make_tensor(size), make_tensor(size), make_tensor(size)
|
||||
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v, None, 0.0, False))
|
||||
|
||||
# Invalid dtype for Flash Attention
|
||||
make_tensor = partial(self.rand_tensor, type="dense", device=device, dtype=torch.float32)
|
||||
q, k, v = make_tensor(size), make_tensor(size), make_tensor(size)
|
||||
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v, None, 0.0, False))
|
||||
|
||||
# Failures for unsupported SDP args
|
||||
q, k, v = make_tensor(size), make_tensor(size), make_tensor(size)
|
||||
|
||||
# Non-None attention mask
|
||||
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v, torch.ones_like(q), 0.0, False))
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA, "Does not support fused scaled dot product attention")
|
||||
@parametrize(
|
||||
"kernel",
|
||||
[SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION]
|
||||
if SM80OrLater
|
||||
else [SDPBackend.EFFICIENT_ATTENTION],
|
||||
)
|
||||
def test_invalid_fused_inputs_attn_mask_present(self, kernel: SDPBackend):
|
||||
with sdp_kernel(**self.backend_map[kernel]):
|
||||
# Failures for unsupported SDP args
|
||||
device = "cuda"
|
||||
size = (2, 2, 3, 16)
|
||||
make_tensor = partial(self.rand_tensor, type="dense", device=device, dtype=torch.float16)
|
||||
q, k, v = make_tensor(size), make_tensor(size), make_tensor(size)
|
||||
# Non-None attention mask
|
||||
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v, torch.ones_like(q), 0.0, False))
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA or not SM80OrLater, "CUDA unavailable")
|
||||
def test_unaligned_tensors(self):
|
||||
|
|
@ -1784,6 +1841,39 @@ class TestSDPA(NNTestCase):
|
|||
self.assertEqual(value.grad, value_ref.grad.to(value.grad.dtype),
|
||||
atol=grad_v_ref_atol, rtol=grad_v_ref_rtol)
|
||||
|
||||
@parametrize("kernel", [SDPBackend.MATH, SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION])
|
||||
@parametrize("device", ["cpu", "cuda"] if TEST_CUDA else ["cpu"])
|
||||
def test_invalid_inputs_different_datatypes(self, kernel: SDPBackend, device: str):
|
||||
with sdp_kernel(**self.backend_map[kernel]):
|
||||
# Different datatypes
|
||||
shape = (1, 4, 8, 16)
|
||||
query = torch.randn(shape, dtype=torch.float32, device=device)
|
||||
key = torch.randn(shape, dtype=torch.float16, device=device)
|
||||
value = torch.randn(shape, dtype=torch.float16, device=device)
|
||||
self.assertRaises(RuntimeError, lambda: F.scaled_dot_product_attention(query, key, value))
|
||||
|
||||
@parametrize("kernel", [SDPBackend.MATH, SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION])
|
||||
@parametrize("device", ["cpu", "cuda"] if TEST_CUDA else ["cpu"])
|
||||
def test_invalid_inputs_different_devices(self, kernel: SDPBackend, device: str):
|
||||
# Different devices
|
||||
shape = (1, 4, 8, 16)
|
||||
if device == "cuda":
|
||||
query = torch.randn(shape, dtype=torch.float32, device=device)
|
||||
key = torch.randn(shape, dtype=torch.float16, device='cpu')
|
||||
value = torch.randn(shape, dtype=torch.float16, device='cpu')
|
||||
self.assertRaises(RuntimeError, lambda: F.scaled_dot_product_attention(query, key, value))
|
||||
|
||||
@parametrize("kernel", [SDPBackend.MATH, SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION])
|
||||
@parametrize("device", ["cpu", "cuda"] if TEST_CUDA else ["cpu"])
|
||||
def test_invalid_inputs_1_dimensional_inputs(self, kernel: SDPBackend, device: str):
|
||||
with sdp_kernel(**self.backend_map[kernel]):
|
||||
# 1 dimensional input
|
||||
shape = (1, 4)
|
||||
query = torch.randn(4, dtype=torch.float16, device=device)
|
||||
key = torch.randn(shape, dtype=torch.float16, device=device)
|
||||
value = torch.randn(shape, dtype=torch.float16, device=device)
|
||||
self.assertRaises(RuntimeError, lambda: F.scaled_dot_product_attention(query, key, value))
|
||||
|
||||
# TODO: Replace this with instantiate_device_type_tests() to take advantage of test framework support for
|
||||
# cross device / dtype testing.
|
||||
instantiate_parametrized_tests(TestTransformers)
|
||||
|
|
|
|||
|
|
@ -7670,17 +7670,32 @@ def sample_inputs_scaled_dot_product_attention(op_info, device, dtype, requires_
|
|||
dim_4_q_shape = (batch, num_heads, seq_q, head_dim)
|
||||
dim_4_kv_shape = (batch, num_heads, seq_kv, head_dim)
|
||||
|
||||
qkv_shapes = [(dim_3_q_shape, dim_3_kv_shape), (dim_4_q_shape, dim_4_kv_shape)]
|
||||
broadcast_tuple = ((num_heads, seq_q, head_dim), (batch, num_heads, seq_kv, head_dim))
|
||||
|
||||
qkv_shapes = [(dim_3_q_shape, dim_3_kv_shape), (dim_4_q_shape, dim_4_kv_shape), broadcast_tuple]
|
||||
samples = []
|
||||
for qkv_shapes, is_causal, dropout_p in product(
|
||||
qkv_shapes, [True, False], [0.0, 0.5]):
|
||||
shape_q, shape_kv = qkv_shapes
|
||||
yield SampleInput(
|
||||
samples.append(SampleInput(
|
||||
make(shape_q),
|
||||
make(shape_kv),
|
||||
make(shape_kv),
|
||||
is_causal=is_causal,
|
||||
dropout_p=dropout_p
|
||||
)
|
||||
))
|
||||
|
||||
# Add non standard shapes
|
||||
diff_v_head_dim = SampleInput(
|
||||
make((batch, num_heads, seq_q, head_dim)),
|
||||
make((batch, num_heads, seq_kv, head_dim)),
|
||||
make((batch, num_heads, seq_kv, head_dim + 8)),
|
||||
is_causal=is_causal,
|
||||
dropout_p=dropout_p
|
||||
)
|
||||
samples.append(diff_v_head_dim)
|
||||
|
||||
yield from samples
|
||||
|
||||
def sample_inputs_pairwise_distance(op_info, device, dtype, requires_grad, **kwargs):
|
||||
make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user