mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Fixes https://github.com/pytorch/pytorch/issues/42979. Pull Request resolved: https://github.com/pytorch/pytorch/pull/45689 Reviewed By: agolynski Differential Revision: D24229870 Pulled By: xuzhao9 fbshipit-source-id: 5fc87cc428170139962ab65b71cacba494d46130
160 lines
5.5 KiB
Python
160 lines
5.5 KiB
Python
r"""
|
|
The ``distributions`` package contains parameterizable probability distributions
|
|
and sampling functions. This allows the construction of stochastic computation
|
|
graphs and stochastic gradient estimators for optimization. This package
|
|
generally follows the design of the `TensorFlow Distributions`_ package.
|
|
|
|
.. _`TensorFlow Distributions`:
|
|
https://arxiv.org/abs/1711.10604
|
|
|
|
It is not possible to directly backpropagate through random samples. However,
|
|
there are two main methods for creating surrogate functions that can be
|
|
backpropagated through. These are the score function estimator/likelihood ratio
|
|
estimator/REINFORCE and the pathwise derivative estimator. REINFORCE is commonly
|
|
seen as the basis for policy gradient methods in reinforcement learning, and the
|
|
pathwise derivative estimator is commonly seen in the reparameterization trick
|
|
in variational autoencoders. Whilst the score function only requires the value
|
|
of samples :math:`f(x)`, the pathwise derivative requires the derivative
|
|
:math:`f'(x)`. The next sections discuss these two in a reinforcement learning
|
|
example. For more details see
|
|
`Gradient Estimation Using Stochastic Computation Graphs`_ .
|
|
|
|
.. _`Gradient Estimation Using Stochastic Computation Graphs`:
|
|
https://arxiv.org/abs/1506.05254
|
|
|
|
Score function
|
|
^^^^^^^^^^^^^^
|
|
|
|
When the probability density function is differentiable with respect to its
|
|
parameters, we only need :meth:`~torch.distributions.Distribution.sample` and
|
|
:meth:`~torch.distributions.Distribution.log_prob` to implement REINFORCE:
|
|
|
|
.. math::
|
|
|
|
\Delta\theta = \alpha r \frac{\partial\log p(a|\pi^\theta(s))}{\partial\theta}
|
|
|
|
where :math:`\theta` are the parameters, :math:`\alpha` is the learning rate,
|
|
:math:`r` is the reward and :math:`p(a|\pi^\theta(s))` is the probability of
|
|
taking action :math:`a` in state :math:`s` given policy :math:`\pi^\theta`.
|
|
|
|
In practice we would sample an action from the output of a network, apply this
|
|
action in an environment, and then use ``log_prob`` to construct an equivalent
|
|
loss function. Note that we use a negative because optimizers use gradient
|
|
descent, whilst the rule above assumes gradient ascent. With a categorical
|
|
policy, the code for implementing REINFORCE would be as follows::
|
|
|
|
probs = policy_network(state)
|
|
# Note that this is equivalent to what used to be called multinomial
|
|
m = Categorical(probs)
|
|
action = m.sample()
|
|
next_state, reward = env.step(action)
|
|
loss = -m.log_prob(action) * reward
|
|
loss.backward()
|
|
|
|
Pathwise derivative
|
|
^^^^^^^^^^^^^^^^^^^
|
|
|
|
The other way to implement these stochastic/policy gradients would be to use the
|
|
reparameterization trick from the
|
|
:meth:`~torch.distributions.Distribution.rsample` method, where the
|
|
parameterized random variable can be constructed via a parameterized
|
|
deterministic function of a parameter-free random variable. The reparameterized
|
|
sample therefore becomes differentiable. The code for implementing the pathwise
|
|
derivative would be as follows::
|
|
|
|
params = policy_network(state)
|
|
m = Normal(*params)
|
|
# Any distribution with .has_rsample == True could work based on the application
|
|
action = m.rsample()
|
|
next_state, reward = env.step(action) # Assuming that reward is differentiable
|
|
loss = -reward
|
|
loss.backward()
|
|
"""
|
|
|
|
from .bernoulli import Bernoulli
|
|
from .beta import Beta
|
|
from .binomial import Binomial
|
|
from .categorical import Categorical
|
|
from .cauchy import Cauchy
|
|
from .chi2 import Chi2
|
|
from .constraint_registry import biject_to, transform_to
|
|
from .continuous_bernoulli import ContinuousBernoulli
|
|
from .dirichlet import Dirichlet
|
|
from .distribution import Distribution
|
|
from .exp_family import ExponentialFamily
|
|
from .exponential import Exponential
|
|
from .fishersnedecor import FisherSnedecor
|
|
from .gamma import Gamma
|
|
from .geometric import Geometric
|
|
from .gumbel import Gumbel
|
|
from .half_cauchy import HalfCauchy
|
|
from .half_normal import HalfNormal
|
|
from .independent import Independent
|
|
from .kl import kl_divergence, register_kl
|
|
from .laplace import Laplace
|
|
from .log_normal import LogNormal
|
|
from .logistic_normal import LogisticNormal
|
|
from .lowrank_multivariate_normal import LowRankMultivariateNormal
|
|
from .mixture_same_family import MixtureSameFamily
|
|
from .multinomial import Multinomial
|
|
from .multivariate_normal import MultivariateNormal
|
|
from .negative_binomial import NegativeBinomial
|
|
from .normal import Normal
|
|
from .one_hot_categorical import OneHotCategorical
|
|
from .pareto import Pareto
|
|
from .poisson import Poisson
|
|
from .relaxed_bernoulli import RelaxedBernoulli
|
|
from .relaxed_categorical import RelaxedOneHotCategorical
|
|
from .studentT import StudentT
|
|
from .transformed_distribution import TransformedDistribution
|
|
from .transforms import *
|
|
from .uniform import Uniform
|
|
from .von_mises import VonMises
|
|
from .weibull import Weibull
|
|
from . import transforms
|
|
|
|
__all__ = [
|
|
'Bernoulli',
|
|
'Beta',
|
|
'Binomial',
|
|
'Categorical',
|
|
'Cauchy',
|
|
'Chi2',
|
|
'ContinuousBernoulli',
|
|
'Dirichlet',
|
|
'Distribution',
|
|
'Exponential',
|
|
'ExponentialFamily',
|
|
'FisherSnedecor',
|
|
'Gamma',
|
|
'Geometric',
|
|
'Gumbel',
|
|
'HalfCauchy',
|
|
'HalfNormal',
|
|
'Independent',
|
|
'Laplace',
|
|
'LogNormal',
|
|
'LogisticNormal',
|
|
'LowRankMultivariateNormal',
|
|
'MixtureSameFamily',
|
|
'Multinomial',
|
|
'MultivariateNormal',
|
|
'NegativeBinomial',
|
|
'Normal',
|
|
'OneHotCategorical',
|
|
'Pareto',
|
|
'RelaxedBernoulli',
|
|
'RelaxedOneHotCategorical',
|
|
'StudentT',
|
|
'Poisson',
|
|
'Uniform',
|
|
'VonMises',
|
|
'Weibull',
|
|
'TransformedDistribution',
|
|
'biject_to',
|
|
'kl_divergence',
|
|
'register_kl',
|
|
'transform_to',
|
|
]
|
|
__all__.extend(transforms.__all__)
|