mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
optimize exp_family methods and reduce redundant computation
This commit is contained in:
parent
71f0a768cf
commit
3a37173665
|
|
@ -277,16 +277,8 @@ class Wishart(ExponentialFamily):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _natural_params(self):
|
def _natural_params(self):
|
||||||
nu = self.df # has shape (batch_shape)
|
return self.precision_matrix, self.df
|
||||||
p = self._event_shape[-1] # has singleton shape
|
|
||||||
return (
|
|
||||||
- 0.5 * self.precision_matrix,
|
|
||||||
0.5 * nu,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _log_normalizer(self, x, y):
|
def _log_normalizer(self, x, y):
|
||||||
p = self._event_shape[-1]
|
p = self._event_shape[-1]
|
||||||
return (
|
return 0.5 * y * (- torch.linalg.slogdet(-x).logabsdet + _log_2 * p) + torch.mvlgamma(y, p=p)
|
||||||
y * (- torch.linalg.slogdet(-2 * x).logabsdet + _log_2 * p)
|
|
||||||
+ torch.mvlgamma(y, p=p)
|
|
||||||
)
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user