Fix typing errors in the torch.distributions module (#45689)

Summary:
Fixes https://github.com/pytorch/pytorch/issues/42979.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/45689

Reviewed By: agolynski

Differential Revision: D24229870

Pulled By: xuzhao9

fbshipit-source-id: 5fc87cc428170139962ab65b71cacba494d46130
This commit is contained in:
Xu Zhao 2020-10-12 10:23:42 -07:00 committed by Facebook GitHub Bot
parent 6a001decf2
commit 146721f1df
13 changed files with 42 additions and 27 deletions

View File

@ -81,9 +81,6 @@ ignore_errors = True
[mypy-torch.quantization.fx.*]
ignore_errors = True
[mypy-torch.distributions.*]
ignore_errors = True
[mypy-torch._tensor_str]
ignore_errors = True

View File

@ -111,6 +111,7 @@ from .transforms import *
from .uniform import Uniform
from .von_mises import VonMises
from .weibull import Weibull
from . import transforms
__all__ = [
'Bernoulli',

View File

@ -1,4 +1,4 @@
from numbers import Number
from numbers import Real, Number
import torch
from torch.distributions import constraints
@ -28,7 +28,7 @@ class Beta(ExponentialFamily):
has_rsample = True
def __init__(self, concentration1, concentration0, validate_args=None):
if isinstance(concentration1, Number) and isinstance(concentration0, Number):
if isinstance(concentration1, Real) and isinstance(concentration0, Real):
concentration1_concentration0 = torch.tensor([float(concentration1), float(concentration0)])
else:
concentration1, concentration0 = broadcast_all(concentration1, concentration0)

View File

@ -2,6 +2,7 @@ import torch
import warnings
from torch.distributions import constraints
from torch.distributions.utils import lazy_property
from typing import Dict, Optional, Any
class Distribution(object):
@ -12,8 +13,6 @@ class Distribution(object):
has_rsample = False
has_enumerate_support = False
_validate_args = False
support = None
arg_constraints = {}
@staticmethod
def set_default_validate_args(value):
@ -72,7 +71,7 @@ class Distribution(object):
return self._event_shape
@property
def arg_constraints(self):
def arg_constraints(self) -> Dict[str, constraints.Constraint]:
"""
Returns a dictionary from argument names to
:class:`~torch.distributions.constraints.Constraint` objects that
@ -82,7 +81,7 @@ class Distribution(object):
raise NotImplementedError
@property
def support(self):
def support(self) -> Optional[Any]:
"""
Returns a :class:`~torch.distributions.constraints.Constraint` object
representing this distribution's support.
@ -248,7 +247,7 @@ class Distribution(object):
if i != 1 and j != 1 and i != j:
raise ValueError('Value is not broadcastable with batch_shape+event_shape: {} vs {}.'.
format(actual_shape, expected_shape))
assert self.support is not None
if not self.support.check(value).all():
raise ValueError('The value argument must be within the support')

View File

@ -2,7 +2,7 @@ import torch
from torch.distributions import constraints
from torch.distributions.distribution import Distribution
from torch.distributions.utils import _sum_rightmost
from typing import Dict
class Independent(Distribution):
r"""
@ -31,7 +31,7 @@ class Independent(Distribution):
reinterpreted_batch_ndims (int): the number of batch dims to
reinterpret as event dims
"""
arg_constraints = {}
arg_constraints: Dict[str, constraints.Constraint] = {}
def __init__(self, base_distribution, reinterpreted_batch_ndims, validate_args=None):
if reinterpreted_batch_ndims > len(base_distribution.batch_shape):

View File

@ -1,6 +1,7 @@
import math
import warnings
from functools import total_ordering
from typing import Type, Dict, Callable, Tuple
import torch
from torch._six import inf
@ -33,7 +34,7 @@ from .uniform import Uniform
from .utils import _sum_rightmost
_KL_REGISTRY = {} # Source of truth mapping a few general (type, type) pairs to functions.
_KL_MEMOIZE = {} # Memoized version mapping many specific (type, type) pairs to functions.
_KL_MEMOIZE: Dict[Tuple[Type, Type], Callable] = {} # Memoized version mapping many specific (type, type) pairs to functions.
def register_kl(type_p, type_q):

View File

@ -2,6 +2,7 @@ import torch
from torch.distributions.distribution import Distribution
from torch.distributions import Categorical
from torch.distributions import constraints
from typing import Dict
class MixtureSameFamily(Distribution):
@ -45,7 +46,7 @@ class MixtureSameFamily(Distribution):
component_distribution: `torch.distributions.Distribution`-like
instance. Right-most batch dimension indexes component.
"""
arg_constraints = {}
arg_constraints: Dict[str, constraints.Constraint] = {}
has_rsample = False
def __init__(self,

View File

@ -2,7 +2,6 @@ import torch
from torch._six import inf
from torch.distributions.distribution import Distribution
from torch.distributions import Categorical
from numbers import Number
from torch.distributions import constraints
from torch.distributions.utils import broadcast_all
@ -40,6 +39,7 @@ class Multinomial(Distribution):
"""
arg_constraints = {'probs': constraints.simplex,
'logits': constraints.real}
total_count: int
@property
def mean(self):
@ -50,7 +50,7 @@ class Multinomial(Distribution):
return self.total_count * self.probs * (1 - self.probs)
def __init__(self, total_count=1, probs=None, logits=None, validate_args=None):
if not isinstance(total_count, Number):
if not isinstance(total_count, int):
raise NotImplementedError('inhomogeneous total_count is not supported')
self.total_count = total_count
self._categorical = Categorical(probs=probs, logits=logits)

View File

@ -1,4 +1,5 @@
import math
from numbers import Real
from numbers import Number
import torch
@ -72,7 +73,7 @@ class Normal(ExponentialFamily):
self._validate_sample(value)
# compute the variance
var = (self.scale ** 2)
log_scale = math.log(self.scale) if isinstance(self.scale, Number) else self.scale.log()
log_scale = math.log(self.scale) if isinstance(self.scale, Real) else self.scale.log()
return -((value - self.loc) ** 2) / (2 * var) - log_scale - math.log(math.sqrt(2 * math.pi))
def cdf(self, value):

View File

@ -3,6 +3,7 @@ from torch.distributions import constraints
from torch.distributions.distribution import Distribution
from torch.distributions.transforms import Transform
from torch.distributions.utils import _sum_rightmost
from typing import Dict
class TransformedDistribution(Distribution):
@ -38,7 +39,7 @@ class TransformedDistribution(Distribution):
:class:`~torch.distributions.relaxed_bernoulli.RelaxedBernoulli` and
:class:`~torch.distributions.relaxed_categorical.RelaxedOneHotCategorical`
"""
arg_constraints = {}
arg_constraints: Dict[str, constraints.Constraint] = {}
def __init__(self, base_distribution, transforms, validate_args=None):
self.base_dist = base_distribution

View File

@ -9,6 +9,7 @@ from torch.distributions.utils import (_sum_rightmost, broadcast_all,
lazy_property)
from torch.nn.functional import pad
from torch.nn.functional import softplus
from typing import List
__all__ = [
'AbsTransform',
@ -77,6 +78,7 @@ class Transform(object):
transforms that act jointly on matrices, etc.
"""
bijective = False
codomain: constraints.Constraint
event_dim = 0
def __init__(self, cache_size=0):
@ -185,22 +187,27 @@ class _InverseTransform(Transform):
@constraints.dependent_property
def domain(self):
assert self._inv is not None
return self._inv.codomain
@constraints.dependent_property
def codomain(self):
assert self._inv is not None
return self._inv.domain
@property
def bijective(self):
assert self._inv is not None
return self._inv.bijective
@property
def sign(self):
assert self._inv is not None
return self._inv.sign
@property
def event_dim(self):
assert self._inv is not None
return self._inv.event_dim
@property
@ -208,17 +215,21 @@ class _InverseTransform(Transform):
return self._inv
def with_cache(self, cache_size=1):
assert self._inv is not None
return self.inv.with_cache(cache_size).inv
def __eq__(self, other):
if not isinstance(other, _InverseTransform):
return False
assert self._inv is not None
return self._inv == other._inv
def __call__(self, x):
assert self._inv is not None
return self._inv._inv_call(x)
def log_abs_det_jacobian(self, x, y):
assert self._inv is not None
return -self._inv.log_abs_det_jacobian(y, x)
@ -500,8 +511,8 @@ class AffineTransform(Transform):
@property
def sign(self):
if isinstance(self.scale, numbers.Number):
return 1 if self.scale > 0 else -1 if self.scale < 0 else 0
if isinstance(self.scale, numbers.Real):
return 1 if float(self.scale) > 0 else -1 if float(self.scale) < 0 else 0
return self.scale.sign()
def _call(self, x):
@ -513,7 +524,7 @@ class AffineTransform(Transform):
def log_abs_det_jacobian(self, x, y):
shape = x.shape
scale = self.scale
if isinstance(scale, numbers.Number):
if isinstance(scale, numbers.Real):
result = torch.full_like(x, math.log(abs(scale)))
else:
result = torch.abs(scale).log()
@ -575,7 +586,7 @@ class StickBreakingTransform(Transform):
offset = x.shape[-1] + 1 - x.new_ones(x.shape[-1]).cumsum(-1)
z = _clipped_sigmoid(x - offset.log())
z_cumprod = (1 - z).cumprod(-1)
y = pad(z, (0, 1), value=1) * pad(z_cumprod, (1, 0), value=1)
y = pad(z, [0, 1], value=1) * pad(z_cumprod, [1, 0], value=1)
return y
def _inverse(self, y):
@ -619,6 +630,7 @@ class LowerCholeskyTransform(Transform):
class CatTransform(Transform):
tseq: List[numbers.Number]
"""
Transform functor that applies a sequence of transforms `tseq`
component-wise to each submatrix at `dim`, of length `lengths[dim]`,

View File

@ -2,6 +2,7 @@ from functools import update_wrapper
from numbers import Number
import torch
import torch.nn.functional as F
from typing import Dict, Any
def broadcast_all(*values):
@ -23,13 +24,14 @@ def broadcast_all(*values):
if not all(isinstance(v, torch.Tensor) or isinstance(v, Number) for v in values):
raise ValueError('Input arguments must all be instances of numbers.Number or torch.tensor.')
if not all([isinstance(v, torch.Tensor) for v in values]):
options = dict(dtype=torch.get_default_dtype())
options: Dict[str, Any] = dict(dtype=torch.get_default_dtype())
for value in values:
if isinstance(value, torch.Tensor):
options = dict(dtype=value.dtype, device=value.device)
break
values = [v if isinstance(v, torch.Tensor) else torch.tensor(v, **options)
new_values = [v if isinstance(v, torch.Tensor) else torch.tensor(v, **options)
for v in values]
return torch.broadcast_tensors(*new_values)
return torch.broadcast_tensors(*values)
@ -94,7 +96,7 @@ class lazy_property(object):
"""
def __init__(self, wrapped):
self.wrapped = wrapped
update_wrapper(self, wrapped)
update_wrapper(self, wrapped) # type: ignore[arg-type]
def __get__(self, instance, obj_type=None):
if instance is None:

View File

@ -510,7 +510,7 @@ def stft(input: Tensor, n_fft: int, hop_length: Optional[int] = None,
signal_dim = input.dim()
extended_shape = [1] * (3 - signal_dim) + list(input.size())
pad = int(n_fft // 2)
input = F.pad(input.view(extended_shape), (pad, pad), pad_mode)
input = F.pad(input.view(extended_shape), [pad, pad], pad_mode)
input = input.view(input.shape[-signal_dim:])
return _VF.stft(input, n_fft, hop_length, win_length, window, # type: ignore
normalized, onesided, return_complex)