Merge pull request #48653 from geetachavan1/cherrypicks_GFW62

[CherryPick:r2.5] Use ast instead of astunparse for python 3.9+.
This commit is contained in:
Mihai Maruseac 2021-04-22 15:25:27 -07:00 committed by GitHub
commit f4b96ff8b3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 126 additions and 85 deletions

View File

@ -28,12 +28,20 @@ from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import ast_util
from tensorflow.python.autograph.pyct import loader
from tensorflow.python.autograph.pyct import parser
from tensorflow.python.autograph.pyct import pretty_printer
from tensorflow.python.autograph.pyct import qual_names
from tensorflow.python.platform import test
class AstUtilTest(test.TestCase):
def assertAstMatches(self, actual_node, expected_node_src):
expected_node = gast.parse('({})'.format(expected_node_src)).body[0]
msg = 'AST did not match expected:\n{}\nActual:\n{}'.format(
pretty_printer.fmt(expected_node),
pretty_printer.fmt(actual_node))
self.assertTrue(ast_util.matches(actual_node, expected_node), msg)
def setUp(self):
super(AstUtilTest, self).setUp()
self._invocation_counts = collections.defaultdict(lambda: 0)
@ -44,10 +52,12 @@ class AstUtilTest(test.TestCase):
node = ast_util.rename_symbols(
node, {qual_names.QN('a'): qual_names.QN('renamed_a')})
source = parser.unparse(node, include_encoding_marker=False)
expected_node_src = 'renamed_a + b'
self.assertIsInstance(node.value.left.id, str)
source = parser.unparse(node, include_encoding_marker=False)
self.assertEqual(source.strip(), '(renamed_a + b)')
self.assertAstMatches(node, source)
self.assertAstMatches(node, expected_node_src)
def test_rename_symbols_attributes(self):
node = parser.parse('b.c = b.c.d')

View File

