Refactoring of layer name autogeneration, to remove a graph serialization warning.

PiperOrigin-RevId: 157520123
This commit is contained in:
Francois Chollet 2017-05-30 15:20:53 -07:00 committed by TensorFlower Gardener
parent 5784e1e35e
commit e405b0f6b1
2 changed files with 19 additions and 23 deletions

View File

@ -33,6 +33,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes as dtypes_module
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.layers import base as tf_base_layers
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import control_flow_ops
@ -261,16 +262,9 @@ def get_uid(prefix=''):
2
```
"""
layer_name_uids_collection = ops.get_collection('LAYER_NAME_UIDS')
if not layer_name_uids_collection:
layer_name_uids = {}
ops.add_to_collection('LAYER_NAME_UIDS', layer_name_uids)
else:
layer_name_uids = layer_name_uids_collection[0]
if prefix not in layer_name_uids:
layer_name_uids[prefix] = 1
else:
layer_name_uids[prefix] += 1
graph = ops.get_default_graph()
layer_name_uids = tf_base_layers.PER_GRAPH_LAYER_NAME_UIDS[graph]
layer_name_uids[prefix] += 1
return layer_name_uids[prefix]

View File

@ -23,9 +23,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import copy
import functools
import re
from six.moves import xrange # pylint: disable=redefined-builtin
import numpy as np
import six
@ -650,10 +652,10 @@ def _to_list(x):
return [x]
def _add_elements_to_collection(elements, collections):
def _add_elements_to_collection(elements, collection_list):
elements = _to_list(elements)
collections = _to_list(collections)
for name in collections:
collection_list = _to_list(collection_list)
for name in collection_list:
collection = ops.get_collection_ref(name)
collection_set = set(collection)
for element in elements:
@ -666,6 +668,13 @@ def _object_list_uid(object_list):
return ', '.join([str(abs(id(x))) for x in object_list])
# A global dictionary mapping graph objects to an index of counters used
# for various layer names in each graph.
# Allows to give unique autogenerated names to layers, in a graph-specific way.
PER_GRAPH_LAYER_NAME_UIDS = collections.defaultdict(
lambda: collections.defaultdict(int))
def _unique_layer_name(name):
"""Makes a layer name (or arbitrary string) unique within a TensorFlow graph.
@ -684,14 +693,7 @@ def _unique_layer_name(name):
dense_2
```
"""
layer_name_uids_collection = ops.get_collection('LAYER_NAME_UIDS')
if not layer_name_uids_collection:
layer_name_uids = {}
ops.add_to_collection('LAYER_NAME_UIDS', layer_name_uids)
else:
layer_name_uids = layer_name_uids_collection[0]
if name not in layer_name_uids:
layer_name_uids[name] = 1
else:
layer_name_uids[name] += 1
graph = ops.get_default_graph()
layer_name_uids = PER_GRAPH_LAYER_NAME_UIDS[graph]
layer_name_uids[name] += 1
return name + '_' + str(layer_name_uids[name])