Remove sdp_kernel and replace with sdpa_kernel in attention namespace (#114689)

# Summary
Simplification of Backend Selection

This PR deprecates the `torch.backends/cuda/sdp_kernel` context manager and replaces it with a new context manager `torch.nn.attention.sdpa_kernel`. This context manager also changes the api for this context manager.

For `sdp_kernel` one would specify the backend choice by taking the negation of what kernel they would like to run. The purpose of this backend manager was to only to be a debugging tool, "turn off the math backend" and see if you can run one of the fused implementations.

Problems:
- This pattern makes sense if majority of users don't care to know anything about the backends that can be run. However, if users are seeking to use this context manager then they are explicitly trying to run a specific backend.
- This is not scalable. We are working on adding the cudnn backend and this API makes it so so that more implementations will need to be turned off if user wants to explicitly run a given backend.
- Discoverability of the current context manager. It is somewhat un-intutive that this backend manager is in backends/cuda/init when this now also controls the CPU fused kernel behavior. I think centralizing to attention namespace will be helpful.

Other concerns:
- Typically backends (kernels) for operators are entirely hidden from users and implementation details of the framework. We have exposed this to users already, albeit not by default and with beta warnings. Does making backends choices even more explicit lead to problems when we potentially want to remove existing backends, (perhaps inputs shapes will get covered by newer backends).

A nice side effect is now that we aren't using the `BACKEND_MAP` in test_transformers many, many dynamo failures are passing for CPU tests.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114689
Approved by: https://github.com/cpuhrsch
This commit is contained in:
drisspg 2024-01-24 22:28:04 +00:00 committed by PyTorch MergeBot
parent 77186af028
commit 4e29f01bf2
13 changed files with 225 additions and 1073 deletions

View File

@ -0,0 +1,11 @@
.. role:: hidden
:class: hidden-section
.. currentmodule:: {{ module }}
{{ name | underline}}
.. autoclass:: {{ name }}
:members:
.. autogenerated from source/_templates/autosummary/class.rst

View File

@ -68,8 +68,6 @@ torch.backends.cuda
.. autofunction:: torch.backends.cuda.preferred_linalg_library .. autofunction:: torch.backends.cuda.preferred_linalg_library
.. autoclass:: torch.backends.cuda.SDPBackend
.. autoclass:: torch.backends.cuda.SDPAParams .. autoclass:: torch.backends.cuda.SDPAParams
.. autofunction:: torch.backends.cuda.flash_sdp_enabled .. autofunction:: torch.backends.cuda.flash_sdp_enabled

View File

@ -93,7 +93,7 @@ Features described in this documentation are classified by release status:
torch.package <package> torch.package <package>
profiler profiler
nn.init nn.init
nn.attention.bias nn.attention
onnx onnx
optim optim
complex_numbers complex_numbers

View File

@ -10,7 +10,13 @@ torch.nn.attention.bias
CausalBias CausalBias
========== ==========
.. autoclass:: CausalBias .. autosummary::
:toctree: generated
:nosignatures:
:template: classnoinheritance.rst
CausalBias
.. autosummary:: .. autosummary::
:toctree: generated :toctree: generated

View File

@ -0,0 +1,28 @@
.. role:: hidden
:class: hidden-section
torch.nn.attention
==================
.. automodule:: torch.nn.attention
Utils
-------------------
.. autosummary::
:toctree: generated
:nosignatures:
sdpa_kernel
SDPBackend
Submodules
----------
.. autosummary::
:nosignatures:
bias
.. toctree::
:hidden:
nn.attention.bias

View File

@ -527,7 +527,6 @@ Lazy Modules Initialization
.. This module needs to be documented. Adding here in the meantime .. This module needs to be documented. Adding here in the meantime
.. for tracking purposes .. for tracking purposes
.. py:module:: torch.nn.attention
.. py:module:: torch.nn.backends .. py:module:: torch.nn.backends
.. py:module:: torch.nn.utils.stateless .. py:module:: torch.nn.utils.stateless
.. py:module:: torch.nn.backends.thnn .. py:module:: torch.nn.backends.thnn

View File

@ -8,12 +8,12 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn.functional import scaled_dot_product_attention from torch.nn.functional import scaled_dot_product_attention
from torch.nn.attention import sdpa_kernel, SDPBackend
from torch.nn.attention.bias import CausalVariant, causal_lower_right, causal_upper_left from torch.nn.attention.bias import CausalVariant, causal_lower_right, causal_upper_left
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
import unittest import unittest
from unittest.mock import patch, MagicMock, ANY from unittest.mock import patch, MagicMock, ANY
import math import math
from torch.backends.cuda import sdp_kernel, SDPBackend
import torch.optim as optim import torch.optim as optim
from torch.testing._internal.common_device_type import instantiate_device_type_tests, onlyCUDA, onlyCPU from torch.testing._internal.common_device_type import instantiate_device_type_tests, onlyCUDA, onlyCPU
from typing import List, Tuple, Optional from typing import List, Tuple, Optional
@ -103,12 +103,6 @@ def get_tolerances(
rtol = default_rtol[computed_value.dtype] rtol = default_rtol[computed_value.dtype]
return atol, rtol return atol, rtol
backend_map = {
SDPBackend.MATH: {"enable_math": True, "enable_flash": False, "enable_mem_efficient": False},
SDPBackend.FLASH_ATTENTION: {"enable_math": False, "enable_flash": True, "enable_mem_efficient": False},
SDPBackend.EFFICIENT_ATTENTION: {
"enable_math": False, "enable_flash": False, "enable_mem_efficient": True}
}
def query_key_value_clones(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, dtype: torch.dtype = None): def query_key_value_clones(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, dtype: torch.dtype = None):
""" Clones the query, key, and value tensors and moves them to the specified dtype. """ """ Clones the query, key, and value tensors and moves them to the specified dtype. """
@ -960,7 +954,7 @@ class TestTransformers(NNTestCase):
f"{attn_dim}D_{'causal_' if is_causal else ''}attn_mask" f"{attn_dim}D_{'causal_' if is_causal else ''}attn_mask"
if attn_dim is not None else "no_attn_mask"))) if attn_dim is not None else "no_attn_mask")))
@parametrize("dropout_p", [0.0, 0.2, 0.5]) @parametrize("dropout_p", [0.0, 0.2, 0.5])
@sdp_kernel(enable_flash=False, enable_mem_efficient=False) @sdpa_kernel(backends=[SDPBackend.MATH])
def test_scaled_dot_product_attention(self, device, input_dim, attn_mask_dim, is_causal, dropout_p): def test_scaled_dot_product_attention(self, device, input_dim, attn_mask_dim, is_causal, dropout_p):
def sdp_ref( def sdp_ref(
q, q,
@ -1213,7 +1207,7 @@ class TestTransformers(NNTestCase):
mock_layer.assert_called_with(ANY, src_mask=ANY, is_causal=True, src_key_padding_mask=ANY) mock_layer.assert_called_with(ANY, src_mask=ANY, is_causal=True, src_key_padding_mask=ANY)
# check expected numerical values with all kernels # check expected numerical values with all kernels
self.is_causal_kernels(["math"], device) self.is_causal_kernels([SDPBackend.MATH], device)
def is_causal_kernels(self, kernels, device): def is_causal_kernels(self, kernels, device):
def ones_tensor(*shape): def ones_tensor(*shape):
@ -1230,15 +1224,11 @@ class TestTransformers(NNTestCase):
) )
for kernel in kernels: for kernel in kernels:
with torch.backends.cuda.sdp_kernel( with sdpa_kernel(backends=[kernel]):
enable_math=(kernel == 'math'),
enable_flash=(kernel == 'flash'),
enable_mem_efficient=(kernel == 'meff')
):
actual, _ = mha(qkv, qkv, qkv, attn_mask=mask, need_weights=False, is_causal=True) actual, _ = mha(qkv, qkv, qkv, attn_mask=mask, need_weights=False, is_causal=True)
self.assertTrue(torch.equal(actual, expected)) self.assertTrue(torch.equal(actual, expected))
if kernel != 'math': if kernel != SDPBackend.MATH:
# fails with embedding size not multiple of 4 # fails with embedding size not multiple of 4
with self.assertRaisesRegex(RuntimeError, "No available kernel"): with self.assertRaisesRegex(RuntimeError, "No available kernel"):
qkv_f, mha_f = ones_tensor(S, L, 2), nn.MultiheadAttention(2, H).to(device) qkv_f, mha_f = ones_tensor(S, L, 2), nn.MultiheadAttention(2, H).to(device)
@ -1254,7 +1244,7 @@ class TestTransformers(NNTestCase):
) )
def test_is_causal_gpu(self): def test_is_causal_gpu(self):
device = 'cuda' device = 'cuda'
self.is_causal_kernels(["math", "meff"], device) self.is_causal_kernels([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION], device)
def test_script_mha_in_proj_weight_none(self): def test_script_mha_in_proj_weight_none(self):
mha = torch.nn.MultiheadAttention( mha = torch.nn.MultiheadAttention(
@ -1342,10 +1332,10 @@ class TestSDPAFailureModes(NNTestCase):
size = (2, 2, 4, head_dim) size = (2, 2, 4, head_dim)
q, k, v = make_tensor(size), make_tensor(size), make_tensor(size) q, k, v = make_tensor(size), make_tensor(size), make_tensor(size)
with sdp_kernel(enable_mem_efficient=False, enable_flash=False, enable_math=True): with sdpa_kernel(backends=[SDPBackend.MATH]):
math_ref = torch.nn.functional.scaled_dot_product_attention(q, k, v, None, 0.0, False) math_ref = torch.nn.functional.scaled_dot_product_attention(q, k, v, None, 0.0, False)
with sdp_kernel(enable_mem_efficient=False, enable_flash=True, enable_math=False): with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
# Should not fail because inputs don't require grad # Should not fail because inputs don't require grad
flash_ref = torch.nn.functional.scaled_dot_product_attention(q, k, v, None, 0.0, False) flash_ref = torch.nn.functional.scaled_dot_product_attention(q, k, v, None, 0.0, False)
@ -1361,7 +1351,7 @@ class TestSDPAFailureModes(NNTestCase):
@onlyCUDA @onlyCUDA
def test_dispatch_fails_no_backend(self, device): def test_dispatch_fails_no_backend(self, device):
dtype = torch.float16 dtype = torch.float16
with sdp_kernel(enable_flash=False, enable_math=False, enable_mem_efficient=False): with sdpa_kernel(backends=[SDPBackend.ERROR]):
size = (2, 3, 4) size = (2, 3, 4)
q = torch.randn(size, device=device, dtype=dtype) q = torch.randn(size, device=device, dtype=dtype)
k = torch.randn(size, device=device, dtype=dtype) k = torch.randn(size, device=device, dtype=dtype)
@ -1378,7 +1368,7 @@ class TestSDPAFailureModes(NNTestCase):
PLATFORM_SPECIFIC_SDPA, PLATFORM_SPECIFIC_SDPA,
) )
def test_invalid_fused_inputs_dim_3(self, device, kernel: SDPBackend): def test_invalid_fused_inputs_dim_3(self, device, kernel: SDPBackend):
with sdp_kernel(**backend_map[kernel]): with sdpa_kernel(backends=[kernel]):
# Dim is not 4 # Dim is not 4
size = (2, 3, 8) size = (2, 3, 8)
dtype = torch.float16 dtype = torch.float16
@ -1396,7 +1386,7 @@ class TestSDPAFailureModes(NNTestCase):
PLATFORM_SPECIFIC_SDPA, PLATFORM_SPECIFIC_SDPA,
) )
def test_invalid_fused_inputs_broadcast(self, device, kernel: SDPBackend): def test_invalid_fused_inputs_broadcast(self, device, kernel: SDPBackend):
with sdp_kernel(**backend_map[kernel]): with sdpa_kernel(backends=[kernel]):
# Fused Kernels don't support broadcasting for dense inputs # Fused Kernels don't support broadcasting for dense inputs
dtype = torch.float16 dtype = torch.float16
size = (2, 4, 3, 8) size = (2, 4, 3, 8)
@ -1411,7 +1401,7 @@ class TestSDPAFailureModes(NNTestCase):
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Does not support fused scaled dot product attention") @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Does not support fused scaled dot product attention")
@parametrize("kernel", PLATFORM_SPECIFIC_SDPA) @parametrize("kernel", PLATFORM_SPECIFIC_SDPA)
def test_invalid_sequence_lengths(self, device, kernel: SDPBackend): def test_invalid_sequence_lengths(self, device, kernel: SDPBackend):
with sdp_kernel(**backend_map[kernel]): with sdpa_kernel(backends=[kernel]):
# Passing in a q,k,v with 0 length sequences will error # Passing in a q,k,v with 0 length sequences will error
dtype = torch.float16 dtype = torch.float16
make_tensor = partial(torch.rand, device=device, dtype=dtype) make_tensor = partial(torch.rand, device=device, dtype=dtype)
@ -1425,7 +1415,7 @@ class TestSDPAFailureModes(NNTestCase):
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Does not support fused scaled dot product attention") @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Does not support fused scaled dot product attention")
@parametrize("kernel", PLATFORM_SPECIFIC_SDPA) @parametrize("kernel", PLATFORM_SPECIFIC_SDPA)
def test_invalid_last_dim_stride(self, device, kernel: SDPBackend): def test_invalid_last_dim_stride(self, device, kernel: SDPBackend):
with sdp_kernel(**backend_map[kernel]): with sdpa_kernel(backends=[kernel]):
# Passing in a q,k,v with 0 length sequences will error # Passing in a q,k,v with 0 length sequences will error
dtype = torch.float16 dtype = torch.float16
make_tensor = partial(torch.rand, device=device, dtype=dtype) make_tensor = partial(torch.rand, device=device, dtype=dtype)
@ -1440,7 +1430,7 @@ class TestSDPAFailureModes(NNTestCase):
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not flash_attention fused scaled dot product attention") @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not flash_attention fused scaled dot product attention")
@parametrize("kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION]) @parametrize("kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION])
def test_invalid_fused_inputs_head_dim(self, device, kernel: SDPBackend): def test_invalid_fused_inputs_head_dim(self, device, kernel: SDPBackend):
with sdp_kernel(**backend_map[kernel]): with sdpa_kernel(backends=[kernel]):
# The embed dim per head is not divisible by 8 for flash attention # The embed dim per head is not divisible by 8 for flash attention
dtype = torch.float16 dtype = torch.float16
make_tensor = partial(torch.rand, device=device, dtype=dtype) make_tensor = partial(torch.rand, device=device, dtype=dtype)
@ -1456,7 +1446,7 @@ class TestSDPAFailureModes(NNTestCase):
PLATFORM_SPECIFIC_SDPA, PLATFORM_SPECIFIC_SDPA,
) )
def test_invalid_fused_inputs_invalid_dtype(self, device, kernel: SDPBackend): def test_invalid_fused_inputs_invalid_dtype(self, device, kernel: SDPBackend):
with sdp_kernel(**backend_map[kernel]): with sdpa_kernel(backends=[kernel]):
# Invalid dtype for both Flash Attention and Mem Efficient Attention # Invalid dtype for both Flash Attention and Mem Efficient Attention
size = SdpaShape(2, 2, 3, 16) size = SdpaShape(2, 2, 3, 16)
make_tensor = partial(torch.rand, device=device, dtype=torch.float64) make_tensor = partial(torch.rand, device=device, dtype=torch.float64)
@ -1468,7 +1458,7 @@ class TestSDPAFailureModes(NNTestCase):
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support flash attention") @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support flash attention")
@parametrize("kernel", [SDPBackend.FLASH_ATTENTION]) @parametrize("kernel", [SDPBackend.FLASH_ATTENTION])
def test_invalid_fused_inputs_attn_mask_present(self, device, kernel: SDPBackend): def test_invalid_fused_inputs_attn_mask_present(self, device, kernel: SDPBackend):
with sdp_kernel(**backend_map[kernel]): with sdpa_kernel(backends=[kernel]):
# Failures for unsupported SDP args # Failures for unsupported SDP args
size = SdpaShape(2, 2, 3, 16) size = SdpaShape(2, 2, 3, 16)
make_tensor = partial(torch.rand, size, device=device, dtype=torch.float16) make_tensor = partial(torch.rand, size, device=device, dtype=torch.float16)
@ -1486,7 +1476,7 @@ class TestSDPAFailureModes(NNTestCase):
size = SdpaShape(2, 2, 8, 5) size = SdpaShape(2, 2, 8, 5)
make_tensor = partial(torch.rand, size, device=device, dtype=dtype) make_tensor = partial(torch.rand, size, device=device, dtype=dtype)
q, k, v = make_tensor(), make_tensor(), make_tensor() q, k, v = make_tensor(), make_tensor(), make_tensor()
with sdp_kernel(enable_flash=False, enable_mem_efficient=True, enable_math=False): with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention( self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
q, k, v, None, 0.0, False)) q, k, v, None, 0.0, False))
@ -1497,7 +1487,7 @@ class TestSDPAFailureModes(NNTestCase):
size = SdpaShape(16, 16, 32, 32) size = SdpaShape(16, 16, 32, 32)
make_tensor = partial(torch.rand, size, device=device, dtype=dtype) make_tensor = partial(torch.rand, size, device=device, dtype=dtype)
q, k, v = make_tensor(), make_tensor(), make_tensor() q, k, v = make_tensor(), make_tensor(), make_tensor()
with sdp_kernel(enable_flash=True, enable_mem_efficient=False, enable_math=False): with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
with self.assertWarnsRegex(UserWarning, "Expected query, key and value to all be of dtype: {Half, BFloat16}"): with self.assertWarnsRegex(UserWarning, "Expected query, key and value to all be of dtype: {Half, BFloat16}"):
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention( self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
q, k, v, None, 0.0, False)) q, k, v, None, 0.0, False))
@ -1510,7 +1500,7 @@ class TestSDPAFailureModes(NNTestCase):
make_tensor = partial(torch.rand, size, device=device, dtype=dtype) make_tensor = partial(torch.rand, size, device=device, dtype=dtype)
q, k, v = make_tensor(), make_tensor(), make_tensor() q, k, v = make_tensor(), make_tensor(), make_tensor()
with torch.autocast(device_type='cuda', dtype=torch.float16): with torch.autocast(device_type='cuda', dtype=torch.float16):
with sdp_kernel(enable_flash=True, enable_mem_efficient=False, enable_math=False): with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
_ = torch.nn.functional.scaled_dot_product_attention( _ = torch.nn.functional.scaled_dot_product_attention(
q, k, v, None, 0.0, False) q, k, v, None, 0.0, False)
@ -1522,14 +1512,14 @@ class TestSDPAFailureModes(NNTestCase):
make_tensor = partial(torch.rand, size, device=device, dtype=dtype) make_tensor = partial(torch.rand, size, device=device, dtype=dtype)
q, k, v = make_tensor(), make_tensor(), make_tensor() q, k, v = make_tensor(), make_tensor(), make_tensor()
with torch.autocast(device_type='cuda', dtype=torch.bfloat16): with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
with sdp_kernel(enable_flash=True, enable_mem_efficient=False, enable_math=False): with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
_ = torch.nn.functional.scaled_dot_product_attention( _ = torch.nn.functional.scaled_dot_product_attention(
q, k, v, None, 0.0, False) q, k, v, None, 0.0, False)
# Note: do not truncate the list according to platforms. These tests should always raise errors. # Note: do not truncate the list according to platforms. These tests should always raise errors.
@parametrize("kernel", [SDPBackend.MATH, SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION]) @parametrize("kernel", [SDPBackend.MATH, SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION])
def test_invalid_inputs_different_datatypes(self, device, kernel: SDPBackend): def test_invalid_inputs_different_datatypes(self, device, kernel: SDPBackend):
with sdp_kernel(**backend_map[kernel]): with sdpa_kernel(backends=[kernel]):
# Different datatypes # Different datatypes
shape = (1, 4, 8, 16) shape = (1, 4, 8, 16)
query = torch.randn(shape, dtype=torch.float32, device=device) query = torch.randn(shape, dtype=torch.float32, device=device)
@ -1549,7 +1539,7 @@ class TestSDPAFailureModes(NNTestCase):
@parametrize("kernel", [SDPBackend.MATH, SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION]) @parametrize("kernel", [SDPBackend.MATH, SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION])
def test_invalid_inputs_1_dimensional_inputs(self, device, kernel: SDPBackend): def test_invalid_inputs_1_dimensional_inputs(self, device, kernel: SDPBackend):
with sdp_kernel(**backend_map[kernel]): with sdpa_kernel(backends=[kernel]):
# 1 dimensional input # 1 dimensional input
shape = (1, 4) shape = (1, 4)
query = torch.randn(4, dtype=torch.float16, device=device) query = torch.randn(4, dtype=torch.float16, device=device)
@ -1575,7 +1565,7 @@ class TestSDPAFailureModes(NNTestCase):
key = rand_nested_tensor(k_shape).transpose(1, 2) key = rand_nested_tensor(k_shape).transpose(1, 2)
value = rand_nested_tensor(v_shape).transpose(1, 2) value = rand_nested_tensor(v_shape).transpose(1, 2)
with sdp_kernel(enable_flash=False, enable_math=False, enable_mem_efficient=True): with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
with self.assertRaisesRegex(RuntimeError, "No available kernel"): with self.assertRaisesRegex(RuntimeError, "No available kernel"):
torch.nn.functional.scaled_dot_product_attention( torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False) query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False)
@ -1588,7 +1578,7 @@ class TestSDPAFailureModes(NNTestCase):
shape = SdpaShape(5, 8, seq_len_list, 57) shape = SdpaShape(5, 8, seq_len_list, 57)
make_tensor = partial(rand_sdpa_tensor, shape=shape, type="nested", device=device, dtype=dtype) make_tensor = partial(rand_sdpa_tensor, shape=shape, type="nested", device=device, dtype=dtype)
q, k, v = make_tensor(), make_tensor(), make_tensor() q, k, v = make_tensor(), make_tensor(), make_tensor()
with torch.backends.cuda.sdp_kernel(enable_math=False, enable_flash=True, enable_mem_efficient=False): with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
with self.assertWarnsRegex(UserWarning, "For NestedTensor inputs, Flash attention requires"): with self.assertWarnsRegex(UserWarning, "For NestedTensor inputs, Flash attention requires"):
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention( self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
q, k, v, None, 0.0, False)) q, k, v, None, 0.0, False))
@ -1602,7 +1592,7 @@ class TestSDPAFailureModes(NNTestCase):
size = SdpaShape(16, 16, 32, 32) size = SdpaShape(16, 16, 32, 32)
make_tensor = partial(torch.rand, size, device=device, dtype=dtype) make_tensor = partial(torch.rand, size, device=device, dtype=dtype)
q, k, v = make_tensor(), make_tensor(), make_tensor() q, k, v = make_tensor(), make_tensor(), make_tensor()
with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]): with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
with self.assertWarnsRegex(UserWarning, "Expected query, key and value to all be of dtype: {Half, Float}"): with self.assertWarnsRegex(UserWarning, "Expected query, key and value to all be of dtype: {Half, Float}"):
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention( self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
q, k, v, None, 0.0, False)) q, k, v, None, 0.0, False))
@ -1629,7 +1619,7 @@ class TestSDPAFailureModes(NNTestCase):
key = key.transpose(1, 2) key = key.transpose(1, 2)
value = value.transpose(1, 2) value = value.transpose(1, 2)
with sdp_kernel(**backend_map[fused_kernel]): with sdpa_kernel(backends=[fused_kernel]):
with self.assertRaisesRegex(RuntimeError, "No available kernel"): with self.assertRaisesRegex(RuntimeError, "No available kernel"):
torch.nn.functional.scaled_dot_product_attention( torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False) query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False)
@ -1653,7 +1643,7 @@ class TestSDPAFailureModes(NNTestCase):
key = key.transpose(1, 2) key = key.transpose(1, 2)
value = value.transpose(1, 2) value = value.transpose(1, 2)
with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]): with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
with self.assertWarnsRegex(UserWarning, "Both fused kernels do not support training with broadcasted NT inputs"): with self.assertWarnsRegex(UserWarning, "Both fused kernels do not support training with broadcasted NT inputs"):
with self.assertRaisesRegex(RuntimeError, "No available kernel"): with self.assertRaisesRegex(RuntimeError, "No available kernel"):
out = torch.nn.functional.scaled_dot_product_attention( out = torch.nn.functional.scaled_dot_product_attention(
@ -1669,7 +1659,7 @@ class TestSDPAFailureModes(NNTestCase):
make_kv = partial(torch.rand, kv_shape, device=device, dtype=dtype) make_kv = partial(torch.rand, kv_shape, device=device, dtype=dtype)
q, k, v = make_q(), make_kv(), make_kv() q, k, v = make_q(), make_kv(), make_kv()
warning_str = "Flash attention does not support the is_causal flag when seqlen_q != seqlen_k." warning_str = "Flash attention does not support the is_causal flag when seqlen_q != seqlen_k."
with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]): with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
with self.assertWarnsRegex(UserWarning, warning_str): with self.assertWarnsRegex(UserWarning, warning_str):
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention( self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
q, k, v, None, 0.0, is_causal=True)) q, k, v, None, 0.0, is_causal=True))
@ -1745,7 +1735,7 @@ class TestSDPA(NNTestCase):
key = key.contiguous() key = key.contiguous()
value = value.contiguous() value = value.contiguous()
with sdp_kernel(enable_math=True, enable_mem_efficient=False, enable_flash=False): with sdpa_kernel(backends=[SDPBackend.MATH]):
assert gradcheck(lambda *args, **kwargs: assert gradcheck(lambda *args, **kwargs:
wrapper_set_seed(torch.nn.functional.scaled_dot_product_attention, *args, **kwargs), wrapper_set_seed(torch.nn.functional.scaled_dot_product_attention, *args, **kwargs),
(query, key, value, None, 0.0, False) (query, key, value, None, 0.0, False)
@ -1820,10 +1810,10 @@ class TestSDPA(NNTestCase):
q2 = q2.view(batch_size, seq_len, n_head, head_dim).transpose(1, 2) q2 = q2.view(batch_size, seq_len, n_head, head_dim).transpose(1, 2)
v2 = v2.view(batch_size, seq_len, n_head, head_dim).transpose(1, 2) v2 = v2.view(batch_size, seq_len, n_head, head_dim).transpose(1, 2)
with sdp_kernel(**backend_map[fused_kernel]): with sdpa_kernel(backends=[fused_kernel]):
actual = torch.nn.functional.scaled_dot_product_attention( actual = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=None, dropout_p=0.0, is_causal=causal) q, k, v, attn_mask=None, dropout_p=0.0, is_causal=causal)
with sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False): with sdpa_kernel(backends=[SDPBackend.MATH]):
math_ref = torch.nn.functional.scaled_dot_product_attention( math_ref = torch.nn.functional.scaled_dot_product_attention(
q2, k2, v2, attn_mask=None, dropout_p=0.0, is_causal=causal) q2, k2, v2, attn_mask=None, dropout_p=0.0, is_causal=causal)
@ -1907,10 +1897,10 @@ class TestSDPA(NNTestCase):
k2 = k2.view(batch_size, kv_seq_len, n_head, head_dim).transpose(1, 2) k2 = k2.view(batch_size, kv_seq_len, n_head, head_dim).transpose(1, 2)
v2 = v2.view(batch_size, kv_seq_len, n_head, head_dim).transpose(1, 2) v2 = v2.view(batch_size, kv_seq_len, n_head, head_dim).transpose(1, 2)
with sdp_kernel(**backend_map[fused_kernel]): with sdpa_kernel(backends=[fused_kernel]):
actual = torch.nn.functional.scaled_dot_product_attention( actual = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=attn_mask, dropout_p=0.0, is_causal=False) q, k, v, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
with sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False): with sdpa_kernel(backends=[SDPBackend.MATH]):
if not bool_mask and dtype is torch.bfloat16: if not bool_mask and dtype is torch.bfloat16:
attn_mask = attn_mask.float() attn_mask = attn_mask.float()
math_ref = torch.nn.functional.scaled_dot_product_attention( math_ref = torch.nn.functional.scaled_dot_product_attention(
@ -1944,7 +1934,7 @@ class TestSDPA(NNTestCase):
x = torch.randn(1, 3, 64, 64, device=device) x = torch.randn(1, 3, 64, 64, device=device)
ref_result = ref(x) ref_result = ref(x)
with sdp_kernel(**backend_map[kernel]): with sdpa_kernel(backends=[kernel]):
sdp_math = torch.nn.functional.scaled_dot_product_attention(x, x, x, scale=-1.0 / 0.0001) sdp_math = torch.nn.functional.scaled_dot_product_attention(x, x, x, scale=-1.0 / 0.0001)
self.assertEqual(ref_result, sdp_math) self.assertEqual(ref_result, sdp_math)
@ -2038,7 +2028,7 @@ class TestSDPACudaOnly(NNTestCase):
mask = torch.randn((num_heads, seq_len_q, seq_len_kv), device=device, dtype=dtype) mask = torch.randn((num_heads, seq_len_q, seq_len_kv), device=device, dtype=dtype)
elif mask_dim == 4: elif mask_dim == 4:
mask = torch.randn((batch, num_heads, seq_len_q, seq_len_kv), device=device, dtype=dtype) mask = torch.randn((batch, num_heads, seq_len_q, seq_len_kv), device=device, dtype=dtype)
with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]): with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
out = F.scaled_dot_product_attention(query, key, value, mask) out = F.scaled_dot_product_attention(query, key, value, mask)
out.sum().backward() out.sum().backward()
@ -2052,7 +2042,7 @@ class TestSDPACudaOnly(NNTestCase):
kv_shape = SdpaShape(batch, num_heads, seq_len_kv, head_dim) kv_shape = SdpaShape(batch, num_heads, seq_len_kv, head_dim)
key, value = make_tensor(kv_shape), make_tensor(kv_shape) key, value = make_tensor(kv_shape), make_tensor(kv_shape)
mask = torch.randn((batch, num_heads, seq_len_q, seq_len_kv), device=device, dtype=dtype) mask = torch.randn((batch, num_heads, seq_len_q, seq_len_kv), device=device, dtype=dtype)
with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]): with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
out = F.scaled_dot_product_attention(query, key, value, mask) out = F.scaled_dot_product_attention(query, key, value, mask)
out.sum().backward() out.sum().backward()
@ -2067,7 +2057,7 @@ class TestSDPACudaOnly(NNTestCase):
key, value = make_tensor(kv_shape), make_tensor(kv_shape) key, value = make_tensor(kv_shape), make_tensor(kv_shape)
mask = torch.randn((batch, num_heads, seq_len_q, seq_len_kv), device=device, dtype=dtype) mask = torch.randn((batch, num_heads, seq_len_q, seq_len_kv), device=device, dtype=dtype)
mask = torch.as_strided(mask, (batch, num_heads, seq_len_q, seq_len_kv), (0, 0, 0, 1)) mask = torch.as_strided(mask, (batch, num_heads, seq_len_q, seq_len_kv), (0, 0, 0, 1))
with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]): with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
out = F.scaled_dot_product_attention(query, key, value, mask) out = F.scaled_dot_product_attention(query, key, value, mask)
out.sum().backward() out.sum().backward()
@ -2084,7 +2074,7 @@ class TestSDPACudaOnly(NNTestCase):
kv_shape = SdpaShape(batch, num_heads, seq_len_kv, head_dim) kv_shape = SdpaShape(batch, num_heads, seq_len_kv, head_dim)
key, value = make_tensor(kv_shape), make_tensor(kv_shape) key, value = make_tensor(kv_shape), make_tensor(kv_shape)
mask = torch.randn((batch, num_heads, seq_len_q, seq_len_kv), device=device, dtype=dtype) mask = torch.randn((batch, num_heads, seq_len_q, seq_len_kv), device=device, dtype=dtype)
with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]): with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
out = F.scaled_dot_product_attention(query, key, value, mask) out = F.scaled_dot_product_attention(query, key, value, mask)
out.sum().backward() out.sum().backward()
@ -2113,7 +2103,7 @@ class TestSDPACudaOnly(NNTestCase):
value = torch.randn(value_num_elements, device=device).as_strided(value_size, value_strides) value = torch.randn(value_num_elements, device=device).as_strided(value_size, value_strides)
bias = torch.randn(attention_mask_num_elements, device=device).as_strided(attention_mask_size, attn_mask_strides) bias = torch.randn(attention_mask_num_elements, device=device).as_strided(attention_mask_size, attn_mask_strides)
with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]): with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
out = F.scaled_dot_product_attention(query, key, value, bias) out = F.scaled_dot_product_attention(query, key, value, bias)
out_contig = F.scaled_dot_product_attention(query, key, value, bias.contiguous()) out_contig = F.scaled_dot_product_attention(query, key, value, bias.contiguous())
@ -2152,10 +2142,10 @@ class TestSDPACudaOnly(NNTestCase):
key = key.contiguous() key = key.contiguous()
value = value.contiguous() value = value.contiguous()
with sdp_kernel(enable_flash=False, enable_math=False, enable_mem_efficient=True): with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
actual = torch.nn.functional.scaled_dot_product_attention( actual = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False) query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False)
with sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False): with sdpa_kernel(backends=[SDPBackend.MATH]):
math_ref = torch.nn.functional.scaled_dot_product_attention( math_ref = torch.nn.functional.scaled_dot_product_attention(
query.contiguous(), key.contiguous(), value.contiguous(), query.contiguous(), key.contiguous(), value.contiguous(),
attn_mask=None, dropout_p=0.0, is_causal=False) attn_mask=None, dropout_p=0.0, is_causal=False)
@ -2196,11 +2186,11 @@ class TestSDPACudaOnly(NNTestCase):
key_lp = key_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) key_lp = key_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
value_lp = value_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) value_lp = value_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
with sdp_kernel(**backend_map[fused_kernel]): with sdpa_kernel(backends=[fused_kernel]):
actual = torch.nn.functional.scaled_dot_product_attention( actual = torch.nn.functional.scaled_dot_product_attention(
query_lp, key_lp, value_lp, attn_mask=None, dropout_p=0.0, is_causal=False) query_lp, key_lp, value_lp, attn_mask=None, dropout_p=0.0, is_causal=False)
with sdp_kernel(**backend_map[SDPBackend.MATH]): with sdpa_kernel(backends=[SDPBackend.MATH]):
math_ref_lp = torch.nn.functional.scaled_dot_product_attention( math_ref_lp = torch.nn.functional.scaled_dot_product_attention(
query_lp.contiguous(), key_lp.contiguous(), value_lp.contiguous(), query_lp.contiguous(), key_lp.contiguous(), value_lp.contiguous(),
attn_mask=None, dropout_p=0.0, is_causal=False) attn_mask=None, dropout_p=0.0, is_causal=False)
@ -2259,10 +2249,10 @@ class TestSDPACudaOnly(NNTestCase):
key_lp = key_lp.contiguous() key_lp = key_lp.contiguous()
value_lp = value_lp.contiguous() value_lp = value_lp.contiguous()
with sdp_kernel(enable_math=True, enable_mem_efficient=False, enable_flash=False): with sdpa_kernel(backends=[SDPBackend.MATH]):
out = torch.nn.functional.scaled_dot_product_attention(query, key, value, None, 0.0, is_causal) out = torch.nn.functional.scaled_dot_product_attention(query, key, value, None, 0.0, is_causal)
with sdp_kernel(enable_math=False, enable_mem_efficient=True, enable_flash=False): with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
out_lp = torch.nn.functional.scaled_dot_product_attention( out_lp = torch.nn.functional.scaled_dot_product_attention(
query_lp, key_lp, value_lp, None, 0.0, is_causal) query_lp, key_lp, value_lp, None, 0.0, is_causal)
@ -2308,10 +2298,10 @@ class TestSDPACudaOnly(NNTestCase):
key_lp = key_lp.contiguous() key_lp = key_lp.contiguous()
value_lp = value_lp.contiguous() value_lp = value_lp.contiguous()
with sdp_kernel(enable_math=True, enable_mem_efficient=False, enable_flash=False): with sdpa_kernel(backends=[SDPBackend.MATH]):
out = torch.nn.functional.scaled_dot_product_attention(query, key, value, None, 0.0, is_causal) out = torch.nn.functional.scaled_dot_product_attention(query, key, value, None, 0.0, is_causal)
with sdp_kernel(enable_math=False, enable_mem_efficient=False, enable_flash=True): with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
out_lp = torch.nn.functional.scaled_dot_product_attention( out_lp = torch.nn.functional.scaled_dot_product_attention(
query_lp, key_lp, value_lp, None, 0.0, is_causal) query_lp, key_lp, value_lp, None, 0.0, is_causal)
@ -2369,7 +2359,8 @@ class TestSDPACudaOnly(NNTestCase):
query, key, value = make_tensor(shape), make_tensor(shape), make_tensor(shape) query, key, value = make_tensor(shape), make_tensor(shape), make_tensor(shape)
with use_deterministic_algorithims(True, warn_only=warn_only): with use_deterministic_algorithims(True, warn_only=warn_only):
with sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=True): # Note that this should swith to a testing version with we remove old context manager
with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]):
assert torch._fused_sdp_choice(query, key, value) == SDPBackend.EFFICIENT_ATTENTION.value assert torch._fused_sdp_choice(query, key, value) == SDPBackend.EFFICIENT_ATTENTION.value
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Platform does not support fused SDPA") @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Platform does not support fused SDPA")
@ -2389,7 +2380,7 @@ class TestSDPACudaOnly(NNTestCase):
else contextlib.nullcontext() else contextlib.nullcontext()
) )
with use_deterministic_algorithims(True, warn_only=warn_only): with use_deterministic_algorithims(True, warn_only=warn_only):
with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]): with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
with warning_context: with warning_context:
torch.nn.functional.scaled_dot_product_attention(query, key, value).sum().backward() torch.nn.functional.scaled_dot_product_attention(query, key, value).sum().backward()
@ -2406,7 +2397,7 @@ class TestSDPACudaOnly(NNTestCase):
value = torch.rand(batch_size, n_heads, seq_len, head_dim, value = torch.rand(batch_size, n_heads, seq_len, head_dim,
device=device, dtype=dtype, requires_grad=True) device=device, dtype=dtype, requires_grad=True)
with sdp_kernel(enable_mem_efficient=True, enable_math=False, enable_flash=False): with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
# Run once to establish baseline # Run once to establish baseline
out = F.scaled_dot_product_attention(query, key, value) out = F.scaled_dot_product_attention(query, key, value)
upward_grad = torch.rand_like(out) upward_grad = torch.rand_like(out)
@ -2483,13 +2474,13 @@ class TestSDPACudaOnly(NNTestCase):
query_ref, key_ref, value_ref = query_key_value_clones(query, key, value, dtype=higher_precision_dtype) query_ref, key_ref, value_ref = query_key_value_clones(query, key, value, dtype=higher_precision_dtype)
# Create real output # Create real output
with sdp_kernel(enable_mem_efficient=True, enable_flash=False, enable_math=False): with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
# Set the seed and run the kernel # Set the seed and run the kernel
torch.manual_seed(seed) torch.manual_seed(seed)
out = F.scaled_dot_product_attention(query, key, value, dropout_p=dropout_p, is_causal=is_causal, scale=scale) out = F.scaled_dot_product_attention(query, key, value, dropout_p=dropout_p, is_causal=is_causal, scale=scale)
if dropout_p == 0.0: if dropout_p == 0.0:
with sdp_kernel(enable_math=True, enable_flash=False, enable_mem_efficient=False): with sdpa_kernel(backends=[SDPBackend.MATH]):
# High Precision Math Reference # High Precision Math Reference
out_ref = F.scaled_dot_product_attention(query_ref, key_ref, value_ref, out_ref = F.scaled_dot_product_attention(query_ref, key_ref, value_ref,
dropout_p=dropout_p, is_causal=is_causal, scale=scale) dropout_p=dropout_p, is_causal=is_causal, scale=scale)
@ -2589,14 +2580,14 @@ class TestSDPACudaOnly(NNTestCase):
attn_mask_ref = attn_mask.detach().to(higher_precision_dtype).requires_grad_(True) attn_mask_ref = attn_mask.detach().to(higher_precision_dtype).requires_grad_(True)
# Create real output # Create real output
with sdp_kernel(enable_mem_efficient=True, enable_flash=False, enable_math=False): with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
# Set the seed and run the kernel # Set the seed and run the kernel
torch.manual_seed(seed) torch.manual_seed(seed)
out = F.scaled_dot_product_attention(query, key, value, attn_mask, dropout_p=dropout_p, out = F.scaled_dot_product_attention(query, key, value, attn_mask, dropout_p=dropout_p,
is_causal=is_causal, scale=scale) is_causal=is_causal, scale=scale)
if dropout_p == 0.0: if dropout_p == 0.0:
with sdp_kernel(enable_math=True, enable_flash=False, enable_mem_efficient=False): with sdpa_kernel(backends=[SDPBackend.MATH]):
# High Precision Math Reference # High Precision Math Reference
out_ref = F.scaled_dot_product_attention(query_ref, key_ref, value_ref, attn_mask_ref, out_ref = F.scaled_dot_product_attention(query_ref, key_ref, value_ref, attn_mask_ref,
dropout_p=dropout_p, is_causal=is_causal, scale=scale) dropout_p=dropout_p, is_causal=is_causal, scale=scale)
@ -2710,9 +2701,9 @@ class TestSDPACudaOnly(NNTestCase):
if not is_dropout: if not is_dropout:
# Problem: We pad sizes in the composite region of the top level SDPA. But we need the # Problem: We pad sizes in the composite region of the top level SDPA. But we need the
# Debug mask when have dropout. So I am going to manualy pad up here when testing dropout # Debug mask when have dropout. So I am going to manualy pad up here when testing dropout
with sdp_kernel(enable_math=False, enable_flash=True, enable_mem_efficient=False): with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
out = F.scaled_dot_product_attention(query, key, value, dropout_p=dropout_p, is_causal=is_causal, scale=scale) out = F.scaled_dot_product_attention(query, key, value, dropout_p=dropout_p, is_causal=is_causal, scale=scale)
with sdp_kernel(enable_math=True, enable_flash=False, enable_mem_efficient=False): with sdpa_kernel(backends=[SDPBackend.MATH]):
# High Precision Math Reference # High Precision Math Reference
out_ref = F.scaled_dot_product_attention( out_ref = F.scaled_dot_product_attention(
query_ref, key_ref, value_ref, is_causal=is_causal, scale=scale) query_ref, key_ref, value_ref, is_causal=is_causal, scale=scale)
@ -2884,7 +2875,7 @@ class TestSDPACudaOnly(NNTestCase):
# replays produce different results # replays produce different results
self.assertNotEqual(out_first, out) self.assertNotEqual(out_first, out)
with sdp_kernel(enable_math=True, enable_flash=False, enable_mem_efficient=False): with sdpa_kernel(backends=[SDPBackend.MATH]):
if dropout_p == 0.0: if dropout_p == 0.0:
# High Precision Math Reference # High Precision Math Reference
out_ref = F.scaled_dot_product_attention(query_ref, key_ref, value_ref, out_ref = F.scaled_dot_product_attention(query_ref, key_ref, value_ref,
@ -2964,10 +2955,10 @@ class TestSDPACudaOnly(NNTestCase):
key = key.transpose(1, 2) key = key.transpose(1, 2)
value = value.transpose(1, 2) value = value.transpose(1, 2)
with sdp_kernel(**backend_map[fused_kernel]): with sdpa_kernel(backends=[fused_kernel]):
actual = torch.nn.functional.scaled_dot_product_attention( actual = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False) query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False)
with sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False): with sdpa_kernel(backends=[SDPBackend.MATH]):
math_ref = torch.nn.functional.scaled_dot_product_attention( math_ref = torch.nn.functional.scaled_dot_product_attention(
query.contiguous().to(torch.float32), query.contiguous().to(torch.float32),
key.contiguous().to(torch.float32), key.contiguous().to(torch.float32),
@ -3056,10 +3047,10 @@ class TestSDPACudaOnly(NNTestCase):
key = key.transpose(1, 2) key = key.transpose(1, 2)
value = value.transpose(1, 2) value = value.transpose(1, 2)
with sdp_kernel(**backend_map[kernel]): with sdpa_kernel(backends=[kernel]):
actual = torch.nn.functional.scaled_dot_product_attention( actual = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False) query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False)
with sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False): with sdpa_kernel(backends=[SDPBackend.MATH]):
math_ref = torch.nn.functional.scaled_dot_product_attention( math_ref = torch.nn.functional.scaled_dot_product_attention(
query_expanded.contiguous(), key_expanded.contiguous(), value_expanded.contiguous(), query_expanded.contiguous(), key_expanded.contiguous(), value_expanded.contiguous(),
attn_mask=None, dropout_p=0.0, is_causal=False) attn_mask=None, dropout_p=0.0, is_causal=False)
@ -3090,10 +3081,10 @@ class TestSDPACudaOnly(NNTestCase):
key = key.transpose(1, 2) key = key.transpose(1, 2)
value = value.transpose(1, 2) value = value.transpose(1, 2)
with sdp_kernel(enable_flash=False, enable_math=False, enable_mem_efficient=True): with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
actual = torch.nn.functional.scaled_dot_product_attention( actual = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False) query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False)
with sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False): with sdpa_kernel(backends=[SDPBackend.MATH]):
math_ref = torch.nn.functional.scaled_dot_product_attention( math_ref = torch.nn.functional.scaled_dot_product_attention(
query_expanded.contiguous(), key.contiguous(), value_expanded.contiguous(), query_expanded.contiguous(), key.contiguous(), value_expanded.contiguous(),
attn_mask=None, dropout_p=0.0, is_causal=False) attn_mask=None, dropout_p=0.0, is_causal=False)
@ -3146,9 +3137,9 @@ class TestSDPACudaOnly(NNTestCase):
is_dropout = dropout_p > 0.0 is_dropout = dropout_p > 0.0
if not is_dropout: if not is_dropout:
with sdp_kernel(enable_math=False, enable_flash=True, enable_mem_efficient=False): with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
out = F.scaled_dot_product_attention(query, key, value, dropout_p=dropout_p, is_causal=is_causal, scale=scale) out = F.scaled_dot_product_attention(query, key, value, dropout_p=dropout_p, is_causal=is_causal, scale=scale)
with sdp_kernel(enable_math=True, enable_flash=False, enable_mem_efficient=False): with sdpa_kernel(backends=[SDPBackend.MATH]):
# High Precision Math Reference # High Precision Math Reference
out_ref = F.scaled_dot_product_attention( out_ref = F.scaled_dot_product_attention(
query_ref, key_ref, value_ref, is_causal=is_causal, scale=scale) query_ref, key_ref, value_ref, is_causal=is_causal, scale=scale)

