fix cuDNN SDPA meta registration (#148921)

Update `cuDNN SDPA` meta registration to matching memory layout behavior in: https://github.com/pytorch/pytorch/pull/138354

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148921
Approved by: https://github.com/drisspg, https://github.com/jbschlosser
This commit is contained in:
eqy 2025-03-13 07:33:12 +00:00 committed by PyTorch MergeBot
parent 2a7d583452
commit ec93aa7f84

View File

@ -5583,7 +5583,20 @@ def meta__scaled_dot_product_cudnn_attention(
S_KV = key.size(2) S_KV = key.size(2)
D_V = value.size(-1) D_V = value.size(-1)
res = torch.empty((B, H, S_Q, D_V), dtype=query.dtype, device=query.device) res_shape = (B, H, S_Q, D_V)
if tuple(query.shape) == res_shape:
query_t = query.transpose(1, 2)
res = torch.empty_like(query_t).transpose(1, 2)
else:
dim_order = sorted(
[0, 1, 2, 3], key=lambda idx: query.stride()[idx], reverse=True
)
permuted_shape = [res_shape[idx] for idx in dim_order]
final_permute = [dim_order.index(i) for i in range(len(dim_order))]
res = torch.empty(
permuted_shape, dtype=query.dtype, device=query.device
).permute(final_permute)
logsum_exp = torch.empty( logsum_exp = torch.empty(
(B, H, S_Q), (B, H, S_Q),
dtype=torch.float, dtype=torch.float,