Fix numerically instability of SigmoidTransform (#19802)

Summary:
fix #18254 for numerically instability of `SigmoidTransform`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19802

Differential Revision: D15701837

Pulled By: ezyang

fbshipit-source-id: fe6c755c523487c8bbdcc3bfb8455801617c70a4
This commit is contained in:
Xingdong Zuo 2019-06-06 13:44:29 -07:00 committed by Facebook Github Bot
parent f8cab38578
commit c5d5d45f40

View File

@ -3,6 +3,7 @@ import numbers
import weakref
import torch
import torch.nn.functional as F
from torch.distributions import constraints
from torch.distributions.utils import (_sum_rightmost, broadcast_all,
lazy_property)
@ -359,7 +360,7 @@ class SigmoidTransform(Transform):
return y.log() - (-y).log1p()
def log_abs_det_jacobian(self, x, y):
return -(y.reciprocal() + (1 - y).reciprocal()).log()
return -F.softplus(-x) - F.softplus(x)
class AbsTransform(Transform):