PR #161102 caused tf32 to be the default precision for flex attention.  This PR forward-fixes the broken logic and restores ROCm MI200 CI flex attention test.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161465
Approved by: https://github.com/jeffdaily, https://github.com/eqy

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
This commit is contained in:
Jeff Daily 2025-08-26 15:11:54 +00:00 committed by PyTorch MergeBot
parent 818ba434c7
commit 10e67f5ec3

View File

@ -53,9 +53,11 @@ def flex_attention_grid(batch_size, q_heads, num_queries, d_model, meta, *, cdiv
def get_float32_precision():
if (
torch.backends.cuda.matmul.fp32_precision == "ieee"
if torch.backends.cuda.matmul.fp32_precision != "none"
else torch.get_float32_matmul_precision() == "highest"
(
torch.backends.cuda.matmul.fp32_precision == "ieee"
if torch.backends.cuda.matmul.fp32_precision != "none"
else torch.get_float32_matmul_precision() == "highest"
)
or torch.version.hip
or torch.mtia.is_available()
):