mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: This PR addresses some numerical issues of Sigmoid/StickBreakingTransform, where these transforms give +-inf when the unconstrained values move to +-20 areas. For example, with ``` t = torch.distributions.SigmoidTransform() x = torch.tensor(20.) t.inv(t(x)), t.log_abs_det_jacobian(x, t(x)) ``` current behaviour the inverse will return `inf` and logdet return `-inf` while this PR makes it to `15.9424` and `-15.9424`. And for ``` t = torch.distributions.StickBreakingTransform() x = torch.tensor([20., 20.]) t.inv(t(x)), t.log_abs_det_jacobian(x, t(x)) ``` current value is `(inf, nan)` and `-inf` for logdet, while this PR makes it `[16.6355, 71.3942]` and `-47.8272` for logdet. Although these finite values are wrong and seems unavoidable, it is better than returning `inf` or `nan` in my opinion. This is useful in HMC where despite that the grad will be zero when the unconstrained parameter moves to unstable area (due to clipping), velocity variable will force the parameter move to another area which by chance can move the parameter out of unstable area. But inf/nan can be useful to stop doing inference early. So the changes in this PR might be inappropriate. I also fix some small issues of `_Simplex` and `_RealVector` constraints where batch shape of the input is not respected when checking validation. Pull Request resolved: https://github.com/pytorch/pytorch/pull/20288 Differential Revision: D15742047 Pulled By: ezyang fbshipit-source-id: b427ed1752c41327abb3957f98d4b289307a7d17
366 lines
10 KiB
Python
366 lines
10 KiB
Python
r"""
|
|
The following constraints are implemented:
|
|
|
|
- ``constraints.boolean``
|
|
- ``constraints.cat``
|
|
- ``constraints.dependent``
|
|
- ``constraints.greater_than(lower_bound)``
|
|
- ``constraints.integer_interval(lower_bound, upper_bound)``
|
|
- ``constraints.interval(lower_bound, upper_bound)``
|
|
- ``constraints.lower_cholesky``
|
|
- ``constraints.lower_triangular``
|
|
- ``constraints.nonnegative_integer``
|
|
- ``constraints.positive``
|
|
- ``constraints.positive_definite``
|
|
- ``constraints.positive_integer``
|
|
- ``constraints.real``
|
|
- ``constraints.real_vector``
|
|
- ``constraints.simplex``
|
|
- ``constraints.stack``
|
|
- ``constraints.unit_interval``
|
|
"""
|
|
|
|
import torch
|
|
|
|
__all__ = [
|
|
'Constraint',
|
|
'boolean',
|
|
'cat',
|
|
'dependent',
|
|
'dependent_property',
|
|
'greater_than',
|
|
'greater_than_eq',
|
|
'integer_interval',
|
|
'interval',
|
|
'half_open_interval',
|
|
'is_dependent',
|
|
'less_than',
|
|
'lower_cholesky',
|
|
'lower_triangular',
|
|
'nonnegative_integer',
|
|
'positive',
|
|
'positive_definite',
|
|
'positive_integer',
|
|
'real',
|
|
'real_vector',
|
|
'simplex',
|
|
'stack',
|
|
'unit_interval',
|
|
]
|
|
|
|
|
|
class Constraint(object):
|
|
"""
|
|
Abstract base class for constraints.
|
|
|
|
A constraint object represents a region over which a variable is valid,
|
|
e.g. within which a variable can be optimized.
|
|
"""
|
|
def check(self, value):
|
|
"""
|
|
Returns a byte tensor of `sample_shape + batch_shape` indicating
|
|
whether each event in value satisfies this constraint.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def __repr__(self):
|
|
return self.__class__.__name__[1:] + '()'
|
|
|
|
|
|
class _Dependent(Constraint):
|
|
"""
|
|
Placeholder for variables whose support depends on other variables.
|
|
These variables obey no simple coordinate-wise constraints.
|
|
"""
|
|
def check(self, x):
|
|
raise ValueError('Cannot determine validity of dependent constraint')
|
|
|
|
|
|
def is_dependent(constraint):
|
|
return isinstance(constraint, _Dependent)
|
|
|
|
|
|
class _DependentProperty(property, _Dependent):
|
|
"""
|
|
Decorator that extends @property to act like a `Dependent` constraint when
|
|
called on a class and act like a property when called on an object.
|
|
|
|
Example::
|
|
|
|
class Uniform(Distribution):
|
|
def __init__(self, low, high):
|
|
self.low = low
|
|
self.high = high
|
|
@constraints.dependent_property
|
|
def support(self):
|
|
return constraints.interval(self.low, self.high)
|
|
"""
|
|
pass
|
|
|
|
|
|
class _Boolean(Constraint):
|
|
"""
|
|
Constrain to the two values `{0, 1}`.
|
|
"""
|
|
def check(self, value):
|
|
return (value == 0) | (value == 1)
|
|
|
|
|
|
class _IntegerInterval(Constraint):
|
|
"""
|
|
Constrain to an integer interval `[lower_bound, upper_bound]`.
|
|
"""
|
|
def __init__(self, lower_bound, upper_bound):
|
|
self.lower_bound = lower_bound
|
|
self.upper_bound = upper_bound
|
|
|
|
def check(self, value):
|
|
return (value % 1 == 0) & (self.lower_bound <= value) & (value <= self.upper_bound)
|
|
|
|
def __repr__(self):
|
|
fmt_string = self.__class__.__name__[1:]
|
|
fmt_string += '(lower_bound={}, upper_bound={})'.format(self.lower_bound, self.upper_bound)
|
|
return fmt_string
|
|
|
|
|
|
class _IntegerLessThan(Constraint):
|
|
"""
|
|
Constrain to an integer interval `(-inf, upper_bound]`.
|
|
"""
|
|
def __init__(self, upper_bound):
|
|
self.upper_bound = upper_bound
|
|
|
|
def check(self, value):
|
|
return (value % 1 == 0) & (value <= self.upper_bound)
|
|
|
|
def __repr__(self):
|
|
fmt_string = self.__class__.__name__[1:]
|
|
fmt_string += '(upper_bound={})'.format(self.upper_bound)
|
|
return fmt_string
|
|
|
|
|
|
class _IntegerGreaterThan(Constraint):
|
|
"""
|
|
Constrain to an integer interval `[lower_bound, inf)`.
|
|
"""
|
|
def __init__(self, lower_bound):
|
|
self.lower_bound = lower_bound
|
|
|
|
def check(self, value):
|
|
return (value % 1 == 0) & (value >= self.lower_bound)
|
|
|
|
def __repr__(self):
|
|
fmt_string = self.__class__.__name__[1:]
|
|
fmt_string += '(lower_bound={})'.format(self.lower_bound)
|
|
return fmt_string
|
|
|
|
|
|
class _Real(Constraint):
|
|
"""
|
|
Trivially constrain to the extended real line `[-inf, inf]`.
|
|
"""
|
|
def check(self, value):
|
|
return value == value # False for NANs.
|
|
|
|
|
|
class _GreaterThan(Constraint):
|
|
"""
|
|
Constrain to a real half line `(lower_bound, inf]`.
|
|
"""
|
|
def __init__(self, lower_bound):
|
|
self.lower_bound = lower_bound
|
|
|
|
def check(self, value):
|
|
return self.lower_bound < value
|
|
|
|
def __repr__(self):
|
|
fmt_string = self.__class__.__name__[1:]
|
|
fmt_string += '(lower_bound={})'.format(self.lower_bound)
|
|
return fmt_string
|
|
|
|
|
|
class _GreaterThanEq(Constraint):
|
|
"""
|
|
Constrain to a real half line `[lower_bound, inf)`.
|
|
"""
|
|
def __init__(self, lower_bound):
|
|
self.lower_bound = lower_bound
|
|
|
|
def check(self, value):
|
|
return self.lower_bound <= value
|
|
|
|
def __repr__(self):
|
|
fmt_string = self.__class__.__name__[1:]
|
|
fmt_string += '(lower_bound={})'.format(self.lower_bound)
|
|
return fmt_string
|
|
|
|
|
|
class _LessThan(Constraint):
|
|
"""
|
|
Constrain to a real half line `[-inf, upper_bound)`.
|
|
"""
|
|
def __init__(self, upper_bound):
|
|
self.upper_bound = upper_bound
|
|
|
|
def check(self, value):
|
|
return value < self.upper_bound
|
|
|
|
def __repr__(self):
|
|
fmt_string = self.__class__.__name__[1:]
|
|
fmt_string += '(upper_bound={})'.format(self.upper_bound)
|
|
return fmt_string
|
|
|
|
|
|
class _Interval(Constraint):
|
|
"""
|
|
Constrain to a real interval `[lower_bound, upper_bound]`.
|
|
"""
|
|
def __init__(self, lower_bound, upper_bound):
|
|
self.lower_bound = lower_bound
|
|
self.upper_bound = upper_bound
|
|
|
|
def check(self, value):
|
|
return (self.lower_bound <= value) & (value <= self.upper_bound)
|
|
|
|
def __repr__(self):
|
|
fmt_string = self.__class__.__name__[1:]
|
|
fmt_string += '(lower_bound={}, upper_bound={})'.format(self.lower_bound, self.upper_bound)
|
|
return fmt_string
|
|
|
|
|
|
class _HalfOpenInterval(Constraint):
|
|
"""
|
|
Constrain to a real interval `[lower_bound, upper_bound)`.
|
|
"""
|
|
def __init__(self, lower_bound, upper_bound):
|
|
self.lower_bound = lower_bound
|
|
self.upper_bound = upper_bound
|
|
|
|
def check(self, value):
|
|
return (self.lower_bound <= value) & (value < self.upper_bound)
|
|
|
|
def __repr__(self):
|
|
fmt_string = self.__class__.__name__[1:]
|
|
fmt_string += '(lower_bound={}, upper_bound={})'.format(self.lower_bound, self.upper_bound)
|
|
return fmt_string
|
|
|
|
|
|
class _Simplex(Constraint):
|
|
"""
|
|
Constrain to the unit simplex in the innermost (rightmost) dimension.
|
|
Specifically: `x >= 0` and `x.sum(-1) == 1`.
|
|
"""
|
|
def check(self, value):
|
|
return torch.all(value >= 0, dim=-1) & ((value.sum(-1) - 1).abs() < 1e-6)
|
|
|
|
|
|
class _LowerTriangular(Constraint):
|
|
"""
|
|
Constrain to lower-triangular square matrices.
|
|
"""
|
|
def check(self, value):
|
|
value_tril = value.tril()
|
|
return (value_tril == value).view(value.shape[:-2] + (-1,)).min(-1)[0]
|
|
|
|
|
|
class _LowerCholesky(Constraint):
|
|
"""
|
|
Constrain to lower-triangular square matrices with positive diagonals.
|
|
"""
|
|
def check(self, value):
|
|
value_tril = value.tril()
|
|
lower_triangular = (value_tril == value).view(value.shape[:-2] + (-1,)).min(-1)[0]
|
|
|
|
positive_diagonal = (value.diagonal(dim1=-2, dim2=-1) > 0).min(-1)[0]
|
|
return lower_triangular & positive_diagonal
|
|
|
|
|
|
class _PositiveDefinite(Constraint):
|
|
"""
|
|
Constrain to positive-definite matrices.
|
|
"""
|
|
def check(self, value):
|
|
matrix_shape = value.shape[-2:]
|
|
batch_shape = value.unsqueeze(0).shape[:-2]
|
|
# TODO: replace with batched linear algebra routine when one becomes available
|
|
# note that `symeig()` returns eigenvalues in ascending order
|
|
flattened_value = value.reshape((-1,) + matrix_shape)
|
|
return torch.stack([v.symeig(eigenvectors=False)[0][:1] > 0.0
|
|
for v in flattened_value]).view(batch_shape)
|
|
|
|
|
|
class _RealVector(Constraint):
|
|
"""
|
|
Constrain to real-valued vectors. This is the same as `constraints.real`,
|
|
but additionally reduces across the `event_shape` dimension.
|
|
"""
|
|
def check(self, value):
|
|
return torch.all(value == value, dim=-1) # False for NANs.
|
|
|
|
|
|
class _Cat(Constraint):
|
|
"""
|
|
Constraint functor that applies a sequence of constraints
|
|
`cseq` at the submatrices at dimension `dim`,
|
|
each of size `lengths[dim]`, in a way compatible with :func:`torch.cat`.
|
|
"""
|
|
def __init__(self, cseq, dim=0, lengths=None):
|
|
assert all(isinstance(c, Constraint) for c in cseq)
|
|
self.cseq = list(cseq)
|
|
if lengths is None:
|
|
lengths = [1] * len(self.cseq)
|
|
self.lengths = list(lengths)
|
|
assert len(self.lengths) == len(self.cseq)
|
|
self.dim = dim
|
|
|
|
def check(self, value):
|
|
assert -value.dim() <= self.dim < value.dim()
|
|
checks = []
|
|
start = 0
|
|
for constr, length in zip(self.cseq, self.lengths):
|
|
v = value.narrow(self.dim, start, length)
|
|
checks.append(constr.check(v))
|
|
start = start + length # avoid += for jit compat
|
|
return torch.cat(checks, self.dim)
|
|
|
|
|
|
class _Stack(Constraint):
|
|
"""
|
|
Constraint functor that applies a sequence of constraints
|
|
`cseq` at the submatrices at dimension `dim`,
|
|
in a way compatible with :func:`torch.stack`.
|
|
"""
|
|
def __init__(self, cseq, dim=0):
|
|
assert all(isinstance(c, Constraint) for c in cseq)
|
|
self.cseq = list(cseq)
|
|
self.dim = dim
|
|
|
|
def check(self, value):
|
|
assert -value.dim() <= self.dim < value.dim()
|
|
vs = [value.select(self.dim, i) for i in range(value.size(self.dim))]
|
|
return torch.stack([constr.check(v)
|
|
for v, constr in zip(vs, self.cseq)], self.dim)
|
|
|
|
# Public interface.
|
|
dependent = _Dependent()
|
|
dependent_property = _DependentProperty
|
|
boolean = _Boolean()
|
|
nonnegative_integer = _IntegerGreaterThan(0)
|
|
positive_integer = _IntegerGreaterThan(1)
|
|
integer_interval = _IntegerInterval
|
|
real = _Real()
|
|
real_vector = _RealVector()
|
|
positive = _GreaterThan(0.)
|
|
greater_than = _GreaterThan
|
|
greater_than_eq = _GreaterThanEq
|
|
less_than = _LessThan
|
|
unit_interval = _Interval(0., 1.)
|
|
interval = _Interval
|
|
half_open_interval = _HalfOpenInterval
|
|
simplex = _Simplex()
|
|
lower_triangular = _LowerTriangular()
|
|
lower_cholesky = _LowerCholesky()
|
|
positive_definite = _PositiveDefinite()
|
|
cat = _Cat
|
|
stack = _Stack
|