From 10e67f5ec3834da93fc2022caa7ac69cf97c01f0 Mon Sep 17 00:00:00 2001 From: Jeff Daily Date: Tue, 26 Aug 2025 15:11:54 +0000 Subject: [PATCH] forward fix #161102 (#161465) PR #161102 caused tf32 to be the default precision for flex attention. This PR forward-fixes the broken logic and restores ROCm MI200 CI flex attention test. Pull Request resolved: https://github.com/pytorch/pytorch/pull/161465 Approved by: https://github.com/jeffdaily, https://github.com/eqy Co-authored-by: Jeff Daily --- torch/_inductor/kernel/flex/flex_attention.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/torch/_inductor/kernel/flex/flex_attention.py b/torch/_inductor/kernel/flex/flex_attention.py index 816dedb8eff..cbb86b6090e 100644 --- a/torch/_inductor/kernel/flex/flex_attention.py +++ b/torch/_inductor/kernel/flex/flex_attention.py @@ -53,9 +53,11 @@ def flex_attention_grid(batch_size, q_heads, num_queries, d_model, meta, *, cdiv def get_float32_precision(): if ( - torch.backends.cuda.matmul.fp32_precision == "ieee" - if torch.backends.cuda.matmul.fp32_precision != "none" - else torch.get_float32_matmul_precision() == "highest" + ( + torch.backends.cuda.matmul.fp32_precision == "ieee" + if torch.backends.cuda.matmul.fp32_precision != "none" + else torch.get_float32_matmul_precision() == "highest" + ) or torch.version.hip or torch.mtia.is_available() ):