From 91ea2cd5a76172bf25b51b19fbd2191544ef3e45 Mon Sep 17 00:00:00 2001 From: fehiepsi Date: Mon, 10 Jun 2019 11:10:11 -0700 Subject: [PATCH] clip sigmoid to prevent transforms return inf/nan values (#20288) Summary: This PR addresses some numerical issues of Sigmoid/StickBreakingTransform, where these transforms give +-inf when the unconstrained values move to +-20 areas. For example, with ``` t = torch.distributions.SigmoidTransform() x = torch.tensor(20.) t.inv(t(x)), t.log_abs_det_jacobian(x, t(x)) ``` current behaviour the inverse will return `inf` and logdet return `-inf` while this PR makes it to `15.9424` and `-15.9424`. And for ``` t = torch.distributions.StickBreakingTransform() x = torch.tensor([20., 20.]) t.inv(t(x)), t.log_abs_det_jacobian(x, t(x)) ``` current value is `(inf, nan)` and `-inf` for logdet, while this PR makes it `[16.6355, 71.3942]` and `-47.8272` for logdet. Although these finite values are wrong and seems unavoidable, it is better than returning `inf` or `nan` in my opinion. This is useful in HMC where despite that the grad will be zero when the unconstrained parameter moves to unstable area (due to clipping), velocity variable will force the parameter move to another area which by chance can move the parameter out of unstable area. But inf/nan can be useful to stop doing inference early. So the changes in this PR might be inappropriate. I also fix some small issues of `_Simplex` and `_RealVector` constraints where batch shape of the input is not respected when checking validation. Pull Request resolved: https://github.com/pytorch/pytorch/pull/20288 Differential Revision: D15742047 Pulled By: ezyang fbshipit-source-id: b427ed1752c41327abb3957f98d4b289307a7d17 --- test/test_distributions.py | 4 +-- torch/distributions/constraints.py | 4 +-- torch/distributions/transforms.py | 42 ++++++++++++++++++++---------- 3 files changed, 32 insertions(+), 18 deletions(-) diff --git a/test/test_distributions.py b/test/test_distributions.py index 891d44164f8..62a19cd1b85 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -3469,7 +3469,7 @@ class TestKL(TestCase): class TestConstraints(TestCase): - def test_params_contains(self): + def test_params_constraints(self): for Dist, params in EXAMPLES: for i, param in enumerate(params): dist = Dist(**param) @@ -3492,7 +3492,7 @@ class TestConstraints(TestCase): Dist.__name__, i + 1, len(params), name, value) self.assertTrue(constraint.check(value).all(), msg=message) - def test_support_contains(self): + def test_support_constraints(self): for Dist, params in EXAMPLES: self.assertIsInstance(Dist.support, Constraint) for i, param in enumerate(params): diff --git a/torch/distributions/constraints.py b/torch/distributions/constraints.py index 88ca01f3ee2..7bcbc586434 100644 --- a/torch/distributions/constraints.py +++ b/torch/distributions/constraints.py @@ -251,7 +251,7 @@ class _Simplex(Constraint): Specifically: `x >= 0` and `x.sum(-1) == 1`. """ def check(self, value): - return (value >= 0).all() & ((value.sum(-1, True) - 1).abs() < 1e-6).all() + return torch.all(value >= 0, dim=-1) & ((value.sum(-1) - 1).abs() < 1e-6) class _LowerTriangular(Constraint): @@ -295,7 +295,7 @@ class _RealVector(Constraint): but additionally reduces across the `event_shape` dimension. """ def check(self, value): - return (value == value).all() # False for NANs. + return torch.all(value == value, dim=-1) # False for NANs. class _Cat(Constraint): diff --git a/torch/distributions/transforms.py b/torch/distributions/transforms.py index 5e0935a90de..c7e99e188ac 100644 --- a/torch/distributions/transforms.py +++ b/torch/distributions/transforms.py @@ -274,11 +274,14 @@ class ComposeTransform(Transform): if not self.parts: return torch.zeros_like(x) result = 0 - for part in self.parts: - y = part(x) - result = result + _sum_rightmost(part.log_abs_det_jacobian(x, y), + for part in self.parts[:-1]: + y_tmp = part(x) + result = result + _sum_rightmost(part.log_abs_det_jacobian(x, y_tmp), self.event_dim - part.event_dim) - x = y + x = y_tmp + part = self.parts[-1] + result = result + _sum_rightmost(part.log_abs_det_jacobian(x, y), + self.event_dim - part.event_dim) return result def __repr__(self): @@ -341,6 +344,11 @@ class PowerTransform(Transform): return (self.exponent * y / x).abs().log() +def _clipped_sigmoid(x): + finfo = torch.finfo(x.dtype) + return torch.clamp(torch.sigmoid(x), min=finfo.tiny, max=1. - finfo.eps) + + class SigmoidTransform(Transform): r""" Transform via the mapping :math:`y = \frac{1}{1 + \exp(-x)}` and :math:`x = \text{logit}(y)`. @@ -354,9 +362,11 @@ class SigmoidTransform(Transform): return isinstance(other, SigmoidTransform) def _call(self, x): - return torch.sigmoid(x) + return _clipped_sigmoid(x) def _inverse(self, y): + finfo = torch.finfo(y.dtype) + y = y.clamp(min=finfo.tiny, max=1. - finfo.eps) return y.log() - (-y).log1p() def log_abs_det_jacobian(self, x, y): @@ -495,23 +505,27 @@ class StickBreakingTransform(Transform): return isinstance(other, StickBreakingTransform) def _call(self, x): - offset = (x.shape[-1] + 1) - x.new([1]).expand(x.shape).cumsum(-1) - z = torch.sigmoid(x - offset.log()) + offset = x.shape[-1] + 1 - x.new_ones(x.shape[-1]).cumsum(-1) + z = _clipped_sigmoid(x - offset.log()) z_cumprod = (1 - z).cumprod(-1) y = pad(z, (0, 1), value=1) * pad(z_cumprod, (1, 0), value=1) return y def _inverse(self, y): - shape = y.shape[:-1] + (y.shape[-1] - 1,) - offset = (shape[-1] + 1) - y.new([1]).expand(shape).cumsum(-1) - sf = (1 - y.cumsum(-1))[..., :-1] - x = y[..., :-1].log() - sf.log() + offset.log() + y_crop = y[..., :-1] + offset = y.shape[-1] - y.new_ones(y_crop.shape[-1]).cumsum(-1) + sf = 1 - y_crop.cumsum(-1) + # we clamp to make sure that sf is positive which sometimes does not + # happen when y[-1] ~ 0 or y[:-1].sum() ~ 1 + sf = torch.clamp(sf, min=torch.finfo(y.dtype).tiny) + x = y_crop.log() - sf.log() + offset.log() return x def log_abs_det_jacobian(self, x, y): - offset = (x.shape[-1] + 1) - x.new([1]).expand(x.shape).cumsum(-1) - z = torch.sigmoid(x - offset.log()) - detJ = ((1 - z).log() + y[..., :-1].log()).sum(-1) + offset = x.shape[-1] + 1 - x.new_ones(x.shape[-1]).cumsum(-1) + x = x - offset.log() + # use the identity 1 - sigmoid(x) = exp(-x) * sigmoid(x) + detJ = (-x + F.logsigmoid(x) + y[..., :-1].log()).sum(-1) return detJ