Revert D19899550: [pytorch][PR] Second try on Von Mises: Make it JIT compatible

Test Plan: revert-hammer

Differential Revision:
D19899550

Original commit changeset: fbcdd9bc9143

fbshipit-source-id: c8a675a8b53f884acd0e6c57bc7aa15faf83d5d6
This commit is contained in:
George Guanheng Zhang 2020-02-14 08:40:22 -08:00 committed by Facebook Github Bot
parent ff5f38f53b
commit 0c98939b7b
4 changed files with 21 additions and 230 deletions

View File

@ -302,15 +302,6 @@ Probability distributions - torch.distributions
:undoc-members: :undoc-members:
:show-inheritance: :show-inheritance:
:hidden:`VonMises`
~~~~~~~~~~~~~~~~~~~~~~~
.. currentmodule:: torch.distributions.von_mises
.. autoclass:: VonMises
:members:
:undoc-members:
:show-inheritance:
:hidden:`Weibull` :hidden:`Weibull`
~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -50,7 +50,7 @@ from torch.distributions import (Bernoulli, Beta, Binomial, Categorical,
NegativeBinomial, Normal, OneHotCategorical, Pareto, NegativeBinomial, Normal, OneHotCategorical, Pareto,
Poisson, RelaxedBernoulli, RelaxedOneHotCategorical, Poisson, RelaxedBernoulli, RelaxedOneHotCategorical,
StudentT, TransformedDistribution, Uniform, StudentT, TransformedDistribution, Uniform,
VonMises, Weibull, constraints, kl_divergence) Weibull, constraints, kl_divergence)
from torch.distributions.constraint_registry import biject_to, transform_to from torch.distributions.constraint_registry import biject_to, transform_to
from torch.distributions.constraints import Constraint, is_dependent from torch.distributions.constraints import Constraint, is_dependent
from torch.distributions.dirichlet import _Dirichlet_backward from torch.distributions.dirichlet import _Dirichlet_backward
@ -425,16 +425,6 @@ EXAMPLES = [
'high': torch.tensor([2.0, 3.0], requires_grad=True), 'high': torch.tensor([2.0, 3.0], requires_grad=True),
}, },
]), ]),
Example(VonMises, [
{
'loc': torch.tensor(1.0, requires_grad=True),
'concentration': torch.tensor(10.0, requires_grad=True)
},
{
'loc': torch.tensor([0.0, math.pi / 2], requires_grad=True),
'concentration': torch.tensor([1.0, 10.0], requires_grad=True)
}
]),
Example(Weibull, [ Example(Weibull, [
{ {
'scale': torch.randn(5, 5).abs().requires_grad_(), 'scale': torch.randn(5, 5).abs().requires_grad_(),
@ -693,7 +683,7 @@ class TestDistributions(TestCase):
asset_fn(i, val.squeeze(), log_prob) asset_fn(i, val.squeeze(), log_prob)
def _check_sampler_sampler(self, torch_dist, ref_dist, message, multivariate=False, def _check_sampler_sampler(self, torch_dist, ref_dist, message, multivariate=False,
circular=False, num_samples=10000, failure_rate=1e-3): num_samples=10000, failure_rate=1e-3):
# Checks that the .sample() method matches a reference function. # Checks that the .sample() method matches a reference function.
torch_samples = torch_dist.sample((num_samples,)).squeeze() torch_samples = torch_dist.sample((num_samples,)).squeeze()
torch_samples = torch_samples.cpu().numpy() torch_samples = torch_samples.cpu().numpy()
@ -705,8 +695,6 @@ class TestDistributions(TestCase):
torch_samples = np.dot(torch_samples, axis) torch_samples = np.dot(torch_samples, axis)
ref_samples = np.dot(ref_samples, axis) ref_samples = np.dot(ref_samples, axis)
samples = [(x, +1) for x in torch_samples] + [(x, -1) for x in ref_samples] samples = [(x, +1) for x in torch_samples] + [(x, -1) for x in ref_samples]
if circular:
samples = [(np.cos(x), v) for (x, v) in samples]
shuffle(samples) # necessary to prevent stable sort from making uneven bins for discrete shuffle(samples) # necessary to prevent stable sort from making uneven bins for discrete
samples.sort(key=lambda x: x[0]) samples.sort(key=lambda x: x[0])
samples = np.array(samples)[:, 1] samples = np.array(samples)[:, 1]
@ -1360,23 +1348,6 @@ class TestDistributions(TestCase):
low.grad.zero_() low.grad.zero_()
high.grad.zero_() high.grad.zero_()
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_vonmises_sample(self):
for loc in [0.0, math.pi / 2.0]:
for concentration in [0.03, 0.3, 1.0, 10.0, 100.0]:
self._check_sampler_sampler(VonMises(loc, concentration),
scipy.stats.vonmises(loc=loc, kappa=concentration),
"VonMises(loc={}, concentration={})".format(loc, concentration),
num_samples=int(1e5), circular=True)
def test_vonmises_logprob(self):
concentrations = [0.01, 0.03, 0.1, 0.3, 1.0, 3.0, 10.0, 30.0, 100.0]
for concentration in concentrations:
grid = torch.arange(0., 2 * math.pi, 1e-4)
prob = VonMises(0.0, concentration).log_prob(grid).exp()
norm = prob.mean().item() * 2 * math.pi
self.assertLess(abs(norm - 1), 1e-3)
def test_cauchy(self): def test_cauchy(self):
loc = torch.zeros(5, 5, requires_grad=True) loc = torch.zeros(5, 5, requires_grad=True)
scale = torch.ones(5, 5, requires_grad=True) scale = torch.ones(5, 5, requires_grad=True)
@ -3052,27 +3023,6 @@ class TestDistributionShapes(TestCase):
self.assertEqual(gumbel.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))) self.assertEqual(gumbel.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertEqual(gumbel.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3))) self.assertEqual(gumbel.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)))
def test_vonmises_shape_tensor_params(self):
von_mises = VonMises(torch.tensor([0., 0.]), torch.tensor([1., 1.]))
self.assertEqual(von_mises._batch_shape, torch.Size((2,)))
self.assertEqual(von_mises._event_shape, torch.Size(()))
self.assertEqual(von_mises.sample().size(), torch.Size((2,)))
self.assertEqual(von_mises.sample(torch.Size((3, 2))).size(), torch.Size((3, 2, 2)))
self.assertEqual(von_mises.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertEqual(von_mises.log_prob(torch.ones(2, 1)).size(), torch.Size((2, 2)))
def test_vonmises_shape_scalar_params(self):
von_mises = VonMises(0., 1.)
self.assertEqual(von_mises._batch_shape, torch.Size())
self.assertEqual(von_mises._event_shape, torch.Size())
self.assertEqual(von_mises.sample().size(), torch.Size())
self.assertEqual(von_mises.sample(torch.Size((3, 2))).size(),
torch.Size((3, 2)))
self.assertEqual(von_mises.log_prob(self.tensor_sample_1).size(),
torch.Size((3, 2)))
self.assertEqual(von_mises.log_prob(self.tensor_sample_2).size(),
torch.Size((3, 2, 3)))
def test_weibull_scale_scalar_params(self): def test_weibull_scale_scalar_params(self):
weibull = Weibull(1, 1) weibull = Weibull(1, 1)
self.assertEqual(weibull._batch_shape, torch.Size()) self.assertEqual(weibull._batch_shape, torch.Size())
@ -3823,10 +3773,6 @@ class TestAgainstScipy(TestCase):
Uniform(random_var, random_var + positive_var), Uniform(random_var, random_var + positive_var),
scipy.stats.uniform(random_var, positive_var) scipy.stats.uniform(random_var, positive_var)
), ),
(
VonMises(random_var, positive_var),
scipy.stats.vonmises(positive_var, loc=random_var)
),
( (
Weibull(positive_var[0], positive_var2[0]), # scipy var for Weibull only supports scalars Weibull(positive_var[0], positive_var2[0]), # scipy var for Weibull only supports scalars
scipy.stats.weibull_min(c=positive_var2[0], scale=positive_var[0]) scipy.stats.weibull_min(c=positive_var2[0], scale=positive_var[0])
@ -3845,9 +3791,8 @@ class TestAgainstScipy(TestCase):
def test_variance_stddev(self): def test_variance_stddev(self):
for pytorch_dist, scipy_dist in self.distribution_pairs: for pytorch_dist, scipy_dist in self.distribution_pairs:
if isinstance(pytorch_dist, (Cauchy, HalfCauchy, VonMises)): if isinstance(pytorch_dist, (Cauchy, HalfCauchy)):
# Cauchy, HalfCauchy distributions' standard deviation is nan, skipping check # Cauchy, HalfCauchy distributions' standard deviation is nan, skipping check
# VonMises variance is circular and scipy doesn't produce a correct result
continue continue
elif isinstance(pytorch_dist, (Multinomial, OneHotCategorical)): elif isinstance(pytorch_dist, (Multinomial, OneHotCategorical)):
self.assertEqual(pytorch_dist.variance, np.diag(scipy_dist.cov()), message=pytorch_dist) self.assertEqual(pytorch_dist.variance, np.diag(scipy_dist.cov()), message=pytorch_dist)
@ -4174,9 +4119,9 @@ class TestTransforms(TestCase):
class TestFunctors(TestCase): class TestFunctors(TestCase):
def test_cat_transform(self): def test_cat_transform(self):
x1 = -1 * torch.arange(1, 101, dtype=torch.float).view(-1, 100) x1 = -1 * torch.range(1, 100).view(-1, 100)
x2 = (torch.arange(1, 101, dtype=torch.float).view(-1, 100) - 1) / 100 x2 = (torch.range(1, 100).view(-1, 100) - 1) / 100
x3 = torch.arange(1, 101, dtype=torch.float).view(-1, 100) x3 = torch.range(1, 100).view(-1, 100)
t1, t2, t3 = ExpTransform(), AffineTransform(1, 100), identity_transform t1, t2, t3 = ExpTransform(), AffineTransform(1, 100), identity_transform
dim = 0 dim = 0
x = torch.cat([x1, x2, x3], dim=dim) x = torch.cat([x1, x2, x3], dim=dim)
@ -4189,9 +4134,9 @@ class TestFunctors(TestCase):
actual = t(x) actual = t(x)
expected = torch.cat([t1(x1), t2(x2), t3(x3)], dim=dim) expected = torch.cat([t1(x1), t2(x2), t3(x3)], dim=dim)
self.assertEqual(expected, actual) self.assertEqual(expected, actual)
y1 = torch.arange(1, 101, dtype=torch.float).view(-1, 100) y1 = torch.range(1, 100).view(-1, 100)
y2 = torch.arange(1, 101, dtype=torch.float).view(-1, 100) y2 = torch.range(1, 100).view(-1, 100)
y3 = torch.arange(1, 101, dtype=torch.float).view(-1, 100) y3 = torch.range(1, 100).view(-1, 100)
y = torch.cat([y1, y2, y3], dim=dim) y = torch.cat([y1, y2, y3], dim=dim)
actual_cod_check = t.codomain.check(y) actual_cod_check = t.codomain.check(y)
expected_cod_check = torch.cat([t1.codomain.check(y1), expected_cod_check = torch.cat([t1.codomain.check(y1),
@ -4208,9 +4153,9 @@ class TestFunctors(TestCase):
self.assertEqual(actual_jac, expected_jac) self.assertEqual(actual_jac, expected_jac)
def test_cat_transform_non_uniform(self): def test_cat_transform_non_uniform(self):
x1 = -1 * torch.arange(1, 101, dtype=torch.float).view(-1, 100) x1 = -1 * torch.range(1, 100).view(-1, 100)
x2 = torch.cat([(torch.arange(1, 101, dtype=torch.float).view(-1, 100) - 1) / 100, x2 = torch.cat([(torch.range(1, 100).view(-1, 100) - 1) / 100,
torch.arange(1, 101, dtype=torch.float).view(-1, 100)]) torch.range(1, 100).view(-1, 100)])
t1 = ExpTransform() t1 = ExpTransform()
t2 = CatTransform([AffineTransform(1, 100), identity_transform], dim=0) t2 = CatTransform([AffineTransform(1, 100), identity_transform], dim=0)
dim = 0 dim = 0
@ -4223,9 +4168,9 @@ class TestFunctors(TestCase):
actual = t(x) actual = t(x)
expected = torch.cat([t1(x1), t2(x2)], dim=dim) expected = torch.cat([t1(x1), t2(x2)], dim=dim)
self.assertEqual(expected, actual) self.assertEqual(expected, actual)
y1 = torch.arange(1, 101, dtype=torch.float).view(-1, 100) y1 = torch.range(1, 100).view(-1, 100)
y2 = torch.cat([torch.arange(1, 101, dtype=torch.float).view(-1, 100), y2 = torch.cat([torch.range(1, 100).view(-1, 100),
torch.arange(1, 101, dtype=torch.float).view(-1, 100)]) torch.range(1, 100).view(-1, 100)])
y = torch.cat([y1, y2], dim=dim) y = torch.cat([y1, y2], dim=dim)
actual_cod_check = t.codomain.check(y) actual_cod_check = t.codomain.check(y)
expected_cod_check = torch.cat([t1.codomain.check(y1), expected_cod_check = torch.cat([t1.codomain.check(y1),
@ -4240,9 +4185,9 @@ class TestFunctors(TestCase):
self.assertEqual(actual_jac, expected_jac) self.assertEqual(actual_jac, expected_jac)
def test_stack_transform(self): def test_stack_transform(self):
x1 = -1 * torch.arange(1, 101, dtype=torch.float) x1 = -1 * torch.range(1, 100)
x2 = (torch.arange(1, 101, dtype=torch.float) - 1) / 100 x2 = (torch.range(1, 100) - 1) / 100
x3 = torch.arange(1, 101, dtype=torch.float) x3 = torch.range(1, 100)
t1, t2, t3 = ExpTransform(), AffineTransform(1, 100), identity_transform t1, t2, t3 = ExpTransform(), AffineTransform(1, 100), identity_transform
dim = 0 dim = 0
x = torch.stack([x1, x2, x3], dim=dim) x = torch.stack([x1, x2, x3], dim=dim)
@ -4255,9 +4200,9 @@ class TestFunctors(TestCase):
actual = t(x) actual = t(x)
expected = torch.stack([t1(x1), t2(x2), t3(x3)], dim=dim) expected = torch.stack([t1(x1), t2(x2), t3(x3)], dim=dim)
self.assertEqual(expected, actual) self.assertEqual(expected, actual)
y1 = torch.arange(1, 101, dtype=torch.float) y1 = torch.range(1, 100)
y2 = torch.arange(1, 101, dtype=torch.float) y2 = torch.range(1, 100)
y3 = torch.arange(1, 101, dtype=torch.float) y3 = torch.range(1, 100)
y = torch.stack([y1, y2, y3], dim=dim) y = torch.stack([y1, y2, y3], dim=dim)
actual_cod_check = t.codomain.check(y) actual_cod_check = t.codomain.check(y)
expected_cod_check = torch.stack([t1.codomain.check(y1), expected_cod_check = torch.stack([t1.codomain.check(y1),
@ -4450,7 +4395,6 @@ class TestJit(TestCase):
xfail = [ xfail = [
Cauchy, # aten::cauchy(Double(2,1), float, float, Generator) Cauchy, # aten::cauchy(Double(2,1), float, float, Generator)
HalfCauchy, # aten::cauchy(Double(2, 1), float, float, Generator) HalfCauchy, # aten::cauchy(Double(2, 1), float, float, Generator)
VonMises # Variance is not Euclidean
] ]
if Dist in xfail: if Dist in xfail:
continue continue

View File

@ -107,7 +107,6 @@ from .studentT import StudentT
from .transformed_distribution import TransformedDistribution from .transformed_distribution import TransformedDistribution
from .transforms import * from .transforms import *
from .uniform import Uniform from .uniform import Uniform
from .von_mises import VonMises
from .weibull import Weibull from .weibull import Weibull
__all__ = [ __all__ = [
@ -143,7 +142,6 @@ __all__ = [
'StudentT', 'StudentT',
'Poisson', 'Poisson',
'Uniform', 'Uniform',
'VonMises',
'Weibull', 'Weibull',
'TransformedDistribution', 'TransformedDistribution',
'biject_to', 'biject_to',

View File

@ -1,142 +0,0 @@
from __future__ import absolute_import, division, print_function
import math
import torch
import torch.jit
from torch.distributions import constraints
from torch.distributions.distribution import Distribution
from torch.distributions.utils import broadcast_all, lazy_property
def _eval_poly(y, coef):
coef = list(coef)
result = coef.pop()
while coef:
result = coef.pop() + y * result
return result
_I0_COEF_SMALL = [1.0, 3.5156229, 3.0899424, 1.2067492, 0.2659732, 0.360768e-1, 0.45813e-2]
_I0_COEF_LARGE = [0.39894228, 0.1328592e-1, 0.225319e-2, -0.157565e-2, 0.916281e-2,
-0.2057706e-1, 0.2635537e-1, -0.1647633e-1, 0.392377e-2]
_I1_COEF_SMALL = [0.5, 0.87890594, 0.51498869, 0.15084934, 0.2658733e-1, 0.301532e-2, 0.32411e-3]
_I1_COEF_LARGE = [0.39894228, -0.3988024e-1, -0.362018e-2, 0.163801e-2, -0.1031555e-1,
0.2282967e-1, -0.2895312e-1, 0.1787654e-1, -0.420059e-2]
_COEF_SMALL = [_I0_COEF_SMALL, _I1_COEF_SMALL]
_COEF_LARGE = [_I0_COEF_LARGE, _I1_COEF_LARGE]
def _log_modified_bessel_fn(x, order=0):
"""
Returns ``log(I_order(x))`` for ``x > 0``,
where `order` is either 0 or 1.
"""
assert order == 0 or order == 1
# compute small solution
y = (x / 3.75).pow(2)
small = _eval_poly(y, _COEF_SMALL[order])
if order == 1:
small = x.abs() * small
small = small.log()
# compute large solution
y = 3.75 / x
large = x - 0.5 * x.log() + _eval_poly(y, _COEF_LARGE[order]).log()
mask = (x < 3.75)
result = large
if mask.any():
result[mask] = small[mask]
return result
@torch.jit.script
def _rejection_sample(loc, concentration, proposal_r, x):
done = torch.zeros(x.shape, dtype=torch.bool, device=loc.device)
while not done.all():
u = torch.rand((3,) + x.shape, dtype=loc.dtype, device=loc.device)
u1, u2, u3 = u.unbind()
z = torch.cos(math.pi * u1)
f = (1 + proposal_r * z) / (proposal_r + z)
c = concentration * (proposal_r - f)
accept = ((c * (2 - c) - u2) > 0) | ((c / u2).log() + 1 - c >= 0)
if accept.any():
x = torch.where(accept, (u3 - 0.5).sign() * f.acos(), x)
done = done | accept
return (x + math.pi + loc) % (2 * math.pi) - math.pi
class VonMises(Distribution):
"""
A circular von Mises distribution.
This implementation uses polar coordinates. The ``loc`` and ``value`` args
can be any real number (to facilitate unconstrained optimization), but are
interpreted as angles modulo 2 pi.
Example::
>>> m = dist.VonMises(torch.tensor([1.0]), torch.tensor([1.0]))
>>> m.sample() # von Mises distributed with loc=1 and concentration=1
tensor([1.9777])
:param torch.Tensor loc: an angle in radians.
:param torch.Tensor concentration: concentration parameter
"""
arg_constraints = {'loc': constraints.real, 'concentration': constraints.positive}
support = constraints.real
has_rsample = False
def __init__(self, loc, concentration, validate_args=None):
self.loc, self.concentration = broadcast_all(loc, concentration)
batch_shape = self.loc.shape
event_shape = torch.Size()
# Parameters for sampling
tau = 1 + (1 + 4 * self.concentration ** 2).sqrt()
rho = (tau - (2 * tau).sqrt()) / (2 * self.concentration)
self._proposal_r = (1 + rho ** 2) / (2 * rho)
super(VonMises, self).__init__(batch_shape, event_shape, validate_args)
def log_prob(self, value):
log_prob = self.concentration * torch.cos(value - self.loc)
log_prob = log_prob - math.log(2 * math.pi) - _log_modified_bessel_fn(self.concentration, order=0)
return log_prob
@torch.no_grad()
def sample(self, sample_shape=torch.Size()):
"""
The sampling algorithm for the von Mises distribution is based on the following paper:
Best, D. J., and Nicholas I. Fisher.
"Efficient simulation of the von Mises distribution." Applied Statistics (1979): 152-157.
"""
shape = self._extended_shape(sample_shape)
x = torch.empty(shape, dtype=self.loc.dtype, device=self.loc.device)
return _rejection_sample(self.loc, self.concentration, self._proposal_r, x)
def expand(self, batch_shape):
try:
return super(VonMises, self).expand(batch_shape)
except NotImplementedError:
validate_args = self.__dict__.get('_validate_args')
loc = self.loc.expand(batch_shape)
concentration = self.concentration.expand(batch_shape)
return type(self)(loc, concentration, validate_args=validate_args)
@property
def mean(self):
"""
The provided mean is the circular one.
"""
return self.loc
@lazy_property
def variance(self):
"""
The provided variance is the circular one.
"""
return 1 - (_log_modified_bessel_fn(self.concentration, order=1) -
_log_modified_bessel_fn(self.concentration, order=0)).exp()