mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: At the moment LocalSession creates a new workspace if none if provided. As a result anything that have been executed in local session is not going to be avaiable to the external caller, i.e. everything that is using SingleRunner can only observe side-effects and not really access intermediate blobs. This diff is modifying LocalSession to run in current workspace instead (unless it gots some really weird effects because we rely on privateness of the workspace it should work). Differential Revision: D4634743 fbshipit-source-id: 975bed154c7ca215dc3fc0d60f05a7c092711482
178 lines
5.6 KiB
Python
178 lines
5.6 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
from __future__ import unicode_literals
|
|
|
|
|
|
from caffe2.python import core, workspace
|
|
from caffe2.python.task import Cluster, Task, TaskGroup, WorkspaceType
|
|
|
|
|
|
class CompiledRunnable(object):
|
|
""" Wrapper for compiled runnable returned from session.compile() """
|
|
def __init__(self, obj, session_class):
|
|
self.obj = obj
|
|
self.session_class = session_class
|
|
|
|
|
|
class Session(object):
|
|
"""
|
|
Allows to run Nets, ExecutionSteps, Plans, Tasks and TaskGroups.
|
|
A session can potentially run in multiple nodes concurrently.
|
|
|
|
|
|
Example:
|
|
from core import Net
|
|
from caffe2.python.task import Task, TaskGroup, WorkspaceType
|
|
|
|
net = Net('test1')
|
|
net.Add([net.Const(1), net.Const(2)])
|
|
|
|
net2 = net.Clone()
|
|
step = core.execution_step('step1', [net2])
|
|
|
|
with TaskGroup(WorkspaceType.GLOBAL) as init_tg:
|
|
with Node('node1'):
|
|
n1setup = net.Net('n1setup')
|
|
n1msg = n1setup.Const('Hello from node 1.')
|
|
Task(step=n1setup)
|
|
|
|
with TaskGroup() as private_tg:
|
|
with Node('node1'):
|
|
n1 = net.Net('n1')
|
|
n1.Print(n1msg, 0)
|
|
Task(step=n1)
|
|
with Node('node2'):
|
|
n2 = net.Net('n2')
|
|
n2.Print(n2.Const('Hello from node 2.'), 0)
|
|
Task(step=n2)
|
|
|
|
session = LocalSession()
|
|
session.run(net)
|
|
session.run(step)
|
|
session.run(init_tg)
|
|
session.run(private_tg)
|
|
|
|
|
|
Global Workspace:
|
|
At the beggining of the session, a global workspace is created and kept
|
|
alive for the duration of the session.
|
|
|
|
|
|
Private Workspace:
|
|
Tasks can be run either directly on the global workspace, or they can
|
|
instantiate a private child workspace that is released after each run.
|
|
|
|
Blob visibility:
|
|
Tasks running in different nodes in parallel will always run under
|
|
different workspaces, so it must be assumed that they won't be able to
|
|
access each other's blobs. On the other hand, tasks running on the same
|
|
node are guaranteed to run on the same workspace within a run.
|
|
"""
|
|
|
|
_compiled_cache = {}
|
|
|
|
def __init__(self):
|
|
self._open = True
|
|
|
|
def is_open(self):
|
|
return self._open
|
|
|
|
@classmethod
|
|
def compile(cls, runnable):
|
|
if isinstance(runnable, CompiledRunnable):
|
|
assert cls == runnable.session_class, (
|
|
'Runnable was compiled for different session type. ' +
|
|
'Need: %s, got: %s' % (
|
|
cls.__name__, runnable.session_class.__name__))
|
|
return runnable
|
|
|
|
if runnable in cls._compiled_cache:
|
|
return cls._compiled_cache[runnable]
|
|
|
|
if isinstance(runnable, TaskGroup):
|
|
tg = runnable
|
|
else:
|
|
tg = TaskGroup(workspace_type=WorkspaceType.GLOBAL)
|
|
if isinstance(runnable, Task):
|
|
tg.add(runnable)
|
|
elif isinstance(runnable, core.ExecutionStep):
|
|
tg.add(Task(step=runnable))
|
|
else:
|
|
step = core.execution_step('runnable', runnable)
|
|
tg.add(Task(step=step))
|
|
compiled = CompiledRunnable(
|
|
cls._compile_task_group(tg), session_class=cls)
|
|
cls._compiled_cache[runnable] = compiled
|
|
return compiled
|
|
|
|
def run(self, runnable):
|
|
assert self.is_open(), 'Session is closed.'
|
|
self._run_compiled(self.compile(runnable).obj)
|
|
|
|
def close(self):
|
|
if self.is_open():
|
|
self._do_close()
|
|
self._open = False
|
|
|
|
def fetch_output(self, output):
|
|
raise NotImplementedError()
|
|
|
|
def _run_compiled(self, task_group):
|
|
raise NotImplementedError()
|
|
|
|
@classmethod
|
|
def _compile_task_group(cls, task_group):
|
|
return task_group
|
|
|
|
def _do_close(self):
|
|
pass
|
|
|
|
def __enter__(self):
|
|
assert self._open, 'Session already closed.'
|
|
return self
|
|
|
|
def __exit__(self, ex_type, value, traceback):
|
|
if ex_type is None:
|
|
self.close()
|
|
|
|
|
|
class LocalSession(Session):
|
|
"""
|
|
Session that runs in a single node.
|
|
Tasks are all remapped to run in parallel in the 'local' node.
|
|
|
|
Currently, LocalSession runs all parallel tasks in the same workspace,
|
|
but this behavior may change in the future. Only tasks pointing to the
|
|
same logical node are guaranteed to always run in the same workspace.
|
|
"""
|
|
def __init__(self, ws=None):
|
|
Session.__init__(self)
|
|
self._ws = ws or workspace.C.Workspace.current
|
|
|
|
@classmethod
|
|
def _compile_task_group(cls, task_group):
|
|
with Cluster():
|
|
task = task_group.to_task()
|
|
plan = core.Plan('task_group_plan')
|
|
plan.AddStep(task.get_step())
|
|
return (plan, task.output_list(), task.workspace_type)
|
|
|
|
def _run_compiled(self, compiled):
|
|
plan, output_list, workspace_type = compiled
|
|
|
|
# make sure the output blobs belong to the parent workspace
|
|
outputs = []
|
|
for name in output_list.names():
|
|
self._ws.create_blob(str(name))
|
|
outputs.append(core.BlobReference(str(name)))
|
|
output_list.set_values(outputs, _fetch_func=self._fetch_output)
|
|
task_ws = (
|
|
workspace.C.Workspace(self._ws)
|
|
if workspace_type == WorkspaceType.PRIVATE else self._ws)
|
|
with workspace.WorkspaceGuard(task_ws):
|
|
task_ws.run(plan)
|
|
|
|
def _fetch_output(self, output):
|
|
return self._ws.blobs[str(output)].fetch()
|