mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18622 Set up correct inheritance structure for remaining operator test classes Reviewed By: ezyang Differential Revision: D14685941 fbshipit-source-id: a6b1b3be325935b7fec7515be13a4994b3016bf0
44 lines
1.5 KiB
Python
44 lines
1.5 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
from __future__ import unicode_literals
|
|
|
|
from caffe2.python import core, workspace, test_util
|
|
import os
|
|
import shutil
|
|
import tempfile
|
|
import unittest
|
|
|
|
|
|
class CheckpointTest(test_util.TestCase):
|
|
"""A simple test case to make sure that the checkpoint behavior is correct.
|
|
"""
|
|
|
|
def testCheckpoint(self):
|
|
temp_root = tempfile.mkdtemp()
|
|
net = core.Net("test_checkpoint")
|
|
# Note(jiayq): I am being a bit lazy here and am using the old iter
|
|
# convention that does not have an input. Optionally change it to the
|
|
# new style if needed.
|
|
net.Iter([], "iter")
|
|
net.ConstantFill([], "value", shape=[1, 2, 3])
|
|
net.Checkpoint(["iter", "value"], [],
|
|
db=os.path.join(temp_root, "test_checkpoint_at_%05d"),
|
|
db_type="leveldb", every=10, absolute_path=True)
|
|
self.assertTrue(workspace.CreateNet(net))
|
|
for i in range(100):
|
|
self.assertTrue(workspace.RunNet("test_checkpoint"))
|
|
for i in range(1, 10):
|
|
# Print statements are only for debugging purposes.
|
|
# print("Asserting %d" % i)
|
|
# print(os.path.join(temp_root, "test_checkpoint_at_%05d" % (i * 10)))
|
|
self.assertTrue(os.path.exists(
|
|
os.path.join(temp_root, "test_checkpoint_at_%05d" % (i * 10))))
|
|
|
|
# Finally, clean up.
|
|
shutil.rmtree(temp_root)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|