diff --git a/torch/_inductor/kernel/flex/flex_attention.py b/torch/_inductor/kernel/flex/flex_attention.py index 816dedb8eff..cbb86b6090e 100644 --- a/torch/_inductor/kernel/flex/flex_attention.py +++ b/torch/_inductor/kernel/flex/flex_attention.py @@ -53,9 +53,11 @@ def flex_attention_grid(batch_size, q_heads, num_queries, d_model, meta, *, cdiv def get_float32_precision(): if ( - torch.backends.cuda.matmul.fp32_precision == "ieee" - if torch.backends.cuda.matmul.fp32_precision != "none" - else torch.get_float32_matmul_precision() == "highest" + ( + torch.backends.cuda.matmul.fp32_precision == "ieee" + if torch.backends.cuda.matmul.fp32_precision != "none" + else torch.get_float32_matmul_precision() == "highest" + ) or torch.version.hip or torch.mtia.is_available() ):