mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[Nested Tensor]fix sdpa backward for the special case with ragged second batch dim and constant length (#128349)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128349 Approved by: https://github.com/jbschlosser
This commit is contained in:
parent
7b7f357042
commit
00f675bb4c
|
|
@ -5174,18 +5174,26 @@ class TestNestedTensorSubclass(TestCase):
|
|||
# S: (constant) sequence length
|
||||
# D: embedding size
|
||||
query = random_nt_from_dims(
|
||||
[4, None, 8, 10], device=device, dtype=dtype, layout=torch.jagged
|
||||
[4, None, 8, 10],
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
layout=torch.jagged,
|
||||
requires_grad=True,
|
||||
)
|
||||
key = random_nt_from_similar(query)
|
||||
value = random_nt_from_similar(query)
|
||||
output = F.scaled_dot_product_attention(query, key, value)
|
||||
self.assertTrue(isinstance(output, NestedTensor))
|
||||
output.values().sum().backward()
|
||||
|
||||
query_dense = query.clone().detach().requires_grad_(True)
|
||||
# should be equivalent to just running the buffers through
|
||||
output_dense = F.scaled_dot_product_attention(
|
||||
query._values, key._values, value._values
|
||||
query_dense.values(), key.values(), value.values()
|
||||
)
|
||||
self.assertEqual(output._values, output_dense)
|
||||
torch._dynamo.disable(self.assertEqual)(output._values, output_dense)
|
||||
output_dense.sum().backward()
|
||||
torch._dynamo.disable(self.assertEqual)(query.grad, query_dense.grad)
|
||||
|
||||
@onlyCUDA
|
||||
@unittest.skipIf(
|
||||
|
|
|
|||
|
|
@ -13,8 +13,8 @@ from torch.backends.cuda import (
|
|||
mem_efficient_sdp_enabled,
|
||||
SDPAParams,
|
||||
)
|
||||
|
||||
from torch.nn.attention import SDPBackend
|
||||
|
||||
from .nested_tensor import NestedTensor
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
|
@ -51,9 +51,9 @@ def _validate_sdpa_input(
|
|||
f"but got query.device: {query.device}, key.device: {key.device}, "
|
||||
f"and value.device: {value.device} instead."
|
||||
)
|
||||
if query.dim() < 2 or key.dim() < 2 or value.dim() < 2:
|
||||
if query.dim() < 3 or key.dim() < 3 or value.dim() < 3:
|
||||
raise ValueError(
|
||||
f"Expected query, key, and value to all be at least 2 dimensional, but got query.dim: "
|
||||
f"Expected query, key, and value to all be at least 3 dimensional, but got query.dim: "
|
||||
f"{query.dim()}, key.dim: {key.dim()} and value.dim: {value.dim()} instead."
|
||||
)
|
||||
if query._ragged_idx != key._ragged_idx or query._ragged_idx != value._ragged_idx:
|
||||
|
|
@ -630,26 +630,24 @@ def jagged_scaled_dot_product_attention(
|
|||
and isinstance(key, NestedTensor)
|
||||
and isinstance(value, NestedTensor)
|
||||
)
|
||||
from torch.nested._internal.nested_tensor import nested_view_from_values_offsets
|
||||
|
||||
# Special path for non-ragged sequence length (e.g. for SAM where we have a ragged
|
||||
# second batch dim instead). For this case, we can just send the dense buffers through
|
||||
# vanilla SDPA.
|
||||
if query.dim() > 3 and key.dim() > 3 and value.dim() > 3 and query._ragged_idx == 1:
|
||||
from torch.nested._internal.ops import extract_kwargs
|
||||
|
||||
output = F.scaled_dot_product_attention(
|
||||
query._values,
|
||||
key._values,
|
||||
value._values,
|
||||
query.values(),
|
||||
key.values(),
|
||||
value.values(),
|
||||
attn_mask=(
|
||||
attn_mask._values if isinstance(attn_mask, NestedTensor) else attn_mask
|
||||
attn_mask.values() if isinstance(attn_mask, NestedTensor) else attn_mask
|
||||
),
|
||||
dropout_p=dropout_p,
|
||||
is_causal=is_causal,
|
||||
scale=scale,
|
||||
)
|
||||
|
||||
return NestedTensor(output, **extract_kwargs(query))
|
||||
return nested_view_from_values_offsets(output, query.offsets())
|
||||
|
||||
compute_logsumexp = query.requires_grad or key.requires_grad or value.requires_grad
|
||||
|
||||
|
|
@ -694,7 +692,6 @@ def jagged_scaled_dot_product_attention(
|
|||
False,
|
||||
scale=og_scale,
|
||||
)
|
||||
from torch.nested._internal.nested_tensor import nested_view_from_values_offsets
|
||||
|
||||
# Reshape output to convert nnz to batch_size and seq_len
|
||||
attention = nested_view_from_values_offsets(
|
||||
|
|
@ -737,8 +734,6 @@ def jagged_scaled_dot_product_attention(
|
|||
scale=scale,
|
||||
)
|
||||
|
||||
from torch.nested._internal.nested_tensor import nested_view_from_values_offsets
|
||||
|
||||
# Reshape output to convert nnz to batch_size and seq_len
|
||||
return nested_view_from_values_offsets(
|
||||
attention.squeeze(0),
|
||||
|
|
@ -779,10 +774,7 @@ def jagged_scaled_dot_product_attention(
|
|||
query, key, value, attn_mask, dropout_p, is_causal, scale=scale
|
||||
)[0]
|
||||
|
||||
from torch.nested._internal.nested_tensor import (
|
||||
_load_val_from_tensor,
|
||||
nested_view_from_values_offsets,
|
||||
)
|
||||
from torch.nested._internal.nested_tensor import _load_val_from_tensor
|
||||
|
||||
# convert strided layout Nested Tensor back to jagged layout Nested Tensor
|
||||
attn_out = attn_out.transpose(1, 2).contiguous().values()
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user