mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
Remove unused arguments to call_cpp_shape_fn.
PiperOrigin-RevId: 157640125
This commit is contained in:
parent
9ddbf31feb
commit
8e868cf6a1
|
|
@ -553,24 +553,11 @@ def broadcast_shape(shape_x, shape_y):
|
||||||
return tensor_shape.TensorShape(return_dims)
|
return tensor_shape.TensorShape(return_dims)
|
||||||
|
|
||||||
|
|
||||||
def call_cpp_shape_fn(op,
|
def call_cpp_shape_fn(op, require_shape_fn=True):
|
||||||
input_tensors_needed=None,
|
|
||||||
input_tensors_as_shapes_needed=None,
|
|
||||||
debug_python_shape_fn=None,
|
|
||||||
require_shape_fn=True):
|
|
||||||
"""A shape function that delegates to the registered C++ shape function.
|
"""A shape function that delegates to the registered C++ shape function.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
op: the node in the graph for which to compute output shapes.
|
op: the node in the graph for which to compute output shapes.
|
||||||
input_tensors_needed: a list of input tensor indices for which to compute
|
|
||||||
the input tensor's value and pass to the C++ shape function.
|
|
||||||
input_tensors_as_shapes_needed: a list of input tensor indices for which to
|
|
||||||
compute the constant_value_as_shape and pass to the C++ shape function.
|
|
||||||
debug_python_shape_fn: For testing only during migration to using
|
|
||||||
call_cpp_shape_fn. Do not submit calls that set this,
|
|
||||||
as the comparison is slow. If non-None, the python shape function;
|
|
||||||
this function will be called and its output compared to that of
|
|
||||||
the C++ shape function.
|
|
||||||
require_shape_fn: If true, and the C++ shape function is not registered
|
require_shape_fn: If true, and the C++ shape function is not registered
|
||||||
in the current binary then an exception is raised; otherwise, if the
|
in the current binary then an exception is raised; otherwise, if the
|
||||||
C++ shape function is not registered then unknown_shape is used.
|
C++ shape function is not registered then unknown_shape is used.
|
||||||
|
|
@ -599,13 +586,13 @@ def call_cpp_shape_fn(op,
|
||||||
"handle_data": [None]
|
"handle_data": [None]
|
||||||
}
|
}
|
||||||
|
|
||||||
input_tensors_needed = input_tensors_needed or []
|
input_tensors_needed = []
|
||||||
input_tensors_as_shapes_needed = input_tensors_as_shapes_needed or []
|
input_tensors_as_shapes_needed = []
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
res = _call_cpp_shape_fn_impl(op, input_tensors_needed,
|
res = _call_cpp_shape_fn_impl(op, input_tensors_needed,
|
||||||
input_tensors_as_shapes_needed,
|
input_tensors_as_shapes_needed,
|
||||||
debug_python_shape_fn, require_shape_fn)
|
require_shape_fn)
|
||||||
if not isinstance(res, dict):
|
if not isinstance(res, dict):
|
||||||
# Handles the case where _call_cpp_shape_fn_impl calls unknown_shape(op).
|
# Handles the case where _call_cpp_shape_fn_impl calls unknown_shape(op).
|
||||||
return res
|
return res
|
||||||
|
|
@ -629,9 +616,7 @@ def call_cpp_shape_fn(op,
|
||||||
|
|
||||||
|
|
||||||
def _call_cpp_shape_fn_impl(
|
def _call_cpp_shape_fn_impl(
|
||||||
op, input_tensors_needed,
|
op, input_tensors_needed, input_tensors_as_shapes_needed, require_shape_fn):
|
||||||
input_tensors_as_shapes_needed,
|
|
||||||
debug_python_shape_fn, require_shape_fn):
|
|
||||||
"""Core implementaton of call_cpp_shape_fn."""
|
"""Core implementaton of call_cpp_shape_fn."""
|
||||||
graph_def_version = op.graph.graph_def_versions.producer
|
graph_def_version = op.graph.graph_def_versions.producer
|
||||||
node_def_str = op.node_def.SerializeToString()
|
node_def_str = op.node_def.SerializeToString()
|
||||||
|
|
@ -691,22 +676,6 @@ def _call_cpp_shape_fn_impl(
|
||||||
r.handle_data if r.handle_data.is_set else None for r in result_protos
|
r.handle_data if r.handle_data.is_set else None for r in result_protos
|
||||||
]
|
]
|
||||||
|
|
||||||
if debug_python_shape_fn:
|
|
||||||
try:
|
|
||||||
python_result = [tensor_shape.as_shape(s)
|
|
||||||
for s in debug_python_shape_fn(op)]
|
|
||||||
except Exception as err:
|
|
||||||
raise AssertionError("Python shape function return error but "
|
|
||||||
"C++ shape functon did not: %s" % str(err))
|
|
||||||
result_as_shapes = [tensor_shape.as_shape(s) for s in result]
|
|
||||||
if str(result_as_shapes) != str(python_result):
|
|
||||||
raise ValueError(
|
|
||||||
("Python vs CPP shape mismatch. "
|
|
||||||
"CPP: %s vs python: %s on node %s "
|
|
||||||
"with input shapes %s") % (
|
|
||||||
str(result_as_shapes), str(python_result), str(op.node_def),
|
|
||||||
",".join([str(i.get_shape()) for i in op.inputs])))
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"shapes": result,
|
"shapes": result,
|
||||||
"handle_data": result_handle_data,
|
"handle_data": result_handle_data,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user