[ROCm] fix numpy version detection and adjust fudge_factors for MI355 (#161429)

This PR fixes:

- Numpy >= 2.1 version detection (instead of python 3.13 version detection) to skip some tests (numpy 2.1 can be installed for older python versions)
```
test_quantization.py::TestDynamicQuantizedOps::test_qlinear
test_quantization.py::TestDynamicQuantizedOps::test_qlinear_legacy
test_quantization.py::TestQuantizedLinear::test_qlinear
test_quantization.py::TestQuantizedLinear::test_qlinear_leaky_relu
test_quantization.py::TestQuantizedLinear::test_qlinear_relu
test_quantization.py::TestQuantizedLinear::test_qlinear_tanh
test_quantization.py::TestQuantizedLinear::test_qlinear_with_input_q_dq_qweight_dq_output_fp32
```
- A couple of SDPA tests on MI355 by adjusting fudge_factors:

```
test_transformers.py::TestSDPACudaOnlyCUDA::test_mem_efficient_attention_attn_mask_vs_math_ref_grads_batch_size_1_seq_len_q_2048_seq_len_k_8_head_dim_8_is_causal_False_dropout_p_0_0_float32_scale_l1_cuda_float32
test_transformers.py::TestSDPACudaOnlyCUDA::test_mem_efficient_attention_vs_math_ref_grads_batch_size_8_seq_len_q_2048_seq_len_k_8_head_dim_128_is_causal_True_dropout_p_0_0_float32_scale0_cuda_float32
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161429
Approved by: https://github.com/jeffdaily
This commit is contained in:
Dmitry Nikolaev 2025-08-28 19:32:06 +00:00 committed by PyTorch MergeBot
parent 130e50afff
commit b76f6d117a
2 changed files with 4 additions and 2 deletions

View File

@ -7,8 +7,8 @@ import itertools
import numpy as np import numpy as np
import operator import operator
import random import random
import sys
import unittest import unittest
from packaging.version import Version
from typing import NamedTuple from typing import NamedTuple
import torch import torch
@ -73,7 +73,7 @@ class PointwisePostOp(NamedTuple):
def avoid_vpmaddubsw_overflow_linear( def avoid_vpmaddubsw_overflow_linear(
batch_size, input_channels, output_channels, X, X_min, X_max, W, W_min, W_max batch_size, input_channels, output_channels, X, X_min, X_max, W, W_min, W_max
): ):
if sys.version_info >= (3, 13): if Version(np.__version__) >= Version("2.1"):
raise unittest.SkipTest("numpy 2.1 overflow error") raise unittest.SkipTest("numpy 2.1 overflow error")
for i, j in np.ndindex((batch_size, output_channels)): for i, j in np.ndindex((batch_size, output_channels)):
for k in range(0, input_channels // 2 * 2, 2): for k in range(0, input_channels // 2 * 2, 2):

View File

@ -3555,6 +3555,8 @@ class TestSDPACudaOnly(NNTestCase):
fudge_factors['grad_query'] = 670.0 # gfx90a fudge_factors['grad_query'] = 670.0 # gfx90a
if dtype == torch.float32: if dtype == torch.float32:
fudge_factors['grad_key'] = 90.0 fudge_factors['grad_key'] = 90.0
if "gfx95" in torch.cuda.get_device_properties(0).gcnArchName:
fudge_factors['grad_value'] = 12.0
check_out_and_grad( check_out_and_grad(
(out_ref, out_lp_ref, out), (out_ref, out_lp_ref, out),