mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
parent
50d8473ccc
commit
bf95dff85b
|
|
@ -378,7 +378,21 @@ class TestTorch(TestCase):
|
|||
@unittest.skipIf(not TEST_SCIPY, "Scipy not found")
|
||||
def test_digamma(self):
|
||||
from scipy.special import digamma
|
||||
self._test_math(torch.digamma, digamma, self._digamma_input())
|
||||
|
||||
# scipy 1.1.0 changed when it returns +/-inf vs. NaN
|
||||
def torch_digamma_without_inf(inp):
|
||||
res = torch.digamma(inp)
|
||||
res[(res == float('-inf')) | (res == float('inf'))] = float('nan')
|
||||
return res
|
||||
|
||||
def scipy_digamma_without_inf(inp):
|
||||
res = digamma(inp)
|
||||
if np.isscalar(res):
|
||||
return res if np.isfinite(res) else float('nan')
|
||||
res[np.isinf(res)] = float('nan')
|
||||
return res
|
||||
|
||||
self._test_math(torch_digamma_without_inf, scipy_digamma_without_inf, self._digamma_input())
|
||||
|
||||
@unittest.skipIf(not TEST_SCIPY, "Scipy not found")
|
||||
def test_polygamma(self):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user