mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Merge pull request #48653 from geetachavan1/cherrypicks_GFW62
[CherryPick:r2.5] Use ast instead of astunparse for python 3.9+.
This commit is contained in:
commit
f4b96ff8b3
|
|
@ -28,12 +28,20 @@ from tensorflow.python.autograph.pyct import anno
|
|||
from tensorflow.python.autograph.pyct import ast_util
|
||||
from tensorflow.python.autograph.pyct import loader
|
||||
from tensorflow.python.autograph.pyct import parser
|
||||
from tensorflow.python.autograph.pyct import pretty_printer
|
||||
from tensorflow.python.autograph.pyct import qual_names
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class AstUtilTest(test.TestCase):
|
||||
|
||||
def assertAstMatches(self, actual_node, expected_node_src):
|
||||
expected_node = gast.parse('({})'.format(expected_node_src)).body[0]
|
||||
msg = 'AST did not match expected:\n{}\nActual:\n{}'.format(
|
||||
pretty_printer.fmt(expected_node),
|
||||
pretty_printer.fmt(actual_node))
|
||||
self.assertTrue(ast_util.matches(actual_node, expected_node), msg)
|
||||
|
||||
def setUp(self):
|
||||
super(AstUtilTest, self).setUp()
|
||||
self._invocation_counts = collections.defaultdict(lambda: 0)
|
||||
|
|
@ -44,10 +52,12 @@ class AstUtilTest(test.TestCase):
|
|||
|
||||
node = ast_util.rename_symbols(
|
||||
node, {qual_names.QN('a'): qual_names.QN('renamed_a')})
|
||||
source = parser.unparse(node, include_encoding_marker=False)
|
||||
expected_node_src = 'renamed_a + b'
|
||||
|
||||
self.assertIsInstance(node.value.left.id, str)
|
||||
source = parser.unparse(node, include_encoding_marker=False)
|
||||
self.assertEqual(source.strip(), '(renamed_a + b)')
|
||||
self.assertAstMatches(node, source)
|
||||
self.assertAstMatches(node, expected_node_src)
|
||||
|
||||
def test_rename_symbols_attributes(self):
|
||||
node = parser.parse('b.c = b.c.d')
|
||||
|
|
|
|||
|
|
@ -41,14 +41,14 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import enum
|
||||
import weakref
|
||||
from enum import Enum
|
||||
|
||||
import astunparse
|
||||
import gast
|
||||
import six
|
||||
|
||||
from tensorflow.python.autograph.pyct import anno
|
||||
from tensorflow.python.autograph.pyct import parser
|
||||
|
||||
|
||||
class Node(object):
|
||||
|
|
@ -87,9 +87,9 @@ class Node(object):
|
|||
elif isinstance(self.ast_node, gast.ClassDef):
|
||||
return 'class %s' % self.ast_node.name
|
||||
elif isinstance(self.ast_node, gast.withitem):
|
||||
return parser.unparse(
|
||||
self.ast_node.context_expr, include_encoding_marker=False).strip()
|
||||
return parser.unparse(self.ast_node, include_encoding_marker=False).strip()
|
||||
# TODO(xjun): remove use of astunparse
|
||||
return astunparse.unparse(self.ast_node.context_expr).strip()
|
||||
return astunparse.unparse(self.ast_node).strip()
|
||||
|
||||
|
||||
class Graph(
|
||||
|
|
@ -142,7 +142,7 @@ class Graph(
|
|||
return result
|
||||
|
||||
|
||||
class _WalkMode(Enum):
|
||||
class _WalkMode(enum.Enum):
|
||||
FORWARD = 1
|
||||
REVERSE = 2
|
||||
|
||||
|
|
|
|||
|
|
@ -24,14 +24,24 @@ import textwrap
|
|||
|
||||
import gast
|
||||
|
||||
from tensorflow.python.autograph.pyct import ast_util
|
||||
from tensorflow.python.autograph.pyct import loader
|
||||
from tensorflow.python.autograph.pyct import parser
|
||||
from tensorflow.python.autograph.pyct import pretty_printer
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.util import tf_inspect
|
||||
|
||||
|
||||
class LoaderTest(test.TestCase):
|
||||
|
||||
def assertAstMatches(self, actual_node, expected_node_src):
|
||||
expected_node = gast.parse(expected_node_src).body[0]
|
||||
|
||||
msg = 'AST did not match expected:\n{}\nActual:\n{}'.format(
|
||||
pretty_printer.fmt(expected_node),
|
||||
pretty_printer.fmt(actual_node))
|
||||
self.assertTrue(ast_util.matches(actual_node, expected_node), msg)
|
||||
|
||||
def test_parse_load_identity(self):
|
||||
|
||||
def test_fn(x):
|
||||
|
|
@ -43,11 +53,11 @@ class LoaderTest(test.TestCase):
|
|||
|
||||
node, _ = parser.parse_entity(test_fn, future_features=())
|
||||
module, _, _ = loader.load_ast(node)
|
||||
source = tf_inspect.getsource(module.test_fn)
|
||||
expected_node_src = textwrap.dedent(tf_inspect.getsource(test_fn))
|
||||
|
||||
# astunparse uses fixed 4-space indenting.
|
||||
self.assertEqual(
|
||||
textwrap.dedent(tf_inspect.getsource(test_fn)),
|
||||
tf_inspect.getsource(module.test_fn).replace(' ', ' '))
|
||||
self.assertAstMatches(node, source)
|
||||
self.assertAstMatches(node, expected_node_src)
|
||||
|
||||
def test_load_ast(self):
|
||||
node = gast.FunctionDef(
|
||||
|
|
@ -80,19 +90,19 @@ class LoaderTest(test.TestCase):
|
|||
|
||||
module, source, _ = loader.load_ast(node)
|
||||
|
||||
expected_source = """
|
||||
expected_node_src = """
|
||||
# coding=utf-8
|
||||
def f(a):
|
||||
return (a + 1)
|
||||
"""
|
||||
self.assertEqual(
|
||||
textwrap.dedent(expected_source).strip(),
|
||||
source.strip())
|
||||
expected_node_src = textwrap.dedent(expected_node_src)
|
||||
|
||||
self.assertAstMatches(node, source)
|
||||
self.assertAstMatches(node, expected_node_src)
|
||||
|
||||
self.assertEqual(2, module.f(1))
|
||||
with open(module.__file__, 'r') as temp_output:
|
||||
self.assertEqual(
|
||||
textwrap.dedent(expected_source).strip(),
|
||||
temp_output.read().strip())
|
||||
self.assertAstMatches(node, temp_output.read())
|
||||
|
||||
def test_load_source(self):
|
||||
test_source = textwrap.dedent(u"""
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import ast
|
||||
import inspect
|
||||
import linecache
|
||||
import re
|
||||
|
|
@ -44,6 +45,9 @@ from __future__ import print_function
|
|||
PY3_PREAMBLE = ''
|
||||
MAX_SIZE = 0
|
||||
|
||||
if sys.version_info >= (3, 9):
|
||||
astunparse = ast
|
||||
|
||||
if sys.version_info >= (3,):
|
||||
STANDARD_PREAMBLE = PY3_PREAMBLE
|
||||
MAX_SIZE = sys.maxsize
|
||||
|
|
@ -386,7 +390,12 @@ def unparse(node, indentation=None, include_encoding_marker=True):
|
|||
codes.append('# coding=utf-8')
|
||||
for n in node:
|
||||
if isinstance(n, gast.AST):
|
||||
n = gast.gast_to_ast(n)
|
||||
codes.append(astunparse.unparse(n).strip())
|
||||
ast_n = gast.gast_to_ast(n)
|
||||
else:
|
||||
ast_n = n
|
||||
|
||||
if astunparse is ast:
|
||||
ast.fix_missing_locations(ast_n) # Only ast needs to call this.
|
||||
codes.append(astunparse.unparse(ast_n).strip())
|
||||
|
||||
return '\n'.join(codes)
|
||||
|
|
|
|||
|
|
@ -23,13 +23,28 @@ import textwrap
|
|||
|
||||
import gast
|
||||
|
||||
from tensorflow.python.autograph.pyct import ast_util
|
||||
from tensorflow.python.autograph.pyct import errors
|
||||
from tensorflow.python.autograph.pyct import parser
|
||||
from tensorflow.python.autograph.pyct import pretty_printer
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class ParserTest(test.TestCase):
|
||||
|
||||
def assertAstMatches(self, actual_node, expected_node_src, expr=True):
|
||||
if expr:
|
||||
# Ensure multi-line expressions parse.
|
||||
expected_node = gast.parse('({})'.format(expected_node_src)).body[0]
|
||||
expected_node = expected_node.value
|
||||
else:
|
||||
expected_node = gast.parse(expected_node_src).body[0]
|
||||
|
||||
msg = 'AST did not match expected:\n{}\nActual:\n{}'.format(
|
||||
pretty_printer.fmt(expected_node),
|
||||
pretty_printer.fmt(actual_node))
|
||||
self.assertTrue(ast_util.matches(actual_node, expected_node), msg)
|
||||
|
||||
def test_parse_entity(self):
|
||||
|
||||
def f(x):
|
||||
|
|
@ -41,33 +56,31 @@ class ParserTest(test.TestCase):
|
|||
def test_parse_lambda(self):
|
||||
|
||||
l = lambda x: x + 1
|
||||
expected_node_src = 'lambda x: (x + 1)'
|
||||
|
||||
node, source = parser.parse_entity(l, future_features=())
|
||||
self.assertEqual(
|
||||
parser.unparse(node, include_encoding_marker=False),
|
||||
'(lambda x: (x + 1))')
|
||||
self.assertEqual(source, 'lambda x: x + 1')
|
||||
self.assertAstMatches(node, source)
|
||||
self.assertAstMatches(node, expected_node_src)
|
||||
|
||||
def test_parse_lambda_prefix_cleanup(self):
|
||||
|
||||
lambda_lam = lambda x: x + 1
|
||||
expected_node_src = 'lambda x: (x + 1)'
|
||||
|
||||
node, source = parser.parse_entity(lambda_lam, future_features=())
|
||||
self.assertEqual(
|
||||
parser.unparse(node, include_encoding_marker=False),
|
||||
'(lambda x: (x + 1))')
|
||||
self.assertEqual(source, 'lambda x: x + 1')
|
||||
self.assertAstMatches(node, source)
|
||||
self.assertAstMatches(node, expected_node_src)
|
||||
|
||||
def test_parse_lambda_resolution_by_location(self):
|
||||
|
||||
_ = lambda x: x + 1
|
||||
l = lambda x: x + 1
|
||||
_ = lambda x: x + 1
|
||||
expected_node_src = 'lambda x: (x + 1)'
|
||||
|
||||
node, source = parser.parse_entity(l, future_features=())
|
||||
self.assertEqual(
|
||||
parser.unparse(node, include_encoding_marker=False),
|
||||
'(lambda x: (x + 1))')
|
||||
self.assertAstMatches(node, source)
|
||||
self.assertAstMatches(node, expected_node_src)
|
||||
self.assertEqual(source, 'lambda x: x + 1')
|
||||
|
||||
def test_parse_lambda_resolution_by_signature(self):
|
||||
|
|
@ -75,15 +88,15 @@ class ParserTest(test.TestCase):
|
|||
l = lambda x: lambda x, y: x + y
|
||||
|
||||
node, source = parser.parse_entity(l, future_features=())
|
||||
self.assertEqual(
|
||||
parser.unparse(node, include_encoding_marker=False),
|
||||
'(lambda x: (lambda x, y: (x + y)))')
|
||||
expected_node_src = 'lambda x: (lambda x, y: (x + y))'
|
||||
self.assertAstMatches(node, source)
|
||||
self.assertAstMatches(node, expected_node_src)
|
||||
self.assertEqual(source, 'lambda x: lambda x, y: x + y')
|
||||
|
||||
node, source = parser.parse_entity(l(0), future_features=())
|
||||
self.assertEqual(
|
||||
parser.unparse(node, include_encoding_marker=False),
|
||||
'(lambda x, y: (x + y))')
|
||||
expected_node_src = 'lambda x, y: (x + y)'
|
||||
self.assertAstMatches(node, source)
|
||||
self.assertAstMatches(node, expected_node_src)
|
||||
self.assertEqual(source, 'lambda x, y: x + y')
|
||||
|
||||
def test_parse_lambda_resolution_ambiguous(self):
|
||||
|
|
@ -92,9 +105,9 @@ class ParserTest(test.TestCase):
|
|||
|
||||
expected_exception_text = re.compile(r'found multiple definitions'
|
||||
r'.+'
|
||||
r'\(lambda x: \(lambda x'
|
||||
r'\(?lambda x: \(?lambda x'
|
||||
r'.+'
|
||||
r'\(lambda x: \(2', re.DOTALL)
|
||||
r'\(?lambda x: \(?2', re.DOTALL)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
errors.UnsupportedLanguageElementError,
|
||||
|
|
@ -118,17 +131,15 @@ class ParserTest(test.TestCase):
|
|||
- 1)
|
||||
|
||||
node, source = parser.parse_entity(l, future_features=())
|
||||
self.assertEqual(
|
||||
parser.unparse(node, include_encoding_marker=False),
|
||||
'(lambda x: (lambda y: ((x + y) - 1)))')
|
||||
expected_node_src = 'lambda x: (lambda y: ((x + y) - 1))'
|
||||
self.assertAstMatches(node, expected_node_src)
|
||||
self.assertMatchesWithPotentialGarbage(
|
||||
source, ('lambda x: lambda y: x + y # pylint:disable=g-long-lambda\n'
|
||||
' - 1'), ')')
|
||||
|
||||
node, source = parser.parse_entity(l(0), future_features=())
|
||||
self.assertEqual(
|
||||
parser.unparse(node, include_encoding_marker=False),
|
||||
'(lambda y: ((x + y) - 1))')
|
||||
expected_node_src = 'lambda y: ((x + y) - 1)'
|
||||
self.assertAstMatches(node, expected_node_src)
|
||||
self.assertMatchesWithPotentialGarbage(
|
||||
source, ('lambda y: x + y # pylint:disable=g-long-lambda\n'
|
||||
' - 1'), ')')
|
||||
|
|
@ -141,30 +152,26 @@ class ParserTest(test.TestCase):
|
|||
)
|
||||
|
||||
node, source = parser.parse_entity(l[0], future_features=())
|
||||
self.assertEqual(
|
||||
parser.unparse(node, include_encoding_marker=False),
|
||||
'(lambda x: (lambda y: ((x + y) + 1)))')
|
||||
expected_node_src = 'lambda x: (lambda y: ((x + y) + 1))'
|
||||
self.assertAstMatches(node, expected_node_src)
|
||||
self.assertMatchesWithPotentialGarbage(
|
||||
source, 'lambda x: lambda y: x + y + 1', ',')
|
||||
|
||||
node, source = parser.parse_entity(l[0](0), future_features=())
|
||||
self.assertEqual(
|
||||
parser.unparse(node, include_encoding_marker=False),
|
||||
'(lambda y: ((x + y) + 1))')
|
||||
expected_node_src = 'lambda y: ((x + y) + 1)'
|
||||
self.assertAstMatches(node, expected_node_src)
|
||||
self.assertMatchesWithPotentialGarbage(
|
||||
source, 'lambda y: x + y + 1', ',')
|
||||
|
||||
node, source = parser.parse_entity(l[1], future_features=())
|
||||
self.assertEqual(
|
||||
parser.unparse(node, include_encoding_marker=False),
|
||||
'(lambda x: (lambda y: ((x + y) + 2)))')
|
||||
expected_node_src = 'lambda x: (lambda y: ((x + y) + 2))'
|
||||
self.assertAstMatches(node, expected_node_src)
|
||||
self.assertMatchesWithPotentialGarbage(source,
|
||||
'lambda x: lambda y: x + y + 2', ',')
|
||||
|
||||
node, source = parser.parse_entity(l[1](0), future_features=())
|
||||
self.assertEqual(
|
||||
parser.unparse(node, include_encoding_marker=False),
|
||||
'(lambda y: ((x + y) + 2))')
|
||||
expected_node_src = 'lambda y: ((x + y) + 2)'
|
||||
self.assertAstMatches(node, expected_node_src)
|
||||
self.assertMatchesWithPotentialGarbage(source, 'lambda y: x + y + 2', ',')
|
||||
|
||||
def test_parse_lambda_complex_body(self):
|
||||
|
|
@ -182,9 +189,9 @@ class ParserTest(test.TestCase):
|
|||
)
|
||||
|
||||
node, source = parser.parse_entity(l, future_features=())
|
||||
self.assertEqual(
|
||||
parser.unparse(node, include_encoding_marker=False),
|
||||
"(lambda x: (x.y([], x.z, (), x[0:2]), x.u, 'abc', 1))")
|
||||
expected_node_src = "lambda x: (x.y([], x.z, (), x[0:2]), x.u, 'abc', 1)"
|
||||
self.assertAstMatches(node, expected_node_src)
|
||||
|
||||
base_source = ('lambda x: ( # pylint:disable=g-long-lambda\n'
|
||||
' x.y(\n'
|
||||
' [],\n'
|
||||
|
|
@ -197,16 +204,14 @@ class ParserTest(test.TestCase):
|
|||
' 1,')
|
||||
# The complete source includes the trailing parenthesis. But that is only
|
||||
# detected in runtimes which correctly track end_lineno for ASTs.
|
||||
self.assertIn(source, (base_source, base_source + '\n )'))
|
||||
self.assertMatchesWithPotentialGarbage(source, base_source, '\n )')
|
||||
|
||||
def test_parse_lambda_function_call_definition(self):
|
||||
|
||||
def do_parse_and_test(lam, **unused_kwargs):
|
||||
node, source = parser.parse_entity(lam, future_features=())
|
||||
|
||||
self.assertEqual(
|
||||
parser.unparse(node, include_encoding_marker=False),
|
||||
'(lambda x: x)')
|
||||
expected_node_src = 'lambda x: x'
|
||||
self.assertAstMatches(node, expected_node_src)
|
||||
self.assertMatchesWithPotentialGarbage(
|
||||
source, 'lambda x: x', ', named_arg=1)')
|
||||
|
||||
|
|
@ -372,6 +377,13 @@ string""")
|
|||
a = 'c'
|
||||
""").strip(), source.strip())
|
||||
|
||||
def test_ext_slice_roundtrip(self):
|
||||
def ext_slice(n):
|
||||
return n[:, :], n[0, :], n[:, 0]
|
||||
|
||||
node, _ = parser.parse_entity(ext_slice, future_features=())
|
||||
source = parser.unparse(node)
|
||||
self.assertAstMatches(node, source, expr=False)
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import re
|
||||
import gast
|
||||
|
||||
from tensorflow.python.autograph.pyct import anno
|
||||
|
|
@ -180,10 +181,10 @@ class TransformerTest(test.TestCase):
|
|||
node = tr.visit(node)
|
||||
|
||||
self.assertEqual(len(node.body), 2)
|
||||
self.assertTrue(isinstance(node.body[0], gast.Assign))
|
||||
self.assertTrue(isinstance(node.body[1], gast.If))
|
||||
self.assertTrue(isinstance(node.body[1].body[0], gast.Assign))
|
||||
self.assertTrue(isinstance(node.body[1].body[1], gast.Return))
|
||||
self.assertIsInstance(node.body[0], gast.Assign)
|
||||
self.assertIsInstance(node.body[1], gast.If)
|
||||
self.assertIsInstance(node.body[1].body[0], gast.Assign)
|
||||
self.assertIsInstance(node.body[1].body[1], gast.Return)
|
||||
|
||||
def test_robust_error_on_list_visit(self):
|
||||
|
||||
|
|
@ -244,7 +245,7 @@ class TransformerTest(test.TestCase):
|
|||
# The message should reference the exception actually raised, not anything
|
||||
# from the exception handler.
|
||||
expected_substring = 'I blew up'
|
||||
self.assertTrue(expected_substring in obtained_message, obtained_message)
|
||||
self.assertIn(expected_substring, obtained_message)
|
||||
|
||||
def test_origin_info_propagated_to_new_nodes(self):
|
||||
|
||||
|
|
@ -347,20 +348,19 @@ class CodeGeneratorTest(test.TestCase):
|
|||
origin_info.resolve(node, source, 'test_file', 100, 0)
|
||||
tg.visit(node)
|
||||
|
||||
self.assertEqual(
|
||||
tg.code_buffer, '\n'.join([
|
||||
'x = 1',
|
||||
'if (x > 0) {',
|
||||
'x = 2',
|
||||
'if (x > 1) {',
|
||||
'x = 3',
|
||||
'} else {',
|
||||
'}',
|
||||
'} else {',
|
||||
'}',
|
||||
'return x',
|
||||
'',
|
||||
]))
|
||||
r = re.compile('.*'.join([
|
||||
r'x = 1',
|
||||
r'if \(?x > 0\)? {',
|
||||
r'x = 2',
|
||||
r'if \(?x > 1\)? {',
|
||||
r'x = 3',
|
||||
r'} else {',
|
||||
r'}',
|
||||
r'} else {',
|
||||
r'}',
|
||||
r'return x']), re.DOTALL)
|
||||
|
||||
self.assertRegex(tg.code_buffer, r)
|
||||
# TODO(mdan): Test the source map.
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user