mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
fix cuDNN SDPA meta registration (#148921)
Update `cuDNN SDPA` meta registration to matching memory layout behavior in: https://github.com/pytorch/pytorch/pull/138354 Pull Request resolved: https://github.com/pytorch/pytorch/pull/148921 Approved by: https://github.com/drisspg, https://github.com/jbschlosser
This commit is contained in:
parent
2a7d583452
commit
ec93aa7f84
|
|
@ -5583,7 +5583,20 @@ def meta__scaled_dot_product_cudnn_attention(
|
|||
S_KV = key.size(2)
|
||||
D_V = value.size(-1)
|
||||
|
||||
res = torch.empty((B, H, S_Q, D_V), dtype=query.dtype, device=query.device)
|
||||
res_shape = (B, H, S_Q, D_V)
|
||||
if tuple(query.shape) == res_shape:
|
||||
query_t = query.transpose(1, 2)
|
||||
res = torch.empty_like(query_t).transpose(1, 2)
|
||||
else:
|
||||
dim_order = sorted(
|
||||
[0, 1, 2, 3], key=lambda idx: query.stride()[idx], reverse=True
|
||||
)
|
||||
permuted_shape = [res_shape[idx] for idx in dim_order]
|
||||
final_permute = [dim_order.index(i) for i in range(len(dim_order))]
|
||||
res = torch.empty(
|
||||
permuted_shape, dtype=query.dtype, device=query.device
|
||||
).permute(final_permute)
|
||||
|
||||
logsum_exp = torch.empty(
|
||||
(B, H, S_Q),
|
||||
dtype=torch.float,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user