pytorch/torch/nested/_internal/sdpa.py
drisspg 4e29f01bf2 Remove sdp_kernel and replace with sdpa_kernel in attention namespace (#114689)
# Summary
Simplification of Backend Selection

This PR deprecates the `torch.backends/cuda/sdp_kernel` context manager and replaces it with a new context manager `torch.nn.attention.sdpa_kernel`. This context manager also changes the api for this context manager.

For `sdp_kernel` one would specify the backend choice by taking the negation of what kernel they would like to run. The purpose of this backend manager was to only to be a debugging tool, "turn off the math backend" and see if you can run one of the fused implementations.

Problems:
- This pattern makes sense if majority of users don't care to know anything about the backends that can be run. However, if users are seeking to use this context manager then they are explicitly trying to run a specific backend.
- This is not scalable. We are working on adding the cudnn backend and this API makes it so so that more implementations will need to be turned off if user wants to explicitly run a given backend.
- Discoverability of the current context manager. It is somewhat un-intutive that this backend manager is in backends/cuda/init when this now also controls the CPU fused kernel behavior. I think centralizing to attention namespace will be helpful.

Other concerns:
- Typically backends (kernels) for operators are entirely hidden from users and implementation details of the framework. We have exposed this to users already, albeit not by default and with beta warnings. Does making backends choices even more explicit lead to problems when we potentially want to remove existing backends, (perhaps inputs shapes will get covered by newer backends).

A nice side effect is now that we aren't using the `BACKEND_MAP` in test_transformers many, many dynamo failures are passing for CPU tests.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114689
Approved by: https://github.com/cpuhrsch
2024-01-24 22:28:04 +00:00

776 lines
27 KiB
Python

import logging
import math
from typing import Optional, Tuple
import torch
import torch.nn
import torch.nn.functional as F
from torch.backends.cuda import (
can_use_efficient_attention,
can_use_flash_attention,
flash_sdp_enabled,
math_sdp_enabled,
mem_efficient_sdp_enabled,
SDPAParams,
)
from torch.nn.attention import SDPBackend
from .nested_tensor import buffer_from_jagged, NestedTensor, ViewNestedFromBuffer
log = logging.getLogger(__name__)
def _validate_sdpa_input(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p=0.0,
is_causal=False,
scale=None,
):
if (
not isinstance(query, NestedTensor)
or not isinstance(key, NestedTensor)
or not isinstance(value, NestedTensor)
):
raise ValueError(
f"Expected query, key, and value to be nested tensors, "
f"but got query.is_nested: {query.is_nested}, key.is_nested: {key.is_nested}, "
f"and value.is_nested: {value.is_nested} instead."
)
if query.dtype != key.dtype or query.dtype != value.dtype:
raise ValueError(
f"Expected query, key, and value to have the same dtype, "
f"but got query.dtype: {query.dtype}, key.dtype: {key.dtype}, "
f"and value.dtype: {value.dtype} instead."
)
if query.device != key.device or query.device != value.device:
raise ValueError(
f"Expected query, key, and value to have the same device type, "
f"but got query.device: {query.device}, key.device: {key.device}, "
f"and value.device: {value.device} instead."
)
if query.dim() < 2 or key.dim() < 2 or value.dim() < 2:
raise ValueError(
f"Expected query, key, and value to all be at least 2 dimensional, but got query.dim: "
f"{query.dim()}, key.dim: {key.dim()} and value.dim: {value.dim()} instead."
)
if query._ragged_idx != key._ragged_idx or query._ragged_idx != value._ragged_idx:
raise ValueError(
f"Expected query, key, and value to all be ragged on the same dimension, but got ragged "
f"dims {query._ragged_idx}, {key._ragged_idx}, and {value._ragged_idx}, respectively."
)
if attn_mask is not None:
# TODO: Figure out whether masks are actually supported for this layout or not
raise ValueError("Masks are not yet supported!")
if attn_mask.dtype != torch.bool and attn_mask.dtype != query.dtype:
raise ValueError(
f"Expected attn_mask dtype to be bool or to match query dtype, but got attn_mask.dtype: "
f"{attn_mask.dtype}, and query.dtype: {query.dtype} instead."
)
def _check_batch_size_nested(params: SDPAParams, debug=False) -> bool:
# This is expected to be called after check_tensor_shapes ensuring that the
# size() calls won't error since the inputs are all 4 dimensional
q_batch_size = params.query.size(0)
k_batch_size = params.key.size(0)
v_batch_size = params.value.size(0)
# num_heads logic for nested input is checked in
# check_for_seq_len_0_nested_tensor as there is handling there to make sure
# num_heads is not ragged
return q_batch_size == k_batch_size and q_batch_size == v_batch_size
def _check_head_dim_size_flash_nested(params: SDPAParams, debug=False) -> bool:
max_size = 256
query_size_last = params.query.size(-1)
key_size_last = params.key.size(-1)
value_size_last = params.value.size(-1)
same_head_dim_size = (
query_size_last == key_size_last and query_size_last == value_size_last
)
if not (
same_head_dim_size
and (query_size_last % 8 == 0)
and (query_size_last <= max_size)
):
if debug:
log.warning(
"For NestedTensor inputs, Flash attention requires q,k,v to have the same "
"last dimension and to be a multiple of 8 and less than or equal to 256. "
"Got Query.size(-1): %d, Key.size(-1): %d, Value.size(-1): %d instead.",
query_size_last,
key_size_last,
value_size_last,
)
return False
return True
def _check_for_seq_len_0_and_consistent_head_dim_nested_helper(
param: torch.Tensor, param_name: str, debug=False
) -> bool:
assert isinstance(param, NestedTensor), "param should be a jagged NT"
if param._ragged_idx == 1:
# num_head_dims is ragged
if debug:
log.warning(
"Fused kernels do not support ragged num_head_dims, %s has a ragged num_heads.",
param_name,
)
return False
# This is being called inside sdp with shape [batch, heads, {seq_len}, dim]
if param._min_seqlen == 0:
if debug:
log.warning(
"Fused kernels do not support seq_len == 0, %s has a seq len of 0.",
param_name,
)
return False
return True
def _try_broadcast_param_size(q_size, k_size, v_size, param_name, debug=False) -> bool:
max_size = max(q_size, k_size, v_size)
if (
(q_size != max_size and q_size != 1)
or (k_size != max_size and k_size != 1)
or (v_size != max_size and v_size != 1)
):
if debug:
log.warning(
"Both fused kernels require query, key and value to have broadcastable %s, "
"got Query %s %d, Key %s %d, Value %s %d instead.",
param_name,
param_name,
q_size,
param_name,
k_size,
param_name,
v_size,
)
return False
return True
def _check_for_seq_len_0_nested(params: SDPAParams, debug=False) -> bool:
# When this function is called we are assured that the nt is dim==4
q_is_safe = (
_check_for_seq_len_0_and_consistent_head_dim_nested_helper(
params.query, "query", debug
)
if params.query.is_nested
else True
)
# short circuit if any is unsafe
if not q_is_safe:
return False
k_is_safe = (
_check_for_seq_len_0_and_consistent_head_dim_nested_helper(
params.key, "key", debug
)
if params.key.is_nested
else True
)
# short circuit if any is unsafe
if not k_is_safe:
return False
v_is_safe = (
_check_for_seq_len_0_and_consistent_head_dim_nested_helper(
params.value, "value", debug
)
if params.value.is_nested
else True
)
# short circuit if any is unsafe
if not v_is_safe:
return False
# We now know none of the inputs have ragged num_heads, so we can safely
# access .size(1)
q_num_heads = params.query.size(1)
k_num_heads = params.key.size(1)
v_num_heads = params.value.size(1)
same_num_heads = q_num_heads == k_num_heads and q_num_heads == v_num_heads
if not same_num_heads:
if (
params.query.requires_grad
or params.key.requires_grad
or params.value.requires_grad
):
if debug:
log.warning(
"Both fused kernels do not support training with broadcasted NT inputs."
)
return False
return _try_broadcast_param_size(
q_num_heads, k_num_heads, v_num_heads, "num heads", debug
)
return True
def _can_use_flash_sdpa_jagged(params: SDPAParams, debug=False) -> bool:
constraints = (
_check_batch_size_nested,
_check_head_dim_size_flash_nested,
_check_for_seq_len_0_nested,
)
for constraint in constraints:
if not constraint(params, debug):
return False
return True
def _can_use_efficient_sdpa_jagged(params: SDPAParams, debug=False) -> bool:
constraints = (
_check_batch_size_nested,
_check_for_seq_len_0_nested,
)
for constraint in constraints:
if not constraint(params, debug):
return False
return True
def _can_use_math_sdpa_jagged(params: SDPAParams, debug=False) -> bool:
if (
not params.query.transpose(1, 2).is_contiguous()
or not params.key.transpose(1, 2).is_contiguous()
or not params.value.transpose(1, 2).is_contiguous()
):
if debug:
log.warning(
"If inputs are nested tensors they must be contiguous after transposing."
)
return False
if params.is_causal:
if debug:
log.warning(
"Nested tensors for query / key are not supported when is_causal=True."
)
return False
return True
def _select_sdp_backend(query, key, value, attn_mask, dropout, is_causal):
if (
not flash_sdp_enabled()
and not mem_efficient_sdp_enabled()
and not math_sdp_enabled()
):
return SDPBackend.ERROR
ordering = (
SDPBackend.FLASH_ATTENTION,
SDPBackend.EFFICIENT_ATTENTION,
SDPBackend.MATH,
)
params = SDPAParams(query, key, value, attn_mask, dropout, is_causal)
for backend in ordering:
if backend == SDPBackend.FLASH_ATTENTION:
if can_use_flash_attention(params) and _can_use_flash_sdpa_jagged(params):
return SDPBackend.FLASH_ATTENTION
if backend == SDPBackend.EFFICIENT_ATTENTION:
if can_use_efficient_attention(params) and _can_use_efficient_sdpa_jagged(
params
):
return SDPBackend.EFFICIENT_ATTENTION
if backend == SDPBackend.MATH:
if math_sdp_enabled() and _can_use_math_sdpa_jagged(params):
return SDPBackend.MATH
log.warning("Memory efficient kernel not used because:")
can_use_efficient_attention(params, debug=True)
_can_use_efficient_sdpa_jagged(params, debug=True)
log.warning("Flash attention kernel not used because:")
can_use_flash_attention(params, debug=True)
_can_use_flash_sdpa_jagged(params, debug=True)
log.warning("Math attention kernel not used because:")
_can_use_math_sdpa_jagged(params, debug=True)
return SDPBackend.ERROR
def _cumulative_and_max_seq_len_nnz(qkv: torch.Tensor) -> Tuple[torch.Tensor, int, int]:
# This function is used to calculate two pieces of metadata that are needed
# for use with flash-attention and efficient_attention kernels. They are the
# cumulative sequence_length over a batch of sequences and the maximum
# sequence length.
# It returns a tuple of cumulative sequence lengths and the maximum sequence
# length, and the last element in the cumulative_sequence_lengths
if not isinstance(qkv, NestedTensor):
raise ValueError("QKV must be nested for flash cumulative_seq_len calculation.")
if qkv.lengths() is None:
# TODO: Explore performance impact of copying
cumulative_seqlen = qkv.offsets().to(dtype=torch.int32, device=qkv.device)
max_seqlen = qkv._max_seqlen
n_elem = qkv.values().shape[0]
else:
# TODO: Explore performance impact of copying
cumulative_seqlen = (
qkv.lengths().cumsum(0).to(dtype=torch.int32, device=qkv.device)
)
batch_size = qkv.size(0)
max_seqlen = qkv._max_seqlen
# TODO: Explore performance impact when compiling
n_elem = int(cumulative_seqlen[-1].item())
return cumulative_seqlen, max_seqlen, n_elem
def _is_safe_to_get_storage_as_tensor(tensor: torch.Tensor):
# This function checks if a nested tensor is valid for
# use with the flash-attention and efficient_attention kernels without
# needing to call contiguous on the nested tensor input.
# It checks that the storage offsets' adjacent_differences are a constant
# mutiple of the previous tensor in the nested tensor and that the strides
# are monitonically decreasing. This check is done after calling transpose on
# the nested tensor resulting in a Nt of shape [bsz, {seq_len}, num_heads, dim]
# Returns a boolean indicating if contiguous needs to be called for input
assert isinstance(tensor, NestedTensor)
offsets = tensor.offsets()
strides = tensor._stride
n_tensors = offsets.size(0) - 1
if n_tensors <= 1:
return True
# Check initially that the tensor strides are in strictly descending order
prev_stride = strides[1]
for stride in strides[2:]:
if prev_stride <= stride:
# This would mean that the last stride is greater than the seq_len
# stride
return False
prev_stride = stride
# Congrats you made it!
return True
def _view_as_dense(
tensor: torch.Tensor, Nnz: int, num_heads: int, head_dim: int
) -> torch.Tensor:
if tensor.is_nested:
return buffer_from_jagged(tensor)
return tensor.view(Nnz, num_heads, head_dim)
# TODO: Next iteration should add test cases and check it works
# def _sdpa_nested_preprocessing_with_broadcast(query, key, value):
# # Query (Batch x Num_heads x {Q_seq_len} x Dim_per_head)
# # Key (Batch x Num_heads x {KV_seq_len} x Dim_per_head)
# # Value (Batch x Num_heads x {KV_seq_len} x Dim_per_head)
# q_batch_size = query.size(0)
# k_batch_size = key.size(0)
# v_batch_size = value.size(0)
# output_batch_size = max(q_batch_size, k_batch_size, v_batch_size)
# q_num_heads = query.size(1)
# k_num_heads = key.size(1)
# v_num_heads = value.size(1)
# output_num_heads = max(q_num_heads, k_num_heads, v_num_heads)
# head_dim_qk = query.size(3)
# head_dim_v = value.size(3)
# q_t = query.transpose(1, 2)
# k_t = key.transpose(1, 2)
# v_t = value.transpose(1, 2)
# # Checks in sdp_utils ensure that if {*}_batch_size/{*}_num_heads !=
# # output_batch_size/num_heads then they are 1
# q_batch_size_needs_broadcast = q_batch_size != output_batch_size
# k_batch_size_needs_broadcast = k_batch_size != output_batch_size
# v_batch_size_needs_broadcast = v_batch_size != output_batch_size
# # If {*}_batch_size_needs_broadcast, then
# # (1) max_seqlen_batch_{*} is given by {*}_t.size(1)
# # this is because needs_broadcast indicates that the batch_size is 1
# # and hence there is only 1 value for seq_len
# # (2) The cum_seq_lens are given by [0, {*}_t.size(1), 2 * {*}_t.size(1),
# # ..., outut_batch_size * {*}_t.size(1)]
# # (3) Nnz_{*} is given by output_batch_size * {*}_t.size(1)
# if q_batch_size_needs_broadcast or not q_t.is_nested:
# max_seqlen_batch_q = q_t.size(1)
# cumulative_sequence_length_q = torch.arange(
# 0,
# (output_batch_size + 1) * max_seqlen_batch_q,
# max_seqlen_batch_q,
# device=q_t.device,
# dtype=torch.int32,
# )
# Nnz_q = output_batch_size * max_seqlen_batch_q
# else:
# (
# cumulative_sequence_length_q,
# max_seqlen_batch_q,
# Nnz_q,
# ) = _cumulative_and_max_seq_len_nnz(q_t)
# if k_batch_size_needs_broadcast and v_batch_size_needs_broadcast:
# assert k_t.size(1) == v_t.size(1)
# max_seqlen_batch_kv = k_t.size(1)
# cumulative_sequence_length_kv = torch.arange(
# 0,
# (output_batch_size + 1) * max_seqlen_batch_kv,
# max_seqlen_batch_kv,
# device=k_t.device,
# dtype=torch.int32,
# )
# Nnz_kv = output_batch_size * max_seqlen_batch_kv
# else:
# cumulative_sequence_length_kv, max_seqlen_batch_kv, Nnz_kv = (
# _cumulative_and_max_seq_len_nnz(v_t)
# if k_batch_size_needs_broadcast
# else _cumulative_and_max_seq_len_nnz(k_t)
# )
# q_num_heads_needs_broadcast = q_num_heads != output_num_heads
# k_num_heads_needs_broadcast = k_num_heads != output_num_heads
# v_num_heads_needs_broadcast = v_num_heads != output_num_heads
# if not q_t.is_nested:
# query_buffer_reshaped = q_t.expand(
# output_batch_size, q_t.size(1), output_num_heads, head_dim_qk
# )
# query_buffer_reshaped = query_buffer_reshaped.reshape(
# Nnz_q, output_num_heads, head_dim_qk
# )
# else:
# if not q_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(q_t):
# q_t = q_t.contiguous()
# # If we are broadcasting then Nnz_q will be the output_batch_size since
# # seq_len is 1
# effective_batch_size_q = (
# output_batch_size if q_batch_size_needs_broadcast else Nnz_q
# )
# query_buffer_reshaped = _view_as_dense(
# q_t, effective_batch_size_q, output_num_heads, head_dim_qk
# )
# # If the physical layout of the NestedTensor's storage
# # is not: batch, {seq_len}, num_heads, head_dim then we need
# # to call contiguous
# if not k_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(k_t):
# k_t = k_t.contiguous()
# if not v_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(v_t):
# v_t = v_t.contiguous()
# effective_batch_size_k = (
# output_batch_size if k_batch_size_needs_broadcast else Nnz_kv
# )
# key_buffer_reshaped = _view_as_dense(
# k_t, effective_batch_size_k, output_num_heads, head_dim_qk
# )
# effective_batch_size_v = (
# output_batch_size if v_batch_size_needs_broadcast else Nnz_kv
# )
# value_buffer_reshaped = _view_as_dense(
# v_t, effective_batch_size_v, output_num_heads, head_dim_v
# )
# if not q_batch_size_needs_broadcast:
# output_shape = q_t._size
# if head_dim_v != head_dim_qk:
# output_shape[-1] = head_dim_v
# if q_num_heads_needs_broadcast:
# output_shape[1] = output_num_heads
# else:
# output_shape = torch.empty(3, dtype=torch.int64, device=torch.device("cpu"))
# output_shape[0] = q_t.size(1)
# output_shape[1] = output_num_heads
# output_shape[2] = head_dim_v
# return (
# query_buffer_reshaped,
# key_buffer_reshaped,
# value_buffer_reshaped,
# cumulative_sequence_length_q,
# cumulative_sequence_length_kv,
# max_seqlen_batch_q,
# max_seqlen_batch_kv,
# output_shape,
# )
def _sdpa_nested_preprocessing(query, key, value):
# Query (Batch x Num_heads x {Q_seq_len} x Dim_per_head)
# Key (Batch x Num_heads x {KV_seq_len} x Dim_per_head)
# Value (Batch x Num_heads x {KV_seq_len} x Dim_per_head)
q_batch_size = query.size(0)
k_batch_size = key.size(0)
v_batch_size = value.size(0)
q_num_heads = query.size(1)
k_num_heads = key.size(1)
v_num_heads = value.size(1)
if not (q_batch_size == k_batch_size and q_batch_size == v_batch_size) or not (
q_num_heads == k_num_heads and k_num_heads == v_num_heads
):
raise RuntimeError(
"This path is currently not implemented for jagged layout NT."
)
# return _sdpa_nested_preprocessing_with_broadcast(query, key, value)
num_heads = query.size(1)
head_dim_qk = query.size(3)
head_dim_v = value.size(3)
q_t = query.transpose(1, 2)
k_t = key.transpose(1, 2)
v_t = value.transpose(1, 2)
(
cumulative_sequence_length_q,
max_seqlen_batch_q,
Nnz_q,
) = _cumulative_and_max_seq_len_nnz(q_t)
(
cumulative_sequence_length_kv,
max_seqlen_batch_kv,
Nnz_kv,
) = _cumulative_and_max_seq_len_nnz(k_t)
# [TODO] K and V have to have the same Nnz, should probably torch_check
# assume in order to not iterate over v
# If the physical layout of the NestedTensor's storage
# is not: batch, {seq_len}, num_heads, head_dim then we need
# to call contiguous
if not q_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(q_t):
q_t = q_t.contiguous()
if not k_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(k_t):
k_t = k_t.contiguous()
if not v_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(v_t):
v_t = v_t.contiguous()
query_buffer_reshaped = _view_as_dense(q_t, Nnz_q, num_heads, head_dim_qk)
key_buffer_reshaped = _view_as_dense(k_t, Nnz_kv, num_heads, head_dim_qk)
value_buffer_reshaped = _view_as_dense(v_t, Nnz_kv, num_heads, head_dim_v)
output_nt_info = {
"offsets": q_t.offsets(),
"_max_seqlen": q_t._max_seqlen,
"_min_seqlen": q_t._min_seqlen,
}
return (
query_buffer_reshaped,
key_buffer_reshaped,
value_buffer_reshaped,
cumulative_sequence_length_q,
cumulative_sequence_length_kv,
max_seqlen_batch_q,
max_seqlen_batch_kv,
output_nt_info,
)
def _pad_last_dim(
tensor: torch.Tensor, alignment_size: int, slice: bool
) -> torch.Tensor:
# FlashAttentionV2 requires that head dimension be a multiple of 8
# This was previously done within the kernel, however
# This causes the kernel to maybe alias query, key, value
# So instead we pad the head_dimensions to be a multiple of 8
# in the composite region
last_dim_size = tensor.size(-1)
if last_dim_size % alignment_size == 0:
return tensor
pad_count = alignment_size - (last_dim_size % alignment_size)
tensor = torch.nn.functional.pad(tensor, [0, pad_count])
if slice:
return tensor[..., 0:last_dim_size]
return tensor
# TODO: coalesce with torch/nn/utils/attention.py
def _calculate_scale(query, scale):
softmax_scale = scale if scale is not None else math.sqrt(1.0 / query.size(-1))
return softmax_scale
def _post_process_flash_output(out: torch.Tensor, og_size):
if not out.is_nested and out.size(-1) != og_size:
out = out[..., 0:og_size]
return out
def jagged_scaled_dot_product_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p=0.0,
is_causal=False,
scale=None,
):
_validate_sdpa_input(query, key, value, attn_mask, dropout_p, is_causal, scale)
# for mypy, ugh
assert (
isinstance(query, NestedTensor)
and isinstance(key, NestedTensor)
and isinstance(value, NestedTensor)
)
# Special path for non-ragged sequence length (e.g. for SAM where we have a ragged
# second batch dim instead). For this case, we can just send the dense buffers through
# vanilla SDPA.
if query.dim() > 3 and key.dim() > 3 and value.dim() > 3 and query._ragged_idx == 1:
from torch.nested._internal.ops import extract_kwargs
output = F.scaled_dot_product_attention(
query._values,
key._values,
value._values,
attn_mask=(
attn_mask._values if isinstance(attn_mask, NestedTensor) else attn_mask
),
dropout_p=dropout_p,
is_causal=is_causal,
scale=scale,
)
return NestedTensor(output, **extract_kwargs(query))
compute_logsumexp = query.requires_grad or key.requires_grad or value.requires_grad
backend_choice = _select_sdp_backend(
query, key, value, attn_mask, dropout_p, is_causal
)
if backend_choice == SDPBackend.FLASH_ATTENTION:
og_size = query.size(-1)
query_padded = _pad_last_dim(query, 8, False)
key_padded = _pad_last_dim(key, 8, False)
value_padded = _pad_last_dim(value, 8, False)
# We need to calculate the scale based off the OG head dim size
og_scale = _calculate_scale(query, scale)
(
query_buffer_reshaped,
key_buffer_reshaped,
value_buffer_reshaped,
cumulative_sequence_length_q,
cumulative_sequence_length_kv,
max_seqlen_batch_q,
max_seqlen_batch_kv,
output_nt_info,
) = _sdpa_nested_preprocessing(query_padded, key_padded, value_padded)
(
attention,
logsumexp,
philox_seed,
philox_offset,
debug_attn_mask,
) = torch.ops.aten._flash_attention_forward(
query_buffer_reshaped,
key_buffer_reshaped,
value_buffer_reshaped,
cumulative_sequence_length_q,
cumulative_sequence_length_kv,
max_seqlen_batch_q,
max_seqlen_batch_kv,
dropout_p,
is_causal,
False,
scale=og_scale,
)
# Reshape output to convert nnz to batch_size and seq_len
attention = ViewNestedFromBuffer.apply(
attention.squeeze(0), output_nt_info["offsets"]
).transpose(1, 2)
return _post_process_flash_output(attention, og_size)
elif backend_choice == SDPBackend.EFFICIENT_ATTENTION:
(
query_reshaped,
key_reshaped,
value_reshaped,
cumulative_sequence_length_q,
cumulative_sequence_length_kv,
max_seqlen_batch_q,
max_seqlen_batch_kv,
output_nt_info,
) = _sdpa_nested_preprocessing(query, key, value)
(
attention,
log_sumexp,
seed,
offset,
max_seqlen_q,
max_seqlen_batch_kv,
) = torch.ops.aten._efficient_attention_forward(
query_reshaped.unsqueeze(0),
key_reshaped.unsqueeze(0),
value_reshaped.unsqueeze(0),
None,
cumulative_sequence_length_q,
cumulative_sequence_length_kv,
max_seqlen_batch_q,
max_seqlen_batch_kv,
dropout_p,
int(is_causal),
compute_logsumexp,
scale=scale,
)
# Reshape output to convert nnz to batch_size and seq_len
return ViewNestedFromBuffer.apply(
attention.squeeze(0), output_nt_info["offsets"]
).transpose(1, 2)
elif backend_choice == SDPBackend.MATH:
# save the offsets and shape of the inputs, so we can reshape the final output
# query @ key = attn: [B, D1, j0, D'] @ [B, D1, D' j1] = [B, D1, j0, j1]
# attn @ value = out: [B, D1, j0, j1] @ [B, D1, j1, D2] = [B, D1, j0, D2]
offsets = query.offsets()
d1 = query._size[1]
d2 = value._size[-1]
# convert jagged layout Nested Tensor to strided layout Nested Tensor
# which support the math implementation of SDPA
def get_strided_layout_nested_tensor(jagged_layout_nt):
lengths = jagged_layout_nt._offsets[1:] - jagged_layout_nt._offsets[:-1]
transpose = torch.transpose(jagged_layout_nt, 1, 2)
tensor_list = buffer_from_jagged(transpose).split(list(lengths), dim=0)
strided_nt = torch.nested.as_nested_tensor(list(tensor_list))
strided_nt = strided_nt.transpose(1, 2).contiguous()
return strided_nt
query = get_strided_layout_nested_tensor(query)
key = get_strided_layout_nested_tensor(key)
value = get_strided_layout_nested_tensor(value)
attn_out = torch._scaled_dot_product_attention_math(
query, key, value, attn_mask, dropout_p, is_causal, scale=scale
)[0]
# convert strided layout Nested Tensor back to jagged layout Nested Tensor
attn_out = attn_out.transpose(1, 2).contiguous().values()
attn_out = attn_out.view(-1, d1, d2)
attn_out = ViewNestedFromBuffer.apply(attn_out, offsets)
attn_out = attn_out.transpose(1, 2)
return attn_out
else:
raise RuntimeError(
"No viable backend for scaled_dot_product_attention was found."
)