mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Addressing issue https://github.com/pytorch/pytorch/issues/18125 This implements a mixture distributions, where all components are from the same distribution family. Right now the implementation supports the ```mean, variance, sample, log_prob``` methods. cc: fritzo and neerajprad - [x] add import and `__all__` string in `torch/distributions/__init__.py` - [x] register docs in docs/source/distributions.rst ### Tests (all tests live in tests/distributions.py) - [x] add an `Example(MixtureSameFamily, [...])` to the `EXAMPLES` list, populating `[...]` with three examples: one with `Normal`, one with `Categorical`, and one with `MultivariateNormal` (to exercise, `FloatTensor`, `LongTensor`, and nontrivial `event_dim`) - [x] add a `test_mixture_same_family_shape()` to `TestDistributions`. It would be good to test this with both `Normal` and `MultivariateNormal` - [x] add a `test_mixture_same_family_log_prob()` to `TestDistributions`. - [x] add a `test_mixture_same_family_sample()` to `TestDistributions`. - [x] add a `test_mixture_same_family_shape()` to `TestDistributionShapes` ### Triaged for follup-up PR? - support batch shape - implement `.expand()` - implement `kl_divergence()` in torch/distributions/kl.py Pull Request resolved: https://github.com/pytorch/pytorch/pull/22742 Differential Revision: D19899726 Pulled By: ezyang fbshipit-source-id: 9c816e83a2ef104fe3ea3117c95680b51c7a2fa4 |
||
|---|---|---|
| .. | ||
| caffe2 | ||
| cpp | ||
| source | ||
| .gitignore | ||
| libtorch.rst | ||
| make.bat | ||
| Makefile | ||
| requirements.txt | ||