mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
Remove unused arguments to call_cpp_shape_fn.
PiperOrigin-RevId: 157640125
This commit is contained in:
parent
9ddbf31feb
commit
8e868cf6a1
|
|
@ -553,24 +553,11 @@ def broadcast_shape(shape_x, shape_y):
|
|||
return tensor_shape.TensorShape(return_dims)
|
||||
|
||||
|
||||
def call_cpp_shape_fn(op,
|
||||
input_tensors_needed=None,
|
||||
input_tensors_as_shapes_needed=None,
|
||||
debug_python_shape_fn=None,
|
||||
require_shape_fn=True):
|
||||
def call_cpp_shape_fn(op, require_shape_fn=True):
|
||||
"""A shape function that delegates to the registered C++ shape function.
|
||||
|
||||
Args:
|
||||
op: the node in the graph for which to compute output shapes.
|
||||
input_tensors_needed: a list of input tensor indices for which to compute
|
||||
the input tensor's value and pass to the C++ shape function.
|
||||
input_tensors_as_shapes_needed: a list of input tensor indices for which to
|
||||
compute the constant_value_as_shape and pass to the C++ shape function.
|
||||
debug_python_shape_fn: For testing only during migration to using
|
||||
call_cpp_shape_fn. Do not submit calls that set this,
|
||||
as the comparison is slow. If non-None, the python shape function;
|
||||
this function will be called and its output compared to that of
|
||||
the C++ shape function.
|
||||
require_shape_fn: If true, and the C++ shape function is not registered
|
||||
in the current binary then an exception is raised; otherwise, if the
|
||||
C++ shape function is not registered then unknown_shape is used.
|
||||
|
|
@ -599,13 +586,13 @@ def call_cpp_shape_fn(op,
|
|||
"handle_data": [None]
|
||||
}
|
||||
|
||||
input_tensors_needed = input_tensors_needed or []
|
||||
input_tensors_as_shapes_needed = input_tensors_as_shapes_needed or []
|
||||
input_tensors_needed = []
|
||||
input_tensors_as_shapes_needed = []
|
||||
|
||||
while True:
|
||||
res = _call_cpp_shape_fn_impl(op, input_tensors_needed,
|
||||
input_tensors_as_shapes_needed,
|
||||
debug_python_shape_fn, require_shape_fn)
|
||||
require_shape_fn)
|
||||
if not isinstance(res, dict):
|
||||
# Handles the case where _call_cpp_shape_fn_impl calls unknown_shape(op).
|
||||
return res
|
||||
|
|
@ -629,9 +616,7 @@ def call_cpp_shape_fn(op,
|
|||
|
||||
|
||||
def _call_cpp_shape_fn_impl(
|
||||
op, input_tensors_needed,
|
||||
input_tensors_as_shapes_needed,
|
||||
debug_python_shape_fn, require_shape_fn):
|
||||
op, input_tensors_needed, input_tensors_as_shapes_needed, require_shape_fn):
|
||||
"""Core implementaton of call_cpp_shape_fn."""
|
||||
graph_def_version = op.graph.graph_def_versions.producer
|
||||
node_def_str = op.node_def.SerializeToString()
|
||||
|
|
@ -691,22 +676,6 @@ def _call_cpp_shape_fn_impl(
|
|||
r.handle_data if r.handle_data.is_set else None for r in result_protos
|
||||
]
|
||||
|
||||
if debug_python_shape_fn:
|
||||
try:
|
||||
python_result = [tensor_shape.as_shape(s)
|
||||
for s in debug_python_shape_fn(op)]
|
||||
except Exception as err:
|
||||
raise AssertionError("Python shape function return error but "
|
||||
"C++ shape functon did not: %s" % str(err))
|
||||
result_as_shapes = [tensor_shape.as_shape(s) for s in result]
|
||||
if str(result_as_shapes) != str(python_result):
|
||||
raise ValueError(
|
||||
("Python vs CPP shape mismatch. "
|
||||
"CPP: %s vs python: %s on node %s "
|
||||
"with input shapes %s") % (
|
||||
str(result_as_shapes), str(python_result), str(op.node_def),
|
||||
",".join([str(i.get_shape()) for i in op.inputs])))
|
||||
|
||||
return {
|
||||
"shapes": result,
|
||||
"handle_data": result_handle_data,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user