mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
No need for unique variable names in eager.
PiperOrigin-RevId: 173954805
This commit is contained in:
parent
f17f389d88
commit
a60cd87c43
|
|
@ -121,9 +121,10 @@ class MetricsTest(test.TestCase):
|
|||
# accidentally share state.
|
||||
m1 = metrics.Mean()
|
||||
m1(0)
|
||||
with self.assertRaises(ValueError):
|
||||
m2 = metrics.Mean()
|
||||
m2(2)
|
||||
m2 = metrics.Mean()
|
||||
m2(2)
|
||||
self.assertAllEqual(0.0, m1.result())
|
||||
self.assertAllEqual(2.0, m2.result())
|
||||
|
||||
def testNamesWithSpaces(self):
|
||||
# Verify two metrics with the same class and name don't
|
||||
|
|
|
|||
|
|
@ -54,6 +54,18 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
|
|||
0,
|
||||
dtype=dtypes.int32)).run()
|
||||
|
||||
def testEagerNameNotIdentity(self):
|
||||
with context.eager_mode():
|
||||
v0 = resource_variable_ops.ResourceVariable(1.0, name="a")
|
||||
v1 = resource_variable_ops.ResourceVariable(2.0, name="a")
|
||||
self.assertAllEqual(v0.numpy(), 1.0)
|
||||
self.assertAllEqual(v1.numpy(), 2.0)
|
||||
|
||||
def testEagerNameNotNeeded(self):
|
||||
with context.eager_mode():
|
||||
v0 = resource_variable_ops.ResourceVariable(1.0)
|
||||
self.assertAllEqual(v0.numpy(), 1.0)
|
||||
|
||||
def testReadVariableDtypeMismatchEager(self):
|
||||
with context.eager_mode():
|
||||
handle = resource_variable_ops.var_handle_op(
|
||||
|
|
@ -332,39 +344,38 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
|
|||
with self.assertRaisesRegexp(ValueError, "No attr named '_class'"):
|
||||
_ = w.value().op.get_attr("_class")
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testSharedName(self):
|
||||
v = resource_variable_ops.ResourceVariable(300.0, name="var4")
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
with self.test_session():
|
||||
v = resource_variable_ops.ResourceVariable(300.0, name="var4")
|
||||
variables.global_variables_initializer().run()
|
||||
|
||||
w = resource_variable_ops.var_handle_op(
|
||||
dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var4",
|
||||
# Needed in Eager since we get a unique container name by default.
|
||||
container=ops.get_default_graph()._container)
|
||||
w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype)
|
||||
self.assertEqual(300.0, self.evaluate(w_read))
|
||||
w = resource_variable_ops.var_handle_op(
|
||||
dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var4",
|
||||
# Needed in Eager since we get a unique container name by default.
|
||||
container=ops.get_default_graph()._container)
|
||||
w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype)
|
||||
self.assertEqual(300.0, w_read.eval())
|
||||
|
||||
x = resource_variable_ops.var_handle_op(
|
||||
dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var5",
|
||||
container=ops.get_default_graph()._container)
|
||||
with self.assertRaisesOpError("Resource .*/var5/.* does not exist"):
|
||||
x_read = resource_variable_ops.read_variable_op(x, v.dtype.base_dtype)
|
||||
self.evaluate(x_read)
|
||||
x = resource_variable_ops.var_handle_op(
|
||||
dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var5",
|
||||
container=ops.get_default_graph()._container)
|
||||
with self.assertRaisesOpError("Resource .*/var5/.* does not exist"):
|
||||
resource_variable_ops.read_variable_op(x, v.dtype.base_dtype).eval()
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testSharedNameWithNamescope(self):
|
||||
with ops.name_scope("foo"):
|
||||
v = resource_variable_ops.ResourceVariable(300.0, name="var6")
|
||||
self.assertEqual("foo/var6", v._shared_name) # pylint: disable=protected-access
|
||||
self.assertEqual("foo/var6:0", v.name)
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
with self.test_session():
|
||||
with ops.name_scope("foo"):
|
||||
v = resource_variable_ops.ResourceVariable(300.0, name="var6")
|
||||
self.assertEqual("foo/var6", v._shared_name) # pylint: disable=protected-access
|
||||
self.assertEqual("foo/var6:0", v.name)
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
|
||||
w = resource_variable_ops.var_handle_op(
|
||||
dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="foo/var6",
|
||||
# Needed in Eager since we get a unique container name by default.
|
||||
container=ops.get_default_graph()._container)
|
||||
w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype)
|
||||
self.assertEqual(300.0, self.evaluate(w_read))
|
||||
w = resource_variable_ops.var_handle_op(
|
||||
dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="foo/var6",
|
||||
# Needed in Eager since we get a unique container name by default.
|
||||
container=ops.get_default_graph()._container)
|
||||
w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype)
|
||||
self.assertEqual(300.0, self.evaluate(w_read))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testShape(self):
|
||||
|
|
@ -468,25 +479,10 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
|
|||
name="var8")
|
||||
var.__del__()
|
||||
with self.assertRaisesRegexp(errors.NotFoundError,
|
||||
r"Resource .*\/var8\/.* does not exist."):
|
||||
r"Resource .* does not exist."):
|
||||
resource_variable_ops.destroy_resource_op(var._handle,
|
||||
ignore_lookup_error=False)
|
||||
|
||||
def testSharingViaResourceVariableObject(self):
|
||||
with context.eager_mode():
|
||||
_ = resource_variable_ops.ResourceVariable(1.0, name="var0")
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
"'var0' already created"):
|
||||
_ = resource_variable_ops.ResourceVariable(2.0, name="var0")
|
||||
with ops.Graph().as_default():
|
||||
_ = resource_variable_ops.ResourceVariable(2.0, name="var0")
|
||||
|
||||
def testVariableNameMissing(self):
|
||||
with context.eager_mode():
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
"Variables need to have explicit names"):
|
||||
_ = resource_variable_ops.ResourceVariable(1.0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
|
|
|||
|
|
@ -43,6 +43,10 @@ def _eager_safe_variable_handle(shape, dtype, shared_name, name, graph_mode):
|
|||
container = ops.get_default_graph()._container # pylint: disable=protected-access
|
||||
if container is None:
|
||||
container = ""
|
||||
if not graph_mode:
|
||||
# When in eager mode use a uid for the shared_name, to prevent accidental
|
||||
# sharing.
|
||||
shared_name = str(ops.uid())
|
||||
handle = gen_resource_variable_ops.var_handle_op(shape=shape, dtype=dtype,
|
||||
shared_name=shared_name,
|
||||
name=name,
|
||||
|
|
@ -293,12 +297,6 @@ class ResourceVariable(variables.Variable):
|
|||
# Save the graph's container prefix for error checking. Reading the value of
|
||||
# the ResourceVariable from another Graph in Eager mode is an error.
|
||||
self._container_prefix = ops.get_default_graph()._container_prefix # pylint: disable=protected-access
|
||||
if not self._in_graph_mode and not name:
|
||||
# TODO(ashankar,josh11b): make this unnecessary using the same
|
||||
# logic as in layer
|
||||
raise ValueError("Variables need to have explicit names when eager "
|
||||
"execution is enabled")
|
||||
|
||||
with ops.control_dependencies(None):
|
||||
with ops.name_scope(name, "Variable", []
|
||||
if init_from_fn else [initial_value]) as name:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user