mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fix torch.max optional args dim, keepdim description (#147177)
[`torch.max`](https://pytorch.org/docs/stable/generated/torch.max.html#torch.max) optional args `dim`, `keepdim` not described in document, but users can ignore them. ```python >>> import torch >>> a = torch.randn(3,1,3) >>> a.max() tensor(1.9145) >>> a.max(dim=1) torch.return_types.max( values=tensor([[ 1.1436, -0.0728, 1.3312], [-0.4049, 0.1792, -1.2247], [ 0.8767, -0.7888, 1.9145]]), indices=tensor([[0, 0, 0], [0, 0, 0], [0, 0, 0]])) ``` ## Changes - Add `optional` description for `dim`, `keepdim` - Add example of using `dim`, `keepdim` ## Test Result ### Before  ### After  Pull Request resolved: https://github.com/pytorch/pytorch/pull/147177 Approved by: https://github.com/colesbury
This commit is contained in:
parent
452315c84f
commit
6a72aaadae
|
|
@ -71,6 +71,11 @@ output tensor having 1 (or ``len(dim)``) fewer dimension(s).
|
||||||
"opt_dim": """
|
"opt_dim": """
|
||||||
dim (int or tuple of ints, optional): the dimension or dimensions to reduce.
|
dim (int or tuple of ints, optional): the dimension or dimensions to reduce.
|
||||||
If ``None``, all dimensions are reduced.
|
If ``None``, all dimensions are reduced.
|
||||||
|
"""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"opt_keepdim": """
|
||||||
|
keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``.
|
||||||
"""
|
"""
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
@ -6483,8 +6488,8 @@ in the output tensors having 1 fewer dimension than ``input``.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
{input}
|
{input}
|
||||||
{dim}
|
{opt_dim}
|
||||||
{keepdim} Default: ``False``.
|
{opt_keepdim}
|
||||||
|
|
||||||
Keyword args:
|
Keyword args:
|
||||||
out (tuple, optional): the result tuple of two output tensors (max, max_indices)
|
out (tuple, optional): the result tuple of two output tensors (max, max_indices)
|
||||||
|
|
@ -6499,13 +6504,22 @@ Example::
|
||||||
[-0.6172, 1.0036, -0.6060, -0.2432]])
|
[-0.6172, 1.0036, -0.6060, -0.2432]])
|
||||||
>>> torch.max(a, 1)
|
>>> torch.max(a, 1)
|
||||||
torch.return_types.max(values=tensor([0.8475, 1.1949, 1.5717, 1.0036]), indices=tensor([3, 0, 0, 1]))
|
torch.return_types.max(values=tensor([0.8475, 1.1949, 1.5717, 1.0036]), indices=tensor([3, 0, 0, 1]))
|
||||||
|
>>> a = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
|
||||||
|
>>> a.max(dim=1, keepdim=True)
|
||||||
|
torch.return_types.max(
|
||||||
|
values=tensor([[2.], [4.]]),
|
||||||
|
indices=tensor([[1], [1]]))
|
||||||
|
>>> a.max(dim=1, keepdim=False)
|
||||||
|
torch.return_types.max(
|
||||||
|
values=tensor([2., 4.]),
|
||||||
|
indices=tensor([1, 1]))
|
||||||
|
|
||||||
.. function:: max(input, other, *, out=None) -> Tensor
|
.. function:: max(input, other, *, out=None) -> Tensor
|
||||||
:noindex:
|
:noindex:
|
||||||
|
|
||||||
See :func:`torch.maximum`.
|
See :func:`torch.maximum`.
|
||||||
|
|
||||||
""".format(**single_dim_common),
|
""".format(**multi_dim_common),
|
||||||
)
|
)
|
||||||
|
|
||||||
add_docstr(
|
add_docstr(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user