No need for unique variable names in eager.

PiperOrigin-RevId: 173954805
This commit is contained in:
Alexandre Passos 2017-10-30 15:00:41 -07:00 committed by TensorFlower Gardener
parent f17f389d88
commit a60cd87c43
3 changed files with 47 additions and 52 deletions

View File

@ -121,9 +121,10 @@ class MetricsTest(test.TestCase):
# accidentally share state. # accidentally share state.
m1 = metrics.Mean() m1 = metrics.Mean()
m1(0) m1(0)
with self.assertRaises(ValueError):
m2 = metrics.Mean() m2 = metrics.Mean()
m2(2) m2(2)
self.assertAllEqual(0.0, m1.result())
self.assertAllEqual(2.0, m2.result())
def testNamesWithSpaces(self): def testNamesWithSpaces(self):
# Verify two metrics with the same class and name don't # Verify two metrics with the same class and name don't

View File

@ -54,6 +54,18 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
0, 0,
dtype=dtypes.int32)).run() dtype=dtypes.int32)).run()
def testEagerNameNotIdentity(self):
with context.eager_mode():
v0 = resource_variable_ops.ResourceVariable(1.0, name="a")
v1 = resource_variable_ops.ResourceVariable(2.0, name="a")
self.assertAllEqual(v0.numpy(), 1.0)
self.assertAllEqual(v1.numpy(), 2.0)
def testEagerNameNotNeeded(self):
with context.eager_mode():
v0 = resource_variable_ops.ResourceVariable(1.0)
self.assertAllEqual(v0.numpy(), 1.0)
def testReadVariableDtypeMismatchEager(self): def testReadVariableDtypeMismatchEager(self):
with context.eager_mode(): with context.eager_mode():
handle = resource_variable_ops.var_handle_op( handle = resource_variable_ops.var_handle_op(
@ -332,27 +344,26 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
with self.assertRaisesRegexp(ValueError, "No attr named '_class'"): with self.assertRaisesRegexp(ValueError, "No attr named '_class'"):
_ = w.value().op.get_attr("_class") _ = w.value().op.get_attr("_class")
@test_util.run_in_graph_and_eager_modes()
def testSharedName(self): def testSharedName(self):
with self.test_session():
v = resource_variable_ops.ResourceVariable(300.0, name="var4") v = resource_variable_ops.ResourceVariable(300.0, name="var4")
self.evaluate(variables.global_variables_initializer()) variables.global_variables_initializer().run()
w = resource_variable_ops.var_handle_op( w = resource_variable_ops.var_handle_op(
dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var4", dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var4",
# Needed in Eager since we get a unique container name by default. # Needed in Eager since we get a unique container name by default.
container=ops.get_default_graph()._container) container=ops.get_default_graph()._container)
w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype) w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype)
self.assertEqual(300.0, self.evaluate(w_read)) self.assertEqual(300.0, w_read.eval())
x = resource_variable_ops.var_handle_op( x = resource_variable_ops.var_handle_op(
dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var5", dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var5",
container=ops.get_default_graph()._container) container=ops.get_default_graph()._container)
with self.assertRaisesOpError("Resource .*/var5/.* does not exist"): with self.assertRaisesOpError("Resource .*/var5/.* does not exist"):
x_read = resource_variable_ops.read_variable_op(x, v.dtype.base_dtype) resource_variable_ops.read_variable_op(x, v.dtype.base_dtype).eval()
self.evaluate(x_read)
@test_util.run_in_graph_and_eager_modes()
def testSharedNameWithNamescope(self): def testSharedNameWithNamescope(self):
with self.test_session():
with ops.name_scope("foo"): with ops.name_scope("foo"):
v = resource_variable_ops.ResourceVariable(300.0, name="var6") v = resource_variable_ops.ResourceVariable(300.0, name="var6")
self.assertEqual("foo/var6", v._shared_name) # pylint: disable=protected-access self.assertEqual("foo/var6", v._shared_name) # pylint: disable=protected-access
@ -468,25 +479,10 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
name="var8") name="var8")
var.__del__() var.__del__()
with self.assertRaisesRegexp(errors.NotFoundError, with self.assertRaisesRegexp(errors.NotFoundError,
r"Resource .*\/var8\/.* does not exist."): r"Resource .* does not exist."):
resource_variable_ops.destroy_resource_op(var._handle, resource_variable_ops.destroy_resource_op(var._handle,
ignore_lookup_error=False) ignore_lookup_error=False)
def testSharingViaResourceVariableObject(self):
with context.eager_mode():
_ = resource_variable_ops.ResourceVariable(1.0, name="var0")
with self.assertRaisesRegexp(ValueError,
"'var0' already created"):
_ = resource_variable_ops.ResourceVariable(2.0, name="var0")
with ops.Graph().as_default():
_ = resource_variable_ops.ResourceVariable(2.0, name="var0")
def testVariableNameMissing(self):
with context.eager_mode():
with self.assertRaisesRegexp(ValueError,
"Variables need to have explicit names"):
_ = resource_variable_ops.ResourceVariable(1.0)
if __name__ == "__main__": if __name__ == "__main__":
test.main() test.main()

View File

@ -43,6 +43,10 @@ def _eager_safe_variable_handle(shape, dtype, shared_name, name, graph_mode):
container = ops.get_default_graph()._container # pylint: disable=protected-access container = ops.get_default_graph()._container # pylint: disable=protected-access
if container is None: if container is None:
container = "" container = ""
if not graph_mode:
# When in eager mode use a uid for the shared_name, to prevent accidental
# sharing.
shared_name = str(ops.uid())
handle = gen_resource_variable_ops.var_handle_op(shape=shape, dtype=dtype, handle = gen_resource_variable_ops.var_handle_op(shape=shape, dtype=dtype,
shared_name=shared_name, shared_name=shared_name,
name=name, name=name,
@ -293,12 +297,6 @@ class ResourceVariable(variables.Variable):
# Save the graph's container prefix for error checking. Reading the value of # Save the graph's container prefix for error checking. Reading the value of
# the ResourceVariable from another Graph in Eager mode is an error. # the ResourceVariable from another Graph in Eager mode is an error.
self._container_prefix = ops.get_default_graph()._container_prefix # pylint: disable=protected-access self._container_prefix = ops.get_default_graph()._container_prefix # pylint: disable=protected-access
if not self._in_graph_mode and not name:
# TODO(ashankar,josh11b): make this unnecessary using the same
# logic as in layer
raise ValueError("Variables need to have explicit names when eager "
"execution is enabled")
with ops.control_dependencies(None): with ops.control_dependencies(None):
with ops.name_scope(name, "Variable", [] with ops.name_scope(name, "Variable", []
if init_from_fn else [initial_value]) as name: if init_from_fn else [initial_value]) as name: