mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary:
Fixes https://github.com/pytorch/pytorch/issues/72765.
- [x] Improved `NotImplementedError` verbosity.
- [x] Automate the docstring generation process
## Improved `NotImplementedError` verbosity
### Code
```python
import torch
dist = torch.distributions
torch_normal = dist.Normal(loc=0.0, scale=1.0)
torch_mixture = dist.MixtureSameFamily(
dist.Categorical(torch.ones(5,)
),
dist.Normal(torch.randn(5,), torch.rand(5,)),
)
dist.kl_divergence(torch_normal, torch_mixture)
```
#### Output before this PR
```python
NotImplementedError:
```
#### Output after this PR
```python
NotImplementedError: No KL(p || q) is implemented for p type Normal and q type MixtureSameFamily
```
## Automate the docstring generation process
### Docstring before this PR
```python
Compute Kullback-Leibler divergence :math:`KL(p \| q)` between two distributions.
.. math::
KL(p \| q) = \int p(x) \log\frac {p(x)} {q(x)} \,dx
Args:
p (Distribution): A :class:`~torch.distributions.Distribution` object.
q (Distribution): A :class:`~torch.distributions.Distribution` object.
Returns:
Tensor: A batch of KL divergences of shape `batch_shape`.
Raises:
NotImplementedError: If the distribution types have not been registered via
:meth:`register_kl`.
```
### Docstring after this PR
```python
Compute Kullback-Leibler divergence :math:`KL(p \| q)` between two distributions.
.. math::
KL(p \| q) = \int p(x) \log\frac {p(x)} {q(x)} \,dx
Args:
p (Distribution): A :class:`~torch.distributions.Distribution` object.
q (Distribution): A :class:`~torch.distributions.Distribution` object.
Returns:
Tensor: A batch of KL divergences of shape `batch_shape`.
Raises:
NotImplementedError: If the distribution types have not been registered via
:meth:`register_kl`.
KL divergence is currently implemented for the following distribution pairs:
* :class:`~torch.distributions.Bernoulli` and :class:`~torch.distributions.Bernoulli`
* :class:`~torch.distributions.Bernoulli` and :class:`~torch.distributions.Poisson`
* :class:`~torch.distributions.Beta` and :class:`~torch.distributions.Beta`
* :class:`~torch.distributions.Beta` and :class:`~torch.distributions.ContinuousBernoulli`
* :class:`~torch.distributions.Beta` and :class:`~torch.distributions.Exponential`
* :class:`~torch.distributions.Beta` and :class:`~torch.distributions.Gamma`
* :class:`~torch.distributions.Beta` and :class:`~torch.distributions.Normal`
* :class:`~torch.distributions.Beta` and :class:`~torch.distributions.Pareto`
* :class:`~torch.distributions.Beta` and :class:`~torch.distributions.Uniform`
* :class:`~torch.distributions.Binomial` and :class:`~torch.distributions.Binomial`
* :class:`~torch.distributions.Categorical` and :class:`~torch.distributions.Categorical`
* :class:`~torch.distributions.Cauchy` and :class:`~torch.distributions.Cauchy`
* :class:`~torch.distributions.ContinuousBernoulli` and :class:`~torch.distributions.ContinuousBernoulli`
* :class:`~torch.distributions.ContinuousBernoulli` and :class:`~torch.distributions.Exponential`
* :class:`~torch.distributions.ContinuousBernoulli` and :class:`~torch.distributions.Normal`
* :class:`~torch.distributions.ContinuousBernoulli` and :class:`~torch.distributions.Pareto`
* :class:`~torch.distributions.ContinuousBernoulli` and :class:`~torch.distributions.Uniform`
* :class:`~torch.distributions.Dirichlet` and :class:`~torch.distributions.Dirichlet`
* :class:`~torch.distributions.Exponential` and :class:`~torch.distributions.Beta`
* :class:`~torch.distributions.Exponential` and :class:`~torch.distributions.ContinuousBernoulli`
* :class:`~torch.distributions.Exponential` and :class:`~torch.distributions.Exponential`
* :class:`~torch.distributions.Exponential` and :class:`~torch.distributions.Gamma`
* :class:`~torch.distributions.Exponential` and :class:`~torch.distributions.Gumbel`
* :class:`~torch.distributions.Exponential` and :class:`~torch.distributions.Normal`
* :class:`~torch.distributions.Exponential` and :class:`~torch.distributions.Pareto`
* :class:`~torch.distributions.Exponential` and :class:`~torch.distributions.Uniform`
* :class:`~torch.distributions.ExponentialFamily` and :class:`~torch.distributions.ExponentialFamily`
* :class:`~torch.distributions.Gamma` and :class:`~torch.distributions.Beta`
* :class:`~torch.distributions.Gamma` and :class:`~torch.distributions.ContinuousBernoulli`
* :class:`~torch.distributions.Gamma` and :class:`~torch.distributions.Exponential`
* :class:`~torch.distributions.Gamma` and :class:`~torch.distributions.Gamma`
* :class:`~torch.distributions.Gamma` and :class:`~torch.distributions.Gumbel`
* :class:`~torch.distributions.Gamma` and :class:`~torch.distributions.Normal`
* :class:`~torch.distributions.Gamma` and :class:`~torch.distributions.Pareto`
* :class:`~torch.distributions.Gamma` and :class:`~torch.distributions.Uniform`
* :class:`~torch.distributions.Geometric` and :class:`~torch.distributions.Geometric`
* :class:`~torch.distributions.Gumbel` and :class:`~torch.distributions.Beta`
* :class:`~torch.distributions.Gumbel` and :class:`~torch.distributions.ContinuousBernoulli`
* :class:`~torch.distributions.Gumbel` and :class:`~torch.distributions.Exponential`
* :class:`~torch.distributions.Gumbel` and :class:`~torch.distributions.Gamma`
* :class:`~torch.distributions.Gumbel` and :class:`~torch.distributions.Gumbel`
* :class:`~torch.distributions.Gumbel` and :class:`~torch.distributions.Normal`
* :class:`~torch.distributions.Gumbel` and :class:`~torch.distributions.Pareto`
* :class:`~torch.distributions.Gumbel` and :class:`~torch.distributions.Uniform`
* :class:`~torch.distributions.HalfNormal` and :class:`~torch.distributions.HalfNormal`
* :class:`~torch.distributions.Independent` and :class:`~torch.distributions.Independent`
* :class:`~torch.distributions.Laplace` and :class:`~torch.distributions.Beta`
* :class:`~torch.distributions.Laplace` and :class:`~torch.distributions.ContinuousBernoulli`
* :class:`~torch.distributions.Laplace` and :class:`~torch.distributions.Exponential`
* :class:`~torch.distributions.Laplace` and :class:`~torch.distributions.Gamma`
* :class:`~torch.distributions.Laplace` and :class:`~torch.distributions.Laplace`
* :class:`~torch.distributions.Laplace` and :class:`~torch.distributions.Normal`
* :class:`~torch.distributions.Laplace` and :class:`~torch.distributions.Pareto`
* :class:`~torch.distributions.Laplace` and :class:`~torch.distributions.Uniform`
* :class:`~torch.distributions.LowRankMultivariateNormal` and :class:`~torch.distributions.LowRankMultivariateNormal`
* :class:`~torch.distributions.LowRankMultivariateNormal` and :class:`~torch.distributions.MultivariateNormal`
* :class:`~torch.distributions.MultivariateNormal` and :class:`~torch.distributions.LowRankMultivariateNormal`
* :class:`~torch.distributions.MultivariateNormal` and :class:`~torch.distributions.MultivariateNormal`
* :class:`~torch.distributions.Normal` and :class:`~torch.distributions.Beta`
* :class:`~torch.distributions.Normal` and :class:`~torch.distributions.ContinuousBernoulli`
* :class:`~torch.distributions.Normal` and :class:`~torch.distributions.Exponential`
* :class:`~torch.distributions.Normal` and :class:`~torch.distributions.Gamma`
* :class:`~torch.distributions.Normal` and :class:`~torch.distributions.Gumbel`
* :class:`~torch.distributions.Normal` and :class:`~torch.distributions.Laplace`
* :class:`~torch.distributions.Normal` and :class:`~torch.distributions.Normal`
* :class:`~torch.distributions.Normal` and :class:`~torch.distributions.Pareto`
* :class:`~torch.distributions.Normal` and :class:`~torch.distributions.Uniform`
* :class:`~torch.distributions.OneHotCategorical` and :class:`~torch.distributions.OneHotCategorical`
* :class:`~torch.distributions.Pareto` and :class:`~torch.distributions.Beta`
* :class:`~torch.distributions.Pareto` and :class:`~torch.distributions.ContinuousBernoulli`
* :class:`~torch.distributions.Pareto` and :class:`~torch.distributions.Exponential`
* :class:`~torch.distributions.Pareto` and :class:`~torch.distributions.Gamma`
* :class:`~torch.distributions.Pareto` and :class:`~torch.distributions.Normal`
* :class:`~torch.distributions.Pareto` and :class:`~torch.distributions.Pareto`
* :class:`~torch.distributions.Pareto` and :class:`~torch.distributions.Uniform`
* :class:`~torch.distributions.Poisson` and :class:`~torch.distributions.Bernoulli`
* :class:`~torch.distributions.Poisson` and :class:`~torch.distributions.Binomial`
* :class:`~torch.distributions.Poisson` and :class:`~torch.distributions.Poisson`
* :class:`~torch.distributions.TransformedDistribution` and :class:`~torch.distributions.TransformedDistribution`
* :class:`~torch.distributions.Uniform` and :class:`~torch.distributions.Beta`
* :class:`~torch.distributions.Uniform` and :class:`~torch.distributions.ContinuousBernoulli`
* :class:`~torch.distributions.Uniform` and :class:`~torch.distributions.Exponential`
* :class:`~torch.distributions.Uniform` and :class:`~torch.distributions.Gamma`
* :class:`~torch.distributions.Uniform` and :class:`~torch.distributions.Gumbel`
* :class:`~torch.distributions.Uniform` and :class:`~torch.distributions.Normal`
* :class:`~torch.distributions.Uniform` and :class:`~torch.distributions.Pareto`
* :class:`~torch.distributions.Uniform` and :class:`~torch.distributions.Uniform`
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72845
Reviewed By: mikaylagawarecki
Differential Revision: D34344551
Pulled By: soulitzer
fbshipit-source-id: 7a603613a2f56f71138d56399c7c521e2238e8c5
(cherry picked from commit 6b2a51c796)
170 lines
5.8 KiB
Python
170 lines
5.8 KiB
Python
r"""
|
|
The ``distributions`` package contains parameterizable probability distributions
|
|
and sampling functions. This allows the construction of stochastic computation
|
|
graphs and stochastic gradient estimators for optimization. This package
|
|
generally follows the design of the `TensorFlow Distributions`_ package.
|
|
|
|
.. _`TensorFlow Distributions`:
|
|
https://arxiv.org/abs/1711.10604
|
|
|
|
It is not possible to directly backpropagate through random samples. However,
|
|
there are two main methods for creating surrogate functions that can be
|
|
backpropagated through. These are the score function estimator/likelihood ratio
|
|
estimator/REINFORCE and the pathwise derivative estimator. REINFORCE is commonly
|
|
seen as the basis for policy gradient methods in reinforcement learning, and the
|
|
pathwise derivative estimator is commonly seen in the reparameterization trick
|
|
in variational autoencoders. Whilst the score function only requires the value
|
|
of samples :math:`f(x)`, the pathwise derivative requires the derivative
|
|
:math:`f'(x)`. The next sections discuss these two in a reinforcement learning
|
|
example. For more details see
|
|
`Gradient Estimation Using Stochastic Computation Graphs`_ .
|
|
|
|
.. _`Gradient Estimation Using Stochastic Computation Graphs`:
|
|
https://arxiv.org/abs/1506.05254
|
|
|
|
Score function
|
|
^^^^^^^^^^^^^^
|
|
|
|
When the probability density function is differentiable with respect to its
|
|
parameters, we only need :meth:`~torch.distributions.Distribution.sample` and
|
|
:meth:`~torch.distributions.Distribution.log_prob` to implement REINFORCE:
|
|
|
|
.. math::
|
|
|
|
\Delta\theta = \alpha r \frac{\partial\log p(a|\pi^\theta(s))}{\partial\theta}
|
|
|
|
where :math:`\theta` are the parameters, :math:`\alpha` is the learning rate,
|
|
:math:`r` is the reward and :math:`p(a|\pi^\theta(s))` is the probability of
|
|
taking action :math:`a` in state :math:`s` given policy :math:`\pi^\theta`.
|
|
|
|
In practice we would sample an action from the output of a network, apply this
|
|
action in an environment, and then use ``log_prob`` to construct an equivalent
|
|
loss function. Note that we use a negative because optimizers use gradient
|
|
descent, whilst the rule above assumes gradient ascent. With a categorical
|
|
policy, the code for implementing REINFORCE would be as follows::
|
|
|
|
probs = policy_network(state)
|
|
# Note that this is equivalent to what used to be called multinomial
|
|
m = Categorical(probs)
|
|
action = m.sample()
|
|
next_state, reward = env.step(action)
|
|
loss = -m.log_prob(action) * reward
|
|
loss.backward()
|
|
|
|
Pathwise derivative
|
|
^^^^^^^^^^^^^^^^^^^
|
|
|
|
The other way to implement these stochastic/policy gradients would be to use the
|
|
reparameterization trick from the
|
|
:meth:`~torch.distributions.Distribution.rsample` method, where the
|
|
parameterized random variable can be constructed via a parameterized
|
|
deterministic function of a parameter-free random variable. The reparameterized
|
|
sample therefore becomes differentiable. The code for implementing the pathwise
|
|
derivative would be as follows::
|
|
|
|
params = policy_network(state)
|
|
m = Normal(*params)
|
|
# Any distribution with .has_rsample == True could work based on the application
|
|
action = m.rsample()
|
|
next_state, reward = env.step(action) # Assuming that reward is differentiable
|
|
loss = -reward
|
|
loss.backward()
|
|
"""
|
|
|
|
from .bernoulli import Bernoulli
|
|
from .beta import Beta
|
|
from .binomial import Binomial
|
|
from .categorical import Categorical
|
|
from .cauchy import Cauchy
|
|
from .chi2 import Chi2
|
|
from .constraint_registry import biject_to, transform_to
|
|
from .continuous_bernoulli import ContinuousBernoulli
|
|
from .dirichlet import Dirichlet
|
|
from .distribution import Distribution
|
|
from .exp_family import ExponentialFamily
|
|
from .exponential import Exponential
|
|
from .fishersnedecor import FisherSnedecor
|
|
from .gamma import Gamma
|
|
from .geometric import Geometric
|
|
from .gumbel import Gumbel
|
|
from .half_cauchy import HalfCauchy
|
|
from .half_normal import HalfNormal
|
|
from .independent import Independent
|
|
from .kl import kl_divergence, register_kl, _add_kl_info
|
|
from .kumaraswamy import Kumaraswamy
|
|
from .laplace import Laplace
|
|
from .lkj_cholesky import LKJCholesky
|
|
from .log_normal import LogNormal
|
|
from .logistic_normal import LogisticNormal
|
|
from .lowrank_multivariate_normal import LowRankMultivariateNormal
|
|
from .mixture_same_family import MixtureSameFamily
|
|
from .multinomial import Multinomial
|
|
from .multivariate_normal import MultivariateNormal
|
|
from .negative_binomial import NegativeBinomial
|
|
from .normal import Normal
|
|
from .one_hot_categorical import OneHotCategorical, OneHotCategoricalStraightThrough
|
|
from .pareto import Pareto
|
|
from .poisson import Poisson
|
|
from .relaxed_bernoulli import RelaxedBernoulli
|
|
from .relaxed_categorical import RelaxedOneHotCategorical
|
|
from .studentT import StudentT
|
|
from .transformed_distribution import TransformedDistribution
|
|
from .transforms import * # noqa: F403
|
|
from .uniform import Uniform
|
|
from .von_mises import VonMises
|
|
from .weibull import Weibull
|
|
from .wishart import Wishart
|
|
from . import transforms
|
|
|
|
_add_kl_info()
|
|
del _add_kl_info
|
|
|
|
__all__ = [
|
|
'Bernoulli',
|
|
'Beta',
|
|
'Binomial',
|
|
'Categorical',
|
|
'Cauchy',
|
|
'Chi2',
|
|
'ContinuousBernoulli',
|
|
'Dirichlet',
|
|
'Distribution',
|
|
'Exponential',
|
|
'ExponentialFamily',
|
|
'FisherSnedecor',
|
|
'Gamma',
|
|
'Geometric',
|
|
'Gumbel',
|
|
'HalfCauchy',
|
|
'HalfNormal',
|
|
'Independent',
|
|
'Kumaraswamy',
|
|
'LKJCholesky',
|
|
'Laplace',
|
|
'LogNormal',
|
|
'LogisticNormal',
|
|
'LowRankMultivariateNormal',
|
|
'MixtureSameFamily',
|
|
'Multinomial',
|
|
'MultivariateNormal',
|
|
'NegativeBinomial',
|
|
'Normal',
|
|
'OneHotCategorical',
|
|
'OneHotCategoricalStraightThrough',
|
|
'Pareto',
|
|
'RelaxedBernoulli',
|
|
'RelaxedOneHotCategorical',
|
|
'StudentT',
|
|
'Poisson',
|
|
'Uniform',
|
|
'VonMises',
|
|
'Weibull',
|
|
'Wishart',
|
|
'TransformedDistribution',
|
|
'biject_to',
|
|
'kl_divergence',
|
|
'register_kl',
|
|
'transform_to',
|
|
]
|
|
__all__.extend(transforms.__all__)
|