mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
PR #161102 caused tf32 to be the default precision for flex attention. This PR forward-fixes the broken logic and restores ROCm MI200 CI flex attention test. Pull Request resolved: https://github.com/pytorch/pytorch/pull/161465 Approved by: https://github.com/jeffdaily, https://github.com/eqy Co-authored-by: Jeff Daily <jeff.daily@amd.com>
This commit is contained in:
parent
818ba434c7
commit
10e67f5ec3
|
|
@ -53,9 +53,11 @@ def flex_attention_grid(batch_size, q_heads, num_queries, d_model, meta, *, cdiv
|
||||||
|
|
||||||
def get_float32_precision():
|
def get_float32_precision():
|
||||||
if (
|
if (
|
||||||
torch.backends.cuda.matmul.fp32_precision == "ieee"
|
(
|
||||||
if torch.backends.cuda.matmul.fp32_precision != "none"
|
torch.backends.cuda.matmul.fp32_precision == "ieee"
|
||||||
else torch.get_float32_matmul_precision() == "highest"
|
if torch.backends.cuda.matmul.fp32_precision != "none"
|
||||||
|
else torch.get_float32_matmul_precision() == "highest"
|
||||||
|
)
|
||||||
or torch.version.hip
|
or torch.version.hip
|
||||||
or torch.mtia.is_available()
|
or torch.mtia.is_available()
|
||||||
):
|
):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user