pytorch/caffe2/python/checkpoint_test.py
Bor-Yiing Su 8f9cd757db Skips the initialization phase of the individual checkpoint objects.
Summary:
The initialization phase of each checkpoint object simply loads the nanmes of
the blobs in the checkpoints. When we load from the checkpoints, the names of
the blobs are given. We can skip this init step.

Reviewed By: azzolini

Differential Revision: D4808114

fbshipit-source-id: 4c740049c1014f3e93b4b87f43e3937afdefa25a
2017-03-31 10:10:56 -07:00

134 lines
4.9 KiB
Python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from caffe2.python.schema import Struct, ConstRecord
from caffe2.python import core, workspace
from caffe2.python.session import LocalSession
from caffe2.python.dataset import Dataset
from caffe2.python.pipeline import pipe
from caffe2.python.checkpoint import (
CheckpointManager, MultiNodeCheckpointManager, Job, JobRunner)
from caffe2.python.net_builder import ops
from caffe2.python.task import Task, Node
from caffe2.python.test_util import TestCase
from caffe2.python.dataio import ReaderWithLimit
import tempfile
import numpy as np
import shutil
def build_pipeline(node_id):
with Node('reader:%d' % node_id):
with Job.current().init_group, Task():
data_arr = Struct(('val', np.array(range(10))))
data = ConstRecord(ops, data_arr)
ds = Dataset(data, name='dataset:%d' % node_id)
full_reader = ds.reader(ops)
total = ops.Const([100])
def inc_total(rec):
ops.Add([total, rec.val()], [total])
epoch_reader = ReaderWithLimit(full_reader, num_iter=3)
pipe(epoch_reader, processor=inc_total)
Job.current().add_stop_signal(epoch_reader.data_finished())
return [total]
EXPECTED_TOTALS = [103, 115, 136, 145]
class TestCheckpoint(TestCase):
def run_with(self, builder):
with Job() as job:
outputs = build_pipeline(node_id=0)
output_fetcher = Task(step=core.Net('empty'), outputs=outputs)
def fetch_total(session):
session.run(output_fetcher)
return output_fetcher.outputs()[0].fetch()
session, checkpoint = builder()
compiled_job = job.compile(LocalSession)
num_epochs = JobRunner(compiled_job, checkpoint)(session)
self.assertEquals(num_epochs, len(EXPECTED_TOTALS))
self.assertEquals(fetch_total(session), EXPECTED_TOTALS[-1])
for initial_epoch in range(1, num_epochs + 1):
session, checkpoint = builder()
JobRunner(
compiled_job,
checkpoint, resume_from_epoch=initial_epoch)(session)
self.assertEquals(fetch_total(session), EXPECTED_TOTALS[-1])
for epoch in range(1, num_epochs + 1):
session.run(checkpoint.load(epoch))
self.assertEquals(fetch_total(session), EXPECTED_TOTALS[epoch - 1])
def test_single_checkpoint(self):
# test single node
with tempfile.NamedTemporaryFile() as tmp:
def builder():
ws = workspace.C.Workspace()
session = LocalSession(ws)
checkpoint = CheckpointManager(tmp.name, 'minidb')
return session, checkpoint
self.run_with(builder)
# test multi-node
try:
tmpdir = tempfile.mkdtemp()
def builder():
ws = workspace.C.Workspace()
session = LocalSession(ws)
checkpoint = MultiNodeCheckpointManager(tmpdir, 'minidb')
return session, checkpoint
self.run_with(builder)
finally:
shutil.rmtree(tmpdir)
def test_load_model_from_checkpoints(self):
try:
tmpdir = tempfile.mkdtemp()
for node_id in range(3):
ws = workspace.C.Workspace()
session = LocalSession(ws)
checkpoint = MultiNodeCheckpointManager(tmpdir, 'minidb')
with Job() as job:
build_pipeline(node_id)
compiled_job = job.compile(LocalSession)
job_runner = JobRunner(compiled_job, checkpoint)
num_epochs = job_runner(session)
self.assertEquals(num_epochs, len(EXPECTED_TOTALS))
# There are 16 blobs after finishing up the job runner.
self.assertEquals(len(ws.blobs), 16)
ws = workspace.C.Workspace()
session = LocalSession(ws)
self.assertEquals(len(ws.blobs), 0)
model_blob_names = ['reader:1/task/GivenTensorInt64Fill:0',
'reader:2/task/GivenTensorInt64Fill:0']
checkpoint = MultiNodeCheckpointManager(tmpdir, 'minidb')
with Job() as job:
for node_id in range(3):
build_pipeline(node_id)
compiled_job = job.compile(LocalSession)
job_runner = JobRunner(compiled_job, checkpoint)
job_runner.load_blobs_from_checkpoints(blob_names=model_blob_names,
epoch=1, session=session)
# Check that all the model blobs are loaded.
for blob_name in model_blob_names:
self.assertTrue(ws.has_blob(blob_name))
self.assertEquals(ws.fetch_blob(blob_name), np.array([103]))
finally:
shutil.rmtree(tmpdir)