Remove unused arguments to call_cpp_shape_fn.

PiperOrigin-RevId: 157640125
This commit is contained in:
A. Unique TensorFlower 2017-05-31 15:01:24 -07:00 committed by TensorFlower Gardener
parent 9ddbf31feb
commit 8e868cf6a1

View File

@ -553,24 +553,11 @@ def broadcast_shape(shape_x, shape_y):
return tensor_shape.TensorShape(return_dims)
def call_cpp_shape_fn(op,
input_tensors_needed=None,
input_tensors_as_shapes_needed=None,
debug_python_shape_fn=None,
require_shape_fn=True):
def call_cpp_shape_fn(op, require_shape_fn=True):
"""A shape function that delegates to the registered C++ shape function.
Args:
op: the node in the graph for which to compute output shapes.
input_tensors_needed: a list of input tensor indices for which to compute
the input tensor's value and pass to the C++ shape function.
input_tensors_as_shapes_needed: a list of input tensor indices for which to
compute the constant_value_as_shape and pass to the C++ shape function.
debug_python_shape_fn: For testing only during migration to using
call_cpp_shape_fn. Do not submit calls that set this,
as the comparison is slow. If non-None, the python shape function;
this function will be called and its output compared to that of
the C++ shape function.
require_shape_fn: If true, and the C++ shape function is not registered
in the current binary then an exception is raised; otherwise, if the
C++ shape function is not registered then unknown_shape is used.
@ -599,13 +586,13 @@ def call_cpp_shape_fn(op,
"handle_data": [None]
}
input_tensors_needed = input_tensors_needed or []
input_tensors_as_shapes_needed = input_tensors_as_shapes_needed or []
input_tensors_needed = []
input_tensors_as_shapes_needed = []
while True:
res = _call_cpp_shape_fn_impl(op, input_tensors_needed,
input_tensors_as_shapes_needed,
debug_python_shape_fn, require_shape_fn)
require_shape_fn)
if not isinstance(res, dict):
# Handles the case where _call_cpp_shape_fn_impl calls unknown_shape(op).
return res
@ -629,9 +616,7 @@ def call_cpp_shape_fn(op,
def _call_cpp_shape_fn_impl(
op, input_tensors_needed,
input_tensors_as_shapes_needed,
debug_python_shape_fn, require_shape_fn):
op, input_tensors_needed, input_tensors_as_shapes_needed, require_shape_fn):
"""Core implementaton of call_cpp_shape_fn."""
graph_def_version = op.graph.graph_def_versions.producer
node_def_str = op.node_def.SerializeToString()
@ -691,22 +676,6 @@ def _call_cpp_shape_fn_impl(
r.handle_data if r.handle_data.is_set else None for r in result_protos
]
if debug_python_shape_fn:
try:
python_result = [tensor_shape.as_shape(s)
for s in debug_python_shape_fn(op)]
except Exception as err:
raise AssertionError("Python shape function return error but "
"C++ shape functon did not: %s" % str(err))
result_as_shapes = [tensor_shape.as_shape(s) for s in result]
if str(result_as_shapes) != str(python_result):
raise ValueError(
("Python vs CPP shape mismatch. "
"CPP: %s vs python: %s on node %s "
"with input shapes %s") % (
str(result_as_shapes), str(python_result), str(op.node_def),
",".join([str(i.get_shape()) for i in op.inputs])))
return {
"shapes": result,
"handle_data": result_handle_data,