pytorch/torch/distributions/constraints.py
fehiepsi 91ea2cd5a7 clip sigmoid to prevent transforms return inf/nan values (#20288)
Summary:
This PR addresses some numerical issues of Sigmoid/StickBreakingTransform, where these transforms give +-inf when the unconstrained values move to +-20 areas.

For example, with
```
t = torch.distributions.SigmoidTransform()
x = torch.tensor(20.)
t.inv(t(x)), t.log_abs_det_jacobian(x, t(x))
```
current behaviour the inverse will return `inf` and logdet return `-inf` while this PR makes it to `15.9424` and `-15.9424`.

And for
```
t = torch.distributions.StickBreakingTransform()
x = torch.tensor([20., 20.])
t.inv(t(x)), t.log_abs_det_jacobian(x, t(x))
```
current value is `(inf, nan)` and `-inf` for logdet, while this PR makes it `[16.6355, 71.3942]` and `-47.8272` for logdet.

Although these finite values are wrong and seems unavoidable, it is better than returning `inf` or `nan` in my opinion. This is useful in HMC where despite that the grad will be zero when the unconstrained parameter moves to unstable area (due to clipping), velocity variable will force the parameter move to another area which by chance can move the parameter out of unstable area. But inf/nan can be useful to stop doing inference early. So the changes in this PR might be inappropriate.

I also fix some small issues of `_Simplex` and `_RealVector` constraints where batch shape of the input is not respected when checking validation.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/20288

Differential Revision: D15742047

Pulled By: ezyang

fbshipit-source-id: b427ed1752c41327abb3957f98d4b289307a7d17
2019-06-10 11:16:31 -07:00

366 lines
10 KiB
Python

r"""
The following constraints are implemented:
- ``constraints.boolean``
- ``constraints.cat``
- ``constraints.dependent``
- ``constraints.greater_than(lower_bound)``
- ``constraints.integer_interval(lower_bound, upper_bound)``
- ``constraints.interval(lower_bound, upper_bound)``
- ``constraints.lower_cholesky``
- ``constraints.lower_triangular``
- ``constraints.nonnegative_integer``
- ``constraints.positive``
- ``constraints.positive_definite``
- ``constraints.positive_integer``
- ``constraints.real``
- ``constraints.real_vector``
- ``constraints.simplex``
- ``constraints.stack``
- ``constraints.unit_interval``
"""
import torch
__all__ = [
'Constraint',
'boolean',
'cat',
'dependent',
'dependent_property',
'greater_than',
'greater_than_eq',
'integer_interval',
'interval',
'half_open_interval',
'is_dependent',
'less_than',
'lower_cholesky',
'lower_triangular',
'nonnegative_integer',
'positive',
'positive_definite',
'positive_integer',
'real',
'real_vector',
'simplex',
'stack',
'unit_interval',
]
class Constraint(object):
"""
Abstract base class for constraints.
A constraint object represents a region over which a variable is valid,
e.g. within which a variable can be optimized.
"""
def check(self, value):
"""
Returns a byte tensor of `sample_shape + batch_shape` indicating
whether each event in value satisfies this constraint.
"""
raise NotImplementedError
def __repr__(self):
return self.__class__.__name__[1:] + '()'
class _Dependent(Constraint):
"""
Placeholder for variables whose support depends on other variables.
These variables obey no simple coordinate-wise constraints.
"""
def check(self, x):
raise ValueError('Cannot determine validity of dependent constraint')
def is_dependent(constraint):
return isinstance(constraint, _Dependent)
class _DependentProperty(property, _Dependent):
"""
Decorator that extends @property to act like a `Dependent` constraint when
called on a class and act like a property when called on an object.
Example::
class Uniform(Distribution):
def __init__(self, low, high):
self.low = low
self.high = high
@constraints.dependent_property
def support(self):
return constraints.interval(self.low, self.high)
"""
pass
class _Boolean(Constraint):
"""
Constrain to the two values `{0, 1}`.
"""
def check(self, value):
return (value == 0) | (value == 1)
class _IntegerInterval(Constraint):
"""
Constrain to an integer interval `[lower_bound, upper_bound]`.
"""
def __init__(self, lower_bound, upper_bound):
self.lower_bound = lower_bound
self.upper_bound = upper_bound
def check(self, value):
return (value % 1 == 0) & (self.lower_bound <= value) & (value <= self.upper_bound)
def __repr__(self):
fmt_string = self.__class__.__name__[1:]
fmt_string += '(lower_bound={}, upper_bound={})'.format(self.lower_bound, self.upper_bound)
return fmt_string
class _IntegerLessThan(Constraint):
"""
Constrain to an integer interval `(-inf, upper_bound]`.
"""
def __init__(self, upper_bound):
self.upper_bound = upper_bound
def check(self, value):
return (value % 1 == 0) & (value <= self.upper_bound)
def __repr__(self):
fmt_string = self.__class__.__name__[1:]
fmt_string += '(upper_bound={})'.format(self.upper_bound)
return fmt_string
class _IntegerGreaterThan(Constraint):
"""
Constrain to an integer interval `[lower_bound, inf)`.
"""
def __init__(self, lower_bound):
self.lower_bound = lower_bound
def check(self, value):
return (value % 1 == 0) & (value >= self.lower_bound)
def __repr__(self):
fmt_string = self.__class__.__name__[1:]
fmt_string += '(lower_bound={})'.format(self.lower_bound)
return fmt_string
class _Real(Constraint):
"""
Trivially constrain to the extended real line `[-inf, inf]`.
"""
def check(self, value):
return value == value # False for NANs.
class _GreaterThan(Constraint):
"""
Constrain to a real half line `(lower_bound, inf]`.
"""
def __init__(self, lower_bound):
self.lower_bound = lower_bound
def check(self, value):
return self.lower_bound < value
def __repr__(self):
fmt_string = self.__class__.__name__[1:]
fmt_string += '(lower_bound={})'.format(self.lower_bound)
return fmt_string
class _GreaterThanEq(Constraint):
"""
Constrain to a real half line `[lower_bound, inf)`.
"""
def __init__(self, lower_bound):
self.lower_bound = lower_bound
def check(self, value):
return self.lower_bound <= value
def __repr__(self):
fmt_string = self.__class__.__name__[1:]
fmt_string += '(lower_bound={})'.format(self.lower_bound)
return fmt_string
class _LessThan(Constraint):
"""
Constrain to a real half line `[-inf, upper_bound)`.
"""
def __init__(self, upper_bound):
self.upper_bound = upper_bound
def check(self, value):
return value < self.upper_bound
def __repr__(self):
fmt_string = self.__class__.__name__[1:]
fmt_string += '(upper_bound={})'.format(self.upper_bound)
return fmt_string
class _Interval(Constraint):
"""
Constrain to a real interval `[lower_bound, upper_bound]`.
"""
def __init__(self, lower_bound, upper_bound):
self.lower_bound = lower_bound
self.upper_bound = upper_bound
def check(self, value):
return (self.lower_bound <= value) & (value <= self.upper_bound)
def __repr__(self):
fmt_string = self.__class__.__name__[1:]
fmt_string += '(lower_bound={}, upper_bound={})'.format(self.lower_bound, self.upper_bound)
return fmt_string
class _HalfOpenInterval(Constraint):
"""
Constrain to a real interval `[lower_bound, upper_bound)`.
"""
def __init__(self, lower_bound, upper_bound):
self.lower_bound = lower_bound
self.upper_bound = upper_bound
def check(self, value):
return (self.lower_bound <= value) & (value < self.upper_bound)
def __repr__(self):
fmt_string = self.__class__.__name__[1:]
fmt_string += '(lower_bound={}, upper_bound={})'.format(self.lower_bound, self.upper_bound)
return fmt_string
class _Simplex(Constraint):
"""
Constrain to the unit simplex in the innermost (rightmost) dimension.
Specifically: `x >= 0` and `x.sum(-1) == 1`.
"""
def check(self, value):
return torch.all(value >= 0, dim=-1) & ((value.sum(-1) - 1).abs() < 1e-6)
class _LowerTriangular(Constraint):
"""
Constrain to lower-triangular square matrices.
"""
def check(self, value):
value_tril = value.tril()
return (value_tril == value).view(value.shape[:-2] + (-1,)).min(-1)[0]
class _LowerCholesky(Constraint):
"""
Constrain to lower-triangular square matrices with positive diagonals.
"""
def check(self, value):
value_tril = value.tril()
lower_triangular = (value_tril == value).view(value.shape[:-2] + (-1,)).min(-1)[0]
positive_diagonal = (value.diagonal(dim1=-2, dim2=-1) > 0).min(-1)[0]
return lower_triangular & positive_diagonal
class _PositiveDefinite(Constraint):
"""
Constrain to positive-definite matrices.
"""
def check(self, value):
matrix_shape = value.shape[-2:]
batch_shape = value.unsqueeze(0).shape[:-2]
# TODO: replace with batched linear algebra routine when one becomes available
# note that `symeig()` returns eigenvalues in ascending order
flattened_value = value.reshape((-1,) + matrix_shape)
return torch.stack([v.symeig(eigenvectors=False)[0][:1] > 0.0
for v in flattened_value]).view(batch_shape)
class _RealVector(Constraint):
"""
Constrain to real-valued vectors. This is the same as `constraints.real`,
but additionally reduces across the `event_shape` dimension.
"""
def check(self, value):
return torch.all(value == value, dim=-1) # False for NANs.
class _Cat(Constraint):
"""
Constraint functor that applies a sequence of constraints
`cseq` at the submatrices at dimension `dim`,
each of size `lengths[dim]`, in a way compatible with :func:`torch.cat`.
"""
def __init__(self, cseq, dim=0, lengths=None):
assert all(isinstance(c, Constraint) for c in cseq)
self.cseq = list(cseq)
if lengths is None:
lengths = [1] * len(self.cseq)
self.lengths = list(lengths)
assert len(self.lengths) == len(self.cseq)
self.dim = dim
def check(self, value):
assert -value.dim() <= self.dim < value.dim()
checks = []
start = 0
for constr, length in zip(self.cseq, self.lengths):
v = value.narrow(self.dim, start, length)
checks.append(constr.check(v))
start = start + length # avoid += for jit compat
return torch.cat(checks, self.dim)
class _Stack(Constraint):
"""
Constraint functor that applies a sequence of constraints
`cseq` at the submatrices at dimension `dim`,
in a way compatible with :func:`torch.stack`.
"""
def __init__(self, cseq, dim=0):
assert all(isinstance(c, Constraint) for c in cseq)
self.cseq = list(cseq)
self.dim = dim
def check(self, value):
assert -value.dim() <= self.dim < value.dim()
vs = [value.select(self.dim, i) for i in range(value.size(self.dim))]
return torch.stack([constr.check(v)
for v, constr in zip(vs, self.cseq)], self.dim)
# Public interface.
dependent = _Dependent()
dependent_property = _DependentProperty
boolean = _Boolean()
nonnegative_integer = _IntegerGreaterThan(0)
positive_integer = _IntegerGreaterThan(1)
integer_interval = _IntegerInterval
real = _Real()
real_vector = _RealVector()
positive = _GreaterThan(0.)
greater_than = _GreaterThan
greater_than_eq = _GreaterThanEq
less_than = _LessThan
unit_interval = _Interval(0., 1.)
interval = _Interval
half_open_interval = _HalfOpenInterval
simplex = _Simplex()
lower_triangular = _LowerTriangular()
lower_cholesky = _LowerCholesky()
positive_definite = _PositiveDefinite()
cat = _Cat
stack = _Stack