Expose Flash attn to autograd (#114378)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114378
Approved by: https://github.com/drisspg
This commit is contained in:
Antoni Viros 2023-12-01 20:15:35 +00:00 committed by PyTorch MergeBot
parent 80d8a2a237
commit d47f715d29
12 changed files with 174 additions and 23 deletions

View File

@ -14493,7 +14493,7 @@
CUDA: _scaled_dot_product_efficient_attention_backward_cuda
tags: nondeterministic_seeded
- func: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt? max_q, SymInt? max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
- func: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
variants: function
dispatch:
CUDA: _flash_attention_forward

View File

@ -708,8 +708,8 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt, Tensor, Ten
v_t,
c10::nullopt,
c10::nullopt,
c10::nullopt,
c10::nullopt,
max_seqlen_batch_q,
max_seqlen_batch_k,
dropout_p,
is_causal,
return_debug_mask,
@ -779,8 +779,8 @@ _flash_attention_forward(
const Tensor& value,
const c10::optional<Tensor>& cumulative_sequence_length_q,
const c10::optional<Tensor>& cumulative_sequence_length_k,
c10::optional<int64_t> max_seqlen_batch_q,
c10::optional<int64_t> max_seqlen_batch_k,
int64_t max_seqlen_batch_q,
int64_t max_seqlen_batch_k,
double dropout_p,
bool is_causal,
bool return_debug_mask,
@ -797,15 +797,9 @@ _flash_attention_forward(
cumulative_sequence_length_q.has_value() ==
cumulative_sequence_length_k.has_value(),
"cumulative_sequence_length_q and cumulative_sequence_length_k must be both set or both not set");
TORCH_CHECK(
max_seqlen_batch_q.has_value() == max_seqlen_batch_k.has_value(),
"max_seqlen_batch_q and max_seqlen_batch_k must be both set or both not set");
Tensor output, q_padded, k_padded, v_padded, logsumexp, output_shape,
philox_seed, philox_offset, debug_attn_mask;
if (cumulative_sequence_length_q.has_value()) {
TORCH_CHECK(
max_seqlen_batch_q.has_value(),
"max_seqlen_batch_q must be set when cumulative_sequence_length_q is set");
std::tie(
output,
q_padded,
@ -822,8 +816,8 @@ _flash_attention_forward(
out,
cumulative_sequence_length_q.value(),
cumulative_sequence_length_k.value(),
max_seqlen_batch_q.value(),
max_seqlen_batch_k.value(),
max_seqlen_batch_q,
max_seqlen_batch_k,
dropout_p,
softmax_scale,
false /*zero_tensors*/,

View File

@ -71,7 +71,7 @@ std::tuple<Tensor, Tensor, Tensor> _flash_attention_backward(
c10::optional<at::Tensor> dk{c10::nullopt};
c10::optional<at::Tensor> dv{c10::nullopt};
// The kernel computes irregadless we will drop for this functions return
// The kernel computes irregardless we will drop for this functions return
Tensor grad_softmax;
// We check the whether the cumulative_sequence_length_q is defined

View File

@ -266,6 +266,8 @@ ALLOW_LIST = [
("aten::_upsample_nearest_exact2d_backward", datetime.date(2022, 12, 15)),
("aten::_efficient_attention_forward", datetime.date(2023, 11, 30)),
("aten::_efficient_attention_backward", datetime.date(2023, 11, 30)),
("aten::_flash_attention_forward", datetime.date(2023, 12, 30)),
("aten::_flash_attention_backward", datetime.date(2023, 12, 30)),
("aten::_sparse_mask_helper", datetime.date(2023, 3, 15)),
("mkldnn::_convolution_pointwise.binary", datetime.date(2022, 12, 15)),
("prim::CudaFusionIvalGuard", datetime.date(2023, 2, 1)),

View File

@ -392,6 +392,7 @@ class TestOperators(TestCase):
# query: last dimension must be contiguous
# Fused attention kernels require last dim to be contiguous
xfail('nn.functional.scaled_dot_product_attention'),
xfail("torch.ops.aten._flash_attention_forward"),
xfail("torch.ops.aten._efficient_attention_forward"),
}))
@opsToleranceOverride('TestOperators', 'test_grad', (
@ -475,6 +476,7 @@ class TestOperators(TestCase):
xfail("_native_batch_norm_legit"), # TODO: fails comparing None to tensor of 0s for saved_mean/var tangents
xfail('nn.functional.scaled_dot_product_attention'),
xfail('torch.ops.aten._flash_attention_forward'),
xfail('torch.ops.aten._efficient_attention_forward'),
xfail('nn.functional.rrelu'), # in-place test errors out with no formula implemented
@ -603,6 +605,7 @@ class TestOperators(TestCase):
# RuntimeError: query: last dimension must be contiguous
# The fused attention kernels require the last dim to be contiguous
xfail('nn.functional.scaled_dot_product_attention'),
xfail('torch.ops.aten._flash_attention_forward'),
xfail('torch.ops.aten._efficient_attention_forward'),
# BUG
# AssertionError: Tensor-likes are not close!
@ -682,6 +685,7 @@ class TestOperators(TestCase):
xfail('sparse.sampled_addmm', ''), # sparse tensors have no strides
xfail('sparse.mm', 'reduce'), # sparse tensors have no strides
skip('nn.functional.scaled_dot_product_attention'),
xfail('torch.ops.aten._flash_attention_forward'),
xfail('torch.ops.aten._efficient_attention_forward'),
# AssertionError: Tensor-likes are not close!
# Mismatched elements: 1 / 15 (6.7%)

View File

@ -3573,6 +3573,9 @@ class TestVmapOperatorsOpInfo(TestCase):
xfail('le', device_type='cuda'),
xfail('lt', device_type='cuda'),
xfail('ne', device_type='cuda'),
# RuntimeError: aten::_flash_attention_forward hit the vmap fallback which is currently disabled
xfail('torch.ops.aten._flash_attention_forward'),
}
@with_tf32_off # https://github.com/pytorch/pytorch/issues/86798

View File

@ -238,6 +238,7 @@ inductor_expected_failures_single_sample["cuda"] = {
"sparse.sampled_addmm": {f32, f64},
"to_sparse": {f16, f32, f64},
"torch.ops.aten._efficient_attention_forward": {f16, bf16, f32},
"torch.ops.aten._flash_attention_forward": {f16, bf16, f32},
}

View File

@ -2779,9 +2779,9 @@
output_differentiability: [True, False, False, False, False, False, False, False, False]
query, key, value: _scaled_dot_product_flash_attention_backward_symint(grad, query, key, value, output, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, scale)
# - name: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, int? max_q, int? max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None) -> (Tensor output, Tensor query_padded, Tensor key_padded, Tensor value_padded, Tensor softmax_logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
# output_differentiability: [True, False, False, False, False, False, False, False]
# query, key, value: _flash_attention_backward(grad, query, key, value, output, softmax_logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, scale)
- name: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
output_differentiability: [True, False, False, False, False]
query, key, value: _flash_attention_backward_symint(grad, query, key, value, output, softmax_logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, scale)
- name: _efficient_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seqlen_q, float dropout_p, int custom_mask_type, bool compute_log_sumexp=False, *, float? scale=None, Tensor? causal_diagonal=None, Tensor? seqlen_k=None) -> (Tensor output, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, SymInt max_seqlen_batch_q, SymInt max_seqlen_batch_k)
output_differentiability: [True, False, False, False, False, False]

View File

@ -2022,8 +2022,6 @@ make_fallback(aten.cumsum, require_dense, warn=False)
make_fallback(aten.cumprod, require_dense, warn=False)
make_fallback(aten._embedding_bag, require_contiguous)
make_fallback(aten._embedding_bag_forward_only, require_contiguous)
make_fallback(aten._flash_attention_forward)
make_fallback(aten._flash_attention_backward)
make_fallback(aten._fused_moving_avg_obs_fq_helper)
make_fallback(aten._fused_moving_avg_obs_fq_helper_functional)
make_fallback(aten.grid_sampler_2d_backward, require_dense)
@ -2113,8 +2111,10 @@ make_fallback(
sdpa_constraint,
warn=False,
)
make_fallback(torch.ops.aten._efficient_attention_forward.default)
make_fallback(torch.ops.aten._efficient_attention_backward.default)
make_fallback(aten._flash_attention_forward.default, sdpa_constraint)
make_fallback(aten._flash_attention_backward.default, sdpa_constraint)
make_fallback(aten._efficient_attention_forward.default, sdpa_constraint)
make_fallback(aten._efficient_attention_backward.default, sdpa_constraint)
make_fallback(aten.sort)
make_fallback(aten.sort.stable)
make_fallback(aten._sparse_coo_tensor_with_dims_and_tensors)

View File

@ -5296,6 +5296,93 @@ def meta__scaled_dot_product_efficient_backward(
return grad_q, grad_k, grad_v, grad_bias
@register_meta(
[
aten._flash_attention_forward,
]
)
def meta__flash_attention_forward(
query: Tensor,
key: Tensor,
value: Tensor,
cum_seq_q: Optional[Tensor],
cum_seq_k: Optional[Tensor],
max_q: int,
max_k: int,
dropout_p: float,
is_causal: bool,
return_debug_mask: bool,
scale: Optional[float] = None,
):
batch_size = query.size(0)
max_seqlen_batch_q = query.size(1)
num_heads = query.size(2)
head_dim = query.size(3)
max_seqlen_batch_k = key.size(1)
# Cuda Path
attention = torch.empty_like(query)
logsumexp = torch.empty(
(batch_size, num_heads, max_seqlen_batch_q),
dtype=torch.float,
device=query.device,
)
if return_debug_mask:
blocksize_c = 128 if head_dim > 64 else 256
max_seqlen_k = math.ceil(max_seqlen_batch_q / blocksize_c)
if max_seqlen_batch_k <= 128:
max_seqlen_k = 128
elif max_seqlen_batch_k <= 256:
max_seqlen_k = 256
debug_mask = torch.empty(
(batch_size, num_heads, max_seqlen_batch_q, max_seqlen_k),
dtype=query.dtype,
device=query.device,
)
else:
debug_mask = torch.empty(0, dtype=query.dtype, device=query.device)
# See Note [Seed and Offset]:
return (
attention,
logsumexp,
torch.empty((), dtype=torch.long, device="meta"),
torch.empty((), dtype=torch.long, device="meta"),
debug_mask,
)
@register_meta(
[
aten._flash_attention_backward,
]
)
def meta__flash_attention_backward(
grad_out: Tensor,
query: Tensor,
key: Tensor,
value: Tensor,
out: Tensor,
logsumexp: Tensor,
cum_seq_q: Tensor,
cum_seq_k: Tensor,
max_q: int,
max_k: int,
dropout_p: float,
is_causal: bool,
philox_seed: Tensor,
philox_offset: Tensor,
scale: Optional[float] = None,
):
grad_query = torch.empty_like(query)
grad_key = torch.empty_like(key)
grad_value = torch.empty_like(value)
return grad_query, grad_key, grad_value
@register_meta(
[
aten._efficient_attention_forward,

View File

@ -47,7 +47,10 @@ def output_alias_each_other(outputs):
def is_sdpa_error(func, idx, e):
if (
func is aten._scaled_dot_product_flash_attention.default
(
func is aten._scaled_dot_product_flash_attention.default
or func is aten._flash_attention_forward.default
)
and idx in (6, 7)
and "Devices" in repr(e)
):

View File

@ -26,7 +26,7 @@ from torch.testing._internal.common_device_type import \
skipCPUIfNoMklSparse,
toleranceOverride, tol)
from torch.testing._internal.common_cuda import (
SM53OrLater, SM60OrLater, SM80OrLater, SM90OrLater, with_tf32_off, TEST_CUDNN,
PLATFORM_SUPPORTS_FLASH_ATTENTION, SM53OrLater, SM60OrLater, SM80OrLater, SM90OrLater, with_tf32_off, TEST_CUDNN,
_get_torch_cuda_version, _get_torch_rocm_version,
)
from torch.testing._internal.common_utils import (
@ -8511,6 +8511,38 @@ def sample_inputs_efficient_attention_forward(op_info, device, dtype, requires_g
yield from samples
def sample_inputs_flash_attention_forward(op_info, device, dtype, requires_grad, **kwargs):
make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
batch, num_heads, head_dim = 4, 4, 8
seq_q = 11
seq_kv = 32
dim_4_q_shape = (batch, num_heads, seq_q, head_dim)
dim_4_kv_shape = (batch, num_heads, seq_kv, head_dim)
qkv_shapes = [(dim_4_q_shape, dim_4_kv_shape)]
samples = []
scales = [None, 1.0]
for qkv_shape, is_causal, dropout_p, scale in product(
qkv_shapes, [True, False], [0.0, 0.5], scales):
shape_q, shape_kv = qkv_shape
samples.append(SampleInput(
make(shape_q).transpose(1, 2),
make(shape_kv).transpose(1, 2),
make(shape_kv).transpose(1, 2),
cum_seq_q=None,
cum_seq_k=None,
max_q=seq_q,
max_k=seq_kv,
dropout_p=dropout_p,
is_causal=is_causal,
return_debug_mask=False,
scale=scale,
))
yield from samples
def sample_inputs_pairwise_distance(op_info, device, dtype, requires_grad, **kwargs):
make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
@ -14240,6 +14272,31 @@ op_db: List[OpInfo] = [
DecorateInfo(unittest.skip("Skipped!"), 'TestSchemaCheckModeOpInfo', 'test_schema_correctness',
device_type='cuda', dtypes=(torch.bfloat16,), active_if=not SM80OrLater),),
),
OpInfo(
'torch.ops.aten._flash_attention_forward',
sample_inputs_func=sample_inputs_flash_attention_forward,
dtypes=empty_types(),
dtypesIfCUDA=custom_types(torch.float16)
if not SM80OrLater
else custom_types(torch.float16, torch.bfloat16),
supports_out=False,
supports_autograd=True,
supports_fwgrad_bwgrad=False,
supports_forward_ad=False,
check_batched_forward_grad=False,
decorators=[skipCUDAIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "This platform doesn't support Flash Attention")],
skips=(
# Device mismatch due to philox seed and offset
DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake_autocast', device_type='cuda'),
DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake', device_type='cuda'),
# Checking the scalar value of the philox seed and offset
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator', device_type='cuda'),
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples', device_type='cuda'),
DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', device_type='cuda'),
# None Mismatch Tensor
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward', device_type='cuda'),
)
),
OpInfo(
'torch.ops.aten._efficient_attention_forward',
sample_inputs_func=sample_inputs_efficient_attention_forward,