mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/27173 `docs/source/named_tensor.rst` is the entry point; most users will land either here or the named tensor tutorial when looking to use named tensors. We should strive to make this as readable, concise, and understandable as possible. `docs/source/name_inference.rst` lists all of the name inference rules. It should be clear but it's hard to make it concise. Please let me know if anything doesn't make sense and please propose alternative wordings and/or restructuring to improve the documentation. This should ultimately get cherry-picked into the 1.3 branch as one monolithic commit so it would be good to get all necessary changes made in this PR and not have any follow ups. Test Plan: - built and reviewed locally with `cd docs/ && make html`. Differential Revision: D17763046 Pulled By: zou3519 fbshipit-source-id: c7872184fc4b189d405b18dad77cad6899ae1522
96 lines
3.8 KiB
Python
96 lines
3.8 KiB
Python
import torch
|
|
import unittest
|
|
import os
|
|
import re
|
|
import textwrap
|
|
|
|
|
|
path = os.path.dirname(os.path.realpath(__file__))
|
|
rstpath = os.path.join(path, '../docs/source/')
|
|
pypath = os.path.join(path, '../torch/_torch_docs.py')
|
|
r1 = re.compile(r'\.\. autofunction:: (\w*)')
|
|
r2 = re.compile(r'\.\. auto(?:method|attribute):: (\w*)')
|
|
|
|
|
|
class TestDocCoverage(unittest.TestCase):
|
|
|
|
@staticmethod
|
|
def parse_rst(filename, regex):
|
|
filename = os.path.join(rstpath, filename)
|
|
ret = set()
|
|
with open(filename, 'r') as f:
|
|
lines = f.readlines()
|
|
for l in lines:
|
|
l = l.strip()
|
|
name = regex.findall(l)
|
|
if name:
|
|
ret.add(name[0])
|
|
return ret
|
|
|
|
def test_torch(self):
|
|
# TODO: The algorithm here is kind of unsound; we don't assume
|
|
# every identifier in torch.rst lives in torch by virtue of
|
|
# where it lives; instead, it lives in torch because at the
|
|
# beginning of the file we specified automodule. This means
|
|
# that this script can get confused if you have, e.g., multiple
|
|
# automodule directives in the torch file. "Don't do that."
|
|
# (Or fix this to properly handle that case.)
|
|
|
|
# get symbols documented in torch.rst
|
|
in_rst = self.parse_rst('torch.rst', r1)
|
|
# get symbols in functional.py and _torch_docs.py
|
|
whitelist = {
|
|
# below are some jit functions
|
|
'wait', 'fork', 'parse_type_comment', 'import_ir_module',
|
|
'import_ir_module_from_buffer', 'merge_type_from_type_comment',
|
|
'parse_ir', 'parse_schema',
|
|
|
|
# below are symbols mistakely binded to torch.*, but should
|
|
# go to torch.nn.functional.* instead
|
|
'avg_pool1d', 'conv_transpose2d', 'conv_transpose1d', 'conv3d',
|
|
'relu_', 'pixel_shuffle', 'conv2d', 'selu_', 'celu_', 'threshold_',
|
|
'cosine_similarity', 'rrelu_', 'conv_transpose3d', 'conv1d', 'pdist',
|
|
'adaptive_avg_pool1d', 'conv_tbc'
|
|
}
|
|
has_docstring = set(
|
|
a for a in dir(torch)
|
|
if getattr(torch, a).__doc__ and not a.startswith('_') and
|
|
'function' in type(getattr(torch, a)).__name__)
|
|
self.assertEqual(
|
|
has_docstring & whitelist, whitelist,
|
|
textwrap.dedent('''
|
|
The whitelist in test_docs_coverage.py contains something
|
|
that doesn't have a docstring or isn't in torch.*. If you just
|
|
removed something from torch.*, please remove it from the whitelist
|
|
in test_docs_coverage.py'''))
|
|
has_docstring -= whitelist
|
|
# assert they are equal
|
|
self.assertEqual(
|
|
has_docstring, in_rst,
|
|
textwrap.dedent('''
|
|
The lists of functions documented in torch.rst and in python are different.
|
|
Did you forget to add a new thing to torch.rst, or whitelist things you
|
|
don't want to document?''')
|
|
)
|
|
|
|
def test_tensor(self):
|
|
in_rst = self.parse_rst('tensors.rst', r2)
|
|
whitelist = {
|
|
'names', 'unflatten', 'align_as', 'rename_', 'refine_names', 'align_to',
|
|
'has_names', 'rename',
|
|
}
|
|
classes = [torch.FloatTensor, torch.LongTensor, torch.ByteTensor]
|
|
has_docstring = set(x for c in classes for x in dir(c) if not x.startswith('_') and getattr(c, x).__doc__)
|
|
has_docstring -= whitelist
|
|
self.assertEqual(
|
|
has_docstring, in_rst,
|
|
textwrap.dedent('''
|
|
The lists of tensor methods documented in tensors.rst and in python are
|
|
different. Did you forget to add a new thing to tensors.rst, or whitelist
|
|
things you don't want to document?''')
|
|
)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|