[Nested Tensor]fix sdpa backward for the special case with ragged second batch dim and constant length (#128349)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128349
Approved by: https://github.com/jbschlosser
This commit is contained in:
yuqingj 2024-06-24 12:01:43 -07:00 committed by PyTorch MergeBot
parent 7b7f357042
commit 00f675bb4c
2 changed files with 21 additions and 21 deletions

View File

@ -5174,18 +5174,26 @@ class TestNestedTensorSubclass(TestCase):
# S: (constant) sequence length
# D: embedding size
query = random_nt_from_dims(
[4, None, 8, 10], device=device, dtype=dtype, layout=torch.jagged
[4, None, 8, 10],
device=device,
dtype=dtype,
layout=torch.jagged,
requires_grad=True,
)
key = random_nt_from_similar(query)
value = random_nt_from_similar(query)
output = F.scaled_dot_product_attention(query, key, value)
self.assertTrue(isinstance(output, NestedTensor))
output.values().sum().backward()
query_dense = query.clone().detach().requires_grad_(True)
# should be equivalent to just running the buffers through
output_dense = F.scaled_dot_product_attention(
query._values, key._values, value._values
query_dense.values(), key.values(), value.values()
)
self.assertEqual(output._values, output_dense)
torch._dynamo.disable(self.assertEqual)(output._values, output_dense)
output_dense.sum().backward()
torch._dynamo.disable(self.assertEqual)(query.grad, query_dense.grad)
@onlyCUDA
@unittest.skipIf(

View File

@ -13,8 +13,8 @@ from torch.backends.cuda import (
mem_efficient_sdp_enabled,
SDPAParams,
)
from torch.nn.attention import SDPBackend
from .nested_tensor import NestedTensor
log = logging.getLogger(__name__)
@ -51,9 +51,9 @@ def _validate_sdpa_input(
f"but got query.device: {query.device}, key.device: {key.device}, "
f"and value.device: {value.device} instead."
)
if query.dim() < 2 or key.dim() < 2 or value.dim() < 2:
if query.dim() < 3 or key.dim() < 3 or value.dim() < 3:
raise ValueError(
f"Expected query, key, and value to all be at least 2 dimensional, but got query.dim: "
f"Expected query, key, and value to all be at least 3 dimensional, but got query.dim: "
f"{query.dim()}, key.dim: {key.dim()} and value.dim: {value.dim()} instead."
)
if query._ragged_idx != key._ragged_idx or query._ragged_idx != value._ragged_idx:
@ -630,26 +630,24 @@ def jagged_scaled_dot_product_attention(
and isinstance(key, NestedTensor)
and isinstance(value, NestedTensor)
)
from torch.nested._internal.nested_tensor import nested_view_from_values_offsets
# Special path for non-ragged sequence length (e.g. for SAM where we have a ragged
# second batch dim instead). For this case, we can just send the dense buffers through
# vanilla SDPA.
if query.dim() > 3 and key.dim() > 3 and value.dim() > 3 and query._ragged_idx == 1:
from torch.nested._internal.ops import extract_kwargs
output = F.scaled_dot_product_attention(
query._values,
key._values,
value._values,
query.values(),
key.values(),
value.values(),
attn_mask=(
attn_mask._values if isinstance(attn_mask, NestedTensor) else attn_mask
attn_mask.values() if isinstance(attn_mask, NestedTensor) else attn_mask
),
dropout_p=dropout_p,
is_causal=is_causal,
scale=scale,
)
return NestedTensor(output, **extract_kwargs(query))
return nested_view_from_values_offsets(output, query.offsets())
compute_logsumexp = query.requires_grad or key.requires_grad or value.requires_grad
@ -694,7 +692,6 @@ def jagged_scaled_dot_product_attention(
False,
scale=og_scale,
)
from torch.nested._internal.nested_tensor import nested_view_from_values_offsets
# Reshape output to convert nnz to batch_size and seq_len
attention = nested_view_from_values_offsets(
@ -737,8 +734,6 @@ def jagged_scaled_dot_product_attention(
scale=scale,
)
from torch.nested._internal.nested_tensor import nested_view_from_values_offsets
# Reshape output to convert nnz to batch_size and seq_len
return nested_view_from_values_offsets(
attention.squeeze(0),
@ -779,10 +774,7 @@ def jagged_scaled_dot_product_attention(
query, key, value, attn_mask, dropout_p, is_causal, scale=scale
)[0]
from torch.nested._internal.nested_tensor import (
_load_val_from_tensor,
nested_view_from_values_offsets,
)
from torch.nested._internal.nested_tensor import _load_val_from_tensor
# convert strided layout Nested Tensor back to jagged layout Nested Tensor
attn_out = attn_out.transpose(1, 2).contiguous().values()