Registering of kl-divergence for independent distribution (#17681)

Summary:
This address issue https://github.com/pytorch/pytorch/issues/13545 and implements the proposed fix together with a single test.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17681

Differential Revision: D14360161

Pulled By: ezyang

fbshipit-source-id: 427afc88e9054b5b0dc39ebbab1087b990695ea5
This commit is contained in:
Nicki Skafte 2019-03-11 08:07:22 -07:00 committed by Facebook Github Bot
parent c02369151d
commit 8045b3eb14
2 changed files with 11 additions and 0 deletions

View File

@ -3078,6 +3078,7 @@ class TestKL(TestCase):
laplace = pairwise(Laplace, [-2.0, 4.0, -3.0, 6.0], [1.0, 2.5, 1.0, 2.5])
lognormal = pairwise(LogNormal, [-2.0, 2.0, -3.0, 3.0], [1.0, 2.0, 1.0, 2.0])
normal = pairwise(Normal, [-2.0, 2.0, -3.0, 3.0], [1.0, 2.0, 1.0, 2.0])
independent = (Independent(normal[0], 1), Independent(normal[1], 1))
onehotcategorical = pairwise(OneHotCategorical, [[0.4, 0.3, 0.3],
[0.2, 0.7, 0.1],
[0.33, 0.33, 0.34],
@ -3127,6 +3128,7 @@ class TestKL(TestCase):
(gumbel, gumbel),
(gumbel, normal),
(halfnormal, halfnormal),
(independent, independent),
(laplace, laplace),
(lognormal, lognormal),
(laplace, normal),

View File

@ -17,6 +17,7 @@ from .gamma import Gamma
from .geometric import Geometric
from .gumbel import Gumbel
from .half_normal import HalfNormal
from .independent import Independent
from .laplace import Laplace
from .logistic_normal import LogisticNormal
from .lowrank_multivariate_normal import (LowRankMultivariateNormal, _batch_lowrank_logdet,
@ -730,3 +731,11 @@ def _kl_uniform_pareto(p, q):
result = t2 * (q.alpha + 1) - t1
result[p.low < q.support.lower_bound] = inf
return result
@register_kl(Independent, Independent)
def _kl_independent_independent(p, q):
if p.reinterpreted_batch_ndims != q.reinterpreted_batch_ndims:
raise NotImplementedError
result = kl_divergence(p.base_dist, q.base_dist)
return _sum_rightmost(result, p.reinterpreted_batch_ndims)