mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 00:19:58 +01:00
Deprecate experimental_follow_type_hints for tf.function
PiperOrigin-RevId: 479350246
This commit is contained in:
parent
5a2003d501
commit
a549623567
|
|
@ -115,6 +115,8 @@
|
|||
"composite" tensors, such as `tf.RaggedTensor`, as inputs.
|
||||
* Fix device placement issues related to datasets with ragged tensors of
|
||||
strings (i.e. variant encoded data with types not supported on GPU).
|
||||
* 'experimental_follow_type_hints' for tf.function has been deprecated.
|
||||
Please use input_signature or reduce_retracing to minimize retracing.
|
||||
|
||||
* `tf.SparseTensor`:
|
||||
* Introduced `set_shape`, which sets the static dense shape of the sparse tensor and has the same semantics as `tf.Tensor.set_shape`.
|
||||
|
|
|
|||
|
|
@ -49,7 +49,6 @@ class FunctionSpec(object):
|
|||
def from_function_and_signature(cls, python_function,
|
||||
input_signature,
|
||||
is_pure=False,
|
||||
experimental_follow_type_hints=False,
|
||||
jit_compile=None):
|
||||
"""Creates a FunctionSpec instance given a python function and signature.
|
||||
|
||||
|
|
@ -58,7 +57,6 @@ class FunctionSpec(object):
|
|||
input_signature: a signature of the function (None, if variable)
|
||||
is_pure: if True all input arguments (including variables and constants)
|
||||
will be converted to tensors and no variable changes allowed.
|
||||
experimental_follow_type_hints: see `tf.function`
|
||||
jit_compile: see `tf.function`
|
||||
|
||||
Returns:
|
||||
|
|
@ -145,7 +143,6 @@ class FunctionSpec(object):
|
|||
input_signature,
|
||||
is_pure=is_pure,
|
||||
jit_compile=jit_compile,
|
||||
experimental_follow_type_hints=experimental_follow_type_hints,
|
||||
name=name)
|
||||
|
||||
def __init__(self,
|
||||
|
|
@ -153,7 +150,6 @@ class FunctionSpec(object):
|
|||
is_method,
|
||||
input_signature,
|
||||
is_pure=False,
|
||||
experimental_follow_type_hints=False,
|
||||
name=None,
|
||||
jit_compile=None):
|
||||
"""Constructs a FunctionSpec describing a python function.
|
||||
|
|
@ -164,7 +160,6 @@ class FunctionSpec(object):
|
|||
input_signature: a signature of the function (None, if variable)
|
||||
is_pure: if True all input arguments (including variables and constants)
|
||||
will be converted to tensors and no variable changes allowed.
|
||||
experimental_follow_type_hints: see `tf.function`.
|
||||
name: Name of the function
|
||||
jit_compile: see `tf.function`.
|
||||
"""
|
||||
|
|
@ -172,7 +167,6 @@ class FunctionSpec(object):
|
|||
self._is_method = is_method
|
||||
self._is_pure = is_pure
|
||||
self._jit_compile = jit_compile
|
||||
self._experimental_follow_type_hints = experimental_follow_type_hints
|
||||
|
||||
# TODO(edloper): Include name when serializing for SavedModel?
|
||||
self._name = name or "f"
|
||||
|
|
@ -385,35 +379,6 @@ class FunctionSpec(object):
|
|||
f"{len(missing_tensor_specs)} argument(s):"
|
||||
f" {missing_tensor_specs}.")
|
||||
|
||||
def _convert_annotated_args_to_tensors(self, args, kwargs):
|
||||
"""Attempts to autobox arguments annotated as tf.Tensor."""
|
||||
if self.input_signature is not None:
|
||||
return
|
||||
|
||||
args = list(args)
|
||||
for i, arg in enumerate(args):
|
||||
# See
|
||||
# https://docs.python.org/3/library/inspect.html#inspect.getfullargspec
|
||||
if i < len(self._fullargspec.args):
|
||||
annotation_key = self._fullargspec.args[i]
|
||||
else:
|
||||
annotation_key = self._fullargspec.varargs
|
||||
arg_annotation = self._fullargspec.annotations.get(annotation_key, None)
|
||||
|
||||
# TODO(rahulkamat): Change to TensorLike (here ans below)
|
||||
if arg_annotation == ops.Tensor:
|
||||
args[i] = _to_tensor_or_tensor_spec(arg)
|
||||
|
||||
for kw, v in kwargs.items():
|
||||
if kw in self._fullargspec.kwonlyargs or kw in self._fullargspec.args:
|
||||
annotation_key = kw
|
||||
else:
|
||||
annotation_key = self._fullargspec.varkw
|
||||
kwarg_annotation = self._fullargspec.annotations.get(annotation_key, None)
|
||||
if kwarg_annotation == ops.Tensor:
|
||||
kwargs[kw] = _to_tensor_or_tensor_spec(v)
|
||||
return tuple(args), kwargs
|
||||
|
||||
def _validate_inputs(self, flat_inputs):
|
||||
"""Raises an error if inputs contain illegal values."""
|
||||
for inp in flat_inputs:
|
||||
|
|
@ -484,8 +449,7 @@ class FunctionSpec(object):
|
|||
kwargs = {key: kwargs[key] for key in kwargs}
|
||||
if self._is_pure:
|
||||
args, kwargs = _convert_variables_to_tensors(args, kwargs)
|
||||
if self._experimental_follow_type_hints:
|
||||
args, kwargs = self._convert_annotated_args_to_tensors(args, kwargs)
|
||||
|
||||
# Pre-calculate to reduce overhead
|
||||
arglen = len(args)
|
||||
if self._input_signature is not None:
|
||||
|
|
|
|||
|
|
@ -533,8 +533,7 @@ class Function(core.GenericFunction, trackable.Trackable):
|
|||
jit_compile=None,
|
||||
reduce_retracing=False,
|
||||
experimental_implements=None,
|
||||
experimental_autograph_options=None,
|
||||
experimental_follow_type_hints=None):
|
||||
experimental_autograph_options=None):
|
||||
"""Initializes a `Function`.
|
||||
|
||||
Args:
|
||||
|
|
@ -546,7 +545,6 @@ class Function(core.GenericFunction, trackable.Trackable):
|
|||
reduce_retracing: See the documentation for `tf.function`.
|
||||
experimental_implements: See the documentation for `tf.function`.
|
||||
experimental_autograph_options: See the documentation for `tf.function`.
|
||||
experimental_follow_type_hints: See the documentation for `tf.function`.
|
||||
|
||||
Raises:
|
||||
ValueError: if `input_signature` is not None and the `python_function`'s
|
||||
|
|
@ -558,7 +556,6 @@ class Function(core.GenericFunction, trackable.Trackable):
|
|||
python_function,
|
||||
input_signature,
|
||||
jit_compile=jit_compile,
|
||||
experimental_follow_type_hints=experimental_follow_type_hints,
|
||||
)
|
||||
self._implements = experimental_implements
|
||||
# If `True`, the function uses the rendezvous of the parent. This is only
|
||||
|
|
@ -569,9 +566,6 @@ class Function(core.GenericFunction, trackable.Trackable):
|
|||
self._experimental_autograph_options = experimental_autograph_options
|
||||
self._reduce_retracing = reduce_retracing
|
||||
self._jit_compile = jit_compile
|
||||
if experimental_follow_type_hints is None:
|
||||
experimental_follow_type_hints = False
|
||||
self._experimental_follow_type_hints = experimental_follow_type_hints
|
||||
self._created_variables = None # GUARDED_BY(self._lock)
|
||||
self._variable_creation_fn = None # GUARDED_BY(self._lock)
|
||||
self._no_variable_creation_fn = None # GUARDED_BY(self._lock)
|
||||
|
|
@ -710,8 +704,7 @@ class Function(core.GenericFunction, trackable.Trackable):
|
|||
autograph=self._autograph,
|
||||
jit_compile=self._jit_compile,
|
||||
reduce_retracing=self._reduce_retracing,
|
||||
autograph_options=self._experimental_autograph_options,
|
||||
experimental_follow_type_hints=self._experimental_follow_type_hints)
|
||||
autograph_options=self._experimental_autograph_options)
|
||||
|
||||
def _initialize(self, args, kwds, add_initializers_to=None):
|
||||
"""Initializes, on the first call.
|
||||
|
|
@ -781,8 +774,7 @@ class Function(core.GenericFunction, trackable.Trackable):
|
|||
jit_compile=self._jit_compile,
|
||||
reduce_retracing=self._reduce_retracing,
|
||||
experimental_implements=self._implements,
|
||||
experimental_autograph_options=self._experimental_autograph_options,
|
||||
experimental_follow_type_hints=self._experimental_follow_type_hints)
|
||||
experimental_autograph_options=self._experimental_autograph_options)
|
||||
|
||||
if self._shared_rendezvous:
|
||||
f._shared_rendezvous = self._shared_rendezvous # pylint: disable=protected-access
|
||||
|
|
@ -1267,16 +1259,21 @@ class Function(core.GenericFunction, trackable.Trackable):
|
|||
"experimental_relax_shapes is deprecated, use "
|
||||
"reduce_retracing instead",
|
||||
"experimental_relax_shapes")
|
||||
def function(func=None,
|
||||
input_signature=None,
|
||||
autograph=True,
|
||||
jit_compile=None,
|
||||
reduce_retracing=False,
|
||||
experimental_implements=None,
|
||||
experimental_autograph_options=None,
|
||||
experimental_relax_shapes=None,
|
||||
experimental_compile=None,
|
||||
experimental_follow_type_hints=None) -> core.GenericFunction:
|
||||
@deprecation.deprecated_args(None,
|
||||
"experimental_follow_type_hints is deprecated",
|
||||
"experimental_follow_type_hints")
|
||||
def function(
|
||||
func=None,
|
||||
input_signature=None,
|
||||
autograph=True,
|
||||
jit_compile=None,
|
||||
reduce_retracing=False,
|
||||
experimental_implements=None,
|
||||
experimental_autograph_options=None,
|
||||
experimental_relax_shapes=None,
|
||||
experimental_compile=None,
|
||||
experimental_follow_type_hints=None # pylint: disable=unused-argument
|
||||
) -> core.GenericFunction:
|
||||
"""Compiles a function into a callable TensorFlow graph.
|
||||
|
||||
`tf.function` constructs a `tf.types.experimental.GenericFunction` that
|
||||
|
|
@ -1518,32 +1515,6 @@ def function(func=None,
|
|||
>>> f(2, tf.constant(2))
|
||||
<tf.Tensor: shape=(), dtype=int32, numpy=2>
|
||||
|
||||
## Using type annotations to improve performance
|
||||
|
||||
`experimental_follow_type_hints` can be used along with type annotations to
|
||||
reduce retracing by automatically casting any Python values to `tf.Tensor`
|
||||
(something that is not done by default, unless you use input signatures).
|
||||
|
||||
>>> @tf.function(experimental_follow_type_hints=True)
|
||||
... def f_with_hints(x: tf.Tensor):
|
||||
... print('Tracing')
|
||||
... return x
|
||||
>>> @tf.function(experimental_follow_type_hints=False)
|
||||
... def f_no_hints(x: tf.Tensor):
|
||||
... print('Tracing')
|
||||
... return x
|
||||
>>> f_no_hints(1)
|
||||
Tracing
|
||||
<tf.Tensor: shape=(), dtype=int32, numpy=1>
|
||||
>>> f_no_hints(2)
|
||||
Tracing
|
||||
<tf.Tensor: shape=(), dtype=int32, numpy=2>
|
||||
>>> f_with_hints(1)
|
||||
Tracing
|
||||
<tf.Tensor: shape=(), dtype=int32, numpy=1>
|
||||
>>> f_with_hints(2)
|
||||
<tf.Tensor: shape=(), dtype=int32, numpy=2>
|
||||
|
||||
Args:
|
||||
func: The function to be compiled. If `func` is None, `tf.function` returns
|
||||
a decorator that can be invoked with a single argument - `func`. In other
|
||||
|
|
@ -1603,10 +1574,8 @@ def function(func=None,
|
|||
experimental_relax_shapes: Deprecated. Use `reduce_retracing`
|
||||
instead.
|
||||
experimental_compile: Deprecated alias to 'jit_compile'.
|
||||
experimental_follow_type_hints: When True, the function may use type
|
||||
annotations from `func` to optimize the tracing performance. For example,
|
||||
arguments annotated with `tf.Tensor` will automatically be converted
|
||||
to a Tensor.
|
||||
experimental_follow_type_hints: Deprecated. Please use input_signature or
|
||||
reduce_retracing instead.
|
||||
|
||||
Returns:
|
||||
If `func` is not None, returns a `tf.types.experimental.GenericFunction`.
|
||||
|
|
@ -1617,9 +1586,6 @@ def function(func=None,
|
|||
`ValueError` when attempting to use `jit_compile=True`, but XLA support is
|
||||
not available.
|
||||
"""
|
||||
if experimental_follow_type_hints is None:
|
||||
experimental_follow_type_hints = False
|
||||
|
||||
if jit_compile is None and JIT_COMPILE_FUNCTIONS:
|
||||
jit_compile = True
|
||||
|
||||
|
|
@ -1650,8 +1616,7 @@ def function(func=None,
|
|||
jit_compile,
|
||||
"experimental_compile",
|
||||
experimental_compile),
|
||||
experimental_implements=experimental_implements,
|
||||
experimental_follow_type_hints=experimental_follow_type_hints))
|
||||
experimental_implements=experimental_implements))
|
||||
|
||||
# This code path is for the `foo = tf.function(foo, ...)` use case
|
||||
if func is not None:
|
||||
|
|
|
|||
|
|
@ -2822,382 +2822,6 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
|||
|
||||
gradients(constant_op.constant([[[1.0], [2.0]]])) # No error is raised
|
||||
|
||||
def testFollowTypeHintsTraceBasic(self):
|
||||
trace_count = [0]
|
||||
|
||||
def func(x: ops.Tensor):
|
||||
trace_count[0] += 1
|
||||
return x
|
||||
|
||||
enabled = polymorphic_function.function(
|
||||
func, experimental_follow_type_hints=True)
|
||||
disabled = polymorphic_function.function(
|
||||
func, experimental_follow_type_hints=False)
|
||||
|
||||
enabled(1) # Initial call gets traced
|
||||
enabled(2)
|
||||
enabled(3)
|
||||
self.assertEqual(trace_count[0], 1)
|
||||
|
||||
trace_count = [0]
|
||||
disabled(1)
|
||||
disabled(2) # Retrace
|
||||
disabled(3) # Retrace
|
||||
self.assertEqual(trace_count[0], 3)
|
||||
|
||||
def testFollowTypeHintsTraceWithArgs(self):
|
||||
trace_count = [0]
|
||||
|
||||
def func(*args: ops.Tensor):
|
||||
trace_count[0] += 1
|
||||
return args
|
||||
|
||||
enabled = polymorphic_function.function(
|
||||
func, experimental_follow_type_hints=True)
|
||||
disabled = polymorphic_function.function(
|
||||
func, experimental_follow_type_hints=False)
|
||||
|
||||
args = (
|
||||
'abc',
|
||||
'def',
|
||||
) * 20
|
||||
args2 = (
|
||||
'def',
|
||||
'abc',
|
||||
) * 20
|
||||
|
||||
enabled(args)
|
||||
enabled(args2)
|
||||
self.assertEqual(trace_count[0], 1)
|
||||
|
||||
trace_count = [0]
|
||||
disabled(args)
|
||||
disabled(args2) # Retrace
|
||||
self.assertEqual(trace_count[0], 2)
|
||||
|
||||
def testFollowTypeHintsTraceWithKwargs(self):
|
||||
trace_count = [0]
|
||||
|
||||
def func(t: ops.Tensor, **kwargs: ops.Tensor):
|
||||
del kwargs
|
||||
trace_count[0] += 1
|
||||
return t
|
||||
|
||||
enabled = polymorphic_function.function(
|
||||
func, experimental_follow_type_hints=True)
|
||||
disabled = polymorphic_function.function(
|
||||
func, experimental_follow_type_hints=False)
|
||||
|
||||
enabled(1, x=1, y=1.0, z='one')
|
||||
enabled(2, x=2, y=2.0, z='two')
|
||||
self.assertEqual(trace_count[0], 1)
|
||||
|
||||
trace_count = [0]
|
||||
disabled(1, x=1, y=1.0, z='one')
|
||||
disabled(2, x=2, y=2.0, z='two') # Retrace
|
||||
self.assertEqual(trace_count[0], 2)
|
||||
|
||||
def testFollowTypeHintsTraceWithMultipleInputTypes(self):
|
||||
trace_count = [0]
|
||||
|
||||
def func(t: ops.Tensor, *args: ops.Tensor, **kwargs: ops.Tensor):
|
||||
del args, kwargs
|
||||
trace_count[0] += 1
|
||||
return t
|
||||
|
||||
enabled = polymorphic_function.function(
|
||||
func, experimental_follow_type_hints=True)
|
||||
disabled = polymorphic_function.function(
|
||||
func, experimental_follow_type_hints=False)
|
||||
|
||||
enabled(1, constant_op.constant(1), 'str', x=4.0)
|
||||
enabled(2, constant_op.constant(2), 'str2', x=5.0)
|
||||
self.assertEqual(trace_count[0], 1)
|
||||
|
||||
trace_count = [0]
|
||||
disabled(1, constant_op.constant(1), 'str', x=4.0)
|
||||
disabled(2, constant_op.constant(2), 'str2', x=5.0) # Retrace
|
||||
self.assertEqual(trace_count[0], 2)
|
||||
|
||||
def testFollowTypeHintsTraceWithOnlyArgNamed(self):
|
||||
trace_count = [0]
|
||||
|
||||
def func(t: ops.Tensor, i: int = 1, **kwargs): # pylint: disable=bad-whitespace
|
||||
del i, kwargs
|
||||
trace_count[0] += 1
|
||||
return t
|
||||
|
||||
enabled = polymorphic_function.function(
|
||||
func, experimental_follow_type_hints=True)
|
||||
|
||||
enabled(1, 3, x=4.0, y='str')
|
||||
enabled(2, 4, x=4.0, y='str') # Retrace
|
||||
self.assertEqual(trace_count[0], 2)
|
||||
|
||||
def testFollowTypeHintsTraceWithNotAllNamed(self):
|
||||
trace_count = [0]
|
||||
|
||||
def func(x, y: ops.Tensor, z: int):
|
||||
del y, z
|
||||
trace_count[0] += 1
|
||||
return x
|
||||
|
||||
enabled = polymorphic_function.function(
|
||||
func, experimental_follow_type_hints=True)
|
||||
|
||||
enabled(1, 2, 3)
|
||||
enabled(1, 20, 3) # No retrace - change in ops.Tensor typed arg
|
||||
enabled(2, 2, 3) # Retrace - change in untyped arg
|
||||
enabled(2, 2, 4) # Retrace - change in typed arg
|
||||
self.assertEqual(trace_count[0], 3)
|
||||
|
||||
def testFollowTypeHintsTraceWithOnlyArgsNamed(self):
|
||||
trace_count = [0]
|
||||
|
||||
def func(x, y, *args: ops.Tensor):
|
||||
del y, args
|
||||
trace_count[0] += 1
|
||||
return x
|
||||
|
||||
enabled = polymorphic_function.function(
|
||||
func, experimental_follow_type_hints=True)
|
||||
|
||||
enabled(1, 20, 3, 4, 5, 6)
|
||||
enabled(1, 20, 3, 4, 5, 60) # No retrace - change in *args
|
||||
enabled(1, 30, 7, 8, 9, 10) # Retrace - change in args
|
||||
self.assertEqual(trace_count[0], 2)
|
||||
|
||||
def testFollowTypeHintsTraceWithOnlyKwargsNamed(self):
|
||||
trace_count = [0]
|
||||
|
||||
def func(x, y, *args, **kwargs: ops.Tensor):
|
||||
del y, args, kwargs
|
||||
trace_count[0] += 1
|
||||
return x
|
||||
|
||||
enabled = polymorphic_function.function(
|
||||
func, experimental_follow_type_hints=True)
|
||||
|
||||
enabled(1, 2, 3, 4, 5, 6, a=1.0, b=2.0, c=3.0)
|
||||
enabled(
|
||||
1, 2, 3, 4, 5, 6, a=1.5, b=2.5,
|
||||
c=3.5) # No retrace - change in **kwargs
|
||||
enabled(100, 2, 3, 4, 5, 6, a=1.0, b=2.0, c=3.0) # Retrace - change in args
|
||||
enabled(
|
||||
1, 2, 3, 4, 5, 100, a=1.0, b=2.0, c=3.0) # Retrace - change in *args
|
||||
self.assertEqual(trace_count[0], 3)
|
||||
|
||||
def testFollowTypeHintsTraceWithArgsEquals(self):
|
||||
trace_count = [0]
|
||||
|
||||
def func(
|
||||
x: ops.Tensor = 0, # pylint:disable=bad-whitespace
|
||||
y: int = 1, # pylint:disable=bad-whitespace
|
||||
**kwargs: ops.Tensor):
|
||||
del y, kwargs
|
||||
trace_count[0] += 1
|
||||
return x
|
||||
|
||||
enabled = polymorphic_function.function(
|
||||
func, experimental_follow_type_hints=True)
|
||||
|
||||
enabled(x=1, y=2, z=3)
|
||||
enabled(x=1, y=3, z=3) # Retrace - change in args
|
||||
enabled(x=2, y=2, z=4) # No retrace - change in args and **kwargs
|
||||
enabled(x=2, y=2, z=4, u=5) # Retrace - change in **kwargs
|
||||
self.assertEqual(trace_count[0], 3)
|
||||
|
||||
def testFollowTypeHintsWithTensorSpec(self):
|
||||
|
||||
def func(x: ops.Tensor, y):
|
||||
return x + y
|
||||
|
||||
v = polymorphic_function.function(experimental_follow_type_hints=True)(func)
|
||||
v = v.get_concrete_function(
|
||||
tensor_spec.TensorSpec(shape=None, dtype=dtypes.float32), 3)
|
||||
x = v(constant_op.constant(1.), 3)
|
||||
self.assertEqual(x.numpy(), 4.)
|
||||
|
||||
def testFollowTypeHintsTraceWithKwArgsAndNoVarKws(self):
|
||||
trace_count = [0]
|
||||
|
||||
def func(a: int, b: ops.Tensor, x: ops.Tensor = 0, y: int = 1):
|
||||
del a, b, y
|
||||
trace_count[0] += 1
|
||||
return x
|
||||
|
||||
enabled = polymorphic_function.function(
|
||||
func, experimental_follow_type_hints=True)
|
||||
|
||||
enabled(0, 0, x=1, y=2)
|
||||
enabled(
|
||||
0,
|
||||
0,
|
||||
x=2,
|
||||
y=2,
|
||||
) # No retrace, since only tensor changed
|
||||
self.assertEqual(trace_count[0], 1)
|
||||
|
||||
# Pass args as keyword args.
|
||||
enabled(
|
||||
a=0,
|
||||
b=0,
|
||||
x=2,
|
||||
y=2,
|
||||
) # No retrace, args are the same
|
||||
self.assertEqual(trace_count[0], 1)
|
||||
|
||||
enabled(
|
||||
a=1,
|
||||
b=0,
|
||||
x=2,
|
||||
y=2,
|
||||
) # Retrace, since non-tensor arg changed
|
||||
self.assertEqual(trace_count[0], 2)
|
||||
|
||||
enabled(a=1, b=2, x=2, y=2) # No retrace, since only tensor changed
|
||||
self.assertEqual(trace_count[0], 2)
|
||||
|
||||
trace_count[0] = 0
|
||||
disabled = polymorphic_function.function(
|
||||
func, experimental_follow_type_hints=False)
|
||||
disabled(0, 0, x=1, y=2)
|
||||
disabled(
|
||||
0,
|
||||
0,
|
||||
x=2,
|
||||
y=2,
|
||||
) # Retrace
|
||||
self.assertEqual(trace_count[0], 2)
|
||||
|
||||
def testFollowTypeHintsTraceWithArgsEqualsTypedKwargs(self):
|
||||
trace_count = [0]
|
||||
|
||||
def func(x, y, **kwargs: ops.Tensor):
|
||||
del y, kwargs
|
||||
trace_count[0] += 1
|
||||
return x
|
||||
|
||||
enabled = polymorphic_function.function(
|
||||
func, experimental_follow_type_hints=True)
|
||||
|
||||
enabled(x=1, y=2, z=3)
|
||||
enabled(x=1, y=3, z=3) # Retrace
|
||||
enabled(x=1, y=2, z=4) # No retrace
|
||||
enabled(x=2, y=2, z=4) # Retrace
|
||||
enabled(x=2, y=2, z=4, u=5) # Retrace
|
||||
self.assertEqual(trace_count[0], 4)
|
||||
|
||||
def testFollowTypeHintsTraceWithArgsEqualsTypedArgs(self):
|
||||
trace_count = [0]
|
||||
|
||||
def func(x: ops.Tensor, y: int, **kwargs):
|
||||
del y, kwargs
|
||||
trace_count[0] += 1
|
||||
return x
|
||||
|
||||
enabled = polymorphic_function.function(
|
||||
func, experimental_follow_type_hints=True)
|
||||
|
||||
enabled(x=1, y=2, z=3)
|
||||
enabled(x=1, y=3, z=3) # Retrace
|
||||
enabled(x=1, y=2, z=4) # Retrace
|
||||
enabled(x=2, y=2, z=3) # No retrace
|
||||
enabled(x=2, y=2, z=4, u=5) # Retrace
|
||||
self.assertEqual(trace_count[0], 4)
|
||||
|
||||
def testFollowTypeHintsTraceWithKwOnlyArgsBasic(self):
|
||||
trace_count = [0]
|
||||
|
||||
def func(*, a: ops.Tensor = None, b=1): # pylint: disable=bad-whitespace
|
||||
del b
|
||||
trace_count[0] += 1
|
||||
return a
|
||||
|
||||
enabled = polymorphic_function.function(
|
||||
func, experimental_follow_type_hints=True)
|
||||
|
||||
enabled(a=1, b=2)
|
||||
enabled(a=2, b=2) # No retrace
|
||||
enabled(a=1, b=1) # Retrace
|
||||
self.assertEqual(trace_count[0], 2)
|
||||
|
||||
def testFollowTypeHintsTraceWithArgsKwOnlyArgsKwargsAndTypedArg(self):
|
||||
trace_count = [0]
|
||||
|
||||
def func(arg: ops.Tensor, *args, kwonly, **kwargs):
|
||||
del args, kwonly, kwargs
|
||||
trace_count[0] += 1
|
||||
return arg
|
||||
|
||||
enabled = polymorphic_function.function(
|
||||
func, experimental_follow_type_hints=True)
|
||||
|
||||
enabled(1, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7)
|
||||
enabled(100, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7) # No retrace
|
||||
enabled(1000, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7) # No retrace
|
||||
enabled(1, 20, 30, 40, kwonly=5, kwarg1=6, kwarg2=7) # Retrace
|
||||
enabled(1, 2, 3, 4, kwonly=50, kwarg1=6, kwarg2=7) # Retrace
|
||||
enabled(1, 2, 3, 4, kwonly=5, kwarg1=60, kwarg2=70) # Retrace
|
||||
self.assertEqual(trace_count[0], 4)
|
||||
|
||||
def testFollowTypeHintsTraceWithArgsKwOnlyArgsKwargsAndTypedArgs(self):
|
||||
trace_count = [0]
|
||||
|
||||
def func(arg, *args: ops.Tensor, kwonly, **kwargs):
|
||||
del args, kwonly, kwargs
|
||||
trace_count[0] += 1
|
||||
return arg
|
||||
|
||||
enabled = polymorphic_function.function(
|
||||
func, experimental_follow_type_hints=True)
|
||||
|
||||
enabled(1, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7)
|
||||
enabled(100, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7) # Retrace
|
||||
enabled(1, 20, 30, 40, kwonly=5, kwarg1=6, kwarg2=7) # No retrace
|
||||
enabled(1, 200, 300, 400, kwonly=5, kwarg1=6, kwarg2=7) # No retrace
|
||||
enabled(1, 2, 3, 4, kwonly=50, kwarg1=6, kwarg2=7) # Retrace
|
||||
enabled(1, 2, 3, 4, kwonly=5, kwarg1=60, kwarg2=70) # Retrace
|
||||
self.assertEqual(trace_count[0], 4)
|
||||
|
||||
def testFollowTypeHintsTraceWithArgsKwOnlyArgsKwargsAndTypedKwOnlyArg(self):
|
||||
trace_count = [0]
|
||||
|
||||
def func(arg, *args, kwonly: ops.Tensor, **kwargs):
|
||||
del args, kwonly, kwargs
|
||||
trace_count[0] += 1
|
||||
return arg
|
||||
|
||||
enabled = polymorphic_function.function(
|
||||
func, experimental_follow_type_hints=True)
|
||||
|
||||
enabled(1, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7)
|
||||
enabled(100, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7) # Retrace
|
||||
enabled(1, 20, 30, 40, kwonly=5, kwarg1=6, kwarg2=7) # Retrace
|
||||
enabled(1, 2, 3, 4, kwonly=50, kwarg1=6, kwarg2=7) # No retrace
|
||||
enabled(1, 2, 3, 4, kwonly=500, kwarg1=6, kwarg2=7) # No retrace
|
||||
enabled(1, 2, 3, 4, kwonly=5, kwarg1=60, kwarg2=70) # Retrace
|
||||
self.assertEqual(trace_count[0], 4)
|
||||
|
||||
def testFollowTypeHintsTraceWithArgsKwOnlyArgsKwargsAndTypedKwargs(self):
|
||||
trace_count = [0]
|
||||
|
||||
def func(arg, *args, kwonly, **kwargs: ops.Tensor):
|
||||
del args, kwonly, kwargs
|
||||
trace_count[0] += 1
|
||||
return arg
|
||||
|
||||
enabled = polymorphic_function.function(
|
||||
func, experimental_follow_type_hints=True)
|
||||
|
||||
enabled(1, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7)
|
||||
enabled(100, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7) # Retrace
|
||||
enabled(1, 20, 30, 40, kwonly=5, kwarg1=6, kwarg2=7) # Retrace
|
||||
enabled(1, 2, 3, 4, kwonly=50, kwarg1=6, kwarg2=7) # Retrace
|
||||
enabled(1, 2, 3, 4, kwonly=5, kwarg1=60, kwarg2=70) # No retrace
|
||||
enabled(1, 2, 3, 4, kwonly=5, kwarg1=600, kwarg2=700) # No retrace
|
||||
self.assertEqual(trace_count[0], 4)
|
||||
|
||||
def testWithExtraWrapper(self):
|
||||
|
||||
class Foo(module.Module):
|
||||
|
|
|
|||
|
|
@ -12,7 +12,6 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Internal APIs to be removed in the future."""
|
||||
|
||||
from tensorflow.python.eager.polymorphic_function import monomorphic_function
|
||||
|
|
@ -369,8 +368,7 @@ def defun_with_attributes(func=None,
|
|||
autograph=True,
|
||||
experimental_autograph_options=None,
|
||||
jit_compile=None,
|
||||
reduce_retracing=False,
|
||||
experimental_follow_type_hints=False):
|
||||
reduce_retracing=False):
|
||||
"""Compiles a Python function into a callable TensorFlow graph.
|
||||
|
||||
This function supports adding extra function attributes. See detailed
|
||||
|
|
@ -392,7 +390,6 @@ def defun_with_attributes(func=None,
|
|||
experimental_autograph_options.
|
||||
jit_compile: same as defun()'s jit_compile.
|
||||
reduce_retracing: same as defun()'s reduce_retracing
|
||||
experimental_follow_type_hints: see `tf.function`.
|
||||
|
||||
Returns:
|
||||
Same as the return value of defun, with attributes added to the function in
|
||||
|
|
@ -418,8 +415,7 @@ def defun_with_attributes(func=None,
|
|||
autograph=autograph,
|
||||
autograph_options=experimental_autograph_options,
|
||||
jit_compile=jit_compile,
|
||||
reduce_retracing=reduce_retracing,
|
||||
experimental_follow_type_hints=experimental_follow_type_hints))
|
||||
reduce_retracing=reduce_retracing))
|
||||
|
||||
# This code path is for the `foo = tfe.defun(foo, ...)` use case
|
||||
if func is not None:
|
||||
|
|
@ -513,8 +509,7 @@ def clear_function_callbacks():
|
|||
|
||||
|
||||
@deprecation.deprecated(
|
||||
None,
|
||||
"Use `tf.config.run_functions_eagerly` instead of the experimental "
|
||||
None, "Use `tf.config.run_functions_eagerly` instead of the experimental "
|
||||
"version.")
|
||||
@tf_export("config.experimental_run_functions_eagerly")
|
||||
def experimental_run_functions_eagerly(run_eagerly):
|
||||
|
|
|
|||
|
|
@ -73,8 +73,7 @@ class TracingCompiler:
|
|||
autograph_options=None,
|
||||
reduce_retracing=False,
|
||||
capture_by_value=None,
|
||||
jit_compile=None,
|
||||
experimental_follow_type_hints=False):
|
||||
jit_compile=None):
|
||||
"""Initializes a `TracingCompiler`.
|
||||
|
||||
Args:
|
||||
|
|
@ -98,7 +97,6 @@ class TracingCompiler:
|
|||
default to False.
|
||||
jit_compile: Force-compile the function with XLA, cf.
|
||||
tf.function doc on jit_compile.
|
||||
experimental_follow_type_hints: See the documentation for `tf.function`.
|
||||
|
||||
Raises:
|
||||
ValueError: if `input_signature` is not None and the `python_function`'s
|
||||
|
|
@ -109,8 +107,7 @@ class TracingCompiler:
|
|||
self._function_spec = function_spec.FunctionSpec.from_function_and_signature(
|
||||
python_function,
|
||||
input_signature,
|
||||
is_pure=pure_function,
|
||||
experimental_follow_type_hints=experimental_follow_type_hints)
|
||||
is_pure=pure_function)
|
||||
self._name = name
|
||||
self._autograph = autograph
|
||||
self._autograph_options = autograph_options
|
||||
|
|
@ -128,7 +125,6 @@ class TracingCompiler:
|
|||
# create different functions for each instance.
|
||||
self._descriptor_cache = weakref.WeakKeyDictionary()
|
||||
self._jit_compile = jit_compile
|
||||
self._experimental_follow_type_hints = experimental_follow_type_hints
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""Calls a graph function specialized to the inputs."""
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ tf_class {
|
|||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'python_function\', \'name\', \'input_signature\', \'autograph\', \'jit_compile\', \'reduce_retracing\', \'experimental_implements\', \'experimental_autograph_options\', \'experimental_follow_type_hints\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\', \'False\', \'None\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'python_function\', \'name\', \'input_signature\', \'autograph\', \'jit_compile\', \'reduce_retracing\', \'experimental_implements\', \'experimental_autograph_options\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\', \'False\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_get_compiler_ir"
|
||||
|
|
|
|||
|
|
@ -6,6 +6,6 @@ tf_module {
|
|||
}
|
||||
member_method {
|
||||
name: "defun_with_attributes"
|
||||
argspec: "args=[\'func\', \'input_signature\', \'attributes\', \'autograph\', \'experimental_autograph_options\', \'jit_compile\', \'reduce_retracing\', \'experimental_follow_type_hints\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'False\', \'False\'], "
|
||||
argspec: "args=[\'func\', \'input_signature\', \'attributes\', \'autograph\', \'experimental_autograph_options\', \'jit_compile\', \'reduce_retracing\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'False\'], "
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user