Deprecate experimental_follow_type_hints for tf.function

PiperOrigin-RevId: 479350246
This commit is contained in:
Faizan Muhammad 2022-10-06 10:33:27 -07:00 committed by TensorFlower Gardener
parent 5a2003d501
commit a549623567
8 changed files with 31 additions and 485 deletions

View File

@ -115,6 +115,8 @@
"composite" tensors, such as `tf.RaggedTensor`, as inputs.
* Fix device placement issues related to datasets with ragged tensors of
strings (i.e. variant encoded data with types not supported on GPU).
* 'experimental_follow_type_hints' for tf.function has been deprecated.
Please use input_signature or reduce_retracing to minimize retracing.
* `tf.SparseTensor`:
* Introduced `set_shape`, which sets the static dense shape of the sparse tensor and has the same semantics as `tf.Tensor.set_shape`.

View File

@ -49,7 +49,6 @@ class FunctionSpec(object):
def from_function_and_signature(cls, python_function,
input_signature,
is_pure=False,
experimental_follow_type_hints=False,
jit_compile=None):
"""Creates a FunctionSpec instance given a python function and signature.
@ -58,7 +57,6 @@ class FunctionSpec(object):
input_signature: a signature of the function (None, if variable)
is_pure: if True all input arguments (including variables and constants)
will be converted to tensors and no variable changes allowed.
experimental_follow_type_hints: see `tf.function`
jit_compile: see `tf.function`
Returns:
@ -145,7 +143,6 @@ class FunctionSpec(object):
input_signature,
is_pure=is_pure,
jit_compile=jit_compile,
experimental_follow_type_hints=experimental_follow_type_hints,
name=name)
def __init__(self,
@ -153,7 +150,6 @@ class FunctionSpec(object):
is_method,
input_signature,
is_pure=False,
experimental_follow_type_hints=False,
name=None,
jit_compile=None):
"""Constructs a FunctionSpec describing a python function.
@ -164,7 +160,6 @@ class FunctionSpec(object):
input_signature: a signature of the function (None, if variable)
is_pure: if True all input arguments (including variables and constants)
will be converted to tensors and no variable changes allowed.
experimental_follow_type_hints: see `tf.function`.
name: Name of the function
jit_compile: see `tf.function`.
"""
@ -172,7 +167,6 @@ class FunctionSpec(object):
self._is_method = is_method
self._is_pure = is_pure
self._jit_compile = jit_compile
self._experimental_follow_type_hints = experimental_follow_type_hints
# TODO(edloper): Include name when serializing for SavedModel?
self._name = name or "f"
@ -385,35 +379,6 @@ class FunctionSpec(object):
f"{len(missing_tensor_specs)} argument(s):"
f" {missing_tensor_specs}.")
def _convert_annotated_args_to_tensors(self, args, kwargs):
"""Attempts to autobox arguments annotated as tf.Tensor."""
if self.input_signature is not None:
return
args = list(args)
for i, arg in enumerate(args):
# See
# https://docs.python.org/3/library/inspect.html#inspect.getfullargspec
if i < len(self._fullargspec.args):
annotation_key = self._fullargspec.args[i]
else:
annotation_key = self._fullargspec.varargs
arg_annotation = self._fullargspec.annotations.get(annotation_key, None)
# TODO(rahulkamat): Change to TensorLike (here ans below)
if arg_annotation == ops.Tensor:
args[i] = _to_tensor_or_tensor_spec(arg)
for kw, v in kwargs.items():
if kw in self._fullargspec.kwonlyargs or kw in self._fullargspec.args:
annotation_key = kw
else:
annotation_key = self._fullargspec.varkw
kwarg_annotation = self._fullargspec.annotations.get(annotation_key, None)
if kwarg_annotation == ops.Tensor:
kwargs[kw] = _to_tensor_or_tensor_spec(v)
return tuple(args), kwargs
def _validate_inputs(self, flat_inputs):
"""Raises an error if inputs contain illegal values."""
for inp in flat_inputs:
@ -484,8 +449,7 @@ class FunctionSpec(object):
kwargs = {key: kwargs[key] for key in kwargs}
if self._is_pure:
args, kwargs = _convert_variables_to_tensors(args, kwargs)
if self._experimental_follow_type_hints:
args, kwargs = self._convert_annotated_args_to_tensors(args, kwargs)
# Pre-calculate to reduce overhead
arglen = len(args)
if self._input_signature is not None:

View File

@ -533,8 +533,7 @@ class Function(core.GenericFunction, trackable.Trackable):
jit_compile=None,
reduce_retracing=False,
experimental_implements=None,
experimental_autograph_options=None,
experimental_follow_type_hints=None):
experimental_autograph_options=None):
"""Initializes a `Function`.
Args:
@ -546,7 +545,6 @@ class Function(core.GenericFunction, trackable.Trackable):
reduce_retracing: See the documentation for `tf.function`.
experimental_implements: See the documentation for `tf.function`.
experimental_autograph_options: See the documentation for `tf.function`.
experimental_follow_type_hints: See the documentation for `tf.function`.
Raises:
ValueError: if `input_signature` is not None and the `python_function`'s
@ -558,7 +556,6 @@ class Function(core.GenericFunction, trackable.Trackable):
python_function,
input_signature,
jit_compile=jit_compile,
experimental_follow_type_hints=experimental_follow_type_hints,
)
self._implements = experimental_implements
# If `True`, the function uses the rendezvous of the parent. This is only
@ -569,9 +566,6 @@ class Function(core.GenericFunction, trackable.Trackable):
self._experimental_autograph_options = experimental_autograph_options
self._reduce_retracing = reduce_retracing
self._jit_compile = jit_compile
if experimental_follow_type_hints is None:
experimental_follow_type_hints = False
self._experimental_follow_type_hints = experimental_follow_type_hints
self._created_variables = None # GUARDED_BY(self._lock)
self._variable_creation_fn = None # GUARDED_BY(self._lock)
self._no_variable_creation_fn = None # GUARDED_BY(self._lock)
@ -710,8 +704,7 @@ class Function(core.GenericFunction, trackable.Trackable):
autograph=self._autograph,
jit_compile=self._jit_compile,
reduce_retracing=self._reduce_retracing,
autograph_options=self._experimental_autograph_options,
experimental_follow_type_hints=self._experimental_follow_type_hints)
autograph_options=self._experimental_autograph_options)
def _initialize(self, args, kwds, add_initializers_to=None):
"""Initializes, on the first call.
@ -781,8 +774,7 @@ class Function(core.GenericFunction, trackable.Trackable):
jit_compile=self._jit_compile,
reduce_retracing=self._reduce_retracing,
experimental_implements=self._implements,
experimental_autograph_options=self._experimental_autograph_options,
experimental_follow_type_hints=self._experimental_follow_type_hints)
experimental_autograph_options=self._experimental_autograph_options)
if self._shared_rendezvous:
f._shared_rendezvous = self._shared_rendezvous # pylint: disable=protected-access
@ -1267,16 +1259,21 @@ class Function(core.GenericFunction, trackable.Trackable):
"experimental_relax_shapes is deprecated, use "
"reduce_retracing instead",
"experimental_relax_shapes")
def function(func=None,
input_signature=None,
autograph=True,
jit_compile=None,
reduce_retracing=False,
experimental_implements=None,
experimental_autograph_options=None,
experimental_relax_shapes=None,
experimental_compile=None,
experimental_follow_type_hints=None) -> core.GenericFunction:
@deprecation.deprecated_args(None,
"experimental_follow_type_hints is deprecated",
"experimental_follow_type_hints")
def function(
func=None,
input_signature=None,
autograph=True,
jit_compile=None,
reduce_retracing=False,
experimental_implements=None,
experimental_autograph_options=None,
experimental_relax_shapes=None,
experimental_compile=None,
experimental_follow_type_hints=None # pylint: disable=unused-argument
) -> core.GenericFunction:
"""Compiles a function into a callable TensorFlow graph.
`tf.function` constructs a `tf.types.experimental.GenericFunction` that
@ -1518,32 +1515,6 @@ def function(func=None,
>>> f(2, tf.constant(2))
<tf.Tensor: shape=(), dtype=int32, numpy=2>
## Using type annotations to improve performance
`experimental_follow_type_hints` can be used along with type annotations to
reduce retracing by automatically casting any Python values to `tf.Tensor`
(something that is not done by default, unless you use input signatures).
>>> @tf.function(experimental_follow_type_hints=True)
... def f_with_hints(x: tf.Tensor):
... print('Tracing')
... return x
>>> @tf.function(experimental_follow_type_hints=False)
... def f_no_hints(x: tf.Tensor):
... print('Tracing')
... return x
>>> f_no_hints(1)
Tracing
<tf.Tensor: shape=(), dtype=int32, numpy=1>
>>> f_no_hints(2)
Tracing
<tf.Tensor: shape=(), dtype=int32, numpy=2>
>>> f_with_hints(1)
Tracing
<tf.Tensor: shape=(), dtype=int32, numpy=1>
>>> f_with_hints(2)
<tf.Tensor: shape=(), dtype=int32, numpy=2>
Args:
func: The function to be compiled. If `func` is None, `tf.function` returns
a decorator that can be invoked with a single argument - `func`. In other
@ -1603,10 +1574,8 @@ def function(func=None,
experimental_relax_shapes: Deprecated. Use `reduce_retracing`
instead.
experimental_compile: Deprecated alias to 'jit_compile'.
experimental_follow_type_hints: When True, the function may use type
annotations from `func` to optimize the tracing performance. For example,
arguments annotated with `tf.Tensor` will automatically be converted
to a Tensor.
experimental_follow_type_hints: Deprecated. Please use input_signature or
reduce_retracing instead.
Returns:
If `func` is not None, returns a `tf.types.experimental.GenericFunction`.
@ -1617,9 +1586,6 @@ def function(func=None,
`ValueError` when attempting to use `jit_compile=True`, but XLA support is
not available.
"""
if experimental_follow_type_hints is None:
experimental_follow_type_hints = False
if jit_compile is None and JIT_COMPILE_FUNCTIONS:
jit_compile = True
@ -1650,8 +1616,7 @@ def function(func=None,
jit_compile,
"experimental_compile",
experimental_compile),
experimental_implements=experimental_implements,
experimental_follow_type_hints=experimental_follow_type_hints))
experimental_implements=experimental_implements))
# This code path is for the `foo = tf.function(foo, ...)` use case
if func is not None:

View File

@ -2822,382 +2822,6 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
gradients(constant_op.constant([[[1.0], [2.0]]])) # No error is raised
def testFollowTypeHintsTraceBasic(self):
trace_count = [0]
def func(x: ops.Tensor):
trace_count[0] += 1
return x
enabled = polymorphic_function.function(
func, experimental_follow_type_hints=True)
disabled = polymorphic_function.function(
func, experimental_follow_type_hints=False)
enabled(1) # Initial call gets traced
enabled(2)
enabled(3)
self.assertEqual(trace_count[0], 1)
trace_count = [0]
disabled(1)
disabled(2) # Retrace
disabled(3) # Retrace
self.assertEqual(trace_count[0], 3)
def testFollowTypeHintsTraceWithArgs(self):
trace_count = [0]
def func(*args: ops.Tensor):
trace_count[0] += 1
return args
enabled = polymorphic_function.function(
func, experimental_follow_type_hints=True)
disabled = polymorphic_function.function(
func, experimental_follow_type_hints=False)
args = (
'abc',
'def',
) * 20
args2 = (
'def',
'abc',
) * 20
enabled(args)
enabled(args2)
self.assertEqual(trace_count[0], 1)
trace_count = [0]
disabled(args)
disabled(args2) # Retrace
self.assertEqual(trace_count[0], 2)
def testFollowTypeHintsTraceWithKwargs(self):
trace_count = [0]
def func(t: ops.Tensor, **kwargs: ops.Tensor):
del kwargs
trace_count[0] += 1
return t
enabled = polymorphic_function.function(
func, experimental_follow_type_hints=True)
disabled = polymorphic_function.function(
func, experimental_follow_type_hints=False)
enabled(1, x=1, y=1.0, z='one')
enabled(2, x=2, y=2.0, z='two')
self.assertEqual(trace_count[0], 1)
trace_count = [0]
disabled(1, x=1, y=1.0, z='one')
disabled(2, x=2, y=2.0, z='two') # Retrace
self.assertEqual(trace_count[0], 2)
def testFollowTypeHintsTraceWithMultipleInputTypes(self):
trace_count = [0]
def func(t: ops.Tensor, *args: ops.Tensor, **kwargs: ops.Tensor):
del args, kwargs
trace_count[0] += 1
return t
enabled = polymorphic_function.function(
func, experimental_follow_type_hints=True)
disabled = polymorphic_function.function(
func, experimental_follow_type_hints=False)
enabled(1, constant_op.constant(1), 'str', x=4.0)
enabled(2, constant_op.constant(2), 'str2', x=5.0)
self.assertEqual(trace_count[0], 1)
trace_count = [0]
disabled(1, constant_op.constant(1), 'str', x=4.0)
disabled(2, constant_op.constant(2), 'str2', x=5.0) # Retrace
self.assertEqual(trace_count[0], 2)
def testFollowTypeHintsTraceWithOnlyArgNamed(self):
trace_count = [0]
def func(t: ops.Tensor, i: int = 1, **kwargs): # pylint: disable=bad-whitespace
del i, kwargs
trace_count[0] += 1
return t
enabled = polymorphic_function.function(
func, experimental_follow_type_hints=True)
enabled(1, 3, x=4.0, y='str')
enabled(2, 4, x=4.0, y='str') # Retrace
self.assertEqual(trace_count[0], 2)
def testFollowTypeHintsTraceWithNotAllNamed(self):
trace_count = [0]
def func(x, y: ops.Tensor, z: int):
del y, z
trace_count[0] += 1
return x
enabled = polymorphic_function.function(
func, experimental_follow_type_hints=True)
enabled(1, 2, 3)
enabled(1, 20, 3) # No retrace - change in ops.Tensor typed arg
enabled(2, 2, 3) # Retrace - change in untyped arg
enabled(2, 2, 4) # Retrace - change in typed arg
self.assertEqual(trace_count[0], 3)
def testFollowTypeHintsTraceWithOnlyArgsNamed(self):
trace_count = [0]
def func(x, y, *args: ops.Tensor):
del y, args
trace_count[0] += 1
return x
enabled = polymorphic_function.function(
func, experimental_follow_type_hints=True)
enabled(1, 20, 3, 4, 5, 6)
enabled(1, 20, 3, 4, 5, 60) # No retrace - change in *args
enabled(1, 30, 7, 8, 9, 10) # Retrace - change in args
self.assertEqual(trace_count[0], 2)
def testFollowTypeHintsTraceWithOnlyKwargsNamed(self):
trace_count = [0]
def func(x, y, *args, **kwargs: ops.Tensor):
del y, args, kwargs
trace_count[0] += 1
return x
enabled = polymorphic_function.function(
func, experimental_follow_type_hints=True)
enabled(1, 2, 3, 4, 5, 6, a=1.0, b=2.0, c=3.0)
enabled(
1, 2, 3, 4, 5, 6, a=1.5, b=2.5,
c=3.5) # No retrace - change in **kwargs
enabled(100, 2, 3, 4, 5, 6, a=1.0, b=2.0, c=3.0) # Retrace - change in args
enabled(
1, 2, 3, 4, 5, 100, a=1.0, b=2.0, c=3.0) # Retrace - change in *args
self.assertEqual(trace_count[0], 3)
def testFollowTypeHintsTraceWithArgsEquals(self):
trace_count = [0]
def func(
x: ops.Tensor = 0, # pylint:disable=bad-whitespace
y: int = 1, # pylint:disable=bad-whitespace
**kwargs: ops.Tensor):
del y, kwargs
trace_count[0] += 1
return x
enabled = polymorphic_function.function(
func, experimental_follow_type_hints=True)
enabled(x=1, y=2, z=3)
enabled(x=1, y=3, z=3) # Retrace - change in args
enabled(x=2, y=2, z=4) # No retrace - change in args and **kwargs
enabled(x=2, y=2, z=4, u=5) # Retrace - change in **kwargs
self.assertEqual(trace_count[0], 3)
def testFollowTypeHintsWithTensorSpec(self):
def func(x: ops.Tensor, y):
return x + y
v = polymorphic_function.function(experimental_follow_type_hints=True)(func)
v = v.get_concrete_function(
tensor_spec.TensorSpec(shape=None, dtype=dtypes.float32), 3)
x = v(constant_op.constant(1.), 3)
self.assertEqual(x.numpy(), 4.)
def testFollowTypeHintsTraceWithKwArgsAndNoVarKws(self):
trace_count = [0]
def func(a: int, b: ops.Tensor, x: ops.Tensor = 0, y: int = 1):
del a, b, y
trace_count[0] += 1
return x
enabled = polymorphic_function.function(
func, experimental_follow_type_hints=True)
enabled(0, 0, x=1, y=2)
enabled(
0,
0,
x=2,
y=2,
) # No retrace, since only tensor changed
self.assertEqual(trace_count[0], 1)
# Pass args as keyword args.
enabled(
a=0,
b=0,
x=2,
y=2,
) # No retrace, args are the same
self.assertEqual(trace_count[0], 1)
enabled(
a=1,
b=0,
x=2,
y=2,
) # Retrace, since non-tensor arg changed
self.assertEqual(trace_count[0], 2)
enabled(a=1, b=2, x=2, y=2) # No retrace, since only tensor changed
self.assertEqual(trace_count[0], 2)
trace_count[0] = 0
disabled = polymorphic_function.function(
func, experimental_follow_type_hints=False)
disabled(0, 0, x=1, y=2)
disabled(
0,
0,
x=2,
y=2,
) # Retrace
self.assertEqual(trace_count[0], 2)
def testFollowTypeHintsTraceWithArgsEqualsTypedKwargs(self):
trace_count = [0]
def func(x, y, **kwargs: ops.Tensor):
del y, kwargs
trace_count[0] += 1
return x
enabled = polymorphic_function.function(
func, experimental_follow_type_hints=True)
enabled(x=1, y=2, z=3)
enabled(x=1, y=3, z=3) # Retrace
enabled(x=1, y=2, z=4) # No retrace
enabled(x=2, y=2, z=4) # Retrace
enabled(x=2, y=2, z=4, u=5) # Retrace
self.assertEqual(trace_count[0], 4)
def testFollowTypeHintsTraceWithArgsEqualsTypedArgs(self):
trace_count = [0]
def func(x: ops.Tensor, y: int, **kwargs):
del y, kwargs
trace_count[0] += 1
return x
enabled = polymorphic_function.function(
func, experimental_follow_type_hints=True)
enabled(x=1, y=2, z=3)
enabled(x=1, y=3, z=3) # Retrace
enabled(x=1, y=2, z=4) # Retrace
enabled(x=2, y=2, z=3) # No retrace
enabled(x=2, y=2, z=4, u=5) # Retrace
self.assertEqual(trace_count[0], 4)
def testFollowTypeHintsTraceWithKwOnlyArgsBasic(self):
trace_count = [0]
def func(*, a: ops.Tensor = None, b=1): # pylint: disable=bad-whitespace
del b
trace_count[0] += 1
return a
enabled = polymorphic_function.function(
func, experimental_follow_type_hints=True)
enabled(a=1, b=2)
enabled(a=2, b=2) # No retrace
enabled(a=1, b=1) # Retrace
self.assertEqual(trace_count[0], 2)
def testFollowTypeHintsTraceWithArgsKwOnlyArgsKwargsAndTypedArg(self):
trace_count = [0]
def func(arg: ops.Tensor, *args, kwonly, **kwargs):
del args, kwonly, kwargs
trace_count[0] += 1
return arg
enabled = polymorphic_function.function(
func, experimental_follow_type_hints=True)
enabled(1, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7)
enabled(100, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7) # No retrace
enabled(1000, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7) # No retrace
enabled(1, 20, 30, 40, kwonly=5, kwarg1=6, kwarg2=7) # Retrace
enabled(1, 2, 3, 4, kwonly=50, kwarg1=6, kwarg2=7) # Retrace
enabled(1, 2, 3, 4, kwonly=5, kwarg1=60, kwarg2=70) # Retrace
self.assertEqual(trace_count[0], 4)
def testFollowTypeHintsTraceWithArgsKwOnlyArgsKwargsAndTypedArgs(self):
trace_count = [0]
def func(arg, *args: ops.Tensor, kwonly, **kwargs):
del args, kwonly, kwargs
trace_count[0] += 1
return arg
enabled = polymorphic_function.function(
func, experimental_follow_type_hints=True)
enabled(1, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7)
enabled(100, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7) # Retrace
enabled(1, 20, 30, 40, kwonly=5, kwarg1=6, kwarg2=7) # No retrace
enabled(1, 200, 300, 400, kwonly=5, kwarg1=6, kwarg2=7) # No retrace
enabled(1, 2, 3, 4, kwonly=50, kwarg1=6, kwarg2=7) # Retrace
enabled(1, 2, 3, 4, kwonly=5, kwarg1=60, kwarg2=70) # Retrace
self.assertEqual(trace_count[0], 4)
def testFollowTypeHintsTraceWithArgsKwOnlyArgsKwargsAndTypedKwOnlyArg(self):
trace_count = [0]
def func(arg, *args, kwonly: ops.Tensor, **kwargs):
del args, kwonly, kwargs
trace_count[0] += 1
return arg
enabled = polymorphic_function.function(
func, experimental_follow_type_hints=True)
enabled(1, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7)
enabled(100, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7) # Retrace
enabled(1, 20, 30, 40, kwonly=5, kwarg1=6, kwarg2=7) # Retrace
enabled(1, 2, 3, 4, kwonly=50, kwarg1=6, kwarg2=7) # No retrace
enabled(1, 2, 3, 4, kwonly=500, kwarg1=6, kwarg2=7) # No retrace
enabled(1, 2, 3, 4, kwonly=5, kwarg1=60, kwarg2=70) # Retrace
self.assertEqual(trace_count[0], 4)
def testFollowTypeHintsTraceWithArgsKwOnlyArgsKwargsAndTypedKwargs(self):
trace_count = [0]
def func(arg, *args, kwonly, **kwargs: ops.Tensor):
del args, kwonly, kwargs
trace_count[0] += 1
return arg
enabled = polymorphic_function.function(
func, experimental_follow_type_hints=True)
enabled(1, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7)
enabled(100, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7) # Retrace
enabled(1, 20, 30, 40, kwonly=5, kwarg1=6, kwarg2=7) # Retrace
enabled(1, 2, 3, 4, kwonly=50, kwarg1=6, kwarg2=7) # Retrace
enabled(1, 2, 3, 4, kwonly=5, kwarg1=60, kwarg2=70) # No retrace
enabled(1, 2, 3, 4, kwonly=5, kwarg1=600, kwarg2=700) # No retrace
self.assertEqual(trace_count[0], 4)
def testWithExtraWrapper(self):
class Foo(module.Module):

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Internal APIs to be removed in the future."""
from tensorflow.python.eager.polymorphic_function import monomorphic_function
@ -369,8 +368,7 @@ def defun_with_attributes(func=None,
autograph=True,
experimental_autograph_options=None,
jit_compile=None,
reduce_retracing=False,
experimental_follow_type_hints=False):
reduce_retracing=False):
"""Compiles a Python function into a callable TensorFlow graph.
This function supports adding extra function attributes. See detailed
@ -392,7 +390,6 @@ def defun_with_attributes(func=None,
experimental_autograph_options.
jit_compile: same as defun()'s jit_compile.
reduce_retracing: same as defun()'s reduce_retracing
experimental_follow_type_hints: see `tf.function`.
Returns:
Same as the return value of defun, with attributes added to the function in
@ -418,8 +415,7 @@ def defun_with_attributes(func=None,
autograph=autograph,
autograph_options=experimental_autograph_options,
jit_compile=jit_compile,
reduce_retracing=reduce_retracing,
experimental_follow_type_hints=experimental_follow_type_hints))
reduce_retracing=reduce_retracing))
# This code path is for the `foo = tfe.defun(foo, ...)` use case
if func is not None:
@ -513,8 +509,7 @@ def clear_function_callbacks():
@deprecation.deprecated(
None,
"Use `tf.config.run_functions_eagerly` instead of the experimental "
None, "Use `tf.config.run_functions_eagerly` instead of the experimental "
"version.")
@tf_export("config.experimental_run_functions_eagerly")
def experimental_run_functions_eagerly(run_eagerly):

View File

@ -73,8 +73,7 @@ class TracingCompiler:
autograph_options=None,
reduce_retracing=False,
capture_by_value=None,
jit_compile=None,
experimental_follow_type_hints=False):
jit_compile=None):
"""Initializes a `TracingCompiler`.
Args:
@ -98,7 +97,6 @@ class TracingCompiler:
default to False.
jit_compile: Force-compile the function with XLA, cf.
tf.function doc on jit_compile.
experimental_follow_type_hints: See the documentation for `tf.function`.
Raises:
ValueError: if `input_signature` is not None and the `python_function`'s
@ -109,8 +107,7 @@ class TracingCompiler:
self._function_spec = function_spec.FunctionSpec.from_function_and_signature(
python_function,
input_signature,
is_pure=pure_function,
experimental_follow_type_hints=experimental_follow_type_hints)
is_pure=pure_function)
self._name = name
self._autograph = autograph
self._autograph_options = autograph_options
@ -128,7 +125,6 @@ class TracingCompiler:
# create different functions for each instance.
self._descriptor_cache = weakref.WeakKeyDictionary()
self._jit_compile = jit_compile
self._experimental_follow_type_hints = experimental_follow_type_hints
def __call__(self, *args, **kwargs):
"""Calls a graph function specialized to the inputs."""

View File

@ -23,7 +23,7 @@ tf_class {
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'python_function\', \'name\', \'input_signature\', \'autograph\', \'jit_compile\', \'reduce_retracing\', \'experimental_implements\', \'experimental_autograph_options\', \'experimental_follow_type_hints\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\', \'False\', \'None\', \'None\', \'None\'], "
argspec: "args=[\'self\', \'python_function\', \'name\', \'input_signature\', \'autograph\', \'jit_compile\', \'reduce_retracing\', \'experimental_implements\', \'experimental_autograph_options\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\', \'False\', \'None\', \'None\'], "
}
member_method {
name: "experimental_get_compiler_ir"

View File

@ -6,6 +6,6 @@ tf_module {
}
member_method {
name: "defun_with_attributes"
argspec: "args=[\'func\', \'input_signature\', \'attributes\', \'autograph\', \'experimental_autograph_options\', \'jit_compile\', \'reduce_retracing\', \'experimental_follow_type_hints\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'False\', \'False\'], "
argspec: "args=[\'func\', \'input_signature\', \'attributes\', \'autograph\', \'experimental_autograph_options\', \'jit_compile\', \'reduce_retracing\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'False\'], "
}
}