mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Documenting torch.distributions.utils.clamp_probs (#128136)
Fixes https://github.com/pytorch/pytorch/issues/127889 This PR adds docstring to the `torch.distributions.utils.clamp_probs` function. Co-authored-by: Svetlana Karslioglu <svekars@meta.com> Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/128136 Approved by: https://github.com/janeyx99, https://github.com/svekars, https://github.com/malfet
This commit is contained in:
parent
740cd0559f
commit
f99409903c
|
|
@ -90,6 +90,27 @@ def logits_to_probs(logits, is_binary=False):
|
|||
|
||||
|
||||
def clamp_probs(probs):
|
||||
"""Clamps the probabilities to be in the open interval `(0, 1)`.
|
||||
|
||||
The probabilities would be clamped between `eps` and `1 - eps`,
|
||||
and `eps` would be the smallest representable positive number for the input data type.
|
||||
|
||||
Args:
|
||||
probs (Tensor): A tensor of probabilities.
|
||||
|
||||
Returns:
|
||||
Tensor: The clamped probabilities.
|
||||
|
||||
Examples:
|
||||
>>> probs = torch.tensor([0.0, 0.5, 1.0])
|
||||
>>> clamp_probs(probs)
|
||||
tensor([1.1921e-07, 5.0000e-01, 1.0000e+00])
|
||||
|
||||
>>> probs = torch.tensor([0.0, 0.5, 1.0], dtype=torch.float64)
|
||||
>>> clamp_probs(probs)
|
||||
tensor([2.2204e-16, 5.0000e-01, 1.0000e+00], dtype=torch.float64)
|
||||
|
||||
"""
|
||||
eps = torch.finfo(probs.dtype).eps
|
||||
return probs.clamp(min=eps, max=1 - eps)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user