mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert D19899550: [pytorch][PR] Second try on Von Mises: Make it JIT compatible
Test Plan: revert-hammer Differential Revision: D19899550 Original commit changeset: fbcdd9bc9143 fbshipit-source-id: c8a675a8b53f884acd0e6c57bc7aa15faf83d5d6
This commit is contained in:
parent
ff5f38f53b
commit
0c98939b7b
|
|
@ -302,15 +302,6 @@ Probability distributions - torch.distributions
|
||||||
:undoc-members:
|
:undoc-members:
|
||||||
:show-inheritance:
|
:show-inheritance:
|
||||||
|
|
||||||
:hidden:`VonMises`
|
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~
|
|
||||||
|
|
||||||
.. currentmodule:: torch.distributions.von_mises
|
|
||||||
.. autoclass:: VonMises
|
|
||||||
:members:
|
|
||||||
:undoc-members:
|
|
||||||
:show-inheritance:
|
|
||||||
|
|
||||||
:hidden:`Weibull`
|
:hidden:`Weibull`
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -50,7 +50,7 @@ from torch.distributions import (Bernoulli, Beta, Binomial, Categorical,
|
||||||
NegativeBinomial, Normal, OneHotCategorical, Pareto,
|
NegativeBinomial, Normal, OneHotCategorical, Pareto,
|
||||||
Poisson, RelaxedBernoulli, RelaxedOneHotCategorical,
|
Poisson, RelaxedBernoulli, RelaxedOneHotCategorical,
|
||||||
StudentT, TransformedDistribution, Uniform,
|
StudentT, TransformedDistribution, Uniform,
|
||||||
VonMises, Weibull, constraints, kl_divergence)
|
Weibull, constraints, kl_divergence)
|
||||||
from torch.distributions.constraint_registry import biject_to, transform_to
|
from torch.distributions.constraint_registry import biject_to, transform_to
|
||||||
from torch.distributions.constraints import Constraint, is_dependent
|
from torch.distributions.constraints import Constraint, is_dependent
|
||||||
from torch.distributions.dirichlet import _Dirichlet_backward
|
from torch.distributions.dirichlet import _Dirichlet_backward
|
||||||
|
|
@ -425,16 +425,6 @@ EXAMPLES = [
|
||||||
'high': torch.tensor([2.0, 3.0], requires_grad=True),
|
'high': torch.tensor([2.0, 3.0], requires_grad=True),
|
||||||
},
|
},
|
||||||
]),
|
]),
|
||||||
Example(VonMises, [
|
|
||||||
{
|
|
||||||
'loc': torch.tensor(1.0, requires_grad=True),
|
|
||||||
'concentration': torch.tensor(10.0, requires_grad=True)
|
|
||||||
},
|
|
||||||
{
|
|
||||||
'loc': torch.tensor([0.0, math.pi / 2], requires_grad=True),
|
|
||||||
'concentration': torch.tensor([1.0, 10.0], requires_grad=True)
|
|
||||||
}
|
|
||||||
]),
|
|
||||||
Example(Weibull, [
|
Example(Weibull, [
|
||||||
{
|
{
|
||||||
'scale': torch.randn(5, 5).abs().requires_grad_(),
|
'scale': torch.randn(5, 5).abs().requires_grad_(),
|
||||||
|
|
@ -693,7 +683,7 @@ class TestDistributions(TestCase):
|
||||||
asset_fn(i, val.squeeze(), log_prob)
|
asset_fn(i, val.squeeze(), log_prob)
|
||||||
|
|
||||||
def _check_sampler_sampler(self, torch_dist, ref_dist, message, multivariate=False,
|
def _check_sampler_sampler(self, torch_dist, ref_dist, message, multivariate=False,
|
||||||
circular=False, num_samples=10000, failure_rate=1e-3):
|
num_samples=10000, failure_rate=1e-3):
|
||||||
# Checks that the .sample() method matches a reference function.
|
# Checks that the .sample() method matches a reference function.
|
||||||
torch_samples = torch_dist.sample((num_samples,)).squeeze()
|
torch_samples = torch_dist.sample((num_samples,)).squeeze()
|
||||||
torch_samples = torch_samples.cpu().numpy()
|
torch_samples = torch_samples.cpu().numpy()
|
||||||
|
|
@ -705,8 +695,6 @@ class TestDistributions(TestCase):
|
||||||
torch_samples = np.dot(torch_samples, axis)
|
torch_samples = np.dot(torch_samples, axis)
|
||||||
ref_samples = np.dot(ref_samples, axis)
|
ref_samples = np.dot(ref_samples, axis)
|
||||||
samples = [(x, +1) for x in torch_samples] + [(x, -1) for x in ref_samples]
|
samples = [(x, +1) for x in torch_samples] + [(x, -1) for x in ref_samples]
|
||||||
if circular:
|
|
||||||
samples = [(np.cos(x), v) for (x, v) in samples]
|
|
||||||
shuffle(samples) # necessary to prevent stable sort from making uneven bins for discrete
|
shuffle(samples) # necessary to prevent stable sort from making uneven bins for discrete
|
||||||
samples.sort(key=lambda x: x[0])
|
samples.sort(key=lambda x: x[0])
|
||||||
samples = np.array(samples)[:, 1]
|
samples = np.array(samples)[:, 1]
|
||||||
|
|
@ -1360,23 +1348,6 @@ class TestDistributions(TestCase):
|
||||||
low.grad.zero_()
|
low.grad.zero_()
|
||||||
high.grad.zero_()
|
high.grad.zero_()
|
||||||
|
|
||||||
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
|
|
||||||
def test_vonmises_sample(self):
|
|
||||||
for loc in [0.0, math.pi / 2.0]:
|
|
||||||
for concentration in [0.03, 0.3, 1.0, 10.0, 100.0]:
|
|
||||||
self._check_sampler_sampler(VonMises(loc, concentration),
|
|
||||||
scipy.stats.vonmises(loc=loc, kappa=concentration),
|
|
||||||
"VonMises(loc={}, concentration={})".format(loc, concentration),
|
|
||||||
num_samples=int(1e5), circular=True)
|
|
||||||
|
|
||||||
def test_vonmises_logprob(self):
|
|
||||||
concentrations = [0.01, 0.03, 0.1, 0.3, 1.0, 3.0, 10.0, 30.0, 100.0]
|
|
||||||
for concentration in concentrations:
|
|
||||||
grid = torch.arange(0., 2 * math.pi, 1e-4)
|
|
||||||
prob = VonMises(0.0, concentration).log_prob(grid).exp()
|
|
||||||
norm = prob.mean().item() * 2 * math.pi
|
|
||||||
self.assertLess(abs(norm - 1), 1e-3)
|
|
||||||
|
|
||||||
def test_cauchy(self):
|
def test_cauchy(self):
|
||||||
loc = torch.zeros(5, 5, requires_grad=True)
|
loc = torch.zeros(5, 5, requires_grad=True)
|
||||||
scale = torch.ones(5, 5, requires_grad=True)
|
scale = torch.ones(5, 5, requires_grad=True)
|
||||||
|
|
@ -3052,27 +3023,6 @@ class TestDistributionShapes(TestCase):
|
||||||
self.assertEqual(gumbel.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
|
self.assertEqual(gumbel.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
|
||||||
self.assertEqual(gumbel.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)))
|
self.assertEqual(gumbel.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)))
|
||||||
|
|
||||||
def test_vonmises_shape_tensor_params(self):
|
|
||||||
von_mises = VonMises(torch.tensor([0., 0.]), torch.tensor([1., 1.]))
|
|
||||||
self.assertEqual(von_mises._batch_shape, torch.Size((2,)))
|
|
||||||
self.assertEqual(von_mises._event_shape, torch.Size(()))
|
|
||||||
self.assertEqual(von_mises.sample().size(), torch.Size((2,)))
|
|
||||||
self.assertEqual(von_mises.sample(torch.Size((3, 2))).size(), torch.Size((3, 2, 2)))
|
|
||||||
self.assertEqual(von_mises.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
|
|
||||||
self.assertEqual(von_mises.log_prob(torch.ones(2, 1)).size(), torch.Size((2, 2)))
|
|
||||||
|
|
||||||
def test_vonmises_shape_scalar_params(self):
|
|
||||||
von_mises = VonMises(0., 1.)
|
|
||||||
self.assertEqual(von_mises._batch_shape, torch.Size())
|
|
||||||
self.assertEqual(von_mises._event_shape, torch.Size())
|
|
||||||
self.assertEqual(von_mises.sample().size(), torch.Size())
|
|
||||||
self.assertEqual(von_mises.sample(torch.Size((3, 2))).size(),
|
|
||||||
torch.Size((3, 2)))
|
|
||||||
self.assertEqual(von_mises.log_prob(self.tensor_sample_1).size(),
|
|
||||||
torch.Size((3, 2)))
|
|
||||||
self.assertEqual(von_mises.log_prob(self.tensor_sample_2).size(),
|
|
||||||
torch.Size((3, 2, 3)))
|
|
||||||
|
|
||||||
def test_weibull_scale_scalar_params(self):
|
def test_weibull_scale_scalar_params(self):
|
||||||
weibull = Weibull(1, 1)
|
weibull = Weibull(1, 1)
|
||||||
self.assertEqual(weibull._batch_shape, torch.Size())
|
self.assertEqual(weibull._batch_shape, torch.Size())
|
||||||
|
|
@ -3823,10 +3773,6 @@ class TestAgainstScipy(TestCase):
|
||||||
Uniform(random_var, random_var + positive_var),
|
Uniform(random_var, random_var + positive_var),
|
||||||
scipy.stats.uniform(random_var, positive_var)
|
scipy.stats.uniform(random_var, positive_var)
|
||||||
),
|
),
|
||||||
(
|
|
||||||
VonMises(random_var, positive_var),
|
|
||||||
scipy.stats.vonmises(positive_var, loc=random_var)
|
|
||||||
),
|
|
||||||
(
|
(
|
||||||
Weibull(positive_var[0], positive_var2[0]), # scipy var for Weibull only supports scalars
|
Weibull(positive_var[0], positive_var2[0]), # scipy var for Weibull only supports scalars
|
||||||
scipy.stats.weibull_min(c=positive_var2[0], scale=positive_var[0])
|
scipy.stats.weibull_min(c=positive_var2[0], scale=positive_var[0])
|
||||||
|
|
@ -3845,9 +3791,8 @@ class TestAgainstScipy(TestCase):
|
||||||
|
|
||||||
def test_variance_stddev(self):
|
def test_variance_stddev(self):
|
||||||
for pytorch_dist, scipy_dist in self.distribution_pairs:
|
for pytorch_dist, scipy_dist in self.distribution_pairs:
|
||||||
if isinstance(pytorch_dist, (Cauchy, HalfCauchy, VonMises)):
|
if isinstance(pytorch_dist, (Cauchy, HalfCauchy)):
|
||||||
# Cauchy, HalfCauchy distributions' standard deviation is nan, skipping check
|
# Cauchy, HalfCauchy distributions' standard deviation is nan, skipping check
|
||||||
# VonMises variance is circular and scipy doesn't produce a correct result
|
|
||||||
continue
|
continue
|
||||||
elif isinstance(pytorch_dist, (Multinomial, OneHotCategorical)):
|
elif isinstance(pytorch_dist, (Multinomial, OneHotCategorical)):
|
||||||
self.assertEqual(pytorch_dist.variance, np.diag(scipy_dist.cov()), message=pytorch_dist)
|
self.assertEqual(pytorch_dist.variance, np.diag(scipy_dist.cov()), message=pytorch_dist)
|
||||||
|
|
@ -4174,9 +4119,9 @@ class TestTransforms(TestCase):
|
||||||
|
|
||||||
class TestFunctors(TestCase):
|
class TestFunctors(TestCase):
|
||||||
def test_cat_transform(self):
|
def test_cat_transform(self):
|
||||||
x1 = -1 * torch.arange(1, 101, dtype=torch.float).view(-1, 100)
|
x1 = -1 * torch.range(1, 100).view(-1, 100)
|
||||||
x2 = (torch.arange(1, 101, dtype=torch.float).view(-1, 100) - 1) / 100
|
x2 = (torch.range(1, 100).view(-1, 100) - 1) / 100
|
||||||
x3 = torch.arange(1, 101, dtype=torch.float).view(-1, 100)
|
x3 = torch.range(1, 100).view(-1, 100)
|
||||||
t1, t2, t3 = ExpTransform(), AffineTransform(1, 100), identity_transform
|
t1, t2, t3 = ExpTransform(), AffineTransform(1, 100), identity_transform
|
||||||
dim = 0
|
dim = 0
|
||||||
x = torch.cat([x1, x2, x3], dim=dim)
|
x = torch.cat([x1, x2, x3], dim=dim)
|
||||||
|
|
@ -4189,9 +4134,9 @@ class TestFunctors(TestCase):
|
||||||
actual = t(x)
|
actual = t(x)
|
||||||
expected = torch.cat([t1(x1), t2(x2), t3(x3)], dim=dim)
|
expected = torch.cat([t1(x1), t2(x2), t3(x3)], dim=dim)
|
||||||
self.assertEqual(expected, actual)
|
self.assertEqual(expected, actual)
|
||||||
y1 = torch.arange(1, 101, dtype=torch.float).view(-1, 100)
|
y1 = torch.range(1, 100).view(-1, 100)
|
||||||
y2 = torch.arange(1, 101, dtype=torch.float).view(-1, 100)
|
y2 = torch.range(1, 100).view(-1, 100)
|
||||||
y3 = torch.arange(1, 101, dtype=torch.float).view(-1, 100)
|
y3 = torch.range(1, 100).view(-1, 100)
|
||||||
y = torch.cat([y1, y2, y3], dim=dim)
|
y = torch.cat([y1, y2, y3], dim=dim)
|
||||||
actual_cod_check = t.codomain.check(y)
|
actual_cod_check = t.codomain.check(y)
|
||||||
expected_cod_check = torch.cat([t1.codomain.check(y1),
|
expected_cod_check = torch.cat([t1.codomain.check(y1),
|
||||||
|
|
@ -4208,9 +4153,9 @@ class TestFunctors(TestCase):
|
||||||
self.assertEqual(actual_jac, expected_jac)
|
self.assertEqual(actual_jac, expected_jac)
|
||||||
|
|
||||||
def test_cat_transform_non_uniform(self):
|
def test_cat_transform_non_uniform(self):
|
||||||
x1 = -1 * torch.arange(1, 101, dtype=torch.float).view(-1, 100)
|
x1 = -1 * torch.range(1, 100).view(-1, 100)
|
||||||
x2 = torch.cat([(torch.arange(1, 101, dtype=torch.float).view(-1, 100) - 1) / 100,
|
x2 = torch.cat([(torch.range(1, 100).view(-1, 100) - 1) / 100,
|
||||||
torch.arange(1, 101, dtype=torch.float).view(-1, 100)])
|
torch.range(1, 100).view(-1, 100)])
|
||||||
t1 = ExpTransform()
|
t1 = ExpTransform()
|
||||||
t2 = CatTransform([AffineTransform(1, 100), identity_transform], dim=0)
|
t2 = CatTransform([AffineTransform(1, 100), identity_transform], dim=0)
|
||||||
dim = 0
|
dim = 0
|
||||||
|
|
@ -4223,9 +4168,9 @@ class TestFunctors(TestCase):
|
||||||
actual = t(x)
|
actual = t(x)
|
||||||
expected = torch.cat([t1(x1), t2(x2)], dim=dim)
|
expected = torch.cat([t1(x1), t2(x2)], dim=dim)
|
||||||
self.assertEqual(expected, actual)
|
self.assertEqual(expected, actual)
|
||||||
y1 = torch.arange(1, 101, dtype=torch.float).view(-1, 100)
|
y1 = torch.range(1, 100).view(-1, 100)
|
||||||
y2 = torch.cat([torch.arange(1, 101, dtype=torch.float).view(-1, 100),
|
y2 = torch.cat([torch.range(1, 100).view(-1, 100),
|
||||||
torch.arange(1, 101, dtype=torch.float).view(-1, 100)])
|
torch.range(1, 100).view(-1, 100)])
|
||||||
y = torch.cat([y1, y2], dim=dim)
|
y = torch.cat([y1, y2], dim=dim)
|
||||||
actual_cod_check = t.codomain.check(y)
|
actual_cod_check = t.codomain.check(y)
|
||||||
expected_cod_check = torch.cat([t1.codomain.check(y1),
|
expected_cod_check = torch.cat([t1.codomain.check(y1),
|
||||||
|
|
@ -4240,9 +4185,9 @@ class TestFunctors(TestCase):
|
||||||
self.assertEqual(actual_jac, expected_jac)
|
self.assertEqual(actual_jac, expected_jac)
|
||||||
|
|
||||||
def test_stack_transform(self):
|
def test_stack_transform(self):
|
||||||
x1 = -1 * torch.arange(1, 101, dtype=torch.float)
|
x1 = -1 * torch.range(1, 100)
|
||||||
x2 = (torch.arange(1, 101, dtype=torch.float) - 1) / 100
|
x2 = (torch.range(1, 100) - 1) / 100
|
||||||
x3 = torch.arange(1, 101, dtype=torch.float)
|
x3 = torch.range(1, 100)
|
||||||
t1, t2, t3 = ExpTransform(), AffineTransform(1, 100), identity_transform
|
t1, t2, t3 = ExpTransform(), AffineTransform(1, 100), identity_transform
|
||||||
dim = 0
|
dim = 0
|
||||||
x = torch.stack([x1, x2, x3], dim=dim)
|
x = torch.stack([x1, x2, x3], dim=dim)
|
||||||
|
|
@ -4255,9 +4200,9 @@ class TestFunctors(TestCase):
|
||||||
actual = t(x)
|
actual = t(x)
|
||||||
expected = torch.stack([t1(x1), t2(x2), t3(x3)], dim=dim)
|
expected = torch.stack([t1(x1), t2(x2), t3(x3)], dim=dim)
|
||||||
self.assertEqual(expected, actual)
|
self.assertEqual(expected, actual)
|
||||||
y1 = torch.arange(1, 101, dtype=torch.float)
|
y1 = torch.range(1, 100)
|
||||||
y2 = torch.arange(1, 101, dtype=torch.float)
|
y2 = torch.range(1, 100)
|
||||||
y3 = torch.arange(1, 101, dtype=torch.float)
|
y3 = torch.range(1, 100)
|
||||||
y = torch.stack([y1, y2, y3], dim=dim)
|
y = torch.stack([y1, y2, y3], dim=dim)
|
||||||
actual_cod_check = t.codomain.check(y)
|
actual_cod_check = t.codomain.check(y)
|
||||||
expected_cod_check = torch.stack([t1.codomain.check(y1),
|
expected_cod_check = torch.stack([t1.codomain.check(y1),
|
||||||
|
|
@ -4450,7 +4395,6 @@ class TestJit(TestCase):
|
||||||
xfail = [
|
xfail = [
|
||||||
Cauchy, # aten::cauchy(Double(2,1), float, float, Generator)
|
Cauchy, # aten::cauchy(Double(2,1), float, float, Generator)
|
||||||
HalfCauchy, # aten::cauchy(Double(2, 1), float, float, Generator)
|
HalfCauchy, # aten::cauchy(Double(2, 1), float, float, Generator)
|
||||||
VonMises # Variance is not Euclidean
|
|
||||||
]
|
]
|
||||||
if Dist in xfail:
|
if Dist in xfail:
|
||||||
continue
|
continue
|
||||||
|
|
|
||||||
|
|
@ -107,7 +107,6 @@ from .studentT import StudentT
|
||||||
from .transformed_distribution import TransformedDistribution
|
from .transformed_distribution import TransformedDistribution
|
||||||
from .transforms import *
|
from .transforms import *
|
||||||
from .uniform import Uniform
|
from .uniform import Uniform
|
||||||
from .von_mises import VonMises
|
|
||||||
from .weibull import Weibull
|
from .weibull import Weibull
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
|
@ -143,7 +142,6 @@ __all__ = [
|
||||||
'StudentT',
|
'StudentT',
|
||||||
'Poisson',
|
'Poisson',
|
||||||
'Uniform',
|
'Uniform',
|
||||||
'VonMises',
|
|
||||||
'Weibull',
|
'Weibull',
|
||||||
'TransformedDistribution',
|
'TransformedDistribution',
|
||||||
'biject_to',
|
'biject_to',
|
||||||
|
|
|
||||||
|
|
@ -1,142 +0,0 @@
|
||||||
from __future__ import absolute_import, division, print_function
|
|
||||||
|
|
||||||
import math
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.jit
|
|
||||||
from torch.distributions import constraints
|
|
||||||
from torch.distributions.distribution import Distribution
|
|
||||||
from torch.distributions.utils import broadcast_all, lazy_property
|
|
||||||
|
|
||||||
|
|
||||||
def _eval_poly(y, coef):
|
|
||||||
coef = list(coef)
|
|
||||||
result = coef.pop()
|
|
||||||
while coef:
|
|
||||||
result = coef.pop() + y * result
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
_I0_COEF_SMALL = [1.0, 3.5156229, 3.0899424, 1.2067492, 0.2659732, 0.360768e-1, 0.45813e-2]
|
|
||||||
_I0_COEF_LARGE = [0.39894228, 0.1328592e-1, 0.225319e-2, -0.157565e-2, 0.916281e-2,
|
|
||||||
-0.2057706e-1, 0.2635537e-1, -0.1647633e-1, 0.392377e-2]
|
|
||||||
_I1_COEF_SMALL = [0.5, 0.87890594, 0.51498869, 0.15084934, 0.2658733e-1, 0.301532e-2, 0.32411e-3]
|
|
||||||
_I1_COEF_LARGE = [0.39894228, -0.3988024e-1, -0.362018e-2, 0.163801e-2, -0.1031555e-1,
|
|
||||||
0.2282967e-1, -0.2895312e-1, 0.1787654e-1, -0.420059e-2]
|
|
||||||
|
|
||||||
_COEF_SMALL = [_I0_COEF_SMALL, _I1_COEF_SMALL]
|
|
||||||
_COEF_LARGE = [_I0_COEF_LARGE, _I1_COEF_LARGE]
|
|
||||||
|
|
||||||
|
|
||||||
def _log_modified_bessel_fn(x, order=0):
|
|
||||||
"""
|
|
||||||
Returns ``log(I_order(x))`` for ``x > 0``,
|
|
||||||
where `order` is either 0 or 1.
|
|
||||||
"""
|
|
||||||
assert order == 0 or order == 1
|
|
||||||
|
|
||||||
# compute small solution
|
|
||||||
y = (x / 3.75).pow(2)
|
|
||||||
small = _eval_poly(y, _COEF_SMALL[order])
|
|
||||||
if order == 1:
|
|
||||||
small = x.abs() * small
|
|
||||||
small = small.log()
|
|
||||||
|
|
||||||
# compute large solution
|
|
||||||
y = 3.75 / x
|
|
||||||
large = x - 0.5 * x.log() + _eval_poly(y, _COEF_LARGE[order]).log()
|
|
||||||
|
|
||||||
mask = (x < 3.75)
|
|
||||||
result = large
|
|
||||||
if mask.any():
|
|
||||||
result[mask] = small[mask]
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
|
||||||
def _rejection_sample(loc, concentration, proposal_r, x):
|
|
||||||
done = torch.zeros(x.shape, dtype=torch.bool, device=loc.device)
|
|
||||||
while not done.all():
|
|
||||||
u = torch.rand((3,) + x.shape, dtype=loc.dtype, device=loc.device)
|
|
||||||
u1, u2, u3 = u.unbind()
|
|
||||||
z = torch.cos(math.pi * u1)
|
|
||||||
f = (1 + proposal_r * z) / (proposal_r + z)
|
|
||||||
c = concentration * (proposal_r - f)
|
|
||||||
accept = ((c * (2 - c) - u2) > 0) | ((c / u2).log() + 1 - c >= 0)
|
|
||||||
if accept.any():
|
|
||||||
x = torch.where(accept, (u3 - 0.5).sign() * f.acos(), x)
|
|
||||||
done = done | accept
|
|
||||||
return (x + math.pi + loc) % (2 * math.pi) - math.pi
|
|
||||||
|
|
||||||
|
|
||||||
class VonMises(Distribution):
|
|
||||||
"""
|
|
||||||
A circular von Mises distribution.
|
|
||||||
|
|
||||||
This implementation uses polar coordinates. The ``loc`` and ``value`` args
|
|
||||||
can be any real number (to facilitate unconstrained optimization), but are
|
|
||||||
interpreted as angles modulo 2 pi.
|
|
||||||
|
|
||||||
Example::
|
|
||||||
>>> m = dist.VonMises(torch.tensor([1.0]), torch.tensor([1.0]))
|
|
||||||
>>> m.sample() # von Mises distributed with loc=1 and concentration=1
|
|
||||||
tensor([1.9777])
|
|
||||||
|
|
||||||
:param torch.Tensor loc: an angle in radians.
|
|
||||||
:param torch.Tensor concentration: concentration parameter
|
|
||||||
"""
|
|
||||||
arg_constraints = {'loc': constraints.real, 'concentration': constraints.positive}
|
|
||||||
support = constraints.real
|
|
||||||
has_rsample = False
|
|
||||||
|
|
||||||
def __init__(self, loc, concentration, validate_args=None):
|
|
||||||
self.loc, self.concentration = broadcast_all(loc, concentration)
|
|
||||||
batch_shape = self.loc.shape
|
|
||||||
event_shape = torch.Size()
|
|
||||||
|
|
||||||
# Parameters for sampling
|
|
||||||
tau = 1 + (1 + 4 * self.concentration ** 2).sqrt()
|
|
||||||
rho = (tau - (2 * tau).sqrt()) / (2 * self.concentration)
|
|
||||||
self._proposal_r = (1 + rho ** 2) / (2 * rho)
|
|
||||||
|
|
||||||
super(VonMises, self).__init__(batch_shape, event_shape, validate_args)
|
|
||||||
|
|
||||||
def log_prob(self, value):
|
|
||||||
log_prob = self.concentration * torch.cos(value - self.loc)
|
|
||||||
log_prob = log_prob - math.log(2 * math.pi) - _log_modified_bessel_fn(self.concentration, order=0)
|
|
||||||
return log_prob
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def sample(self, sample_shape=torch.Size()):
|
|
||||||
"""
|
|
||||||
The sampling algorithm for the von Mises distribution is based on the following paper:
|
|
||||||
Best, D. J., and Nicholas I. Fisher.
|
|
||||||
"Efficient simulation of the von Mises distribution." Applied Statistics (1979): 152-157.
|
|
||||||
"""
|
|
||||||
shape = self._extended_shape(sample_shape)
|
|
||||||
x = torch.empty(shape, dtype=self.loc.dtype, device=self.loc.device)
|
|
||||||
return _rejection_sample(self.loc, self.concentration, self._proposal_r, x)
|
|
||||||
|
|
||||||
def expand(self, batch_shape):
|
|
||||||
try:
|
|
||||||
return super(VonMises, self).expand(batch_shape)
|
|
||||||
except NotImplementedError:
|
|
||||||
validate_args = self.__dict__.get('_validate_args')
|
|
||||||
loc = self.loc.expand(batch_shape)
|
|
||||||
concentration = self.concentration.expand(batch_shape)
|
|
||||||
return type(self)(loc, concentration, validate_args=validate_args)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def mean(self):
|
|
||||||
"""
|
|
||||||
The provided mean is the circular one.
|
|
||||||
"""
|
|
||||||
return self.loc
|
|
||||||
|
|
||||||
@lazy_property
|
|
||||||
def variance(self):
|
|
||||||
"""
|
|
||||||
The provided variance is the circular one.
|
|
||||||
"""
|
|
||||||
return 1 - (_log_modified_bessel_fn(self.concentration, order=1) -
|
|
||||||
_log_modified_bessel_fn(self.concentration, order=0)).exp()
|
|
||||||
Loading…
Reference in New Issue
Block a user