mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: expose necessary functions to python, and add round-way tests for function schema str() and parsing functions. We iterate over all the registered function schemas and get the string, then parse the string. We compare the schema generated from parsing with the original one, and make sure they are equal. Pull Request resolved: https://github.com/pytorch/pytorch/pull/23208 ghstack-source-id: 89638026 Test Plan: buck test //caffe2/test:function_schema Reviewed By: zrphercule Differential Revision: D16435471 fbshipit-source-id: 6961ab096335eb88a96b132575996c24090fd4c0
91 lines
3.6 KiB
Python
91 lines
3.6 KiB
Python
import torch
|
|
import unittest
|
|
import os
|
|
import re
|
|
import textwrap
|
|
|
|
|
|
path = os.path.dirname(os.path.realpath(__file__))
|
|
rstpath = os.path.join(path, '../docs/source/')
|
|
pypath = os.path.join(path, '../torch/_torch_docs.py')
|
|
r1 = re.compile(r'\.\. autofunction:: (\w*)')
|
|
r2 = re.compile(r'\.\. auto(?:method|attribute):: (\w*)')
|
|
|
|
|
|
class TestDocCoverage(unittest.TestCase):
|
|
|
|
@staticmethod
|
|
def parse_rst(filename, regex):
|
|
filename = os.path.join(rstpath, filename)
|
|
ret = set()
|
|
with open(filename, 'r') as f:
|
|
lines = f.readlines()
|
|
for l in lines:
|
|
l = l.strip()
|
|
name = regex.findall(l)
|
|
if name:
|
|
ret.add(name[0])
|
|
return ret
|
|
|
|
def test_torch(self):
|
|
# TODO: The algorithm here is kind of unsound; we don't assume
|
|
# every identifier in torch.rst lives in torch by virtue of
|
|
# where it lives; instead, it lives in torch because at the
|
|
# beginning of the file we specified automodule. This means
|
|
# that this script can get confused if you have, e.g., multiple
|
|
# automodule directives in the torch file. "Don't do that."
|
|
# (Or fix this to properly handle that case.)
|
|
|
|
# get symbols documented in torch.rst
|
|
in_rst = self.parse_rst('torch.rst', r1)
|
|
# get symbols in functional.py and _torch_docs.py
|
|
whitelist = {
|
|
# below are some jit functions
|
|
'wait', 'fork', 'parse_type_comment', 'import_ir_module',
|
|
'import_ir_module_from_buffer', 'merge_type_from_type_comment',
|
|
'parse_ir', 'parse_schema',
|
|
|
|
# below are symbols mistakely binded to torch.*, but should
|
|
# go to torch.nn.functional.* instead
|
|
'avg_pool1d', 'conv_transpose2d', 'conv_transpose1d', 'conv3d',
|
|
'relu_', 'pixel_shuffle', 'conv2d', 'selu_', 'celu_', 'threshold_',
|
|
'cosine_similarity', 'rrelu_', 'conv_transpose3d', 'conv1d', 'pdist',
|
|
'adaptive_avg_pool1d', 'conv_tbc'
|
|
}
|
|
has_docstring = set(
|
|
a for a in dir(torch)
|
|
if getattr(torch, a).__doc__ and not a.startswith('_') and
|
|
'function' in type(getattr(torch, a)).__name__)
|
|
self.assertEqual(
|
|
has_docstring & whitelist, whitelist,
|
|
textwrap.dedent('''
|
|
The whitelist in test_docs_coverage.py contains something
|
|
that doesn't have a docstring or isn't in torch.*. If you just
|
|
removed something from torch.*, please remove it from the whitelist
|
|
in test_docs_coverage.py'''))
|
|
has_docstring -= whitelist
|
|
# assert they are equal
|
|
self.assertEqual(
|
|
has_docstring, in_rst,
|
|
textwrap.dedent('''
|
|
The lists of functions documented in torch.rst and in python are different.
|
|
Did you forget to add a new thing to torch.rst, or whitelist things you
|
|
don't want to document?''')
|
|
)
|
|
|
|
def test_tensor(self):
|
|
in_rst = self.parse_rst('tensors.rst', r2)
|
|
classes = [torch.FloatTensor, torch.LongTensor, torch.ByteTensor]
|
|
has_docstring = set(x for c in classes for x in dir(c) if not x.startswith('_') and getattr(c, x).__doc__)
|
|
self.assertEqual(
|
|
has_docstring, in_rst,
|
|
textwrap.dedent('''
|
|
The lists of tensor methods documented in tensors.rst and in python are
|
|
different. Did you forget to add a new thing to tensors.rst, or whitelist
|
|
things you don't want to document?''')
|
|
)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|