View File

@ -1,4 +1,5 @@
import contextlib import contextlib
import warnings
from typing import Union from typing import Union
@ -13,7 +14,6 @@ __all__ = [
"preferred_linalg_library", "preferred_linalg_library",
"cufft_plan_cache", "cufft_plan_cache",
"matmul", "matmul",
"SDPBackend",
"SDPAParams", "SDPAParams",
"enable_flash_sdp", "enable_flash_sdp",
"flash_sdp_enabled", "flash_sdp_enabled",
@ -204,10 +204,9 @@ def preferred_linalg_library(
return torch._C._get_linalg_preferred_backend() return torch._C._get_linalg_preferred_backend()
from torch._C import _SDPAParams as SDPAParams, _SDPBackend as SDPBackend from torch._C import _SDPAParams as SDPAParams
# Set the __module__ attribute # Set the __module__ attribute
SDPBackend.__module__ = "torch.backends.cuda"
SDPAParams.__module__ = "torch.backends.cuda" SDPAParams.__module__ = "torch.backends.cuda"
SDPAParams.__name__ = "SDPAParams" SDPAParams.__name__ = "SDPAParams"
@ -318,18 +317,30 @@ def sdp_kernel(
This context manager can be used to temporarily enable or disable any of the three backends for scaled dot product attention. This context manager can be used to temporarily enable or disable any of the three backends for scaled dot product attention.
Upon exiting the context manager, the previous state of the flags will be restored. Upon exiting the context manager, the previous state of the flags will be restored.
""" """
previous_flash: bool = flash_sdp_enabled() warnings.warn(
previous_mem_efficient: bool = mem_efficient_sdp_enabled() (
previous_math: bool = math_sdp_enabled() "torch.backends.cuda.sdp_kernel() "
try: "is deprecated. In the future, this context manager will be removed. "
enable_flash_sdp(enable_flash) "Please see, torch.nn.attention.sdpa_kernel() for the new context manager, with updated "
enable_mem_efficient_sdp(enable_mem_efficient) "signature."
enable_math_sdp(enable_math) ),
yield {} FutureWarning,
finally: )
enable_flash_sdp(previous_flash) from torch.nn.attention import sdpa_kernel, SDPBackend
enable_mem_efficient_sdp(previous_mem_efficient)
enable_math_sdp(previous_math) backend_list = []
if enable_flash:
backend_list.append(SDPBackend.FLASH_ATTENTION)
if enable_mem_efficient:
backend_list.append(SDPBackend.EFFICIENT_ATTENTION)
if enable_math:
backend_list.append(SDPBackend.MATH)
with sdpa_kernel(backend_list) as context:
try:
yield context
finally:
pass
cufft_plan_cache = cuFFTPlanCacheManager() cufft_plan_cache = cuFFTPlanCacheManager()

View File

@ -1809,7 +1809,9 @@ Call this whenever a new thread is created in order to propagate values from
py::enum_<sdp::SDPBackend>( py::enum_<sdp::SDPBackend>(
py_module, py_module,
"_SDPBackend", "_SDPBackend",
"Enum class for the scaled dot product attention backends\n\n... warning:: This class is in beta and subject to change.") "An enum-like class that contains the different backends for scaled dot product attention.\n\n... warning:: This class is in beta and subject to change.\n\n"
"This backend class is designed to be used with the sdpa_kernel context manager."
"See :func: torch.nn.attention.sdpa_kernel for more details.")
.value("ERROR", sdp::SDPBackend::error) .value("ERROR", sdp::SDPBackend::error)
.value("MATH", sdp::SDPBackend::math) .value("MATH", sdp::SDPBackend::math)
.value("FLASH_ATTENTION", sdp::SDPBackend::flash_attention) .value("FLASH_ATTENTION", sdp::SDPBackend::flash_attention)

View File

@ -12,9 +12,10 @@ from torch.backends.cuda import (
math_sdp_enabled, math_sdp_enabled,
mem_efficient_sdp_enabled, mem_efficient_sdp_enabled,
SDPAParams, SDPAParams,
SDPBackend,
) )
from torch.nn.attention import SDPBackend
from .nested_tensor import buffer_from_jagged, NestedTensor, ViewNestedFromBuffer from .nested_tensor import buffer_from_jagged, NestedTensor, ViewNestedFromBuffer
log = logging.getLogger(__name__) log = logging.getLogger(__name__)

View File

@ -1,16 +1,24 @@
from typing import List """ This module contains functions and classes that alter the behavior of torch.nn.functional.scaled_dot_product_attention """
import contextlib
from typing import List, Union
from warnings import warn from warnings import warn
from torch.backends.cuda import ( from torch.backends.cuda import (
can_use_efficient_attention, can_use_efficient_attention,
can_use_flash_attention, can_use_flash_attention,
enable_flash_sdp,
enable_math_sdp,
enable_mem_efficient_sdp,
flash_sdp_enabled,
math_sdp_enabled,
mem_efficient_sdp_enabled,
SDPAParams, SDPAParams,
) )
__all__: List[str] = [] __all__: List[str] = ["SDPBackend", "sdpa_kernel", "WARN_FOR_UNFUSED_KERNELS"]
# Note: [SDPA warnings] # Note: [SDPA warnings]
# TODO: Consider using this to sdpa regardless of subclasses # TODO: Consider using this for sdpa regardless of subclasses
# This only effects users of bias subclasses # This only effects users of bias subclasses
# If this is set to True, we will warn the user if they are not using the fused kernels # If this is set to True, we will warn the user if they are not using the fused kernels
# As well, it will raise warnings for all the reasons why the fused kernels can't be run. # As well, it will raise warnings for all the reasons why the fused kernels can't be run.
@ -19,6 +27,21 @@ __all__: List[str] = []
WARN_FOR_UNFUSED_KERNELS = False WARN_FOR_UNFUSED_KERNELS = False
from torch._C import _SDPBackend as SDPBackend
# Hacks for Sphinx documentation:
# https://stackoverflow.com/questions/38765577/overriding-sphinx-autodoc-alias-of-for-import-of-private-class
SDPBackend = SDPBackend
r"""An enum-like class that contains the different backends for scaled dot product attention.
This backend class is designed to be used with the sdpa_kernel context manager.
See :func: torch.nn.attention.sdpa_kernel for more details.
... warning:: This class is in beta and subject to change.
"""
SDPBackend.__module__ = __name__
SDPBackend.__name__ = "SDPBackend"
def _raise_kernel_warnings(params: SDPAParams) -> None: def _raise_kernel_warnings(params: SDPAParams) -> None:
""" """
If WARN_FOR_UNFUSED_KERNELS is set to True, this will raise warnings If WARN_FOR_UNFUSED_KERNELS is set to True, this will raise warnings
@ -31,3 +54,39 @@ def _raise_kernel_warnings(params: SDPAParams) -> None:
if not can_use_flash_attention(params): if not can_use_flash_attention(params):
warn("Flash attention can't be used because:") warn("Flash attention can't be used because:")
can_use_flash_attention(params, True) can_use_flash_attention(params, True)
@contextlib.contextmanager
def sdpa_kernel(backends: List[SDPBackend]):
r"""
Context manager to select which backend to use for scaled dot product attention.
.. warning:: This function is beta and subject to change.
Args:
backend (Union[List[SDPBackend], SDPBackend]): A backend or list of backends for scaled dot product attention.
This context manager can be used to select which backend to use for scaled dot product attention.
Upon exiting the context manager, the previous state of the flags will be restored, enabling all backends.
"""
assert backends is None or isinstance(
backends, list
), "Backend must be an instance of SDPBackend or a list of SDPBackend instances"
backends = set(backends)
previous_flash: bool = flash_sdp_enabled()
previous_mem_efficient: bool = mem_efficient_sdp_enabled()
previous_math: bool = math_sdp_enabled()
try:
enable_flash = SDPBackend.FLASH_ATTENTION in backends
enable_mem_efficient = SDPBackend.EFFICIENT_ATTENTION in backends
enable_math = SDPBackend.MATH in backends
enable_flash_sdp(enable_flash)
enable_mem_efficient_sdp(enable_mem_efficient)
enable_math_sdp(enable_math)
yield {}
finally:
enable_flash_sdp(previous_flash)
enable_mem_efficient_sdp(previous_mem_efficient)
enable_math_sdp(previous_math)

View File

@ -1,4 +1,4 @@
"""Defines utilities for interacting with scaled_dot_product_attention""" """Defines bias subclasses that work with scaled_dot_product_attention"""
from enum import auto, IntEnum from enum import auto, IntEnum
from typing import Optional from typing import Optional
from warnings import warn from warnings import warn

File diff suppressed because it is too large Load Diff