LowRankMultivariateNormal cleanup

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/11179

Differential Revision: D9627502

Pulled By: soumith

fbshipit-source-id: c7a4aa8be24bd8c688a7c655ff25ca901ed19704
This commit is contained in:
Samuel Ainsworth 2018-09-02 07:46:05 -07:00 committed by Facebook Github Bot
parent 4d28b65fb8
commit abe8b3391d
2 changed files with 3 additions and 2 deletions

View File

@ -128,6 +128,7 @@ __all__ = [
'Laplace', 'Laplace',
'LogNormal', 'LogNormal',
'LogisticNormal', 'LogisticNormal',
'LowRankMultivariateNormal',
'Multinomial', 'Multinomial',
'MultivariateNormal', 'MultivariateNormal',
'NegativeBinomial', 'NegativeBinomial',

View File

@ -62,7 +62,7 @@ class LowRankMultivariateNormal(Distribution):
Example: Example:
>>> m = MultivariateNormal(torch.zeros(2), torch.tensor([1, 0]), torch.tensor([1, 1])) >>> m = LowRankMultivariateNormal(torch.zeros(2), torch.tensor([1, 0]), torch.tensor([1, 1]))
>>> m.sample() # normally distributed with mean=`[0,0]`, cov_factor=`[1,0]`, cov_diag=`[1,1]` >>> m.sample() # normally distributed with mean=`[0,0]`, cov_factor=`[1,0]`, cov_diag=`[1,1]`
tensor([-0.2102, -0.5429]) tensor([-0.2102, -0.5429])
@ -120,7 +120,7 @@ class LowRankMultivariateNormal(Distribution):
def mean(self): def mean(self):
return self.loc return self.loc
@property @lazy_property
def variance(self): def variance(self):
return self.cov_factor.pow(2).sum(-1) + self.cov_diag return self.cov_factor.pow(2).sum(-1) + self.cov_diag