mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Registering of kl-divergence for independent distribution (#17681)
Summary: This address issue https://github.com/pytorch/pytorch/issues/13545 and implements the proposed fix together with a single test. Pull Request resolved: https://github.com/pytorch/pytorch/pull/17681 Differential Revision: D14360161 Pulled By: ezyang fbshipit-source-id: 427afc88e9054b5b0dc39ebbab1087b990695ea5
This commit is contained in:
parent
c02369151d
commit
8045b3eb14
|
|
@ -3078,6 +3078,7 @@ class TestKL(TestCase):
|
|||
laplace = pairwise(Laplace, [-2.0, 4.0, -3.0, 6.0], [1.0, 2.5, 1.0, 2.5])
|
||||
lognormal = pairwise(LogNormal, [-2.0, 2.0, -3.0, 3.0], [1.0, 2.0, 1.0, 2.0])
|
||||
normal = pairwise(Normal, [-2.0, 2.0, -3.0, 3.0], [1.0, 2.0, 1.0, 2.0])
|
||||
independent = (Independent(normal[0], 1), Independent(normal[1], 1))
|
||||
onehotcategorical = pairwise(OneHotCategorical, [[0.4, 0.3, 0.3],
|
||||
[0.2, 0.7, 0.1],
|
||||
[0.33, 0.33, 0.34],
|
||||
|
|
@ -3127,6 +3128,7 @@ class TestKL(TestCase):
|
|||
(gumbel, gumbel),
|
||||
(gumbel, normal),
|
||||
(halfnormal, halfnormal),
|
||||
(independent, independent),
|
||||
(laplace, laplace),
|
||||
(lognormal, lognormal),
|
||||
(laplace, normal),
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ from .gamma import Gamma
|
|||
from .geometric import Geometric
|
||||
from .gumbel import Gumbel
|
||||
from .half_normal import HalfNormal
|
||||
from .independent import Independent
|
||||
from .laplace import Laplace
|
||||
from .logistic_normal import LogisticNormal
|
||||
from .lowrank_multivariate_normal import (LowRankMultivariateNormal, _batch_lowrank_logdet,
|
||||
|
|
@ -730,3 +731,11 @@ def _kl_uniform_pareto(p, q):
|
|||
result = t2 * (q.alpha + 1) - t1
|
||||
result[p.low < q.support.lower_bound] = inf
|
||||
return result
|
||||
|
||||
|
||||
@register_kl(Independent, Independent)
|
||||
def _kl_independent_independent(p, q):
|
||||
if p.reinterpreted_batch_ndims != q.reinterpreted_batch_ndims:
|
||||
raise NotImplementedError
|
||||
result = kl_divergence(p.base_dist, q.base_dist)
|
||||
return _sum_rightmost(result, p.reinterpreted_batch_ndims)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user