[TF-numpy] Adds a switch to control whether to inline the original numpy docstrings or to just add links.

PiperOrigin-RevId: 320514706
Change-Id: Ibc12d6b4fb84fbfcd27076de244363b63d3f86c1
This commit is contained in:
Peng Wang 2020-07-09 18:16:30 -07:00 committed by TensorFlower Gardener
parent 902443f6bb
commit d0c63bb151
2 changed files with 137 additions and 17 deletions

View File

@ -22,6 +22,7 @@ from __future__ import print_function
import inspect
import numbers
import os
import re
import numpy as np
from tensorflow.python.framework import dtypes
@ -220,12 +221,79 @@ def _np_doc_helper(f, np_f, np_fun_name=None, unsupported_params=None):
doc = _add_blank_line(doc)
# TODO(wangpeng): Re-enable the following and choose inlined vs. link to numpy
# doc according to some global switch.
# if _has_docstring(np_f):
# doc += 'Documentation for `numpy.%s`:\n\n' % np_f.__name__
# # TODO(wangpeng): It looks like code snippets in numpy doc don't work
# # correctly with doctest. Fix that and remove the reformatting of the np_f
# # comment.
# doc += np_f.__doc__.replace('>>>', '>')
doc = _add_np_doc(doc, np_fun_name, np_f)
return doc
_np_doc_form = os.getenv('TF_NP_DOC_FORM', '1.16')
def get_np_doc_form():
"""Gets the form of the original numpy docstrings.
Returns:
See `set_np_doc_form` for the list of valid values.
"""
return _np_doc_form
def set_np_doc_form(value):
r"""Selects the form of the original numpy docstrings.
This function sets a global variable that controls how a tf-numpy symbol's
docstring should refer to the original numpy docstring. If `value` is
`'inlined'`, the numpy docstring will be verbatim copied into the tf-numpy
docstring. Otherwise, a link to the original numpy docstring will be
added. Which numpy version the link points to depends on `value`:
* `'stable'`: the current stable version;
* `'dev'`: the current development version;
* pattern `\d+(\.\d+(\.\d+)?)?`: `value` will be treated as a version number,
e.g. '1.16'.
Args:
value: the value to set the global variable to.
"""
global _np_doc_form
_np_doc_form = value
def _add_np_doc(doc, np_fun_name, np_f):
"""Appends the numpy docstring to `doc`, according to `set_np_doc_form`.
See `set_np_doc_form` for how it controls the form of the numpy docstring.
Args:
doc: the docstring to be appended to.
np_fun_name: the name of the numpy function.
np_f: (optional) the numpy function.
Returns:
`doc` with numpy docstring appended.
"""
flag = get_np_doc_form()
if flag == 'inlined':
if _has_docstring(np_f):
doc += 'Documentation for `numpy.%s`:\n\n' % np_fun_name
# TODO(wangpeng): It looks like code snippets in numpy doc don't work
# correctly with doctest. Fix that and remove the reformatting of the np_f
# comment.
doc += np_f.__doc__.replace('>>>', '>')
elif isinstance(flag, str):
# Only adds link in this case
if flag == 'dev':
template = 'https://numpy.org/devdocs/reference/generated/numpy.%s.html'
elif flag == 'stable':
template = (
'https://numpy.org/doc/stable/reference/generated/numpy.%s.html')
elif re.match(r'\d+(\.\d+(\.\d+)?)?$', flag):
# `flag` is the version number
template = ('https://numpy.org/doc/' + flag +
'/reference/generated/numpy.%s.html')
else:
template = None
if template is not None:
link = template % np_fun_name
doc += 'See the documentation for `numpy.%s`: [%s]' % (np_fun_name, link)
return doc

View File

@ -18,24 +18,82 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python.ops.numpy_ops import np_utils
from tensorflow.python.platform import test
class UtilsTest(test.TestCase):
class UtilsTest(test.TestCase, parameterized.TestCase):
def setUp(self):
super(UtilsTest, self).setUp()
self._old_np_doc_form = np_utils.get_np_doc_form()
self._old_is_sig_mismatch_an_error = np_utils.is_sig_mismatch_an_error()
def tearDown(self):
np_utils.set_np_doc_form(self._old_np_doc_form)
np_utils.set_is_sig_mismatch_an_error(self._old_is_sig_mismatch_an_error)
super(UtilsTest, self).tearDown()
# pylint: disable=unused-argument
def testNpDoc(self):
def testNpDocInlined(self):
def np_fun(x):
"""np_fun docstring."""
return
np_utils.set_np_doc_form('inlined')
@np_utils.np_doc(None, np_fun=np_fun)
def f():
"""f docstring."""
return
expected = """TensorFlow variant of `numpy.np_fun`.
Unsupported arguments: `x`.
f docstring.
Documentation for `numpy.np_fun`:
np_fun docstring."""
self.assertEqual(expected, f.__doc__)
@parameterized.named_parameters([
(version, version, link) for version, link in # pylint: disable=g-complex-comprehension
[('dev',
'https://numpy.org/devdocs/reference/generated/numpy.np_fun.html'),
('stable',
'https://numpy.org/doc/stable/reference/generated/numpy.np_fun.html'),
('1.16',
'https://numpy.org/doc/1.16/reference/generated/numpy.np_fun.html')
]])
def testNpDocLink(self, version, link):
def np_fun(x):
"""np_fun docstring."""
return
np_utils.set_np_doc_form(version)
@np_utils.np_doc(None, np_fun=np_fun)
def f():
"""f docstring."""
return
expected = """TensorFlow variant of `numpy.np_fun`.
Unsupported arguments: `x`.
f docstring.
See the documentation for `numpy.np_fun`: [%s]""" % link
self.assertEqual(expected, f.__doc__)
@parameterized.parameters([None, 1, 'a', '1a', '1.1a', '1.1.1a'])
def testNpDocInvalid(self, invalid_flag):
def np_fun(x):
"""np_fun docstring."""
return
np_utils.set_np_doc_form(invalid_flag)
@np_utils.np_doc(None, np_fun=np_fun)
def f():
"""f docstring."""
return
expected = """TensorFlow variant of `numpy.np_fun`.
Unsupported arguments: `x`.
@ -46,7 +104,7 @@ f docstring.
self.assertEqual(expected, f.__doc__)
def testNpDocName(self):
np_utils.set_np_doc_form('inlined')
@np_utils.np_doc('foo')
def f():
"""f docstring."""
@ -64,7 +122,6 @@ f docstring.
if not np_utils._supports_signature():
self.skipTest('inspect.signature not supported')
old_flag = np_utils.is_sig_mismatch_an_error()
np_utils.set_is_sig_mismatch_an_error(True)
def np_fun(x, y=1, **kwargs):
@ -86,11 +143,8 @@ f docstring.
def f3(x, y):
return
np_utils.set_is_sig_mismatch_an_error(old_flag)
def testSigMismatchIsNotError(self):
"""Tests that signature mismatch is not an error (when configured so)."""
old_flag = np_utils.is_sig_mismatch_an_error()
np_utils.set_is_sig_mismatch_an_error(False)
def np_fun(x, y=1, **kwargs):
@ -110,8 +164,6 @@ f docstring.
def f3(x, y):
return
np_utils.set_is_sig_mismatch_an_error(old_flag)
# pylint: enable=unused-variable