Fix performance regression and memory storage handling of Flash Attention on ROCM (#122857)

This PR fixes the two major issues that was discovered after the initial merge of PR #121561
1. The Flash Attention support added by has severe performance regressions on regular shapes (power of two head dimensions and sequence lengths) compared with PR #115981. Its performance is worse than the math backend and only has numerical stability advantages. This PR fixes this problem.
2. There is a flaw of memory storage handling in PR #121561 which does not copy the gradients back to the designated output tensor. This PR removes the deprecated `TensorStorageSanitizer` class which is unnecessary due to the more flexible backward kernel shipped by PR #121561

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122857
Approved by: https://github.com/jeffdaily, https://github.com/drisspg
This commit is contained in:
Xinya Zhang 2024-03-29 16:37:24 +00:00 committed by PyTorch MergeBot
parent d8b69de73b
commit b83c94339e
3 changed files with 6 additions and 45 deletions

View File

@ -157,43 +157,6 @@ aotriton::TensorView<Rank> mk_aotensor(const at::Tensor& q, c10::string_view ten
cast_dtype(q.dtype()));
}
template<bool COPY_FROM_INPUT, // For Input Tensor
bool COPY_BACK> // For Output Tensor
class TensorStorageSanitizer {
public:
TensorStorageSanitizer(const at::Tensor& ref,
at::Tensor& to_sanitize)
: ref_(ref), to_sanitize_(to_sanitize)
{
need_sanitize = ref_.strides() != to_sanitize_.strides();
if (!need_sanitize)
return;
temp_ = at::empty_like(ref_);
if (COPY_FROM_INPUT) {
temp_.copy_(to_sanitize_);
}
}
~TensorStorageSanitizer()
{
if (need_sanitize && COPY_BACK)
to_sanitize_.copy_(temp_);
}
at::Tensor& sanitized_tensor()
{
if (need_sanitize)
return temp_;
return to_sanitize_;
}
private:
const at::Tensor& ref_;
at::Tensor& to_sanitize_;
at::Tensor temp_;
bool need_sanitize = false;
};
}
#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
@ -531,9 +494,6 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
int d_head = head_size_og;
hipError_t err; // TODO: Error handling
{
TensorStorageSanitizer<true, false> dq_s(q_t, dq_t);
TensorStorageSanitizer<true, false> dk_s(k_t, dk_t);
TensorStorageSanitizer<true, false> dv_s(v_t, dv_t);
using aotriton::v2::flash::attn_bwd;
err = attn_bwd(mk_aotensor(q_t, "q"),
mk_aotensor(k_t, "k"),
@ -541,9 +501,9 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
softmax_scale,
mk_aotensor(out_t, "out"),
mk_aotensor(dout_t, "dout"),
mk_aotensor(dq_s.sanitized_tensor(), "dq"),
mk_aotensor(dk_s.sanitized_tensor(), "dk"),
mk_aotensor(dv_s.sanitized_tensor(), "dv"),
mk_aotensor(dq_t, "dq"),
mk_aotensor(dk_t, "dk"),
mk_aotensor(dv_t, "dv"),
mk_aotensor<2>(softmax_lse_cont, "L"),
mk_aotensor<2>(delta, "delta"),
p_dropout,

View File

@ -6,7 +6,7 @@ if(NOT __AOTRITON_INCLUDED)
set(__AOTRITON_INSTALL_DIR "${PROJECT_SOURCE_DIR}/torch")
ExternalProject_Add(aotriton_external
GIT_REPOSITORY https://github.com/ROCm/aotriton.git
GIT_TAG 9044fe5eb16130e49a0a1f781ea15037353ad542
GIT_TAG 5d9a1dbcf5b17ff798ff77f60a8a08fa41953ff0
SOURCE_DIR ${__AOTRITON_SOURCE_DIR}
BINARY_DIR ${__AOTRITON_BUILD_DIR}
PREFIX ${__AOTRITON_INSTALL_DIR}

View File

@ -2396,7 +2396,6 @@ class TestSDPACudaOnly(NNTestCase):
# Cast up and compare
self.assertEqual(qkv.grad, qkv_lp.grad.to(torch.float64), atol=1e-5, rtol=1e-5)
@skipIfRocm # TODO: Packed QKV
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Flash Attention was not built for this system")
@parametrize("contiguous_inputs", [True, False])
@parametrize("is_causal", [True, False])
@ -2447,6 +2446,8 @@ class TestSDPACudaOnly(NNTestCase):
# Bump down the tolearnce for blfoat16
atol = 7e-4 if dtype == torch.float16 else 7e-3
rtol = 7e-4 if dtype == torch.float16 else 7e-3
if TEST_WITH_ROCM:
atol = 9e-4 if dtype == torch.float16 else 9e-3
self.assertEqual(qkv.grad, qkv_lp.grad.to(torch.float64), atol=atol, rtol=rtol)
@skipIfRocm # Missing nested and EFFICIENT_ATTENTION