Fix meta registration for _flash_attention_forward() (#119812)

Meta registration wrongly assumes 4D inputs, while the underlying op allows 3D inputs for the `mha_varlen_fwd()` case.
Testing: I added `detach()`es so the NJT test `test_sdpa_compile()` won't fail for a view-related reason. It should pass now with this fix.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119812
Approved by: https://github.com/drisspg
This commit is contained in:
Joel Schlosser 2024-02-13 18:29:52 -05:00 committed by PyTorch MergeBot
parent 179ecab7e7
commit 31e59766e7
3 changed files with 18 additions and 16 deletions

View File

@ -11,13 +11,14 @@ import numpy as np
import torch
import torch.nn
import torch.nn.functional as F
from torch.testing._internal.common_cuda import SM80OrLater
from torch.testing._internal.common_cuda import SM70OrLater, SM80OrLater
from torch.testing._internal.common_device_type import (
dtypes,
dtypesIfCUDA,
instantiate_device_type_tests,
onlyCPU,
onlyCUDA,
skipCUDAIf,
skipMeta,
PYTORCH_CUDA_MEMCHECK,
)
@ -3827,11 +3828,10 @@ class TestNestedTensorSubclass(TestCase):
if not (str(device).startswith("cuda") and dtype == torch.bfloat16):
check_forward_backward()
# This requires NT -> NT views to work in inductor, which is a TODO
@unittest.expectedFailure # noqa: E301
@skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
@onlyCUDA
@parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32] if
SM80OrLater else [torch.float16, torch.float32])
@dtypes(*([torch.float16, torch.bfloat16, torch.float32] if SM80OrLater
else [torch.float16, torch.float32]))
def test_sdpa_compile(self, device, dtype):
batch_size = 1
emb_dims = 1024
@ -3857,9 +3857,9 @@ class TestNestedTensorSubclass(TestCase):
k_d2 = key(x_d2).view(batch_size, -1, n_heads, head_dims).transpose(1, 2)
v_d2 = value(x_d2).view(batch_size, -1, n_heads, head_dims).transpose(1, 2)
q_nt = query(x_nt).view(*x_nt.size()[0:2], n_heads, head_dims).transpose(1, 2)
k_nt = key(x_nt).view(*x_nt.size()[0:2], n_heads, head_dims).transpose(1, 2)
v_nt = value(x_nt).view(*x_nt.size()[0:2], n_heads, head_dims).transpose(1, 2)
q_nt = query(x_nt).view(*x_nt.size()[0:2], n_heads, head_dims).detach().transpose(1, 2)
k_nt = key(x_nt).view(*x_nt.size()[0:2], n_heads, head_dims).detach().transpose(1, 2)
v_nt = value(x_nt).view(*x_nt.size()[0:2], n_heads, head_dims).detach().transpose(1, 2)
# High Precision Math Reference
q_d1_f32 = q_d1.to(torch.float32)

View File

@ -5354,12 +5354,15 @@ def meta__flash_attention_forward(
return_debug_mask: bool,
scale: Optional[float] = None,
):
batch_size = query.size(0)
max_seqlen_batch_q = query.size(1)
num_heads = query.size(2)
head_dim = query.size(3)
max_seqlen_batch_k = key.size(1)
# NB: there are two underlying paths:
# 1. normal dense path; expect 4D inputs of shape (batch_size, seqlen, num_heads, head_dim)
# 2. varseqlen path; expect 3D inputs of shape (total, num_heads, head_dim) where total
# includes all batch item sequences. cum_seq_q / cum_seq_k contain offsets into total
batch_size = query.size(0) if cum_seq_q is None else cum_seq_q.numel() - 1
max_seqlen_batch_q = query.size(1) if cum_seq_q is None else max_q
max_seqlen_batch_k = key.size(1) if cum_seq_k is None else max_k
num_heads = query.size(-2)
head_dim = query.size(-1)
# Cuda Path
attention = torch.empty_like(query)

View File

@ -1,5 +1,4 @@
import logging
import math
from typing import Optional, Tuple
import torch
@ -604,7 +603,7 @@ def _pad_last_dim(
# TODO: coalesce with torch/nn/utils/attention.py
def _calculate_scale(query, scale):
softmax_scale = scale if scale is not None else math.sqrt(1.0 / query.size(-1))
softmax_scale = scale if scale is not None else torch.sym_sqrt(1.0 / query.size(-1))
return softmax_scale