pytorch/caffe2/python/checkpoint_test.py
Bor-Yiing Su 7fa4acab9b Loads only the model blobs from the checkpoints.
Summary:
To evaluate from checkpoints, we need to load a model from the checkpoints.
However, the checkpoints store way more blobs than the blobs needed by the
model. This function enables the model builder to load only the blobs
associated with the model to the workspace. After that, the model builder
can evaluate the model from the populated workspace.

Reviewed By: azzolini

Differential Revision: D4751414

fbshipit-source-id: a7a420228d681fc2dcfd8573cf69a97b1abc2ef3
2017-03-27 10:02:11 -07:00

134 lines
5.1 KiB
Python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from caffe2.python.schema import Struct, ConstRecord
from caffe2.python import core, workspace
from caffe2.python.session import LocalSession
from caffe2.python.dataset import Dataset
from caffe2.python.pipeline import pipe
from caffe2.python.checkpoint import (
CheckpointManager, MultiNodeCheckpointManager, Job, JobRunner)
from caffe2.python.task import Task, Node
from caffe2.python.test_util import TestCase
from caffe2.python.dataio import ReaderWithLimit
import tempfile
import numpy as np
import shutil
def build_job(num_nodes):
all_outputs = []
with Job() as job:
for node_id in range(num_nodes):
with Node('reader' + str(node_id)):
with job.init_group:
init_net = core.Net('init_net' + str(node_id))
data_arr = Struct(('val', np.array(range(10))))
data = ConstRecord(init_net, data_arr)
ds = Dataset(data, name='dataset' + str(node_id))
full_reader = ds.reader(init_net)
total = init_net.Const([100])
Task(step=init_net)
def inc_total(rec):
net = core.Net('inc_total' + str(node_id))
net.Add([total, rec.val()], [total])
return [net]
epoch_reader = ReaderWithLimit(full_reader, num_iter=3)
pipe(epoch_reader, processor=inc_total)
job.add_stop_signal(epoch_reader.data_finished())
all_outputs.append(total)
total_fetcher = Task(step=core.Net('empty'), outputs=all_outputs)
return job, total_fetcher
EXPECTED_TOTALS = [103, 115, 136, 145]
class TestCheckpoint(TestCase):
def run_with(self, builder):
job, output_fetcher = build_job(num_nodes=1)
def fetch_total(session):
session.run(output_fetcher)
return output_fetcher.outputs()[0].fetch()
session, checkpoint = builder()
compiled_job = job.compile(LocalSession)
num_epochs = JobRunner(compiled_job, checkpoint)(session)
self.assertEquals(num_epochs, len(EXPECTED_TOTALS))
self.assertEquals(fetch_total(session), EXPECTED_TOTALS[-1])
for initial_epoch in range(1, num_epochs + 1):
session, checkpoint = builder()
JobRunner(
compiled_job,
checkpoint, resume_from_epoch=initial_epoch)(session)
self.assertEquals(fetch_total(session), EXPECTED_TOTALS[-1])
for epoch in range(1, num_epochs + 1):
session.run(checkpoint.load(epoch))
self.assertEquals(fetch_total(session), EXPECTED_TOTALS[epoch - 1])
def test_single_checkpoint(self):
# test single node
with tempfile.NamedTemporaryFile() as tmp:
def builder():
ws = workspace.C.Workspace()
session = LocalSession(ws)
checkpoint = CheckpointManager(tmp.name, 'minidb')
return session, checkpoint
self.run_with(builder)
# test multi-node
try:
tmpdir = tempfile.mkdtemp()
def builder():
ws = workspace.C.Workspace()
session = LocalSession(ws)
checkpoint = MultiNodeCheckpointManager(tmpdir, 'minidb')
return session, checkpoint
self.run_with(builder)
finally:
shutil.rmtree(tmpdir)
def test_load_model_from_checkpoints(self):
try:
tmpdir = tempfile.mkdtemp()
ws = workspace.C.Workspace()
session = LocalSession(ws)
checkpoint = MultiNodeCheckpointManager(tmpdir, 'minidb')
job, output_fetcher = build_job(num_nodes=3)
compiled_job = job.compile(LocalSession)
job_runner = JobRunner(compiled_job, checkpoint)
num_epochs = job_runner(session)
self.assertEquals(num_epochs, len(EXPECTED_TOTALS))
# There are 44 blobs after finishing up the job runner.
self.assertEquals(len(ws.blobs), 44)
ws = workspace.C.Workspace()
session = LocalSession(ws)
self.assertEquals(len(ws.blobs), 0)
model_blob_names = ['init_net0/GivenTensorInt64Fill:0',
'init_net1/GivenTensorInt64Fill:0']
job_runner.load_blobs_from_checkpoints(blob_names=model_blob_names,
epoch=1, session=session)
# In addition to the two model blobs, we also have 3 output blobs
# and one runnable blob. So there are 6 blobs in total.
self.assertEquals(len(ws.blobs), 6)
# Check that all the model blobs are loaded.
for blob_name in model_blob_names:
self.assertTrue(ws.has_blob(blob_name))
self.assertEquals(ws.fetch_blob(blob_name), np.array([103]))
finally:
shutil.rmtree(tmpdir)