diff --git a/torch/distributions/wishart.py b/torch/distributions/wishart.py index af049724f43..9cc3c411d1b 100644 --- a/torch/distributions/wishart.py +++ b/torch/distributions/wishart.py @@ -277,16 +277,8 @@ class Wishart(ExponentialFamily): @property def _natural_params(self): - nu = self.df # has shape (batch_shape) - p = self._event_shape[-1] # has singleton shape - return ( - - 0.5 * self.precision_matrix, - 0.5 * nu, - ) + return self.precision_matrix, self.df def _log_normalizer(self, x, y): p = self._event_shape[-1] - return ( - y * (- torch.linalg.slogdet(-2 * x).logabsdet + _log_2 * p) - + torch.mvlgamma(y, p=p) - ) + return 0.5 * y * (- torch.linalg.slogdet(-x).logabsdet + _log_2 * p) + torch.mvlgamma(y, p=p)