remove allow-untyped-defs from torch/distributions/pareto.py (#144624)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144624
Approved by: https://github.com/Skylion007
This commit is contained in:
bobrenjc93 2025-01-11 11:40:31 -08:00 committed by PyTorch MergeBot
parent 80b756ed91
commit ad221269b0
2 changed files with 22 additions and 8 deletions

View File

@ -1,4 +1,8 @@
# mypy: allow-untyped-defs
from typing import Any, Callable, Optional
r"""
The following constraints are implemented:
@ -198,13 +202,17 @@ class _DependentProperty(property, _Dependent):
"""
def __init__(
self, fn=None, *, is_discrete=NotImplemented, event_dim=NotImplemented
):
self,
fn: Optional[Callable[..., Any]] = None,
*,
is_discrete: Optional[bool] = NotImplemented,
event_dim: Optional[int] = NotImplemented,
) -> None:
super().__init__(fn)
self._is_discrete = is_discrete
self._event_dim = event_dim
def __call__(self, fn): # type: ignore[override]
def __call__(self, fn: Callable[..., Any]) -> "_DependentProperty": # type: ignore[override]
"""
Support for syntax to customize static attributes::

View File

@ -1,10 +1,12 @@
# mypy: allow-untyped-defs
from typing import Optional
from torch import Tensor
from torch.distributions import constraints
from torch.distributions.exponential import Exponential
from torch.distributions.transformed_distribution import TransformedDistribution
from torch.distributions.transforms import AffineTransform, ExpTransform
from torch.distributions.utils import broadcast_all
from torch.types import _size
__all__ = ["Pareto"]
@ -27,13 +29,17 @@ class Pareto(TransformedDistribution):
"""
arg_constraints = {"alpha": constraints.positive, "scale": constraints.positive}
def __init__(self, scale, alpha, validate_args=None):
def __init__(
self, scale: Tensor, alpha: Tensor, validate_args: Optional[bool] = None
) -> None:
self.scale, self.alpha = broadcast_all(scale, alpha)
base_dist = Exponential(self.alpha, validate_args=validate_args)
transforms = [ExpTransform(), AffineTransform(loc=0, scale=self.scale)]
super().__init__(base_dist, transforms, validate_args=validate_args)
def expand(self, batch_shape, _instance=None):
def expand(
self, batch_shape: _size, _instance: Optional["Pareto"] = None
) -> "Pareto":
new = self._get_checked_instance(Pareto, _instance)
new.scale = self.scale.expand(batch_shape)
new.alpha = self.alpha.expand(batch_shape)
@ -56,8 +62,8 @@ class Pareto(TransformedDistribution):
return self.scale.pow(2) * a / ((a - 1).pow(2) * (a - 2))
@constraints.dependent_property(is_discrete=False, event_dim=0)
def support(self):
def support(self) -> constraints.Constraint:
return constraints.greater_than_eq(self.scale)
def entropy(self):
def entropy(self) -> Tensor:
return (self.scale / self.alpha).log() + (1 + self.alpha.reciprocal())