mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
This PR comprises a few small contributions: 1. `PowerTransform` returned a sign of `+1` irrespective of exponent. However, it should return the sign of the exponent because the gradient has the same sign as the exponent. That issue has been fixed. 2. Added tests to catch errors akin to 1. in the future. 3. Added an `InverseGamma` distribution as a `TransformedDistribution` with `PowerTransform(-1)` and `Gamma` base distribution. The `InverseGamma` is often used as a prior for the length scale of Gaussian processes to aggressively suppress short length scales (see [here](https://betanalpha.github.io/assets/case_studies/gaussian_processes.html#323_Informative_Prior_Model) for a discussion). Note: I added a `positive` constraint for the support of the inverse gamma distribution because the `PowerTransform(-1)` can fail for `nonnegative` constraints if the random variable is zero. ```python >>> torch.distributions.InverseGamma(0.5, 1.0).log_prob(torch.zeros(1)) --------------------------------------------------------------------------- ValueError Traceback (most recent call last) <ipython-input-8-758aa22deacd> in <module> ----> 1 torch.distributions.InverseGamma(0.5, 1.0).log_prob(torch.zeros(1)) ~/git/pytorch/torch/distributions/transformed_distribution.py in log_prob(self, value) 140 """ 141 if self._validate_args: --> 142 self._validate_sample(value) 143 event_dim = len(self.event_shape) 144 log_prob = 0.0 ~/git/pytorch/torch/distributions/distribution.py in _validate_sample(self, value) 298 valid = support.check(value) 299 if not valid.all(): --> 300 raise ValueError( 301 "Expected value argument " 302 f"({type(value).__name__} of shape {tuple(value.shape)}) " ValueError: Expected value argument (Tensor of shape (1,)) to be within the support (GreaterThan(lower_bound=0.0)) of the distribution InverseGamma(), but found invalid values: tensor([0.]) ``` This differs from the scipy implementation. ```python >>> scipy.stats.invgamma(0.5).pdf(0) 0.0 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/104501 Approved by: https://github.com/fritzo, https://github.com/ezyang
451 lines
10 KiB
ReStructuredText
451 lines
10 KiB
ReStructuredText
.. role:: hidden
|
|
:class: hidden-section
|
|
|
|
Probability distributions - torch.distributions
|
|
==================================================
|
|
|
|
.. automodule:: torch.distributions
|
|
.. currentmodule:: torch.distributions
|
|
|
|
:hidden:`Distribution`
|
|
~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. currentmodule:: torch.distributions.distribution
|
|
.. autoclass:: Distribution
|
|
:members:
|
|
:show-inheritance:
|
|
|
|
:hidden:`ExponentialFamily`
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. currentmodule:: torch.distributions.exp_family
|
|
.. autoclass:: ExponentialFamily
|
|
:members:
|
|
:show-inheritance:
|
|
|
|
:hidden:`Bernoulli`
|
|
~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. currentmodule:: torch.distributions.bernoulli
|
|
.. autoclass:: Bernoulli
|
|
:members:
|
|
:undoc-members:
|
|
:show-inheritance:
|
|
|
|
:hidden:`Beta`
|
|
~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. currentmodule:: torch.distributions.beta
|
|
.. autoclass:: Beta
|
|
:members:
|
|
:undoc-members:
|
|
:show-inheritance:
|
|
|
|
:hidden:`Binomial`
|
|
~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. currentmodule:: torch.distributions.binomial
|
|
.. autoclass:: Binomial
|
|
:members:
|
|
:undoc-members:
|
|
:show-inheritance:
|
|
|
|
:hidden:`Categorical`
|
|
~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. currentmodule:: torch.distributions.categorical
|
|
.. autoclass:: Categorical
|
|
:members:
|
|
:undoc-members:
|
|
:show-inheritance:
|
|
|
|
:hidden:`Cauchy`
|
|
~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. currentmodule:: torch.distributions.cauchy
|
|
.. autoclass:: Cauchy
|
|
:members:
|
|
:undoc-members:
|
|
:show-inheritance:
|
|
|
|
:hidden:`Chi2`
|
|
~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. currentmodule:: torch.distributions.chi2
|
|
.. autoclass:: Chi2
|
|
:members:
|
|
:undoc-members:
|
|
:show-inheritance:
|
|
|
|
:hidden:`ContinuousBernoulli`
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. currentmodule:: torch.distributions.continuous_bernoulli
|
|
.. autoclass:: ContinuousBernoulli
|
|
:members:
|
|
:undoc-members:
|
|
:show-inheritance:
|
|
|
|
:hidden:`Dirichlet`
|
|
~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. currentmodule:: torch.distributions.dirichlet
|
|
.. autoclass:: Dirichlet
|
|
:members:
|
|
:undoc-members:
|
|
:show-inheritance:
|
|
|
|
:hidden:`Exponential`
|
|
~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. currentmodule:: torch.distributions.exponential
|
|
.. autoclass:: Exponential
|
|
:members:
|
|
:undoc-members:
|
|
:show-inheritance:
|
|
|
|
:hidden:`FisherSnedecor`
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. currentmodule:: torch.distributions.fishersnedecor
|
|
.. autoclass:: FisherSnedecor
|
|
:members:
|
|
:undoc-members:
|
|
:show-inheritance:
|
|
|
|
:hidden:`Gamma`
|
|
~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. currentmodule:: torch.distributions.gamma
|
|
.. autoclass:: Gamma
|
|
:members:
|
|
:undoc-members:
|
|
:show-inheritance:
|
|
|
|
:hidden:`Geometric`
|
|
~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. currentmodule:: torch.distributions.geometric
|
|
.. autoclass:: Geometric
|
|
:members:
|
|
:undoc-members:
|
|
:show-inheritance:
|
|
|
|
:hidden:`Gumbel`
|
|
~~~~~~~~~~~~~~~~
|
|
|
|
.. currentmodule:: torch.distributions.gumbel
|
|
.. autoclass:: Gumbel
|
|
:members:
|
|
:undoc-members:
|
|
:show-inheritance:
|
|
|
|
:hidden:`HalfCauchy`
|
|
~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. currentmodule:: torch.distributions.half_cauchy
|
|
.. autoclass:: HalfCauchy
|
|
:members:
|
|
:undoc-members:
|
|
:show-inheritance:
|
|
|
|
:hidden:`HalfNormal`
|
|
~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. currentmodule:: torch.distributions.half_normal
|
|
.. autoclass:: HalfNormal
|
|
:members:
|
|
:undoc-members:
|
|
:show-inheritance:
|
|
|
|
:hidden:`Independent`
|
|
~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. currentmodule:: torch.distributions.independent
|
|
.. autoclass:: Independent
|
|
:members:
|
|
:undoc-members:
|
|
:show-inheritance:
|
|
|
|
:hidden:`InverseGamma`
|
|
~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. currentmodule:: torch.distributions.inverse_gamma
|
|
.. autoclass:: InverseGamma
|
|
:members:
|
|
:undoc-members:
|
|
:show-inheritance:
|
|
|
|
:hidden:`Kumaraswamy`
|
|
~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. currentmodule:: torch.distributions.kumaraswamy
|
|
.. autoclass:: Kumaraswamy
|
|
:members:
|
|
:undoc-members:
|
|
:show-inheritance:
|
|
|
|
:hidden:`LKJCholesky`
|
|
~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. currentmodule:: torch.distributions.lkj_cholesky
|
|
.. autoclass:: LKJCholesky
|
|
:members:
|
|
:undoc-members:
|
|
:show-inheritance:
|
|
|
|
:hidden:`Laplace`
|
|
~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. currentmodule:: torch.distributions.laplace
|
|
.. autoclass:: Laplace
|
|
:members:
|
|
:undoc-members:
|
|
:show-inheritance:
|
|
|
|
:hidden:`LogNormal`
|
|
~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. currentmodule:: torch.distributions.log_normal
|
|
.. autoclass:: LogNormal
|
|
:members:
|
|
:undoc-members:
|
|
:show-inheritance:
|
|
|
|
:hidden:`LowRankMultivariateNormal`
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. currentmodule:: torch.distributions.lowrank_multivariate_normal
|
|
.. autoclass:: LowRankMultivariateNormal
|
|
:members:
|
|
:undoc-members:
|
|
:show-inheritance:
|
|
|
|
:hidden:`MixtureSameFamily`
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. currentmodule:: torch.distributions.mixture_same_family
|
|
.. autoclass:: MixtureSameFamily
|
|
:members:
|
|
:undoc-members:
|
|
:show-inheritance:
|
|
|
|
:hidden:`Multinomial`
|
|
~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. currentmodule:: torch.distributions.multinomial
|
|
.. autoclass:: Multinomial
|
|
:members:
|
|
:undoc-members:
|
|
:show-inheritance:
|
|
|
|
:hidden:`MultivariateNormal`
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. currentmodule:: torch.distributions.multivariate_normal
|
|
.. autoclass:: MultivariateNormal
|
|
:members:
|
|
:undoc-members:
|
|
:show-inheritance:
|
|
|
|
:hidden:`NegativeBinomial`
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. currentmodule:: torch.distributions.negative_binomial
|
|
.. autoclass:: NegativeBinomial
|
|
:members:
|
|
:undoc-members:
|
|
:show-inheritance:
|
|
|
|
:hidden:`Normal`
|
|
~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. currentmodule:: torch.distributions.normal
|
|
.. autoclass:: Normal
|
|
:members:
|
|
:undoc-members:
|
|
:show-inheritance:
|
|
|
|
:hidden:`OneHotCategorical`
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. currentmodule:: torch.distributions.one_hot_categorical
|
|
.. autoclass:: OneHotCategorical
|
|
:members:
|
|
:undoc-members:
|
|
:show-inheritance:
|
|
|
|
:hidden:`Pareto`
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. currentmodule:: torch.distributions.pareto
|
|
.. autoclass:: Pareto
|
|
:members:
|
|
:undoc-members:
|
|
:show-inheritance:
|
|
|
|
:hidden:`Poisson`
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. currentmodule:: torch.distributions.poisson
|
|
.. autoclass:: Poisson
|
|
:members:
|
|
:undoc-members:
|
|
:show-inheritance:
|
|
|
|
:hidden:`RelaxedBernoulli`
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. currentmodule:: torch.distributions.relaxed_bernoulli
|
|
.. autoclass:: RelaxedBernoulli
|
|
:members:
|
|
:undoc-members:
|
|
:show-inheritance:
|
|
|
|
:hidden:`LogitRelaxedBernoulli`
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. currentmodule:: torch.distributions.relaxed_bernoulli
|
|
.. autoclass:: LogitRelaxedBernoulli
|
|
:members:
|
|
:undoc-members:
|
|
:show-inheritance:
|
|
|
|
:hidden:`RelaxedOneHotCategorical`
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. currentmodule:: torch.distributions.relaxed_categorical
|
|
.. autoclass:: RelaxedOneHotCategorical
|
|
:members:
|
|
:undoc-members:
|
|
:show-inheritance:
|
|
|
|
:hidden:`StudentT`
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. currentmodule:: torch.distributions.studentT
|
|
.. autoclass:: StudentT
|
|
:members:
|
|
:undoc-members:
|
|
:show-inheritance:
|
|
|
|
:hidden:`TransformedDistribution`
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. currentmodule:: torch.distributions.transformed_distribution
|
|
.. autoclass:: TransformedDistribution
|
|
:members:
|
|
:undoc-members:
|
|
:show-inheritance:
|
|
|
|
:hidden:`Uniform`
|
|
~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. currentmodule:: torch.distributions.uniform
|
|
.. autoclass:: Uniform
|
|
:members:
|
|
:undoc-members:
|
|
:show-inheritance:
|
|
|
|
:hidden:`VonMises`
|
|
~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. currentmodule:: torch.distributions.von_mises
|
|
.. autoclass:: VonMises
|
|
:members:
|
|
:undoc-members:
|
|
:show-inheritance:
|
|
|
|
:hidden:`Weibull`
|
|
~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. currentmodule:: torch.distributions.weibull
|
|
.. autoclass:: Weibull
|
|
:members:
|
|
:undoc-members:
|
|
:show-inheritance:
|
|
|
|
:hidden:`Wishart`
|
|
~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. currentmodule:: torch.distributions.wishart
|
|
.. autoclass:: Wishart
|
|
:members:
|
|
:undoc-members:
|
|
:show-inheritance:
|
|
|
|
`KL Divergence`
|
|
~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. automodule:: torch.distributions.kl
|
|
.. currentmodule:: torch.distributions.kl
|
|
|
|
.. autofunction:: kl_divergence
|
|
.. autofunction:: register_kl
|
|
|
|
`Transforms`
|
|
~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. automodule:: torch.distributions.transforms
|
|
:members:
|
|
:member-order: bysource
|
|
|
|
`Constraints`
|
|
~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. automodule:: torch.distributions.constraints
|
|
:members:
|
|
:member-order: bysource
|
|
|
|
`Constraint Registry`
|
|
~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. automodule:: torch.distributions.constraint_registry
|
|
:members:
|
|
:member-order: bysource
|
|
|
|
.. This module needs to be documented. Adding here in the meantime
|
|
.. for tracking purposes
|
|
.. py:module:: torch.distributions.bernoulli
|
|
.. py:module:: torch.distributions.beta
|
|
.. py:module:: torch.distributions.binomial
|
|
.. py:module:: torch.distributions.categorical
|
|
.. py:module:: torch.distributions.cauchy
|
|
.. py:module:: torch.distributions.chi2
|
|
.. py:module:: torch.distributions.continuous_bernoulli
|
|
.. py:module:: torch.distributions.dirichlet
|
|
.. py:module:: torch.distributions.distribution
|
|
.. py:module:: torch.distributions.exp_family
|
|
.. py:module:: torch.distributions.exponential
|
|
.. py:module:: torch.distributions.fishersnedecor
|
|
.. py:module:: torch.distributions.gamma
|
|
.. py:module:: torch.distributions.geometric
|
|
.. py:module:: torch.distributions.gumbel
|
|
.. py:module:: torch.distributions.half_cauchy
|
|
.. py:module:: torch.distributions.half_normal
|
|
.. py:module:: torch.distributions.independent
|
|
.. py:module:: torch.distributions.inverse_gamma
|
|
.. py:module:: torch.distributions.kumaraswamy
|
|
.. py:module:: torch.distributions.laplace
|
|
.. py:module:: torch.distributions.lkj_cholesky
|
|
.. py:module:: torch.distributions.log_normal
|
|
.. py:module:: torch.distributions.logistic_normal
|
|
.. py:module:: torch.distributions.lowrank_multivariate_normal
|
|
.. py:module:: torch.distributions.mixture_same_family
|
|
.. py:module:: torch.distributions.multinomial
|
|
.. py:module:: torch.distributions.multivariate_normal
|
|
.. py:module:: torch.distributions.negative_binomial
|
|
.. py:module:: torch.distributions.normal
|
|
.. py:module:: torch.distributions.one_hot_categorical
|
|
.. py:module:: torch.distributions.pareto
|
|
.. py:module:: torch.distributions.poisson
|
|
.. py:module:: torch.distributions.relaxed_bernoulli
|
|
.. py:module:: torch.distributions.relaxed_categorical
|
|
.. py:module:: torch.distributions.studentT
|
|
.. py:module:: torch.distributions.transformed_distribution
|
|
.. py:module:: torch.distributions.uniform
|
|
.. py:module:: torch.distributions.utils
|
|
.. py:module:: torch.distributions.von_mises
|
|
.. py:module:: torch.distributions.weibull
|
|
.. py:module:: torch.distributions.wishart
|