mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
remove allow-untyped-defs from torch/distributions/pareto.py (#144624)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144624 Approved by: https://github.com/Skylion007
This commit is contained in:
parent
80b756ed91
commit
ad221269b0
|
|
@ -1,4 +1,8 @@
|
|||
# mypy: allow-untyped-defs
|
||||
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
|
||||
r"""
|
||||
The following constraints are implemented:
|
||||
|
||||
|
|
@ -198,13 +202,17 @@ class _DependentProperty(property, _Dependent):
|
|||
"""
|
||||
|
||||
def __init__(
|
||||
self, fn=None, *, is_discrete=NotImplemented, event_dim=NotImplemented
|
||||
):
|
||||
self,
|
||||
fn: Optional[Callable[..., Any]] = None,
|
||||
*,
|
||||
is_discrete: Optional[bool] = NotImplemented,
|
||||
event_dim: Optional[int] = NotImplemented,
|
||||
) -> None:
|
||||
super().__init__(fn)
|
||||
self._is_discrete = is_discrete
|
||||
self._event_dim = event_dim
|
||||
|
||||
def __call__(self, fn): # type: ignore[override]
|
||||
def __call__(self, fn: Callable[..., Any]) -> "_DependentProperty": # type: ignore[override]
|
||||
"""
|
||||
Support for syntax to customize static attributes::
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,12 @@
|
|||
# mypy: allow-untyped-defs
|
||||
from typing import Optional
|
||||
|
||||
from torch import Tensor
|
||||
from torch.distributions import constraints
|
||||
from torch.distributions.exponential import Exponential
|
||||
from torch.distributions.transformed_distribution import TransformedDistribution
|
||||
from torch.distributions.transforms import AffineTransform, ExpTransform
|
||||
from torch.distributions.utils import broadcast_all
|
||||
from torch.types import _size
|
||||
|
||||
|
||||
__all__ = ["Pareto"]
|
||||
|
|
@ -27,13 +29,17 @@ class Pareto(TransformedDistribution):
|
|||
"""
|
||||
arg_constraints = {"alpha": constraints.positive, "scale": constraints.positive}
|
||||
|
||||
def __init__(self, scale, alpha, validate_args=None):
|
||||
def __init__(
|
||||
self, scale: Tensor, alpha: Tensor, validate_args: Optional[bool] = None
|
||||
) -> None:
|
||||
self.scale, self.alpha = broadcast_all(scale, alpha)
|
||||
base_dist = Exponential(self.alpha, validate_args=validate_args)
|
||||
transforms = [ExpTransform(), AffineTransform(loc=0, scale=self.scale)]
|
||||
super().__init__(base_dist, transforms, validate_args=validate_args)
|
||||
|
||||
def expand(self, batch_shape, _instance=None):
|
||||
def expand(
|
||||
self, batch_shape: _size, _instance: Optional["Pareto"] = None
|
||||
) -> "Pareto":
|
||||
new = self._get_checked_instance(Pareto, _instance)
|
||||
new.scale = self.scale.expand(batch_shape)
|
||||
new.alpha = self.alpha.expand(batch_shape)
|
||||
|
|
@ -56,8 +62,8 @@ class Pareto(TransformedDistribution):
|
|||
return self.scale.pow(2) * a / ((a - 1).pow(2) * (a - 2))
|
||||
|
||||
@constraints.dependent_property(is_discrete=False, event_dim=0)
|
||||
def support(self):
|
||||
def support(self) -> constraints.Constraint:
|
||||
return constraints.greater_than_eq(self.scale)
|
||||
|
||||
def entropy(self):
|
||||
def entropy(self) -> Tensor:
|
||||
return (self.scale / self.alpha).log() + (1 + self.alpha.reciprocal())
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user