mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 00:20:20 +01:00
Test save/restore variable from graph_callable.
PiperOrigin-RevId: 171051237
This commit is contained in:
parent
cf17ec96ed
commit
3cf41b2edd
|
|
@ -81,6 +81,7 @@ cuda_py_test(
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:client",
|
"//tensorflow/python:client",
|
||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
|
"//tensorflow/python/eager:graph_callable",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
"//tensorflow/python:variables",
|
"//tensorflow/python:variables",
|
||||||
],
|
],
|
||||||
|
|
|
||||||
|
|
@ -21,10 +21,14 @@ import os
|
||||||
|
|
||||||
from tensorflow.contrib.eager.python import saver as _saver
|
from tensorflow.contrib.eager.python import saver as _saver
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
|
from tensorflow.python.eager import graph_callable
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import init_ops
|
||||||
from tensorflow.python.ops import resource_variable_ops
|
from tensorflow.python.ops import resource_variable_ops
|
||||||
|
from tensorflow.python.ops import variable_scope
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -87,6 +91,53 @@ class SaverTest(test.TestCase):
|
||||||
with _saver.restore_variables_on_create(ckpt_prefix):
|
with _saver.restore_variables_on_create(ckpt_prefix):
|
||||||
_ = model(resource_variable_ops.ResourceVariable(1.0, name='v2'))
|
_ = model(resource_variable_ops.ResourceVariable(1.0, name='v2'))
|
||||||
|
|
||||||
|
def testSaveRestoreGraphCallable(self):
|
||||||
|
with context.eager_mode(), ops.device(self._dev()):
|
||||||
|
@graph_callable.graph_callable(
|
||||||
|
[graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)])
|
||||||
|
def model(x):
|
||||||
|
v = variable_scope.get_variable(
|
||||||
|
'v', initializer=init_ops.zeros_initializer(), shape=())
|
||||||
|
return v + x
|
||||||
|
|
||||||
|
# Default 2 + 0 = 2
|
||||||
|
self.assertEqual(
|
||||||
|
2, model(array_ops.constant(2, dtype=dtypes.float32)).numpy())
|
||||||
|
|
||||||
|
# Save the variable value 0.
|
||||||
|
ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt')
|
||||||
|
_saver.Saver(model.variables).save(ckpt_prefix)
|
||||||
|
|
||||||
|
# update variable to 1, so that 2 + 1 = 3
|
||||||
|
model.variables[0].assign(1.)
|
||||||
|
self.assertEqual(
|
||||||
|
3, model(array_ops.constant(2, dtype=dtypes.float32)).numpy())
|
||||||
|
|
||||||
|
# load the variable value 0, so that 2 + 0 = 2
|
||||||
|
_saver.Saver(model.variables).restore(ckpt_prefix)
|
||||||
|
self.assertEqual(
|
||||||
|
2, model(array_ops.constant(2, dtype=dtypes.float32)).numpy())
|
||||||
|
|
||||||
|
# update checkpoint variable to 1 and memory value to 2.
|
||||||
|
model.variables[0].assign(1.)
|
||||||
|
_saver.Saver(model.variables).save(ckpt_prefix)
|
||||||
|
model.variables[0].assign(2.)
|
||||||
|
self.assertEqual(
|
||||||
|
4, model(array_ops.constant(2, dtype=dtypes.float32)).numpy())
|
||||||
|
|
||||||
|
# reset the graph and reload on create, so that 1 + 2 = 3
|
||||||
|
with ops.Graph().as_default():
|
||||||
|
with _saver.restore_variables_on_create(ckpt_prefix):
|
||||||
|
@graph_callable.graph_callable(
|
||||||
|
[graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)])
|
||||||
|
def model2(x):
|
||||||
|
v = variable_scope.get_variable(
|
||||||
|
'v', initializer=init_ops.zeros_initializer(), shape=())
|
||||||
|
return v + x
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
3, model2(array_ops.constant(2, dtype=dtypes.float32)).numpy())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test.main()
|
test.main()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user