Create placeholders using FunctionType in TracingCompiler.

PiperOrigin-RevId: 499337642
This commit is contained in:
Umer Javed 2023-01-03 16:05:48 -08:00 committed by TensorFlower Gardener
parent 446878d523
commit 7db3792f5e
10 changed files with 114 additions and 201 deletions

View File

@ -13,8 +13,10 @@
signature is malformed, e.g.
* Using functools.wraps on a function with different signature
* Using functools.partial with an invalid tf.function input
* tf.types.experimental.TraceType now requires an additional
* tf.types.experimental.TraceType now requires an additional
`placeholder_value` method to be defined.
* tf.function now traces with placeholder values generated by TraceType
instead of the value itself.
* `tf.config.experimental.enable_mlir_graph_optimization`:

View File

@ -84,6 +84,10 @@ class Literal(trace.TraceType, serialization.Serializable):
type(self.value).__name__)
def placeholder_value(self, placeholder_context=None) -> Any:
# TODO(b/263505796): Remove this check when a range's placeholder output
# is expected to be a range and not a list.
if isinstance(self.value, range):
return list(self.value)
return self.value
def __eq__(self, other) -> bool:

View File

@ -141,64 +141,51 @@ class DebugEventsMonitorTest(dumping_callback_test_lib.DumpingCallbackTestBase,
traces = test_monitor.graph_execution_traces
if tensor_debug_mode == "CONCISE_HEALTH":
self.assertLen(traces, 3) # [Placeholder:0, Unique:0 , Sum:0].
self.assertEqual(traces[0].op_type, "Placeholder")
self.assertLen(traces, 2) # [Unique:0 , Sum:0].
self.assertEqual(traces[0].op_type, "Unique")
self.assertEqual(traces[0].output_slot, 0)
self.assertEqual(traces[1].op_type, "Unique")
self.assertEqual(traces[1].output_slot, 0)
# Unique:1 is not traced under CONCISE_HEALTH mode, as it's int-dtype.
self.assertEqual(traces[2].op_type, "Sum")
self.assertEqual(traces[2].output_slot, 0)
self.assertEqual(traces[1].op_type, "Sum")
self.assertEqual(traces[1].output_slot, 0)
# [tensor_id, element_count, neg_inf_count, pos_inf_count, nan_count].
self.assertLen(traces[0].debug_tensor_value, 5)
self.assertLen(traces[1].debug_tensor_value, 5)
self.assertLen(traces[2].debug_tensor_value, 5)
elif tensor_debug_mode == "FULL_HEALTH":
self.assertLen(traces, 3) # [Placeholder:0, Unique:0 , Sum:0].
self.assertEqual(traces[0].op_type, "Placeholder")
self.assertLen(traces, 2) # [Unique:0 , Sum:0].
self.assertEqual(traces[0].op_type, "Unique")
self.assertEqual(traces[0].output_slot, 0)
self.assertEqual(traces[1].op_type, "Unique")
self.assertEqual(traces[1].output_slot, 0)
# Unique:1 is not traced under FULL_HEALTH mode, as it's int-dtype.
self.assertEqual(traces[2].op_type, "Sum")
self.assertEqual(traces[2].output_slot, 0)
self.assertEqual(traces[1].op_type, "Sum")
self.assertEqual(traces[1].output_slot, 0)
# [tensor_id, device_id, dtype, rank, element_count,
# neg_inf_count, pos_inf_count, nan_count,
# neg_finite_count, zero_count, pos_finite_count].
self.assertLen(traces[0].debug_tensor_value, 11)
self.assertLen(traces[1].debug_tensor_value, 11)
self.assertLen(traces[2].debug_tensor_value, 11)
elif tensor_debug_mode == "FULL_TENSOR":
# [Placeholder:0, Unique:0, Unique:1, Const:0, Sum:0].
self.assertLen(traces, 5)
self.assertEqual(traces[0].op_type, "Placeholder")
# [Unique:0, Unique:1, Const:0, Sum:0].
self.assertEqual(traces[0].op_type, "Unique")
self.assertEqual(traces[0].output_slot, 0)
self.assertIsNone(traces[0].debug_tensor_value)
self.assertAllEqual(
reader.graph_execution_trace_to_tensor_value(traces[0]),
[2., 6., 8., 1., 2.])
[2., 6., 8., 1.])
self.assertEqual(traces[1].op_type, "Unique")
self.assertEqual(traces[1].output_slot, 0)
self.assertEqual(traces[1].output_slot, 1)
self.assertIsNone(traces[1].debug_tensor_value)
self.assertAllEqual(
reader.graph_execution_trace_to_tensor_value(traces[1]),
[2., 6., 8., 1.])
self.assertEqual(traces[2].op_type, "Unique")
self.assertEqual(traces[2].output_slot, 1)
self.assertIsNone(traces[2].debug_tensor_value)
self.assertAllEqual(
reader.graph_execution_trace_to_tensor_value(traces[2]),
[0, 1, 2, 3, 0])
self.assertEqual(traces[3].op_type, "Const")
self.assertEqual(traces[2].op_type, "Const")
self.assertEqual(traces[2].output_slot, 0)
self.assertIsNone(traces[2].debug_tensor_value)
self.assertAllClose(
reader.graph_execution_trace_to_tensor_value(traces[2]), [0])
self.assertEqual(traces[3].op_type, "Sum")
self.assertEqual(traces[3].output_slot, 0)
self.assertIsNone(traces[3].debug_tensor_value)
self.assertAllClose(
reader.graph_execution_trace_to_tensor_value(traces[3]), [0])
self.assertEqual(traces[4].op_type, "Sum")
self.assertEqual(traces[4].output_slot, 0)
self.assertIsNone(traces[4].debug_tensor_value)
self.assertAllClose(
reader.graph_execution_trace_to_tensor_value(traces[4]), 17.)
reader.graph_execution_trace_to_tensor_value(traces[3]), 17.)
class AlertDataObjectsTest(test_util.TensorFlowTestCase):

