|
|
|
|
@ -8,12 +8,12 @@ import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
from torch.nn.functional import scaled_dot_product_attention
|
|
|
|
|
from torch.nn.attention import sdpa_kernel, SDPBackend
|
|
|
|
|
from torch.nn.attention.bias import CausalVariant, causal_lower_right, causal_upper_left
|
|
|
|
|
from torch.nn.parameter import Parameter
|
|
|
|
|
import unittest
|
|
|
|
|
from unittest.mock import patch, MagicMock, ANY
|
|
|
|
|
import math
|
|
|
|
|
from torch.backends.cuda import sdp_kernel, SDPBackend
|
|
|
|
|
import torch.optim as optim
|
|
|
|
|
from torch.testing._internal.common_device_type import instantiate_device_type_tests, onlyCUDA, onlyCPU
|
|
|
|
|
from typing import List, Tuple, Optional
|
|
|
|
|
@ -103,12 +103,6 @@ def get_tolerances(
|
|
|
|
|
rtol = default_rtol[computed_value.dtype]
|
|
|
|
|
return atol, rtol
|
|
|
|
|
|
|
|
|
|
backend_map = {
|
|
|
|
|
SDPBackend.MATH: {"enable_math": True, "enable_flash": False, "enable_mem_efficient": False},
|
|
|
|
|
SDPBackend.FLASH_ATTENTION: {"enable_math": False, "enable_flash": True, "enable_mem_efficient": False},
|
|
|
|
|
SDPBackend.EFFICIENT_ATTENTION: {
|
|
|
|
|
"enable_math": False, "enable_flash": False, "enable_mem_efficient": True}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
def query_key_value_clones(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, dtype: torch.dtype = None):
|
|
|
|
|
""" Clones the query, key, and value tensors and moves them to the specified dtype. """
|
|
|
|
|
@ -960,7 +954,7 @@ class TestTransformers(NNTestCase):
|
|
|
|
|
f"{attn_dim}D_{'causal_' if is_causal else ''}attn_mask"
|
|
|
|
|
if attn_dim is not None else "no_attn_mask")))
|
|
|
|
|
@parametrize("dropout_p", [0.0, 0.2, 0.5])
|
|
|
|
|
@sdp_kernel(enable_flash=False, enable_mem_efficient=False)
|
|
|
|
|
@sdpa_kernel(backends=[SDPBackend.MATH])
|
|
|
|
|
def test_scaled_dot_product_attention(self, device, input_dim, attn_mask_dim, is_causal, dropout_p):
|
|
|
|
|
def sdp_ref(
|
|
|
|
|
q,
|
|
|
|
|
@ -1213,7 +1207,7 @@ class TestTransformers(NNTestCase):
|
|
|
|
|
mock_layer.assert_called_with(ANY, src_mask=ANY, is_causal=True, src_key_padding_mask=ANY)
|
|
|
|
|
|
|
|
|
|
# check expected numerical values with all kernels
|
|
|
|
|
self.is_causal_kernels(["math"], device)
|
|
|
|
|
self.is_causal_kernels([SDPBackend.MATH], device)
|
|
|
|
|
|
|
|
|
|
def is_causal_kernels(self, kernels, device):
|
|
|
|
|
def ones_tensor(*shape):
|
|
|
|
|
@ -1230,15 +1224,11 @@ class TestTransformers(NNTestCase):
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
for kernel in kernels:
|
|
|
|
|
with torch.backends.cuda.sdp_kernel(
|
|
|
|
|
enable_math=(kernel == 'math'),
|
|
|
|
|
enable_flash=(kernel == 'flash'),
|
|
|
|
|
enable_mem_efficient=(kernel == 'meff')
|
|
|
|
|
):
|
|
|
|
|
with sdpa_kernel(backends=[kernel]):
|
|
|
|
|
actual, _ = mha(qkv, qkv, qkv, attn_mask=mask, need_weights=False, is_causal=True)
|
|
|
|
|
self.assertTrue(torch.equal(actual, expected))
|
|
|
|
|
|
|
|
|
|
if kernel != 'math':
|
|
|
|
|
if kernel != SDPBackend.MATH:
|
|
|
|
|
# fails with embedding size not multiple of 4
|
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "No available kernel"):
|
|
|
|
|
qkv_f, mha_f = ones_tensor(S, L, 2), nn.MultiheadAttention(2, H).to(device)
|
|
|
|
|
@ -1254,7 +1244,7 @@ class TestTransformers(NNTestCase):
|
|
|
|
|
)
|
|
|
|
|
def test_is_causal_gpu(self):
|
|
|
|
|
device = 'cuda'
|
|
|
|
|
self.is_causal_kernels(["math", "meff"], device)
|
|
|
|
|
self.is_causal_kernels([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION], device)
|
|
|
|
|
|
|
|
|
|
def test_script_mha_in_proj_weight_none(self):
|
|
|
|
|
mha = torch.nn.MultiheadAttention(
|
|
|
|
|
@ -1342,10 +1332,10 @@ class TestSDPAFailureModes(NNTestCase):
|
|
|
|
|
size = (2, 2, 4, head_dim)
|
|
|
|
|
q, k, v = make_tensor(size), make_tensor(size), make_tensor(size)
|
|
|
|
|
|
|
|
|
|
with sdp_kernel(enable_mem_efficient=False, enable_flash=False, enable_math=True):
|
|
|
|
|
with sdpa_kernel(backends=[SDPBackend.MATH]):
|
|
|
|
|
math_ref = torch.nn.functional.scaled_dot_product_attention(q, k, v, None, 0.0, False)
|
|
|
|
|
|
|
|
|
|
with sdp_kernel(enable_mem_efficient=False, enable_flash=True, enable_math=False):
|
|
|
|
|
with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
|
|
|
|
|
# Should not fail because inputs don't require grad
|
|
|
|
|
flash_ref = torch.nn.functional.scaled_dot_product_attention(q, k, v, None, 0.0, False)
|
|
|
|
|
|
|
|
|
|
@ -1361,7 +1351,7 @@ class TestSDPAFailureModes(NNTestCase):
|
|
|
|
|
@onlyCUDA
|
|
|
|
|
def test_dispatch_fails_no_backend(self, device):
|
|
|
|
|
dtype = torch.float16
|
|
|
|
|
with sdp_kernel(enable_flash=False, enable_math=False, enable_mem_efficient=False):
|
|
|
|
|
with sdpa_kernel(backends=[SDPBackend.ERROR]):
|
|
|
|
|
size = (2, 3, 4)
|
|
|
|
|
q = torch.randn(size, device=device, dtype=dtype)
|
|
|
|
|
k = torch.randn(size, device=device, dtype=dtype)
|
|
|
|
|
@ -1378,7 +1368,7 @@ class TestSDPAFailureModes(NNTestCase):
|
|
|
|
|
PLATFORM_SPECIFIC_SDPA,
|
|
|
|
|
)
|
|
|
|
|
def test_invalid_fused_inputs_dim_3(self, device, kernel: SDPBackend):
|
|
|
|
|
with sdp_kernel(**backend_map[kernel]):
|
|
|
|
|
with sdpa_kernel(backends=[kernel]):
|
|
|
|
|
# Dim is not 4
|
|
|
|
|
size = (2, 3, 8)
|
|
|
|
|
dtype = torch.float16
|
|
|
|
|
@ -1396,7 +1386,7 @@ class TestSDPAFailureModes(NNTestCase):
|
|
|
|
|
PLATFORM_SPECIFIC_SDPA,
|
|
|
|
|
)
|
|
|
|
|
def test_invalid_fused_inputs_broadcast(self, device, kernel: SDPBackend):
|
|
|
|
|
with sdp_kernel(**backend_map[kernel]):
|
|
|
|
|
with sdpa_kernel(backends=[kernel]):
|
|
|
|
|
# Fused Kernels don't support broadcasting for dense inputs
|
|
|
|
|
dtype = torch.float16
|
|
|
|
|
size = (2, 4, 3, 8)
|
|
|
|
|
@ -1411,7 +1401,7 @@ class TestSDPAFailureModes(NNTestCase):
|
|
|
|
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Does not support fused scaled dot product attention")
|
|
|
|
|
@parametrize("kernel", PLATFORM_SPECIFIC_SDPA)
|
|
|
|
|
def test_invalid_sequence_lengths(self, device, kernel: SDPBackend):
|
|
|
|
|
with sdp_kernel(**backend_map[kernel]):
|
|
|
|
|
with sdpa_kernel(backends=[kernel]):
|
|
|
|
|
# Passing in a q,k,v with 0 length sequences will error
|
|
|
|
|
dtype = torch.float16
|
|
|
|
|
make_tensor = partial(torch.rand, device=device, dtype=dtype)
|
|
|
|
|
@ -1425,7 +1415,7 @@ class TestSDPAFailureModes(NNTestCase):
|
|
|
|
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Does not support fused scaled dot product attention")
|
|
|
|
|
@parametrize("kernel", PLATFORM_SPECIFIC_SDPA)
|
|
|
|
|
def test_invalid_last_dim_stride(self, device, kernel: SDPBackend):
|
|
|
|
|
with sdp_kernel(**backend_map[kernel]):
|
|
|
|
|
with sdpa_kernel(backends=[kernel]):
|
|
|
|
|
# Passing in a q,k,v with 0 length sequences will error
|
|
|
|
|
dtype = torch.float16
|
|
|
|
|
make_tensor = partial(torch.rand, device=device, dtype=dtype)
|
|
|
|
|
@ -1440,7 +1430,7 @@ class TestSDPAFailureModes(NNTestCase):
|
|
|
|
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not flash_attention fused scaled dot product attention")
|
|
|
|
|
@parametrize("kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION])
|
|
|
|
|
def test_invalid_fused_inputs_head_dim(self, device, kernel: SDPBackend):
|
|
|
|
|
with sdp_kernel(**backend_map[kernel]):
|
|
|
|
|
with sdpa_kernel(backends=[kernel]):
|
|
|
|
|
# The embed dim per head is not divisible by 8 for flash attention
|
|
|
|
|
dtype = torch.float16
|
|
|
|
|
make_tensor = partial(torch.rand, device=device, dtype=dtype)
|
|
|
|
|
@ -1456,7 +1446,7 @@ class TestSDPAFailureModes(NNTestCase):
|
|
|
|
|
PLATFORM_SPECIFIC_SDPA,
|
|
|
|
|
)
|
|
|
|
|
def test_invalid_fused_inputs_invalid_dtype(self, device, kernel: SDPBackend):
|
|
|
|
|
with sdp_kernel(**backend_map[kernel]):
|
|
|
|
|
with sdpa_kernel(backends=[kernel]):
|
|
|
|
|
# Invalid dtype for both Flash Attention and Mem Efficient Attention
|
|
|
|
|
size = SdpaShape(2, 2, 3, 16)
|
|
|
|
|
make_tensor = partial(torch.rand, device=device, dtype=torch.float64)
|
|
|
|
|
@ -1468,7 +1458,7 @@ class TestSDPAFailureModes(NNTestCase):
|
|
|
|
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support flash attention")
|
|
|
|
|
@parametrize("kernel", [SDPBackend.FLASH_ATTENTION])
|
|
|
|
|
def test_invalid_fused_inputs_attn_mask_present(self, device, kernel: SDPBackend):
|
|
|
|
|
with sdp_kernel(**backend_map[kernel]):
|
|
|
|
|
with sdpa_kernel(backends=[kernel]):
|
|
|
|
|
# Failures for unsupported SDP args
|
|
|
|
|
size = SdpaShape(2, 2, 3, 16)
|
|
|
|
|
make_tensor = partial(torch.rand, size, device=device, dtype=torch.float16)
|
|
|
|
|
@ -1486,7 +1476,7 @@ class TestSDPAFailureModes(NNTestCase):
|
|
|
|
|
size = SdpaShape(2, 2, 8, 5)
|
|
|
|
|
make_tensor = partial(torch.rand, size, device=device, dtype=dtype)
|
|
|
|
|
q, k, v = make_tensor(), make_tensor(), make_tensor()
|
|
|
|
|
with sdp_kernel(enable_flash=False, enable_mem_efficient=True, enable_math=False):
|
|
|
|
|
with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
|
|
|
|
|
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
|
|
|
|
|
q, k, v, None, 0.0, False))
|
|
|
|
|
|
|
|
|
|
@ -1497,7 +1487,7 @@ class TestSDPAFailureModes(NNTestCase):
|
|
|
|
|
size = SdpaShape(16, 16, 32, 32)
|
|
|
|
|
make_tensor = partial(torch.rand, size, device=device, dtype=dtype)
|
|
|
|
|
q, k, v = make_tensor(), make_tensor(), make_tensor()
|
|
|
|
|
with sdp_kernel(enable_flash=True, enable_mem_efficient=False, enable_math=False):
|
|
|
|
|
with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
|
|
|
|
|
with self.assertWarnsRegex(UserWarning, "Expected query, key and value to all be of dtype: {Half, BFloat16}"):
|
|
|
|
|
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
|
|
|
|
|
q, k, v, None, 0.0, False))
|
|
|
|
|
@ -1510,7 +1500,7 @@ class TestSDPAFailureModes(NNTestCase):
|
|
|
|
|
make_tensor = partial(torch.rand, size, device=device, dtype=dtype)
|
|
|
|
|
q, k, v = make_tensor(), make_tensor(), make_tensor()
|
|
|
|
|
with torch.autocast(device_type='cuda', dtype=torch.float16):
|
|
|
|
|
with sdp_kernel(enable_flash=True, enable_mem_efficient=False, enable_math=False):
|
|
|
|
|
with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
|
|
|
|
|
_ = torch.nn.functional.scaled_dot_product_attention(
|
|
|
|
|
q, k, v, None, 0.0, False)
|
|
|
|
|
|
|
|
|
|
@ -1522,14 +1512,14 @@ class TestSDPAFailureModes(NNTestCase):
|
|
|
|
|
make_tensor = partial(torch.rand, size, device=device, dtype=dtype)
|
|
|
|
|
q, k, v = make_tensor(), make_tensor(), make_tensor()
|
|
|
|
|
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
|
|
|
|
|
with sdp_kernel(enable_flash=True, enable_mem_efficient=False, enable_math=False):
|
|
|
|
|
with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
|
|
|
|
|
_ = torch.nn.functional.scaled_dot_product_attention(
|
|
|
|
|
q, k, v, None, 0.0, False)
|
|
|
|
|
|
|
|
|
|
# Note: do not truncate the list according to platforms. These tests should always raise errors.
|
|
|
|
|
@parametrize("kernel", [SDPBackend.MATH, SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION])
|
|
|
|
|
def test_invalid_inputs_different_datatypes(self, device, kernel: SDPBackend):
|
|
|
|
|
with sdp_kernel(**backend_map[kernel]):
|
|
|
|
|
with sdpa_kernel(backends=[kernel]):
|
|
|
|
|
# Different datatypes
|
|
|
|
|
shape = (1, 4, 8, 16)
|
|
|
|
|
query = torch.randn(shape, dtype=torch.float32, device=device)
|
|
|
|
|
@ -1549,7 +1539,7 @@ class TestSDPAFailureModes(NNTestCase):
|
|
|
|
|
|
|
|
|
|
@parametrize("kernel", [SDPBackend.MATH, SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION])
|
|
|
|
|
def test_invalid_inputs_1_dimensional_inputs(self, device, kernel: SDPBackend):
|
|
|
|
|
with sdp_kernel(**backend_map[kernel]):
|
|
|
|
|
with sdpa_kernel(backends=[kernel]):
|
|
|
|
|
# 1 dimensional input
|
|
|
|
|
shape = (1, 4)
|
|
|
|
|
query = torch.randn(4, dtype=torch.float16, device=device)
|
|
|
|
|
@ -1575,7 +1565,7 @@ class TestSDPAFailureModes(NNTestCase):
|
|
|
|
|
key = rand_nested_tensor(k_shape).transpose(1, 2)
|
|
|
|
|
value = rand_nested_tensor(v_shape).transpose(1, 2)
|
|
|
|
|
|
|
|
|
|
with sdp_kernel(enable_flash=False, enable_math=False, enable_mem_efficient=True):
|
|
|
|
|
with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
|
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "No available kernel"):
|
|
|
|
|
torch.nn.functional.scaled_dot_product_attention(
|
|
|
|
|
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False)
|
|
|
|
|
@ -1588,7 +1578,7 @@ class TestSDPAFailureModes(NNTestCase):
|
|
|
|
|
shape = SdpaShape(5, 8, seq_len_list, 57)
|
|
|
|
|
make_tensor = partial(rand_sdpa_tensor, shape=shape, type="nested", device=device, dtype=dtype)
|
|
|
|
|
q, k, v = make_tensor(), make_tensor(), make_tensor()
|
|
|
|
|
with torch.backends.cuda.sdp_kernel(enable_math=False, enable_flash=True, enable_mem_efficient=False):
|
|
|
|
|
with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
|
|
|
|
|
with self.assertWarnsRegex(UserWarning, "For NestedTensor inputs, Flash attention requires"):
|
|
|
|
|
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
|
|
|
|
|
q, k, v, None, 0.0, False))
|
|
|
|
|
@ -1602,7 +1592,7 @@ class TestSDPAFailureModes(NNTestCase):
|
|
|
|
|
size = SdpaShape(16, 16, 32, 32)
|
|
|
|
|
make_tensor = partial(torch.rand, size, device=device, dtype=dtype)
|
|
|
|
|
q, k, v = make_tensor(), make_tensor(), make_tensor()
|
|
|
|
|
with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]):
|
|
|
|
|
with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
|
|
|
|
|
with self.assertWarnsRegex(UserWarning, "Expected query, key and value to all be of dtype: {Half, Float}"):
|
|
|
|
|
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
|
|
|
|
|
q, k, v, None, 0.0, False))
|
|
|
|
|
@ -1629,7 +1619,7 @@ class TestSDPAFailureModes(NNTestCase):
|
|
|
|
|
key = key.transpose(1, 2)
|
|
|
|
|
value = value.transpose(1, 2)
|
|
|
|
|
|
|
|
|
|
with sdp_kernel(**backend_map[fused_kernel]):
|
|
|
|
|
with sdpa_kernel(backends=[fused_kernel]):
|
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "No available kernel"):
|
|
|
|
|
torch.nn.functional.scaled_dot_product_attention(
|
|
|
|
|
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False)
|
|
|
|
|
@ -1653,7 +1643,7 @@ class TestSDPAFailureModes(NNTestCase):
|
|
|
|
|
key = key.transpose(1, 2)
|
|
|
|
|
value = value.transpose(1, 2)
|
|
|
|
|
|
|
|
|
|
with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]):
|
|
|
|
|
with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
|
|
|
|
|
with self.assertWarnsRegex(UserWarning, "Both fused kernels do not support training with broadcasted NT inputs"):
|
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "No available kernel"):
|
|
|
|
|
out = torch.nn.functional.scaled_dot_product_attention(
|
|
|
|
|
@ -1669,7 +1659,7 @@ class TestSDPAFailureModes(NNTestCase):
|
|
|
|
|
make_kv = partial(torch.rand, kv_shape, device=device, dtype=dtype)
|
|
|
|
|
q, k, v = make_q(), make_kv(), make_kv()
|
|
|
|
|
warning_str = "Flash attention does not support the is_causal flag when seqlen_q != seqlen_k."
|
|
|
|
|
with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]):
|
|
|
|
|
with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
|
|
|
|
|
with self.assertWarnsRegex(UserWarning, warning_str):
|
|
|
|
|
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
|
|
|
|
|
q, k, v, None, 0.0, is_causal=True))
|
|
|
|
|
@ -1745,7 +1735,7 @@ class TestSDPA(NNTestCase):
|
|
|
|
|
key = key.contiguous()
|
|
|
|
|
value = value.contiguous()
|
|
|
|
|
|
|
|
|
|
with sdp_kernel(enable_math=True, enable_mem_efficient=False, enable_flash=False):
|
|
|
|
|
with sdpa_kernel(backends=[SDPBackend.MATH]):
|
|
|
|
|
assert gradcheck(lambda *args, **kwargs:
|
|
|
|
|
wrapper_set_seed(torch.nn.functional.scaled_dot_product_attention, *args, **kwargs),
|
|
|
|
|
(query, key, value, None, 0.0, False)
|
|
|
|
|
@ -1820,10 +1810,10 @@ class TestSDPA(NNTestCase):
|
|
|
|
|
q2 = q2.view(batch_size, seq_len, n_head, head_dim).transpose(1, 2)
|
|
|
|
|
v2 = v2.view(batch_size, seq_len, n_head, head_dim).transpose(1, 2)
|
|
|
|
|
|
|
|
|
|
with sdp_kernel(**backend_map[fused_kernel]):
|
|
|
|
|
with sdpa_kernel(backends=[fused_kernel]):
|
|
|
|
|
actual = torch.nn.functional.scaled_dot_product_attention(
|
|
|
|
|
q, k, v, attn_mask=None, dropout_p=0.0, is_causal=causal)
|
|
|
|
|
with sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False):
|
|
|
|
|
with sdpa_kernel(backends=[SDPBackend.MATH]):
|
|
|
|
|
math_ref = torch.nn.functional.scaled_dot_product_attention(
|
|
|
|
|
q2, k2, v2, attn_mask=None, dropout_p=0.0, is_causal=causal)
|
|
|
|
|
|
|
|
|
|
@ -1907,10 +1897,10 @@ class TestSDPA(NNTestCase):
|
|
|
|
|
k2 = k2.view(batch_size, kv_seq_len, n_head, head_dim).transpose(1, 2)
|
|
|
|
|
v2 = v2.view(batch_size, kv_seq_len, n_head, head_dim).transpose(1, 2)
|
|
|
|
|
|
|
|
|
|
with sdp_kernel(**backend_map[fused_kernel]):
|
|
|
|
|
with sdpa_kernel(backends=[fused_kernel]):
|
|
|
|
|
actual = torch.nn.functional.scaled_dot_product_attention(
|
|
|
|
|
q, k, v, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
|
|
|
|
|
with sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False):
|
|
|
|
|
with sdpa_kernel(backends=[SDPBackend.MATH]):
|
|
|
|
|
if not bool_mask and dtype is torch.bfloat16:
|
|
|
|
|
attn_mask = attn_mask.float()
|
|
|
|
|
math_ref = torch.nn.functional.scaled_dot_product_attention(
|
|
|
|
|
@ -1944,7 +1934,7 @@ class TestSDPA(NNTestCase):
|
|
|
|
|
|
|
|
|
|
x = torch.randn(1, 3, 64, 64, device=device)
|
|
|
|
|
ref_result = ref(x)
|
|
|
|
|
with sdp_kernel(**backend_map[kernel]):
|
|
|
|
|
with sdpa_kernel(backends=[kernel]):
|
|
|
|
|
sdp_math = torch.nn.functional.scaled_dot_product_attention(x, x, x, scale=-1.0 / 0.0001)
|
|
|
|
|
self.assertEqual(ref_result, sdp_math)
|
|
|
|
|
|
|
|
|
|
@ -2038,7 +2028,7 @@ class TestSDPACudaOnly(NNTestCase):
|
|
|
|
|
mask = torch.randn((num_heads, seq_len_q, seq_len_kv), device=device, dtype=dtype)
|
|
|
|
|
elif mask_dim == 4:
|
|
|
|
|
mask = torch.randn((batch, num_heads, seq_len_q, seq_len_kv), device=device, dtype=dtype)
|
|
|
|
|
with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]):
|
|
|
|
|
with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
|
|
|
|
|
out = F.scaled_dot_product_attention(query, key, value, mask)
|
|
|
|
|
out.sum().backward()
|
|
|
|
|
|
|
|
|
|
@ -2052,7 +2042,7 @@ class TestSDPACudaOnly(NNTestCase):
|
|
|
|
|
kv_shape = SdpaShape(batch, num_heads, seq_len_kv, head_dim)
|
|
|
|
|
key, value = make_tensor(kv_shape), make_tensor(kv_shape)
|
|
|
|
|
mask = torch.randn((batch, num_heads, seq_len_q, seq_len_kv), device=device, dtype=dtype)
|
|
|
|
|
with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]):
|
|
|
|
|
with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
|
|
|
|
|
out = F.scaled_dot_product_attention(query, key, value, mask)
|
|
|
|
|
out.sum().backward()
|
|
|
|
|
|
|
|
|
|
@ -2067,7 +2057,7 @@ class TestSDPACudaOnly(NNTestCase):
|
|
|
|
|
key, value = make_tensor(kv_shape), make_tensor(kv_shape)
|
|
|
|
|
mask = torch.randn((batch, num_heads, seq_len_q, seq_len_kv), device=device, dtype=dtype)
|
|
|
|
|
mask = torch.as_strided(mask, (batch, num_heads, seq_len_q, seq_len_kv), (0, 0, 0, 1))
|
|
|
|
|
with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]):
|
|
|
|
|
with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
|
|
|
|
|
out = F.scaled_dot_product_attention(query, key, value, mask)
|
|
|
|
|
out.sum().backward()
|
|
|
|
|
|
|
|
|
|
@ -2084,7 +2074,7 @@ class TestSDPACudaOnly(NNTestCase):
|
|
|
|
|
kv_shape = SdpaShape(batch, num_heads, seq_len_kv, head_dim)
|
|
|
|
|
key, value = make_tensor(kv_shape), make_tensor(kv_shape)
|
|
|
|
|
mask = torch.randn((batch, num_heads, seq_len_q, seq_len_kv), device=device, dtype=dtype)
|
|
|
|
|
with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]):
|
|
|
|
|
with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
|
|
|
|
|
out = F.scaled_dot_product_attention(query, key, value, mask)
|
|
|
|
|
out.sum().backward()
|
|
|
|
|
|
|
|
|
|
@ -2113,7 +2103,7 @@ class TestSDPACudaOnly(NNTestCase):
|
|
|
|
|
value = torch.randn(value_num_elements, device=device).as_strided(value_size, value_strides)
|
|
|
|
|
bias = torch.randn(attention_mask_num_elements, device=device).as_strided(attention_mask_size, attn_mask_strides)
|
|
|
|
|
|
|
|
|
|
with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]):
|
|
|
|
|
with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
|
|
|
|
|
out = F.scaled_dot_product_attention(query, key, value, bias)
|
|
|
|
|
out_contig = F.scaled_dot_product_attention(query, key, value, bias.contiguous())
|
|
|
|
|
|
|
|
|
|
@ -2152,10 +2142,10 @@ class TestSDPACudaOnly(NNTestCase):
|
|
|
|
|
key = key.contiguous()
|
|
|
|
|
value = value.contiguous()
|
|
|
|
|
|
|
|
|
|
with sdp_kernel(enable_flash=False, enable_math=False, enable_mem_efficient=True):
|
|
|
|
|
with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
|
|
|
|
|
actual = torch.nn.functional.scaled_dot_product_attention(
|
|
|
|
|
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False)
|
|
|
|
|
with sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False):
|
|
|
|
|
with sdpa_kernel(backends=[SDPBackend.MATH]):
|
|
|
|
|
math_ref = torch.nn.functional.scaled_dot_product_attention(
|
|
|
|
|
query.contiguous(), key.contiguous(), value.contiguous(),
|
|
|
|
|
attn_mask=None, dropout_p=0.0, is_causal=False)
|
|
|
|
|
@ -2196,11 +2186,11 @@ class TestSDPACudaOnly(NNTestCase):
|
|
|
|
|
key_lp = key_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
|
|
|
|
|
value_lp = value_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
|
|
|
|
|
|
|
|
|
|
with sdp_kernel(**backend_map[fused_kernel]):
|
|
|
|
|
with sdpa_kernel(backends=[fused_kernel]):
|
|
|
|
|
actual = torch.nn.functional.scaled_dot_product_attention(
|
|
|
|
|
query_lp, key_lp, value_lp, attn_mask=None, dropout_p=0.0, is_causal=False)
|
|
|
|
|
|
|
|
|
|
with sdp_kernel(**backend_map[SDPBackend.MATH]):
|
|
|
|
|
with sdpa_kernel(backends=[SDPBackend.MATH]):
|
|
|
|
|
math_ref_lp = torch.nn.functional.scaled_dot_product_attention(
|
|
|
|
|
query_lp.contiguous(), key_lp.contiguous(), value_lp.contiguous(),
|
|
|
|
|
attn_mask=None, dropout_p=0.0, is_causal=False)
|
|
|
|
|
@ -2259,10 +2249,10 @@ class TestSDPACudaOnly(NNTestCase):
|
|
|
|
|
key_lp = key_lp.contiguous()
|
|
|
|
|
value_lp = value_lp.contiguous()
|
|
|
|
|
|
|
|
|
|
with sdp_kernel(enable_math=True, enable_mem_efficient=False, enable_flash=False):
|
|
|
|
|
with sdpa_kernel(backends=[SDPBackend.MATH]):
|
|
|
|
|
out = torch.nn.functional.scaled_dot_product_attention(query, key, value, None, 0.0, is_causal)
|
|
|
|
|
|
|
|
|
|
with sdp_kernel(enable_math=False, enable_mem_efficient=True, enable_flash=False):
|
|
|
|
|
with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
|
|
|
|
|
out_lp = torch.nn.functional.scaled_dot_product_attention(
|
|
|
|
|
query_lp, key_lp, value_lp, None, 0.0, is_causal)
|
|
|
|
|
|
|
|
|
|
@ -2308,10 +2298,10 @@ class TestSDPACudaOnly(NNTestCase):
|
|
|
|
|
key_lp = key_lp.contiguous()
|
|
|
|
|
value_lp = value_lp.contiguous()
|
|
|
|
|
|
|
|
|
|
with sdp_kernel(enable_math=True, enable_mem_efficient=False, enable_flash=False):
|
|
|
|
|
with sdpa_kernel(backends=[SDPBackend.MATH]):
|
|
|
|
|
out = torch.nn.functional.scaled_dot_product_attention(query, key, value, None, 0.0, is_causal)
|
|
|
|
|
|
|
|
|
|
with sdp_kernel(enable_math=False, enable_mem_efficient=False, enable_flash=True):
|
|
|
|
|
with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
|
|
|
|
|
out_lp = torch.nn.functional.scaled_dot_product_attention(
|
|
|
|
|
query_lp, key_lp, value_lp, None, 0.0, is_causal)
|
|
|
|
|
|
|
|
|
|
@ -2369,7 +2359,8 @@ class TestSDPACudaOnly(NNTestCase):
|
|
|
|
|
query, key, value = make_tensor(shape), make_tensor(shape), make_tensor(shape)
|
|
|
|
|
|
|
|
|
|
with use_deterministic_algorithims(True, warn_only=warn_only):
|
|
|
|
|
with sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=True):
|
|
|
|
|
# Note that this should swith to a testing version with we remove old context manager
|
|
|
|
|
with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]):
|
|
|
|
|
assert torch._fused_sdp_choice(query, key, value) == SDPBackend.EFFICIENT_ATTENTION.value
|
|
|
|
|
|
|
|
|
|
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Platform does not support fused SDPA")
|
|
|
|
|
@ -2389,7 +2380,7 @@ class TestSDPACudaOnly(NNTestCase):
|
|
|
|
|
else contextlib.nullcontext()
|
|
|
|
|
)
|
|
|
|
|
with use_deterministic_algorithims(True, warn_only=warn_only):
|
|
|
|
|
with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]):
|
|
|
|
|
with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
|
|
|
|
|
with warning_context:
|
|
|
|
|
torch.nn.functional.scaled_dot_product_attention(query, key, value).sum().backward()
|
|
|
|
|
|
|
|
|
|
@ -2406,7 +2397,7 @@ class TestSDPACudaOnly(NNTestCase):
|
|
|
|
|
value = torch.rand(batch_size, n_heads, seq_len, head_dim,
|
|
|
|
|
device=device, dtype=dtype, requires_grad=True)
|
|
|
|
|
|
|
|
|
|
with sdp_kernel(enable_mem_efficient=True, enable_math=False, enable_flash=False):
|
|
|
|
|
with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
|
|
|
|
|
# Run once to establish baseline
|
|
|
|
|
out = F.scaled_dot_product_attention(query, key, value)
|
|
|
|
|
upward_grad = torch.rand_like(out)
|
|
|
|
|
@ -2483,13 +2474,13 @@ class TestSDPACudaOnly(NNTestCase):
|
|
|
|
|
query_ref, key_ref, value_ref = query_key_value_clones(query, key, value, dtype=higher_precision_dtype)
|
|
|
|
|
|
|
|
|
|
# Create real output
|
|
|
|
|
with sdp_kernel(enable_mem_efficient=True, enable_flash=False, enable_math=False):
|
|
|
|
|
with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
|
|
|
|
|
# Set the seed and run the kernel
|
|
|
|
|
torch.manual_seed(seed)
|
|
|
|
|
out = F.scaled_dot_product_attention(query, key, value, dropout_p=dropout_p, is_causal=is_causal, scale=scale)
|
|
|
|
|
|
|
|
|
|
if dropout_p == 0.0:
|
|
|
|
|
with sdp_kernel(enable_math=True, enable_flash=False, enable_mem_efficient=False):
|
|
|
|
|
with sdpa_kernel(backends=[SDPBackend.MATH]):
|
|
|
|
|
# High Precision Math Reference
|
|
|
|
|
out_ref = F.scaled_dot_product_attention(query_ref, key_ref, value_ref,
|
|
|
|
|
dropout_p=dropout_p, is_causal=is_causal, scale=scale)
|
|
|
|
|
@ -2589,14 +2580,14 @@ class TestSDPACudaOnly(NNTestCase):
|
|
|
|
|
attn_mask_ref = attn_mask.detach().to(higher_precision_dtype).requires_grad_(True)
|
|
|
|
|
|
|
|
|
|
# Create real output
|
|
|
|
|
with sdp_kernel(enable_mem_efficient=True, enable_flash=False, enable_math=False):
|
|
|
|
|
with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
|
|
|
|
|
# Set the seed and run the kernel
|
|
|
|
|
torch.manual_seed(seed)
|
|
|
|
|
out = F.scaled_dot_product_attention(query, key, value, attn_mask, dropout_p=dropout_p,
|
|
|
|
|
is_causal=is_causal, scale=scale)
|
|
|
|
|
|
|
|
|
|
if dropout_p == 0.0:
|
|
|
|
|
with sdp_kernel(enable_math=True, enable_flash=False, enable_mem_efficient=False):
|
|
|
|
|
with sdpa_kernel(backends=[SDPBackend.MATH]):
|
|
|
|
|
# High Precision Math Reference
|
|
|
|
|
out_ref = F.scaled_dot_product_attention(query_ref, key_ref, value_ref, attn_mask_ref,
|
|
|
|
|
dropout_p=dropout_p, is_causal=is_causal, scale=scale)
|
|
|
|
|
@ -2710,9 +2701,9 @@ class TestSDPACudaOnly(NNTestCase):
|
|
|
|
|
if not is_dropout:
|
|
|
|
|
# Problem: We pad sizes in the composite region of the top level SDPA. But we need the
|
|
|
|
|
# Debug mask when have dropout. So I am going to manualy pad up here when testing dropout
|
|
|
|
|
with sdp_kernel(enable_math=False, enable_flash=True, enable_mem_efficient=False):
|
|
|
|
|
with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
|
|
|
|
|
out = F.scaled_dot_product_attention(query, key, value, dropout_p=dropout_p, is_causal=is_causal, scale=scale)
|
|
|
|
|
with sdp_kernel(enable_math=True, enable_flash=False, enable_mem_efficient=False):
|
|
|
|
|
with sdpa_kernel(backends=[SDPBackend.MATH]):
|
|
|
|
|
# High Precision Math Reference
|
|
|
|
|
out_ref = F.scaled_dot_product_attention(
|
|
|
|
|
query_ref, key_ref, value_ref, is_causal=is_causal, scale=scale)
|
|
|
|
|
@ -2884,7 +2875,7 @@ class TestSDPACudaOnly(NNTestCase):
|
|
|
|
|
# replays produce different results
|
|
|
|
|
self.assertNotEqual(out_first, out)
|
|
|
|
|
|
|
|
|
|
with sdp_kernel(enable_math=True, enable_flash=False, enable_mem_efficient=False):
|
|
|
|
|
with sdpa_kernel(backends=[SDPBackend.MATH]):
|
|
|
|
|
if dropout_p == 0.0:
|
|
|
|
|
# High Precision Math Reference
|
|
|
|
|
out_ref = F.scaled_dot_product_attention(query_ref, key_ref, value_ref,
|
|
|
|
|
@ -2964,10 +2955,10 @@ class TestSDPACudaOnly(NNTestCase):
|
|
|
|
|
key = key.transpose(1, 2)
|
|
|
|
|
value = value.transpose(1, 2)
|
|
|
|
|
|
|
|
|
|
with sdp_kernel(**backend_map[fused_kernel]):
|
|
|
|
|
with sdpa_kernel(backends=[fused_kernel]):
|
|
|
|
|
actual = torch.nn.functional.scaled_dot_product_attention(
|
|
|
|
|
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False)
|
|
|
|
|
with sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False):
|
|
|
|
|
with sdpa_kernel(backends=[SDPBackend.MATH]):
|
|
|
|
|
math_ref = torch.nn.functional.scaled_dot_product_attention(
|
|
|
|
|
query.contiguous().to(torch.float32),
|
|
|
|
|
key.contiguous().to(torch.float32),
|
|
|
|
|
@ -3056,10 +3047,10 @@ class TestSDPACudaOnly(NNTestCase):
|
|
|
|
|
key = key.transpose(1, 2)
|
|
|
|
|
value = value.transpose(1, 2)
|
|
|
|
|
|
|
|
|
|
with sdp_kernel(**backend_map[kernel]):
|
|
|
|
|
with sdpa_kernel(backends=[kernel]):
|
|
|
|
|
actual = torch.nn.functional.scaled_dot_product_attention(
|
|
|
|
|
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False)
|
|
|
|
|
with sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False):
|
|
|
|
|
with sdpa_kernel(backends=[SDPBackend.MATH]):
|
|
|
|
|
math_ref = torch.nn.functional.scaled_dot_product_attention(
|
|
|
|
|
query_expanded.contiguous(), key_expanded.contiguous(), value_expanded.contiguous(),
|
|
|
|
|
attn_mask=None, dropout_p=0.0, is_causal=False)
|
|
|
|
|
@ -3090,10 +3081,10 @@ class TestSDPACudaOnly(NNTestCase):
|
|
|
|
|
key = key.transpose(1, 2)
|
|
|
|
|
value = value.transpose(1, 2)
|
|
|
|
|
|
|
|
|
|
with sdp_kernel(enable_flash=False, enable_math=False, enable_mem_efficient=True):
|
|
|
|
|
with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
|
|
|
|
|
actual = torch.nn.functional.scaled_dot_product_attention(
|
|
|
|
|
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False)
|
|
|
|
|
with sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False):
|
|
|
|
|
with sdpa_kernel(backends=[SDPBackend.MATH]):
|
|
|
|
|
math_ref = torch.nn.functional.scaled_dot_product_attention(
|
|
|
|
|
query_expanded.contiguous(), key.contiguous(), value_expanded.contiguous(),
|
|
|
|
|
attn_mask=None, dropout_p=0.0, is_causal=False)
|
|
|
|
|
@ -3146,9 +3137,9 @@ class TestSDPACudaOnly(NNTestCase):
|
|
|
|
|
is_dropout = dropout_p > 0.0
|
|
|
|
|
|
|
|
|
|
if not is_dropout:
|
|
|
|
|
with sdp_kernel(enable_math=False, enable_flash=True, enable_mem_efficient=False):
|
|
|
|
|
with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
|
|
|
|
|
out = F.scaled_dot_product_attention(query, key, value, dropout_p=dropout_p, is_causal=is_causal, scale=scale)
|
|
|
|
|
with sdp_kernel(enable_math=True, enable_flash=False, enable_mem_efficient=False):
|
|
|
|
|
with sdpa_kernel(backends=[SDPBackend.MATH]):
|
|
|
|
|
# High Precision Math Reference
|
|
|
|
|
out_ref = F.scaled_dot_product_attention(
|
|
|
|
|
query_ref, key_ref, value_ref, is_causal=is_causal, scale=scale)
|
|
|
|
|
|