@ -41,14 +41,14 @@ from __future__ import division
from __future__ import print_function
import collections
import enum
import weakref
from enum import Enum
import astunparse
import gast
import six
from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import parser
class Node(object):
@ -87,9 +87,9 @@ class Node(object):
elif isinstance(self.ast_node, gast.ClassDef):
return 'class %s' % self.ast_node.name
elif isinstance(self.ast_node, gast.withitem):
return parser.unparse(
self.ast_node.context_expr, include_encoding_marker=False).strip()
return parser.unparse(self.ast_node, include_encoding_marker=False).strip()
# TODO(xjun): remove use of astunparse
return astunparse.unparse(self.ast_node.context_expr).strip()
return astunparse.unparse(self.ast_node).strip()
class Graph(
@ -142,7 +142,7 @@ class Graph(
return result
class _WalkMode(Enum):
class _WalkMode(enum.Enum):
FORWARD = 1
REVERSE = 2

View File

@ -24,14 +24,24 @@ import textwrap
import gast
from tensorflow.python.autograph.pyct import ast_util
from tensorflow.python.autograph.pyct import loader
from tensorflow.python.autograph.pyct import parser
from tensorflow.python.autograph.pyct import pretty_printer
from tensorflow.python.platform import test
from tensorflow.python.util import tf_inspect
class LoaderTest(test.TestCase):
def assertAstMatches(self, actual_node, expected_node_src):
expected_node = gast.parse(expected_node_src).body[0]
msg = 'AST did not match expected:\n{}\nActual:\n{}'.format(
pretty_printer.fmt(expected_node),
pretty_printer.fmt(actual_node))
self.assertTrue(ast_util.matches(actual_node, expected_node), msg)
def test_parse_load_identity(self):
def test_fn(x):
@ -43,11 +53,11 @@ class LoaderTest(test.TestCase):
node, _ = parser.parse_entity(test_fn, future_features=())
module, _, _ = loader.load_ast(node)
source = tf_inspect.getsource(module.test_fn)
expected_node_src = textwrap.dedent(tf_inspect.getsource(test_fn))
# astunparse uses fixed 4-space indenting.
self.assertEqual(
textwrap.dedent(tf_inspect.getsource(test_fn)),
tf_inspect.getsource(module.test_fn).replace(' ', ' '))
self.assertAstMatches(node, source)
self.assertAstMatches(node, expected_node_src)
def test_load_ast(self):
node = gast.FunctionDef(
@ -80,19 +90,19 @@ class LoaderTest(test.TestCase):
module, source, _ = loader.load_ast(node)
expected_source = """
expected_node_src = """
# coding=utf-8
def f(a):
return (a + 1)
"""
self.assertEqual(
textwrap.dedent(expected_source).strip(),
source.strip())
expected_node_src = textwrap.dedent(expected_node_src)
self.assertAstMatches(node, source)
self.assertAstMatches(node, expected_node_src)
self.assertEqual(2, module.f(1))
with open(module.__file__, 'r') as temp_output:
self.assertEqual(
textwrap.dedent(expected_source).strip(),
temp_output.read().strip())
self.assertAstMatches(node, temp_output.read())
def test_load_source(self):
test_source = textwrap.dedent(u"""

View File

@ -21,6 +21,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import ast
import inspect
import linecache
import re
@ -44,6 +45,9 @@ from __future__ import print_function
PY3_PREAMBLE = ''
MAX_SIZE = 0
if sys.version_info >= (3, 9):
astunparse = ast
if sys.version_info >= (3,):
STANDARD_PREAMBLE = PY3_PREAMBLE
MAX_SIZE = sys.maxsize
@ -386,7 +390,12 @@ def unparse(node, indentation=None, include_encoding_marker=True):
codes.append('# coding=utf-8')
for n in node:
if isinstance(n, gast.AST):
n = gast.gast_to_ast(n)
codes.append(astunparse.unparse(n).strip())
ast_n = gast.gast_to_ast(n)
else:
ast_n = n
if astunparse is ast:
ast.fix_missing_locations(ast_n) # Only ast needs to call this.
codes.append(astunparse.unparse(ast_n).strip())
return '\n'.join(codes)

View File

@ -23,13 +23,28 @@ import textwrap
import gast
from tensorflow.python.autograph.pyct import ast_util
from tensorflow.python.autograph.pyct import errors
from tensorflow.python.autograph.pyct import parser
from tensorflow.python.autograph.pyct import pretty_printer
from tensorflow.python.platform import test
class ParserTest(test.TestCase):
def assertAstMatches(self, actual_node, expected_node_src, expr=True):
if expr:
# Ensure multi-line expressions parse.
expected_node = gast.parse('({})'.format(expected_node_src)).body[0]
expected_node = expected_node.value
else:
expected_node = gast.parse(expected_node_src).body[0]
msg = 'AST did not match expected:\n{}\nActual:\n{}'.format(
pretty_printer.fmt(expected_node),
pretty_printer.fmt(actual_node))
self.assertTrue(ast_util.matches(actual_node, expected_node), msg)
def test_parse_entity(self):
def f(x):
@ -41,33 +56,31 @@ class ParserTest(test.TestCase):
def test_parse_lambda(self):
l = lambda x: x + 1
expected_node_src = 'lambda x: (x + 1)'
node, source = parser.parse_entity(l, future_features=())
self.assertEqual(
parser.unparse(node, include_encoding_marker=False),
'(lambda x: (x + 1))')
self.assertEqual(source, 'lambda x: x + 1')
self.assertAstMatches(node, source)
self.assertAstMatches(node, expected_node_src)
def test_parse_lambda_prefix_cleanup(self):
lambda_lam = lambda x: x + 1
expected_node_src = 'lambda x: (x + 1)'
node, source = parser.parse_entity(lambda_lam, future_features=())
self.assertEqual(
parser.unparse(node, include_encoding_marker=False),
'(lambda x: (x + 1))')
self.assertEqual(source, 'lambda x: x + 1')
self.assertAstMatches(node, source)
self.assertAstMatches(node, expected_node_src)
def test_parse_lambda_resolution_by_location(self):
_ = lambda x: x + 1
l = lambda x: x + 1
_ = lambda x: x + 1
expected_node_src = 'lambda x: (x + 1)'
node, source = parser.parse_entity(l, future_features=())
self.assertEqual(
parser.unparse(node, include_encoding_marker=False),
'(lambda x: (x + 1))')
self.assertAstMatches(node, source)
self.assertAstMatches(node, expected_node_src)
self.assertEqual(source, 'lambda x: x + 1')
def test_parse_lambda_resolution_by_signature(self):
@ -75,15 +88,15 @@ class ParserTest(test.TestCase):
l = lambda x: lambda x, y: x + y
node, source = parser.parse_entity(l, future_features=())
self.assertEqual(
parser.unparse(node, include_encoding_marker=False),
'(lambda x: (lambda x, y: (x + y)))')
expected_node_src = 'lambda x: (lambda x, y: (x + y))'
self.assertAstMatches(node, source)
self.assertAstMatches(node, expected_node_src)
self.assertEqual(source, 'lambda x: lambda x, y: x + y')
node, source = parser.parse_entity(l(0), future_features=())
self.assertEqual(
parser.unparse(node, include_encoding_marker=False),
'(lambda x, y: (x + y))')
expected_node_src = 'lambda x, y: (x + y)'
self.assertAstMatches(node, source)
self.assertAstMatches(node, expected_node_src)
self.assertEqual(source, 'lambda x, y: x + y')
def test_parse_lambda_resolution_ambiguous(self):
@ -92,9 +105,9 @@ class ParserTest(test.TestCase):
expected_exception_text = re.compile(r'found multiple definitions'
r'.+'
r'\(lambda x: \(lambda x'
r'\(?lambda x: \(?lambda x'
r'.+'
r'\(lambda x: \(2', re.DOTALL)
r'\(?lambda x: \(?2', re.DOTALL)
with self.assertRaisesRegex(
errors.UnsupportedLanguageElementError,
@ -118,17 +131,15 @@ class ParserTest(test.TestCase):
- 1)
node, source = parser.parse_entity(l, future_features=())
self.assertEqual(
parser.unparse(node, include_encoding_marker=False),
'(lambda x: (lambda y: ((x + y) - 1)))')
expected_node_src = 'lambda x: (lambda y: ((x + y) - 1))'
self.assertAstMatches(node, expected_node_src)
self.assertMatchesWithPotentialGarbage(
source, ('lambda x: lambda y: x + y # pylint:disable=g-long-lambda\n'
' - 1'), ')')
node, source = parser.parse_entity(l(0), future_features=())
self.assertEqual(
parser.unparse(node, include_encoding_marker=False),
'(lambda y: ((x + y) - 1))')
expected_node_src = 'lambda y: ((x + y) - 1)'
self.assertAstMatches(node, expected_node_src)
self.assertMatchesWithPotentialGarbage(
source, ('lambda y: x + y # pylint:disable=g-long-lambda\n'
' - 1'), ')')
@ -141,30 +152,26 @@ class ParserTest(test.TestCase):
)
node, source = parser.parse_entity(l[0], future_features=())
self.assertEqual(
parser.unparse(node, include_encoding_marker=False),
'(lambda x: (lambda y: ((x + y) + 1)))')
expected_node_src = 'lambda x: (lambda y: ((x + y) + 1))'
self.assertAstMatches(node, expected_node_src)
self.assertMatchesWithPotentialGarbage(
source, 'lambda x: lambda y: x + y + 1', ',')
node, source = parser.parse_entity(l[0](0), future_features=())
self.assertEqual(
parser.unparse(node, include_encoding_marker=False),
'(lambda y: ((x + y) + 1))')
expected_node_src = 'lambda y: ((x + y) + 1)'
self.assertAstMatches(node, expected_node_src)
self.assertMatchesWithPotentialGarbage(
source, 'lambda y: x + y + 1', ',')
node, source = parser.parse_entity(l[1], future_features=())
self.assertEqual(
parser.unparse(node, include_encoding_marker=False),
'(lambda x: (lambda y: ((x + y) + 2)))')
expected_node_src = 'lambda x: (lambda y: ((x + y) + 2))'
self.assertAstMatches(node, expected_node_src)
self.assertMatchesWithPotentialGarbage(source,
'lambda x: lambda y: x + y + 2', ',')
node, source = parser.parse_entity(l[1](0), future_features=())
self.assertEqual(
parser.unparse(node, include_encoding_marker=False),
'(lambda y: ((x + y) + 2))')
expected_node_src = 'lambda y: ((x + y) + 2)'
self.assertAstMatches(node, expected_node_src)
self.assertMatchesWithPotentialGarbage(source, 'lambda y: x + y + 2', ',')
def test_parse_lambda_complex_body(self):
@ -182,9 +189,9 @@ class ParserTest(test.TestCase):
)
node, source = parser.parse_entity(l, future_features=())
self.assertEqual(
parser.unparse(node, include_encoding_marker=False),
"(lambda x: (x.y([], x.z, (), x[0:2]), x.u, 'abc', 1))")
expected_node_src = "lambda x: (x.y([], x.z, (), x[0:2]), x.u, 'abc', 1)"
self.assertAstMatches(node, expected_node_src)
base_source = ('lambda x: ( # pylint:disable=g-long-lambda\n'
' x.y(\n'
' [],\n'
@ -197,16 +204,14 @@ class ParserTest(test.TestCase):
' 1,')
# The complete source includes the trailing parenthesis. But that is only
# detected in runtimes which correctly track end_lineno for ASTs.
self.assertIn(source, (base_source, base_source + '\n )'))
self.assertMatchesWithPotentialGarbage(source, base_source, '\n )')
def test_parse_lambda_function_call_definition(self):
def do_parse_and_test(lam, **unused_kwargs):
node, source = parser.parse_entity(lam, future_features=())
self.assertEqual(
parser.unparse(node, include_encoding_marker=False),
'(lambda x: x)')
expected_node_src = 'lambda x: x'
self.assertAstMatches(node, expected_node_src)
self.assertMatchesWithPotentialGarbage(
source, 'lambda x: x', ', named_arg=1)')
@ -372,6 +377,13 @@ string""")
a = 'c'
""").strip(), source.strip())
def test_ext_slice_roundtrip(self):
def ext_slice(n):
return n[:, :], n[0, :], n[:, 0]
node, _ = parser.parse_entity(ext_slice, future_features=())
source = parser.unparse(node)
self.assertAstMatches(node, source, expr=False)
if __name__ == '__main__':
test.main()

