mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Generate static docstrings for torch._masked functions. (#72865)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72865
Fixes #72636
Test Plan: Imported from OSS
Reviewed By: zou3519
Differential Revision: D34286183
Pulled By: cpuhrsch
fbshipit-source-id: 9cf81bfed6ba8c82593f6a1d9e0b20d0a083310d
(cherry picked from commit 0a3f57896b)
This commit is contained in:
parent
1f74e082e2
commit
456d96d544
58
tools/update_masked_docs.py
Normal file
58
tools/update_masked_docs.py
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
"""This script updates the file torch/_masked/_docs.py that contains
|
||||
the generated doc-strings for various masked operations. The update
|
||||
should be triggered whenever a new masked operation is introduced to
|
||||
torch._masked package. Running the script requires that torch package
|
||||
is functional.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
def main() -> None:
|
||||
|
||||
target = os.path.join('torch', '_masked', '_docs.py')
|
||||
|
||||
try:
|
||||
import torch
|
||||
except ImportError as msg:
|
||||
print(f'Failed to import torch required to build {target}: {msg}')
|
||||
return
|
||||
|
||||
if os.path.isfile(target):
|
||||
with open(target) as _f:
|
||||
current_content = _f.read()
|
||||
else:
|
||||
current_content = ''
|
||||
|
||||
_new_content = []
|
||||
_new_content.append('''\
|
||||
# -*- coding: utf-8 -*-
|
||||
# This file is generated, do not modify it!
|
||||
#
|
||||
# To update this file, run the update masked docs script as follows:
|
||||
#
|
||||
# python tools/update_masked_docs.py
|
||||
#
|
||||
# The script must be called from an environment where the development
|
||||
# version of torch package can be imported and is functional.
|
||||
#
|
||||
''')
|
||||
|
||||
for func_name in sorted(torch._masked.__all__):
|
||||
func = getattr(torch._masked, func_name)
|
||||
func_doc = torch._masked._generate_docstring(func)
|
||||
_new_content.append(f'{func_name}_docstring = """{func_doc}"""\n')
|
||||
|
||||
new_content = '\n'.join(_new_content)
|
||||
|
||||
if new_content == current_content:
|
||||
print(f'Nothing to update in {target}')
|
||||
return
|
||||
|
||||
with open(target, 'w') as _f:
|
||||
_f.write(new_content)
|
||||
|
||||
print(f'Successfully updated {target}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
|
@ -2,8 +2,10 @@
|
|||
|
||||
from typing import Optional, Tuple, List, Union, Any
|
||||
|
||||
import warnings
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from . import _docs
|
||||
|
||||
# A workaround to support both TorchScript and MyPy:
|
||||
from typing import TYPE_CHECKING
|
||||
|
|
@ -27,6 +29,26 @@ def _apply_docstring_templates(func):
|
|||
"""Decorator that applies docstring templates to function docstring
|
||||
and returns the function instance.
|
||||
"""
|
||||
|
||||
doc_string = getattr(_docs, f'{func.__name__}_docstring', None)
|
||||
if doc_string is None:
|
||||
warnings.warn(
|
||||
f'No documentation string available for {func.__name__}.'
|
||||
' PyTorch team should run `python tools/update_masked_docs.py`'
|
||||
' to generate the missing docstrings.')
|
||||
else:
|
||||
func.__doc__ = doc_string
|
||||
|
||||
# Expose function as public symbol
|
||||
__all__.append(func.__name__)
|
||||
|
||||
return func
|
||||
|
||||
|
||||
def _generate_docstring(func):
|
||||
"""An utility function called from tools/update_masked_docs.py
|
||||
script to update the module torch._masked._docs.py
|
||||
"""
|
||||
docstring_templates = dict(
|
||||
reduction_signature='''\
|
||||
{function_name}(input, {operation_args}, *, {operation_kwargs}) -> Tensor''',
|
||||
|
|
@ -297,12 +319,7 @@ defined as ``x[i]/max(norm(x, p), eps)``.''')
|
|||
doc_template = '\n\n'.join([f'{{{op_kind}_{sec}}}' for sec in doc_sections])
|
||||
else:
|
||||
doc_template = func.__doc__
|
||||
func.__doc__ = doc_template.format_map(templates)
|
||||
|
||||
# Expose function as public symbol
|
||||
__all__.append(func.__name__)
|
||||
|
||||
return func
|
||||
return doc_template.format_map(templates)
|
||||
|
||||
|
||||
def _reduction_identity(op_name: str, input: Tensor, *args):
|
||||
|
|
|
|||
734
torch/_masked/_docs.py
Normal file
734
torch/_masked/_docs.py
Normal file
|
|
@ -0,0 +1,734 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# This file is generated, do not modify it!
|
||||
#
|
||||
# To update this file, run the update masked docs script as follows:
|
||||
#
|
||||
# python tools/update_masked_docs.py
|
||||
#
|
||||
# The script must be called from an environment where the development
|
||||
# version of torch package can be imported and is functional.
|
||||
#
|
||||
|
||||
amax_docstring = """amax(input, dim, *, keepdim=False, dtype=None, mask=None) -> Tensor
|
||||
|
||||
Returns maximum of all the elements in the :attr:`input`
|
||||
tensor along the given dimension(s) :attr:`dim` while the :attr:`input`
|
||||
elements are masked out according to the boolean tensor
|
||||
:attr:`mask`.
|
||||
|
||||
The identity value of maximum operation, which is used to start the
|
||||
reduction, depends on input dtype. For instance, for float32, uint8,
|
||||
and int32 dtypes, the identity values are ``-inf``, ``0``, and ``-2147483648``, respectively.
|
||||
|
||||
If :attr:`keepdim` is ``True``, the output tensor is of the same size
|
||||
as :attr:`input` except in the dimension(s) :attr:`dim` where it is of
|
||||
size 1. Otherwise, :attr:`dim` is squeezed (see
|
||||
:func:`torch.squeeze`), resulting in the output tensor having 1 (or
|
||||
``len(dim)``) fewer dimension(s).
|
||||
|
||||
The boolean tensor :attr:`mask` defines the "validity" of
|
||||
:attr:`input` tensor elements: if :attr:`mask` element is True
|
||||
then the corresponding element in :attr:`input` tensor will be
|
||||
included in maximum computation, otherwise the element is
|
||||
ignored.
|
||||
|
||||
When all elements of :attr:`input` along the given dimension
|
||||
:attr:`dim` are ignored (fully masked-out), the corresponding element
|
||||
of the output tensor will have undefined value: it may or may not
|
||||
correspond to the identity value of maximum operation; the
|
||||
choice may correspond to the value that leads to the most efficient
|
||||
storage of :attr:`output` tensor.
|
||||
|
||||
The mask of the output tensor can be computed as
|
||||
``torch.any(torch.broadcast_to(mask, input.shape), dim, keepdim=keepdim,
|
||||
dtype=torch.bool)``.
|
||||
|
||||
The shapes of the :attr:`mask` tensor and the :attr:`input` tensor
|
||||
don't need to match, but they must be :ref:`broadcastable
|
||||
<broadcasting-semantics>` and the dimensionality of the :attr:`mask`
|
||||
tensor must not be greater than of the :attr:`input` tensor.
|
||||
|
||||
Args:
|
||||
input (Tensor): the input tensor
|
||||
dim (int or tuple of ints, optional): the dimension or dimensions to reduce.
|
||||
Default: None that is equivalent to ``tuple(range(input.ndim))``.
|
||||
|
||||
Keyword args:
|
||||
keepdim (bool, optional): whether the output tensor has
|
||||
:attr:`dim` retained or not. Default: False.
|
||||
dtype (:class:`torch.dtype`, optional): the desired data type
|
||||
of returned tensor. If specified, the input tensor is
|
||||
casted to :attr:`dtype` before the operation is
|
||||
performed. Default: None.
|
||||
mask (:class:`torch.Tensor`, optional): the boolean tensor
|
||||
containing the binary mask of validity of input tensor
|
||||
elements.
|
||||
Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``.
|
||||
|
||||
Example::
|
||||
|
||||
>>> input = tensor([[-3, -2, -1], [ 0, 1, 2]])
|
||||
>>> input
|
||||
tensor([[-3, -2, -1],
|
||||
[ 0, 1, 2]])
|
||||
>>> mask = tensor([[ True, False, True], [False, False, False]])
|
||||
>>> mask
|
||||
tensor([[ True, False, True],
|
||||
[False, False, False]])
|
||||
>>> torch._masked.amax(input, 1, mask=mask)
|
||||
tensor([ -1, -9223372036854775808])
|
||||
"""
|
||||
|
||||
amin_docstring = """amin(input, dim, *, keepdim=False, dtype=None, mask=None) -> Tensor
|
||||
|
||||
Returns minimum of all the elements in the :attr:`input`
|
||||
tensor along the given dimension(s) :attr:`dim` while the :attr:`input`
|
||||
elements are masked out according to the boolean tensor
|
||||
:attr:`mask`.
|
||||
|
||||
The identity value of minimum operation, which is used to start the
|
||||
reduction, depends on input dtype. For instance, for float32, uint8,
|
||||
and int32 dtypes, the identity values are ``inf``, ``255``, and ``2147483647``, respectively.
|
||||
|
||||
If :attr:`keepdim` is ``True``, the output tensor is of the same size
|
||||
as :attr:`input` except in the dimension(s) :attr:`dim` where it is of
|
||||
size 1. Otherwise, :attr:`dim` is squeezed (see
|
||||
:func:`torch.squeeze`), resulting in the output tensor having 1 (or
|
||||
``len(dim)``) fewer dimension(s).
|
||||
|
||||
The boolean tensor :attr:`mask` defines the "validity" of
|
||||
:attr:`input` tensor elements: if :attr:`mask` element is True
|
||||
then the corresponding element in :attr:`input` tensor will be
|
||||
included in minimum computation, otherwise the element is
|
||||
ignored.
|
||||
|
||||
When all elements of :attr:`input` along the given dimension
|
||||
:attr:`dim` are ignored (fully masked-out), the corresponding element
|
||||
of the output tensor will have undefined value: it may or may not
|
||||
correspond to the identity value of minimum operation; the
|
||||
choice may correspond to the value that leads to the most efficient
|
||||
storage of :attr:`output` tensor.
|
||||
|
||||
The mask of the output tensor can be computed as
|
||||
``torch.any(torch.broadcast_to(mask, input.shape), dim, keepdim=keepdim,
|
||||
dtype=torch.bool)``.
|
||||
|
||||
The shapes of the :attr:`mask` tensor and the :attr:`input` tensor
|
||||
don't need to match, but they must be :ref:`broadcastable
|
||||
<broadcasting-semantics>` and the dimensionality of the :attr:`mask`
|
||||
tensor must not be greater than of the :attr:`input` tensor.
|
||||
|
||||
Args:
|
||||
input (Tensor): the input tensor
|
||||
dim (int or tuple of ints, optional): the dimension or dimensions to reduce.
|
||||
Default: None that is equivalent to ``tuple(range(input.ndim))``.
|
||||
|
||||
Keyword args:
|
||||
keepdim (bool, optional): whether the output tensor has
|
||||
:attr:`dim` retained or not. Default: False.
|
||||
dtype (:class:`torch.dtype`, optional): the desired data type
|
||||
of returned tensor. If specified, the input tensor is
|
||||
casted to :attr:`dtype` before the operation is
|
||||
performed. Default: None.
|
||||
mask (:class:`torch.Tensor`, optional): the boolean tensor
|
||||
containing the binary mask of validity of input tensor
|
||||
elements.
|
||||
Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``.
|
||||
|
||||
Example::
|
||||
|
||||
>>> input = tensor([[-3, -2, -1], [ 0, 1, 2]])
|
||||
>>> input
|
||||
tensor([[-3, -2, -1],
|
||||
[ 0, 1, 2]])
|
||||
>>> mask = tensor([[ True, False, True], [False, False, False]])
|
||||
>>> mask
|
||||
tensor([[ True, False, True],
|
||||
[False, False, False]])
|
||||
>>> torch._masked.amin(input, 1, mask=mask)
|
||||
tensor([ -3, 9223372036854775807])
|
||||
"""
|
||||
|
||||
log_softmax_docstring = """log_softmax(input, dim, *, dtype=None, mask=None) -> Tensor
|
||||
|
||||
Returns log_softmax of all the slices in the :attr:`input` tensor
|
||||
along :attr:`dim` while the :attr:`input` elements are masked out
|
||||
according to the boolean tensor :attr:`mask`.
|
||||
|
||||
Let ``x`` be a sequence of unmasked elements of one-dimensional slice
|
||||
of the :attr:`input` tensor. LogSoftmax of i-th element in ``x`` is
|
||||
defined as ``log(exp(x[i])/sum(exp(x)))``.
|
||||
|
||||
The boolean tensor :attr:`mask` defines the "validity" of
|
||||
:attr:`input` tensor elements: if :attr:`mask` element is True then
|
||||
the corresponding element in :attr:`input` tensor will be included in
|
||||
log_softmax computation, otherwise the element is ignored.
|
||||
|
||||
The values of masked-out elements of the output tensor have undefined
|
||||
value: it may or may not be set to zero or nan; the choice may correspond to
|
||||
the value that leads to the most efficient storage of :attr:`output`
|
||||
tensor.
|
||||
|
||||
The mask of the log_softmax output tensor can be computed as
|
||||
``torch.broadcast_to(mask, input.shape)``.
|
||||
|
||||
The shapes of the :attr:`mask` tensor and the :attr:`input` tensor
|
||||
don't need to match, but they must be :ref:`broadcastable
|
||||
<broadcasting-semantics>` and the dimensionality of the :attr:`mask`
|
||||
tensor must not be greater than of the :attr:`input` tensor.
|
||||
|
||||
Args:
|
||||
input (Tensor): the input tensor
|
||||
dim (int): the dimension along which log_softmax is computed.
|
||||
|
||||
Keyword args:
|
||||
dtype (:class:`torch.dtype`, optional): the desired data type
|
||||
of returned tensor. If specified, the input tensor is
|
||||
casted to :attr:`dtype` before the operation is
|
||||
performed. Default: None.
|
||||
mask (:class:`torch.Tensor`, optional): the boolean tensor
|
||||
containing the binary mask of validity of input tensor
|
||||
elements.
|
||||
Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``.
|
||||
|
||||
Example::
|
||||
|
||||
>>> input = tensor([[-3., -2., -1.], [ 0., 1., 2.]])
|
||||
>>> input
|
||||
tensor([[-3., -2., -1.],
|
||||
[ 0., 1., 2.]])
|
||||
>>> mask = tensor([[ True, False, True], [False, False, False]])
|
||||
>>> mask
|
||||
tensor([[ True, False, True],
|
||||
[False, False, False]])
|
||||
>>> torch._masked.log_softmax(input, 1, mask=mask)
|
||||
tensor([[-2.1269, -inf, -0.1269],
|
||||
[ nan, nan, nan]])
|
||||
"""
|
||||
|
||||
mean_docstring = """mean(input, dim, *, keepdim=False, dtype=None, mask=None) -> Tensor
|
||||
|
||||
Returns mean of all the elements in the :attr:`input`
|
||||
tensor along the given dimension(s) :attr:`dim` while the :attr:`input`
|
||||
elements are masked out according to the boolean tensor
|
||||
:attr:`mask`.
|
||||
|
||||
By definition, the identity value of a mean operation is the mean
|
||||
value of the tensor. If all elements of the input tensor along given
|
||||
dimension(s) :attr:`dim` are masked-out, the identity value of the
|
||||
mean is undefined. Due to this ambiguity, the elements of output
|
||||
tensor with strided layout, that correspond to fully masked-out
|
||||
elements, have ``nan`` values.
|
||||
|
||||
If :attr:`keepdim` is ``True``, the output tensor is of the same size
|
||||
as :attr:`input` except in the dimension(s) :attr:`dim` where it is of
|
||||
size 1. Otherwise, :attr:`dim` is squeezed (see
|
||||
:func:`torch.squeeze`), resulting in the output tensor having 1 (or
|
||||
``len(dim)``) fewer dimension(s).
|
||||
|
||||
The boolean tensor :attr:`mask` defines the "validity" of
|
||||
:attr:`input` tensor elements: if :attr:`mask` element is True
|
||||
then the corresponding element in :attr:`input` tensor will be
|
||||
included in mean computation, otherwise the element is
|
||||
ignored.
|
||||
|
||||
When all elements of :attr:`input` along the given dimension
|
||||
:attr:`dim` are ignored (fully masked-out), the corresponding element
|
||||
of the output tensor will have undefined value: it may or may not
|
||||
correspond to the identity value of mean operation; the
|
||||
choice may correspond to the value that leads to the most efficient
|
||||
storage of :attr:`output` tensor.
|
||||
|
||||
The mask of the output tensor can be computed as
|
||||
``torch.any(torch.broadcast_to(mask, input.shape), dim, keepdim=keepdim,
|
||||
dtype=torch.bool)``.
|
||||
|
||||
The shapes of the :attr:`mask` tensor and the :attr:`input` tensor
|
||||
don't need to match, but they must be :ref:`broadcastable
|
||||
<broadcasting-semantics>` and the dimensionality of the :attr:`mask`
|
||||
tensor must not be greater than of the :attr:`input` tensor.
|
||||
|
||||
Args:
|
||||
input (Tensor): the input tensor
|
||||
dim (int or tuple of ints, optional): the dimension or dimensions to reduce.
|
||||
Default: None that is equivalent to ``tuple(range(input.ndim))``.
|
||||
|
||||
Keyword args:
|
||||
keepdim (bool, optional): whether the output tensor has
|
||||
:attr:`dim` retained or not. Default: False.
|
||||
dtype (:class:`torch.dtype`, optional): the desired data type
|
||||
of returned tensor. If specified, the input tensor is
|
||||
casted to :attr:`dtype` before the operation is
|
||||
performed. Default: None.
|
||||
mask (:class:`torch.Tensor`, optional): the boolean tensor
|
||||
containing the binary mask of validity of input tensor
|
||||
elements.
|
||||
Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``.
|
||||
|
||||
Example::
|
||||
|
||||
>>> input = tensor([[-3, -2, -1], [ 0, 1, 2]])
|
||||
>>> input
|
||||
tensor([[-3, -2, -1],
|
||||
[ 0, 1, 2]])
|
||||
>>> mask = tensor([[ True, False, True], [False, False, False]])
|
||||
>>> mask
|
||||
tensor([[ True, False, True],
|
||||
[False, False, False]])
|
||||
>>> torch._masked.mean(input, 1, mask=mask)
|
||||
tensor([-2., nan])
|
||||
"""
|
||||
|
||||
norm_docstring = """norm(input, ord, dim, *, keepdim=False, dtype=None, mask=None) -> Tensor
|
||||
|
||||
Returns norm of all the elements in the :attr:`input`
|
||||
tensor along the given dimension(s) :attr:`dim` while the :attr:`input`
|
||||
elements are masked out according to the boolean tensor
|
||||
:attr:`mask`.
|
||||
|
||||
The identity value of norm operation, which is used to start the
|
||||
reduction, is ``0.0``, except for ``ord=-inf`` it is
|
||||
``inf``.
|
||||
|
||||
If :attr:`keepdim` is ``True``, the output tensor is of the same size
|
||||
as :attr:`input` except in the dimension(s) :attr:`dim` where it is of
|
||||
size 1. Otherwise, :attr:`dim` is squeezed (see
|
||||
:func:`torch.squeeze`), resulting in the output tensor having 1 (or
|
||||
``len(dim)``) fewer dimension(s).
|
||||
|
||||
The boolean tensor :attr:`mask` defines the "validity" of
|
||||
:attr:`input` tensor elements: if :attr:`mask` element is True
|
||||
then the corresponding element in :attr:`input` tensor will be
|
||||
included in norm computation, otherwise the element is
|
||||
ignored.
|
||||
|
||||
When all elements of :attr:`input` along the given dimension
|
||||
:attr:`dim` are ignored (fully masked-out), the corresponding element
|
||||
of the output tensor will have undefined value: it may or may not
|
||||
correspond to the identity value of norm operation; the
|
||||
choice may correspond to the value that leads to the most efficient
|
||||
storage of :attr:`output` tensor.
|
||||
|
||||
The mask of the output tensor can be computed as
|
||||
``torch.any(torch.broadcast_to(mask, input.shape), dim, keepdim=keepdim,
|
||||
dtype=torch.bool)``.
|
||||
|
||||
The shapes of the :attr:`mask` tensor and the :attr:`input` tensor
|
||||
don't need to match, but they must be :ref:`broadcastable
|
||||
<broadcasting-semantics>` and the dimensionality of the :attr:`mask`
|
||||
tensor must not be greater than of the :attr:`input` tensor.
|
||||
|
||||
Args:
|
||||
input (Tensor): the input tensor
|
||||
ord (int, float, optional): the order of vector norm. Default: 2.
|
||||
See :func:`torch.linalg.vector_norm` for a list of supported norms.
|
||||
dim (int or tuple of ints, optional): the dimension or dimensions to reduce.
|
||||
Default: None that is equivalent to ``tuple(range(input.ndim))``.
|
||||
|
||||
Keyword args:
|
||||
keepdim (bool, optional): whether the output tensor has
|
||||
:attr:`dim` retained or not. Default: False.
|
||||
dtype (:class:`torch.dtype`, optional): the desired data type
|
||||
of returned tensor. If specified, the input tensor is
|
||||
casted to :attr:`dtype` before the operation is
|
||||
performed. Default: None.
|
||||
mask (:class:`torch.Tensor`, optional): the boolean tensor
|
||||
containing the binary mask of validity of input tensor
|
||||
elements.
|
||||
Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``.
|
||||
|
||||
Example::
|
||||
|
||||
>>> input = tensor([[-3., -2., -1.], [ 0., 1., 2.]])
|
||||
>>> input
|
||||
tensor([[-3., -2., -1.],
|
||||
[ 0., 1., 2.]])
|
||||
>>> mask = tensor([[ True, False, True], [False, False, False]])
|
||||
>>> mask
|
||||
tensor([[ True, False, True],
|
||||
[False, False, False]])
|
||||
>>> torch._masked.norm(input, 2.0, 1, mask=mask)
|
||||
tensor([3.1623, 0.0000])
|
||||
"""
|
||||
|
||||
normalize_docstring = """normalize(input, ord, dim, *, eps=1e-12, dtype=None, mask=None) -> Tensor
|
||||
|
||||
Returns normalize of all the slices in the :attr:`input` tensor
|
||||
along :attr:`dim` while the :attr:`input` elements are masked out
|
||||
according to the boolean tensor :attr:`mask`.
|
||||
|
||||
Let ``x`` be a sequence of unmasked elements of one-dimensional slice
|
||||
of the :attr:`input` tensor. Normalize of i-th element in ``x`` is
|
||||
defined as ``x[i]/max(norm(x, p), eps)``.
|
||||
|
||||
The boolean tensor :attr:`mask` defines the "validity" of
|
||||
:attr:`input` tensor elements: if :attr:`mask` element is True then
|
||||
the corresponding element in :attr:`input` tensor will be included in
|
||||
normalize computation, otherwise the element is ignored.
|
||||
|
||||
The values of masked-out elements of the output tensor have undefined
|
||||
value: it may or may not be set to zero or nan; the choice may correspond to
|
||||
the value that leads to the most efficient storage of :attr:`output`
|
||||
tensor.
|
||||
|
||||
The mask of the normalize output tensor can be computed as
|
||||
``torch.broadcast_to(mask, input.shape)``.
|
||||
|
||||
The shapes of the :attr:`mask` tensor and the :attr:`input` tensor
|
||||
don't need to match, but they must be :ref:`broadcastable
|
||||
<broadcasting-semantics>` and the dimensionality of the :attr:`mask`
|
||||
tensor must not be greater than of the :attr:`input` tensor.
|
||||
|
||||
Args:
|
||||
input (Tensor): the input tensor
|
||||
ord (int, float): the order of vector norm. Default: 2.
|
||||
See :func:`torch.linalg.vector_norm` for a list of supported norms.
|
||||
dim (int): the dimension along which normalize is computed.
|
||||
|
||||
Keyword args:
|
||||
eps (float, optional): small value to avoid division by zero. Default: 1e-12.
|
||||
dtype (:class:`torch.dtype`, optional): the desired data type
|
||||
of returned tensor. If specified, the input tensor is
|
||||
casted to :attr:`dtype` before the operation is
|
||||
performed. Default: None.
|
||||
mask (:class:`torch.Tensor`, optional): the boolean tensor
|
||||
containing the binary mask of validity of input tensor
|
||||
elements.
|
||||
Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``.
|
||||
|
||||
Example::
|
||||
|
||||
>>> input = tensor([[-3., -2., -1.], [ 0., 1., 2.]])
|
||||
>>> input
|
||||
tensor([[-3., -2., -1.],
|
||||
[ 0., 1., 2.]])
|
||||
>>> mask = tensor([[ True, False, True], [False, False, False]])
|
||||
>>> mask
|
||||
tensor([[ True, False, True],
|
||||
[False, False, False]])
|
||||
>>> torch._masked.normalize(input, 2.0, 1, mask=mask)
|
||||
tensor([[-0.9487, 0.0000, -0.3162],
|
||||
[ 0.0000, 0.0000, 0.0000]])
|
||||
"""
|
||||
|
||||
prod_docstring = """prod(input, dim, *, keepdim=False, dtype=None, mask=None) -> Tensor
|
||||
|
||||
Returns product of all the elements in the :attr:`input`
|
||||
tensor along the given dimension(s) :attr:`dim` while the :attr:`input`
|
||||
elements are masked out according to the boolean tensor
|
||||
:attr:`mask`.
|
||||
|
||||
The identity value of product operation, which is used to start the reduction, is ``1``.
|
||||
|
||||
If :attr:`keepdim` is ``True``, the output tensor is of the same size
|
||||
as :attr:`input` except in the dimension(s) :attr:`dim` where it is of
|
||||
size 1. Otherwise, :attr:`dim` is squeezed (see
|
||||
:func:`torch.squeeze`), resulting in the output tensor having 1 (or
|
||||
``len(dim)``) fewer dimension(s).
|
||||
|
||||
The boolean tensor :attr:`mask` defines the "validity" of
|
||||
:attr:`input` tensor elements: if :attr:`mask` element is True
|
||||
then the corresponding element in :attr:`input` tensor will be
|
||||
included in product computation, otherwise the element is
|
||||
ignored.
|
||||
|
||||
When all elements of :attr:`input` along the given dimension
|
||||
:attr:`dim` are ignored (fully masked-out), the corresponding element
|
||||
of the output tensor will have undefined value: it may or may not
|
||||
correspond to the identity value of product operation; the
|
||||
choice may correspond to the value that leads to the most efficient
|
||||
storage of :attr:`output` tensor.
|
||||
|
||||
The mask of the output tensor can be computed as
|
||||
``torch.any(torch.broadcast_to(mask, input.shape), dim, keepdim=keepdim,
|
||||
dtype=torch.bool)``.
|
||||
|
||||
The shapes of the :attr:`mask` tensor and the :attr:`input` tensor
|
||||
don't need to match, but they must be :ref:`broadcastable
|
||||
<broadcasting-semantics>` and the dimensionality of the :attr:`mask`
|
||||
tensor must not be greater than of the :attr:`input` tensor.
|
||||
|
||||
Args:
|
||||
input (Tensor): the input tensor
|
||||
dim (int or tuple of ints, optional): the dimension or dimensions to reduce.
|
||||
Default: None that is equivalent to ``tuple(range(input.ndim))``.
|
||||
|
||||
Keyword args:
|
||||
keepdim (bool, optional): whether the output tensor has
|
||||
:attr:`dim` retained or not. Default: False.
|
||||
dtype (:class:`torch.dtype`, optional): the desired data type
|
||||
of returned tensor. If specified, the input tensor is
|
||||
casted to :attr:`dtype` before the operation is
|
||||
performed. Default: None.
|
||||
mask (:class:`torch.Tensor`, optional): the boolean tensor
|
||||
containing the binary mask of validity of input tensor
|
||||
elements.
|
||||
Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``.
|
||||
|
||||
Example::
|
||||
|
||||
>>> input = tensor([[-3, -2, -1], [ 0, 1, 2]])
|
||||
>>> input
|
||||
tensor([[-3, -2, -1],
|
||||
[ 0, 1, 2]])
|
||||
>>> mask = tensor([[ True, False, True], [False, False, False]])
|
||||
>>> mask
|
||||
tensor([[ True, False, True],
|
||||
[False, False, False]])
|
||||
>>> torch._masked.prod(input, 1, mask=mask)
|
||||
tensor([3, 1])
|
||||
"""
|
||||
|
||||
softmax_docstring = """softmax(input, dim, *, dtype=None, mask=None) -> Tensor
|
||||
|
||||
Returns softmax of all the slices in the :attr:`input` tensor
|
||||
along :attr:`dim` while the :attr:`input` elements are masked out
|
||||
according to the boolean tensor :attr:`mask`.
|
||||
|
||||
Let ``x`` be a sequence of unmasked elements of one-dimensional slice
|
||||
of the :attr:`input` tensor. Softmax of i-th element in ``x`` is
|
||||
defined as ``exp(x[i])/sum(exp(x))``.
|
||||
|
||||
The boolean tensor :attr:`mask` defines the "validity" of
|
||||
:attr:`input` tensor elements: if :attr:`mask` element is True then
|
||||
the corresponding element in :attr:`input` tensor will be included in
|
||||
softmax computation, otherwise the element is ignored.
|
||||
|
||||
The values of masked-out elements of the output tensor have undefined
|
||||
value: it may or may not be set to zero or nan; the choice may correspond to
|
||||
the value that leads to the most efficient storage of :attr:`output`
|
||||
tensor.
|
||||
|
||||
The mask of the softmax output tensor can be computed as
|
||||
``torch.broadcast_to(mask, input.shape)``.
|
||||
|
||||
The shapes of the :attr:`mask` tensor and the :attr:`input` tensor
|
||||
don't need to match, but they must be :ref:`broadcastable
|
||||
<broadcasting-semantics>` and the dimensionality of the :attr:`mask`
|
||||
tensor must not be greater than of the :attr:`input` tensor.
|
||||
|
||||
Args:
|
||||
input (Tensor): the input tensor
|
||||
dim (int): the dimension along which softmax is computed.
|
||||
|
||||
Keyword args:
|
||||
dtype (:class:`torch.dtype`, optional): the desired data type
|
||||
of returned tensor. If specified, the input tensor is
|
||||
casted to :attr:`dtype` before the operation is
|
||||
performed. Default: None.
|
||||
mask (:class:`torch.Tensor`, optional): the boolean tensor
|
||||
containing the binary mask of validity of input tensor
|
||||
elements.
|
||||
Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``.
|
||||
|
||||
Example::
|
||||
|
||||
>>> input = tensor([[-3., -2., -1.], [ 0., 1., 2.]])
|
||||
>>> input
|
||||
tensor([[-3., -2., -1.],
|
||||
[ 0., 1., 2.]])
|
||||
>>> mask = tensor([[ True, False, True], [False, False, False]])
|
||||
>>> mask
|
||||
tensor([[ True, False, True],
|
||||
[False, False, False]])
|
||||
>>> torch._masked.softmax(input, 1, mask=mask)
|
||||
tensor([[0.1192, 0.0000, 0.8808],
|
||||
[ nan, nan, nan]])
|
||||
"""
|
||||
|
||||
softmin_docstring = """softmin(input, dim, *, dtype=None, mask=None) -> Tensor
|
||||
|
||||
Returns softmin of all the slices in the :attr:`input` tensor
|
||||
along :attr:`dim` while the :attr:`input` elements are masked out
|
||||
according to the boolean tensor :attr:`mask`.
|
||||
|
||||
Let ``x`` be a sequence of unmasked elements of one-dimensional slice
|
||||
of the :attr:`input` tensor. Softmin of i-th element in ``x`` is
|
||||
defined as ``exp(-x[i])/sum(exp(-x))``.
|
||||
|
||||
The boolean tensor :attr:`mask` defines the "validity" of
|
||||
:attr:`input` tensor elements: if :attr:`mask` element is True then
|
||||
the corresponding element in :attr:`input` tensor will be included in
|
||||
softmin computation, otherwise the element is ignored.
|
||||
|
||||
The values of masked-out elements of the output tensor have undefined
|
||||
value: it may or may not be set to zero or nan; the choice may correspond to
|
||||
the value that leads to the most efficient storage of :attr:`output`
|
||||
tensor.
|
||||
|
||||
The mask of the softmin output tensor can be computed as
|
||||
``torch.broadcast_to(mask, input.shape)``.
|
||||
|
||||
The shapes of the :attr:`mask` tensor and the :attr:`input` tensor
|
||||
don't need to match, but they must be :ref:`broadcastable
|
||||
<broadcasting-semantics>` and the dimensionality of the :attr:`mask`
|
||||
tensor must not be greater than of the :attr:`input` tensor.
|
||||
|
||||
Args:
|
||||
input (Tensor): the input tensor
|
||||
dim (int): the dimension along which softmin is computed.
|
||||
|
||||
Keyword args:
|
||||
dtype (:class:`torch.dtype`, optional): the desired data type
|
||||
of returned tensor. If specified, the input tensor is
|
||||
casted to :attr:`dtype` before the operation is
|
||||
performed. Default: None.
|
||||
mask (:class:`torch.Tensor`, optional): the boolean tensor
|
||||
containing the binary mask of validity of input tensor
|
||||
elements.
|
||||
Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``.
|
||||
|
||||
Example::
|
||||
|
||||
>>> input = tensor([[-3., -2., -1.], [ 0., 1., 2.]])
|
||||
>>> input
|
||||
tensor([[-3., -2., -1.],
|
||||
[ 0., 1., 2.]])
|
||||
>>> mask = tensor([[ True, False, True], [False, False, False]])
|
||||
>>> mask
|
||||
tensor([[ True, False, True],
|
||||
[False, False, False]])
|
||||
>>> torch._masked.softmin(input, 1, mask=mask)
|
||||
tensor([[0.8808, 0.0000, 0.1192],
|
||||
[ nan, nan, nan]])
|
||||
"""
|
||||
|
||||
sum_docstring = """sum(input, dim, *, keepdim=False, dtype=None, mask=None) -> Tensor
|
||||
|
||||
Returns sum of all the elements in the :attr:`input`
|
||||
tensor along the given dimension(s) :attr:`dim` while the :attr:`input`
|
||||
elements are masked out according to the boolean tensor
|
||||
:attr:`mask`.
|
||||
|
||||
The identity value of sum operation, which is used to start the reduction, is ``0``.
|
||||
|
||||
If :attr:`keepdim` is ``True``, the output tensor is of the same size
|
||||
as :attr:`input` except in the dimension(s) :attr:`dim` where it is of
|
||||
size 1. Otherwise, :attr:`dim` is squeezed (see
|
||||
:func:`torch.squeeze`), resulting in the output tensor having 1 (or
|
||||
``len(dim)``) fewer dimension(s).
|
||||
|
||||
The boolean tensor :attr:`mask` defines the "validity" of
|
||||
:attr:`input` tensor elements: if :attr:`mask` element is True
|
||||
then the corresponding element in :attr:`input` tensor will be
|
||||
included in sum computation, otherwise the element is
|
||||
ignored.
|
||||
|
||||
When all elements of :attr:`input` along the given dimension
|
||||
:attr:`dim` are ignored (fully masked-out), the corresponding element
|
||||
of the output tensor will have undefined value: it may or may not
|
||||
correspond to the identity value of sum operation; the
|
||||
choice may correspond to the value that leads to the most efficient
|
||||
storage of :attr:`output` tensor.
|
||||
|
||||
The mask of the output tensor can be computed as
|
||||
``torch.any(torch.broadcast_to(mask, input.shape), dim, keepdim=keepdim,
|
||||
dtype=torch.bool)``.
|
||||
|
||||
The shapes of the :attr:`mask` tensor and the :attr:`input` tensor
|
||||
don't need to match, but they must be :ref:`broadcastable
|
||||
<broadcasting-semantics>` and the dimensionality of the :attr:`mask`
|
||||
tensor must not be greater than of the :attr:`input` tensor.
|
||||
|
||||
Args:
|
||||
input (Tensor): the input tensor
|
||||
dim (int or tuple of ints, optional): the dimension or dimensions to reduce.
|
||||
Default: None that is equivalent to ``tuple(range(input.ndim))``.
|
||||
|
||||
Keyword args:
|
||||
keepdim (bool, optional): whether the output tensor has
|
||||
:attr:`dim` retained or not. Default: False.
|
||||
dtype (:class:`torch.dtype`, optional): the desired data type
|
||||
of returned tensor. If specified, the input tensor is
|
||||
casted to :attr:`dtype` before the operation is
|
||||
performed. Default: None.
|
||||
mask (:class:`torch.Tensor`, optional): the boolean tensor
|
||||
containing the binary mask of validity of input tensor
|
||||
elements.
|
||||
Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``.
|
||||
|
||||
Example::
|
||||
|
||||
>>> input = tensor([[-3, -2, -1], [ 0, 1, 2]])
|
||||
>>> input
|
||||
tensor([[-3, -2, -1],
|
||||
[ 0, 1, 2]])
|
||||
>>> mask = tensor([[ True, False, True], [False, False, False]])
|
||||
>>> mask
|
||||
tensor([[ True, False, True],
|
||||
[False, False, False]])
|
||||
>>> torch._masked.sum(input, 1, mask=mask)
|
||||
tensor([-4, 0])
|
||||
"""
|
||||
|
||||
var_docstring = """var(input, dim, unbiased, *, keepdim=False, dtype=None, mask=None) -> Tensor
|
||||
|
||||
Returns variance of all the elements in the :attr:`input`
|
||||
tensor along the given dimension(s) :attr:`dim` while the :attr:`input`
|
||||
elements are masked out according to the boolean tensor
|
||||
:attr:`mask`.
|
||||
|
||||
The identity value of sample variance operation is undefined. The
|
||||
elements of output tensor with strided layout, that correspond to
|
||||
fully masked-out elements, have ``nan`` values.
|
||||
|
||||
If :attr:`keepdim` is ``True``, the output tensor is of the same size
|
||||
as :attr:`input` except in the dimension(s) :attr:`dim` where it is of
|
||||
size 1. Otherwise, :attr:`dim` is squeezed (see
|
||||
:func:`torch.squeeze`), resulting in the output tensor having 1 (or
|
||||
``len(dim)``) fewer dimension(s).
|
||||
|
||||
The boolean tensor :attr:`mask` defines the "validity" of
|
||||
:attr:`input` tensor elements: if :attr:`mask` element is True
|
||||
then the corresponding element in :attr:`input` tensor will be
|
||||
included in variance computation, otherwise the element is
|
||||
ignored.
|
||||
|
||||
When all elements of :attr:`input` along the given dimension
|
||||
:attr:`dim` are ignored (fully masked-out), the corresponding element
|
||||
of the output tensor will have undefined value: it may or may not
|
||||
correspond to the identity value of variance operation; the
|
||||
choice may correspond to the value that leads to the most efficient
|
||||
storage of :attr:`output` tensor.
|
||||
|
||||
The mask of the output tensor can be computed as
|
||||
``torch.any(torch.broadcast_to(mask, input.shape), dim, keepdim=keepdim,
|
||||
dtype=torch.bool)``.
|
||||
|
||||
The shapes of the :attr:`mask` tensor and the :attr:`input` tensor
|
||||
don't need to match, but they must be :ref:`broadcastable
|
||||
<broadcasting-semantics>` and the dimensionality of the :attr:`mask`
|
||||
tensor must not be greater than of the :attr:`input` tensor.
|
||||
|
||||
Args:
|
||||
input (Tensor): the input tensor
|
||||
dim (int or tuple of ints, optional): the dimension or dimensions to reduce.
|
||||
Default: None that is equivalent to ``tuple(range(input.ndim))``.
|
||||
unbiased (bool): when True, use Bessel’s correction, otherwise, compute
|
||||
the uncorrected sample variance.
|
||||
|
||||
Keyword args:
|
||||
keepdim (bool, optional): whether the output tensor has
|
||||
:attr:`dim` retained or not. Default: False.
|
||||
dtype (:class:`torch.dtype`, optional): the desired data type
|
||||
of returned tensor. If specified, the input tensor is
|
||||
casted to :attr:`dtype` before the operation is
|
||||
performed. Default: None.
|
||||
mask (:class:`torch.Tensor`, optional): the boolean tensor
|
||||
containing the binary mask of validity of input tensor
|
||||
elements.
|
||||
Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``.
|
||||
|
||||
Example::
|
||||
|
||||
>>> input = tensor([[-3, -2, -1], [ 0, 1, 2]])
|
||||
>>> input
|
||||
tensor([[-3, -2, -1],
|
||||
[ 0, 1, 2]])
|
||||
>>> mask = tensor([[ True, False, True], [False, False, False]])
|
||||
>>> mask
|
||||
tensor([[ True, False, True],
|
||||
[False, False, False]])
|
||||
>>> torch._masked.var(input, 1, False, mask=mask)
|
||||
tensor([1., nan])
|
||||
"""
|
||||
Loading…
Reference in New Issue
Block a user