mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 00:20:20 +01:00
Fix a tfprof bug. Throws an error when the flops cannot be calculated.
PiperOrigin-RevId: 173702740
This commit is contained in:
parent
73155f56a3
commit
3d39b32b9a
|
|
@ -373,6 +373,7 @@ def _max_pool_grad_flops(graph, node):
|
|||
kernel_area = _list_product(kernel_shape)
|
||||
orig_out_shape = graph_util.tensor_shape_from_node_def_name(graph,
|
||||
node.input[1])
|
||||
orig_out_shape.assert_is_fully_defined()
|
||||
max_pool_ops = kernel_area * orig_out_shape.num_elements()
|
||||
return ops.OpStats("flops", max_pool_ops + orig_out_shape.num_elements())
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user