mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: This renames the "Snapshot" op name to "Checkpoint" as we discussed earlier. The early Snapshot name is still available, but we should move to the new name and eventually deprecate the old name. The Python SnapshotManager should be also changed, cc azzolini Reviewed By: dzhulgakov Differential Revision: D4272021 fbshipit-source-id: 4b8e029354416530dfbf0d538bfc91a0f61e0296
45 lines
1.5 KiB
Python
45 lines
1.5 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
from __future__ import unicode_literals
|
|
|
|
from caffe2.python import core, workspace
|
|
import os
|
|
import shutil
|
|
import tempfile
|
|
import unittest
|
|
|
|
|
|
class CheckpointTest(unittest.TestCase):
|
|
"""A simple test case to make sure that the checkpoint behavior is correct.
|
|
"""
|
|
|
|
def testCheckpoint(self):
|
|
temp_root = tempfile.mkdtemp()
|
|
net = core.Net("test_checkpoint")
|
|
# Note(jiayq): I am being a bit lazy here and am using the old iter
|
|
# convention that does not have an input. Optionally change it to the
|
|
# new style if needed.
|
|
net.Iter([], "iter")
|
|
net.ConstantFill([], "value", shape=[1, 2, 3])
|
|
net.Checkpoint(["iter", "value"], [],
|
|
db=os.path.join(temp_root, "test_checkpoint_at_%05d"),
|
|
db_type="leveldb", every=10, absolute_path=True)
|
|
self.assertTrue(workspace.CreateNet(net))
|
|
for i in range(100):
|
|
self.assertTrue(workspace.RunNet("test_checkpoint"))
|
|
for i in range(1, 10):
|
|
# Print statements are only for debugging purposes.
|
|
# print("Asserting %d" % i)
|
|
# print(os.path.join(temp_root, "test_checkpoint_at_%05d" % (i * 10)))
|
|
self.assertTrue(os.path.exists(
|
|
os.path.join(temp_root, "test_checkpoint_at_%05d" % (i * 10))))
|
|
|
|
# Finally, clean up.
|
|
shutil.rmtree(temp_root)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import unittest
|
|
unittest.main()
|