pytorch/caffe2/python/operator_test/checkpoint_test.py
Yangqing Jia 4858a6bc6f snapshot -> checkpoint
Summary:
This renames the "Snapshot" op name to "Checkpoint" as we discussed earlier.

The early Snapshot name is still available, but we should move to the new name and
eventually deprecate the old name.

The Python SnapshotManager should be also changed, cc azzolini

Reviewed By: dzhulgakov

Differential Revision: D4272021

fbshipit-source-id: 4b8e029354416530dfbf0d538bfc91a0f61e0296
2016-12-15 12:01:30 -08:00

45 lines
1.5 KiB
Python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from caffe2.python import core, workspace
import os
import shutil
import tempfile
import unittest
class CheckpointTest(unittest.TestCase):
"""A simple test case to make sure that the checkpoint behavior is correct.
"""
def testCheckpoint(self):
temp_root = tempfile.mkdtemp()
net = core.Net("test_checkpoint")
# Note(jiayq): I am being a bit lazy here and am using the old iter
# convention that does not have an input. Optionally change it to the
# new style if needed.
net.Iter([], "iter")
net.ConstantFill([], "value", shape=[1, 2, 3])
net.Checkpoint(["iter", "value"], [],
db=os.path.join(temp_root, "test_checkpoint_at_%05d"),
db_type="leveldb", every=10, absolute_path=True)
self.assertTrue(workspace.CreateNet(net))
for i in range(100):
self.assertTrue(workspace.RunNet("test_checkpoint"))
for i in range(1, 10):
# Print statements are only for debugging purposes.
# print("Asserting %d" % i)
# print(os.path.join(temp_root, "test_checkpoint_at_%05d" % (i * 10)))
self.assertTrue(os.path.exists(
os.path.join(temp_root, "test_checkpoint_at_%05d" % (i * 10))))
# Finally, clean up.
shutil.rmtree(temp_root)
if __name__ == "__main__":
import unittest
unittest.main()