remove allow-untyped-defs from torch/distributions/pareto.py (#144624)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144624
Approved by: https://github.com/Skylion007
This commit is contained in:
bobrenjc93 2025-01-11 11:40:31 -08:00 committed by PyTorch MergeBot
parent 80b756ed91
commit ad221269b0
2 changed files with 22 additions and 8 deletions

View File

@ -1,4 +1,8 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
from typing import Any, Callable, Optional
r""" r"""
The following constraints are implemented: The following constraints are implemented:
@ -198,13 +202,17 @@ class _DependentProperty(property, _Dependent):
""" """
def __init__( def __init__(
self, fn=None, *, is_discrete=NotImplemented, event_dim=NotImplemented self,
): fn: Optional[Callable[..., Any]] = None,
*,
is_discrete: Optional[bool] = NotImplemented,
event_dim: Optional[int] = NotImplemented,
) -> None:
super().__init__(fn) super().__init__(fn)
self._is_discrete = is_discrete self._is_discrete = is_discrete
self._event_dim = event_dim self._event_dim = event_dim
def __call__(self, fn): # type: ignore[override] def __call__(self, fn: Callable[..., Any]) -> "_DependentProperty": # type: ignore[override]
""" """
Support for syntax to customize static attributes:: Support for syntax to customize static attributes::

View File

@ -1,10 +1,12 @@
# mypy: allow-untyped-defs from typing import Optional
from torch import Tensor from torch import Tensor
from torch.distributions import constraints from torch.distributions import constraints
from torch.distributions.exponential import Exponential from torch.distributions.exponential import Exponential
from torch.distributions.transformed_distribution import TransformedDistribution from torch.distributions.transformed_distribution import TransformedDistribution
from torch.distributions.transforms import AffineTransform, ExpTransform from torch.distributions.transforms import AffineTransform, ExpTransform
from torch.distributions.utils import broadcast_all from torch.distributions.utils import broadcast_all
from torch.types import _size
__all__ = ["Pareto"] __all__ = ["Pareto"]
@ -27,13 +29,17 @@ class Pareto(TransformedDistribution):
""" """
arg_constraints = {"alpha": constraints.positive, "scale": constraints.positive} arg_constraints = {"alpha": constraints.positive, "scale": constraints.positive}
def __init__(self, scale, alpha, validate_args=None): def __init__(
self, scale: Tensor, alpha: Tensor, validate_args: Optional[bool] = None
) -> None:
self.scale, self.alpha = broadcast_all(scale, alpha) self.scale, self.alpha = broadcast_all(scale, alpha)
base_dist = Exponential(self.alpha, validate_args=validate_args) base_dist = Exponential(self.alpha, validate_args=validate_args)
transforms = [ExpTransform(), AffineTransform(loc=0, scale=self.scale)] transforms = [ExpTransform(), AffineTransform(loc=0, scale=self.scale)]
super().__init__(base_dist, transforms, validate_args=validate_args) super().__init__(base_dist, transforms, validate_args=validate_args)
def expand(self, batch_shape, _instance=None): def expand(
self, batch_shape: _size, _instance: Optional["Pareto"] = None
) -> "Pareto":
new = self._get_checked_instance(Pareto, _instance) new = self._get_checked_instance(Pareto, _instance)
new.scale = self.scale.expand(batch_shape) new.scale = self.scale.expand(batch_shape)
new.alpha = self.alpha.expand(batch_shape) new.alpha = self.alpha.expand(batch_shape)
@ -56,8 +62,8 @@ class Pareto(TransformedDistribution):
return self.scale.pow(2) * a / ((a - 1).pow(2) * (a - 2)) return self.scale.pow(2) * a / ((a - 1).pow(2) * (a - 2))
@constraints.dependent_property(is_discrete=False, event_dim=0) @constraints.dependent_property(is_discrete=False, event_dim=0)
def support(self): def support(self) -> constraints.Constraint:
return constraints.greater_than_eq(self.scale) return constraints.greater_than_eq(self.scale)
def entropy(self): def entropy(self) -> Tensor:
return (self.scale / self.alpha).log() + (1 + self.alpha.reciprocal()) return (self.scale / self.alpha).log() + (1 + self.alpha.reciprocal())