mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 00:19:58 +01:00
Create placeholders using FunctionType in TracingCompiler.
PiperOrigin-RevId: 499337642
This commit is contained in:
parent
446878d523
commit
7db3792f5e
|
|
@ -13,8 +13,10 @@
|
|||
signature is malformed, e.g.
|
||||
* Using functools.wraps on a function with different signature
|
||||
* Using functools.partial with an invalid tf.function input
|
||||
* tf.types.experimental.TraceType now requires an additional
|
||||
* tf.types.experimental.TraceType now requires an additional
|
||||
`placeholder_value` method to be defined.
|
||||
* tf.function now traces with placeholder values generated by TraceType
|
||||
instead of the value itself.
|
||||
|
||||
* `tf.config.experimental.enable_mlir_graph_optimization`:
|
||||
|
||||
|
|
|
|||
|
|
@ -84,6 +84,10 @@ class Literal(trace.TraceType, serialization.Serializable):
|
|||
type(self.value).__name__)
|
||||
|
||||
def placeholder_value(self, placeholder_context=None) -> Any:
|
||||
# TODO(b/263505796): Remove this check when a range's placeholder output
|
||||
# is expected to be a range and not a list.
|
||||
if isinstance(self.value, range):
|
||||
return list(self.value)
|
||||
return self.value
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
|
|
|
|||
|
|
@ -141,64 +141,51 @@ class DebugEventsMonitorTest(dumping_callback_test_lib.DumpingCallbackTestBase,
|
|||
|
||||
traces = test_monitor.graph_execution_traces
|
||||
if tensor_debug_mode == "CONCISE_HEALTH":
|
||||
self.assertLen(traces, 3) # [Placeholder:0, Unique:0 , Sum:0].
|
||||
self.assertEqual(traces[0].op_type, "Placeholder")
|
||||
self.assertLen(traces, 2) # [Unique:0 , Sum:0].
|
||||
self.assertEqual(traces[0].op_type, "Unique")
|
||||
self.assertEqual(traces[0].output_slot, 0)
|
||||
self.assertEqual(traces[1].op_type, "Unique")
|
||||
self.assertEqual(traces[1].output_slot, 0)
|
||||
# Unique:1 is not traced under CONCISE_HEALTH mode, as it's int-dtype.
|
||||
self.assertEqual(traces[2].op_type, "Sum")
|
||||
self.assertEqual(traces[2].output_slot, 0)
|
||||
self.assertEqual(traces[1].op_type, "Sum")
|
||||
self.assertEqual(traces[1].output_slot, 0)
|
||||
# [tensor_id, element_count, neg_inf_count, pos_inf_count, nan_count].
|
||||
self.assertLen(traces[0].debug_tensor_value, 5)
|
||||
self.assertLen(traces[1].debug_tensor_value, 5)
|
||||
self.assertLen(traces[2].debug_tensor_value, 5)
|
||||
elif tensor_debug_mode == "FULL_HEALTH":
|
||||
self.assertLen(traces, 3) # [Placeholder:0, Unique:0 , Sum:0].
|
||||
self.assertEqual(traces[0].op_type, "Placeholder")
|
||||
self.assertLen(traces, 2) # [Unique:0 , Sum:0].
|
||||
self.assertEqual(traces[0].op_type, "Unique")
|
||||
self.assertEqual(traces[0].output_slot, 0)
|
||||
self.assertEqual(traces[1].op_type, "Unique")
|
||||
self.assertEqual(traces[1].output_slot, 0)
|
||||
# Unique:1 is not traced under FULL_HEALTH mode, as it's int-dtype.
|
||||
self.assertEqual(traces[2].op_type, "Sum")
|
||||
self.assertEqual(traces[2].output_slot, 0)
|
||||
self.assertEqual(traces[1].op_type, "Sum")
|
||||
self.assertEqual(traces[1].output_slot, 0)
|
||||
# [tensor_id, device_id, dtype, rank, element_count,
|
||||
# neg_inf_count, pos_inf_count, nan_count,
|
||||
# neg_finite_count, zero_count, pos_finite_count].
|
||||
self.assertLen(traces[0].debug_tensor_value, 11)
|
||||
self.assertLen(traces[1].debug_tensor_value, 11)
|
||||
self.assertLen(traces[2].debug_tensor_value, 11)
|
||||
elif tensor_debug_mode == "FULL_TENSOR":
|
||||
# [Placeholder:0, Unique:0, Unique:1, Const:0, Sum:0].
|
||||
self.assertLen(traces, 5)
|
||||
self.assertEqual(traces[0].op_type, "Placeholder")
|
||||
# [Unique:0, Unique:1, Const:0, Sum:0].
|
||||
self.assertEqual(traces[0].op_type, "Unique")
|
||||
self.assertEqual(traces[0].output_slot, 0)
|
||||
self.assertIsNone(traces[0].debug_tensor_value)
|
||||
self.assertAllEqual(
|
||||
reader.graph_execution_trace_to_tensor_value(traces[0]),
|
||||
[2., 6., 8., 1., 2.])
|
||||
[2., 6., 8., 1.])
|
||||
self.assertEqual(traces[1].op_type, "Unique")
|
||||
self.assertEqual(traces[1].output_slot, 0)
|
||||
self.assertEqual(traces[1].output_slot, 1)
|
||||
self.assertIsNone(traces[1].debug_tensor_value)
|
||||
self.assertAllEqual(
|
||||
reader.graph_execution_trace_to_tensor_value(traces[1]),
|
||||
[2., 6., 8., 1.])
|
||||
self.assertEqual(traces[2].op_type, "Unique")
|
||||
self.assertEqual(traces[2].output_slot, 1)
|
||||
self.assertIsNone(traces[2].debug_tensor_value)
|
||||
self.assertAllEqual(
|
||||
reader.graph_execution_trace_to_tensor_value(traces[2]),
|
||||
[0, 1, 2, 3, 0])
|
||||
self.assertEqual(traces[3].op_type, "Const")
|
||||
self.assertEqual(traces[2].op_type, "Const")
|
||||
self.assertEqual(traces[2].output_slot, 0)
|
||||
self.assertIsNone(traces[2].debug_tensor_value)
|
||||
self.assertAllClose(
|
||||
reader.graph_execution_trace_to_tensor_value(traces[2]), [0])
|
||||
self.assertEqual(traces[3].op_type, "Sum")
|
||||
self.assertEqual(traces[3].output_slot, 0)
|
||||
self.assertIsNone(traces[3].debug_tensor_value)
|
||||
self.assertAllClose(
|
||||
reader.graph_execution_trace_to_tensor_value(traces[3]), [0])
|
||||
self.assertEqual(traces[4].op_type, "Sum")
|
||||
self.assertEqual(traces[4].output_slot, 0)
|
||||
self.assertIsNone(traces[4].debug_tensor_value)
|
||||
self.assertAllClose(
|
||||
reader.graph_execution_trace_to_tensor_value(traces[4]), 17.)
|
||||
reader.graph_execution_trace_to_tensor_value(traces[3]), 17.)
|
||||
|
||||
|
||||
class AlertDataObjectsTest(test_util.TensorFlowTestCase):
|
||||
|
|
|
|||
|
|
@ -285,10 +285,10 @@ class DumpingCallbackTest(
|
|||
reader.update()
|
||||
graph_exec_traces = reader.graph_execution_traces()
|
||||
executed_op_types = [trace.op_type for trace in graph_exec_traces
|
||||
if trace.op_type != "Const"]
|
||||
if trace.op_type not in ["Const", "Placeholder"]]
|
||||
self.assertCountEqual(
|
||||
executed_op_types,
|
||||
["Placeholder", "Placeholder", "AddV2", "Sub", "RealDiv"])
|
||||
["AddV2", "Sub", "RealDiv"])
|
||||
if tensor_debug_mode == "CURT_HEALTH":
|
||||
for trace in graph_exec_traces:
|
||||
# 1st element: tensor_id, should be >= 0.
|
||||
|
|
@ -404,10 +404,10 @@ class DumpingCallbackTest(
|
|||
reader.update()
|
||||
graph_exec_traces = reader.graph_execution_traces()
|
||||
executed_op_types = [trace.op_type for trace in graph_exec_traces
|
||||
if trace.op_type != "Const"]
|
||||
if trace.op_type not in ["Const", "Placeholder"]]
|
||||
self.assertEqual(
|
||||
executed_op_types,
|
||||
["Placeholder", "Placeholder", "LogicalAnd", "LogicalNot"])
|
||||
["LogicalAnd", "LogicalNot"])
|
||||
for trace in graph_exec_traces:
|
||||
tensor_id = reader.graph_execution_trace_to_tensor_id(trace)
|
||||
self.assertGreaterEqual(tensor_id, 0)
|
||||
|
|
@ -502,7 +502,6 @@ class DumpingCallbackTest(
|
|||
set(reader.device_name_map().values()))
|
||||
|
||||
# Verify the recorded graph-building history.
|
||||
placeholder_op_digests = reader.graph_op_digests(op_type="Placeholder")
|
||||
add_op_digests = reader.graph_op_digests(op_type="AddV2")
|
||||
self.assertLen(add_op_digests, 2)
|
||||
self.assertEqual(
|
||||
|
|
@ -527,59 +526,33 @@ class DumpingCallbackTest(
|
|||
self._verifyStackFrames(stack_frames)
|
||||
|
||||
graph_exec_traces = [trace for trace in reader.graph_execution_traces()
|
||||
if trace.op_type != "Const"]
|
||||
if trace.op_type not in ["Const", "Placeholder"]]
|
||||
executed_op_types = [digest.op_type for digest in graph_exec_traces]
|
||||
self.assertEqual(
|
||||
executed_op_types,
|
||||
["Placeholder", "Placeholder", "Placeholder", "Placeholder",
|
||||
"AddV2", "Log", "AddV2", "Sin"])
|
||||
placeholder_traces = graph_exec_traces[:4]
|
||||
non_placeholder_traces = graph_exec_traces[4:]
|
||||
executed_op_types, ["AddV2", "Log", "AddV2", "Sin"])
|
||||
|
||||
# Verify the graph ID stack of each op.
|
||||
# The outer function's 1st Placeholder.
|
||||
self.assertEqual(
|
||||
reader.graph_by_id(placeholder_traces[0].graph_ids[-1]).name,
|
||||
"sin1p_log_sum")
|
||||
# The outer function's 2nd Placeholder.
|
||||
self.assertEqual(
|
||||
reader.graph_by_id(placeholder_traces[1].graph_ids[-1]).name,
|
||||
"sin1p_log_sum")
|
||||
# The inner function's 1st Placeholder.
|
||||
self.assertEqual(
|
||||
reader.graph_by_id(placeholder_traces[2].graph_ids[-1]).name,
|
||||
"log_sum")
|
||||
self.assertEqual(
|
||||
reader.graph_by_id(placeholder_traces[2].graph_ids[-2]).name,
|
||||
"sin1p_log_sum")
|
||||
# The inner function's 2nd Placeholder.
|
||||
self.assertEqual(
|
||||
reader.graph_by_id(placeholder_traces[3].graph_ids[-1]).name,
|
||||
"log_sum")
|
||||
self.assertEqual(
|
||||
reader.graph_by_id(placeholder_traces[3].graph_ids[-2]).name,
|
||||
"sin1p_log_sum")
|
||||
# 1st AddV2 op.
|
||||
self.assertEqual(
|
||||
reader.graph_by_id(non_placeholder_traces[0].graph_ids[-1]).name,
|
||||
reader.graph_by_id(graph_exec_traces[0].graph_ids[-1]).name,
|
||||
"log_sum")
|
||||
self.assertEqual(
|
||||
reader.graph_by_id(non_placeholder_traces[0].graph_ids[-2]).name,
|
||||
reader.graph_by_id(graph_exec_traces[0].graph_ids[-2]).name,
|
||||
"sin1p_log_sum")
|
||||
# Log op.
|
||||
self.assertEqual(
|
||||
reader.graph_by_id(non_placeholder_traces[1].graph_ids[-1]).name,
|
||||
reader.graph_by_id(graph_exec_traces[1].graph_ids[-1]).name,
|
||||
"log_sum")
|
||||
self.assertEqual(
|
||||
reader.graph_by_id(non_placeholder_traces[1].graph_ids[-2]).name,
|
||||
reader.graph_by_id(graph_exec_traces[1].graph_ids[-2]).name,
|
||||
"sin1p_log_sum")
|
||||
# 2nd AddV2 op.
|
||||
self.assertEqual(
|
||||
reader.graph_by_id(non_placeholder_traces[2].graph_ids[-1]).name,
|
||||
reader.graph_by_id(graph_exec_traces[2].graph_ids[-1]).name,
|
||||
"sin1p_log_sum")
|
||||
# Sin op.
|
||||
self.assertEqual(
|
||||
reader.graph_by_id(non_placeholder_traces[3].graph_ids[-1]).name,
|
||||
reader.graph_by_id(graph_exec_traces[3].graph_ids[-1]).name,
|
||||
"sin1p_log_sum")
|
||||
|
||||
if tensor_debug_mode == "NO_TENSOR":
|
||||
|
|
@ -592,61 +565,37 @@ class DumpingCallbackTest(
|
|||
# In each case, the 1st element of debug_tensor_value is the ID of the
|
||||
# symbolic tenosr and the 2nd element is a zero indicating there is no
|
||||
# inf or nan.
|
||||
self.assertAllClose( # 1st outer placeholder.
|
||||
placeholder_traces[0].debug_tensor_value,
|
||||
[placeholder_op_digests[0].output_tensor_ids[0], 0.0])
|
||||
self.assertAllClose( # 2nd outer placeholder.
|
||||
placeholder_traces[1].debug_tensor_value,
|
||||
[placeholder_op_digests[1].output_tensor_ids[0], 0.0])
|
||||
self.assertAllClose( # 1st inner placeholder.
|
||||
placeholder_traces[2].debug_tensor_value,
|
||||
[placeholder_op_digests[2].output_tensor_ids[0], 0.0])
|
||||
self.assertAllClose( # 2nd outer placeholder.
|
||||
placeholder_traces[3].debug_tensor_value,
|
||||
[placeholder_op_digests[3].output_tensor_ids[0], 0.0])
|
||||
self.assertAllClose( # 1st AddV2 op.
|
||||
non_placeholder_traces[0].debug_tensor_value,
|
||||
graph_exec_traces[0].debug_tensor_value,
|
||||
[add_op_digests[0].output_tensor_ids[0], 0.0])
|
||||
self.assertAllClose( # Log op.
|
||||
non_placeholder_traces[1].debug_tensor_value,
|
||||
graph_exec_traces[1].debug_tensor_value,
|
||||
[log_op_digests[0].output_tensor_ids[0], 0.0])
|
||||
self.assertAllClose( # 2nd AddV2 op.
|
||||
non_placeholder_traces[2].debug_tensor_value,
|
||||
graph_exec_traces[2].debug_tensor_value,
|
||||
[add_op_digests[1].output_tensor_ids[0], 0.0])
|
||||
self.assertAllClose( # Sin op.
|
||||
non_placeholder_traces[3].debug_tensor_value,
|
||||
graph_exec_traces[3].debug_tensor_value,
|
||||
[sin_op_digests[0].output_tensor_ids[0], 0.0])
|
||||
elif tensor_debug_mode == "CONCISE_HEALTH":
|
||||
# 1st element: tensor_id.
|
||||
# 2nd element: element count. Remaining elements: all zero because there
|
||||
# is no -inf, inf or nan.
|
||||
self.assertAllClose( # 1st outer placeholder.
|
||||
placeholder_traces[0].debug_tensor_value,
|
||||
[placeholder_op_digests[0].output_tensor_ids[0], 1., 0., 0., 0.])
|
||||
self.assertAllClose( # 2nd outer placeholder.
|
||||
placeholder_traces[1].debug_tensor_value,
|
||||
[placeholder_op_digests[1].output_tensor_ids[0], 1., 0., 0., 0.])
|
||||
self.assertAllClose( # 1st inner placeholder.
|
||||
placeholder_traces[2].debug_tensor_value,
|
||||
[placeholder_op_digests[2].output_tensor_ids[0], 1., 0., 0., 0.])
|
||||
self.assertAllClose( # 2nd outer placeholder.
|
||||
placeholder_traces[3].debug_tensor_value,
|
||||
[placeholder_op_digests[3].output_tensor_ids[0], 1., 0., 0., 0.])
|
||||
# 1st AddV2 op.
|
||||
self.assertAllClose(
|
||||
non_placeholder_traces[0].debug_tensor_value,
|
||||
graph_exec_traces[0].debug_tensor_value,
|
||||
[add_op_digests[0].output_tensor_ids[0], 1.0, 0.0, 0.0, 0.0])
|
||||
# Log op.
|
||||
self.assertAllClose(
|
||||
non_placeholder_traces[1].debug_tensor_value,
|
||||
graph_exec_traces[1].debug_tensor_value,
|
||||
[log_op_digests[0].output_tensor_ids[0], 1.0, 0.0, 0.0, 0.0])
|
||||
# 2nd AddV2 op.
|
||||
self.assertAllClose(
|
||||
non_placeholder_traces[2].debug_tensor_value,
|
||||
graph_exec_traces[2].debug_tensor_value,
|
||||
[add_op_digests[1].output_tensor_ids[0], 1.0, 0.0, 0.0, 0.0])
|
||||
# Sin op.
|
||||
self.assertAllClose(
|
||||
non_placeholder_traces[3].debug_tensor_value,
|
||||
graph_exec_traces[3].debug_tensor_value,
|
||||
[sin_op_digests[0].output_tensor_ids[0], 1.0, 0.0, 0.0, 0.0])
|
||||
elif tensor_debug_mode == "FULL_HEALTH":
|
||||
# Elements: [
|
||||
|
|
@ -655,40 +604,24 @@ class DumpingCallbackTest(
|
|||
# dtype, rank, element_count,
|
||||
# neg_inf_count, pos_inf_count, nan_count
|
||||
# neg_finite_count, zero_count, pos_finite_count]
|
||||
self.assertAllClose( # 1st outer placeholder.
|
||||
placeholder_traces[0].debug_tensor_value,
|
||||
[placeholder_op_digests[0].output_tensor_ids[0],
|
||||
-1, 1, 0, 1, 0, 0, 0, 0, 0, 1])
|
||||
self.assertAllClose( # 2nd outer placeholder.
|
||||
placeholder_traces[1].debug_tensor_value,
|
||||
[placeholder_op_digests[1].output_tensor_ids[0],
|
||||
-1, 1, 0, 1, 0, 0, 0, 0, 0, 1])
|
||||
self.assertAllClose( # 1st inner placeholder.
|
||||
placeholder_traces[2].debug_tensor_value,
|
||||
[placeholder_op_digests[2].output_tensor_ids[0],
|
||||
-1, 1, 0, 1, 0, 0, 0, 0, 0, 1])
|
||||
self.assertAllClose( # 2nd outer placeholder.
|
||||
placeholder_traces[3].debug_tensor_value,
|
||||
[placeholder_op_digests[3].output_tensor_ids[0],
|
||||
-1, 1, 0, 1, 0, 0, 0, 0, 0, 1])
|
||||
# 1st AddV2 op.
|
||||
self.assertAllClose(
|
||||
non_placeholder_traces[0].debug_tensor_value,
|
||||
graph_exec_traces[0].debug_tensor_value,
|
||||
[add_op_digests[0].output_tensor_ids[0],
|
||||
-1, 1, 0, 1, 0, 0, 0, 0, 0, 1])
|
||||
# Log op.
|
||||
self.assertAllClose(
|
||||
non_placeholder_traces[1].debug_tensor_value,
|
||||
graph_exec_traces[1].debug_tensor_value,
|
||||
[log_op_digests[0].output_tensor_ids[0],
|
||||
-1, 1, 0, 1, 0, 0, 0, 0, 0, 1])
|
||||
# 2nd AddV2 op.
|
||||
self.assertAllClose(
|
||||
non_placeholder_traces[2].debug_tensor_value,
|
||||
graph_exec_traces[2].debug_tensor_value,
|
||||
[add_op_digests[1].output_tensor_ids[0],
|
||||
-1, 1, 0, 1, 0, 0, 0, 0, 0, 1])
|
||||
# Sin op.
|
||||
self.assertAllClose(
|
||||
non_placeholder_traces[3].debug_tensor_value,
|
||||
graph_exec_traces[3].debug_tensor_value,
|
||||
[sin_op_digests[0].output_tensor_ids[0],
|
||||
-1, 1, 0, 1, 0, 0, 0, 0, 0, 1])
|
||||
elif tensor_debug_mode == "SHAPE":
|
||||
|
|
@ -697,58 +630,36 @@ class DumpingCallbackTest(
|
|||
# 3rd element: rank (scalar).
|
||||
# 4th element: element count (1).
|
||||
# Remaining elements: shape padded to fixed length (6).
|
||||
self.assertAllClose( # 1st outer placeholder.
|
||||
placeholder_traces[0].debug_tensor_value,
|
||||
[placeholder_op_digests[0].output_tensor_ids[0],
|
||||
1, 0, 1, 0, 0, 0, 0, 0, 0])
|
||||
self.assertAllClose( # 2nd outer placeholder.
|
||||
placeholder_traces[1].debug_tensor_value,
|
||||
[placeholder_op_digests[1].output_tensor_ids[0],
|
||||
1, 0, 1, 0, 0, 0, 0, 0, 0])
|
||||
self.assertAllClose( # 1st inner placeholder.
|
||||
placeholder_traces[2].debug_tensor_value,
|
||||
[placeholder_op_digests[2].output_tensor_ids[0],
|
||||
1, 0, 1, 0, 0, 0, 0, 0, 0])
|
||||
self.assertAllClose( # 2nd outer placeholder.
|
||||
placeholder_traces[3].debug_tensor_value,
|
||||
[placeholder_op_digests[3].output_tensor_ids[0],
|
||||
1, 0, 1, 0, 0, 0, 0, 0, 0])
|
||||
# 1st AddV2 op.
|
||||
self.assertAllClose(
|
||||
non_placeholder_traces[0].debug_tensor_value,
|
||||
graph_exec_traces[0].debug_tensor_value,
|
||||
[add_op_digests[0].output_tensor_ids[0], 1, 0, 1, 0, 0, 0, 0, 0, 0])
|
||||
# Log op.
|
||||
self.assertAllClose(
|
||||
non_placeholder_traces[1].debug_tensor_value,
|
||||
graph_exec_traces[1].debug_tensor_value,
|
||||
[log_op_digests[0].output_tensor_ids[0], 1, 0, 1, 0, 0, 0, 0, 0, 0])
|
||||
# 2nd AddV2 op.
|
||||
self.assertAllClose(
|
||||
non_placeholder_traces[2].debug_tensor_value,
|
||||
graph_exec_traces[2].debug_tensor_value,
|
||||
[add_op_digests[1].output_tensor_ids[0], 1, 0, 1, 0, 0, 0, 0, 0, 0])
|
||||
# Sin op.
|
||||
self.assertAllClose(
|
||||
non_placeholder_traces[3].debug_tensor_value,
|
||||
graph_exec_traces[3].debug_tensor_value,
|
||||
[sin_op_digests[0].output_tensor_ids[0], 1, 0, 1, 0, 0, 0, 0, 0, 0])
|
||||
else: # FULL_TENSOR.
|
||||
placeholder_full_tensor_values = [
|
||||
|
||||
full_tensor_values = [
|
||||
reader.graph_execution_trace_to_tensor_value(trace)
|
||||
for trace in placeholder_traces]
|
||||
self.assertAllClose(placeholder_full_tensor_values[0], x) # Input x.
|
||||
self.assertAllClose(placeholder_full_tensor_values[1], y) # Input y.
|
||||
self.assertAllClose(placeholder_full_tensor_values[2], x) # Input x.
|
||||
self.assertAllClose(placeholder_full_tensor_values[3], y) # Input y.
|
||||
non_placeholder_full_tensor_values = [
|
||||
reader.graph_execution_trace_to_tensor_value(trace)
|
||||
for trace in non_placeholder_traces]
|
||||
for trace in graph_exec_traces]
|
||||
self.assertAllClose(
|
||||
non_placeholder_full_tensor_values[0], 5.0) # 1st AddV2 op.
|
||||
full_tensor_values[0], 5.0) # 1st AddV2 op.
|
||||
self.assertAllClose(
|
||||
non_placeholder_full_tensor_values[1], np.log(5.0)) # Log op.
|
||||
full_tensor_values[1], np.log(5.0)) # Log op.
|
||||
self.assertAllClose(
|
||||
non_placeholder_full_tensor_values[2],
|
||||
full_tensor_values[2],
|
||||
np.log(5.0) + 1.0) # 2nd AddV2 op.
|
||||
self.assertAllClose(
|
||||
non_placeholder_full_tensor_values[3],
|
||||
full_tensor_values[3],
|
||||
np.sin(np.log(5.0) + 1.0)) # Sin op.
|
||||
|
||||
@parameterized.named_parameters(
|
||||
|
|
|
|||
|
|
@ -531,10 +531,9 @@ def cast_inputs(args, kwargs, input_signature):
|
|||
"""Casts args, kwargs to TF values based on an optional input_signature."""
|
||||
if input_signature is None:
|
||||
args = cast_numpy_inputs(args)
|
||||
kwargs = cast_numpy_inputs(kwargs)
|
||||
else:
|
||||
args = cast_inputs_to_signature(args, input_signature)
|
||||
kwargs = {}
|
||||
kwargs = cast_numpy_inputs(kwargs)
|
||||
|
||||
return args, kwargs
|
||||
|
||||
|
|
|
|||
|
|
@ -1141,20 +1141,24 @@ class Function(core.GenericFunction, trackable.Trackable):
|
|||
Returns:
|
||||
A list of instances of `ConcreteFunction`.
|
||||
"""
|
||||
concrete_functions = self._list_all_concrete_functions()
|
||||
seen_signatures = []
|
||||
for concrete_function in concrete_functions:
|
||||
signature = concrete_function.structured_input_signature
|
||||
flattened = nest.flatten(signature)
|
||||
if any(
|
||||
isinstance(arg, func_graph_module.UnknownArgument)
|
||||
for arg in flattened):
|
||||
logging.info("Unsupported signature for serialization: %s.", signature)
|
||||
continue
|
||||
equal_to_signature = functools.partial(
|
||||
function_spec_lib.is_same_structure, signature, check_values=True)
|
||||
if not any(equal_to_signature(s) for s in seen_signatures):
|
||||
seen_signatures.append(signature)
|
||||
if self.input_signature is not None:
|
||||
seen_signatures.append((self.input_signature, {}))
|
||||
else:
|
||||
concrete_functions = self._list_all_concrete_functions()
|
||||
for concrete_function in concrete_functions:
|
||||
signature = concrete_function.structured_input_signature
|
||||
flattened = nest.flatten(signature)
|
||||
if any(
|
||||
isinstance(arg, func_graph_module.UnknownArgument)
|
||||
for arg in flattened):
|
||||
logging.info("Unsupported signature for serialization: %s.",
|
||||
signature)
|
||||
continue
|
||||
equal_to_signature = functools.partial(
|
||||
function_spec_lib.is_same_structure, signature, check_values=True)
|
||||
if not any(equal_to_signature(s) for s in seen_signatures):
|
||||
seen_signatures.append(signature)
|
||||
|
||||
# Re-create concrete functions for these signatures. Re-creating ensures
|
||||
# that if the cache key has changed, the function will be traced again.
|
||||
|
|
|
|||
|
|
@ -877,7 +877,9 @@ class DefunTest(test.TestCase, parameterized.TestCase):
|
|||
save(mod, '/tmp/kwonlyf', defined.get_concrete_function(*signature))
|
||||
loaded = load('/tmp/kwonlyf')
|
||||
result = loaded.signatures['serving_default'](
|
||||
a=array_ops.constant(1), b=array_ops.constant(2))
|
||||
a=array_ops.constant(1),
|
||||
b=array_ops.constant(2),
|
||||
d=array_ops.constant(5))
|
||||
self.assertEqual(result['output_0'].numpy(), 11)
|
||||
|
||||
def testInputSignatureWithKeywordOnlyArgsNoDefaults(self):
|
||||
|
|
|
|||
|
|
@ -126,9 +126,6 @@ class TracingCompiler:
|
|||
# create different functions for each instance.
|
||||
self._descriptor_cache = weakref.WeakKeyDictionary()
|
||||
self._jit_compile = jit_compile
|
||||
# Flag for preventing recreating placeholders. Set to False when reduced
|
||||
# retracing is True and input_signature is None
|
||||
self._create_placeholders = True
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""Calls a graph function specialized to the inputs."""
|
||||
|
|
@ -296,7 +293,7 @@ class TracingCompiler:
|
|||
autograph_options=self._autograph_options,
|
||||
arg_names=arg_names,
|
||||
capture_by_value=self._capture_by_value,
|
||||
create_placeholders=self._create_placeholders),
|
||||
create_placeholders=False),
|
||||
self._function_attributes,
|
||||
spec=self.function_spec,
|
||||
# Tell the ConcreteFunction to clean up its graph once it goes out of
|
||||
|
|
@ -363,22 +360,23 @@ class TracingCompiler:
|
|||
func_graph = func_graph_module.FuncGraph(
|
||||
self._name, capture_by_value=self._capture_by_value)
|
||||
if self.input_signature is None and self._reduce_retracing:
|
||||
self._create_placeholders = False
|
||||
general_func_type = self._function_cache.generalize(
|
||||
target_func_type = self._function_cache.generalize(
|
||||
current_func_context, lookup_func_type)
|
||||
handledata_mapping = lookup_func_context.get_handledata_mapping()
|
||||
placeholder_mapping = lookup_func_context.get_placeholder_mapping()
|
||||
placeholder_context = trace_type.InternalPlaceholderContext(
|
||||
func_graph, placeholder_mapping, handledata_mapping)
|
||||
with func_graph.as_default():
|
||||
placeholder_bound_args = general_func_type.placeholder_arguments(
|
||||
placeholder_context)
|
||||
if self.function_spec.is_method:
|
||||
# TODO(fmuham): canonicalize_function_inputs removes self arg.
|
||||
args = placeholder_bound_args.args[1:]
|
||||
else:
|
||||
args = placeholder_bound_args.args
|
||||
kwargs = placeholder_bound_args.kwargs
|
||||
else:
|
||||
target_func_type = lookup_func_type
|
||||
handledata_mapping = lookup_func_context.get_handledata_mapping()
|
||||
placeholder_mapping = lookup_func_context.get_placeholder_mapping()
|
||||
placeholder_context = trace_type.InternalPlaceholderContext(
|
||||
func_graph, placeholder_mapping, handledata_mapping)
|
||||
with func_graph.as_default():
|
||||
placeholder_bound_args = target_func_type.placeholder_arguments(
|
||||
placeholder_context)
|
||||
if self.function_spec.is_method:
|
||||
# TODO(fmuham): canonicalize_function_inputs removes self arg.
|
||||
args = placeholder_bound_args.args[1:]
|
||||
else:
|
||||
args = placeholder_bound_args.args
|
||||
kwargs = placeholder_bound_args.kwargs
|
||||
|
||||
concrete_function = self._create_concrete_function(
|
||||
args, kwargs, func_graph)
|
||||
|
|
@ -391,12 +389,8 @@ class TracingCompiler:
|
|||
|
||||
# Create a cache_key with args and captures
|
||||
traced_func_deletion_observer = lookup_func_context.deletion_observer
|
||||
if self.input_signature is None and self._reduce_retracing:
|
||||
traced_func_type = _insert_capture_type(
|
||||
general_func_type, captures, lookup_func_context)
|
||||
else:
|
||||
traced_func_type = _insert_capture_type(
|
||||
lookup_func_type, captures, lookup_func_context)
|
||||
traced_func_type = _insert_capture_type(
|
||||
target_func_type, captures, lookup_func_context)
|
||||
|
||||
self._function_cache.add(current_func_context, traced_func_type,
|
||||
traced_func_deletion_observer,
|
||||
|
|
|
|||
|
|
@ -1322,6 +1322,16 @@ def func_graph_from_py_func(name,
|
|||
func_outputs = nest.map_structure(
|
||||
convert, func_outputs, expand_composites=True)
|
||||
|
||||
# flatten and unflatten func_args and func_kwargs to maintain parity
|
||||
# from flattening which sorts by key
|
||||
func_args = nest.pack_sequence_as(
|
||||
func_args,
|
||||
nest.flatten(func_args, expand_composites=True),
|
||||
expand_composites=True)
|
||||
func_kwargs = nest.pack_sequence_as(
|
||||
func_kwargs,
|
||||
nest.flatten(func_kwargs, expand_composites=True),
|
||||
expand_composites=True)
|
||||
check_func_mutation(func_args_before, func_kwargs_before, func_args,
|
||||
func_kwargs, original_func)
|
||||
finally:
|
||||
|
|
|
|||
|
|
@ -161,8 +161,8 @@ class TraceType(metaclass=abc.ABCMeta):
|
|||
def placeholder_value(self, placeholder_context=None) -> Any:
|
||||
"""Creates a placeholder for tracing.
|
||||
|
||||
Often it is more useful to trace with a placeholder value than an actual
|
||||
one. For example, a placeholder value can represent multiple different
|
||||
tf.funcion traces with the placeholder value rather than the actual value.
|
||||
For example, a placeholder value can represent multiple different
|
||||
actual values. This means that the trace generated with that placeholder
|
||||
value is more general and reusable which saves expensive retracing.
|
||||
|
||||
|
|
@ -188,7 +188,7 @@ class TraceType(metaclass=abc.ABCMeta):
|
|||
```python
|
||||
@tf.function
|
||||
def foo(x):
|
||||
# Here `x` can be the placeholder value
|
||||
# Here `x` is be the placeholder value
|
||||
...
|
||||
|
||||
foo(x) # Here `x` is the actual value
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user