View File

@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import re
import gast
from tensorflow.python.autograph.pyct import anno
@ -180,10 +181,10 @@ class TransformerTest(test.TestCase):
node = tr.visit(node)
self.assertEqual(len(node.body), 2)
self.assertTrue(isinstance(node.body[0], gast.Assign))
self.assertTrue(isinstance(node.body[1], gast.If))
self.assertTrue(isinstance(node.body[1].body[0], gast.Assign))
self.assertTrue(isinstance(node.body[1].body[1], gast.Return))
self.assertIsInstance(node.body[0], gast.Assign)
self.assertIsInstance(node.body[1], gast.If)
self.assertIsInstance(node.body[1].body[0], gast.Assign)
self.assertIsInstance(node.body[1].body[1], gast.Return)
def test_robust_error_on_list_visit(self):
@ -244,7 +245,7 @@ class TransformerTest(test.TestCase):
# The message should reference the exception actually raised, not anything
# from the exception handler.
expected_substring = 'I blew up'
self.assertTrue(expected_substring in obtained_message, obtained_message)
self.assertIn(expected_substring, obtained_message)
def test_origin_info_propagated_to_new_nodes(self):
@ -347,20 +348,19 @@ class CodeGeneratorTest(test.TestCase):
origin_info.resolve(node, source, 'test_file', 100, 0)
tg.visit(node)
self.assertEqual(
tg.code_buffer, '\n'.join([
'x = 1',
'if (x > 0) {',
'x = 2',
'if (x > 1) {',
'x = 3',
'} else {',
'}',
'} else {',
'}',
'return x',
'',
]))
r = re.compile('.*'.join([
r'x = 1',
r'if \(?x > 0\)? {',
r'x = 2',
r'if \(?x > 1\)? {',
r'x = 3',
r'} else {',
r'}',
r'} else {',
r'}',
r'return x']), re.DOTALL)
self.assertRegex(tg.code_buffer, r)
# TODO(mdan): Test the source map.