mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fix score_mod.py dynamic max autotune for backward (#151270)
Same as https://github.com/pytorch/pytorch/pull/148991 but this PR fixes the backward path. Pull Request resolved: https://github.com/pytorch/pytorch/pull/151270 Approved by: https://github.com/drisspg, https://github.com/bobrenjc93
This commit is contained in:
parent
afaadce083
commit
ccfce9ae86
|
|
@ -5335,6 +5335,9 @@ class TestLearnableBiases(InductorTestCase):
|
|||
query = torch.randn(2, 16, 512, 64, device="cuda")
|
||||
key = torch.randn(2, 16, 512, 64, device="cuda")
|
||||
value = torch.randn(2, 16, 512, 64, device="cuda")
|
||||
query.requires_grad = True
|
||||
key.requires_grad = True
|
||||
value.requires_grad = True
|
||||
|
||||
shape = (2, 16, 512, 16, 512, 64)
|
||||
B, Hq, M, Hkv, N, D = shape
|
||||
|
|
@ -5360,6 +5363,7 @@ class TestLearnableBiases(InductorTestCase):
|
|||
enable_gqa=True,
|
||||
kernel_options=None,
|
||||
)
|
||||
out.sum().backward()
|
||||
|
||||
self.assertEqual(
|
||||
out.shape, query.shape, f"Expected shape {query.shape}, got {out.shape}"
|
||||
|
|
|
|||
|
|
@ -2672,7 +2672,7 @@ def flex_attention_backward(*args, **kwargs):
|
|||
broadcasted_grad_key = autotune_select_algorithm(
|
||||
"flex_attention_backward",
|
||||
choices,
|
||||
inputs_for_autotuning,
|
||||
[x for x in inputs_for_autotuning if isinstance(x, torch._inductor.ir.IRNode)],
|
||||
layout_broadcasted_k,
|
||||
input_gen_fns=input_gen_fns,
|
||||
) # [Bq, Hkv, seq_len_kv, k_head_dim]
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user