mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
fix flex attention eager bwd: more rounding (#164317)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164317 Approved by: https://github.com/drisspg ghstack dependencies: #163986
This commit is contained in:
parent
afeec56a5a
commit
a7fa1a91e3
|
|
@ -456,15 +456,17 @@ class TestFlexAttention(InductorTestCase):
|
|||
compiled_out: torch.Tensor,
|
||||
fudge_factor: float,
|
||||
tensor_name: Optional[str] = None,
|
||||
fudge_atol: float = 0,
|
||||
):
|
||||
compiled_error = (golden_out - compiled_out).abs().mean()
|
||||
ref_error = (golden_out - ref_out).abs().mean()
|
||||
if torch.isnan(compiled_error).any() or torch.isnan(ref_error).any():
|
||||
self.assertTrue(False, "Output/Grad with NaN")
|
||||
if compiled_error > ref_error * fudge_factor:
|
||||
name = tensor_name if tensor_name is not None else ""
|
||||
msg = f"{name} Compiled error {compiled_error} is greater than ref error {ref_error} by more than {fudge_factor}X."
|
||||
self.assertTrue(False, msg)
|
||||
self.fail("Output/Grad with NaN")
|
||||
name = tensor_name if tensor_name is not None else ""
|
||||
msg = f"{name} Compiled error {compiled_error} is greater than ref error {ref_error} by more than {fudge_factor}X."
|
||||
torch.testing.assert_close(
|
||||
compiled_error, ref_error, rtol=fudge_factor, atol=1e-7, msg=msg
|
||||
)
|
||||
|
||||
def _check_out(
|
||||
self,
|
||||
|
|
@ -6436,7 +6438,7 @@ class TestLearnableBiases(InductorTestCase):
|
|||
bias = torch.randn(
|
||||
params.seq_length,
|
||||
device=device,
|
||||
dtype=params.dtype,
|
||||
dtype=torch.float32,
|
||||
requires_grad=True,
|
||||
)
|
||||
|
||||
|
|
@ -6619,12 +6621,12 @@ class TestLearnableBiases(InductorTestCase):
|
|||
gate_score = torch.randn(
|
||||
params.num_heads,
|
||||
device=device,
|
||||
dtype=params.dtype,
|
||||
dtype=torch.float32,
|
||||
requires_grad=True,
|
||||
)
|
||||
|
||||
def bias_func(score, b, h, q_idx, kv_idx):
|
||||
return score * torch.sigmoid(gate_score[h].to(torch.float32))
|
||||
return score * torch.sigmoid(gate_score[h])
|
||||
|
||||
flex_compiled = torch.compile(flex_attention, mode=mode)
|
||||
out_eager = flex_attention(query, key, value, score_mod=bias_func)
|
||||
|
|
@ -6659,7 +6661,7 @@ class TestLearnableBiases(InductorTestCase):
|
|||
bias2 = torch.randn(
|
||||
params.seq_length,
|
||||
device=device,
|
||||
dtype=params.dtype,
|
||||
dtype=torch.float32,
|
||||
requires_grad=True,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -902,9 +902,15 @@ def sdpa_dense_backward(
|
|||
|
||||
grad_value = softmax_scores.to(query.dtype).transpose(-2, -1) @ grad_out
|
||||
|
||||
grad_softmax_scores = grad_out @ value.transpose(-2, -1)
|
||||
grad_softmax_scores = grad_out.to(dtype=softmax_scores.dtype) @ value.to(
|
||||
dtype=softmax_scores.dtype
|
||||
).transpose(-2, -1)
|
||||
|
||||
sum_scores = torch.sum(out * grad_out, -1, keepdim=True)
|
||||
sum_scores = torch.sum(
|
||||
out.to(dtype=softmax_scores.dtype) * grad_out.to(dtype=softmax_scores.dtype),
|
||||
-1,
|
||||
keepdim=True,
|
||||
)
|
||||
grad_score_mod = softmax_scores * (
|
||||
grad_softmax_scores - sum_scores + grad_logsumexp.unsqueeze(-1)
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user