mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fix typing errors in the torch.distributions module (#45689)
Summary: Fixes https://github.com/pytorch/pytorch/issues/42979. Pull Request resolved: https://github.com/pytorch/pytorch/pull/45689 Reviewed By: agolynski Differential Revision: D24229870 Pulled By: xuzhao9 fbshipit-source-id: 5fc87cc428170139962ab65b71cacba494d46130
This commit is contained in:
parent
6a001decf2
commit
146721f1df
3
mypy.ini
3
mypy.ini
|
|
@ -81,9 +81,6 @@ ignore_errors = True
|
|||
[mypy-torch.quantization.fx.*]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.distributions.*]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch._tensor_str]
|
||||
ignore_errors = True
|
||||
|
||||
|
|
|
|||
|
|
@ -111,6 +111,7 @@ from .transforms import *
|
|||
from .uniform import Uniform
|
||||
from .von_mises import VonMises
|
||||
from .weibull import Weibull
|
||||
from . import transforms
|
||||
|
||||
__all__ = [
|
||||
'Bernoulli',
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from numbers import Number
|
||||
from numbers import Real, Number
|
||||
|
||||
import torch
|
||||
from torch.distributions import constraints
|
||||
|
|
@ -28,7 +28,7 @@ class Beta(ExponentialFamily):
|
|||
has_rsample = True
|
||||
|
||||
def __init__(self, concentration1, concentration0, validate_args=None):
|
||||
if isinstance(concentration1, Number) and isinstance(concentration0, Number):
|
||||
if isinstance(concentration1, Real) and isinstance(concentration0, Real):
|
||||
concentration1_concentration0 = torch.tensor([float(concentration1), float(concentration0)])
|
||||
else:
|
||||
concentration1, concentration0 = broadcast_all(concentration1, concentration0)
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import torch
|
|||
import warnings
|
||||
from torch.distributions import constraints
|
||||
from torch.distributions.utils import lazy_property
|
||||
from typing import Dict, Optional, Any
|
||||
|
||||
|
||||
class Distribution(object):
|
||||
|
|
@ -12,8 +13,6 @@ class Distribution(object):
|
|||
has_rsample = False
|
||||
has_enumerate_support = False
|
||||
_validate_args = False
|
||||
support = None
|
||||
arg_constraints = {}
|
||||
|
||||
@staticmethod
|
||||
def set_default_validate_args(value):
|
||||
|
|
@ -72,7 +71,7 @@ class Distribution(object):
|
|||
return self._event_shape
|
||||
|
||||
@property
|
||||
def arg_constraints(self):
|
||||
def arg_constraints(self) -> Dict[str, constraints.Constraint]:
|
||||
"""
|
||||
Returns a dictionary from argument names to
|
||||
:class:`~torch.distributions.constraints.Constraint` objects that
|
||||
|
|
@ -82,7 +81,7 @@ class Distribution(object):
|
|||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def support(self):
|
||||
def support(self) -> Optional[Any]:
|
||||
"""
|
||||
Returns a :class:`~torch.distributions.constraints.Constraint` object
|
||||
representing this distribution's support.
|
||||
|
|
@ -248,7 +247,7 @@ class Distribution(object):
|
|||
if i != 1 and j != 1 and i != j:
|
||||
raise ValueError('Value is not broadcastable with batch_shape+event_shape: {} vs {}.'.
|
||||
format(actual_shape, expected_shape))
|
||||
|
||||
assert self.support is not None
|
||||
if not self.support.check(value).all():
|
||||
raise ValueError('The value argument must be within the support')
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import torch
|
|||
from torch.distributions import constraints
|
||||
from torch.distributions.distribution import Distribution
|
||||
from torch.distributions.utils import _sum_rightmost
|
||||
|
||||
from typing import Dict
|
||||
|
||||
class Independent(Distribution):
|
||||
r"""
|
||||
|
|
@ -31,7 +31,7 @@ class Independent(Distribution):
|
|||
reinterpreted_batch_ndims (int): the number of batch dims to
|
||||
reinterpret as event dims
|
||||
"""
|
||||
arg_constraints = {}
|
||||
arg_constraints: Dict[str, constraints.Constraint] = {}
|
||||
|
||||
def __init__(self, base_distribution, reinterpreted_batch_ndims, validate_args=None):
|
||||
if reinterpreted_batch_ndims > len(base_distribution.batch_shape):
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import math
|
||||
import warnings
|
||||
from functools import total_ordering
|
||||
from typing import Type, Dict, Callable, Tuple
|
||||
|
||||
import torch
|
||||
from torch._six import inf
|
||||
|
|
@ -33,7 +34,7 @@ from .uniform import Uniform
|
|||
from .utils import _sum_rightmost
|
||||
|
||||
_KL_REGISTRY = {} # Source of truth mapping a few general (type, type) pairs to functions.
|
||||
_KL_MEMOIZE = {} # Memoized version mapping many specific (type, type) pairs to functions.
|
||||
_KL_MEMOIZE: Dict[Tuple[Type, Type], Callable] = {} # Memoized version mapping many specific (type, type) pairs to functions.
|
||||
|
||||
|
||||
def register_kl(type_p, type_q):
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import torch
|
|||
from torch.distributions.distribution import Distribution
|
||||
from torch.distributions import Categorical
|
||||
from torch.distributions import constraints
|
||||
from typing import Dict
|
||||
|
||||
|
||||
class MixtureSameFamily(Distribution):
|
||||
|
|
@ -45,7 +46,7 @@ class MixtureSameFamily(Distribution):
|
|||
component_distribution: `torch.distributions.Distribution`-like
|
||||
instance. Right-most batch dimension indexes component.
|
||||
"""
|
||||
arg_constraints = {}
|
||||
arg_constraints: Dict[str, constraints.Constraint] = {}
|
||||
has_rsample = False
|
||||
|
||||
def __init__(self,
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ import torch
|
|||
from torch._six import inf
|
||||
from torch.distributions.distribution import Distribution
|
||||
from torch.distributions import Categorical
|
||||
from numbers import Number
|
||||
from torch.distributions import constraints
|
||||
from torch.distributions.utils import broadcast_all
|
||||
|
||||
|
|
@ -40,6 +39,7 @@ class Multinomial(Distribution):
|
|||
"""
|
||||
arg_constraints = {'probs': constraints.simplex,
|
||||
'logits': constraints.real}
|
||||
total_count: int
|
||||
|
||||
@property
|
||||
def mean(self):
|
||||
|
|
@ -50,7 +50,7 @@ class Multinomial(Distribution):
|
|||
return self.total_count * self.probs * (1 - self.probs)
|
||||
|
||||
def __init__(self, total_count=1, probs=None, logits=None, validate_args=None):
|
||||
if not isinstance(total_count, Number):
|
||||
if not isinstance(total_count, int):
|
||||
raise NotImplementedError('inhomogeneous total_count is not supported')
|
||||
self.total_count = total_count
|
||||
self._categorical = Categorical(probs=probs, logits=logits)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import math
|
||||
from numbers import Real
|
||||
from numbers import Number
|
||||
|
||||
import torch
|
||||
|
|
@ -72,7 +73,7 @@ class Normal(ExponentialFamily):
|
|||
self._validate_sample(value)
|
||||
# compute the variance
|
||||
var = (self.scale ** 2)
|
||||
log_scale = math.log(self.scale) if isinstance(self.scale, Number) else self.scale.log()
|
||||
log_scale = math.log(self.scale) if isinstance(self.scale, Real) else self.scale.log()
|
||||
return -((value - self.loc) ** 2) / (2 * var) - log_scale - math.log(math.sqrt(2 * math.pi))
|
||||
|
||||
def cdf(self, value):
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ from torch.distributions import constraints
|
|||
from torch.distributions.distribution import Distribution
|
||||
from torch.distributions.transforms import Transform
|
||||
from torch.distributions.utils import _sum_rightmost
|
||||
from typing import Dict
|
||||
|
||||
|
||||
class TransformedDistribution(Distribution):
|
||||
|
|
@ -38,7 +39,7 @@ class TransformedDistribution(Distribution):
|
|||
:class:`~torch.distributions.relaxed_bernoulli.RelaxedBernoulli` and
|
||||
:class:`~torch.distributions.relaxed_categorical.RelaxedOneHotCategorical`
|
||||
"""
|
||||
arg_constraints = {}
|
||||
arg_constraints: Dict[str, constraints.Constraint] = {}
|
||||
|
||||
def __init__(self, base_distribution, transforms, validate_args=None):
|
||||
self.base_dist = base_distribution
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from torch.distributions.utils import (_sum_rightmost, broadcast_all,
|
|||
lazy_property)
|
||||
from torch.nn.functional import pad
|
||||
from torch.nn.functional import softplus
|
||||
from typing import List
|
||||
|
||||
__all__ = [
|
||||
'AbsTransform',
|
||||
|
|
@ -77,6 +78,7 @@ class Transform(object):
|
|||
transforms that act jointly on matrices, etc.
|
||||
"""
|
||||
bijective = False
|
||||
codomain: constraints.Constraint
|
||||
event_dim = 0
|
||||
|
||||
def __init__(self, cache_size=0):
|
||||
|
|
@ -185,22 +187,27 @@ class _InverseTransform(Transform):
|
|||
|
||||
@constraints.dependent_property
|
||||
def domain(self):
|
||||
assert self._inv is not None
|
||||
return self._inv.codomain
|
||||
|
||||
@constraints.dependent_property
|
||||
def codomain(self):
|
||||
assert self._inv is not None
|
||||
return self._inv.domain
|
||||
|
||||
@property
|
||||
def bijective(self):
|
||||
assert self._inv is not None
|
||||
return self._inv.bijective
|
||||
|
||||
@property
|
||||
def sign(self):
|
||||
assert self._inv is not None
|
||||
return self._inv.sign
|
||||
|
||||
@property
|
||||
def event_dim(self):
|
||||
assert self._inv is not None
|
||||
return self._inv.event_dim
|
||||
|
||||
@property
|
||||
|
|
@ -208,17 +215,21 @@ class _InverseTransform(Transform):
|
|||
return self._inv
|
||||
|
||||
def with_cache(self, cache_size=1):
|
||||
assert self._inv is not None
|
||||
return self.inv.with_cache(cache_size).inv
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, _InverseTransform):
|
||||
return False
|
||||
assert self._inv is not None
|
||||
return self._inv == other._inv
|
||||
|
||||
def __call__(self, x):
|
||||
assert self._inv is not None
|
||||
return self._inv._inv_call(x)
|
||||
|
||||
def log_abs_det_jacobian(self, x, y):
|
||||
assert self._inv is not None
|
||||
return -self._inv.log_abs_det_jacobian(y, x)
|
||||
|
||||
|
||||
|
|
@ -500,8 +511,8 @@ class AffineTransform(Transform):
|
|||
|
||||
@property
|
||||
def sign(self):
|
||||
if isinstance(self.scale, numbers.Number):
|
||||
return 1 if self.scale > 0 else -1 if self.scale < 0 else 0
|
||||
if isinstance(self.scale, numbers.Real):
|
||||
return 1 if float(self.scale) > 0 else -1 if float(self.scale) < 0 else 0
|
||||
return self.scale.sign()
|
||||
|
||||
def _call(self, x):
|
||||
|
|
@ -513,7 +524,7 @@ class AffineTransform(Transform):
|
|||
def log_abs_det_jacobian(self, x, y):
|
||||
shape = x.shape
|
||||
scale = self.scale
|
||||
if isinstance(scale, numbers.Number):
|
||||
if isinstance(scale, numbers.Real):
|
||||
result = torch.full_like(x, math.log(abs(scale)))
|
||||
else:
|
||||
result = torch.abs(scale).log()
|
||||
|
|
@ -575,7 +586,7 @@ class StickBreakingTransform(Transform):
|
|||
offset = x.shape[-1] + 1 - x.new_ones(x.shape[-1]).cumsum(-1)
|
||||
z = _clipped_sigmoid(x - offset.log())
|
||||
z_cumprod = (1 - z).cumprod(-1)
|
||||
y = pad(z, (0, 1), value=1) * pad(z_cumprod, (1, 0), value=1)
|
||||
y = pad(z, [0, 1], value=1) * pad(z_cumprod, [1, 0], value=1)
|
||||
return y
|
||||
|
||||
def _inverse(self, y):
|
||||
|
|
@ -619,6 +630,7 @@ class LowerCholeskyTransform(Transform):
|
|||
|
||||
|
||||
class CatTransform(Transform):
|
||||
tseq: List[numbers.Number]
|
||||
"""
|
||||
Transform functor that applies a sequence of transforms `tseq`
|
||||
component-wise to each submatrix at `dim`, of length `lengths[dim]`,
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ from functools import update_wrapper
|
|||
from numbers import Number
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typing import Dict, Any
|
||||
|
||||
|
||||
def broadcast_all(*values):
|
||||
|
|
@ -23,13 +24,14 @@ def broadcast_all(*values):
|
|||
if not all(isinstance(v, torch.Tensor) or isinstance(v, Number) for v in values):
|
||||
raise ValueError('Input arguments must all be instances of numbers.Number or torch.tensor.')
|
||||
if not all([isinstance(v, torch.Tensor) for v in values]):
|
||||
options = dict(dtype=torch.get_default_dtype())
|
||||
options: Dict[str, Any] = dict(dtype=torch.get_default_dtype())
|
||||
for value in values:
|
||||
if isinstance(value, torch.Tensor):
|
||||
options = dict(dtype=value.dtype, device=value.device)
|
||||
break
|
||||
values = [v if isinstance(v, torch.Tensor) else torch.tensor(v, **options)
|
||||
new_values = [v if isinstance(v, torch.Tensor) else torch.tensor(v, **options)
|
||||
for v in values]
|
||||
return torch.broadcast_tensors(*new_values)
|
||||
return torch.broadcast_tensors(*values)
|
||||
|
||||
|
||||
|
|
@ -94,7 +96,7 @@ class lazy_property(object):
|
|||
"""
|
||||
def __init__(self, wrapped):
|
||||
self.wrapped = wrapped
|
||||
update_wrapper(self, wrapped)
|
||||
update_wrapper(self, wrapped) # type: ignore[arg-type]
|
||||
|
||||
def __get__(self, instance, obj_type=None):
|
||||
if instance is None:
|
||||
|
|
|
|||
|
|
@ -510,7 +510,7 @@ def stft(input: Tensor, n_fft: int, hop_length: Optional[int] = None,
|
|||
signal_dim = input.dim()
|
||||
extended_shape = [1] * (3 - signal_dim) + list(input.size())
|
||||
pad = int(n_fft // 2)
|
||||
input = F.pad(input.view(extended_shape), (pad, pad), pad_mode)
|
||||
input = F.pad(input.view(extended_shape), [pad, pad], pad_mode)
|
||||
input = input.view(input.shape[-signal_dim:])
|
||||
return _VF.stft(input, n_fft, hop_length, win_length, window, # type: ignore
|
||||
normalized, onesided, return_complex)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user