Documenting torch.distributions.utils.clamp_probs (#128136)

Fixes https://github.com/pytorch/pytorch/issues/127889

This PR adds docstring to the `torch.distributions.utils.clamp_probs` function.

Co-authored-by: Svetlana Karslioglu <svekars@meta.com>
Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128136
Approved by: https://github.com/janeyx99, https://github.com/svekars, https://github.com/malfet
This commit is contained in:
GdoongMathew 2024-06-07 00:49:40 +00:00 committed by PyTorch MergeBot
parent 740cd0559f
commit f99409903c

View File

@ -90,6 +90,27 @@ def logits_to_probs(logits, is_binary=False):
def clamp_probs(probs):
"""Clamps the probabilities to be in the open interval `(0, 1)`.
The probabilities would be clamped between `eps` and `1 - eps`,
and `eps` would be the smallest representable positive number for the input data type.
Args:
probs (Tensor): A tensor of probabilities.
Returns:
Tensor: The clamped probabilities.
Examples:
>>> probs = torch.tensor([0.0, 0.5, 1.0])
>>> clamp_probs(probs)
tensor([1.1921e-07, 5.0000e-01, 1.0000e+00])
>>> probs = torch.tensor([0.0, 0.5, 1.0], dtype=torch.float64)
>>> clamp_probs(probs)
tensor([2.2204e-16, 5.0000e-01, 1.0000e+00], dtype=torch.float64)
"""
eps = torch.finfo(probs.dtype).eps
return probs.clamp(min=eps, max=1 - eps)