Fix torch.max optional args dim, keepdim description (#147177)

[`torch.max`](https://pytorch.org/docs/stable/generated/torch.max.html#torch.max) optional args `dim`, `keepdim` not described in document, but users can ignore them.

```python
>>> import torch
>>> a = torch.randn(3,1,3)
>>> a.max()
tensor(1.9145)
>>> a.max(dim=1)
torch.return_types.max(
values=tensor([[ 1.1436, -0.0728,  1.3312],
        [-0.4049,  0.1792, -1.2247],
        [ 0.8767, -0.7888,  1.9145]]),
indices=tensor([[0, 0, 0],
        [0, 0, 0],
        [0, 0, 0]]))

```

## Changes

- Add `optional` description for `dim`, `keepdim`
- Add example of using `dim`, `keepdim`

## Test Result

### Before

![image](https://github.com/user-attachments/assets/3391bc45-b636-4e64-9406-04d80af0c087)

### After

![image](https://github.com/user-attachments/assets/1d70e282-409c-4573-b276-b8219fd6ef0a)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147177
Approved by: https://github.com/colesbury
This commit is contained in:
zeshengzong 2025-02-20 08:18:09 +00:00 committed by PyTorch MergeBot
parent 452315c84f
commit 6a72aaadae

View File

@ -71,6 +71,11 @@ output tensor having 1 (or ``len(dim)``) fewer dimension(s).
"opt_dim": """ "opt_dim": """
dim (int or tuple of ints, optional): the dimension or dimensions to reduce. dim (int or tuple of ints, optional): the dimension or dimensions to reduce.
If ``None``, all dimensions are reduced. If ``None``, all dimensions are reduced.
"""
},
{
"opt_keepdim": """
keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``.
""" """
}, },
) )
@ -6483,8 +6488,8 @@ in the output tensors having 1 fewer dimension than ``input``.
Args: Args:
{input} {input}
{dim} {opt_dim}
{keepdim} Default: ``False``. {opt_keepdim}
Keyword args: Keyword args:
out (tuple, optional): the result tuple of two output tensors (max, max_indices) out (tuple, optional): the result tuple of two output tensors (max, max_indices)
@ -6499,13 +6504,22 @@ Example::
[-0.6172, 1.0036, -0.6060, -0.2432]]) [-0.6172, 1.0036, -0.6060, -0.2432]])
>>> torch.max(a, 1) >>> torch.max(a, 1)
torch.return_types.max(values=tensor([0.8475, 1.1949, 1.5717, 1.0036]), indices=tensor([3, 0, 0, 1])) torch.return_types.max(values=tensor([0.8475, 1.1949, 1.5717, 1.0036]), indices=tensor([3, 0, 0, 1]))
>>> a = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
>>> a.max(dim=1, keepdim=True)
torch.return_types.max(
values=tensor([[2.], [4.]]),
indices=tensor([[1], [1]]))
>>> a.max(dim=1, keepdim=False)
torch.return_types.max(
values=tensor([2., 4.]),
indices=tensor([1, 1]))
.. function:: max(input, other, *, out=None) -> Tensor .. function:: max(input, other, *, out=None) -> Tensor
:noindex: :noindex:
See :func:`torch.maximum`. See :func:`torch.maximum`.
""".format(**single_dim_common), """.format(**multi_dim_common),
) )
add_docstr( add_docstr(