optimize exp_family methods and reduce redundant computation

This commit is contained in:
Kim Juhyeong 2022-02-23 23:11:04 +09:00
parent 71f0a768cf
commit 3a37173665

View File

@ -277,16 +277,8 @@ class Wishart(ExponentialFamily):
@property @property
def _natural_params(self): def _natural_params(self):
nu = self.df # has shape (batch_shape) return self.precision_matrix, self.df
p = self._event_shape[-1] # has singleton shape
return (
- 0.5 * self.precision_matrix,
0.5 * nu,
)
def _log_normalizer(self, x, y): def _log_normalizer(self, x, y):
p = self._event_shape[-1] p = self._event_shape[-1]
return ( return 0.5 * y * (- torch.linalg.slogdet(-x).logabsdet + _log_2 * p) + torch.mvlgamma(y, p=p)
y * (- torch.linalg.slogdet(-2 * x).logabsdet + _log_2 * p)
+ torch.mvlgamma(y, p=p)
)