Map digamma +/-inf results to nan in test (fixes #7651) (#7665)

This commit is contained in:
Thomas Viehmann 2018-05-18 16:35:00 +02:00 committed by Soumith Chintala
parent 50d8473ccc
commit bf95dff85b

View File

@ -378,7 +378,21 @@ class TestTorch(TestCase):
@unittest.skipIf(not TEST_SCIPY, "Scipy not found")
def test_digamma(self):
from scipy.special import digamma
self._test_math(torch.digamma, digamma, self._digamma_input())
# scipy 1.1.0 changed when it returns +/-inf vs. NaN
def torch_digamma_without_inf(inp):
res = torch.digamma(inp)
res[(res == float('-inf')) | (res == float('inf'))] = float('nan')
return res
def scipy_digamma_without_inf(inp):
res = digamma(inp)
if np.isscalar(res):
return res if np.isfinite(res) else float('nan')
res[np.isinf(res)] = float('nan')
return res
self._test_math(torch_digamma_without_inf, scipy_digamma_without_inf, self._digamma_input())
@unittest.skipIf(not TEST_SCIPY, "Scipy not found")
def test_polygamma(self):