pytorch/docs/source/distributions.rst
Till Hoffmann 5296c14094 Add inverse gamma distribution and fix sign bug in PowerTransform. (#104501)
This PR comprises a few small contributions:

1. `PowerTransform` returned a sign of `+1` irrespective of exponent. However, it should return the sign of the exponent because the gradient has the same sign as the exponent. That issue has been fixed.
2. Added tests to catch errors akin to 1. in the future.
3. Added an `InverseGamma` distribution as a `TransformedDistribution` with `PowerTransform(-1)` and `Gamma` base distribution. The `InverseGamma` is often used as a prior for the length scale of Gaussian processes to aggressively suppress short length scales (see [here](https://betanalpha.github.io/assets/case_studies/gaussian_processes.html#323_Informative_Prior_Model) for a discussion).

Note: I added a `positive` constraint for the support of the inverse gamma distribution because the `PowerTransform(-1)` can fail for `nonnegative` constraints if the random variable is zero.

```python
>>> torch.distributions.InverseGamma(0.5, 1.0).log_prob(torch.zeros(1))
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-8-758aa22deacd> in <module>
----> 1 torch.distributions.InverseGamma(0.5, 1.0).log_prob(torch.zeros(1))

~/git/pytorch/torch/distributions/transformed_distribution.py in log_prob(self, value)
    140         """
    141         if self._validate_args:
--> 142             self._validate_sample(value)
    143         event_dim = len(self.event_shape)
    144         log_prob = 0.0

~/git/pytorch/torch/distributions/distribution.py in _validate_sample(self, value)
    298         valid = support.check(value)
    299         if not valid.all():
--> 300             raise ValueError(
    301                 "Expected value argument "
    302                 f"({type(value).__name__} of shape {tuple(value.shape)}) "

ValueError: Expected value argument (Tensor of shape (1,)) to be within the support (GreaterThan(lower_bound=0.0)) of the distribution InverseGamma(), but found invalid values:
tensor([0.])
```

This differs from the scipy implementation.

```python
>>> scipy.stats.invgamma(0.5).pdf(0)
0.0
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104501
Approved by: https://github.com/fritzo, https://github.com/ezyang
2023-11-01 02:26:25 +00:00

451 lines
10 KiB
ReStructuredText

.. role:: hidden
:class: hidden-section
Probability distributions - torch.distributions
==================================================
.. automodule:: torch.distributions
.. currentmodule:: torch.distributions
:hidden:`Distribution`
~~~~~~~~~~~~~~~~~~~~~~~
.. currentmodule:: torch.distributions.distribution
.. autoclass:: Distribution
:members:
:show-inheritance:
:hidden:`ExponentialFamily`
~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. currentmodule:: torch.distributions.exp_family
.. autoclass:: ExponentialFamily
:members:
:show-inheritance:
:hidden:`Bernoulli`
~~~~~~~~~~~~~~~~~~~~~~~
.. currentmodule:: torch.distributions.bernoulli
.. autoclass:: Bernoulli
:members:
:undoc-members:
:show-inheritance:
:hidden:`Beta`
~~~~~~~~~~~~~~~~~~~~~~~
.. currentmodule:: torch.distributions.beta
.. autoclass:: Beta
:members:
:undoc-members:
:show-inheritance:
:hidden:`Binomial`
~~~~~~~~~~~~~~~~~~~~~~~
.. currentmodule:: torch.distributions.binomial
.. autoclass:: Binomial
:members:
:undoc-members:
:show-inheritance:
:hidden:`Categorical`
~~~~~~~~~~~~~~~~~~~~~~~
.. currentmodule:: torch.distributions.categorical
.. autoclass:: Categorical
:members:
:undoc-members:
:show-inheritance:
:hidden:`Cauchy`
~~~~~~~~~~~~~~~~~~~~~~~
.. currentmodule:: torch.distributions.cauchy
.. autoclass:: Cauchy
:members:
:undoc-members:
:show-inheritance:
:hidden:`Chi2`
~~~~~~~~~~~~~~~~~~~~~~~
.. currentmodule:: torch.distributions.chi2
.. autoclass:: Chi2
:members:
:undoc-members:
:show-inheritance:
:hidden:`ContinuousBernoulli`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. currentmodule:: torch.distributions.continuous_bernoulli
.. autoclass:: ContinuousBernoulli
:members:
:undoc-members:
:show-inheritance:
:hidden:`Dirichlet`
~~~~~~~~~~~~~~~~~~~~~~~
.. currentmodule:: torch.distributions.dirichlet
.. autoclass:: Dirichlet
:members:
:undoc-members:
:show-inheritance:
:hidden:`Exponential`
~~~~~~~~~~~~~~~~~~~~~~~
.. currentmodule:: torch.distributions.exponential
.. autoclass:: Exponential
:members:
:undoc-members:
:show-inheritance:
:hidden:`FisherSnedecor`
~~~~~~~~~~~~~~~~~~~~~~~~~
.. currentmodule:: torch.distributions.fishersnedecor
.. autoclass:: FisherSnedecor
:members:
:undoc-members:
:show-inheritance:
:hidden:`Gamma`
~~~~~~~~~~~~~~~~~~~~~~~
.. currentmodule:: torch.distributions.gamma
.. autoclass:: Gamma
:members:
:undoc-members:
:show-inheritance:
:hidden:`Geometric`
~~~~~~~~~~~~~~~~~~~~~~~
.. currentmodule:: torch.distributions.geometric
.. autoclass:: Geometric
:members:
:undoc-members:
:show-inheritance:
:hidden:`Gumbel`
~~~~~~~~~~~~~~~~
.. currentmodule:: torch.distributions.gumbel
.. autoclass:: Gumbel
:members:
:undoc-members:
:show-inheritance:
:hidden:`HalfCauchy`
~~~~~~~~~~~~~~~~~~~~~~~
.. currentmodule:: torch.distributions.half_cauchy
.. autoclass:: HalfCauchy
:members:
:undoc-members:
:show-inheritance:
:hidden:`HalfNormal`
~~~~~~~~~~~~~~~~~~~~~~~
.. currentmodule:: torch.distributions.half_normal
.. autoclass:: HalfNormal
:members:
:undoc-members:
:show-inheritance:
:hidden:`Independent`
~~~~~~~~~~~~~~~~~~~~~
.. currentmodule:: torch.distributions.independent
.. autoclass:: Independent
:members:
:undoc-members:
:show-inheritance:
:hidden:`InverseGamma`
~~~~~~~~~~~~~~~~~~~~~~
.. currentmodule:: torch.distributions.inverse_gamma
.. autoclass:: InverseGamma
:members:
:undoc-members:
:show-inheritance:
:hidden:`Kumaraswamy`
~~~~~~~~~~~~~~~~~~~~~
.. currentmodule:: torch.distributions.kumaraswamy
.. autoclass:: Kumaraswamy
:members:
:undoc-members:
:show-inheritance:
:hidden:`LKJCholesky`
~~~~~~~~~~~~~~~~~~~~~~~
.. currentmodule:: torch.distributions.lkj_cholesky
.. autoclass:: LKJCholesky
:members:
:undoc-members:
:show-inheritance:
:hidden:`Laplace`
~~~~~~~~~~~~~~~~~~~~~~~
.. currentmodule:: torch.distributions.laplace
.. autoclass:: Laplace
:members:
:undoc-members:
:show-inheritance:
:hidden:`LogNormal`
~~~~~~~~~~~~~~~~~~~~~~~
.. currentmodule:: torch.distributions.log_normal
.. autoclass:: LogNormal
:members:
:undoc-members:
:show-inheritance:
:hidden:`LowRankMultivariateNormal`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. currentmodule:: torch.distributions.lowrank_multivariate_normal
.. autoclass:: LowRankMultivariateNormal
:members:
:undoc-members:
:show-inheritance:
:hidden:`MixtureSameFamily`
~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. currentmodule:: torch.distributions.mixture_same_family
.. autoclass:: MixtureSameFamily
:members:
:undoc-members:
:show-inheritance:
:hidden:`Multinomial`
~~~~~~~~~~~~~~~~~~~~~~~
.. currentmodule:: torch.distributions.multinomial
.. autoclass:: Multinomial
:members:
:undoc-members:
:show-inheritance:
:hidden:`MultivariateNormal`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. currentmodule:: torch.distributions.multivariate_normal
.. autoclass:: MultivariateNormal
:members:
:undoc-members:
:show-inheritance:
:hidden:`NegativeBinomial`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. currentmodule:: torch.distributions.negative_binomial
.. autoclass:: NegativeBinomial
:members:
:undoc-members:
:show-inheritance:
:hidden:`Normal`
~~~~~~~~~~~~~~~~~~~~~~~
.. currentmodule:: torch.distributions.normal
.. autoclass:: Normal
:members:
:undoc-members:
:show-inheritance:
:hidden:`OneHotCategorical`
~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. currentmodule:: torch.distributions.one_hot_categorical
.. autoclass:: OneHotCategorical
:members:
:undoc-members:
:show-inheritance:
:hidden:`Pareto`
~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. currentmodule:: torch.distributions.pareto
.. autoclass:: Pareto
:members:
:undoc-members:
:show-inheritance:
:hidden:`Poisson`
~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. currentmodule:: torch.distributions.poisson
.. autoclass:: Poisson
:members:
:undoc-members:
:show-inheritance:
:hidden:`RelaxedBernoulli`
~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. currentmodule:: torch.distributions.relaxed_bernoulli
.. autoclass:: RelaxedBernoulli
:members:
:undoc-members:
:show-inheritance:
:hidden:`LogitRelaxedBernoulli`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. currentmodule:: torch.distributions.relaxed_bernoulli
.. autoclass:: LogitRelaxedBernoulli
:members:
:undoc-members:
:show-inheritance:
:hidden:`RelaxedOneHotCategorical`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. currentmodule:: torch.distributions.relaxed_categorical
.. autoclass:: RelaxedOneHotCategorical
:members:
:undoc-members:
:show-inheritance:
:hidden:`StudentT`
~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. currentmodule:: torch.distributions.studentT
.. autoclass:: StudentT
:members:
:undoc-members:
:show-inheritance:
:hidden:`TransformedDistribution`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. currentmodule:: torch.distributions.transformed_distribution
.. autoclass:: TransformedDistribution
:members:
:undoc-members:
:show-inheritance:
:hidden:`Uniform`
~~~~~~~~~~~~~~~~~~~~~~~
.. currentmodule:: torch.distributions.uniform
.. autoclass:: Uniform
:members:
:undoc-members:
:show-inheritance:
:hidden:`VonMises`
~~~~~~~~~~~~~~~~~~~~~~~
.. currentmodule:: torch.distributions.von_mises
.. autoclass:: VonMises
:members:
:undoc-members:
:show-inheritance:
:hidden:`Weibull`
~~~~~~~~~~~~~~~~~~~~~~~
.. currentmodule:: torch.distributions.weibull
.. autoclass:: Weibull
:members:
:undoc-members:
:show-inheritance:
:hidden:`Wishart`
~~~~~~~~~~~~~~~~~~~~~~~
.. currentmodule:: torch.distributions.wishart
.. autoclass:: Wishart
:members:
:undoc-members:
:show-inheritance:
`KL Divergence`
~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: torch.distributions.kl
.. currentmodule:: torch.distributions.kl
.. autofunction:: kl_divergence
.. autofunction:: register_kl
`Transforms`
~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: torch.distributions.transforms
:members:
:member-order: bysource
`Constraints`
~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: torch.distributions.constraints
:members:
:member-order: bysource
`Constraint Registry`
~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: torch.distributions.constraint_registry
:members:
:member-order: bysource
.. This module needs to be documented. Adding here in the meantime
.. for tracking purposes
.. py:module:: torch.distributions.bernoulli
.. py:module:: torch.distributions.beta
.. py:module:: torch.distributions.binomial
.. py:module:: torch.distributions.categorical
.. py:module:: torch.distributions.cauchy
.. py:module:: torch.distributions.chi2
.. py:module:: torch.distributions.continuous_bernoulli
.. py:module:: torch.distributions.dirichlet
.. py:module:: torch.distributions.distribution
.. py:module:: torch.distributions.exp_family
.. py:module:: torch.distributions.exponential
.. py:module:: torch.distributions.fishersnedecor
.. py:module:: torch.distributions.gamma
.. py:module:: torch.distributions.geometric
.. py:module:: torch.distributions.gumbel
.. py:module:: torch.distributions.half_cauchy
.. py:module:: torch.distributions.half_normal
.. py:module:: torch.distributions.independent
.. py:module:: torch.distributions.inverse_gamma
.. py:module:: torch.distributions.kumaraswamy
.. py:module:: torch.distributions.laplace
.. py:module:: torch.distributions.lkj_cholesky
.. py:module:: torch.distributions.log_normal
.. py:module:: torch.distributions.logistic_normal
.. py:module:: torch.distributions.lowrank_multivariate_normal
.. py:module:: torch.distributions.mixture_same_family
.. py:module:: torch.distributions.multinomial
.. py:module:: torch.distributions.multivariate_normal
.. py:module:: torch.distributions.negative_binomial
.. py:module:: torch.distributions.normal
.. py:module:: torch.distributions.one_hot_categorical
.. py:module:: torch.distributions.pareto
.. py:module:: torch.distributions.poisson
.. py:module:: torch.distributions.relaxed_bernoulli
.. py:module:: torch.distributions.relaxed_categorical
.. py:module:: torch.distributions.studentT
.. py:module:: torch.distributions.transformed_distribution
.. py:module:: torch.distributions.uniform
.. py:module:: torch.distributions.utils
.. py:module:: torch.distributions.von_mises
.. py:module:: torch.distributions.weibull
.. py:module:: torch.distributions.wishart