mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fix test_binary_ufuncs.py for NumPy 2 (#137937)
Related to #107302 The following tests failed in test_binary_ufuncs.py when testing with NumPy 2. ``` FAILED [0.0050s] test/test_binary_ufuncs.py::TestBinaryUfuncsCPU::test_scalar_support__refs_sub_cpu_complex64 - AssertionError FAILED [0.0043s] test/test_binary_ufuncs.py::TestBinaryUfuncsCPU::test_scalar_support__refs_sub_cpu_float32 - AssertionError FAILED [0.0048s] test/test_binary_ufuncs.py::TestBinaryUfuncsCPU::test_scalar_support_sub_cpu_complex64 - AssertionError FAILED [0.0043s] test/test_binary_ufuncs.py::TestBinaryUfuncsCPU::test_scalar_support_sub_cpu_float32 - AssertionError FAILED [0.0028s] test/test_binary_ufuncs.py::TestBinaryUfuncsCPU::test_shift_limits_cpu_uint8 - OverflowError: Python integer -100 out of bounds for uint8 ``` This PR fixes them. More details: * `test_shift_limits` failed because `np.left_shift()` and `np.right_shift()` no longer support negative shift values in NumPy 2. * `test_scalar_support` failed because NumPy 2 changed its dtype promo rules. We special-cased the incompatible cases by changing the expected dtypes. Pull Request resolved: https://github.com/pytorch/pytorch/pull/137937 Approved by: https://github.com/albanD
This commit is contained in:
parent
e4d7676c1b
commit
bdbe0cfffa
|
|
@ -159,6 +159,15 @@ class TestBinaryUfuncs(TestCase):
|
|||
actual = op(l, r)
|
||||
expected = op.ref(l_numpy, r_numpy)
|
||||
|
||||
# Dtype promo rules have changed since NumPy 2.
|
||||
# Specialize the backward-incompatible cases.
|
||||
if (
|
||||
np.__version__ > "2"
|
||||
and op.name in ("sub", "_refs.sub")
|
||||
and isinstance(l_numpy, np.ndarray)
|
||||
):
|
||||
expected = expected.astype(l_numpy.dtype)
|
||||
|
||||
# Crafts a custom error message for smaller, printable tensors
|
||||
def _numel(x):
|
||||
if isinstance(x, torch.Tensor):
|
||||
|
|
@ -3199,7 +3208,12 @@ class TestBinaryUfuncs(TestCase):
|
|||
):
|
||||
shift_left_expected = torch.zeros_like(input)
|
||||
shift_right_expected = torch.clamp(input, -1, 0)
|
||||
for shift in chain(range(-100, -1), range(bits, 100)):
|
||||
# NumPy 2 does not support negative shift values.
|
||||
if np.__version__ > "2":
|
||||
iterator = range(bits, 100)
|
||||
else:
|
||||
iterator = chain(range(-100, -1), range(bits, 100))
|
||||
for shift in iterator:
|
||||
shift_left = input << shift
|
||||
self.assertEqual(shift_left, shift_left_expected, msg=f"<< {shift}")
|
||||
self.compare_with_numpy(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user