mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
Refactoring of layer name autogeneration, to remove a graph serialization warning.
PiperOrigin-RevId: 157520123
This commit is contained in:
parent
5784e1e35e
commit
e405b0f6b1
|
|
@ -33,6 +33,7 @@ from tensorflow.python.framework import constant_op
|
|||
from tensorflow.python.framework import dtypes as dtypes_module
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.layers import base as tf_base_layers
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import clip_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
|
|
@ -261,16 +262,9 @@ def get_uid(prefix=''):
|
|||
2
|
||||
```
|
||||
"""
|
||||
layer_name_uids_collection = ops.get_collection('LAYER_NAME_UIDS')
|
||||
if not layer_name_uids_collection:
|
||||
layer_name_uids = {}
|
||||
ops.add_to_collection('LAYER_NAME_UIDS', layer_name_uids)
|
||||
else:
|
||||
layer_name_uids = layer_name_uids_collection[0]
|
||||
if prefix not in layer_name_uids:
|
||||
layer_name_uids[prefix] = 1
|
||||
else:
|
||||
layer_name_uids[prefix] += 1
|
||||
graph = ops.get_default_graph()
|
||||
layer_name_uids = tf_base_layers.PER_GRAPH_LAYER_NAME_UIDS[graph]
|
||||
layer_name_uids[prefix] += 1
|
||||
return layer_name_uids[prefix]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -23,9 +23,11 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import copy
|
||||
import functools
|
||||
import re
|
||||
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
import numpy as np
|
||||
import six
|
||||
|
|
@ -650,10 +652,10 @@ def _to_list(x):
|
|||
return [x]
|
||||
|
||||
|
||||
def _add_elements_to_collection(elements, collections):
|
||||
def _add_elements_to_collection(elements, collection_list):
|
||||
elements = _to_list(elements)
|
||||
collections = _to_list(collections)
|
||||
for name in collections:
|
||||
collection_list = _to_list(collection_list)
|
||||
for name in collection_list:
|
||||
collection = ops.get_collection_ref(name)
|
||||
collection_set = set(collection)
|
||||
for element in elements:
|
||||
|
|
@ -666,6 +668,13 @@ def _object_list_uid(object_list):
|
|||
return ', '.join([str(abs(id(x))) for x in object_list])
|
||||
|
||||
|
||||
# A global dictionary mapping graph objects to an index of counters used
|
||||
# for various layer names in each graph.
|
||||
# Allows to give unique autogenerated names to layers, in a graph-specific way.
|
||||
PER_GRAPH_LAYER_NAME_UIDS = collections.defaultdict(
|
||||
lambda: collections.defaultdict(int))
|
||||
|
||||
|
||||
def _unique_layer_name(name):
|
||||
"""Makes a layer name (or arbitrary string) unique within a TensorFlow graph.
|
||||
|
||||
|
|
@ -684,14 +693,7 @@ def _unique_layer_name(name):
|
|||
dense_2
|
||||
```
|
||||
"""
|
||||
layer_name_uids_collection = ops.get_collection('LAYER_NAME_UIDS')
|
||||
if not layer_name_uids_collection:
|
||||
layer_name_uids = {}
|
||||
ops.add_to_collection('LAYER_NAME_UIDS', layer_name_uids)
|
||||
else:
|
||||
layer_name_uids = layer_name_uids_collection[0]
|
||||
if name not in layer_name_uids:
|
||||
layer_name_uids[name] = 1
|
||||
else:
|
||||
layer_name_uids[name] += 1
|
||||
graph = ops.get_default_graph()
|
||||
layer_name_uids = PER_GRAPH_LAYER_NAME_UIDS[graph]
|
||||
layer_name_uids[name] += 1
|
||||
return name + '_' + str(layer_name_uids[name])
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user