mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Test save/restore variable from graph_callable.
PiperOrigin-RevId: 171051237
This commit is contained in:
parent
cf17ec96ed
commit
3cf41b2edd
|
|
@ -81,6 +81,7 @@ cuda_py_test(
|
|||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python/eager:graph_callable",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:variables",
|
||||
],
|
||||
|
|
|
|||
|
|
@ -21,10 +21,14 @@ import os
|
|||
|
||||
from tensorflow.contrib.eager.python import saver as _saver
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import graph_callable
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
|
|
@ -87,6 +91,53 @@ class SaverTest(test.TestCase):
|
|||
with _saver.restore_variables_on_create(ckpt_prefix):
|
||||
_ = model(resource_variable_ops.ResourceVariable(1.0, name='v2'))
|
||||
|
||||
def testSaveRestoreGraphCallable(self):
|
||||
with context.eager_mode(), ops.device(self._dev()):
|
||||
@graph_callable.graph_callable(
|
||||
[graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)])
|
||||
def model(x):
|
||||
v = variable_scope.get_variable(
|
||||
'v', initializer=init_ops.zeros_initializer(), shape=())
|
||||
return v + x
|
||||
|
||||
# Default 2 + 0 = 2
|
||||
self.assertEqual(
|
||||
2, model(array_ops.constant(2, dtype=dtypes.float32)).numpy())
|
||||
|
||||
# Save the variable value 0.
|
||||
ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt')
|
||||
_saver.Saver(model.variables).save(ckpt_prefix)
|
||||
|
||||
# update variable to 1, so that 2 + 1 = 3
|
||||
model.variables[0].assign(1.)
|
||||
self.assertEqual(
|
||||
3, model(array_ops.constant(2, dtype=dtypes.float32)).numpy())
|
||||
|
||||
# load the variable value 0, so that 2 + 0 = 2
|
||||
_saver.Saver(model.variables).restore(ckpt_prefix)
|
||||
self.assertEqual(
|
||||
2, model(array_ops.constant(2, dtype=dtypes.float32)).numpy())
|
||||
|
||||
# update checkpoint variable to 1 and memory value to 2.
|
||||
model.variables[0].assign(1.)
|
||||
_saver.Saver(model.variables).save(ckpt_prefix)
|
||||
model.variables[0].assign(2.)
|
||||
self.assertEqual(
|
||||
4, model(array_ops.constant(2, dtype=dtypes.float32)).numpy())
|
||||
|
||||
# reset the graph and reload on create, so that 1 + 2 = 3
|
||||
with ops.Graph().as_default():
|
||||
with _saver.restore_variables_on_create(ckpt_prefix):
|
||||
@graph_callable.graph_callable(
|
||||
[graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)])
|
||||
def model2(x):
|
||||
v = variable_scope.get_variable(
|
||||
'v', initializer=init_ops.zeros_initializer(), shape=())
|
||||
return v + x
|
||||
|
||||
self.assertEqual(
|
||||
3, model2(array_ops.constant(2, dtype=dtypes.float32)).numpy())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user