mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
LowRankMultivariateNormal cleanup
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/11179 Differential Revision: D9627502 Pulled By: soumith fbshipit-source-id: c7a4aa8be24bd8c688a7c655ff25ca901ed19704
This commit is contained in:
parent
4d28b65fb8
commit
abe8b3391d
|
|
@ -128,6 +128,7 @@ __all__ = [
|
||||||
'Laplace',
|
'Laplace',
|
||||||
'LogNormal',
|
'LogNormal',
|
||||||
'LogisticNormal',
|
'LogisticNormal',
|
||||||
|
'LowRankMultivariateNormal',
|
||||||
'Multinomial',
|
'Multinomial',
|
||||||
'MultivariateNormal',
|
'MultivariateNormal',
|
||||||
'NegativeBinomial',
|
'NegativeBinomial',
|
||||||
|
|
|
||||||
|
|
@ -62,7 +62,7 @@ class LowRankMultivariateNormal(Distribution):
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
>>> m = MultivariateNormal(torch.zeros(2), torch.tensor([1, 0]), torch.tensor([1, 1]))
|
>>> m = LowRankMultivariateNormal(torch.zeros(2), torch.tensor([1, 0]), torch.tensor([1, 1]))
|
||||||
>>> m.sample() # normally distributed with mean=`[0,0]`, cov_factor=`[1,0]`, cov_diag=`[1,1]`
|
>>> m.sample() # normally distributed with mean=`[0,0]`, cov_factor=`[1,0]`, cov_diag=`[1,1]`
|
||||||
tensor([-0.2102, -0.5429])
|
tensor([-0.2102, -0.5429])
|
||||||
|
|
||||||
|
|
@ -120,7 +120,7 @@ class LowRankMultivariateNormal(Distribution):
|
||||||
def mean(self):
|
def mean(self):
|
||||||
return self.loc
|
return self.loc
|
||||||
|
|
||||||
@property
|
@lazy_property
|
||||||
def variance(self):
|
def variance(self):
|
||||||
return self.cov_factor.pow(2).sum(-1) + self.cov_diag
|
return self.cov_factor.pow(2).sum(-1) + self.cov_diag
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user