mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Fix performance regression and memory storage handling of Flash Attention on ROCM (#122857)
This PR fixes the two major issues that was discovered after the initial merge of PR #121561 1. The Flash Attention support added by has severe performance regressions on regular shapes (power of two head dimensions and sequence lengths) compared with PR #115981. Its performance is worse than the math backend and only has numerical stability advantages. This PR fixes this problem. 2. There is a flaw of memory storage handling in PR #121561 which does not copy the gradients back to the designated output tensor. This PR removes the deprecated `TensorStorageSanitizer` class which is unnecessary due to the more flexible backward kernel shipped by PR #121561 Pull Request resolved: https://github.com/pytorch/pytorch/pull/122857 Approved by: https://github.com/jeffdaily, https://github.com/drisspg
This commit is contained in:
parent
d8b69de73b
commit
b83c94339e
|
|
@ -157,43 +157,6 @@ aotriton::TensorView<Rank> mk_aotensor(const at::Tensor& q, c10::string_view ten
|
|||
cast_dtype(q.dtype()));
|
||||
}
|
||||
|
||||
template<bool COPY_FROM_INPUT, // For Input Tensor
|
||||
bool COPY_BACK> // For Output Tensor
|
||||
class TensorStorageSanitizer {
|
||||
public:
|
||||
TensorStorageSanitizer(const at::Tensor& ref,
|
||||
at::Tensor& to_sanitize)
|
||||
: ref_(ref), to_sanitize_(to_sanitize)
|
||||
{
|
||||
need_sanitize = ref_.strides() != to_sanitize_.strides();
|
||||
if (!need_sanitize)
|
||||
return;
|
||||
|
||||
temp_ = at::empty_like(ref_);
|
||||
if (COPY_FROM_INPUT) {
|
||||
temp_.copy_(to_sanitize_);
|
||||
}
|
||||
}
|
||||
|
||||
~TensorStorageSanitizer()
|
||||
{
|
||||
if (need_sanitize && COPY_BACK)
|
||||
to_sanitize_.copy_(temp_);
|
||||
}
|
||||
|
||||
at::Tensor& sanitized_tensor()
|
||||
{
|
||||
if (need_sanitize)
|
||||
return temp_;
|
||||
return to_sanitize_;
|
||||
}
|
||||
private:
|
||||
const at::Tensor& ref_;
|
||||
at::Tensor& to_sanitize_;
|
||||
at::Tensor temp_;
|
||||
bool need_sanitize = false;
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
|
||||
|
|
@ -531,9 +494,6 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
|
|||
int d_head = head_size_og;
|
||||
hipError_t err; // TODO: Error handling
|
||||
{
|
||||
TensorStorageSanitizer<true, false> dq_s(q_t, dq_t);
|
||||
TensorStorageSanitizer<true, false> dk_s(k_t, dk_t);
|
||||
TensorStorageSanitizer<true, false> dv_s(v_t, dv_t);
|
||||
using aotriton::v2::flash::attn_bwd;
|
||||
err = attn_bwd(mk_aotensor(q_t, "q"),
|
||||
mk_aotensor(k_t, "k"),
|
||||
|
|
@ -541,9 +501,9 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
|
|||
softmax_scale,
|
||||
mk_aotensor(out_t, "out"),
|
||||
mk_aotensor(dout_t, "dout"),
|
||||
mk_aotensor(dq_s.sanitized_tensor(), "dq"),
|
||||
mk_aotensor(dk_s.sanitized_tensor(), "dk"),
|
||||
mk_aotensor(dv_s.sanitized_tensor(), "dv"),
|
||||
mk_aotensor(dq_t, "dq"),
|
||||
mk_aotensor(dk_t, "dk"),
|
||||
mk_aotensor(dv_t, "dv"),
|
||||
mk_aotensor<2>(softmax_lse_cont, "L"),
|
||||
mk_aotensor<2>(delta, "delta"),
|
||||
p_dropout,
|
||||
|
|
|
|||
2
cmake/External/aotriton.cmake
vendored
2
cmake/External/aotriton.cmake
vendored
|
|
@ -6,7 +6,7 @@ if(NOT __AOTRITON_INCLUDED)
|
|||
set(__AOTRITON_INSTALL_DIR "${PROJECT_SOURCE_DIR}/torch")
|
||||
ExternalProject_Add(aotriton_external
|
||||
GIT_REPOSITORY https://github.com/ROCm/aotriton.git
|
||||
GIT_TAG 9044fe5eb16130e49a0a1f781ea15037353ad542
|
||||
GIT_TAG 5d9a1dbcf5b17ff798ff77f60a8a08fa41953ff0
|
||||
SOURCE_DIR ${__AOTRITON_SOURCE_DIR}
|
||||
BINARY_DIR ${__AOTRITON_BUILD_DIR}
|
||||
PREFIX ${__AOTRITON_INSTALL_DIR}
|
||||
|
|
|
|||
|
|
@ -2396,7 +2396,6 @@ class TestSDPACudaOnly(NNTestCase):
|
|||
# Cast up and compare
|
||||
self.assertEqual(qkv.grad, qkv_lp.grad.to(torch.float64), atol=1e-5, rtol=1e-5)
|
||||
|
||||
@skipIfRocm # TODO: Packed QKV
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Flash Attention was not built for this system")
|
||||
@parametrize("contiguous_inputs", [True, False])
|
||||
@parametrize("is_causal", [True, False])
|
||||
|
|
@ -2447,6 +2446,8 @@ class TestSDPACudaOnly(NNTestCase):
|
|||
# Bump down the tolearnce for blfoat16
|
||||
atol = 7e-4 if dtype == torch.float16 else 7e-3
|
||||
rtol = 7e-4 if dtype == torch.float16 else 7e-3
|
||||
if TEST_WITH_ROCM:
|
||||
atol = 9e-4 if dtype == torch.float16 else 9e-3
|
||||
self.assertEqual(qkv.grad, qkv_lp.grad.to(torch.float64), atol=atol, rtol=rtol)
|
||||
|
||||
@skipIfRocm # Missing nested and EFFICIENT_ATTENTION
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user