mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fix meta registration for _flash_attention_forward() (#119812)
Meta registration wrongly assumes 4D inputs, while the underlying op allows 3D inputs for the `mha_varlen_fwd()` case. Testing: I added `detach()`es so the NJT test `test_sdpa_compile()` won't fail for a view-related reason. It should pass now with this fix. Pull Request resolved: https://github.com/pytorch/pytorch/pull/119812 Approved by: https://github.com/drisspg
This commit is contained in:
parent
179ecab7e7
commit
31e59766e7
|
|
@ -11,13 +11,14 @@ import numpy as np
|
|||
import torch
|
||||
import torch.nn
|
||||
import torch.nn.functional as F
|
||||
from torch.testing._internal.common_cuda import SM80OrLater
|
||||
from torch.testing._internal.common_cuda import SM70OrLater, SM80OrLater
|
||||
from torch.testing._internal.common_device_type import (
|
||||
dtypes,
|
||||
dtypesIfCUDA,
|
||||
instantiate_device_type_tests,
|
||||
onlyCPU,
|
||||
onlyCUDA,
|
||||
skipCUDAIf,
|
||||
skipMeta,
|
||||
PYTORCH_CUDA_MEMCHECK,
|
||||
)
|
||||
|
|
@ -3827,11 +3828,10 @@ class TestNestedTensorSubclass(TestCase):
|
|||
if not (str(device).startswith("cuda") and dtype == torch.bfloat16):
|
||||
check_forward_backward()
|
||||
|
||||
# This requires NT -> NT views to work in inductor, which is a TODO
|
||||
@unittest.expectedFailure # noqa: E301
|
||||
@skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
|
||||
@onlyCUDA
|
||||
@parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32] if
|
||||
SM80OrLater else [torch.float16, torch.float32])
|
||||
@dtypes(*([torch.float16, torch.bfloat16, torch.float32] if SM80OrLater
|
||||
else [torch.float16, torch.float32]))
|
||||
def test_sdpa_compile(self, device, dtype):
|
||||
batch_size = 1
|
||||
emb_dims = 1024
|
||||
|
|
@ -3857,9 +3857,9 @@ class TestNestedTensorSubclass(TestCase):
|
|||
k_d2 = key(x_d2).view(batch_size, -1, n_heads, head_dims).transpose(1, 2)
|
||||
v_d2 = value(x_d2).view(batch_size, -1, n_heads, head_dims).transpose(1, 2)
|
||||
|
||||
q_nt = query(x_nt).view(*x_nt.size()[0:2], n_heads, head_dims).transpose(1, 2)
|
||||
k_nt = key(x_nt).view(*x_nt.size()[0:2], n_heads, head_dims).transpose(1, 2)
|
||||
v_nt = value(x_nt).view(*x_nt.size()[0:2], n_heads, head_dims).transpose(1, 2)
|
||||
q_nt = query(x_nt).view(*x_nt.size()[0:2], n_heads, head_dims).detach().transpose(1, 2)
|
||||
k_nt = key(x_nt).view(*x_nt.size()[0:2], n_heads, head_dims).detach().transpose(1, 2)
|
||||
v_nt = value(x_nt).view(*x_nt.size()[0:2], n_heads, head_dims).detach().transpose(1, 2)
|
||||
|
||||
# High Precision Math Reference
|
||||
q_d1_f32 = q_d1.to(torch.float32)
|
||||
|
|
|
|||
|
|
@ -5354,12 +5354,15 @@ def meta__flash_attention_forward(
|
|||
return_debug_mask: bool,
|
||||
scale: Optional[float] = None,
|
||||
):
|
||||
batch_size = query.size(0)
|
||||
max_seqlen_batch_q = query.size(1)
|
||||
num_heads = query.size(2)
|
||||
head_dim = query.size(3)
|
||||
|
||||
max_seqlen_batch_k = key.size(1)
|
||||
# NB: there are two underlying paths:
|
||||
# 1. normal dense path; expect 4D inputs of shape (batch_size, seqlen, num_heads, head_dim)
|
||||
# 2. varseqlen path; expect 3D inputs of shape (total, num_heads, head_dim) where total
|
||||
# includes all batch item sequences. cum_seq_q / cum_seq_k contain offsets into total
|
||||
batch_size = query.size(0) if cum_seq_q is None else cum_seq_q.numel() - 1
|
||||
max_seqlen_batch_q = query.size(1) if cum_seq_q is None else max_q
|
||||
max_seqlen_batch_k = key.size(1) if cum_seq_k is None else max_k
|
||||
num_heads = query.size(-2)
|
||||
head_dim = query.size(-1)
|
||||
|
||||
# Cuda Path
|
||||
attention = torch.empty_like(query)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import logging
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
|
@ -604,7 +603,7 @@ def _pad_last_dim(
|
|||
|
||||
# TODO: coalesce with torch/nn/utils/attention.py
|
||||
def _calculate_scale(query, scale):
|
||||
softmax_scale = scale if scale is not None else math.sqrt(1.0 / query.size(-1))
|
||||
softmax_scale = scale if scale is not None else torch.sym_sqrt(1.0 / query.size(-1))
|
||||
return softmax_scale
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user