Fix test_binary_ufuncs.py for NumPy 2 (#137937)

Related to #107302

The following tests failed in test_binary_ufuncs.py when testing with NumPy 2.

```
FAILED [0.0050s] test/test_binary_ufuncs.py::TestBinaryUfuncsCPU::test_scalar_support__refs_sub_cpu_complex64 - AssertionError
FAILED [0.0043s] test/test_binary_ufuncs.py::TestBinaryUfuncsCPU::test_scalar_support__refs_sub_cpu_float32 - AssertionError
FAILED [0.0048s] test/test_binary_ufuncs.py::TestBinaryUfuncsCPU::test_scalar_support_sub_cpu_complex64 - AssertionError
FAILED [0.0043s] test/test_binary_ufuncs.py::TestBinaryUfuncsCPU::test_scalar_support_sub_cpu_float32 - AssertionError
FAILED [0.0028s] test/test_binary_ufuncs.py::TestBinaryUfuncsCPU::test_shift_limits_cpu_uint8 - OverflowError: Python integer -100 out of bounds for uint8
```

This PR fixes them.

More details:
* `test_shift_limits` failed because `np.left_shift()` and `np.right_shift()` no longer support negative shift values in NumPy 2.
* `test_scalar_support` failed because NumPy 2 changed its dtype promo rules. We special-cased the incompatible cases by changing the expected dtypes.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137937
Approved by: https://github.com/albanD
This commit is contained in:
Haifeng Jin 2024-10-15 17:04:21 +00:00 committed by PyTorch MergeBot
parent e4d7676c1b
commit bdbe0cfffa

View File

@ -159,6 +159,15 @@ class TestBinaryUfuncs(TestCase):
actual = op(l, r)
expected = op.ref(l_numpy, r_numpy)
# Dtype promo rules have changed since NumPy 2.
# Specialize the backward-incompatible cases.
if (
np.__version__ > "2"
and op.name in ("sub", "_refs.sub")
and isinstance(l_numpy, np.ndarray)
):
expected = expected.astype(l_numpy.dtype)
# Crafts a custom error message for smaller, printable tensors
def _numel(x):
if isinstance(x, torch.Tensor):
@ -3199,7 +3208,12 @@ class TestBinaryUfuncs(TestCase):
):
shift_left_expected = torch.zeros_like(input)
shift_right_expected = torch.clamp(input, -1, 0)
for shift in chain(range(-100, -1), range(bits, 100)):
# NumPy 2 does not support negative shift values.
if np.__version__ > "2":
iterator = range(bits, 100)
else:
iterator = chain(range(-100, -1), range(bits, 100))
for shift in iterator:
shift_left = input << shift
self.assertEqual(shift_left, shift_left_expected, msg=f"<< {shift}")
self.compare_with_numpy(