Fix score_mod.py dynamic max autotune for backward (#151270)

Same as https://github.com/pytorch/pytorch/pull/148991 but this PR fixes the backward path.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151270
Approved by: https://github.com/drisspg, https://github.com/bobrenjc93
This commit is contained in:
Chien-Chin Huang 2025-04-14 16:18:56 -07:00 committed by PyTorch MergeBot
parent afaadce083
commit ccfce9ae86
2 changed files with 5 additions and 1 deletions

View File

@ -5335,6 +5335,9 @@ class TestLearnableBiases(InductorTestCase):
query = torch.randn(2, 16, 512, 64, device="cuda")
key = torch.randn(2, 16, 512, 64, device="cuda")
value = torch.randn(2, 16, 512, 64, device="cuda")
query.requires_grad = True
key.requires_grad = True
value.requires_grad = True
shape = (2, 16, 512, 16, 512, 64)
B, Hq, M, Hkv, N, D = shape
@ -5360,6 +5363,7 @@ class TestLearnableBiases(InductorTestCase):
enable_gqa=True,
kernel_options=None,
)
out.sum().backward()
self.assertEqual(
out.shape, query.shape, f"Expected shape {query.shape}, got {out.shape}"

View File

@ -2672,7 +2672,7 @@ def flex_attention_backward(*args, **kwargs):
broadcasted_grad_key = autotune_select_algorithm(
"flex_attention_backward",
choices,
inputs_for_autotuning,
[x for x in inputs_for_autotuning if isinstance(x, torch._inductor.ir.IRNode)],
layout_broadcasted_k,
input_gen_fns=input_gen_fns,
) # [Bq, Hkv, seq_len_kv, k_head_dim]