Making load() work for resource variables.

PiperOrigin-RevId: 158205361
This commit is contained in:
Alexandre Passos 2017-06-06 16:13:00 -07:00 committed by TensorFlower Gardener
parent 05412bd367
commit 7f5384dccf
2 changed files with 11 additions and 4 deletions

View File

@ -153,6 +153,13 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
v.assign(2.0).eval()
self.assertEqual(2.0, v.value().eval())
def testLoad(self):
with self.test_session():
v = resource_variable_ops.ResourceVariable(1.0)
variables.global_variables_initializer().run()
v.load(2.0)
self.assertEqual(2.0, v.value().eval())
def testToFromProto(self):
with self.test_session():
v = resource_variable_ops.ResourceVariable(1.0)

View File

@ -203,7 +203,7 @@ class ResourceVariable(variables.Variable):
gen_resource_variable_ops.var_is_initialized_op(self._handle))
if initial_value is not None:
with ops.name_scope("Assign") as n, ops.colocate_with(self._handle):
self._initialize_op = gen_resource_variable_ops.assign_variable_op(
self._initializer_op = gen_resource_variable_ops.assign_variable_op(
self._handle, self._initial_value, name=n)
with ops.name_scope("Read"), ops.colocate_with(self._handle):
# Manually assign reads to the handle's device to avoid log messages.
@ -237,7 +237,7 @@ class ResourceVariable(variables.Variable):
self._handle = g.as_graph_element(
ops.prepend_name_scope(variable_def.variable_name,
import_scope=import_scope))
self._initialize_op = g.as_graph_element(
self._initializer_op = g.as_graph_element(
ops.prepend_name_scope(variable_def.initializer_name,
import_scope=import_scope))
if variable_def.snapshot_name:
@ -282,7 +282,7 @@ class ResourceVariable(variables.Variable):
@property
def create(self):
"""The op responsible for initializing this variable."""
return self._initialize_op
return self._initializer_op
@property
def handle(self):
@ -305,7 +305,7 @@ class ResourceVariable(variables.Variable):
@property
def initializer(self):
"""The op responsible for initializing this variable."""
return self._initialize_op
return self._initializer_op
@property
def initial_value(self):