Fix issues pickling jobs

Summary:
We were running into a problem where a Job could not be pickled. It needs to be pickled in order for the master flow operator to execute it using the session.
This creates a concept of "compiled" Job, that pretty much only stores protobufs with the Jobs to be executed, avoiding any issue with pickling.

Reviewed By: dzhulgakov

Differential Revision: D4554799

fbshipit-source-id: 2ee9877ca49a796d51925e5ec917436e3d930984
This commit is contained in:
Alisson Gusatti Azzolini 2017-02-21 20:42:35 -08:00 committed by Facebook Github Bot
parent 8fa156d082
commit 6ff05fd49d
4 changed files with 114 additions and 60 deletions

View File

@ -58,11 +58,30 @@ class Job(object):
model = build_model(params)
build_hogwild_trainer(reader, model)
"""
def __init__(self):
self.init_group = TaskGroup(workspace_type=WorkspaceType.GLOBAL)
self.epoch_group = TaskGroup()
self.exit_group = TaskGroup()
self.stop_signals = []
def __init__(self,
init_group=None, epoch_group=None,
exit_group=None, stop_signals=None,
nodes_to_checkpoint=None):
self.init_group = init_group or TaskGroup(
workspace_type=WorkspaceType.GLOBAL)
self.epoch_group = epoch_group or TaskGroup()
self.exit_group = exit_group or TaskGroup()
self.stop_signals = stop_signals or []
self._nodes_to_checkpoint = nodes_to_checkpoint
def nodes_to_checkpoint(self):
if self._nodes_to_checkpoint:
return self._nodes_to_checkpoint
else:
return self.init_group.used_nodes()
def compile(self, session_class):
return Job(
init_group=session_class.compile(self.init_group),
epoch_group=session_class.compile(self.epoch_group),
exit_group=session_class.compile(self.exit_group),
stop_signals=self.stop_signals,
nodes_to_checkpoint=self.nodes_to_checkpoint())
def __enter__(self):
self.epoch_group.__enter__()
@ -225,7 +244,7 @@ class JobRunner(object):
if self.checkpoint:
logger.info('Preparing checkpoint ...')
client.run(self.checkpoint.init(
self.job.init_group.used_nodes(),
self.job.nodes_to_checkpoint(),
retrieve_from_epoch=self.resume_from_epoch))
if from_scratch:
logger.info('Saving first checkpoint ...')

View File

@ -55,13 +55,16 @@ class TestCheckpoint(TestCase):
return output_fetcher.outputs()[0].fetch()
session, checkpoint = builder()
num_epochs = JobRunner(job, checkpoint)(session)
compiled_job = job.compile(LocalSession)
num_epochs = JobRunner(compiled_job, checkpoint)(session)
self.assertEquals(num_epochs, len(EXPECTED_TOTALS))
self.assertEquals(fetch_total(session), EXPECTED_TOTALS[-1])
for initial_epoch in range(1, num_epochs + 1):
session, checkpoint = builder()
JobRunner(job, checkpoint, resume_from_epoch=initial_epoch)(session)
JobRunner(
compiled_job,
checkpoint, resume_from_epoch=initial_epoch)(session)
self.assertEquals(fetch_total(session), EXPECTED_TOTALS[-1])
for epoch in range(1, num_epochs + 1):

View File

@ -5,7 +5,14 @@ from __future__ import unicode_literals
from caffe2.python import core, workspace
from caffe2.python.task import Task, TaskGroup, WorkspaceType
from caffe2.python.task import Cluster, Task, TaskGroup, WorkspaceType
class CompiledRunnable(object):
""" Wrapper for compiled runnable returned from session.compile() """
def __init__(self, obj, session_class):
self.obj = obj
self.session_class = session_class
class Session(object):
@ -62,29 +69,46 @@ class Session(object):
access each other's blobs. On the other hand, tasks running on the same
node are guaranteed to run on the same workspace within a run.
"""
_compiled_cache = {}
def __init__(self):
self._open = True
self._runnable_cache = {}
def is_open(self):
return self._open
@classmethod
def compile(cls, runnable):
if isinstance(runnable, CompiledRunnable):
assert cls == runnable.session_class, (
'Runnable was compiled for different session type. ' +
'Need: %s, got: %s' % (
cls.__name__, runnable.session_class.__name__))
return runnable
if runnable in cls._compiled_cache:
return cls._compiled_cache[runnable]
if isinstance(runnable, TaskGroup):
tg = runnable
else:
tg = TaskGroup(workspace_type=WorkspaceType.GLOBAL)
if isinstance(runnable, Task):
tg.add(runnable)
elif isinstance(runnable, core.ExecutionStep):
tg.add(Task(step=runnable))
else:
step = core.execution_step('runnable', runnable)
tg.add(Task(step=step))
compiled = CompiledRunnable(
cls._compile_task_group(tg), session_class=cls)
cls._compiled_cache[runnable] = compiled
return compiled
def run(self, runnable):
assert self.is_open(), 'Session is closed.'
if runnable not in self._runnable_cache:
if isinstance(runnable, TaskGroup):
tg = runnable
else:
tg = TaskGroup(workspace_type=WorkspaceType.GLOBAL)
if isinstance(runnable, Task):
tg.add(runnable)
elif isinstance(runnable, core.ExecutionStep):
tg.add(Task(step=runnable))
else:
step = core.execution_step('runnable', runnable)
tg.add(Task(step=step))
self._runnable_cache[runnable] = tg
self._run_task_group(self._runnable_cache[runnable])
self._run_compiled(self.compile(runnable).obj)
def close(self):
if self.is_open():
@ -94,9 +118,13 @@ class Session(object):
def fetch_output(self, output):
raise NotImplementedError()
def _run_task_group(self, task_group):
def _run_compiled(self, task_group):
raise NotImplementedError()
@classmethod
def _compile_task_group(cls, task_group):
return task_group
def _do_close(self):
pass
@ -121,25 +149,27 @@ class LocalSession(Session):
def __init__(self, ws=None):
Session.__init__(self)
self._ws = ws or workspace.C.Workspace()
self._plan_caches = {}
def _run_task_group(self, task_group):
if task_group not in self._plan_caches:
@classmethod
def _compile_task_group(cls, task_group):
with Cluster():
task = task_group.to_task()
plan = core.Plan('task_group_plan')
plan.AddStep(task.get_step())
self._plan_caches[task_group] = (plan, task)
plan, task = self._plan_caches[task_group]
plan = core.Plan('task_group_plan')
plan.AddStep(task.get_step())
return (plan, task.output_list(), task.workspace_type)
def _run_compiled(self, compiled):
plan, output_list, workspace_type = compiled
# make sure the output blobs belong to the parent workspace
outputs = []
for name in task.output_names():
for name in output_list.names():
self._ws.create_blob(str(name))
outputs.append(core.BlobReference(str(name)))
task.set_outputs(outputs, _fetch_func=self._fetch_output)
output_list.set_values(outputs, _fetch_func=self._fetch_output)
task_ws = (
workspace.C.Workspace(self._ws)
if task.workspace_type == WorkspaceType.PRIVATE else self._ws)
if workspace_type == WorkspaceType.PRIVATE else self._ws)
with workspace.WorkspaceGuard(task_ws):
task_ws.run(plan)

