mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Fix numerically instability of SigmoidTransform (#19802)
Summary: fix #18254 for numerically instability of `SigmoidTransform` Pull Request resolved: https://github.com/pytorch/pytorch/pull/19802 Differential Revision: D15701837 Pulled By: ezyang fbshipit-source-id: fe6c755c523487c8bbdcc3bfb8455801617c70a4
This commit is contained in:
parent
f8cab38578
commit
c5d5d45f40
|
|
@ -3,6 +3,7 @@ import numbers
|
|||
import weakref
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.distributions import constraints
|
||||
from torch.distributions.utils import (_sum_rightmost, broadcast_all,
|
||||
lazy_property)
|
||||
|
|
@ -359,7 +360,7 @@ class SigmoidTransform(Transform):
|
|||
return y.log() - (-y).log1p()
|
||||
|
||||
def log_abs_det_jacobian(self, x, y):
|
||||
return -(y.reciprocal() + (1 - y).reciprocal()).log()
|
||||
return -F.softplus(-x) - F.softplus(x)
|
||||
|
||||
|
||||
class AbsTransform(Transform):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user