mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
remove allow-untyped-defs from torch/distributions/pareto.py (#144624)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144624 Approved by: https://github.com/Skylion007
This commit is contained in:
parent
80b756ed91
commit
ad221269b0
|
|
@ -1,4 +1,8 @@
|
||||||
# mypy: allow-untyped-defs
|
# mypy: allow-untyped-defs
|
||||||
|
|
||||||
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
|
|
||||||
r"""
|
r"""
|
||||||
The following constraints are implemented:
|
The following constraints are implemented:
|
||||||
|
|
||||||
|
|
@ -198,13 +202,17 @@ class _DependentProperty(property, _Dependent):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, fn=None, *, is_discrete=NotImplemented, event_dim=NotImplemented
|
self,
|
||||||
):
|
fn: Optional[Callable[..., Any]] = None,
|
||||||
|
*,
|
||||||
|
is_discrete: Optional[bool] = NotImplemented,
|
||||||
|
event_dim: Optional[int] = NotImplemented,
|
||||||
|
) -> None:
|
||||||
super().__init__(fn)
|
super().__init__(fn)
|
||||||
self._is_discrete = is_discrete
|
self._is_discrete = is_discrete
|
||||||
self._event_dim = event_dim
|
self._event_dim = event_dim
|
||||||
|
|
||||||
def __call__(self, fn): # type: ignore[override]
|
def __call__(self, fn: Callable[..., Any]) -> "_DependentProperty": # type: ignore[override]
|
||||||
"""
|
"""
|
||||||
Support for syntax to customize static attributes::
|
Support for syntax to customize static attributes::
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,12 @@
|
||||||
# mypy: allow-untyped-defs
|
from typing import Optional
|
||||||
|
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.distributions import constraints
|
from torch.distributions import constraints
|
||||||
from torch.distributions.exponential import Exponential
|
from torch.distributions.exponential import Exponential
|
||||||
from torch.distributions.transformed_distribution import TransformedDistribution
|
from torch.distributions.transformed_distribution import TransformedDistribution
|
||||||
from torch.distributions.transforms import AffineTransform, ExpTransform
|
from torch.distributions.transforms import AffineTransform, ExpTransform
|
||||||
from torch.distributions.utils import broadcast_all
|
from torch.distributions.utils import broadcast_all
|
||||||
|
from torch.types import _size
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["Pareto"]
|
__all__ = ["Pareto"]
|
||||||
|
|
@ -27,13 +29,17 @@ class Pareto(TransformedDistribution):
|
||||||
"""
|
"""
|
||||||
arg_constraints = {"alpha": constraints.positive, "scale": constraints.positive}
|
arg_constraints = {"alpha": constraints.positive, "scale": constraints.positive}
|
||||||
|
|
||||||
def __init__(self, scale, alpha, validate_args=None):
|
def __init__(
|
||||||
|
self, scale: Tensor, alpha: Tensor, validate_args: Optional[bool] = None
|
||||||
|
) -> None:
|
||||||
self.scale, self.alpha = broadcast_all(scale, alpha)
|
self.scale, self.alpha = broadcast_all(scale, alpha)
|
||||||
base_dist = Exponential(self.alpha, validate_args=validate_args)
|
base_dist = Exponential(self.alpha, validate_args=validate_args)
|
||||||
transforms = [ExpTransform(), AffineTransform(loc=0, scale=self.scale)]
|
transforms = [ExpTransform(), AffineTransform(loc=0, scale=self.scale)]
|
||||||
super().__init__(base_dist, transforms, validate_args=validate_args)
|
super().__init__(base_dist, transforms, validate_args=validate_args)
|
||||||
|
|
||||||
def expand(self, batch_shape, _instance=None):
|
def expand(
|
||||||
|
self, batch_shape: _size, _instance: Optional["Pareto"] = None
|
||||||
|
) -> "Pareto":
|
||||||
new = self._get_checked_instance(Pareto, _instance)
|
new = self._get_checked_instance(Pareto, _instance)
|
||||||
new.scale = self.scale.expand(batch_shape)
|
new.scale = self.scale.expand(batch_shape)
|
||||||
new.alpha = self.alpha.expand(batch_shape)
|
new.alpha = self.alpha.expand(batch_shape)
|
||||||
|
|
@ -56,8 +62,8 @@ class Pareto(TransformedDistribution):
|
||||||
return self.scale.pow(2) * a / ((a - 1).pow(2) * (a - 2))
|
return self.scale.pow(2) * a / ((a - 1).pow(2) * (a - 2))
|
||||||
|
|
||||||
@constraints.dependent_property(is_discrete=False, event_dim=0)
|
@constraints.dependent_property(is_discrete=False, event_dim=0)
|
||||||
def support(self):
|
def support(self) -> constraints.Constraint:
|
||||||
return constraints.greater_than_eq(self.scale)
|
return constraints.greater_than_eq(self.scale)
|
||||||
|
|
||||||
def entropy(self):
|
def entropy(self) -> Tensor:
|
||||||
return (self.scale / self.alpha).log() + (1 + self.alpha.reciprocal())
|
return (self.scale / self.alpha).log() + (1 + self.alpha.reciprocal())
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user