mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
No need for unique variable names in eager.
PiperOrigin-RevId: 173954805
This commit is contained in:
parent
f17f389d88
commit
a60cd87c43
|
|
@ -121,9 +121,10 @@ class MetricsTest(test.TestCase):
|
||||||
# accidentally share state.
|
# accidentally share state.
|
||||||
m1 = metrics.Mean()
|
m1 = metrics.Mean()
|
||||||
m1(0)
|
m1(0)
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
m2 = metrics.Mean()
|
m2 = metrics.Mean()
|
||||||
m2(2)
|
m2(2)
|
||||||
|
self.assertAllEqual(0.0, m1.result())
|
||||||
|
self.assertAllEqual(2.0, m2.result())
|
||||||
|
|
||||||
def testNamesWithSpaces(self):
|
def testNamesWithSpaces(self):
|
||||||
# Verify two metrics with the same class and name don't
|
# Verify two metrics with the same class and name don't
|
||||||
|
|
|
||||||
|
|
@ -54,6 +54,18 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
|
||||||
0,
|
0,
|
||||||
dtype=dtypes.int32)).run()
|
dtype=dtypes.int32)).run()
|
||||||
|
|
||||||
|
def testEagerNameNotIdentity(self):
|
||||||
|
with context.eager_mode():
|
||||||
|
v0 = resource_variable_ops.ResourceVariable(1.0, name="a")
|
||||||
|
v1 = resource_variable_ops.ResourceVariable(2.0, name="a")
|
||||||
|
self.assertAllEqual(v0.numpy(), 1.0)
|
||||||
|
self.assertAllEqual(v1.numpy(), 2.0)
|
||||||
|
|
||||||
|
def testEagerNameNotNeeded(self):
|
||||||
|
with context.eager_mode():
|
||||||
|
v0 = resource_variable_ops.ResourceVariable(1.0)
|
||||||
|
self.assertAllEqual(v0.numpy(), 1.0)
|
||||||
|
|
||||||
def testReadVariableDtypeMismatchEager(self):
|
def testReadVariableDtypeMismatchEager(self):
|
||||||
with context.eager_mode():
|
with context.eager_mode():
|
||||||
handle = resource_variable_ops.var_handle_op(
|
handle = resource_variable_ops.var_handle_op(
|
||||||
|
|
@ -332,27 +344,26 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
|
||||||
with self.assertRaisesRegexp(ValueError, "No attr named '_class'"):
|
with self.assertRaisesRegexp(ValueError, "No attr named '_class'"):
|
||||||
_ = w.value().op.get_attr("_class")
|
_ = w.value().op.get_attr("_class")
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes()
|
|
||||||
def testSharedName(self):
|
def testSharedName(self):
|
||||||
|
with self.test_session():
|
||||||
v = resource_variable_ops.ResourceVariable(300.0, name="var4")
|
v = resource_variable_ops.ResourceVariable(300.0, name="var4")
|
||||||
self.evaluate(variables.global_variables_initializer())
|
variables.global_variables_initializer().run()
|
||||||
|
|
||||||
w = resource_variable_ops.var_handle_op(
|
w = resource_variable_ops.var_handle_op(
|
||||||
dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var4",
|
dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var4",
|
||||||
# Needed in Eager since we get a unique container name by default.
|
# Needed in Eager since we get a unique container name by default.
|
||||||
container=ops.get_default_graph()._container)
|
container=ops.get_default_graph()._container)
|
||||||
w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype)
|
w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype)
|
||||||
self.assertEqual(300.0, self.evaluate(w_read))
|
self.assertEqual(300.0, w_read.eval())
|
||||||
|
|
||||||
x = resource_variable_ops.var_handle_op(
|
x = resource_variable_ops.var_handle_op(
|
||||||
dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var5",
|
dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var5",
|
||||||
container=ops.get_default_graph()._container)
|
container=ops.get_default_graph()._container)
|
||||||
with self.assertRaisesOpError("Resource .*/var5/.* does not exist"):
|
with self.assertRaisesOpError("Resource .*/var5/.* does not exist"):
|
||||||
x_read = resource_variable_ops.read_variable_op(x, v.dtype.base_dtype)
|
resource_variable_ops.read_variable_op(x, v.dtype.base_dtype).eval()
|
||||||
self.evaluate(x_read)
|
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes()
|
|
||||||
def testSharedNameWithNamescope(self):
|
def testSharedNameWithNamescope(self):
|
||||||
|
with self.test_session():
|
||||||
with ops.name_scope("foo"):
|
with ops.name_scope("foo"):
|
||||||
v = resource_variable_ops.ResourceVariable(300.0, name="var6")
|
v = resource_variable_ops.ResourceVariable(300.0, name="var6")
|
||||||
self.assertEqual("foo/var6", v._shared_name) # pylint: disable=protected-access
|
self.assertEqual("foo/var6", v._shared_name) # pylint: disable=protected-access
|
||||||
|
|
@ -468,25 +479,10 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
|
||||||
name="var8")
|
name="var8")
|
||||||
var.__del__()
|
var.__del__()
|
||||||
with self.assertRaisesRegexp(errors.NotFoundError,
|
with self.assertRaisesRegexp(errors.NotFoundError,
|
||||||
r"Resource .*\/var8\/.* does not exist."):
|
r"Resource .* does not exist."):
|
||||||
resource_variable_ops.destroy_resource_op(var._handle,
|
resource_variable_ops.destroy_resource_op(var._handle,
|
||||||
ignore_lookup_error=False)
|
ignore_lookup_error=False)
|
||||||
|
|
||||||
def testSharingViaResourceVariableObject(self):
|
|
||||||
with context.eager_mode():
|
|
||||||
_ = resource_variable_ops.ResourceVariable(1.0, name="var0")
|
|
||||||
with self.assertRaisesRegexp(ValueError,
|
|
||||||
"'var0' already created"):
|
|
||||||
_ = resource_variable_ops.ResourceVariable(2.0, name="var0")
|
|
||||||
with ops.Graph().as_default():
|
|
||||||
_ = resource_variable_ops.ResourceVariable(2.0, name="var0")
|
|
||||||
|
|
||||||
def testVariableNameMissing(self):
|
|
||||||
with context.eager_mode():
|
|
||||||
with self.assertRaisesRegexp(ValueError,
|
|
||||||
"Variables need to have explicit names"):
|
|
||||||
_ = resource_variable_ops.ResourceVariable(1.0)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
|
|
||||||
|
|
@ -43,6 +43,10 @@ def _eager_safe_variable_handle(shape, dtype, shared_name, name, graph_mode):
|
||||||
container = ops.get_default_graph()._container # pylint: disable=protected-access
|
container = ops.get_default_graph()._container # pylint: disable=protected-access
|
||||||
if container is None:
|
if container is None:
|
||||||
container = ""
|
container = ""
|
||||||
|
if not graph_mode:
|
||||||
|
# When in eager mode use a uid for the shared_name, to prevent accidental
|
||||||
|
# sharing.
|
||||||
|
shared_name = str(ops.uid())
|
||||||
handle = gen_resource_variable_ops.var_handle_op(shape=shape, dtype=dtype,
|
handle = gen_resource_variable_ops.var_handle_op(shape=shape, dtype=dtype,
|
||||||
shared_name=shared_name,
|
shared_name=shared_name,
|
||||||
name=name,
|
name=name,
|
||||||
|
|
@ -293,12 +297,6 @@ class ResourceVariable(variables.Variable):
|
||||||
# Save the graph's container prefix for error checking. Reading the value of
|
# Save the graph's container prefix for error checking. Reading the value of
|
||||||
# the ResourceVariable from another Graph in Eager mode is an error.
|
# the ResourceVariable from another Graph in Eager mode is an error.
|
||||||
self._container_prefix = ops.get_default_graph()._container_prefix # pylint: disable=protected-access
|
self._container_prefix = ops.get_default_graph()._container_prefix # pylint: disable=protected-access
|
||||||
if not self._in_graph_mode and not name:
|
|
||||||
# TODO(ashankar,josh11b): make this unnecessary using the same
|
|
||||||
# logic as in layer
|
|
||||||
raise ValueError("Variables need to have explicit names when eager "
|
|
||||||
"execution is enabled")
|
|
||||||
|
|
||||||
with ops.control_dependencies(None):
|
with ops.control_dependencies(None):
|
||||||
with ops.name_scope(name, "Variable", []
|
with ops.name_scope(name, "Variable", []
|
||||||
if init_from_fn else [initial_value]) as name:
|
if init_from_fn else [initial_value]) as name:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user