No need for unique variable names in eager.

PiperOrigin-RevId: 173954805
This commit is contained in:
Alexandre Passos 2017-10-30 15:00:41 -07:00 committed by TensorFlower Gardener
parent f17f389d88
commit a60cd87c43
3 changed files with 47 additions and 52 deletions

View File

@ -121,9 +121,10 @@ class MetricsTest(test.TestCase):
# accidentally share state.
m1 = metrics.Mean()
m1(0)
with self.assertRaises(ValueError):
m2 = metrics.Mean()
m2(2)
m2 = metrics.Mean()
m2(2)
self.assertAllEqual(0.0, m1.result())
self.assertAllEqual(2.0, m2.result())
def testNamesWithSpaces(self):
# Verify two metrics with the same class and name don't

View File

@ -54,6 +54,18 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
0,
dtype=dtypes.int32)).run()
def testEagerNameNotIdentity(self):
with context.eager_mode():
v0 = resource_variable_ops.ResourceVariable(1.0, name="a")
v1 = resource_variable_ops.ResourceVariable(2.0, name="a")
self.assertAllEqual(v0.numpy(), 1.0)
self.assertAllEqual(v1.numpy(), 2.0)
def testEagerNameNotNeeded(self):
with context.eager_mode():
v0 = resource_variable_ops.ResourceVariable(1.0)
self.assertAllEqual(v0.numpy(), 1.0)
def testReadVariableDtypeMismatchEager(self):
with context.eager_mode():
handle = resource_variable_ops.var_handle_op(
@ -332,39 +344,38 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
with self.assertRaisesRegexp(ValueError, "No attr named '_class'"):
_ = w.value().op.get_attr("_class")
@test_util.run_in_graph_and_eager_modes()
def testSharedName(self):
v = resource_variable_ops.ResourceVariable(300.0, name="var4")
self.evaluate(variables.global_variables_initializer())
with self.test_session():
v = resource_variable_ops.ResourceVariable(300.0, name="var4")
variables.global_variables_initializer().run()
w = resource_variable_ops.var_handle_op(
dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var4",
# Needed in Eager since we get a unique container name by default.
container=ops.get_default_graph()._container)
w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype)
self.assertEqual(300.0, self.evaluate(w_read))
w = resource_variable_ops.var_handle_op(
dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var4",
# Needed in Eager since we get a unique container name by default.
container=ops.get_default_graph()._container)
w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype)
self.assertEqual(300.0, w_read.eval())
x = resource_variable_ops.var_handle_op(
dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var5",
container=ops.get_default_graph()._container)
with self.assertRaisesOpError("Resource .*/var5/.* does not exist"):
x_read = resource_variable_ops.read_variable_op(x, v.dtype.base_dtype)
self.evaluate(x_read)
x = resource_variable_ops.var_handle_op(
dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var5",
container=ops.get_default_graph()._container)
with self.assertRaisesOpError("Resource .*/var5/.* does not exist"):
resource_variable_ops.read_variable_op(x, v.dtype.base_dtype).eval()
@test_util.run_in_graph_and_eager_modes()
def testSharedNameWithNamescope(self):
with ops.name_scope("foo"):
v = resource_variable_ops.ResourceVariable(300.0, name="var6")
self.assertEqual("foo/var6", v._shared_name) # pylint: disable=protected-access
self.assertEqual("foo/var6:0", v.name)
self.evaluate(variables.global_variables_initializer())
with self.test_session():
with ops.name_scope("foo"):
v = resource_variable_ops.ResourceVariable(300.0, name="var6")
self.assertEqual("foo/var6", v._shared_name) # pylint: disable=protected-access
self.assertEqual("foo/var6:0", v.name)
self.evaluate(variables.global_variables_initializer())
w = resource_variable_ops.var_handle_op(
dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="foo/var6",
# Needed in Eager since we get a unique container name by default.
container=ops.get_default_graph()._container)
w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype)
self.assertEqual(300.0, self.evaluate(w_read))
w = resource_variable_ops.var_handle_op(
dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="foo/var6",
# Needed in Eager since we get a unique container name by default.
container=ops.get_default_graph()._container)
w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype)
self.assertEqual(300.0, self.evaluate(w_read))
@test_util.run_in_graph_and_eager_modes()
def testShape(self):
@ -468,25 +479,10 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
name="var8")
var.__del__()
with self.assertRaisesRegexp(errors.NotFoundError,
r"Resource .*\/var8\/.* does not exist."):
r"Resource .* does not exist."):
resource_variable_ops.destroy_resource_op(var._handle,
ignore_lookup_error=False)
def testSharingViaResourceVariableObject(self):
with context.eager_mode():
_ = resource_variable_ops.ResourceVariable(1.0, name="var0")
with self.assertRaisesRegexp(ValueError,
"'var0' already created"):
_ = resource_variable_ops.ResourceVariable(2.0, name="var0")
with ops.Graph().as_default():
_ = resource_variable_ops.ResourceVariable(2.0, name="var0")
def testVariableNameMissing(self):
with context.eager_mode():
with self.assertRaisesRegexp(ValueError,
"Variables need to have explicit names"):
_ = resource_variable_ops.ResourceVariable(1.0)
if __name__ == "__main__":
test.main()

View File

@ -43,6 +43,10 @@ def _eager_safe_variable_handle(shape, dtype, shared_name, name, graph_mode):
container = ops.get_default_graph()._container # pylint: disable=protected-access
if container is None:
container = ""
if not graph_mode:
# When in eager mode use a uid for the shared_name, to prevent accidental
# sharing.
shared_name = str(ops.uid())
handle = gen_resource_variable_ops.var_handle_op(shape=shape, dtype=dtype,
shared_name=shared_name,
name=name,
@ -293,12 +297,6 @@ class ResourceVariable(variables.Variable):
# Save the graph's container prefix for error checking. Reading the value of
# the ResourceVariable from another Graph in Eager mode is an error.
self._container_prefix = ops.get_default_graph()._container_prefix # pylint: disable=protected-access
if not self._in_graph_mode and not name:
# TODO(ashankar,josh11b): make this unnecessary using the same
# logic as in layer
raise ValueError("Variables need to have explicit names when eager "
"execution is enabled")
with ops.control_dependencies(None):
with ops.name_scope(name, "Variable", []
if init_from_fn else [initial_value]) as name: