From 3a37173665556c1010f535ed35965b95f4ab9013 Mon Sep 17 00:00:00 2001 From: Kim Juhyeong Date: Wed, 23 Feb 2022 23:11:04 +0900 Subject: [PATCH] optimize exp_family methods and reduce redundant computation --- torch/distributions/wishart.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/torch/distributions/wishart.py b/torch/distributions/wishart.py index af049724f43..9cc3c411d1b 100644 --- a/torch/distributions/wishart.py +++ b/torch/distributions/wishart.py @@ -277,16 +277,8 @@ class Wishart(ExponentialFamily): @property def _natural_params(self): - nu = self.df # has shape (batch_shape) - p = self._event_shape[-1] # has singleton shape - return ( - - 0.5 * self.precision_matrix, - 0.5 * nu, - ) + return self.precision_matrix, self.df def _log_normalizer(self, x, y): p = self._event_shape[-1] - return ( - y * (- torch.linalg.slogdet(-2 * x).logabsdet + _log_2 * p) - + torch.mvlgamma(y, p=p) - ) + return 0.5 * y * (- torch.linalg.slogdet(-x).logabsdet + _log_2 * p) + torch.mvlgamma(y, p=p)