diff --git a/test/test_distributions.py b/test/test_distributions.py index 891d44164f8..62a19cd1b85 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -3469,7 +3469,7 @@ class TestKL(TestCase): class TestConstraints(TestCase): - def test_params_contains(self): + def test_params_constraints(self): for Dist, params in EXAMPLES: for i, param in enumerate(params): dist = Dist(**param) @@ -3492,7 +3492,7 @@ class TestConstraints(TestCase): Dist.__name__, i + 1, len(params), name, value) self.assertTrue(constraint.check(value).all(), msg=message) - def test_support_contains(self): + def test_support_constraints(self): for Dist, params in EXAMPLES: self.assertIsInstance(Dist.support, Constraint) for i, param in enumerate(params): diff --git a/torch/distributions/constraints.py b/torch/distributions/constraints.py index 88ca01f3ee2..7bcbc586434 100644 --- a/torch/distributions/constraints.py +++ b/torch/distributions/constraints.py @@ -251,7 +251,7 @@ class _Simplex(Constraint): Specifically: `x >= 0` and `x.sum(-1) == 1`. """ def check(self, value): - return (value >= 0).all() & ((value.sum(-1, True) - 1).abs() < 1e-6).all() + return torch.all(value >= 0, dim=-1) & ((value.sum(-1) - 1).abs() < 1e-6) class _LowerTriangular(Constraint): @@ -295,7 +295,7 @@ class _RealVector(Constraint): but additionally reduces across the `event_shape` dimension. """ def check(self, value): - return (value == value).all() # False for NANs. + return torch.all(value == value, dim=-1) # False for NANs. class _Cat(Constraint): diff --git a/torch/distributions/transforms.py b/torch/distributions/transforms.py index 5e0935a90de..c7e99e188ac 100644 --- a/torch/distributions/transforms.py +++ b/torch/distributions/transforms.py @@ -274,11 +274,14 @@ class ComposeTransform(Transform): if not self.parts: return torch.zeros_like(x) result = 0 - for part in self.parts: - y = part(x) - result = result + _sum_rightmost(part.log_abs_det_jacobian(x, y), + for part in self.parts[:-1]: + y_tmp = part(x) + result = result + _sum_rightmost(part.log_abs_det_jacobian(x, y_tmp), self.event_dim - part.event_dim) - x = y + x = y_tmp + part = self.parts[-1] + result = result + _sum_rightmost(part.log_abs_det_jacobian(x, y), + self.event_dim - part.event_dim) return result def __repr__(self): @@ -341,6 +344,11 @@ class PowerTransform(Transform): return (self.exponent * y / x).abs().log() +def _clipped_sigmoid(x): + finfo = torch.finfo(x.dtype) + return torch.clamp(torch.sigmoid(x), min=finfo.tiny, max=1. - finfo.eps) + + class SigmoidTransform(Transform): r""" Transform via the mapping :math:`y = \frac{1}{1 + \exp(-x)}` and :math:`x = \text{logit}(y)`. @@ -354,9 +362,11 @@ class SigmoidTransform(Transform): return isinstance(other, SigmoidTransform) def _call(self, x): - return torch.sigmoid(x) + return _clipped_sigmoid(x) def _inverse(self, y): + finfo = torch.finfo(y.dtype) + y = y.clamp(min=finfo.tiny, max=1. - finfo.eps) return y.log() - (-y).log1p() def log_abs_det_jacobian(self, x, y): @@ -495,23 +505,27 @@ class StickBreakingTransform(Transform): return isinstance(other, StickBreakingTransform) def _call(self, x): - offset = (x.shape[-1] + 1) - x.new([1]).expand(x.shape).cumsum(-1) - z = torch.sigmoid(x - offset.log()) + offset = x.shape[-1] + 1 - x.new_ones(x.shape[-1]).cumsum(-1) + z = _clipped_sigmoid(x - offset.log()) z_cumprod = (1 - z).cumprod(-1) y = pad(z, (0, 1), value=1) * pad(z_cumprod, (1, 0), value=1) return y def _inverse(self, y): - shape = y.shape[:-1] + (y.shape[-1] - 1,) - offset = (shape[-1] + 1) - y.new([1]).expand(shape).cumsum(-1) - sf = (1 - y.cumsum(-1))[..., :-1] - x = y[..., :-1].log() - sf.log() + offset.log() + y_crop = y[..., :-1] + offset = y.shape[-1] - y.new_ones(y_crop.shape[-1]).cumsum(-1) + sf = 1 - y_crop.cumsum(-1) + # we clamp to make sure that sf is positive which sometimes does not + # happen when y[-1] ~ 0 or y[:-1].sum() ~ 1 + sf = torch.clamp(sf, min=torch.finfo(y.dtype).tiny) + x = y_crop.log() - sf.log() + offset.log() return x def log_abs_det_jacobian(self, x, y): - offset = (x.shape[-1] + 1) - x.new([1]).expand(x.shape).cumsum(-1) - z = torch.sigmoid(x - offset.log()) - detJ = ((1 - z).log() + y[..., :-1].log()).sum(-1) + offset = x.shape[-1] + 1 - x.new_ones(x.shape[-1]).cumsum(-1) + x = x - offset.log() + # use the identity 1 - sigmoid(x) = exp(-x) * sigmoid(x) + detJ = (-x + F.logsigmoid(x) + y[..., :-1].log()).sum(-1) return detJ