Steps toward making ResourceVariables compatible with Eager.

This change forces the value of the reuse flag in variable scopes to be tf.AUTO_REUSE when in Eager mode.

This change also adds comprehensive Eager tests for ResourceVariable.

PiperOrigin-RevId: 166408161
This commit is contained in:
Ali Yahya 2017-08-24 16:02:09 -07:00 committed by TensorFlower Gardener
parent b2ce451502
commit 3142f8ef5d
7 changed files with 176 additions and 148 deletions

View File

@ -2212,6 +2212,7 @@ py_library(
":resource_variable_ops",
":tensor_shape",
":variables",
"//tensorflow/python/eager:context",
"//tensorflow/python/estimator:util",
"@six_archive//:six",
],

View File

@ -4923,21 +4923,21 @@ def name_scope(name, default_name=None, values=None):
ValueError: if neither `name` nor `default_name` is provided
but `values` are.
"""
n = default_name if name is None else name
name = default_name if name is None else name
if context.in_eager_mode():
ctx = context.context()
old_name = ctx.scope_name
if n is None:
if name is None:
scope_name = ""
else:
scope_name = "%s%s/" % (old_name, n) if old_name else "%s/" % n
scope_name = "%s%s/" % (old_name, name) if old_name else "%s/" % name
ctx.scope_name = scope_name
try:
yield scope_name
finally:
ctx.scope_name = old_name
else:
if n is None and values is not None:
if name is None and values is not None:
# We only raise an error if values is not None (provided) because
# currently tf.name_scope(None) (values=None then) is sometimes used as an
# idiom to reset to top scope.
@ -4947,7 +4947,7 @@ def name_scope(name, default_name=None, values=None):
if values is None:
values = []
g = _get_graph_from_inputs(values)
with g.as_default(), g.name_scope(n) as scope:
with g.as_default(), g.name_scope(name) as scope:
yield scope

View File

@ -638,6 +638,7 @@ cuda_py_test(
srcs = ["resource_variable_ops_test.py"],
additional_deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
"//tensorflow/python:framework_for_generated_wrappers",

View File

@ -53,86 +53,78 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
0,
dtype=dtypes.int32)).run()
@test_util.run_in_graph_and_eager_modes()
def testDtypeSurvivesIdentity(self):
with self.test_session():
handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
id_handle = array_ops.identity(handle)
resource_variable_ops.assign_variable_op(id_handle,
constant_op.constant(
0,
dtype=dtypes.int32)).run()
handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
id_handle = array_ops.identity(handle)
self.evaluate(resource_variable_ops.assign_variable_op(
id_handle, constant_op.constant(0, dtype=dtypes.int32)))
@test_util.run_in_graph_and_eager_modes()
def testCreateRead(self):
with self.test_session():
handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
resource_variable_ops.assign_variable_op(handle,
constant_op.constant(
1,
dtype=dtypes.int32)).run()
value = resource_variable_ops.read_variable_op(
handle, dtype=dtypes.int32).eval()
self.assertAllEqual(1, value)
handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
self.evaluate(resource_variable_ops.assign_variable_op(
handle, constant_op.constant(1, dtype=dtypes.int32)))
value = self.evaluate(
resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32))
self.assertAllEqual(1, value)
@test_util.run_in_graph_and_eager_modes()
def testManyAssigns(self):
with self.test_session() as session:
handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
create = resource_variable_ops.assign_variable_op(handle,
constant_op.constant(
1,
dtype=dtypes.int32))
with ops.control_dependencies([create]):
first_read = resource_variable_ops.read_variable_op(
handle, dtype=dtypes.int32)
with ops.control_dependencies([first_read]):
write = resource_variable_ops.assign_variable_op(
handle, constant_op.constant(2, dtype=dtypes.int32))
with ops.control_dependencies([write]):
second_read = resource_variable_ops.read_variable_op(
handle, dtype=dtypes.int32)
f, s = session.run([first_read, second_read])
self.assertEqual(f, 1)
self.assertEqual(s, 2)
handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
create = resource_variable_ops.assign_variable_op(
handle, constant_op.constant(1, dtype=dtypes.int32))
with ops.control_dependencies([create]):
first_read = resource_variable_ops.read_variable_op(
handle, dtype=dtypes.int32)
with ops.control_dependencies([first_read]):
write = resource_variable_ops.assign_variable_op(
handle, constant_op.constant(2, dtype=dtypes.int32))
with ops.control_dependencies([write]):
second_read = resource_variable_ops.read_variable_op(
handle, dtype=dtypes.int32)
f, s = self.evaluate([first_read, second_read])
self.assertEqual(f, 1)
self.assertEqual(s, 2)
@test_util.run_in_graph_and_eager_modes()
def testAssignAdd(self):
with self.test_session():
handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
resource_variable_ops.assign_variable_op(handle,
constant_op.constant(
1,
dtype=dtypes.int32)).run()
resource_variable_ops.assign_add_variable_op(
handle, constant_op.constant(1, dtype=dtypes.int32)).run()
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
self.assertEqual(read.eval(), 2)
handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
self.evaluate(resource_variable_ops.assign_variable_op(
handle, constant_op.constant(1, dtype=dtypes.int32)))
self.evaluate(resource_variable_ops.assign_add_variable_op(
handle, constant_op.constant(1, dtype=dtypes.int32)))
read = self.evaluate(
resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32))
self.assertEqual(read, 2)
@test_util.run_in_graph_and_eager_modes(use_gpu=True)
def testScatterAdd(self):
with self.test_session(use_gpu=True):
handle = resource_variable_ops.var_handle_op(
dtype=dtypes.int32, shape=[1, 1])
resource_variable_ops.assign_variable_op(handle,
constant_op.constant(
[[1]],
dtype=dtypes.int32)).run()
resource_variable_ops.resource_scatter_add(handle, [0],
constant_op.constant(
[[2]],
dtype=dtypes.int32)).run()
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
self.assertEqual(read.eval(), [[3]])
handle = resource_variable_ops.var_handle_op(
dtype=dtypes.int32, shape=[1, 1])
self.evaluate(resource_variable_ops.assign_variable_op(
handle, constant_op.constant([[1]], dtype=dtypes.int32)))
self.evaluate(resource_variable_ops.resource_scatter_add(
handle, [0], constant_op.constant([[2]], dtype=dtypes.int32)))
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
self.assertEqual(self.evaluate(read), [[3]])
# TODO(alive): get this to work in Eager mode.
def testGPU(self):
with self.test_session(use_gpu=True) as sess:
with self.test_session(use_gpu=True):
abc = variable_scope.get_variable(
"abc",
shape=[1],
initializer=init_ops.ones_initializer(),
use_resource=True)
sess.run(variables.global_variables_initializer())
self.evaluate(variables.global_variables_initializer())
self.assertEqual(
resource_variable_ops.var_is_initialized_op(abc.handle).eval(), True)
print(sess.run(abc))
self.evaluate(
resource_variable_ops.var_is_initialized_op(abc.handle)),
True)
# TODO(alive): fix bug in convert_to_tensor; get this to work in Eager.
def testConstraintArg(self):
constraint = lambda x: x
v = resource_variable_ops.ResourceVariable(
@ -144,6 +136,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
v = resource_variable_ops.ResourceVariable(
initial_value=lambda: 1, constraint=constraint)
# TODO(alive): how should this work in Eager mode?
def testInitFn(self):
with self.test_session():
v = resource_variable_ops.ResourceVariable(
@ -151,53 +144,54 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
self.assertEqual(v.handle.op.colocation_groups(),
v.initializer.inputs[1].op.colocation_groups())
# TODO(alive): fix bug in convert_to_tensor; get this to work in Eager.
def testInitFnDtype(self):
with self.test_session():
v = resource_variable_ops.ResourceVariable(
initial_value=lambda: 1, dtype=dtypes.float32)
self.assertEqual(dtypes.float32, v.value().dtype)
# TODO(alive): fix bug in convert_to_tensor; get this to work in Eager.
def testInitFnNoDtype(self):
with self.test_session():
v = resource_variable_ops.ResourceVariable(initial_value=lambda: 1)
self.assertEqual(dtypes.int32, v.value().dtype)
@test_util.run_in_graph_and_eager_modes()
def testInitializeAllVariables(self):
with self.test_session():
v = resource_variable_ops.ResourceVariable(1, dtype=dtypes.float32)
with self.assertRaises(errors.NotFoundError):
v.value().eval()
variables.global_variables_initializer().run()
self.assertEqual(1.0, v.value().eval())
v = resource_variable_ops.ResourceVariable(1, dtype=dtypes.float32)
self.evaluate(variables.global_variables_initializer())
self.assertEqual(1.0, self.evaluate(v.value()))
@test_util.run_in_graph_and_eager_modes()
def testOperatorOverload(self):
with self.test_session():
v = resource_variable_ops.ResourceVariable(1.0)
variables.global_variables_initializer().run()
self.assertEqual(2.0, (v + v).eval())
v = resource_variable_ops.ResourceVariable(1.0)
self.evaluate(variables.global_variables_initializer())
self.assertEqual(2.0, self.evaluate(v + v))
@test_util.run_in_graph_and_eager_modes()
def testAssignMethod(self):
with self.test_session():
v = resource_variable_ops.ResourceVariable(1.0)
variables.global_variables_initializer().run()
v.assign(2.0).eval()
self.assertEqual(2.0, v.value().eval())
v = resource_variable_ops.ResourceVariable(1.0)
self.evaluate(variables.global_variables_initializer())
self.evaluate(v.assign(2.0))
self.assertEqual(2.0, self.evaluate(v.value()))
@test_util.run_in_graph_and_eager_modes()
def testLoad(self):
with self.test_session():
v = resource_variable_ops.ResourceVariable(1.0)
variables.global_variables_initializer().run()
v.load(2.0)
self.assertEqual(2.0, v.value().eval())
v = resource_variable_ops.ResourceVariable(1.0)
self.evaluate(variables.global_variables_initializer())
v.load(2.0)
self.assertEqual(2.0, self.evaluate(v.value()))
@test_util.run_in_graph_and_eager_modes()
def testSparseRead(self):
with self.test_session():
init_value = np.reshape(np.arange(np.power(4, 3)), (4, 4, 4))
v = resource_variable_ops.ResourceVariable(
constant_op.constant(init_value, dtype=dtypes.int32))
variables.global_variables_initializer().run()
self.evaluate(variables.global_variables_initializer())
value = v.sparse_read([0, 3, 1, 2]).eval()
value = self.evaluate(v.sparse_read([0, 3, 1, 2]))
self.assertAllEqual(init_value[[0, 3, 1, 2], ...], value)
def testToFromProto(self):
@ -208,34 +202,33 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
w = resource_variable_ops.ResourceVariable.from_proto(v.to_proto())
self.assertEquals(2, math_ops.add(w, 1).eval())
@test_util.run_in_graph_and_eager_modes()
def testAssignAddMethod(self):
with self.test_session():
v = resource_variable_ops.ResourceVariable(1.0)
variables.global_variables_initializer().run()
v.assign_add(1.0).eval()
self.assertEqual(2.0, v.value().eval())
v = resource_variable_ops.ResourceVariable(1.0)
self.evaluate(variables.global_variables_initializer())
self.evaluate(v.assign_add(1.0))
self.assertEqual(2.0, self.evaluate(v.value()))
@test_util.run_in_graph_and_eager_modes()
def testAssignSubMethod(self):
with self.test_session():
v = resource_variable_ops.ResourceVariable(3.0)
variables.global_variables_initializer().run()
v.assign_sub(1.0).eval()
self.assertEqual(2.0, v.value().eval())
v = resource_variable_ops.ResourceVariable(3.0)
self.evaluate(variables.global_variables_initializer())
self.evaluate(v.assign_sub(1.0))
self.assertEqual(2.0, self.evaluate(v.value()))
@test_util.run_in_graph_and_eager_modes()
def testDestroyResource(self):
with self.test_session() as sess:
v = resource_variable_ops.ResourceVariable(3.0)
variables.global_variables_initializer().run()
self.assertEqual(3.0, v.value().eval())
sess.run(resource_variable_ops.destroy_resource_op(v.handle))
with self.assertRaises(errors.NotFoundError):
v.value().eval()
# Handle to a resource not actually created.
handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
# Should raise no exception
sess.run(
resource_variable_ops.destroy_resource_op(
handle, ignore_lookup_error=True))
v = resource_variable_ops.ResourceVariable(3.0)
self.evaluate(variables.global_variables_initializer())
self.assertEqual(3.0, self.evaluate(v.value()))
self.evaluate(resource_variable_ops.destroy_resource_op(v.handle))
with self.assertRaises(errors.NotFoundError):
self.evaluate(v.value())
# Handle to a resource not actually created.
handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
# Should raise no exception
self.evaluate(resource_variable_ops.destroy_resource_op(
handle, ignore_lookup_error=True))
def testAssignDifferentShapes(self):
with self.test_session() as sess, variable_scope.variable_scope(
@ -247,12 +240,21 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
[assign],
feed_dict={placeholder: np.zeros(shape=[2, 2], dtype=np.float32)})
def testAssignDifferentShapesEager(self):
with context.eager_mode():
with variable_scope.variable_scope("foo"):
var = variable_scope.get_variable("x", shape=[1, 1],
dtype=dtypes.float32)
assign = var.assign(np.zeros(shape=[2, 2]))
self.evaluate(assign)
def testDtypeAfterFromProto(self):
v = resource_variable_ops.ResourceVariable(2.0)
w = resource_variable_ops.ResourceVariable.from_proto(v.to_proto())
self.assertIsInstance(w.dtype, dtypes.DType)
self.assertEqual(v.dtype, w.dtype)
# TODO(alive): get caching to work in eager mode.
def testCachingDevice(self):
with ops.device("/job:server/task:1"):
v = resource_variable_ops.ResourceVariable(
@ -268,41 +270,47 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
with self.assertRaisesRegexp(ValueError, "No attr named '_class'"):
_ = w.value().op.get_attr("_class")
@test_util.run_in_graph_and_eager_modes()
def testSharedName(self):
with self.test_session():
v = resource_variable_ops.ResourceVariable(300.0, name="var1")
v.initializer.run()
v = resource_variable_ops.ResourceVariable(300.0, name="var1")
self.evaluate(variables.global_variables_initializer())
w = resource_variable_ops.var_handle_op(
dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var1")
w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype)
self.assertEqual(300.0, w_read.eval())
w = resource_variable_ops.var_handle_op(
dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var1")
w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype)
self.assertEqual(300.0, self.evaluate(w_read))
x = resource_variable_ops.var_handle_op(
dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var1/")
x_read = resource_variable_ops.read_variable_op(x, v.dtype.base_dtype)
with self.assertRaisesOpError("Resource .*/var1//.* does not exist"):
_ = x_read.eval()
x = resource_variable_ops.var_handle_op(
dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var2")
if context.in_graph_mode():
with self.assertRaisesOpError("Resource .*/var2/.* does not exist"):
x_read = resource_variable_ops.read_variable_op(x, v.dtype.base_dtype)
self.evaluate(x_read)
else:
with self.assertRaisesRegexp(errors.NotFoundError,
"Attempted to read a nonexistent variable."):
_ = resource_variable_ops.read_variable_op(x, v.dtype.base_dtype)
@test_util.run_in_graph_and_eager_modes()
def testSharedNameWithNamescope(self):
with self.test_session():
with ops.name_scope("foo"):
v = resource_variable_ops.ResourceVariable(300.0, name="var1")
v.initializer.run()
with ops.name_scope("foo"):
v = resource_variable_ops.ResourceVariable(300.0, name="var1")
self.evaluate(variables.global_variables_initializer())
w = resource_variable_ops.var_handle_op(
dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="foo/var1")
w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype)
self.assertEqual(300.0, w_read.eval())
w = resource_variable_ops.var_handle_op(
dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="foo/var1")
w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype)
self.assertEqual(300.0, self.evaluate(w_read))
@test_util.run_in_graph_and_eager_modes()
def testShape(self):
with self.test_session():
v = resource_variable_ops.ResourceVariable(
name="var1", initial_value=array_ops.ones(shape=[10, 20, 35]))
self.assertEqual("(10, 20, 35)", str(v.shape))
self.assertEqual("(10, 20, 35)", str(v.get_shape()))
self.assertEqual("(10, 20, 35)", str(v.value().shape))
self.assertEqual("(3, 20, 35)", str(v.sparse_read([0, 1, 2]).shape))
v = resource_variable_ops.ResourceVariable(
name="var1", initial_value=array_ops.ones(shape=[10, 20, 35]))
self.assertEqual("(10, 20, 35)", str(v.shape))
self.assertEqual("(10, 20, 35)", str(v.get_shape()))
self.assertEqual("(10, 20, 35)", str(v.value().shape))
self.assertEqual("(3, 20, 35)", str(v.sparse_read([0, 1, 2]).shape))
if context.in_graph_mode():
self.assertEqual(
"<unknown>",
str(v.sparse_read(array_ops.placeholder(dtypes.int32)).shape))
@ -329,13 +337,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
with self.assertRaisesRegexp(ValueError, "inside a control-flow"):
control_flow_ops.while_loop(cond, body, [0, 0])
# TODO(agarwal,apassos): Add more comprehensive tests and/or translate the above
# tests to work in both GRAPH and EAGER modes.
# TODO(agarwal): Add tests for sparse_read, scatter_sub
class ResourceVariableOpsEagerTest(test_util.TensorFlowTestCase):
def testVariable(self):
def testVariableEager(self):
with context.eager_mode():
init = array_ops.ones(shape=[10, 20, 35], dtype=dtypes.int32)
constraint = lambda x: x

View File

@ -251,6 +251,8 @@ class ResourceVariable(variables.Variable):
name=name)
else:
initial_value = initial_value()
initial_value = ops.convert_to_tensor(
initial_value, name="initial_value", dtype=dtype)
self._handle = gen_resource_variable_ops.var_handle_op(
shape=initial_value.get_shape(),
dtype=initial_value.dtype.base_dtype,

View File

@ -28,6 +28,7 @@ import traceback
import six
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.eager import context
from tensorflow.python.estimator import util as estimator_util
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@ -840,6 +841,7 @@ def no_regularizer(_):
return None
# TODO(alive): support caching devices and partitioned variables in Eager mode.
class VariableScope(object):
"""Variable scope object to carry defaults to provide to `get_variable`.
@ -894,6 +896,14 @@ class VariableScope(object):
self._dtype = dtype
self._use_resource = use_resource
self._constraint = constraint
if context.in_eager_mode():
if self._caching_device is not None:
raise NotImplementedError("Caching devices is not yet supported "
"in Eager mode.")
if self._partitioner is not None:
raise NotImplementedError("Partitioned variables are not yet supported "
"in Eager mode.")
self._use_resource = True
@property
def name(self):
@ -961,10 +971,16 @@ class VariableScope(object):
def set_caching_device(self, caching_device):
"""Set caching_device for this scope."""
if context.in_eager_mode():
raise NotImplementedError("Partitioned variables are not yet supported "
"in Eager mode.")
self._caching_device = caching_device
def set_partitioner(self, partitioner):
"""Set partitioner for this scope."""
if context.in_eager_mode():
raise NotImplementedError("Partitioned variables are not yet supported "
"in Eager mode.")
self._partitioner = partitioner
def set_custom_getter(self, custom_getter):
@ -1034,8 +1050,11 @@ class VariableScope(object):
constraint = self._constraint
if dtype is None:
dtype = self._dtype
if use_resource is None:
use_resource = self._use_resource
if context.in_graph_mode():
if use_resource is None:
use_resource = self._use_resource
else:
use_resource = True
return var_store.get_variable(
full_name, shape=shape, dtype=dtype, initializer=initializer,
@ -1060,6 +1079,9 @@ class VariableScope(object):
use_resource=None,
constraint=None):
"""Gets an existing variable with this name or create a new one."""
if context.in_eager_mode():
raise NotImplementedError("Partitioned variables are not yet supported "
"in Eager mode.")
if initializer is None:
initializer = self._initializer
if regularizer is None:

View File

@ -1364,7 +1364,7 @@ def variables_initializer(var_list, name="init"):
Returns:
An Op that run the initializers of all the specified variables.
"""
if var_list:
if var_list and context.in_graph_mode():
return control_flow_ops.group(*[v.initializer for v in var_list], name=name)
return control_flow_ops.no_op(name=name)