mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
clip sigmoid to prevent transforms return inf/nan values (#20288)
Summary: This PR addresses some numerical issues of Sigmoid/StickBreakingTransform, where these transforms give +-inf when the unconstrained values move to +-20 areas. For example, with ``` t = torch.distributions.SigmoidTransform() x = torch.tensor(20.) t.inv(t(x)), t.log_abs_det_jacobian(x, t(x)) ``` current behaviour the inverse will return `inf` and logdet return `-inf` while this PR makes it to `15.9424` and `-15.9424`. And for ``` t = torch.distributions.StickBreakingTransform() x = torch.tensor([20., 20.]) t.inv(t(x)), t.log_abs_det_jacobian(x, t(x)) ``` current value is `(inf, nan)` and `-inf` for logdet, while this PR makes it `[16.6355, 71.3942]` and `-47.8272` for logdet. Although these finite values are wrong and seems unavoidable, it is better than returning `inf` or `nan` in my opinion. This is useful in HMC where despite that the grad will be zero when the unconstrained parameter moves to unstable area (due to clipping), velocity variable will force the parameter move to another area which by chance can move the parameter out of unstable area. But inf/nan can be useful to stop doing inference early. So the changes in this PR might be inappropriate. I also fix some small issues of `_Simplex` and `_RealVector` constraints where batch shape of the input is not respected when checking validation. Pull Request resolved: https://github.com/pytorch/pytorch/pull/20288 Differential Revision: D15742047 Pulled By: ezyang fbshipit-source-id: b427ed1752c41327abb3957f98d4b289307a7d17
This commit is contained in:
parent
4bdbd30b96
commit
91ea2cd5a7
|
|
@ -3469,7 +3469,7 @@ class TestKL(TestCase):
|
|||
|
||||
|
||||
class TestConstraints(TestCase):
|
||||
def test_params_contains(self):
|
||||
def test_params_constraints(self):
|
||||
for Dist, params in EXAMPLES:
|
||||
for i, param in enumerate(params):
|
||||
dist = Dist(**param)
|
||||
|
|
@ -3492,7 +3492,7 @@ class TestConstraints(TestCase):
|
|||
Dist.__name__, i + 1, len(params), name, value)
|
||||
self.assertTrue(constraint.check(value).all(), msg=message)
|
||||
|
||||
def test_support_contains(self):
|
||||
def test_support_constraints(self):
|
||||
for Dist, params in EXAMPLES:
|
||||
self.assertIsInstance(Dist.support, Constraint)
|
||||
for i, param in enumerate(params):
|
||||
|
|
|
|||
|
|
@ -251,7 +251,7 @@ class _Simplex(Constraint):
|
|||
Specifically: `x >= 0` and `x.sum(-1) == 1`.
|
||||
"""
|
||||
def check(self, value):
|
||||
return (value >= 0).all() & ((value.sum(-1, True) - 1).abs() < 1e-6).all()
|
||||
return torch.all(value >= 0, dim=-1) & ((value.sum(-1) - 1).abs() < 1e-6)
|
||||
|
||||
|
||||
class _LowerTriangular(Constraint):
|
||||
|
|
@ -295,7 +295,7 @@ class _RealVector(Constraint):
|
|||
but additionally reduces across the `event_shape` dimension.
|
||||
"""
|
||||
def check(self, value):
|
||||
return (value == value).all() # False for NANs.
|
||||
return torch.all(value == value, dim=-1) # False for NANs.
|
||||
|
||||
|
||||
class _Cat(Constraint):
|
||||
|
|
|
|||
|
|
@ -274,11 +274,14 @@ class ComposeTransform(Transform):
|
|||
if not self.parts:
|
||||
return torch.zeros_like(x)
|
||||
result = 0
|
||||
for part in self.parts:
|
||||
y = part(x)
|
||||
result = result + _sum_rightmost(part.log_abs_det_jacobian(x, y),
|
||||
for part in self.parts[:-1]:
|
||||
y_tmp = part(x)
|
||||
result = result + _sum_rightmost(part.log_abs_det_jacobian(x, y_tmp),
|
||||
self.event_dim - part.event_dim)
|
||||
x = y
|
||||
x = y_tmp
|
||||
part = self.parts[-1]
|
||||
result = result + _sum_rightmost(part.log_abs_det_jacobian(x, y),
|
||||
self.event_dim - part.event_dim)
|
||||
return result
|
||||
|
||||
def __repr__(self):
|
||||
|
|
@ -341,6 +344,11 @@ class PowerTransform(Transform):
|
|||
return (self.exponent * y / x).abs().log()
|
||||
|
||||
|
||||
def _clipped_sigmoid(x):
|
||||
finfo = torch.finfo(x.dtype)
|
||||
return torch.clamp(torch.sigmoid(x), min=finfo.tiny, max=1. - finfo.eps)
|
||||
|
||||
|
||||
class SigmoidTransform(Transform):
|
||||
r"""
|
||||
Transform via the mapping :math:`y = \frac{1}{1 + \exp(-x)}` and :math:`x = \text{logit}(y)`.
|
||||
|
|
@ -354,9 +362,11 @@ class SigmoidTransform(Transform):
|
|||
return isinstance(other, SigmoidTransform)
|
||||
|
||||
def _call(self, x):
|
||||
return torch.sigmoid(x)
|
||||
return _clipped_sigmoid(x)
|
||||
|
||||
def _inverse(self, y):
|
||||
finfo = torch.finfo(y.dtype)
|
||||
y = y.clamp(min=finfo.tiny, max=1. - finfo.eps)
|
||||
return y.log() - (-y).log1p()
|
||||
|
||||
def log_abs_det_jacobian(self, x, y):
|
||||
|
|
@ -495,23 +505,27 @@ class StickBreakingTransform(Transform):
|
|||
return isinstance(other, StickBreakingTransform)
|
||||
|
||||
def _call(self, x):
|
||||
offset = (x.shape[-1] + 1) - x.new([1]).expand(x.shape).cumsum(-1)
|
||||
z = torch.sigmoid(x - offset.log())
|
||||
offset = x.shape[-1] + 1 - x.new_ones(x.shape[-1]).cumsum(-1)
|
||||
z = _clipped_sigmoid(x - offset.log())
|
||||
z_cumprod = (1 - z).cumprod(-1)
|
||||
y = pad(z, (0, 1), value=1) * pad(z_cumprod, (1, 0), value=1)
|
||||
return y
|
||||
|
||||
def _inverse(self, y):
|
||||
shape = y.shape[:-1] + (y.shape[-1] - 1,)
|
||||
offset = (shape[-1] + 1) - y.new([1]).expand(shape).cumsum(-1)
|
||||
sf = (1 - y.cumsum(-1))[..., :-1]
|
||||
x = y[..., :-1].log() - sf.log() + offset.log()
|
||||
y_crop = y[..., :-1]
|
||||
offset = y.shape[-1] - y.new_ones(y_crop.shape[-1]).cumsum(-1)
|
||||
sf = 1 - y_crop.cumsum(-1)
|
||||
# we clamp to make sure that sf is positive which sometimes does not
|
||||
# happen when y[-1] ~ 0 or y[:-1].sum() ~ 1
|
||||
sf = torch.clamp(sf, min=torch.finfo(y.dtype).tiny)
|
||||
x = y_crop.log() - sf.log() + offset.log()
|
||||
return x
|
||||
|
||||
def log_abs_det_jacobian(self, x, y):
|
||||
offset = (x.shape[-1] + 1) - x.new([1]).expand(x.shape).cumsum(-1)
|
||||
z = torch.sigmoid(x - offset.log())
|
||||
detJ = ((1 - z).log() + y[..., :-1].log()).sum(-1)
|
||||
offset = x.shape[-1] + 1 - x.new_ones(x.shape[-1]).cumsum(-1)
|
||||
x = x - offset.log()
|
||||
# use the identity 1 - sigmoid(x) = exp(-x) * sigmoid(x)
|
||||
detJ = (-x + F.logsigmoid(x) + y[..., :-1].log()).sum(-1)
|
||||
return detJ
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user