mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: So far the we format the epoch name with 6 digits, but this is constraining. In order to have consistent naming, we can simply append the epoch to the suffix. Then we will have consistent naming rules for small and for large epoch numbers. Reviewed By: azzolini Differential Revision: D5653871 fbshipit-source-id: acdf26a14b731347bb85fe2f33c1b89e2ba83bdd
237 lines
9.2 KiB
Python
237 lines
9.2 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
from __future__ import unicode_literals
|
|
|
|
from caffe2.python.schema import Struct, ConstRecord
|
|
from caffe2.python import core, workspace
|
|
from caffe2.python.session import LocalSession
|
|
from caffe2.python.dataset import Dataset
|
|
from caffe2.python.pipeline import pipe
|
|
from caffe2.python.checkpoint import (
|
|
CheckpointManager, MultiNodeCheckpointManager, Job, JobRunner,
|
|
UploadTaskGroupBuilder)
|
|
from caffe2.python.net_builder import ops
|
|
from caffe2.python.task import Node, Task, TaskGroup, WorkspaceType
|
|
from caffe2.python.test_util import TestCase
|
|
from caffe2.python.dataio import ReaderWithLimit
|
|
|
|
import numpy as np
|
|
import os
|
|
import shutil
|
|
import tempfile
|
|
import unittest
|
|
|
|
def build_pipeline(node_id):
|
|
with Node('trainer:%d' % node_id):
|
|
with Job.current().init_group, Task():
|
|
data_arr = Struct(('val', np.array(list(range(10)))))
|
|
data = ConstRecord(ops, data_arr)
|
|
ds = Dataset(data, name='dataset:%d' % node_id)
|
|
full_reader = ds.reader(ops)
|
|
total = ops.Const([100])
|
|
|
|
def inc_total(rec):
|
|
ops.Add([total, rec.val()], [total])
|
|
|
|
epoch_reader = ReaderWithLimit(full_reader, num_iter=3)
|
|
pipe(epoch_reader, processor=inc_total)
|
|
Job.current().add_stop_signal(epoch_reader.data_finished())
|
|
return [total]
|
|
|
|
|
|
EXPECTED_TOTALS = [103, 115, 136, 145]
|
|
|
|
|
|
def local_copy_op(src, dest):
|
|
def copy_op(inputs, outputs):
|
|
shutil.copyfile(src, dest)
|
|
return copy_op
|
|
|
|
|
|
class UploadToLocalFile(UploadTaskGroupBuilder):
|
|
def __init__(self, dest_dir):
|
|
self.dest_dir = dest_dir
|
|
|
|
def build(self, epoch, checkpoint_manager):
|
|
with TaskGroup(WorkspaceType.GLOBAL) as upload_task_group:
|
|
for node, manager in checkpoint_manager._node_managers:
|
|
with Node(str(node)), Task():
|
|
src_path = manager._db_name(epoch)
|
|
dest_path = os.path.join(self.dest_dir, str(node))
|
|
ops.Python((local_copy_op,
|
|
[src_path, dest_path], {}))([], [])
|
|
return upload_task_group
|
|
|
|
class TestCheckpoint(TestCase):
|
|
def run_with(self, builder):
|
|
with Job() as job:
|
|
outputs = build_pipeline(node_id=0)
|
|
output_fetcher = Task(step=core.Net('empty'), outputs=outputs)
|
|
|
|
def fetch_total(session):
|
|
session.run(output_fetcher)
|
|
return output_fetcher.outputs()[0].fetch()
|
|
|
|
session, checkpoint = builder()
|
|
compiled_job = job.compile(LocalSession)
|
|
num_epochs = JobRunner(compiled_job, checkpoint)(session)
|
|
self.assertEquals(num_epochs, len(EXPECTED_TOTALS))
|
|
self.assertEquals(fetch_total(session), EXPECTED_TOTALS[-1])
|
|
|
|
for initial_epoch in range(1, num_epochs + 1):
|
|
session, checkpoint = builder()
|
|
JobRunner(
|
|
compiled_job,
|
|
checkpoint, resume_from_epoch=initial_epoch)(session)
|
|
self.assertEquals(fetch_total(session), EXPECTED_TOTALS[-1])
|
|
|
|
for epoch in range(1, num_epochs + 1):
|
|
session.run(checkpoint.load(epoch))
|
|
self.assertEquals(fetch_total(session), EXPECTED_TOTALS[epoch - 1])
|
|
|
|
def test_single_checkpoint(self):
|
|
# test single node
|
|
try:
|
|
tmpdir = tempfile.mkdtemp()
|
|
|
|
def builder():
|
|
ws = workspace.C.Workspace()
|
|
session = LocalSession(ws)
|
|
checkpoint = CheckpointManager(tmpdir, 'temp_node', 'minidb')
|
|
return session, checkpoint
|
|
|
|
self.run_with(builder)
|
|
finally:
|
|
shutil.rmtree(tmpdir)
|
|
|
|
# test multi-node
|
|
try:
|
|
tmpdir = tempfile.mkdtemp()
|
|
|
|
def builder():
|
|
ws = workspace.C.Workspace()
|
|
session = LocalSession(ws)
|
|
checkpoint = MultiNodeCheckpointManager(tmpdir, 'minidb')
|
|
return session, checkpoint
|
|
|
|
self.run_with(builder)
|
|
finally:
|
|
shutil.rmtree(tmpdir)
|
|
|
|
# Note(wyiming): we are yet to find out why Travis gives out like:
|
|
# E: AssertionError: 'trainer:1/task/GivenTensorInt64Fill:0, a C++ native class of type nullptr (uninitialized).' != array([103])
|
|
# See for example https://travis-ci.org/caffe2/caffe2/jobs/265665119
|
|
# As a result, we will check if this is travis, and if yes, disable it.
|
|
@unittest.skipIf(os.environ.get("TRAVIS"), "DPMTest has a known issue with Travis.")
|
|
def test_load_model_from_checkpoints(self):
|
|
try:
|
|
tmpdir = tempfile.mkdtemp()
|
|
|
|
for node_id in range(3):
|
|
ws = workspace.C.Workspace()
|
|
session = LocalSession(ws)
|
|
checkpoint = MultiNodeCheckpointManager(tmpdir, 'minidb')
|
|
with Job() as job:
|
|
build_pipeline(node_id)
|
|
compiled_job = job.compile(LocalSession)
|
|
job_runner = JobRunner(compiled_job, checkpoint)
|
|
num_epochs = job_runner(session)
|
|
self.assertEquals(num_epochs, len(EXPECTED_TOTALS))
|
|
|
|
# There are 12 global blobs after finishing up the job runner.
|
|
# (only blobs on init_group are checkpointed)
|
|
self.assertEquals(len(ws.blobs), 12)
|
|
|
|
ws = workspace.C.Workspace()
|
|
session = LocalSession(ws)
|
|
self.assertEquals(len(ws.blobs), 0)
|
|
model_blob_names = ['trainer:1/task/GivenTensorInt64Fill:0',
|
|
'trainer:2/task/GivenTensorInt64Fill:0']
|
|
checkpoint = MultiNodeCheckpointManager(tmpdir, 'minidb')
|
|
with Job() as job:
|
|
for node_id in range(3):
|
|
build_pipeline(node_id)
|
|
compiled_job = job.compile(LocalSession)
|
|
job_runner = JobRunner(compiled_job, checkpoint)
|
|
job_runner.load_blobs_from_checkpoints(blob_names=model_blob_names,
|
|
epoch=1, session=session)
|
|
|
|
# Check that we can successfully load from checkpoints of epochs
|
|
# 1 to 4, but not epoch 5.
|
|
for epoch in range(1, 5):
|
|
self.assertTrue(
|
|
job_runner.load_blobs_from_checkpoints(
|
|
blob_names=model_blob_names, epoch=epoch,
|
|
session=session))
|
|
# Check that all the model blobs are loaded.
|
|
for blob_name in model_blob_names:
|
|
self.assertTrue(ws.has_blob(blob_name))
|
|
self.assertEquals(ws.fetch_blob(blob_name),
|
|
np.array([EXPECTED_TOTALS[epoch - 1]]))
|
|
self.assertFalse(
|
|
job_runner.load_blobs_from_checkpoints(
|
|
blob_names=model_blob_names, epoch=5, session=session))
|
|
|
|
finally:
|
|
shutil.rmtree(tmpdir)
|
|
|
|
def test_get_ckpt_db_name(self):
|
|
try:
|
|
tmpdir = tempfile.mkdtemp()
|
|
num_nodes = 3
|
|
checkpoint = MultiNodeCheckpointManager(tmpdir, 'minidb')
|
|
with Job() as job:
|
|
for node_id in range(num_nodes):
|
|
build_pipeline(node_id)
|
|
compiled_job = job.compile(LocalSession)
|
|
checkpoint.init(compiled_job.nodes_to_checkpoint())
|
|
|
|
for node_id in range(num_nodes):
|
|
epoch = 5
|
|
node_name = 'trainer:%d' % node_id
|
|
expected_db_name = tmpdir + '/' + node_name + '.5'
|
|
self.assertEquals(
|
|
checkpoint.get_ckpt_db_name(node_name, epoch),
|
|
expected_db_name)
|
|
|
|
finally:
|
|
shutil.rmtree(tmpdir)
|
|
|
|
def test_upload_checkpoint(self):
|
|
try:
|
|
tmpdir = tempfile.mkdtemp()
|
|
upload_dir = os.path.join(tmpdir, "upload")
|
|
os.mkdir(upload_dir)
|
|
num_nodes = 3
|
|
|
|
# The uploaded files do not exist yet.
|
|
for node_id in range(num_nodes):
|
|
node_name = 'trainer:%d' % node_id
|
|
upload_path = os.path.join(upload_dir, node_name)
|
|
self.assertFalse(os.path.exists(upload_path))
|
|
|
|
# Create and run the job runner.
|
|
for node_id in range(3):
|
|
ws = workspace.C.Workspace()
|
|
session = LocalSession(ws)
|
|
checkpoint = MultiNodeCheckpointManager(tmpdir, 'minidb')
|
|
with Job() as job:
|
|
build_pipeline(node_id)
|
|
compiled_job = job.compile(LocalSession)
|
|
local_upload_builder = UploadToLocalFile(upload_dir)
|
|
job_runner = JobRunner(
|
|
compiled_job, checkpoint,
|
|
upload_task_group_builder=local_upload_builder)
|
|
num_epochs = job_runner(session)
|
|
self.assertEquals(num_epochs, len(EXPECTED_TOTALS))
|
|
|
|
# The uploaded files should exist now.
|
|
for node_id in range(num_nodes):
|
|
node_name = 'trainer:%d' % node_id
|
|
upload_path = os.path.join(upload_dir, node_name)
|
|
self.assertTrue(os.path.exists(upload_path))
|
|
|
|
finally:
|
|
shutil.rmtree(tmpdir)
|