fix flex attention eager bwd: more rounding (#164317)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164317
Approved by: https://github.com/drisspg
ghstack dependencies: #163986
This commit is contained in:
Markus Hoehnerbach 2025-10-08 16:29:00 -07:00 committed by PyTorch MergeBot
parent afeec56a5a
commit a7fa1a91e3
2 changed files with 19 additions and 11 deletions

View File

@ -456,15 +456,17 @@ class TestFlexAttention(InductorTestCase):
compiled_out: torch.Tensor,
fudge_factor: float,
tensor_name: Optional[str] = None,
fudge_atol: float = 0,
):
compiled_error = (golden_out - compiled_out).abs().mean()
ref_error = (golden_out - ref_out).abs().mean()
if torch.isnan(compiled_error).any() or torch.isnan(ref_error).any():
self.assertTrue(False, "Output/Grad with NaN")
if compiled_error > ref_error * fudge_factor:
name = tensor_name if tensor_name is not None else ""
msg = f"{name} Compiled error {compiled_error} is greater than ref error {ref_error} by more than {fudge_factor}X."
self.assertTrue(False, msg)
self.fail("Output/Grad with NaN")
name = tensor_name if tensor_name is not None else ""
msg = f"{name} Compiled error {compiled_error} is greater than ref error {ref_error} by more than {fudge_factor}X."
torch.testing.assert_close(
compiled_error, ref_error, rtol=fudge_factor, atol=1e-7, msg=msg
)
def _check_out(
self,
@ -6436,7 +6438,7 @@ class TestLearnableBiases(InductorTestCase):
bias = torch.randn(
params.seq_length,
device=device,
dtype=params.dtype,
dtype=torch.float32,
requires_grad=True,
)
@ -6619,12 +6621,12 @@ class TestLearnableBiases(InductorTestCase):
gate_score = torch.randn(
params.num_heads,
device=device,
dtype=params.dtype,
dtype=torch.float32,
requires_grad=True,
)
def bias_func(score, b, h, q_idx, kv_idx):
return score * torch.sigmoid(gate_score[h].to(torch.float32))
return score * torch.sigmoid(gate_score[h])
flex_compiled = torch.compile(flex_attention, mode=mode)
out_eager = flex_attention(query, key, value, score_mod=bias_func)
@ -6659,7 +6661,7 @@ class TestLearnableBiases(InductorTestCase):
bias2 = torch.randn(
params.seq_length,
device=device,
dtype=params.dtype,
dtype=torch.float32,
requires_grad=True,
)

View File

@ -902,9 +902,15 @@ def sdpa_dense_backward(
grad_value = softmax_scores.to(query.dtype).transpose(-2, -1) @ grad_out
grad_softmax_scores = grad_out @ value.transpose(-2, -1)
grad_softmax_scores = grad_out.to(dtype=softmax_scores.dtype) @ value.to(
dtype=softmax_scores.dtype
).transpose(-2, -1)
sum_scores = torch.sum(out * grad_out, -1, keepdim=True)
sum_scores = torch.sum(
out.to(dtype=softmax_scores.dtype) * grad_out.to(dtype=softmax_scores.dtype),
-1,
keepdim=True,
)
grad_score_mod = softmax_scores * (
grad_softmax_scores - sum_scores + grad_logsumexp.unsqueeze(-1)
)