mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Addressing issue https://github.com/pytorch/pytorch/issues/18125 This implements a mixture distributions, where all components are from the same distribution family. Right now the implementation supports the ```mean, variance, sample, log_prob``` methods. cc: fritzo and neerajprad - [x] add import and `__all__` string in `torch/distributions/__init__.py` - [x] register docs in docs/source/distributions.rst ### Tests (all tests live in tests/distributions.py) - [x] add an `Example(MixtureSameFamily, [...])` to the `EXAMPLES` list, populating `[...]` with three examples: one with `Normal`, one with `Categorical`, and one with `MultivariateNormal` (to exercise, `FloatTensor`, `LongTensor`, and nontrivial `event_dim`) - [x] add a `test_mixture_same_family_shape()` to `TestDistributions`. It would be good to test this with both `Normal` and `MultivariateNormal` - [x] add a `test_mixture_same_family_log_prob()` to `TestDistributions`. - [x] add a `test_mixture_same_family_sample()` to `TestDistributions`. - [x] add a `test_mixture_same_family_shape()` to `TestDistributionShapes` ### Triaged for follup-up PR? - support batch shape - implement `.expand()` - implement `kl_divergence()` in torch/distributions/kl.py Pull Request resolved: https://github.com/pytorch/pytorch/pull/22742 Differential Revision: D19899726 Pulled By: ezyang fbshipit-source-id: 9c816e83a2ef104fe3ea3117c95680b51c7a2fa4 |
||
|---|---|---|
| .. | ||
| _static | ||
| _templates | ||
| community | ||
| notes | ||
| org/pytorch | ||
| scripts | ||
| __config__.rst | ||
| amp.rst | ||
| autograd.rst | ||
| bottleneck.rst | ||
| checkpoint.rst | ||
| conf.py | ||
| cpp_extension.rst | ||
| cuda_deterministic_backward.rst | ||
| cuda_deterministic.rst | ||
| cuda.rst | ||
| cudnn_deterministic.rst | ||
| cudnn_persistent_rnn.rst | ||
| data.rst | ||
| distributed.rst | ||
| distributions.rst | ||
| dlpack.rst | ||
| hub.rst | ||
| index.rst | ||
| jit_builtin_functions.rst | ||
| jit_language_reference.rst | ||
| jit_python_reference.rst | ||
| jit_unsupported.rst | ||
| jit.rst | ||
| math-quantizer-equation.png | ||
| model_zoo.rst | ||
| multiprocessing.rst | ||
| name_inference.rst | ||
| named_tensor.rst | ||
| nn.functional.rst | ||
| nn.init.rst | ||
| nn.rst | ||
| onnx.rst | ||
| optim.rst | ||
| packages.rst | ||
| quantization.rst | ||
| random.rst | ||
| rpc.rst | ||
| sparse.rst | ||
| storage.rst | ||
| tensor_attributes.rst | ||
| tensorboard.rst | ||
| tensors.rst | ||
| torch.rst | ||
| type_info.rst | ||