mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Fix issues pickling jobs
Summary: We were running into a problem where a Job could not be pickled. It needs to be pickled in order for the master flow operator to execute it using the session. This creates a concept of "compiled" Job, that pretty much only stores protobufs with the Jobs to be executed, avoiding any issue with pickling. Reviewed By: dzhulgakov Differential Revision: D4554799 fbshipit-source-id: 2ee9877ca49a796d51925e5ec917436e3d930984
This commit is contained in:
parent
8fa156d082
commit
6ff05fd49d
|
|
@ -58,11 +58,30 @@ class Job(object):
|
|||
model = build_model(params)
|
||||
build_hogwild_trainer(reader, model)
|
||||
"""
|
||||
def __init__(self):
|
||||
self.init_group = TaskGroup(workspace_type=WorkspaceType.GLOBAL)
|
||||
self.epoch_group = TaskGroup()
|
||||
self.exit_group = TaskGroup()
|
||||
self.stop_signals = []
|
||||
def __init__(self,
|
||||
init_group=None, epoch_group=None,
|
||||
exit_group=None, stop_signals=None,
|
||||
nodes_to_checkpoint=None):
|
||||
self.init_group = init_group or TaskGroup(
|
||||
workspace_type=WorkspaceType.GLOBAL)
|
||||
self.epoch_group = epoch_group or TaskGroup()
|
||||
self.exit_group = exit_group or TaskGroup()
|
||||
self.stop_signals = stop_signals or []
|
||||
self._nodes_to_checkpoint = nodes_to_checkpoint
|
||||
|
||||
def nodes_to_checkpoint(self):
|
||||
if self._nodes_to_checkpoint:
|
||||
return self._nodes_to_checkpoint
|
||||
else:
|
||||
return self.init_group.used_nodes()
|
||||
|
||||
def compile(self, session_class):
|
||||
return Job(
|
||||
init_group=session_class.compile(self.init_group),
|
||||
epoch_group=session_class.compile(self.epoch_group),
|
||||
exit_group=session_class.compile(self.exit_group),
|
||||
stop_signals=self.stop_signals,
|
||||
nodes_to_checkpoint=self.nodes_to_checkpoint())
|
||||
|
||||
def __enter__(self):
|
||||
self.epoch_group.__enter__()
|
||||
|
|
@ -225,7 +244,7 @@ class JobRunner(object):
|
|||
if self.checkpoint:
|
||||
logger.info('Preparing checkpoint ...')
|
||||
client.run(self.checkpoint.init(
|
||||
self.job.init_group.used_nodes(),
|
||||
self.job.nodes_to_checkpoint(),
|
||||
retrieve_from_epoch=self.resume_from_epoch))
|
||||
if from_scratch:
|
||||
logger.info('Saving first checkpoint ...')
|
||||
|
|
|
|||
|
|
@ -55,13 +55,16 @@ class TestCheckpoint(TestCase):
|
|||
return output_fetcher.outputs()[0].fetch()
|
||||
|
||||
session, checkpoint = builder()
|
||||
num_epochs = JobRunner(job, checkpoint)(session)
|
||||
compiled_job = job.compile(LocalSession)
|
||||
num_epochs = JobRunner(compiled_job, checkpoint)(session)
|
||||
self.assertEquals(num_epochs, len(EXPECTED_TOTALS))
|
||||
self.assertEquals(fetch_total(session), EXPECTED_TOTALS[-1])
|
||||
|
||||
for initial_epoch in range(1, num_epochs + 1):
|
||||
session, checkpoint = builder()
|
||||
JobRunner(job, checkpoint, resume_from_epoch=initial_epoch)(session)
|
||||
JobRunner(
|
||||
compiled_job,
|
||||
checkpoint, resume_from_epoch=initial_epoch)(session)
|
||||
self.assertEquals(fetch_total(session), EXPECTED_TOTALS[-1])
|
||||
|
||||
for epoch in range(1, num_epochs + 1):
|
||||
|
|
|
|||
|
|
@ -5,7 +5,14 @@ from __future__ import unicode_literals
|
|||
|
||||
|
||||
from caffe2.python import core, workspace
|
||||
from caffe2.python.task import Task, TaskGroup, WorkspaceType
|
||||
from caffe2.python.task import Cluster, Task, TaskGroup, WorkspaceType
|
||||
|
||||
|
||||
class CompiledRunnable(object):
|
||||
""" Wrapper for compiled runnable returned from session.compile() """
|
||||
def __init__(self, obj, session_class):
|
||||
self.obj = obj
|
||||
self.session_class = session_class
|
||||
|
||||
|
||||
class Session(object):
|
||||
|
|
@ -62,29 +69,46 @@ class Session(object):
|
|||
access each other's blobs. On the other hand, tasks running on the same
|
||||
node are guaranteed to run on the same workspace within a run.
|
||||
"""
|
||||
|
||||
_compiled_cache = {}
|
||||
|
||||
def __init__(self):
|
||||
self._open = True
|
||||
self._runnable_cache = {}
|
||||
|
||||
def is_open(self):
|
||||
return self._open
|
||||
|
||||
@classmethod
|
||||
def compile(cls, runnable):
|
||||
if isinstance(runnable, CompiledRunnable):
|
||||
assert cls == runnable.session_class, (
|
||||
'Runnable was compiled for different session type. ' +
|
||||
'Need: %s, got: %s' % (
|
||||
cls.__name__, runnable.session_class.__name__))
|
||||
return runnable
|
||||
|
||||
if runnable in cls._compiled_cache:
|
||||
return cls._compiled_cache[runnable]
|
||||
|
||||
if isinstance(runnable, TaskGroup):
|
||||
tg = runnable
|
||||
else:
|
||||
tg = TaskGroup(workspace_type=WorkspaceType.GLOBAL)
|
||||
if isinstance(runnable, Task):
|
||||
tg.add(runnable)
|
||||
elif isinstance(runnable, core.ExecutionStep):
|
||||
tg.add(Task(step=runnable))
|
||||
else:
|
||||
step = core.execution_step('runnable', runnable)
|
||||
tg.add(Task(step=step))
|
||||
compiled = CompiledRunnable(
|
||||
cls._compile_task_group(tg), session_class=cls)
|
||||
cls._compiled_cache[runnable] = compiled
|
||||
return compiled
|
||||
|
||||
def run(self, runnable):
|
||||
assert self.is_open(), 'Session is closed.'
|
||||
if runnable not in self._runnable_cache:
|
||||
if isinstance(runnable, TaskGroup):
|
||||
tg = runnable
|
||||
else:
|
||||
tg = TaskGroup(workspace_type=WorkspaceType.GLOBAL)
|
||||
if isinstance(runnable, Task):
|
||||
tg.add(runnable)
|
||||
elif isinstance(runnable, core.ExecutionStep):
|
||||
tg.add(Task(step=runnable))
|
||||
else:
|
||||
step = core.execution_step('runnable', runnable)
|
||||
tg.add(Task(step=step))
|
||||
self._runnable_cache[runnable] = tg
|
||||
self._run_task_group(self._runnable_cache[runnable])
|
||||
self._run_compiled(self.compile(runnable).obj)
|
||||
|
||||
def close(self):
|
||||
if self.is_open():
|
||||
|
|
@ -94,9 +118,13 @@ class Session(object):
|
|||
def fetch_output(self, output):
|
||||
raise NotImplementedError()
|
||||
|
||||
def _run_task_group(self, task_group):
|
||||
def _run_compiled(self, task_group):
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
def _compile_task_group(cls, task_group):
|
||||
return task_group
|
||||
|
||||
def _do_close(self):
|
||||
pass
|
||||
|
||||
|
|
@ -121,25 +149,27 @@ class LocalSession(Session):
|
|||
def __init__(self, ws=None):
|
||||
Session.__init__(self)
|
||||
self._ws = ws or workspace.C.Workspace()
|
||||
self._plan_caches = {}
|
||||
|
||||
def _run_task_group(self, task_group):
|
||||
if task_group not in self._plan_caches:
|
||||
@classmethod
|
||||
def _compile_task_group(cls, task_group):
|
||||
with Cluster():
|
||||
task = task_group.to_task()
|
||||
plan = core.Plan('task_group_plan')
|
||||
plan.AddStep(task.get_step())
|
||||
self._plan_caches[task_group] = (plan, task)
|
||||
plan, task = self._plan_caches[task_group]
|
||||
plan = core.Plan('task_group_plan')
|
||||
plan.AddStep(task.get_step())
|
||||
return (plan, task.output_list(), task.workspace_type)
|
||||
|
||||
def _run_compiled(self, compiled):
|
||||
plan, output_list, workspace_type = compiled
|
||||
|
||||
# make sure the output blobs belong to the parent workspace
|
||||
outputs = []
|
||||
for name in task.output_names():
|
||||
for name in output_list.names():
|
||||
self._ws.create_blob(str(name))
|
||||
outputs.append(core.BlobReference(str(name)))
|
||||
task.set_outputs(outputs, _fetch_func=self._fetch_output)
|
||||
output_list.set_values(outputs, _fetch_func=self._fetch_output)
|
||||
task_ws = (
|
||||
workspace.C.Workspace(self._ws)
|
||||
if task.workspace_type == WorkspaceType.PRIVATE else self._ws)
|
||||
if workspace_type == WorkspaceType.PRIVATE else self._ws)
|
||||
with workspace.WorkspaceGuard(task_ws):
|
||||
task_ws.run(plan)
|
||||
|
||||
|
|
|
|||
|
|
@ -378,6 +378,30 @@ def final_output(blob_or_record):
|
|||
return cur_task.add_output(blob_or_record)
|
||||
|
||||
|
||||
class TaskOutputList(object):
|
||||
""" Keeps a list of outputs for a task """
|
||||
def __init__(self, outputs=None):
|
||||
self.outputs = outputs or []
|
||||
|
||||
def names(self):
|
||||
"""
|
||||
Retrive the output names.
|
||||
TODO(azzolini): make this schema-based.
|
||||
"""
|
||||
names = []
|
||||
for o in self.outputs:
|
||||
names += o.names
|
||||
return names
|
||||
|
||||
def set_values(self, values, _fetch_func=None):
|
||||
offset = 0
|
||||
for o in self.outputs:
|
||||
num = len(o.names)
|
||||
o.set(values[offset:offset + num], _fetch_func)
|
||||
offset += num
|
||||
assert offset == len(values), 'Wrong number of output values.'
|
||||
|
||||
|
||||
@context.define_context()
|
||||
class Task(object):
|
||||
"""
|
||||
|
|
@ -515,34 +539,12 @@ class Task(object):
|
|||
self._step_with_setup = core.execution_step(self.name, [])
|
||||
return self._step_with_setup
|
||||
|
||||
def output_list(self):
|
||||
return TaskOutputList(self._outputs)
|
||||
|
||||
def outputs(self):
|
||||
return self._outputs
|
||||
|
||||
def output_names(self):
|
||||
"""
|
||||
Retrive the output names.
|
||||
TODO(azzolini): make this schema-based.
|
||||
"""
|
||||
names = []
|
||||
for o in self._outputs:
|
||||
names += o.names
|
||||
return names
|
||||
|
||||
def set_outputs(self, values, _fetch_func):
|
||||
"""
|
||||
Set output values.
|
||||
TODO(azzolini): make this schema-based.
|
||||
"""
|
||||
offset = 0
|
||||
for o in self._outputs:
|
||||
num = len(o.names)
|
||||
o.set(values[offset:offset + num], _fetch_func)
|
||||
offset += num
|
||||
assert offset == len(values), 'Wrong number of output values.'
|
||||
|
||||
def resolved_outputs(self):
|
||||
return [output.get() for output in self._outputs]
|
||||
|
||||
def _notify_used(self):
|
||||
self.get_step()
|
||||
self._already_used = True
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user