from __future__ import absolute_import from __future__ import division from __future__ import print_function from __future__ import unicode_literals import os import logging from caffe2.python import core, context from caffe2.python.task import Node, Task, TaskGroup, TaskOutput, WorkspaceType logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @context.define_context() class Job(object): """ A Job defines three TaskGroups: the `init_group`, the `epoch_group` and the `exit_group` which will be run by a JobRunner. The `init_group` will be run only once at startup. Its role is to initialize globally persistent blobs such as model weights, accumulators and data file lists. The `epoch_group` will be run in a loop after init_group. The loop will exit when any of the stop signals added with `add_stop_signal` is True at the end of an epoch. The `exit_group` will be run only once at the very end of the job, when one of the stopping criterias for `epoch_group` was met. The role of this group is save the results of training in the end of the job. Jobs are context-driven, so that Tasks can be added to the active Job without having to explicitly pass the job object around. Example of usage: def build_reader(partitions): with Job.current().init_group: reader = HiveReader(init_reader, ..., partitions) Task(step=init_reader) with Job.current().epoch_group: limited_reader = ReaderWithLimit(reader, num_iter=10000) data_queue = pipe(limited_reader, num_threads=8) Job.current().add_stop_signal(limited_reader.data_finished()) return data_queue def build_hogwild_trainer(reader, model): with Job.current().init_group: Task(step=model.param_init_net) with Job.current().epoch_group: pipe(reader, processor=model, num_threads=8) with Job.current().exit_group: Task(step=model.save_model_net) with Job() as job: reader = build_reader(partitions) model = build_model(params) build_hogwild_trainer(reader, model) """ def __init__(self, init_group=None, epoch_group=None, exit_group=None, stop_signals=None, nodes_to_checkpoint=None): self.init_group = init_group or TaskGroup( workspace_type=WorkspaceType.GLOBAL) self.epoch_group = epoch_group or TaskGroup() self.exit_group = exit_group or TaskGroup() self.stop_signals = stop_signals or [] self._nodes_to_checkpoint = nodes_to_checkpoint def nodes_to_checkpoint(self): if self._nodes_to_checkpoint: return self._nodes_to_checkpoint else: return self.init_group.used_nodes() def compile(self, session_class): return Job( init_group=session_class.compile(self.init_group), epoch_group=session_class.compile(self.epoch_group), exit_group=session_class.compile(self.exit_group), stop_signals=self.stop_signals, nodes_to_checkpoint=self.nodes_to_checkpoint()) def __enter__(self): self.epoch_group.__enter__() return self def __exit__(self, *args): self.epoch_group.__exit__() def add_stop_signal(self, output): if isinstance(output, core.BlobReference): t = Task(outputs=[output], group=self.epoch_group) output = t.outputs()[0] assert isinstance(output, TaskOutput) self.stop_signals.append(output) class CheckpointManager(object): """ Controls saving and loading of workspaces on every epoch boundary of a job. If a CheckpointManager instance is passed to JobRunner, then JobRunner will call `init`, `read` and `save` at different moments in between epoch runs. """ def __init__(self, db, db_type): self._db = db self._db_type = db_type # make sure these blobs are the first in the checkpoint file. self._net = core.Net('!!checkpoint_mngr') self._blob_names = self._net.AddExternalInput('blob_names') self._names_output = None def init(self, nodes=None, retrieve_from_epoch=None): """ Build a Task that will be run once after the job's `init_group` is run. This task will determine which blobs need to be checkpointed. If retrieve_from_epoch is not None, then the checkpoint metadata is retrieved from a previously saved checkpoint. """ assert nodes is None or len(nodes) == 1, ( 'CheckpointManager only supports single node.') net = core.Net('get_blob_list') if retrieve_from_epoch is None: net.GetAllBlobNames( [], self._blob_names, include_shared=False) else: net.Load( [], self._blob_names, db=self._dbname(retrieve_from_epoch), db_type=self._db_type, absolute_path=True) task = Task(step=net, outputs=[self._blob_names]) self._names_output = task.outputs()[0] return task def blob_list(self): assert self._names_output return self._names_output.fetch().tolist() def _dbname(self, epoch): return '%s.%06d' % (self._db, epoch) def load(self, epoch): """ Build a Task that will be run by JobRunner when the job is to be resumed from a given epoch. This task will run a Load op that will load and deserialize all relevant blobs from a persistent storage. """ net = core.Net('get_blob_list') net.Load( [], self.blob_list(), db=self._dbname(epoch), db_type=self._db_type, absolute_path=True) return Task(step=net) def save(self, epoch): """ Build a Task that is run once after `init_group` and after each epoch is run. This will execute a Save ops to serialize and persist blobs present in the global workspaace. """ net = core.Net('checkpoint_save') net.Save( self.blob_list(), [], db=self._dbname(epoch), db_type=self._db_type, absolute_path=True) return Task(step=net) class MultiNodeCheckpointManager(object): """ Coordinates checkpointing and checkpointing across multiple nodes. Each of `init`, `load` and `save` will build TaskGroups which will trigger checkpointing on each of the nodes involved in a distributed job. """ def __init__( self, db_prefix, db_type, node_manager_class=CheckpointManager): self._node_manager_class = node_manager_class self._node_managers = None self._db_prefix = db_prefix self._db_type = db_type def _task_group(self, func, *args, **kw): assert self._node_managers is not None, 'init must be called first.' with TaskGroup(WorkspaceType.GLOBAL) as task_group: for node, manager in self._node_managers: with Node(node): func(manager, *args, **kw) return task_group def init(self, nodes, retrieve_from_epoch=None): if self._node_managers is not None: assert [node for node, _ in self._node_managers] == nodes return self._node_managers = [] for node in nodes: with Node(node): manager = self._node_manager_class( db=os.path.join(self._db_prefix, node), db_type=self._db_type) self._node_managers.append((node, manager)) return self._task_group( self._node_manager_class.init, nodes=[node], retrieve_from_epoch=retrieve_from_epoch) def load(self, epoch): return self._task_group(self._node_manager_class.load, epoch) def save(self, epoch): return self._task_group(self._node_manager_class.save, epoch) class JobRunner(object): """ Implement the runtime logic for jobs with checkpointing at the level of epoch. Can be used to run either single-host or distributed jobs. Job runner is a callable to be called once from the client, passing a Session as argument. This call will block until the Job execution is complete. If a checkpoint_manager is passed, checkpoints will be taken after initialization and after each epoch execution. If, in addition, `resume_from_epoch` is an epoch number, the corresponding checkpoint will be loaded and job execution will continue from the given epoch. In this case, the job's init_group will not be run. Refer to checkpoint_test.py for an example. """ def __init__(self, job, checkpoint_manager=None, resume_from_epoch=None): self.resume_from_epoch = resume_from_epoch self.checkpoint = checkpoint_manager self.job = job def __call__(self, client): from_scratch = self.resume_from_epoch is None if from_scratch: client.run(self.job.init_group) if self.checkpoint: logger.info('Preparing checkpoint ...') client.run(self.checkpoint.init( self.job.nodes_to_checkpoint(), retrieve_from_epoch=self.resume_from_epoch)) if from_scratch: logger.info('Saving first checkpoint ...') client.run(self.checkpoint.save(0)) logger.info('First checkpoint saved.') else: logger.info('Loading checkpoint for epoch {} ...'.format( self.resume_from_epoch)) client.run(self.checkpoint.load(self.resume_from_epoch)) logger.info('Checkpoint loaded.') epoch = 1 if from_scratch else self.resume_from_epoch + 1 while True: logger.info('Starting epoch %d.' % epoch) client.run(self.job.epoch_group) logger.info('Ran epoch %d.' % epoch) stop_signals = [o.fetch() for o in self.job.stop_signals] if self.checkpoint: logger.info('Saving checkpoint ...') client.run(self.checkpoint.save(epoch)) logger.info('Checkpoint saved.') if any(stop_signals): logger.info('Stopping.') break epoch += 1 client.run(self.job.exit_group) return epoch def epoch_limiter(num_epochs): """ Creates a task that will output True when a given number of epochs has finished. """ with Job.current().init_group: init_net = core.Net('epoch_counter_init') counter = init_net.CreateCounter([], init_count=num_epochs - 1) Task(step=init_net) epoch_net = core.Net('epoch_countdown') finished = epoch_net.CountDown(counter) output = Task(step=epoch_net, outputs=finished).outputs()[0] Job.current().add_stop_signal(output)