Fix sign error in TransformedDistribution.cdf() and .icdf() (#5172)

This commit is contained in:
Fritz Obermeyer 2018-02-12 02:45:22 -08:00 committed by Adam Paszke
parent a061000250
commit b608ea9178
3 changed files with 59 additions and 16 deletions

View File

@ -285,9 +285,9 @@ EXAMPLES = [
'transforms': ExpTransform(),
},
{
'base_distribution': Normal(Variable(torch.randn(2, 3), requires_grad=True),
Variable(torch.randn(2, 3).abs(), requires_grad=True)),
'transforms': [AffineTransform(Variable(torch.randn(1)), Variable(torch.randn(1))),
'base_distribution': Normal(Variable(torch.randn(2, 3, 5), requires_grad=True),
Variable(torch.randn(2, 3, 5).abs(), requires_grad=True)),
'transforms': [AffineTransform(Variable(torch.randn(3, 5)), Variable(torch.randn(3, 5))),
ExpTransform()],
},
]),
@ -1317,7 +1317,7 @@ class TestDistributions(TestCase):
pdfs = dist.log_prob(samples).exp()
except NotImplementedError:
continue
cdfs_derivative = grad(cdfs.sum(), [samples])[0]
cdfs_derivative = grad(cdfs.sum(), [samples])[0] # this should not be wrapped in torch.abs()
self.assertEqual(cdfs_derivative, pdfs, message='\n'.join([
'{} example {}/{}, d(cdf)/dx != pdf(x)'.format(Dist.__name__, i + 1, len(params)),
'x = {}'.format(samples),

View File

@ -49,8 +49,9 @@ class TransformedDistribution(Distribution):
def sample(self, sample_shape=torch.Size()):
"""
Generates a sample_shape shaped sample or sample_shape shaped batch of
samples if the distribution parameters are batched. Samples first from base distribution
and applies `transform()` for every transform in the list.
samples if the distribution parameters are batched. Samples first from
base distribution and applies `transform()` for every transform in the
list.
"""
x = self.base_dist.sample(sample_shape)
for transform in self.transforms:
@ -61,8 +62,8 @@ class TransformedDistribution(Distribution):
"""
Generates a sample_shape shaped reparameterized sample or sample_shape
shaped batch of reparameterized samples if the distribution parameters
are batched. Samples first from base distribution and applies `transform()`
for every transform in the list.
are batched. Samples first from base distribution and applies
`transform()` for every transform in the list.
"""
x = self.base_dist.rsample(sample_shape)
for transform in self.transforms:
@ -71,8 +72,8 @@ class TransformedDistribution(Distribution):
def log_prob(self, value):
"""
Scores the sample by inverting the transform(s) and computing the score using the score
of the base distribution and the log abs det jacobian
Scores the sample by inverting the transform(s) and computing the score
using the score of the base distribution and the log abs det jacobian.
"""
self.base_dist._validate_log_prob_arg(value)
event_dim = len(self.event_shape)
@ -87,21 +88,35 @@ class TransformedDistribution(Distribution):
event_dim - len(self.base_dist.event_shape))
return log_prob
def _monotonize_cdf(self, value):
"""
This conditionally flips ``value -> 1-value`` to ensure :meth:`cdf` is
monotone increasing.
"""
sign = 1
for transform in self.transforms:
sign = sign * transform.sign
if sign is 1:
return value
return sign * (value - 0.5) + 0.5
def cdf(self, value):
"""
Computes the cumulative distribution function by inverting the transform(s) and computing
the score of the base distribution
Computes the cumulative distribution function by inverting the
transform(s) and computing the score of the base distribution.
"""
self.base_dist._validate_log_prob_arg(value)
for transform in self.transforms[::-1]:
value = transform.inv(value)
return self.base_dist.cdf(value)
value = self.base_dist.cdf(value)
value = self._monotonize_cdf(value)
return value
def icdf(self, value):
"""
Computes the inverse cumulative distribution function using transform(s) and computing
the score of the base distribution
Computes the inverse cumulative distribution function using
transform(s) and computing the score of the base distribution.
"""
value = self._monotonize_cdf(value)
value = self.base_dist.icdf(value)
for transform in self.transforms:
value = transform(value)

View File

@ -61,6 +61,9 @@ class Transform(object):
the codomain. Transforms that are not bijective should at least
maintain the weaker pseudoinverse properties
``t(t.inv(t(x)) == t(x)`` and ``t.inv(t(t.inv(y))) == t.inv(y)``.
sign (int or Variable): For bijective univariate transforms, this
should be +1 or -1 depending on whether transform is monotone
increasing or decreasing.
event_dim (int): Number of dimensions that are correlated together in
the transform ``event_shape``. This should be 0 for pointwise
transforms, 1 for transforms that act jointly on vectors, 2 for
@ -93,6 +96,14 @@ class Transform(object):
self._inv = weakref.ref(inv)
return inv
@property
def sign(self):
"""
Returns the sign of the determinant of the Jacobian, if applicable.
In general this only makes sense for bijective transforms.
"""
raise NotImplementedError
def __eq__(self, other):
return self is other
@ -166,6 +177,10 @@ class _InverseTransform(Transform):
def bijective(self):
return self._inv.bijective
@property
def sign(self):
return self._inv.sign
@property
def event_dim(self):
return self._inv.event_dim
@ -219,6 +234,13 @@ class ComposeTransform(Transform):
def bijective(self):
return all(p.bijective for p in self.parts)
@lazy_property
def sign(self):
sign = 1
for p in self.parts:
sign = sign * p.sign
return sign
@lazy_property
def event_dim(self):
return max(p.event_dim for p in self.parts) if self.parts else 0
@ -261,6 +283,7 @@ class ExpTransform(Transform):
domain = constraints.real
codomain = constraints.positive
bijective = True
sign = +1
def __eq__(self, other):
return isinstance(other, ExpTransform)
@ -282,6 +305,7 @@ class SigmoidTransform(Transform):
domain = constraints.real
codomain = constraints.unit_interval
bijective = True
sign = +1
def __eq__(self, other):
return isinstance(other, SigmoidTransform)
@ -341,6 +365,10 @@ class AffineTransform(Transform):
result = result.data.view(-1)[0]
return result
@property
def sign(self):
return self.scale.sign()
def _call(self, x):
return self.loc + self.scale * x