clip sigmoid to prevent transforms return inf/nan values (#20288)

Summary:
This PR addresses some numerical issues of Sigmoid/StickBreakingTransform, where these transforms give +-inf when the unconstrained values move to +-20 areas.

For example, with
```
t = torch.distributions.SigmoidTransform()
x = torch.tensor(20.)
t.inv(t(x)), t.log_abs_det_jacobian(x, t(x))
```
current behaviour the inverse will return `inf` and logdet return `-inf` while this PR makes it to `15.9424` and `-15.9424`.

And for
```
t = torch.distributions.StickBreakingTransform()
x = torch.tensor([20., 20.])
t.inv(t(x)), t.log_abs_det_jacobian(x, t(x))
```
current value is `(inf, nan)` and `-inf` for logdet, while this PR makes it `[16.6355, 71.3942]` and `-47.8272` for logdet.

Although these finite values are wrong and seems unavoidable, it is better than returning `inf` or `nan` in my opinion. This is useful in HMC where despite that the grad will be zero when the unconstrained parameter moves to unstable area (due to clipping), velocity variable will force the parameter move to another area which by chance can move the parameter out of unstable area. But inf/nan can be useful to stop doing inference early. So the changes in this PR might be inappropriate.

I also fix some small issues of `_Simplex` and `_RealVector` constraints where batch shape of the input is not respected when checking validation.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/20288

Differential Revision: D15742047

Pulled By: ezyang

fbshipit-source-id: b427ed1752c41327abb3957f98d4b289307a7d17
This commit is contained in:
fehiepsi 2019-06-10 11:10:11 -07:00 committed by Facebook Github Bot
parent 4bdbd30b96
commit 91ea2cd5a7
3 changed files with 32 additions and 18 deletions

View File

@ -3469,7 +3469,7 @@ class TestKL(TestCase):
class TestConstraints(TestCase):
def test_params_contains(self):
def test_params_constraints(self):
for Dist, params in EXAMPLES:
for i, param in enumerate(params):
dist = Dist(**param)
@ -3492,7 +3492,7 @@ class TestConstraints(TestCase):
Dist.__name__, i + 1, len(params), name, value)
self.assertTrue(constraint.check(value).all(), msg=message)
def test_support_contains(self):
def test_support_constraints(self):
for Dist, params in EXAMPLES:
self.assertIsInstance(Dist.support, Constraint)
for i, param in enumerate(params):

View File

@ -251,7 +251,7 @@ class _Simplex(Constraint):
Specifically: `x >= 0` and `x.sum(-1) == 1`.
"""
def check(self, value):
return (value >= 0).all() & ((value.sum(-1, True) - 1).abs() < 1e-6).all()
return torch.all(value >= 0, dim=-1) & ((value.sum(-1) - 1).abs() < 1e-6)
class _LowerTriangular(Constraint):
@ -295,7 +295,7 @@ class _RealVector(Constraint):
but additionally reduces across the `event_shape` dimension.
"""
def check(self, value):
return (value == value).all() # False for NANs.
return torch.all(value == value, dim=-1) # False for NANs.
class _Cat(Constraint):

View File

@ -274,11 +274,14 @@ class ComposeTransform(Transform):
if not self.parts:
return torch.zeros_like(x)
result = 0
for part in self.parts:
y = part(x)
result = result + _sum_rightmost(part.log_abs_det_jacobian(x, y),
for part in self.parts[:-1]:
y_tmp = part(x)
result = result + _sum_rightmost(part.log_abs_det_jacobian(x, y_tmp),
self.event_dim - part.event_dim)
x = y
x = y_tmp
part = self.parts[-1]
result = result + _sum_rightmost(part.log_abs_det_jacobian(x, y),
self.event_dim - part.event_dim)
return result
def __repr__(self):
@ -341,6 +344,11 @@ class PowerTransform(Transform):
return (self.exponent * y / x).abs().log()
def _clipped_sigmoid(x):
finfo = torch.finfo(x.dtype)
return torch.clamp(torch.sigmoid(x), min=finfo.tiny, max=1. - finfo.eps)
class SigmoidTransform(Transform):
r"""
Transform via the mapping :math:`y = \frac{1}{1 + \exp(-x)}` and :math:`x = \text{logit}(y)`.
@ -354,9 +362,11 @@ class SigmoidTransform(Transform):
return isinstance(other, SigmoidTransform)
def _call(self, x):
return torch.sigmoid(x)
return _clipped_sigmoid(x)
def _inverse(self, y):
finfo = torch.finfo(y.dtype)
y = y.clamp(min=finfo.tiny, max=1. - finfo.eps)
return y.log() - (-y).log1p()
def log_abs_det_jacobian(self, x, y):
@ -495,23 +505,27 @@ class StickBreakingTransform(Transform):
return isinstance(other, StickBreakingTransform)
def _call(self, x):
offset = (x.shape[-1] + 1) - x.new([1]).expand(x.shape).cumsum(-1)
z = torch.sigmoid(x - offset.log())
offset = x.shape[-1] + 1 - x.new_ones(x.shape[-1]).cumsum(-1)
z = _clipped_sigmoid(x - offset.log())
z_cumprod = (1 - z).cumprod(-1)
y = pad(z, (0, 1), value=1) * pad(z_cumprod, (1, 0), value=1)
return y
def _inverse(self, y):
shape = y.shape[:-1] + (y.shape[-1] - 1,)
offset = (shape[-1] + 1) - y.new([1]).expand(shape).cumsum(-1)
sf = (1 - y.cumsum(-1))[..., :-1]
x = y[..., :-1].log() - sf.log() + offset.log()
y_crop = y[..., :-1]
offset = y.shape[-1] - y.new_ones(y_crop.shape[-1]).cumsum(-1)
sf = 1 - y_crop.cumsum(-1)
# we clamp to make sure that sf is positive which sometimes does not
# happen when y[-1] ~ 0 or y[:-1].sum() ~ 1
sf = torch.clamp(sf, min=torch.finfo(y.dtype).tiny)
x = y_crop.log() - sf.log() + offset.log()
return x
def log_abs_det_jacobian(self, x, y):
offset = (x.shape[-1] + 1) - x.new([1]).expand(x.shape).cumsum(-1)
z = torch.sigmoid(x - offset.log())
detJ = ((1 - z).log() + y[..., :-1].log()).sum(-1)
offset = x.shape[-1] + 1 - x.new_ones(x.shape[-1]).cumsum(-1)
x = x - offset.log()
# use the identity 1 - sigmoid(x) = exp(-x) * sigmoid(x)
detJ = (-x + F.logsigmoid(x) + y[..., :-1].log()).sum(-1)
return detJ