mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Add some missing docs for tensor methods and attributes, new unittest to enforce tensors.rst no longer miss anything (#16057)
Summary: This depend on https://github.com/pytorch/pytorch/pull/16039 This prevent people (reviewer, PR author) from forgetting adding things to `tensors.rst`. When something new is added to `_tensor_doc.py` or `tensor.py` but intentionally not in `tensors.rst`, people should manually whitelist it in `test_docs_coverage.py`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/16057 Differential Revision: D14619550 Pulled By: ezyang fbshipit-source-id: e1c6dd6761142e2e48ec499e118df399e3949fcc
This commit is contained in:
parent
66e8c74814
commit
2ba41c5550
|
|
@ -141,6 +141,7 @@ view of a storage and defines numeric operations on it.
|
|||
|
||||
.. autoattribute:: is_cuda
|
||||
.. autoattribute:: device
|
||||
.. autoattribute:: grad
|
||||
|
||||
.. automethod:: abs
|
||||
.. automethod:: abs_
|
||||
|
|
@ -164,16 +165,19 @@ view of a storage and defines numeric operations on it.
|
|||
.. automethod:: apply_
|
||||
.. automethod:: argmax
|
||||
.. automethod:: argmin
|
||||
.. automethod:: argsort
|
||||
.. automethod:: asin
|
||||
.. automethod:: asin_
|
||||
.. automethod:: atan
|
||||
.. automethod:: atan2
|
||||
.. automethod:: atan2_
|
||||
.. automethod:: atan_
|
||||
.. automethod:: backward
|
||||
.. automethod:: baddbmm
|
||||
.. automethod:: baddbmm_
|
||||
.. automethod:: bernoulli
|
||||
.. automethod:: bernoulli_
|
||||
.. automethod:: bincount
|
||||
.. automethod:: bmm
|
||||
.. automethod:: byte
|
||||
.. automethod:: btrifact
|
||||
|
|
@ -202,8 +206,15 @@ view of a storage and defines numeric operations on it.
|
|||
.. automethod:: cumsum
|
||||
.. automethod:: data_ptr
|
||||
.. automethod:: det
|
||||
.. automethod:: dense_dim
|
||||
.. automethod:: detach
|
||||
.. automethod:: detach_
|
||||
.. automethod:: diag
|
||||
.. automethod:: diag_embed
|
||||
.. automethod:: diagflat
|
||||
.. automethod:: diagonal
|
||||
.. automethod:: digamma
|
||||
.. automethod:: digamma_
|
||||
.. automethod:: dim
|
||||
.. automethod:: dist
|
||||
.. automethod:: div
|
||||
|
|
@ -228,6 +239,7 @@ view of a storage and defines numeric operations on it.
|
|||
.. automethod:: expand
|
||||
.. automethod:: expand_as
|
||||
.. automethod:: exponential_
|
||||
.. automethod:: fft
|
||||
.. automethod:: fill_
|
||||
.. automethod:: flatten
|
||||
.. automethod:: flip
|
||||
|
|
@ -250,7 +262,9 @@ view of a storage and defines numeric operations on it.
|
|||
.. automethod:: gt
|
||||
.. automethod:: gt_
|
||||
.. automethod:: half
|
||||
.. automethod:: hardshrink
|
||||
.. automethod:: histc
|
||||
.. automethod:: ifft
|
||||
.. automethod:: index_add_
|
||||
.. automethod:: index_add
|
||||
.. automethod:: index_copy_
|
||||
|
|
@ -260,13 +274,18 @@ view of a storage and defines numeric operations on it.
|
|||
.. automethod:: index_put_
|
||||
.. automethod:: index_put
|
||||
.. automethod:: index_select
|
||||
.. automethod:: indices
|
||||
.. automethod:: int
|
||||
.. automethod:: inverse
|
||||
.. automethod:: irfft
|
||||
.. automethod:: is_contiguous
|
||||
.. automethod:: is_floating_point
|
||||
.. automethod:: is_leaf
|
||||
.. automethod:: is_pinned
|
||||
.. automethod:: is_set_to
|
||||
.. automethod:: is_shared
|
||||
.. automethod:: is_signed
|
||||
.. automethod:: is_sparse
|
||||
.. automethod:: item
|
||||
.. automethod:: kthvalue
|
||||
.. automethod:: le
|
||||
|
|
@ -308,6 +327,7 @@ view of a storage and defines numeric operations on it.
|
|||
.. automethod:: mvlgamma
|
||||
.. automethod:: mvlgamma_
|
||||
.. automethod:: narrow
|
||||
.. automethod:: narrow_copy
|
||||
.. automethod:: ndimension
|
||||
.. automethod:: ne
|
||||
.. automethod:: ne_
|
||||
|
|
@ -336,23 +356,28 @@ view of a storage and defines numeric operations on it.
|
|||
.. automethod:: random_
|
||||
.. automethod:: reciprocal
|
||||
.. automethod:: reciprocal_
|
||||
.. automethod:: register_hook
|
||||
.. automethod:: remainder
|
||||
.. automethod:: remainder_
|
||||
.. automethod:: renorm
|
||||
.. automethod:: renorm_
|
||||
.. automethod:: repeat
|
||||
.. automethod:: requires_grad
|
||||
.. automethod:: requires_grad_
|
||||
.. automethod:: reshape
|
||||
.. automethod:: reshape_as
|
||||
.. automethod:: resize_
|
||||
.. automethod:: resize_as_
|
||||
.. automethod:: retain_grad
|
||||
.. automethod:: rfft
|
||||
.. automethod:: roll
|
||||
.. automethod:: rot90
|
||||
.. automethod:: round
|
||||
.. automethod:: round_
|
||||
.. automethod:: rsqrt
|
||||
.. automethod:: rsqrt_
|
||||
.. automethod:: scatter_
|
||||
.. automethod:: scatter
|
||||
.. automethod:: scatter_
|
||||
.. automethod:: scatter_add_
|
||||
.. automethod:: scatter_add
|
||||
.. automethod:: select
|
||||
|
|
@ -373,11 +398,13 @@ view of a storage and defines numeric operations on it.
|
|||
.. automethod:: sort
|
||||
.. automethod:: split
|
||||
.. automethod:: sparse_mask
|
||||
.. automethod:: sparse_dim
|
||||
.. automethod:: sqrt
|
||||
.. automethod:: sqrt_
|
||||
.. automethod:: squeeze
|
||||
.. automethod:: squeeze_
|
||||
.. automethod:: std
|
||||
.. automethod:: stft
|
||||
.. automethod:: storage
|
||||
.. automethod:: storage_offset
|
||||
.. automethod:: storage_type
|
||||
|
|
@ -385,6 +412,7 @@ view of a storage and defines numeric operations on it.
|
|||
.. automethod:: sub
|
||||
.. automethod:: sub_
|
||||
.. automethod:: sum
|
||||
.. automethod:: sum_to_size
|
||||
.. automethod:: svd
|
||||
.. automethod:: symeig
|
||||
.. automethod:: t
|
||||
|
|
@ -411,14 +439,17 @@ view of a storage and defines numeric operations on it.
|
|||
.. automethod:: trunc_
|
||||
.. automethod:: type
|
||||
.. automethod:: type_as
|
||||
.. automethod:: unbind
|
||||
.. automethod:: unfold
|
||||
.. automethod:: uniform_
|
||||
.. automethod:: unique
|
||||
.. automethod:: unsqueeze
|
||||
.. automethod:: unsqueeze_
|
||||
.. automethod:: values
|
||||
.. automethod:: var
|
||||
.. automethod:: view
|
||||
.. automethod:: view_as
|
||||
.. automethod:: where
|
||||
.. automethod:: zero_
|
||||
|
||||
.. class:: ByteTensor()
|
||||
|
|
|
|||
|
|
@ -4,61 +4,80 @@ import os
|
|||
import re
|
||||
import ast
|
||||
import _ast
|
||||
import textwrap
|
||||
|
||||
|
||||
path = os.path.dirname(os.path.realpath(__file__))
|
||||
rstpath = os.path.join(path, '../docs/source/')
|
||||
pypath = os.path.join(path, '../torch/_torch_docs.py')
|
||||
r1 = re.compile(r'\.\. autofunction:: (\w*)')
|
||||
r2 = re.compile(r'\.\. auto(?:method|attribute):: (\w*)')
|
||||
|
||||
|
||||
class TestDocCoverage(unittest.TestCase):
|
||||
|
||||
def test_torch(self):
|
||||
# get symbols documented in torch.rst
|
||||
whitelist = [
|
||||
'set_printoptions', 'get_rng_state', 'is_storage', 'initial_seed',
|
||||
'set_default_tensor_type', 'load', 'save', 'set_default_dtype',
|
||||
'is_tensor', 'compiled_with_cxx11_abi', 'set_rng_state',
|
||||
'manual_seed'
|
||||
]
|
||||
everything = set()
|
||||
filename = os.path.join(rstpath, 'torch.rst')
|
||||
@staticmethod
|
||||
def parse_rst(filename, regex):
|
||||
filename = os.path.join(rstpath, filename)
|
||||
ret = set()
|
||||
with open(filename, 'r') as f:
|
||||
lines = f.readlines()
|
||||
for l in lines:
|
||||
l = l.strip()
|
||||
name = r1.findall(l)
|
||||
name = regex.findall(l)
|
||||
if name:
|
||||
everything.add(name[0])
|
||||
everything -= set(whitelist)
|
||||
ret.add(name[0])
|
||||
return ret
|
||||
|
||||
def test_torch(self):
|
||||
# get symbols documented in torch.rst
|
||||
in_rst = self.parse_rst('torch.rst', r1)
|
||||
# get symbols in functional.py and _torch_docs.py
|
||||
whitelist2 = ['product', 'inf', 'math', 'reduce', 'warnings', 'torch', 'annotate']
|
||||
everything2 = set()
|
||||
with open(pypath, 'r') as f:
|
||||
body = ast.parse(f.read()).body
|
||||
for i in body:
|
||||
if not isinstance(i, _ast.Expr):
|
||||
continue
|
||||
i = i.value
|
||||
if not isinstance(i, _ast.Call):
|
||||
continue
|
||||
if i.func.id != 'add_docstr':
|
||||
continue
|
||||
i = i.args[0]
|
||||
if i.value.id != 'torch':
|
||||
continue
|
||||
i = i.attr
|
||||
everything2.add(i)
|
||||
for p in dir(torch.functional):
|
||||
if not p.startswith('_') and p[0].islower():
|
||||
everything2.add(p)
|
||||
everything2 -= set(whitelist2)
|
||||
whitelist = {
|
||||
# below are some jit functions
|
||||
'wait', 'fork', 'parse_type_comment', 'import_ir_module',
|
||||
'to_batch_graph', 'import_ir_module_from_buffer',
|
||||
'register_batch_operator', 'merge_type_from_type_comment',
|
||||
|
||||
# below are symbols mistakely binded to torch.*, but should
|
||||
# go to torch.nn.functional.* instead
|
||||
'avg_pool1d', 'conv_transpose2d', 'conv_transpose1d', 'conv3d',
|
||||
'relu_', 'pixel_shuffle', 'conv2d', 'selu_', 'celu_', 'threshold_',
|
||||
'cosine_similarity', 'rrelu_', 'conv_transpose3d', 'conv1d', 'pdist',
|
||||
'adaptive_avg_pool1d', 'conv_tbc'
|
||||
}
|
||||
has_docstring = set(
|
||||
a for a in dir(torch)
|
||||
if getattr(torch, a).__doc__ and not a.startswith('_') and
|
||||
'function' in type(getattr(torch, a)).__name__)
|
||||
self.assertEqual(
|
||||
has_docstring & whitelist, whitelist,
|
||||
textwrap.dedent('''
|
||||
The whitelist in test_docs_coverage.py contains something
|
||||
that don't have docstring or not in torch.*. If you just
|
||||
removed something from torch.*, please remove it from whiltelist
|
||||
in test_docs_coverage.py'''))
|
||||
has_docstring -= whitelist
|
||||
# assert they are equal
|
||||
for p in everything:
|
||||
self.assertIn(p, everything2, 'in torch.rst but not in python')
|
||||
for p in everything2:
|
||||
self.assertIn(p, everything, 'in python but not in torch.rst')
|
||||
self.assertEqual(
|
||||
has_docstring, in_rst,
|
||||
textwrap.dedent('''
|
||||
List of functions documented in torch.rst and in python are different.
|
||||
Do you forget to add new thing to torch.rst, or whitelist things you
|
||||
don't want to document?''')
|
||||
)
|
||||
|
||||
def test_tensor(self):
|
||||
in_rst = self.parse_rst('tensors.rst', r2)
|
||||
classes = [torch.FloatTensor, torch.LongTensor, torch.ByteTensor]
|
||||
has_docstring = set(x for c in classes for x in dir(c) if not x.startswith('_') and getattr(c, x).__doc__)
|
||||
self.assertEqual(
|
||||
has_docstring, in_rst,
|
||||
textwrap.dedent('''
|
||||
List of tensor methods documented in tensor.rst and in python are
|
||||
different. Do you forget to add new thing to tensor.rst, or whitelist
|
||||
things you don't want to document?''')
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
|||
|
|
@ -186,20 +186,17 @@ class _TestTorchMixin(object):
|
|||
'as_strided_',
|
||||
re.compile('^clamp_(min|max)_?$'),
|
||||
'coalesce',
|
||||
'index_put',
|
||||
'is_coalesced',
|
||||
'is_distributed',
|
||||
'is_complex',
|
||||
'is_nonzero',
|
||||
'is_same_size',
|
||||
'is_signed',
|
||||
'isclose',
|
||||
'lgamma',
|
||||
'lgamma_',
|
||||
'log_softmax',
|
||||
'map2_',
|
||||
'new',
|
||||
'pin_memory',
|
||||
'polygamma',
|
||||
'polygamma_',
|
||||
'record_stream',
|
||||
|
|
@ -213,8 +210,6 @@ class _TestTorchMixin(object):
|
|||
'softmax',
|
||||
'split_with_sizes',
|
||||
'sspaddmm',
|
||||
'storage_type',
|
||||
'tan',
|
||||
'to_dense',
|
||||
'sparse_resize_',
|
||||
'sparse_resize_and_clear_',
|
||||
|
|
|
|||
|
|
@ -1203,6 +1203,13 @@ Args:
|
|||
accumulate (bool): whether to accumulate into self
|
||||
""")
|
||||
|
||||
add_docstr_all('index_put',
|
||||
r"""
|
||||
index_put(indices, value, accumulate=False) -> Tensor
|
||||
|
||||
Out-place version of :meth:`~Tensor.index_put_`
|
||||
""")
|
||||
|
||||
add_docstr_all('index_select',
|
||||
r"""
|
||||
index_select(dim, index) -> Tensor
|
||||
|
|
@ -1270,6 +1277,13 @@ is_floating_point() -> bool
|
|||
Returns True if the data type of :attr:`self` is a floating point data type.
|
||||
""")
|
||||
|
||||
add_docstr_all('is_signed',
|
||||
r"""
|
||||
is_signed() -> bool
|
||||
|
||||
Returns True if the data type of :attr:`self` is a signed data type.
|
||||
""")
|
||||
|
||||
add_docstr_all('is_set_to',
|
||||
r"""
|
||||
is_set_to(tensor) -> bool
|
||||
|
|
@ -2278,7 +2292,7 @@ add_docstr_all('storage',
|
|||
r"""
|
||||
storage() -> torch.Storage
|
||||
|
||||
Returns the underlying storage
|
||||
Returns the underlying storage.
|
||||
""")
|
||||
|
||||
add_docstr_all('storage_offset',
|
||||
|
|
@ -2298,6 +2312,13 @@ Example::
|
|||
|
||||
""")
|
||||
|
||||
add_docstr_all('storage_type',
|
||||
r"""
|
||||
storage_type() -> type
|
||||
|
||||
Returns the type of the underlying storage.
|
||||
""")
|
||||
|
||||
add_docstr_all('stride',
|
||||
r"""
|
||||
stride(dim) -> tuple or int
|
||||
|
|
@ -2505,6 +2526,13 @@ take(indices) -> Tensor
|
|||
See :func:`torch.take`
|
||||
""")
|
||||
|
||||
add_docstr_all('tan',
|
||||
r"""
|
||||
tan() -> Tensor
|
||||
|
||||
See :func:`torch.tan`
|
||||
""")
|
||||
|
||||
add_docstr_all('tan_',
|
||||
r"""
|
||||
tan_() -> Tensor
|
||||
|
|
@ -2980,6 +3008,13 @@ unbind(dim=0) -> seq
|
|||
See :func:`torch.unbind`
|
||||
""")
|
||||
|
||||
add_docstr_all('pin_memory',
|
||||
r"""
|
||||
pin_memory() -> Tensor
|
||||
|
||||
Copies the tensor to pinned memory, if it's not already pinned.
|
||||
""")
|
||||
|
||||
add_docstr_all('pinverse',
|
||||
r"""
|
||||
pinverse() -> Tensor
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user