mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Fix sign error in TransformedDistribution.cdf() and .icdf() (#5172)
This commit is contained in:
parent
a061000250
commit
b608ea9178
|
|
@ -285,9 +285,9 @@ EXAMPLES = [
|
|||
'transforms': ExpTransform(),
|
||||
},
|
||||
{
|
||||
'base_distribution': Normal(Variable(torch.randn(2, 3), requires_grad=True),
|
||||
Variable(torch.randn(2, 3).abs(), requires_grad=True)),
|
||||
'transforms': [AffineTransform(Variable(torch.randn(1)), Variable(torch.randn(1))),
|
||||
'base_distribution': Normal(Variable(torch.randn(2, 3, 5), requires_grad=True),
|
||||
Variable(torch.randn(2, 3, 5).abs(), requires_grad=True)),
|
||||
'transforms': [AffineTransform(Variable(torch.randn(3, 5)), Variable(torch.randn(3, 5))),
|
||||
ExpTransform()],
|
||||
},
|
||||
]),
|
||||
|
|
@ -1317,7 +1317,7 @@ class TestDistributions(TestCase):
|
|||
pdfs = dist.log_prob(samples).exp()
|
||||
except NotImplementedError:
|
||||
continue
|
||||
cdfs_derivative = grad(cdfs.sum(), [samples])[0]
|
||||
cdfs_derivative = grad(cdfs.sum(), [samples])[0] # this should not be wrapped in torch.abs()
|
||||
self.assertEqual(cdfs_derivative, pdfs, message='\n'.join([
|
||||
'{} example {}/{}, d(cdf)/dx != pdf(x)'.format(Dist.__name__, i + 1, len(params)),
|
||||
'x = {}'.format(samples),
|
||||
|
|
|
|||
|
|
@ -49,8 +49,9 @@ class TransformedDistribution(Distribution):
|
|||
def sample(self, sample_shape=torch.Size()):
|
||||
"""
|
||||
Generates a sample_shape shaped sample or sample_shape shaped batch of
|
||||
samples if the distribution parameters are batched. Samples first from base distribution
|
||||
and applies `transform()` for every transform in the list.
|
||||
samples if the distribution parameters are batched. Samples first from
|
||||
base distribution and applies `transform()` for every transform in the
|
||||
list.
|
||||
"""
|
||||
x = self.base_dist.sample(sample_shape)
|
||||
for transform in self.transforms:
|
||||
|
|
@ -61,8 +62,8 @@ class TransformedDistribution(Distribution):
|
|||
"""
|
||||
Generates a sample_shape shaped reparameterized sample or sample_shape
|
||||
shaped batch of reparameterized samples if the distribution parameters
|
||||
are batched. Samples first from base distribution and applies `transform()`
|
||||
for every transform in the list.
|
||||
are batched. Samples first from base distribution and applies
|
||||
`transform()` for every transform in the list.
|
||||
"""
|
||||
x = self.base_dist.rsample(sample_shape)
|
||||
for transform in self.transforms:
|
||||
|
|
@ -71,8 +72,8 @@ class TransformedDistribution(Distribution):
|
|||
|
||||
def log_prob(self, value):
|
||||
"""
|
||||
Scores the sample by inverting the transform(s) and computing the score using the score
|
||||
of the base distribution and the log abs det jacobian
|
||||
Scores the sample by inverting the transform(s) and computing the score
|
||||
using the score of the base distribution and the log abs det jacobian.
|
||||
"""
|
||||
self.base_dist._validate_log_prob_arg(value)
|
||||
event_dim = len(self.event_shape)
|
||||
|
|
@ -87,21 +88,35 @@ class TransformedDistribution(Distribution):
|
|||
event_dim - len(self.base_dist.event_shape))
|
||||
return log_prob
|
||||
|
||||
def _monotonize_cdf(self, value):
|
||||
"""
|
||||
This conditionally flips ``value -> 1-value`` to ensure :meth:`cdf` is
|
||||
monotone increasing.
|
||||
"""
|
||||
sign = 1
|
||||
for transform in self.transforms:
|
||||
sign = sign * transform.sign
|
||||
if sign is 1:
|
||||
return value
|
||||
return sign * (value - 0.5) + 0.5
|
||||
|
||||
def cdf(self, value):
|
||||
"""
|
||||
Computes the cumulative distribution function by inverting the transform(s) and computing
|
||||
the score of the base distribution
|
||||
Computes the cumulative distribution function by inverting the
|
||||
transform(s) and computing the score of the base distribution.
|
||||
"""
|
||||
self.base_dist._validate_log_prob_arg(value)
|
||||
for transform in self.transforms[::-1]:
|
||||
value = transform.inv(value)
|
||||
return self.base_dist.cdf(value)
|
||||
value = self.base_dist.cdf(value)
|
||||
value = self._monotonize_cdf(value)
|
||||
return value
|
||||
|
||||
def icdf(self, value):
|
||||
"""
|
||||
Computes the inverse cumulative distribution function using transform(s) and computing
|
||||
the score of the base distribution
|
||||
Computes the inverse cumulative distribution function using
|
||||
transform(s) and computing the score of the base distribution.
|
||||
"""
|
||||
value = self._monotonize_cdf(value)
|
||||
value = self.base_dist.icdf(value)
|
||||
for transform in self.transforms:
|
||||
value = transform(value)
|
||||
|
|
|
|||
|
|
@ -61,6 +61,9 @@ class Transform(object):
|
|||
the codomain. Transforms that are not bijective should at least
|
||||
maintain the weaker pseudoinverse properties
|
||||
``t(t.inv(t(x)) == t(x)`` and ``t.inv(t(t.inv(y))) == t.inv(y)``.
|
||||
sign (int or Variable): For bijective univariate transforms, this
|
||||
should be +1 or -1 depending on whether transform is monotone
|
||||
increasing or decreasing.
|
||||
event_dim (int): Number of dimensions that are correlated together in
|
||||
the transform ``event_shape``. This should be 0 for pointwise
|
||||
transforms, 1 for transforms that act jointly on vectors, 2 for
|
||||
|
|
@ -93,6 +96,14 @@ class Transform(object):
|
|||
self._inv = weakref.ref(inv)
|
||||
return inv
|
||||
|
||||
@property
|
||||
def sign(self):
|
||||
"""
|
||||
Returns the sign of the determinant of the Jacobian, if applicable.
|
||||
In general this only makes sense for bijective transforms.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def __eq__(self, other):
|
||||
return self is other
|
||||
|
||||
|
|
@ -166,6 +177,10 @@ class _InverseTransform(Transform):
|
|||
def bijective(self):
|
||||
return self._inv.bijective
|
||||
|
||||
@property
|
||||
def sign(self):
|
||||
return self._inv.sign
|
||||
|
||||
@property
|
||||
def event_dim(self):
|
||||
return self._inv.event_dim
|
||||
|
|
@ -219,6 +234,13 @@ class ComposeTransform(Transform):
|
|||
def bijective(self):
|
||||
return all(p.bijective for p in self.parts)
|
||||
|
||||
@lazy_property
|
||||
def sign(self):
|
||||
sign = 1
|
||||
for p in self.parts:
|
||||
sign = sign * p.sign
|
||||
return sign
|
||||
|
||||
@lazy_property
|
||||
def event_dim(self):
|
||||
return max(p.event_dim for p in self.parts) if self.parts else 0
|
||||
|
|
@ -261,6 +283,7 @@ class ExpTransform(Transform):
|
|||
domain = constraints.real
|
||||
codomain = constraints.positive
|
||||
bijective = True
|
||||
sign = +1
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, ExpTransform)
|
||||
|
|
@ -282,6 +305,7 @@ class SigmoidTransform(Transform):
|
|||
domain = constraints.real
|
||||
codomain = constraints.unit_interval
|
||||
bijective = True
|
||||
sign = +1
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, SigmoidTransform)
|
||||
|
|
@ -341,6 +365,10 @@ class AffineTransform(Transform):
|
|||
result = result.data.view(-1)[0]
|
||||
return result
|
||||
|
||||
@property
|
||||
def sign(self):
|
||||
return self.scale.sign()
|
||||
|
||||
def _call(self, x):
|
||||
return self.loc + self.scale * x
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user