View File

@ -285,10 +285,10 @@ class DumpingCallbackTest(
reader.update()
graph_exec_traces = reader.graph_execution_traces()
executed_op_types = [trace.op_type for trace in graph_exec_traces
if trace.op_type != "Const"]
if trace.op_type not in ["Const", "Placeholder"]]
self.assertCountEqual(
executed_op_types,
["Placeholder", "Placeholder", "AddV2", "Sub", "RealDiv"])
["AddV2", "Sub", "RealDiv"])
if tensor_debug_mode == "CURT_HEALTH":
for trace in graph_exec_traces:
# 1st element: tensor_id, should be >= 0.
@ -404,10 +404,10 @@ class DumpingCallbackTest(
reader.update()
graph_exec_traces = reader.graph_execution_traces()
executed_op_types = [trace.op_type for trace in graph_exec_traces
if trace.op_type != "Const"]
if trace.op_type not in ["Const", "Placeholder"]]
self.assertEqual(
executed_op_types,
["Placeholder", "Placeholder", "LogicalAnd", "LogicalNot"])
["LogicalAnd", "LogicalNot"])
for trace in graph_exec_traces:
tensor_id = reader.graph_execution_trace_to_tensor_id(trace)
self.assertGreaterEqual(tensor_id, 0)
@ -502,7 +502,6 @@ class DumpingCallbackTest(
set(reader.device_name_map().values()))
# Verify the recorded graph-building history.
placeholder_op_digests = reader.graph_op_digests(op_type="Placeholder")
add_op_digests = reader.graph_op_digests(op_type="AddV2")
self.assertLen(add_op_digests, 2)
self.assertEqual(
@ -527,59 +526,33 @@ class DumpingCallbackTest(
self._verifyStackFrames(stack_frames)
graph_exec_traces = [trace for trace in reader.graph_execution_traces()
if trace.op_type != "Const"]
if trace.op_type not in ["Const", "Placeholder"]]
executed_op_types = [digest.op_type for digest in graph_exec_traces]
self.assertEqual(
executed_op_types,
["Placeholder", "Placeholder", "Placeholder", "Placeholder",
"AddV2", "Log", "AddV2", "Sin"])
placeholder_traces = graph_exec_traces[:4]
non_placeholder_traces = graph_exec_traces[4:]
executed_op_types, ["AddV2", "Log", "AddV2", "Sin"])
# Verify the graph ID stack of each op.
# The outer function's 1st Placeholder.
self.assertEqual(
reader.graph_by_id(placeholder_traces[0].graph_ids[-1]).name,
"sin1p_log_sum")
# The outer function's 2nd Placeholder.
self.assertEqual(
reader.graph_by_id(placeholder_traces[1].graph_ids[-1]).name,
"sin1p_log_sum")
# The inner function's 1st Placeholder.
self.assertEqual(
reader.graph_by_id(placeholder_traces[2].graph_ids[-1]).name,
"log_sum")
self.assertEqual(
reader.graph_by_id(placeholder_traces[2].graph_ids[-2]).name,
"sin1p_log_sum")
# The inner function's 2nd Placeholder.
self.assertEqual(
reader.graph_by_id(placeholder_traces[3].graph_ids[-1]).name,
"log_sum")
self.assertEqual(
reader.graph_by_id(placeholder_traces[3].graph_ids[-2]).name,
"sin1p_log_sum")
# 1st AddV2 op.
self.assertEqual(
reader.graph_by_id(non_placeholder_traces[0].graph_ids[-1]).name,
reader.graph_by_id(graph_exec_traces[0].graph_ids[-1]).name,
"log_sum")
self.assertEqual(
reader.graph_by_id(non_placeholder_traces[0].graph_ids[-2]).name,
reader.graph_by_id(graph_exec_traces[0].graph_ids[-2]).name,
"sin1p_log_sum")
# Log op.
self.assertEqual(
reader.graph_by_id(non_placeholder_traces[1].graph_ids[-1]).name,
reader.graph_by_id(graph_exec_traces[1].graph_ids[-1]).name,
"log_sum")
self.assertEqual(
reader.graph_by_id(non_placeholder_traces[1].graph_ids[-2]).name,
reader.graph_by_id(graph_exec_traces[1].graph_ids[-2]).name,
"sin1p_log_sum")
# 2nd AddV2 op.
self.assertEqual(
reader.graph_by_id(non_placeholder_traces[2].graph_ids[-1]).name,
reader.graph_by_id(graph_exec_traces[2].graph_ids[-1]).name,
"sin1p_log_sum")
# Sin op.
self.assertEqual(
reader.graph_by_id(non_placeholder_traces[3].graph_ids[-1]).name,
reader.graph_by_id(graph_exec_traces[3].graph_ids[-1]).name,
"sin1p_log_sum")
if tensor_debug_mode == "NO_TENSOR":
@ -592,61 +565,37 @@ class DumpingCallbackTest(
# In each case, the 1st element of debug_tensor_value is the ID of the
# symbolic tenosr and the 2nd element is a zero indicating there is no
# inf or nan.
self.assertAllClose( # 1st outer placeholder.
placeholder_traces[0].debug_tensor_value,
[placeholder_op_digests[0].output_tensor_ids[0], 0.0])
self.assertAllClose( # 2nd outer placeholder.
placeholder_traces[1].debug_tensor_value,
[placeholder_op_digests[1].output_tensor_ids[0], 0.0])
self.assertAllClose( # 1st inner placeholder.
placeholder_traces[2].debug_tensor_value,
[placeholder_op_digests[2].output_tensor_ids[0], 0.0])
self.assertAllClose( # 2nd outer placeholder.
placeholder_traces[3].debug_tensor_value,
[placeholder_op_digests[3].output_tensor_ids[0], 0.0])
self.assertAllClose( # 1st AddV2 op.
non_placeholder_traces[0].debug_tensor_value,
graph_exec_traces[0].debug_tensor_value,
[add_op_digests[0].output_tensor_ids[0], 0.0])
self.assertAllClose( # Log op.
non_placeholder_traces[1].debug_tensor_value,
graph_exec_traces[1].debug_tensor_value,
[log_op_digests[0].output_tensor_ids[0], 0.0])
self.assertAllClose( # 2nd AddV2 op.
non_placeholder_traces[2].debug_tensor_value,
graph_exec_traces[2].debug_tensor_value,
[add_op_digests[1].output_tensor_ids[0], 0.0])
self.assertAllClose( # Sin op.
non_placeholder_traces[3].debug_tensor_value,
graph_exec_traces[3].debug_tensor_value,
[sin_op_digests[0].output_tensor_ids[0], 0.0])
elif tensor_debug_mode == "CONCISE_HEALTH":
# 1st element: tensor_id.
# 2nd element: element count. Remaining elements: all zero because there
# is no -inf, inf or nan.
self.assertAllClose( # 1st outer placeholder.
placeholder_traces[0].debug_tensor_value,
[placeholder_op_digests[0].output_tensor_ids[0], 1., 0., 0., 0.])
self.assertAllClose( # 2nd outer placeholder.
placeholder_traces[1].debug_tensor_value,
[placeholder_op_digests[1].output_tensor_ids[0], 1., 0., 0., 0.])
self.assertAllClose( # 1st inner placeholder.
placeholder_traces[2].debug_tensor_value,
[placeholder_op_digests[2].output_tensor_ids[0], 1., 0., 0., 0.])
self.assertAllClose( # 2nd outer placeholder.
placeholder_traces[3].debug_tensor_value,
[placeholder_op_digests[3].output_tensor_ids[0], 1., 0., 0., 0.])
# 1st AddV2 op.
self.assertAllClose(
non_placeholder_traces[0].debug_tensor_value,
graph_exec_traces[0].debug_tensor_value,
[add_op_digests[0].output_tensor_ids[0], 1.0, 0.0, 0.0, 0.0])
# Log op.
self.assertAllClose(
non_placeholder_traces[1].debug_tensor_value,
graph_exec_traces[1].debug_tensor_value,
[log_op_digests[0].output_tensor_ids[0], 1.0, 0.0, 0.0, 0.0])
# 2nd AddV2 op.
self.assertAllClose(
non_placeholder_traces[2].debug_tensor_value,
graph_exec_traces[2].debug_tensor_value,
[add_op_digests[1].output_tensor_ids[0], 1.0, 0.0, 0.0, 0.0])
# Sin op.
self.assertAllClose(
non_placeholder_traces[3].debug_tensor_value,
graph_exec_traces[3].debug_tensor_value,
[sin_op_digests[0].output_tensor_ids[0], 1.0, 0.0, 0.0, 0.0])
elif tensor_debug_mode == "FULL_HEALTH":
# Elements: [
@ -655,40 +604,24 @@ class DumpingCallbackTest(
# dtype, rank, element_count,
# neg_inf_count, pos_inf_count, nan_count
# neg_finite_count, zero_count, pos_finite_count]
self.assertAllClose( # 1st outer placeholder.
placeholder_traces[0].debug_tensor_value,
[placeholder_op_digests[0].output_tensor_ids[0],
-1, 1, 0, 1, 0, 0, 0, 0, 0, 1])
self.assertAllClose( # 2nd outer placeholder.
placeholder_traces[1].debug_tensor_value,
[placeholder_op_digests[1].output_tensor_ids[0],
-1, 1, 0, 1, 0, 0, 0, 0, 0, 1])
self.assertAllClose( # 1st inner placeholder.
placeholder_traces[2].debug_tensor_value,
[placeholder_op_digests[2].output_tensor_ids[0],
-1, 1, 0, 1, 0, 0, 0, 0, 0, 1])
self.assertAllClose( # 2nd outer placeholder.
placeholder_traces[3].debug_tensor_value,
[placeholder_op_digests[3].output_tensor_ids[0],
-1, 1, 0, 1, 0, 0, 0, 0, 0, 1])
# 1st AddV2 op.
self.assertAllClose(
non_placeholder_traces[0].debug_tensor_value,
graph_exec_traces[0].debug_tensor_value,
[add_op_digests[0].output_tensor_ids[0],
-1, 1, 0, 1, 0, 0, 0, 0, 0, 1])
# Log op.
self.assertAllClose(
non_placeholder_traces[1].debug_tensor_value,
graph_exec_traces[1].debug_tensor_value,
[log_op_digests[0].output_tensor_ids[0],
-1, 1, 0, 1, 0, 0, 0, 0, 0, 1])
# 2nd AddV2 op.
self.assertAllClose(
non_placeholder_traces[2].debug_tensor_value,
graph_exec_traces[2].debug_tensor_value,
[add_op_digests[1].output_tensor_ids[0],
-1, 1, 0, 1, 0, 0, 0, 0, 0, 1])
# Sin op.
self.assertAllClose(
non_placeholder_traces[3].debug_tensor_value,
graph_exec_traces[3].debug_tensor_value,
[sin_op_digests[0].output_tensor_ids[0],
-1, 1, 0, 1, 0, 0, 0, 0, 0, 1])
elif tensor_debug_mode == "SHAPE":
@ -697,58 +630,36 @@ class DumpingCallbackTest(
# 3rd element: rank (scalar).
# 4th element: element count (1).
# Remaining elements: shape padded to fixed length (6).
self.assertAllClose( # 1st outer placeholder.
placeholder_traces[0].debug_tensor_value,
[placeholder_op_digests[0].output_tensor_ids[0],
1, 0, 1, 0, 0, 0, 0, 0, 0])
self.assertAllClose( # 2nd outer placeholder.
placeholder_traces[1].debug_tensor_value,
[placeholder_op_digests[1].output_tensor_ids[0],
1, 0, 1, 0, 0, 0, 0, 0, 0])
self.assertAllClose( # 1st inner placeholder.
placeholder_traces[2].debug_tensor_value,
[placeholder_op_digests[2].output_tensor_ids[0],
1, 0, 1, 0, 0, 0, 0, 0, 0])
self.assertAllClose( # 2nd outer placeholder.
placeholder_traces[3].debug_tensor_value,
[placeholder_op_digests[3].output_tensor_ids[0],
1, 0, 1, 0, 0, 0, 0, 0, 0])
# 1st AddV2 op.
self.assertAllClose(
non_placeholder_traces[0].debug_tensor_value,
graph_exec_traces[0].debug_tensor_value,
[add_op_digests[0].output_tensor_ids[0], 1, 0, 1, 0, 0, 0, 0, 0, 0])
# Log op.
self.assertAllClose(
non_placeholder_traces[1].debug_tensor_value,
graph_exec_traces[1].debug_tensor_value,
[log_op_digests[0].output_tensor_ids[0], 1, 0, 1, 0, 0, 0, 0, 0, 0])
# 2nd AddV2 op.
self.assertAllClose(
non_placeholder_traces[2].debug_tensor_value,
graph_exec_traces[2].debug_tensor_value,
[add_op_digests[1].output_tensor_ids[0], 1, 0, 1, 0, 0, 0, 0, 0, 0])
# Sin op.
self.assertAllClose(
non_placeholder_traces[3].debug_tensor_value,
graph_exec_traces[3].debug_tensor_value,
[sin_op_digests[0].output_tensor_ids[0], 1, 0, 1, 0, 0, 0, 0, 0, 0])
else: # FULL_TENSOR.
placeholder_full_tensor_values = [
full_tensor_values = [
reader.graph_execution_trace_to_tensor_value(trace)
for trace in placeholder_traces]
self.assertAllClose(placeholder_full_tensor_values[0], x) # Input x.
self.assertAllClose(placeholder_full_tensor_values[1], y) # Input y.
self.assertAllClose(placeholder_full_tensor_values[2], x) # Input x.
self.assertAllClose(placeholder_full_tensor_values[3], y) # Input y.
non_placeholder_full_tensor_values = [
reader.graph_execution_trace_to_tensor_value(trace)
for trace in non_placeholder_traces]
for trace in graph_exec_traces]
self.assertAllClose(
non_placeholder_full_tensor_values[0], 5.0) # 1st AddV2 op.
full_tensor_values[0], 5.0) # 1st AddV2 op.
self.assertAllClose(
non_placeholder_full_tensor_values[1], np.log(5.0)) # Log op.
full_tensor_values[1], np.log(5.0)) # Log op.
self.assertAllClose(
non_placeholder_full_tensor_values[2],
full_tensor_values[2],
np.log(5.0) + 1.0) # 2nd AddV2 op.
self.assertAllClose(
non_placeholder_full_tensor_values[3],
full_tensor_values[3],
np.sin(np.log(5.0) + 1.0)) # Sin op.
@parameterized.named_parameters(

View File

@ -531,10 +531,9 @@ def cast_inputs(args, kwargs, input_signature):
"""Casts args, kwargs to TF values based on an optional input_signature."""
if input_signature is None:
args = cast_numpy_inputs(args)
kwargs = cast_numpy_inputs(kwargs)
else:
args = cast_inputs_to_signature(args, input_signature)
kwargs = {}
kwargs = cast_numpy_inputs(kwargs)
return args, kwargs

View File

@ -1141,20 +1141,24 @@ class Function(core.GenericFunction, trackable.Trackable):
Returns:
A list of instances of `ConcreteFunction`.
"""
concrete_functions = self._list_all_concrete_functions()
seen_signatures = []
for concrete_function in concrete_functions:
signature = concrete_function.structured_input_signature
flattened = nest.flatten(signature)
if any(
isinstance(arg, func_graph_module.UnknownArgument)
for arg in flattened):
logging.info("Unsupported signature for serialization: %s.", signature)
continue
equal_to_signature = functools.partial(
function_spec_lib.is_same_structure, signature, check_values=True)
if not any(equal_to_signature(s) for s in seen_signatures):
seen_signatures.append(signature)
if self.input_signature is not None:
seen_signatures.append((self.input_signature, {}))
else:
concrete_functions = self._list_all_concrete_functions()
for concrete_function in concrete_functions:
signature = concrete_function.structured_input_signature
flattened = nest.flatten(signature)
if any(
isinstance(arg, func_graph_module.UnknownArgument)
for arg in flattened):
logging.info("Unsupported signature for serialization: %s.",
signature)
continue
equal_to_signature = functools.partial(
function_spec_lib.is_same_structure, signature, check_values=True)
if not any(equal_to_signature(s) for s in seen_signatures):
seen_signatures.append(signature)
# Re-create concrete functions for these signatures. Re-creating ensures
# that if the cache key has changed, the function will be traced again.

View File

@ -877,7 +877,9 @@ class DefunTest(test.TestCase, parameterized.TestCase):
save(mod, '/tmp/kwonlyf', defined.get_concrete_function(*signature))
loaded = load('/tmp/kwonlyf')
result = loaded.signatures['serving_default'](
a=array_ops.constant(1), b=array_ops.constant(2))
a=array_ops.constant(1),
b=array_ops.constant(2),
d=array_ops.constant(5))
self.assertEqual(result['output_0'].numpy(), 11)
def testInputSignatureWithKeywordOnlyArgsNoDefaults(self):

View File

@ -126,9 +126,6 @@ class TracingCompiler:
# create different functions for each instance.
self._descriptor_cache = weakref.WeakKeyDictionary()
self._jit_compile = jit_compile
# Flag for preventing recreating placeholders. Set to False when reduced
# retracing is True and input_signature is None
self._create_placeholders = True
def __call__(self, *args, **kwargs):
"""Calls a graph function specialized to the inputs."""
@ -296,7 +293,7 @@ class TracingCompiler:
autograph_options=self._autograph_options,
arg_names=arg_names,
capture_by_value=self._capture_by_value,
create_placeholders=self._create_placeholders),
create_placeholders=False),
self._function_attributes,
spec=self.function_spec,
# Tell the ConcreteFunction to clean up its graph once it goes out of
@ -363,22 +360,23 @@ class TracingCompiler:
func_graph = func_graph_module.FuncGraph(
self._name, capture_by_value=self._capture_by_value)
if self.input_signature is None and self._reduce_retracing:
self._create_placeholders = False
general_func_type = self._function_cache.generalize(
target_func_type = self._function_cache.generalize(
current_func_context, lookup_func_type)
handledata_mapping = lookup_func_context.get_handledata_mapping()
placeholder_mapping = lookup_func_context.get_placeholder_mapping()
placeholder_context = trace_type.InternalPlaceholderContext(
func_graph, placeholder_mapping, handledata_mapping)
with func_graph.as_default():
placeholder_bound_args = general_func_type.placeholder_arguments(
placeholder_context)
if self.function_spec.is_method:
# TODO(fmuham): canonicalize_function_inputs removes self arg.
args = placeholder_bound_args.args[1:]
else:
args = placeholder_bound_args.args
kwargs = placeholder_bound_args.kwargs
else:
target_func_type = lookup_func_type
handledata_mapping = lookup_func_context.get_handledata_mapping()
placeholder_mapping = lookup_func_context.get_placeholder_mapping()
placeholder_context = trace_type.InternalPlaceholderContext(
func_graph, placeholder_mapping, handledata_mapping)
with func_graph.as_default():
placeholder_bound_args = target_func_type.placeholder_arguments(
placeholder_context)
if self.function_spec.is_method:
# TODO(fmuham): canonicalize_function_inputs removes self arg.
args = placeholder_bound_args.args[1:]
else:
args = placeholder_bound_args.args
kwargs = placeholder_bound_args.kwargs
concrete_function = self._create_concrete_function(
args, kwargs, func_graph)
@ -391,12 +389,8 @@ class TracingCompiler:
# Create a cache_key with args and captures
traced_func_deletion_observer = lookup_func_context.deletion_observer
if self.input_signature is None and self._reduce_retracing:
traced_func_type = _insert_capture_type(
general_func_type, captures, lookup_func_context)
else:
traced_func_type = _insert_capture_type(
lookup_func_type, captures, lookup_func_context)
traced_func_type = _insert_capture_type(
target_func_type, captures, lookup_func_context)
self._function_cache.add(current_func_context, traced_func_type,
traced_func_deletion_observer,

View File

@ -1322,6 +1322,16 @@ def func_graph_from_py_func(name,
func_outputs = nest.map_structure(
convert, func_outputs, expand_composites=True)
# flatten and unflatten func_args and func_kwargs to maintain parity
# from flattening which sorts by key
func_args = nest.pack_sequence_as(
func_args,
nest.flatten(func_args, expand_composites=True),
expand_composites=True)
func_kwargs = nest.pack_sequence_as(
func_kwargs,
nest.flatten(func_kwargs, expand_composites=True),
expand_composites=True)
check_func_mutation(func_args_before, func_kwargs_before, func_args,
func_kwargs, original_func)
finally:

View File

@ -161,8 +161,8 @@ class TraceType(metaclass=abc.ABCMeta):
def placeholder_value(self, placeholder_context=None) -> Any:
"""Creates a placeholder for tracing.
Often it is more useful to trace with a placeholder value than an actual
one. For example, a placeholder value can represent multiple different
tf.funcion traces with the placeholder value rather than the actual value.
For example, a placeholder value can represent multiple different
actual values. This means that the trace generated with that placeholder
value is more general and reusable which saves expensive retracing.
@ -188,7 +188,7 @@ class TraceType(metaclass=abc.ABCMeta):
```python
@tf.function
def foo(x):
# Here `x` can be the placeholder value
# Here `x` is be the placeholder value
...
foo(x) # Here `x` is the actual value