mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Expose Flash attn to autograd (#114378)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114378 Approved by: https://github.com/drisspg
This commit is contained in:
parent
80d8a2a237
commit
d47f715d29
|
|
@ -14493,7 +14493,7 @@
|
|||
CUDA: _scaled_dot_product_efficient_attention_backward_cuda
|
||||
tags: nondeterministic_seeded
|
||||
|
||||
- func: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt? max_q, SymInt? max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
|
||||
- func: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
|
||||
variants: function
|
||||
dispatch:
|
||||
CUDA: _flash_attention_forward
|
||||
|
|
|
|||
|
|
@ -708,8 +708,8 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt, Tensor, Ten
|
|||
v_t,
|
||||
c10::nullopt,
|
||||
c10::nullopt,
|
||||
c10::nullopt,
|
||||
c10::nullopt,
|
||||
max_seqlen_batch_q,
|
||||
max_seqlen_batch_k,
|
||||
dropout_p,
|
||||
is_causal,
|
||||
return_debug_mask,
|
||||
|
|
@ -779,8 +779,8 @@ _flash_attention_forward(
|
|||
const Tensor& value,
|
||||
const c10::optional<Tensor>& cumulative_sequence_length_q,
|
||||
const c10::optional<Tensor>& cumulative_sequence_length_k,
|
||||
c10::optional<int64_t> max_seqlen_batch_q,
|
||||
c10::optional<int64_t> max_seqlen_batch_k,
|
||||
int64_t max_seqlen_batch_q,
|
||||
int64_t max_seqlen_batch_k,
|
||||
double dropout_p,
|
||||
bool is_causal,
|
||||
bool return_debug_mask,
|
||||
|
|
@ -797,15 +797,9 @@ _flash_attention_forward(
|
|||
cumulative_sequence_length_q.has_value() ==
|
||||
cumulative_sequence_length_k.has_value(),
|
||||
"cumulative_sequence_length_q and cumulative_sequence_length_k must be both set or both not set");
|
||||
TORCH_CHECK(
|
||||
max_seqlen_batch_q.has_value() == max_seqlen_batch_k.has_value(),
|
||||
"max_seqlen_batch_q and max_seqlen_batch_k must be both set or both not set");
|
||||
Tensor output, q_padded, k_padded, v_padded, logsumexp, output_shape,
|
||||
philox_seed, philox_offset, debug_attn_mask;
|
||||
if (cumulative_sequence_length_q.has_value()) {
|
||||
TORCH_CHECK(
|
||||
max_seqlen_batch_q.has_value(),
|
||||
"max_seqlen_batch_q must be set when cumulative_sequence_length_q is set");
|
||||
std::tie(
|
||||
output,
|
||||
q_padded,
|
||||
|
|
@ -822,8 +816,8 @@ _flash_attention_forward(
|
|||
out,
|
||||
cumulative_sequence_length_q.value(),
|
||||
cumulative_sequence_length_k.value(),
|
||||
max_seqlen_batch_q.value(),
|
||||
max_seqlen_batch_k.value(),
|
||||
max_seqlen_batch_q,
|
||||
max_seqlen_batch_k,
|
||||
dropout_p,
|
||||
softmax_scale,
|
||||
false /*zero_tensors*/,
|
||||
|
|
|
|||
|
|
@ -71,7 +71,7 @@ std::tuple<Tensor, Tensor, Tensor> _flash_attention_backward(
|
|||
c10::optional<at::Tensor> dk{c10::nullopt};
|
||||
c10::optional<at::Tensor> dv{c10::nullopt};
|
||||
|
||||
// The kernel computes irregadless we will drop for this functions return
|
||||
// The kernel computes irregardless we will drop for this functions return
|
||||
Tensor grad_softmax;
|
||||
|
||||
// We check the whether the cumulative_sequence_length_q is defined
|
||||
|
|
|
|||
|
|
@ -266,6 +266,8 @@ ALLOW_LIST = [
|
|||
("aten::_upsample_nearest_exact2d_backward", datetime.date(2022, 12, 15)),
|
||||
("aten::_efficient_attention_forward", datetime.date(2023, 11, 30)),
|
||||
("aten::_efficient_attention_backward", datetime.date(2023, 11, 30)),
|
||||
("aten::_flash_attention_forward", datetime.date(2023, 12, 30)),
|
||||
("aten::_flash_attention_backward", datetime.date(2023, 12, 30)),
|
||||
("aten::_sparse_mask_helper", datetime.date(2023, 3, 15)),
|
||||
("mkldnn::_convolution_pointwise.binary", datetime.date(2022, 12, 15)),
|
||||
("prim::CudaFusionIvalGuard", datetime.date(2023, 2, 1)),
|
||||
|
|
|
|||
|
|
@ -392,6 +392,7 @@ class TestOperators(TestCase):
|
|||
# query: last dimension must be contiguous
|
||||
# Fused attention kernels require last dim to be contiguous
|
||||
xfail('nn.functional.scaled_dot_product_attention'),
|
||||
xfail("torch.ops.aten._flash_attention_forward"),
|
||||
xfail("torch.ops.aten._efficient_attention_forward"),
|
||||
}))
|
||||
@opsToleranceOverride('TestOperators', 'test_grad', (
|
||||
|
|
@ -475,6 +476,7 @@ class TestOperators(TestCase):
|
|||
xfail("_native_batch_norm_legit"), # TODO: fails comparing None to tensor of 0s for saved_mean/var tangents
|
||||
|
||||
xfail('nn.functional.scaled_dot_product_attention'),
|
||||
xfail('torch.ops.aten._flash_attention_forward'),
|
||||
xfail('torch.ops.aten._efficient_attention_forward'),
|
||||
|
||||
xfail('nn.functional.rrelu'), # in-place test errors out with no formula implemented
|
||||
|
|
@ -603,6 +605,7 @@ class TestOperators(TestCase):
|
|||
# RuntimeError: query: last dimension must be contiguous
|
||||
# The fused attention kernels require the last dim to be contiguous
|
||||
xfail('nn.functional.scaled_dot_product_attention'),
|
||||
xfail('torch.ops.aten._flash_attention_forward'),
|
||||
xfail('torch.ops.aten._efficient_attention_forward'),
|
||||
# BUG
|
||||
# AssertionError: Tensor-likes are not close!
|
||||
|
|
@ -682,6 +685,7 @@ class TestOperators(TestCase):
|
|||
xfail('sparse.sampled_addmm', ''), # sparse tensors have no strides
|
||||
xfail('sparse.mm', 'reduce'), # sparse tensors have no strides
|
||||
skip('nn.functional.scaled_dot_product_attention'),
|
||||
xfail('torch.ops.aten._flash_attention_forward'),
|
||||
xfail('torch.ops.aten._efficient_attention_forward'),
|
||||
# AssertionError: Tensor-likes are not close!
|
||||
# Mismatched elements: 1 / 15 (6.7%)
|
||||
|
|
|
|||
|
|
@ -3573,6 +3573,9 @@ class TestVmapOperatorsOpInfo(TestCase):
|
|||
xfail('le', device_type='cuda'),
|
||||
xfail('lt', device_type='cuda'),
|
||||
xfail('ne', device_type='cuda'),
|
||||
|
||||
# RuntimeError: aten::_flash_attention_forward hit the vmap fallback which is currently disabled
|
||||
xfail('torch.ops.aten._flash_attention_forward'),
|
||||
}
|
||||
|
||||
@with_tf32_off # https://github.com/pytorch/pytorch/issues/86798
|
||||
|
|
|
|||
|
|
@ -238,6 +238,7 @@ inductor_expected_failures_single_sample["cuda"] = {
|
|||
"sparse.sampled_addmm": {f32, f64},
|
||||
"to_sparse": {f16, f32, f64},
|
||||
"torch.ops.aten._efficient_attention_forward": {f16, bf16, f32},
|
||||
"torch.ops.aten._flash_attention_forward": {f16, bf16, f32},
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -2779,9 +2779,9 @@
|
|||
output_differentiability: [True, False, False, False, False, False, False, False, False]
|
||||
query, key, value: _scaled_dot_product_flash_attention_backward_symint(grad, query, key, value, output, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, scale)
|
||||
|
||||
# - name: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, int? max_q, int? max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None) -> (Tensor output, Tensor query_padded, Tensor key_padded, Tensor value_padded, Tensor softmax_logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
|
||||
# output_differentiability: [True, False, False, False, False, False, False, False]
|
||||
# query, key, value: _flash_attention_backward(grad, query, key, value, output, softmax_logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, scale)
|
||||
- name: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
|
||||
output_differentiability: [True, False, False, False, False]
|
||||
query, key, value: _flash_attention_backward_symint(grad, query, key, value, output, softmax_logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, scale)
|
||||
|
||||
- name: _efficient_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seqlen_q, float dropout_p, int custom_mask_type, bool compute_log_sumexp=False, *, float? scale=None, Tensor? causal_diagonal=None, Tensor? seqlen_k=None) -> (Tensor output, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, SymInt max_seqlen_batch_q, SymInt max_seqlen_batch_k)
|
||||
output_differentiability: [True, False, False, False, False, False]
|
||||
|
|
|
|||
|
|
@ -2022,8 +2022,6 @@ make_fallback(aten.cumsum, require_dense, warn=False)
|
|||
make_fallback(aten.cumprod, require_dense, warn=False)
|
||||
make_fallback(aten._embedding_bag, require_contiguous)
|
||||
make_fallback(aten._embedding_bag_forward_only, require_contiguous)
|
||||
make_fallback(aten._flash_attention_forward)
|
||||
make_fallback(aten._flash_attention_backward)
|
||||
make_fallback(aten._fused_moving_avg_obs_fq_helper)
|
||||
make_fallback(aten._fused_moving_avg_obs_fq_helper_functional)
|
||||
make_fallback(aten.grid_sampler_2d_backward, require_dense)
|
||||
|
|
@ -2113,8 +2111,10 @@ make_fallback(
|
|||
sdpa_constraint,
|
||||
warn=False,
|
||||
)
|
||||
make_fallback(torch.ops.aten._efficient_attention_forward.default)
|
||||
make_fallback(torch.ops.aten._efficient_attention_backward.default)
|
||||
make_fallback(aten._flash_attention_forward.default, sdpa_constraint)
|
||||
make_fallback(aten._flash_attention_backward.default, sdpa_constraint)
|
||||
make_fallback(aten._efficient_attention_forward.default, sdpa_constraint)
|
||||
make_fallback(aten._efficient_attention_backward.default, sdpa_constraint)
|
||||
make_fallback(aten.sort)
|
||||
make_fallback(aten.sort.stable)
|
||||
make_fallback(aten._sparse_coo_tensor_with_dims_and_tensors)
|
||||
|
|
|
|||
|
|
@ -5296,6 +5296,93 @@ def meta__scaled_dot_product_efficient_backward(
|
|||
return grad_q, grad_k, grad_v, grad_bias
|
||||
|
||||
|
||||
@register_meta(
|
||||
[
|
||||
aten._flash_attention_forward,
|
||||
]
|
||||
)
|
||||
def meta__flash_attention_forward(
|
||||
query: Tensor,
|
||||
key: Tensor,
|
||||
value: Tensor,
|
||||
cum_seq_q: Optional[Tensor],
|
||||
cum_seq_k: Optional[Tensor],
|
||||
max_q: int,
|
||||
max_k: int,
|
||||
dropout_p: float,
|
||||
is_causal: bool,
|
||||
return_debug_mask: bool,
|
||||
scale: Optional[float] = None,
|
||||
):
|
||||
batch_size = query.size(0)
|
||||
max_seqlen_batch_q = query.size(1)
|
||||
num_heads = query.size(2)
|
||||
head_dim = query.size(3)
|
||||
|
||||
max_seqlen_batch_k = key.size(1)
|
||||
|
||||
# Cuda Path
|
||||
attention = torch.empty_like(query)
|
||||
logsumexp = torch.empty(
|
||||
(batch_size, num_heads, max_seqlen_batch_q),
|
||||
dtype=torch.float,
|
||||
device=query.device,
|
||||
)
|
||||
|
||||
if return_debug_mask:
|
||||
blocksize_c = 128 if head_dim > 64 else 256
|
||||
max_seqlen_k = math.ceil(max_seqlen_batch_q / blocksize_c)
|
||||
if max_seqlen_batch_k <= 128:
|
||||
max_seqlen_k = 128
|
||||
elif max_seqlen_batch_k <= 256:
|
||||
max_seqlen_k = 256
|
||||
debug_mask = torch.empty(
|
||||
(batch_size, num_heads, max_seqlen_batch_q, max_seqlen_k),
|
||||
dtype=query.dtype,
|
||||
device=query.device,
|
||||
)
|
||||
else:
|
||||
debug_mask = torch.empty(0, dtype=query.dtype, device=query.device)
|
||||
|
||||
# See Note [Seed and Offset]:
|
||||
return (
|
||||
attention,
|
||||
logsumexp,
|
||||
torch.empty((), dtype=torch.long, device="meta"),
|
||||
torch.empty((), dtype=torch.long, device="meta"),
|
||||
debug_mask,
|
||||
)
|
||||
|
||||
|
||||
@register_meta(
|
||||
[
|
||||
aten._flash_attention_backward,
|
||||
]
|
||||
)
|
||||
def meta__flash_attention_backward(
|
||||
grad_out: Tensor,
|
||||
query: Tensor,
|
||||
key: Tensor,
|
||||
value: Tensor,
|
||||
out: Tensor,
|
||||
logsumexp: Tensor,
|
||||
cum_seq_q: Tensor,
|
||||
cum_seq_k: Tensor,
|
||||
max_q: int,
|
||||
max_k: int,
|
||||
dropout_p: float,
|
||||
is_causal: bool,
|
||||
philox_seed: Tensor,
|
||||
philox_offset: Tensor,
|
||||
scale: Optional[float] = None,
|
||||
):
|
||||
grad_query = torch.empty_like(query)
|
||||
grad_key = torch.empty_like(key)
|
||||
grad_value = torch.empty_like(value)
|
||||
|
||||
return grad_query, grad_key, grad_value
|
||||
|
||||
|
||||
@register_meta(
|
||||
[
|
||||
aten._efficient_attention_forward,
|
||||
|
|
|
|||
|
|
@ -47,7 +47,10 @@ def output_alias_each_other(outputs):
|
|||
|
||||
def is_sdpa_error(func, idx, e):
|
||||
if (
|
||||
func is aten._scaled_dot_product_flash_attention.default
|
||||
(
|
||||
func is aten._scaled_dot_product_flash_attention.default
|
||||
or func is aten._flash_attention_forward.default
|
||||
)
|
||||
and idx in (6, 7)
|
||||
and "Devices" in repr(e)
|
||||
):
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ from torch.testing._internal.common_device_type import \
|
|||
skipCPUIfNoMklSparse,
|
||||
toleranceOverride, tol)
|
||||
from torch.testing._internal.common_cuda import (
|
||||
SM53OrLater, SM60OrLater, SM80OrLater, SM90OrLater, with_tf32_off, TEST_CUDNN,
|
||||
PLATFORM_SUPPORTS_FLASH_ATTENTION, SM53OrLater, SM60OrLater, SM80OrLater, SM90OrLater, with_tf32_off, TEST_CUDNN,
|
||||
_get_torch_cuda_version, _get_torch_rocm_version,
|
||||
)
|
||||
from torch.testing._internal.common_utils import (
|
||||
|
|
@ -8511,6 +8511,38 @@ def sample_inputs_efficient_attention_forward(op_info, device, dtype, requires_g
|
|||
|
||||
yield from samples
|
||||
|
||||
def sample_inputs_flash_attention_forward(op_info, device, dtype, requires_grad, **kwargs):
|
||||
make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
batch, num_heads, head_dim = 4, 4, 8
|
||||
seq_q = 11
|
||||
seq_kv = 32
|
||||
|
||||
dim_4_q_shape = (batch, num_heads, seq_q, head_dim)
|
||||
dim_4_kv_shape = (batch, num_heads, seq_kv, head_dim)
|
||||
|
||||
qkv_shapes = [(dim_4_q_shape, dim_4_kv_shape)]
|
||||
samples = []
|
||||
scales = [None, 1.0]
|
||||
|
||||
for qkv_shape, is_causal, dropout_p, scale in product(
|
||||
qkv_shapes, [True, False], [0.0, 0.5], scales):
|
||||
shape_q, shape_kv = qkv_shape
|
||||
samples.append(SampleInput(
|
||||
make(shape_q).transpose(1, 2),
|
||||
make(shape_kv).transpose(1, 2),
|
||||
make(shape_kv).transpose(1, 2),
|
||||
cum_seq_q=None,
|
||||
cum_seq_k=None,
|
||||
max_q=seq_q,
|
||||
max_k=seq_kv,
|
||||
dropout_p=dropout_p,
|
||||
is_causal=is_causal,
|
||||
return_debug_mask=False,
|
||||
scale=scale,
|
||||
))
|
||||
|
||||
yield from samples
|
||||
|
||||
def sample_inputs_pairwise_distance(op_info, device, dtype, requires_grad, **kwargs):
|
||||
make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
|
||||
|
|
@ -14240,6 +14272,31 @@ op_db: List[OpInfo] = [
|
|||
DecorateInfo(unittest.skip("Skipped!"), 'TestSchemaCheckModeOpInfo', 'test_schema_correctness',
|
||||
device_type='cuda', dtypes=(torch.bfloat16,), active_if=not SM80OrLater),),
|
||||
),
|
||||
OpInfo(
|
||||
'torch.ops.aten._flash_attention_forward',
|
||||
sample_inputs_func=sample_inputs_flash_attention_forward,
|
||||
dtypes=empty_types(),
|
||||
dtypesIfCUDA=custom_types(torch.float16)
|
||||
if not SM80OrLater
|
||||
else custom_types(torch.float16, torch.bfloat16),
|
||||
supports_out=False,
|
||||
supports_autograd=True,
|
||||
supports_fwgrad_bwgrad=False,
|
||||
supports_forward_ad=False,
|
||||
check_batched_forward_grad=False,
|
||||
decorators=[skipCUDAIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "This platform doesn't support Flash Attention")],
|
||||
skips=(
|
||||
# Device mismatch due to philox seed and offset
|
||||
DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake_autocast', device_type='cuda'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake', device_type='cuda'),
|
||||
# Checking the scalar value of the philox seed and offset
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator', device_type='cuda'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples', device_type='cuda'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', device_type='cuda'),
|
||||
# None Mismatch Tensor
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward', device_type='cuda'),
|
||||
)
|
||||
),
|
||||
OpInfo(
|
||||
'torch.ops.aten._efficient_attention_forward',
|
||||
sample_inputs_func=sample_inputs_efficient_attention_forward,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user