EagerVariableStore, for compatibility with functional layers.

PiperOrigin-RevId: 173915730
This commit is contained in:
Alexandre Passos 2017-10-30 10:49:40 -07:00 committed by TensorFlower Gardener
parent cef680b532
commit 89582677c3
8 changed files with 73 additions and 131 deletions

View File

@ -52,6 +52,7 @@ To use, at program startup, call `tfe.enable_eager_execution()`.
@@restore_variables_on_create
@@Variable
@@get_optimizer_variables
@@EagerVariableStore
@@in_eager_mode
@@in_graph_mode
@ -100,6 +101,7 @@ from tensorflow.python.framework.ops import eager_run as run
from tensorflow.python.framework.test_util import IsolateTest
from tensorflow.python.framework.test_util import run_in_graph_and_eager_modes as run_test_in_graph_and_eager_modes
from tensorflow.python.ops.resource_variable_ops import ResourceVariable as Variable
from tensorflow.python.ops.variable_scope import EagerVariableStore
from tensorflow.python.util.all_util import remove_undocumented
defun = function.defun

View File

@ -386,15 +386,7 @@ def conv1d(inputs,
Raises:
ValueError: if eager execution is enabled.
@compatibility(eager)
Not compatible with eager execution. Use `tf.layers.Conv1D` instead.
@end_compatibility
"""
if context.in_eager_mode():
raise ValueError(
'Functional layers are currently not compatible with eager execution.'
'Use tf.layers.Conv1D instead.')
layer = Conv1D(
filters=filters,
kernel_size=kernel_size,
@ -597,15 +589,7 @@ def conv2d(inputs,
Raises:
ValueError: if eager execution is enabled.
@compatibility(eager)
Not compatible with eager execution. Use `tf.layers.Conv2D` instead.
@end_compatibility
"""
if context.in_eager_mode():
raise ValueError(
'Functional layers are currently not compatible with eager execution.'
'Use tf.layers.Conv2D instead.')
layer = Conv2D(
filters=filters,
kernel_size=kernel_size,
@ -810,15 +794,7 @@ def conv3d(inputs,
Raises:
ValueError: if eager execution is enabled.
@compatibility(eager)
Not compatible with eager execution. Use `tf.layers.Conv3D` instead.
@end_compatibility
"""
if context.in_eager_mode():
raise ValueError(
'Functional layers are currently not compatible with eager execution.'
'Use tf.layers.Conv3D instead.')
layer = Conv3D(
filters=filters,
kernel_size=kernel_size,
@ -1140,15 +1116,7 @@ def separable_conv2d(inputs,
Raises:
ValueError: if eager execution is enabled.
@compatibility(eager)
Not compatible with eager execution. Use `tf.layers.SeparableConv2d` instead.
@end_compatibility
"""
if context.in_eager_mode():
raise ValueError(
'Functional layers are currently not compatible with eager execution.'
'Use tf.layers.SeparableConv2D instead.')
layer = SeparableConv2D(
filters=filters,
kernel_size=kernel_size,
@ -1446,15 +1414,7 @@ def conv2d_transpose(inputs,
Raises:
ValueError: if eager execution is enabled.
@compatibility(eager)
Not compatible with eager execution. Use `tf.layers.Conv2DTranspose` instead.
@end_compatibility
"""
if context.in_eager_mode():
raise ValueError(
'Functional layers are currently not compatible with eager execution.'
'Use tf.layers.Conv2DTranspose instead.')
layer = Conv2DTranspose(
filters=filters,
kernel_size=kernel_size,
@ -1768,15 +1728,7 @@ def conv3d_transpose(inputs,
Raises:
ValueError: if eager execution is enabled.
@compatibility(eager)
Not compatible with eager execution. Use `tf.layers.Conv3DTranspose` instead.
@end_compatibility
"""
if context.in_eager_mode():
raise ValueError(
'Functional layers are currently not compatible with eager execution.'
'Use tf.layers.Conv3DTranspose instead.')
layer = Conv3DTranspose(
filters=filters,
kernel_size=kernel_size,

View File

@ -234,15 +234,7 @@ def dense(
Raises:
ValueError: if eager execution is enabled.
@compatibility(eager)
Not compatible with eager execution. Use `tf.layers.Dense` instead.
@end_compatibility
"""
if context.in_eager_mode():
raise ValueError(
'Functional layers are currently not compatible with eager execution.'
'Use tf.layers.Dense instead.')
layer = Dense(units,
activation=activation,
use_bias=use_bias,
@ -347,15 +339,7 @@ def dropout(inputs,
Raises:
ValueError: if eager execution is enabled.
@compatibility(eager)
Not compatible with eager execution. Use `tf.layers.Dropout` instead.
@end_compatibility
"""
if context.in_eager_mode():
raise ValueError(
'Functional layers are currently not compatible with eager execution.'
'Use tf.layers.Dropout instead.')
layer = Dropout(rate, noise_shape=noise_shape, seed=seed, name=name)
return layer.apply(inputs, training=training)

View File

@ -23,6 +23,7 @@ import collections
import numpy as np
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
@ -258,6 +259,23 @@ class DenseTest(test.TestCase):
self.assertAllClose(weights['scope/dense/bias'].read_value().eval(),
np.zeros((2)))
def testEagerExecution(self):
with context.eager_mode():
container = variable_scope.EagerVariableStore()
x = constant_op.constant([[2.0]])
with container.as_default():
y = core_layers.dense(
x, 1, name='my_dense',
kernel_initializer=init_ops.ones_initializer())
self.assertAllEqual(y, [[2.0]])
self.assertEqual(len(container.variables()), 2)
# Recreate the layer to test reuse.
with container.as_default():
core_layers.dense(
x, 1, name='my_dense',
kernel_initializer=init_ops.ones_initializer())
self.assertEqual(len(container.variables()), 2)
def testFunctionalDenseWithCustomGetter(self):
called = [0]

View File

@ -50,15 +50,7 @@ def maxout(inputs, num_units, axis=-1, name=None):
Raises:
ValueError: if num_units is not multiple of number of features.
@compatibility(eager)
Not compatible with eager execution. Use `tf.layers.MaxOut` instead.
@end_compatibility
"""
if context.in_eager_mode():
raise ValueError(
'Functional layers are currently not compatible with eager execution.'
'use tf.contrib.layers.MaxOut instead')
return MaxOut(num_units=num_units, axis=axis, name=name)(inputs)

View File

@ -729,16 +729,7 @@ def batch_normalization(inputs,
Raises:
ValueError: if eager execution is enabled.
@compatibility(eager)
Not compatible with eager execution. Use `tf.layers.BatchNormalization`
instead.
@end_compatibility
"""
if context.in_eager_mode():
raise ValueError(
'Functional layers are currently not compatible with eager execution.'
'Use tf.layers.BactchNormalization instead.')
layer = BatchNormalization(
axis=axis,
momentum=momentum,

View File

@ -148,15 +148,7 @@ def average_pooling1d(inputs, pool_size, strides,
Raises:
ValueError: if eager execution is enabled.
@compatibility(eager)
Not compatible with eager execution. Use `tf.layers.AveragePooling1D` instead.
@end_compatibility
"""
if context.in_eager_mode():
raise ValueError(
'Functional layers are currently not compatible with eager execution.'
'Use tf.layers.AveragePooling1D instead.')
layer = AveragePooling1D(pool_size=pool_size,
strides=strides,
padding=padding,
@ -221,15 +213,7 @@ def max_pooling1d(inputs, pool_size, strides,
Raises:
ValueError: if eager execution is enabled.
@compatibility(eager)
Not compatible with eager execution. Use `tf.layers.MaxPooling1D` instead.
@end_compatibility
"""
if context.in_eager_mode():
raise ValueError(
'Functional layers are currently not compatible with eager execution.'
'Use tf.layers.MaxPooling1D instead.')
layer = MaxPooling1D(pool_size=pool_size,
strides=strides,
padding=padding,
@ -370,15 +354,7 @@ def average_pooling2d(inputs,
Raises:
ValueError: if eager execution is enabled.
@compatibility(eager)
Not compatible with eager execution. Use `tf.layers.AveragePooling2D` instead.
@end_compatibility
"""
if context.in_eager_mode():
raise ValueError(
'Functional layers are currently not compatible with eager execution.'
'Use tf.layers.AveragePooling2D instead.')
layer = AveragePooling2D(pool_size=pool_size, strides=strides,
padding=padding, data_format=data_format,
name=name)
@ -446,15 +422,7 @@ def max_pooling2d(inputs,
Raises:
ValueError: if eager execution is enabled.
@compatibility(eager)
Not compatible with eager execution. Use `tf.layers.MaxPooling2D` instead.
@end_compatibility
"""
if context.in_eager_mode():
raise ValueError(
'Functional layers are currently not compatible with eager execution.'
'Use tf.layers.MaxPooling2D instead.')
layer = MaxPooling2D(pool_size=pool_size, strides=strides,
padding=padding, data_format=data_format,
name=name)
@ -608,15 +576,7 @@ def average_pooling3d(inputs,
Raises:
ValueError: if eager execution is enabled.
@compatibility(eager)
Not compatible with eager execution. Use `tf.layers.AveragePooling3D` instead.
@end_compatibility
"""
if context.in_eager_mode():
raise ValueError(
'Functional layers are currently not compatible with eager execution.'
'Use tf.layers.AveragePooling3D instead.')
layer = AveragePooling3D(pool_size=pool_size, strides=strides,
padding=padding, data_format=data_format,
name=name)
@ -688,15 +648,7 @@ def max_pooling3d(inputs,
Raises:
ValueError: if eager execution is enabled.
@compatibility(eager)
Not compatible with eager execution. Use `tf.layers.MaxPooling3D` instead.
@end_compatibility
"""
if context.in_eager_mode():
raise ValueError(
'Functional layers are currently not compatible with eager execution.'
'Use tf.layers.MaxPooling3D instead.')
layer = MaxPooling3D(pool_size=pool_size, strides=strides,
padding=padding, data_format=data_format,
name=name)

View File

@ -208,6 +208,7 @@ class _VariableStore(object):
self._vars = {} # A dictionary of the stored TensorFlow variables.
self._partitioned_vars = {} # A dict of the stored PartitionedVariables.
self.variable_scopes_count = {} # Count re-used variable scopes.
self._store_eager_variables = False
def open_variable_scope(self, scope_name):
if scope_name in self.variable_scopes_count:
@ -309,13 +310,21 @@ class _VariableStore(object):
ValueError: when creating a new variable and shape is not declared,
when reusing a variable and specifying a conflicting shape,
or when violating reuse during variable creation.
RuntimeError: when eager execution is enabled and not called from an
EagerVariableStore.
"""
if custom_getter is not None and not callable(custom_getter):
raise ValueError(
"Passed a custom_getter which is not callable: %s" % custom_getter)
if context.in_eager_mode():
reuse = False
if not self._store_eager_variables and reuse:
raise RuntimeError(
"When eager execution is enabled variable reuse is only supported"
" when an EagerVariableStore is active. See the documentation on"
" EagerVariableStore for example usage.")
if self._store_eager_variables:
reuse = AUTO_REUSE
use_resource = True
# If a *_ref type is passed in an error would be triggered further down the
@ -795,7 +804,7 @@ class _VariableStore(object):
dtype=variable_dtype,
validate_shape=validate_shape,
constraint=constraint)
if context.in_graph_mode():
if context.in_graph_mode() or self._store_eager_variables:
# In eager mode we do not want to keep default references to Variable
# objects as this will prevent their memory from being released.
self._vars[name] = v
@ -1177,6 +1186,48 @@ def _get_default_variable_store():
return store
@tf_contextlib.contextmanager
def with_variable_store(store):
store_collection = ops.get_collection_ref(_VARSTORE_KEY)
old = list(store_collection)
store_collection[:] = [store]
try:
yield
finally:
store_collection[:] = old
class EagerVariableStore(object):
"""Wrapper allowing functional layers to be used with eager execution.
When eager execution is enabled Variables get deleted when they go out of
scope, and are not stored in global collections by default. A lot of code
(mostly the functional layers in tf.layers) assumes that variables are kept in
a global list.
EagerVariableStore can be used in conjunction with this code to make it
eager-friendly. For example, to create a dense layer, use:
```
container = tfe.EagerVariableStore()
for input in dataset_iterator:
with container.as_default():
x = tf.layers.dense(input, name="l1")
print(container.variables) # Should print the variables used in the layer.
```
"""
def __init__(self):
self._store = _VariableStore()
self._store._store_eager_variables = True # pylint: disable=protected-access
def as_default(self):
return with_variable_store(self._store)
def variables(self):
return self._store._vars.values() # pylint: disable=protected-access
def get_variable(name,
shape=None,
dtype=None,