mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Steps toward making ResourceVariables compatible with Eager.
This change forces the value of the reuse flag in variable scopes to be tf.AUTO_REUSE when in Eager mode. This change also adds comprehensive Eager tests for ResourceVariable. PiperOrigin-RevId: 166408161
This commit is contained in:
parent
b2ce451502
commit
3142f8ef5d
|
|
@ -2212,6 +2212,7 @@ py_library(
|
|||
":resource_variable_ops",
|
||||
":tensor_shape",
|
||||
":variables",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/estimator:util",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
|
|
|
|||
|
|
@ -4923,21 +4923,21 @@ def name_scope(name, default_name=None, values=None):
|
|||
ValueError: if neither `name` nor `default_name` is provided
|
||||
but `values` are.
|
||||
"""
|
||||
n = default_name if name is None else name
|
||||
name = default_name if name is None else name
|
||||
if context.in_eager_mode():
|
||||
ctx = context.context()
|
||||
old_name = ctx.scope_name
|
||||
if n is None:
|
||||
if name is None:
|
||||
scope_name = ""
|
||||
else:
|
||||
scope_name = "%s%s/" % (old_name, n) if old_name else "%s/" % n
|
||||
scope_name = "%s%s/" % (old_name, name) if old_name else "%s/" % name
|
||||
ctx.scope_name = scope_name
|
||||
try:
|
||||
yield scope_name
|
||||
finally:
|
||||
ctx.scope_name = old_name
|
||||
else:
|
||||
if n is None and values is not None:
|
||||
if name is None and values is not None:
|
||||
# We only raise an error if values is not None (provided) because
|
||||
# currently tf.name_scope(None) (values=None then) is sometimes used as an
|
||||
# idiom to reset to top scope.
|
||||
|
|
@ -4947,7 +4947,7 @@ def name_scope(name, default_name=None, values=None):
|
|||
if values is None:
|
||||
values = []
|
||||
g = _get_graph_from_inputs(values)
|
||||
with g.as_default(), g.name_scope(n) as scope:
|
||||
with g.as_default(), g.name_scope(name) as scope:
|
||||
yield scope
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -638,6 +638,7 @@ cuda_py_test(
|
|||
srcs = ["resource_variable_ops_test.py"],
|
||||
additional_deps = [
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
|
|
|
|||
|
|
@ -53,86 +53,78 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
|
|||
0,
|
||||
dtype=dtypes.int32)).run()
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testDtypeSurvivesIdentity(self):
|
||||
with self.test_session():
|
||||
handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
|
||||
id_handle = array_ops.identity(handle)
|
||||
resource_variable_ops.assign_variable_op(id_handle,
|
||||
constant_op.constant(
|
||||
0,
|
||||
dtype=dtypes.int32)).run()
|
||||
handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
|
||||
id_handle = array_ops.identity(handle)
|
||||
self.evaluate(resource_variable_ops.assign_variable_op(
|
||||
id_handle, constant_op.constant(0, dtype=dtypes.int32)))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testCreateRead(self):
|
||||
with self.test_session():
|
||||
handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
|
||||
resource_variable_ops.assign_variable_op(handle,
|
||||
constant_op.constant(
|
||||
1,
|
||||
dtype=dtypes.int32)).run()
|
||||
value = resource_variable_ops.read_variable_op(
|
||||
handle, dtype=dtypes.int32).eval()
|
||||
self.assertAllEqual(1, value)
|
||||
handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
|
||||
self.evaluate(resource_variable_ops.assign_variable_op(
|
||||
handle, constant_op.constant(1, dtype=dtypes.int32)))
|
||||
value = self.evaluate(
|
||||
resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32))
|
||||
self.assertAllEqual(1, value)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testManyAssigns(self):
|
||||
with self.test_session() as session:
|
||||
handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
|
||||
create = resource_variable_ops.assign_variable_op(handle,
|
||||
constant_op.constant(
|
||||
1,
|
||||
dtype=dtypes.int32))
|
||||
with ops.control_dependencies([create]):
|
||||
first_read = resource_variable_ops.read_variable_op(
|
||||
handle, dtype=dtypes.int32)
|
||||
with ops.control_dependencies([first_read]):
|
||||
write = resource_variable_ops.assign_variable_op(
|
||||
handle, constant_op.constant(2, dtype=dtypes.int32))
|
||||
with ops.control_dependencies([write]):
|
||||
second_read = resource_variable_ops.read_variable_op(
|
||||
handle, dtype=dtypes.int32)
|
||||
f, s = session.run([first_read, second_read])
|
||||
self.assertEqual(f, 1)
|
||||
self.assertEqual(s, 2)
|
||||
handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
|
||||
create = resource_variable_ops.assign_variable_op(
|
||||
handle, constant_op.constant(1, dtype=dtypes.int32))
|
||||
with ops.control_dependencies([create]):
|
||||
first_read = resource_variable_ops.read_variable_op(
|
||||
handle, dtype=dtypes.int32)
|
||||
with ops.control_dependencies([first_read]):
|
||||
write = resource_variable_ops.assign_variable_op(
|
||||
handle, constant_op.constant(2, dtype=dtypes.int32))
|
||||
with ops.control_dependencies([write]):
|
||||
second_read = resource_variable_ops.read_variable_op(
|
||||
handle, dtype=dtypes.int32)
|
||||
f, s = self.evaluate([first_read, second_read])
|
||||
self.assertEqual(f, 1)
|
||||
self.assertEqual(s, 2)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testAssignAdd(self):
|
||||
with self.test_session():
|
||||
handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
|
||||
resource_variable_ops.assign_variable_op(handle,
|
||||
constant_op.constant(
|
||||
1,
|
||||
dtype=dtypes.int32)).run()
|
||||
resource_variable_ops.assign_add_variable_op(
|
||||
handle, constant_op.constant(1, dtype=dtypes.int32)).run()
|
||||
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
|
||||
self.assertEqual(read.eval(), 2)
|
||||
handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
|
||||
self.evaluate(resource_variable_ops.assign_variable_op(
|
||||
handle, constant_op.constant(1, dtype=dtypes.int32)))
|
||||
self.evaluate(resource_variable_ops.assign_add_variable_op(
|
||||
handle, constant_op.constant(1, dtype=dtypes.int32)))
|
||||
read = self.evaluate(
|
||||
resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32))
|
||||
self.assertEqual(read, 2)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes(use_gpu=True)
|
||||
def testScatterAdd(self):
|
||||
with self.test_session(use_gpu=True):
|
||||
handle = resource_variable_ops.var_handle_op(
|
||||
dtype=dtypes.int32, shape=[1, 1])
|
||||
resource_variable_ops.assign_variable_op(handle,
|
||||
constant_op.constant(
|
||||
[[1]],
|
||||
dtype=dtypes.int32)).run()
|
||||
resource_variable_ops.resource_scatter_add(handle, [0],
|
||||
constant_op.constant(
|
||||
[[2]],
|
||||
dtype=dtypes.int32)).run()
|
||||
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
|
||||
self.assertEqual(read.eval(), [[3]])
|
||||
handle = resource_variable_ops.var_handle_op(
|
||||
dtype=dtypes.int32, shape=[1, 1])
|
||||
self.evaluate(resource_variable_ops.assign_variable_op(
|
||||
handle, constant_op.constant([[1]], dtype=dtypes.int32)))
|
||||
self.evaluate(resource_variable_ops.resource_scatter_add(
|
||||
handle, [0], constant_op.constant([[2]], dtype=dtypes.int32)))
|
||||
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
|
||||
self.assertEqual(self.evaluate(read), [[3]])
|
||||
|
||||
# TODO(alive): get this to work in Eager mode.
|
||||
def testGPU(self):
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
with self.test_session(use_gpu=True):
|
||||
abc = variable_scope.get_variable(
|
||||
"abc",
|
||||
shape=[1],
|
||||
initializer=init_ops.ones_initializer(),
|
||||
use_resource=True)
|
||||
|
||||
sess.run(variables.global_variables_initializer())
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.assertEqual(
|
||||
resource_variable_ops.var_is_initialized_op(abc.handle).eval(), True)
|
||||
print(sess.run(abc))
|
||||
self.evaluate(
|
||||
resource_variable_ops.var_is_initialized_op(abc.handle)),
|
||||
True)
|
||||
|
||||
# TODO(alive): fix bug in convert_to_tensor; get this to work in Eager.
|
||||
def testConstraintArg(self):
|
||||
constraint = lambda x: x
|
||||
v = resource_variable_ops.ResourceVariable(
|
||||
|
|
@ -144,6 +136,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
|
|||
v = resource_variable_ops.ResourceVariable(
|
||||
initial_value=lambda: 1, constraint=constraint)
|
||||
|
||||
# TODO(alive): how should this work in Eager mode?
|
||||
def testInitFn(self):
|
||||
with self.test_session():
|
||||
v = resource_variable_ops.ResourceVariable(
|
||||
|
|
@ -151,53 +144,54 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
|
|||
self.assertEqual(v.handle.op.colocation_groups(),
|
||||
v.initializer.inputs[1].op.colocation_groups())
|
||||
|
||||
# TODO(alive): fix bug in convert_to_tensor; get this to work in Eager.
|
||||
def testInitFnDtype(self):
|
||||
with self.test_session():
|
||||
v = resource_variable_ops.ResourceVariable(
|
||||
initial_value=lambda: 1, dtype=dtypes.float32)
|
||||
self.assertEqual(dtypes.float32, v.value().dtype)
|
||||
|
||||
# TODO(alive): fix bug in convert_to_tensor; get this to work in Eager.
|
||||
def testInitFnNoDtype(self):
|
||||
with self.test_session():
|
||||
v = resource_variable_ops.ResourceVariable(initial_value=lambda: 1)
|
||||
self.assertEqual(dtypes.int32, v.value().dtype)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testInitializeAllVariables(self):
|
||||
with self.test_session():
|
||||
v = resource_variable_ops.ResourceVariable(1, dtype=dtypes.float32)
|
||||
with self.assertRaises(errors.NotFoundError):
|
||||
v.value().eval()
|
||||
variables.global_variables_initializer().run()
|
||||
self.assertEqual(1.0, v.value().eval())
|
||||
v = resource_variable_ops.ResourceVariable(1, dtype=dtypes.float32)
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.assertEqual(1.0, self.evaluate(v.value()))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testOperatorOverload(self):
|
||||
with self.test_session():
|
||||
v = resource_variable_ops.ResourceVariable(1.0)
|
||||
variables.global_variables_initializer().run()
|
||||
self.assertEqual(2.0, (v + v).eval())
|
||||
v = resource_variable_ops.ResourceVariable(1.0)
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.assertEqual(2.0, self.evaluate(v + v))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testAssignMethod(self):
|
||||
with self.test_session():
|
||||
v = resource_variable_ops.ResourceVariable(1.0)
|
||||
variables.global_variables_initializer().run()
|
||||
v.assign(2.0).eval()
|
||||
self.assertEqual(2.0, v.value().eval())
|
||||
v = resource_variable_ops.ResourceVariable(1.0)
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.evaluate(v.assign(2.0))
|
||||
self.assertEqual(2.0, self.evaluate(v.value()))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testLoad(self):
|
||||
with self.test_session():
|
||||
v = resource_variable_ops.ResourceVariable(1.0)
|
||||
variables.global_variables_initializer().run()
|
||||
v.load(2.0)
|
||||
self.assertEqual(2.0, v.value().eval())
|
||||
v = resource_variable_ops.ResourceVariable(1.0)
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
v.load(2.0)
|
||||
self.assertEqual(2.0, self.evaluate(v.value()))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testSparseRead(self):
|
||||
with self.test_session():
|
||||
init_value = np.reshape(np.arange(np.power(4, 3)), (4, 4, 4))
|
||||
v = resource_variable_ops.ResourceVariable(
|
||||
constant_op.constant(init_value, dtype=dtypes.int32))
|
||||
variables.global_variables_initializer().run()
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
|
||||
value = v.sparse_read([0, 3, 1, 2]).eval()
|
||||
value = self.evaluate(v.sparse_read([0, 3, 1, 2]))
|
||||
self.assertAllEqual(init_value[[0, 3, 1, 2], ...], value)
|
||||
|
||||
def testToFromProto(self):
|
||||
|
|
@ -208,34 +202,33 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
|
|||
w = resource_variable_ops.ResourceVariable.from_proto(v.to_proto())
|
||||
self.assertEquals(2, math_ops.add(w, 1).eval())
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testAssignAddMethod(self):
|
||||
with self.test_session():
|
||||
v = resource_variable_ops.ResourceVariable(1.0)
|
||||
variables.global_variables_initializer().run()
|
||||
v.assign_add(1.0).eval()
|
||||
self.assertEqual(2.0, v.value().eval())
|
||||
v = resource_variable_ops.ResourceVariable(1.0)
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.evaluate(v.assign_add(1.0))
|
||||
self.assertEqual(2.0, self.evaluate(v.value()))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testAssignSubMethod(self):
|
||||
with self.test_session():
|
||||
v = resource_variable_ops.ResourceVariable(3.0)
|
||||
variables.global_variables_initializer().run()
|
||||
v.assign_sub(1.0).eval()
|
||||
self.assertEqual(2.0, v.value().eval())
|
||||
v = resource_variable_ops.ResourceVariable(3.0)
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.evaluate(v.assign_sub(1.0))
|
||||
self.assertEqual(2.0, self.evaluate(v.value()))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testDestroyResource(self):
|
||||
with self.test_session() as sess:
|
||||
v = resource_variable_ops.ResourceVariable(3.0)
|
||||
variables.global_variables_initializer().run()
|
||||
self.assertEqual(3.0, v.value().eval())
|
||||
sess.run(resource_variable_ops.destroy_resource_op(v.handle))
|
||||
with self.assertRaises(errors.NotFoundError):
|
||||
v.value().eval()
|
||||
# Handle to a resource not actually created.
|
||||
handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
|
||||
# Should raise no exception
|
||||
sess.run(
|
||||
resource_variable_ops.destroy_resource_op(
|
||||
handle, ignore_lookup_error=True))
|
||||
v = resource_variable_ops.ResourceVariable(3.0)
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.assertEqual(3.0, self.evaluate(v.value()))
|
||||
self.evaluate(resource_variable_ops.destroy_resource_op(v.handle))
|
||||
with self.assertRaises(errors.NotFoundError):
|
||||
self.evaluate(v.value())
|
||||
# Handle to a resource not actually created.
|
||||
handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
|
||||
# Should raise no exception
|
||||
self.evaluate(resource_variable_ops.destroy_resource_op(
|
||||
handle, ignore_lookup_error=True))
|
||||
|
||||
def testAssignDifferentShapes(self):
|
||||
with self.test_session() as sess, variable_scope.variable_scope(
|
||||
|
|
@ -247,12 +240,21 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
|
|||
[assign],
|
||||
feed_dict={placeholder: np.zeros(shape=[2, 2], dtype=np.float32)})
|
||||
|
||||
def testAssignDifferentShapesEager(self):
|
||||
with context.eager_mode():
|
||||
with variable_scope.variable_scope("foo"):
|
||||
var = variable_scope.get_variable("x", shape=[1, 1],
|
||||
dtype=dtypes.float32)
|
||||
assign = var.assign(np.zeros(shape=[2, 2]))
|
||||
self.evaluate(assign)
|
||||
|
||||
def testDtypeAfterFromProto(self):
|
||||
v = resource_variable_ops.ResourceVariable(2.0)
|
||||
w = resource_variable_ops.ResourceVariable.from_proto(v.to_proto())
|
||||
self.assertIsInstance(w.dtype, dtypes.DType)
|
||||
self.assertEqual(v.dtype, w.dtype)
|
||||
|
||||
# TODO(alive): get caching to work in eager mode.
|
||||
def testCachingDevice(self):
|
||||
with ops.device("/job:server/task:1"):
|
||||
v = resource_variable_ops.ResourceVariable(
|
||||
|
|
@ -268,41 +270,47 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
|
|||
with self.assertRaisesRegexp(ValueError, "No attr named '_class'"):
|
||||
_ = w.value().op.get_attr("_class")
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testSharedName(self):
|
||||
with self.test_session():
|
||||
v = resource_variable_ops.ResourceVariable(300.0, name="var1")
|
||||
v.initializer.run()
|
||||
v = resource_variable_ops.ResourceVariable(300.0, name="var1")
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
|
||||
w = resource_variable_ops.var_handle_op(
|
||||
dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var1")
|
||||
w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype)
|
||||
self.assertEqual(300.0, w_read.eval())
|
||||
w = resource_variable_ops.var_handle_op(
|
||||
dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var1")
|
||||
w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype)
|
||||
self.assertEqual(300.0, self.evaluate(w_read))
|
||||
|
||||
x = resource_variable_ops.var_handle_op(
|
||||
dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var1/")
|
||||
x_read = resource_variable_ops.read_variable_op(x, v.dtype.base_dtype)
|
||||
with self.assertRaisesOpError("Resource .*/var1//.* does not exist"):
|
||||
_ = x_read.eval()
|
||||
x = resource_variable_ops.var_handle_op(
|
||||
dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var2")
|
||||
if context.in_graph_mode():
|
||||
with self.assertRaisesOpError("Resource .*/var2/.* does not exist"):
|
||||
x_read = resource_variable_ops.read_variable_op(x, v.dtype.base_dtype)
|
||||
self.evaluate(x_read)
|
||||
else:
|
||||
with self.assertRaisesRegexp(errors.NotFoundError,
|
||||
"Attempted to read a nonexistent variable."):
|
||||
_ = resource_variable_ops.read_variable_op(x, v.dtype.base_dtype)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testSharedNameWithNamescope(self):
|
||||
with self.test_session():
|
||||
with ops.name_scope("foo"):
|
||||
v = resource_variable_ops.ResourceVariable(300.0, name="var1")
|
||||
v.initializer.run()
|
||||
with ops.name_scope("foo"):
|
||||
v = resource_variable_ops.ResourceVariable(300.0, name="var1")
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
|
||||
w = resource_variable_ops.var_handle_op(
|
||||
dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="foo/var1")
|
||||
w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype)
|
||||
self.assertEqual(300.0, w_read.eval())
|
||||
w = resource_variable_ops.var_handle_op(
|
||||
dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="foo/var1")
|
||||
w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype)
|
||||
self.assertEqual(300.0, self.evaluate(w_read))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testShape(self):
|
||||
with self.test_session():
|
||||
v = resource_variable_ops.ResourceVariable(
|
||||
name="var1", initial_value=array_ops.ones(shape=[10, 20, 35]))
|
||||
self.assertEqual("(10, 20, 35)", str(v.shape))
|
||||
self.assertEqual("(10, 20, 35)", str(v.get_shape()))
|
||||
self.assertEqual("(10, 20, 35)", str(v.value().shape))
|
||||
self.assertEqual("(3, 20, 35)", str(v.sparse_read([0, 1, 2]).shape))
|
||||
v = resource_variable_ops.ResourceVariable(
|
||||
name="var1", initial_value=array_ops.ones(shape=[10, 20, 35]))
|
||||
self.assertEqual("(10, 20, 35)", str(v.shape))
|
||||
self.assertEqual("(10, 20, 35)", str(v.get_shape()))
|
||||
self.assertEqual("(10, 20, 35)", str(v.value().shape))
|
||||
self.assertEqual("(3, 20, 35)", str(v.sparse_read([0, 1, 2]).shape))
|
||||
if context.in_graph_mode():
|
||||
self.assertEqual(
|
||||
"<unknown>",
|
||||
str(v.sparse_read(array_ops.placeholder(dtypes.int32)).shape))
|
||||
|
|
@ -329,13 +337,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
|
|||
with self.assertRaisesRegexp(ValueError, "inside a control-flow"):
|
||||
control_flow_ops.while_loop(cond, body, [0, 0])
|
||||
|
||||
|
||||
# TODO(agarwal,apassos): Add more comprehensive tests and/or translate the above
|
||||
# tests to work in both GRAPH and EAGER modes.
|
||||
# TODO(agarwal): Add tests for sparse_read, scatter_sub
|
||||
class ResourceVariableOpsEagerTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testVariable(self):
|
||||
def testVariableEager(self):
|
||||
with context.eager_mode():
|
||||
init = array_ops.ones(shape=[10, 20, 35], dtype=dtypes.int32)
|
||||
constraint = lambda x: x
|
||||
|
|
|
|||
|
|
@ -251,6 +251,8 @@ class ResourceVariable(variables.Variable):
|
|||
name=name)
|
||||
else:
|
||||
initial_value = initial_value()
|
||||
initial_value = ops.convert_to_tensor(
|
||||
initial_value, name="initial_value", dtype=dtype)
|
||||
self._handle = gen_resource_variable_ops.var_handle_op(
|
||||
shape=initial_value.get_shape(),
|
||||
dtype=initial_value.dtype.base_dtype,
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ import traceback
|
|||
import six
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.estimator import util as estimator_util
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
|
|
@ -840,6 +841,7 @@ def no_regularizer(_):
|
|||
return None
|
||||
|
||||
|
||||
# TODO(alive): support caching devices and partitioned variables in Eager mode.
|
||||
class VariableScope(object):
|
||||
"""Variable scope object to carry defaults to provide to `get_variable`.
|
||||
|
||||
|
|
@ -894,6 +896,14 @@ class VariableScope(object):
|
|||
self._dtype = dtype
|
||||
self._use_resource = use_resource
|
||||
self._constraint = constraint
|
||||
if context.in_eager_mode():
|
||||
if self._caching_device is not None:
|
||||
raise NotImplementedError("Caching devices is not yet supported "
|
||||
"in Eager mode.")
|
||||
if self._partitioner is not None:
|
||||
raise NotImplementedError("Partitioned variables are not yet supported "
|
||||
"in Eager mode.")
|
||||
self._use_resource = True
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
|
|
@ -961,10 +971,16 @@ class VariableScope(object):
|
|||
|
||||
def set_caching_device(self, caching_device):
|
||||
"""Set caching_device for this scope."""
|
||||
if context.in_eager_mode():
|
||||
raise NotImplementedError("Partitioned variables are not yet supported "
|
||||
"in Eager mode.")
|
||||
self._caching_device = caching_device
|
||||
|
||||
def set_partitioner(self, partitioner):
|
||||
"""Set partitioner for this scope."""
|
||||
if context.in_eager_mode():
|
||||
raise NotImplementedError("Partitioned variables are not yet supported "
|
||||
"in Eager mode.")
|
||||
self._partitioner = partitioner
|
||||
|
||||
def set_custom_getter(self, custom_getter):
|
||||
|
|
@ -1034,8 +1050,11 @@ class VariableScope(object):
|
|||
constraint = self._constraint
|
||||
if dtype is None:
|
||||
dtype = self._dtype
|
||||
if use_resource is None:
|
||||
use_resource = self._use_resource
|
||||
if context.in_graph_mode():
|
||||
if use_resource is None:
|
||||
use_resource = self._use_resource
|
||||
else:
|
||||
use_resource = True
|
||||
|
||||
return var_store.get_variable(
|
||||
full_name, shape=shape, dtype=dtype, initializer=initializer,
|
||||
|
|
@ -1060,6 +1079,9 @@ class VariableScope(object):
|
|||
use_resource=None,
|
||||
constraint=None):
|
||||
"""Gets an existing variable with this name or create a new one."""
|
||||
if context.in_eager_mode():
|
||||
raise NotImplementedError("Partitioned variables are not yet supported "
|
||||
"in Eager mode.")
|
||||
if initializer is None:
|
||||
initializer = self._initializer
|
||||
if regularizer is None:
|
||||
|
|
|
|||
|
|
@ -1364,7 +1364,7 @@ def variables_initializer(var_list, name="init"):
|
|||
Returns:
|
||||
An Op that run the initializers of all the specified variables.
|
||||
"""
|
||||
if var_list:
|
||||
if var_list and context.in_graph_mode():
|
||||
return control_flow_ops.group(*[v.initializer for v in var_list], name=name)
|
||||
return control_flow_ops.no_op(name=name)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user