mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[Nested Tensor]fix sdpa backward for the special case with ragged second batch dim and constant length (#128349)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128349 Approved by: https://github.com/jbschlosser
This commit is contained in:
parent
7b7f357042
commit
00f675bb4c
|
|
@ -5174,18 +5174,26 @@ class TestNestedTensorSubclass(TestCase):
|
||||||
# S: (constant) sequence length
|
# S: (constant) sequence length
|
||||||
# D: embedding size
|
# D: embedding size
|
||||||
query = random_nt_from_dims(
|
query = random_nt_from_dims(
|
||||||
[4, None, 8, 10], device=device, dtype=dtype, layout=torch.jagged
|
[4, None, 8, 10],
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
layout=torch.jagged,
|
||||||
|
requires_grad=True,
|
||||||
)
|
)
|
||||||
key = random_nt_from_similar(query)
|
key = random_nt_from_similar(query)
|
||||||
value = random_nt_from_similar(query)
|
value = random_nt_from_similar(query)
|
||||||
output = F.scaled_dot_product_attention(query, key, value)
|
output = F.scaled_dot_product_attention(query, key, value)
|
||||||
self.assertTrue(isinstance(output, NestedTensor))
|
self.assertTrue(isinstance(output, NestedTensor))
|
||||||
|
output.values().sum().backward()
|
||||||
|
|
||||||
|
query_dense = query.clone().detach().requires_grad_(True)
|
||||||
# should be equivalent to just running the buffers through
|
# should be equivalent to just running the buffers through
|
||||||
output_dense = F.scaled_dot_product_attention(
|
output_dense = F.scaled_dot_product_attention(
|
||||||
query._values, key._values, value._values
|
query_dense.values(), key.values(), value.values()
|
||||||
)
|
)
|
||||||
self.assertEqual(output._values, output_dense)
|
torch._dynamo.disable(self.assertEqual)(output._values, output_dense)
|
||||||
|
output_dense.sum().backward()
|
||||||
|
torch._dynamo.disable(self.assertEqual)(query.grad, query_dense.grad)
|
||||||
|
|
||||||
@onlyCUDA
|
@onlyCUDA
|
||||||
@unittest.skipIf(
|
@unittest.skipIf(
|
||||||
|
|
|
||||||
|
|
@ -13,8 +13,8 @@ from torch.backends.cuda import (
|
||||||
mem_efficient_sdp_enabled,
|
mem_efficient_sdp_enabled,
|
||||||
SDPAParams,
|
SDPAParams,
|
||||||
)
|
)
|
||||||
|
|
||||||
from torch.nn.attention import SDPBackend
|
from torch.nn.attention import SDPBackend
|
||||||
|
|
||||||
from .nested_tensor import NestedTensor
|
from .nested_tensor import NestedTensor
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
@ -51,9 +51,9 @@ def _validate_sdpa_input(
|
||||||
f"but got query.device: {query.device}, key.device: {key.device}, "
|
f"but got query.device: {query.device}, key.device: {key.device}, "
|
||||||
f"and value.device: {value.device} instead."
|
f"and value.device: {value.device} instead."
|
||||||
)
|
)
|
||||||
if query.dim() < 2 or key.dim() < 2 or value.dim() < 2:
|
if query.dim() < 3 or key.dim() < 3 or value.dim() < 3:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Expected query, key, and value to all be at least 2 dimensional, but got query.dim: "
|
f"Expected query, key, and value to all be at least 3 dimensional, but got query.dim: "
|
||||||
f"{query.dim()}, key.dim: {key.dim()} and value.dim: {value.dim()} instead."
|
f"{query.dim()}, key.dim: {key.dim()} and value.dim: {value.dim()} instead."
|
||||||
)
|
)
|
||||||
if query._ragged_idx != key._ragged_idx or query._ragged_idx != value._ragged_idx:
|
if query._ragged_idx != key._ragged_idx or query._ragged_idx != value._ragged_idx:
|
||||||
|
|
@ -630,26 +630,24 @@ def jagged_scaled_dot_product_attention(
|
||||||
and isinstance(key, NestedTensor)
|
and isinstance(key, NestedTensor)
|
||||||
and isinstance(value, NestedTensor)
|
and isinstance(value, NestedTensor)
|
||||||
)
|
)
|
||||||
|
from torch.nested._internal.nested_tensor import nested_view_from_values_offsets
|
||||||
|
|
||||||
# Special path for non-ragged sequence length (e.g. for SAM where we have a ragged
|
# Special path for non-ragged sequence length (e.g. for SAM where we have a ragged
|
||||||
# second batch dim instead). For this case, we can just send the dense buffers through
|
# second batch dim instead). For this case, we can just send the dense buffers through
|
||||||
# vanilla SDPA.
|
# vanilla SDPA.
|
||||||
if query.dim() > 3 and key.dim() > 3 and value.dim() > 3 and query._ragged_idx == 1:
|
if query.dim() > 3 and key.dim() > 3 and value.dim() > 3 and query._ragged_idx == 1:
|
||||||
from torch.nested._internal.ops import extract_kwargs
|
|
||||||
|
|
||||||
output = F.scaled_dot_product_attention(
|
output = F.scaled_dot_product_attention(
|
||||||
query._values,
|
query.values(),
|
||||||
key._values,
|
key.values(),
|
||||||
value._values,
|
value.values(),
|
||||||
attn_mask=(
|
attn_mask=(
|
||||||
attn_mask._values if isinstance(attn_mask, NestedTensor) else attn_mask
|
attn_mask.values() if isinstance(attn_mask, NestedTensor) else attn_mask
|
||||||
),
|
),
|
||||||
dropout_p=dropout_p,
|
dropout_p=dropout_p,
|
||||||
is_causal=is_causal,
|
is_causal=is_causal,
|
||||||
scale=scale,
|
scale=scale,
|
||||||
)
|
)
|
||||||
|
return nested_view_from_values_offsets(output, query.offsets())
|
||||||
return NestedTensor(output, **extract_kwargs(query))
|
|
||||||
|
|
||||||
compute_logsumexp = query.requires_grad or key.requires_grad or value.requires_grad
|
compute_logsumexp = query.requires_grad or key.requires_grad or value.requires_grad
|
||||||
|
|
||||||
|
|
@ -694,7 +692,6 @@ def jagged_scaled_dot_product_attention(
|
||||||
False,
|
False,
|
||||||
scale=og_scale,
|
scale=og_scale,
|
||||||
)
|
)
|
||||||
from torch.nested._internal.nested_tensor import nested_view_from_values_offsets
|
|
||||||
|
|
||||||
# Reshape output to convert nnz to batch_size and seq_len
|
# Reshape output to convert nnz to batch_size and seq_len
|
||||||
attention = nested_view_from_values_offsets(
|
attention = nested_view_from_values_offsets(
|
||||||
|
|
@ -737,8 +734,6 @@ def jagged_scaled_dot_product_attention(
|
||||||
scale=scale,
|
scale=scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
from torch.nested._internal.nested_tensor import nested_view_from_values_offsets
|
|
||||||
|
|
||||||
# Reshape output to convert nnz to batch_size and seq_len
|
# Reshape output to convert nnz to batch_size and seq_len
|
||||||
return nested_view_from_values_offsets(
|
return nested_view_from_values_offsets(
|
||||||
attention.squeeze(0),
|
attention.squeeze(0),
|
||||||
|
|
@ -779,10 +774,7 @@ def jagged_scaled_dot_product_attention(
|
||||||
query, key, value, attn_mask, dropout_p, is_causal, scale=scale
|
query, key, value, attn_mask, dropout_p, is_causal, scale=scale
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
from torch.nested._internal.nested_tensor import (
|
from torch.nested._internal.nested_tensor import _load_val_from_tensor
|
||||||
_load_val_from_tensor,
|
|
||||||
nested_view_from_values_offsets,
|
|
||||||
)
|
|
||||||
|
|
||||||
# convert strided layout Nested Tensor back to jagged layout Nested Tensor
|
# convert strided layout Nested Tensor back to jagged layout Nested Tensor
|
||||||
attn_out = attn_out.transpose(1, 2).contiguous().values()
|
attn_out = attn_out.transpose(1, 2).contiguous().values()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user