View File

@ -378,6 +378,30 @@ def final_output(blob_or_record):
return cur_task.add_output(blob_or_record)
class TaskOutputList(object):
""" Keeps a list of outputs for a task """
def __init__(self, outputs=None):
self.outputs = outputs or []
def names(self):
"""
Retrive the output names.
TODO(azzolini): make this schema-based.
"""
names = []
for o in self.outputs:
names += o.names
return names
def set_values(self, values, _fetch_func=None):
offset = 0
for o in self.outputs:
num = len(o.names)
o.set(values[offset:offset + num], _fetch_func)
offset += num
assert offset == len(values), 'Wrong number of output values.'
@context.define_context()
class Task(object):
"""
@ -515,34 +539,12 @@ class Task(object):
self._step_with_setup = core.execution_step(self.name, [])
return self._step_with_setup
def output_list(self):
return TaskOutputList(self._outputs)
def outputs(self):
return self._outputs
def output_names(self):
"""
Retrive the output names.
TODO(azzolini): make this schema-based.
"""
names = []
for o in self._outputs:
names += o.names
return names
def set_outputs(self, values, _fetch_func):
"""
Set output values.
TODO(azzolini): make this schema-based.
"""
offset = 0
for o in self._outputs:
num = len(o.names)
o.set(values[offset:offset + num], _fetch_func)
offset += num
assert offset == len(values), 'Wrong number of output values.'
def resolved_outputs(self):
return [output.get() for output in self._outputs]
def _notify_used(self):
self.get_step()
self._already_used = True