mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: These return views in Python 3 which would not do anything in a lot of usages currently present in Caffe2. This diff simply removes (almost) all usages of these two in Caffe2 and sub projects in favor of comprehensions which are also easier to read/understand Reviewed By: akyrola Differential Revision: D5142049 fbshipit-source-id: e800631d2df7d0823fed698cae46c486038007dc
441 lines
15 KiB
Python
441 lines
15 KiB
Python
## @package net_builder
|
|
# Module caffe2.python.net_builder
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
from __future__ import unicode_literals
|
|
|
|
from caffe2.python import core, context
|
|
from caffe2.python.task import Task, TaskGroup
|
|
|
|
|
|
@context.define_context()
|
|
class NetBuilder(object):
|
|
"""
|
|
Scope-driven mechanism for building nets, loops and conditional blocks.
|
|
Example:
|
|
from caffe2.python.net_builder import NetBuilder, ops
|
|
with NetBuilder() as nb:
|
|
c = ops.Const(5)
|
|
d = ops.Const(0)
|
|
with ops.loop():
|
|
ops.stop_if(ops.LE([c, ops.Const(0)]))
|
|
ops.Add([c, ops.Const(-1)], [c])
|
|
with ops.If(ops.GE([c, ops.Const(3)])):
|
|
ops.Add([d, ops.Const(10)])
|
|
ops.Print(c, [])
|
|
ops.Print(d, [])
|
|
step = core.to_execution_step(nb)
|
|
"""
|
|
def __init__(self, name=None, _stop_blob_required=False,
|
|
_stop_blob=None, _fullname=None):
|
|
nb = NetBuilder.current(required=False)
|
|
assert not _fullname or not name, 'Cannot set both _fullname and name'
|
|
self.name = _fullname or '/'.join(
|
|
n for n in (nb.name if nb else None, name) if n
|
|
)
|
|
self._frozen = False
|
|
self._current_net = None
|
|
self._children = []
|
|
self._stop_blob = _stop_blob
|
|
self._stop_blob_required = _stop_blob_required
|
|
|
|
def stop_blob(self):
|
|
"""
|
|
Returns the BlobReference to the stop_blob of this NetBuilder.
|
|
If one is not yet available, creates one.
|
|
This function assumes that the stop_blob() will be used immediatelly
|
|
in the current net, so it doesn't initialize it if the current net is
|
|
the first of the builder.
|
|
"""
|
|
if self._stop_blob is None:
|
|
net = self.current_net()
|
|
self._stop_blob = core.BlobReference(
|
|
net.NextName('stop_blob'), net=net)
|
|
if self._current_net != self._children[0]:
|
|
self._children.insert(0, core.Net('stop_blob_init'))
|
|
self._children[0].Const(False, blob_out=self._stop_blob)
|
|
return self._stop_blob
|
|
|
|
def stop_if(self, blob):
|
|
ops.Copy(blob, self.stop_blob())
|
|
self._current_net = None
|
|
|
|
def _assert_mutable(self):
|
|
assert not self._frozen, (
|
|
'This NetBuilder (%s) has been built already.' % self.name)
|
|
|
|
def add(self, child):
|
|
self._assert_mutable()
|
|
self._current_net = None
|
|
self._children.append(child)
|
|
# to-do : check it's not a dag net
|
|
if isinstance(child, core.Net):
|
|
self._current_net = child
|
|
return child
|
|
|
|
def current_net(self, name=None):
|
|
self._assert_mutable()
|
|
if self._current_net is None or name is not None:
|
|
self.add(core.Net(name))
|
|
return self._current_net
|
|
|
|
def freeze(self):
|
|
for child in self._children:
|
|
if hasattr(child, 'freeze'):
|
|
child.freeze()
|
|
self._current_net = None
|
|
self._frozen = True
|
|
|
|
def get(self):
|
|
self.freeze()
|
|
return self._children
|
|
|
|
def __exit__(self, etype, *args):
|
|
self.freeze()
|
|
if etype is not None:
|
|
return
|
|
assert (not self._stop_blob_required) or self._stop_blob is not None, (
|
|
'This NetBuilder (%s) requires a stop condition ' % self.name +
|
|
'to be set with `stop` or `stop_if`')
|
|
|
|
def __str__(self):
|
|
return self.name or 'Un-named NetBuilder'
|
|
|
|
|
|
class Operations(object):
|
|
"""
|
|
Operations to be used in the context of a NetBuilder.
|
|
"""
|
|
def net(self, net=None, name=None):
|
|
"""
|
|
Retrieves the current net, or add a new net to the builder.
|
|
Args:
|
|
net: If provided, add the given net to the active builder.
|
|
Else, returns the current Net or creates a new one as needed.
|
|
name: if provided, creates a new Net with given name and makes
|
|
it the new current net of the active builder. Cannot
|
|
be provided if net is provided.
|
|
"""
|
|
assert name is None or net is None, (
|
|
'Cannot provide both `net` and `name`.')
|
|
if net is not None:
|
|
NetBuilder.current().add(net)
|
|
return net
|
|
return NetBuilder.current().current_net(name=name)
|
|
|
|
def __getattr__(self, op_type):
|
|
"""
|
|
Adds an operator call to the currently active Net.
|
|
"""
|
|
if op_type.startswith('__'):
|
|
raise AttributeError()
|
|
# We want hasattr to work properly even if no context is active.
|
|
if NetBuilder.current(required=False) is None:
|
|
raise AttributeError('No active NetBuilder.')
|
|
return getattr(self.net(), op_type)
|
|
|
|
def task_group(self):
|
|
"""
|
|
Creates a local task group which will execute as the next step of
|
|
the current NetBuilder.
|
|
"""
|
|
from caffe2.python import task
|
|
group = NetBuilder.current()
|
|
with task.Cluster():
|
|
with task.Node('local'):
|
|
tg = task.TaskGroup()
|
|
group.add(tg)
|
|
return tg
|
|
|
|
def stop(self):
|
|
"""
|
|
Stop execution of the current execution step.
|
|
Example:
|
|
ops.Print(a, 0)
|
|
ops.stop()
|
|
ops.Print(b, 0)
|
|
In the example, 'b' will never be printed.
|
|
"""
|
|
return self.stop_if(ops.Const(True))
|
|
|
|
def stop_if(self, blob):
|
|
"""
|
|
Stop execution of the current execution step if the
|
|
condition `blob` is met.
|
|
Example:
|
|
ops.Print(a, 0)
|
|
ops.stop_if(ops.LE([x, ops.Const(0)]))
|
|
ops.Print(b, 0)
|
|
In the example, 'b' will only be printed if the value of scalar
|
|
tensor 'x' lower or equal to 0.
|
|
"""
|
|
return NetBuilder.current().stop_if(blob)
|
|
|
|
def loop(self, iters=None, name=None):
|
|
"""
|
|
Creates a NetBuilder that will execute in a loop as the next step of
|
|
the current NetBuilder. If `iters` is provided, the loop will execute
|
|
for `iters` iterations and then stop. `iters` can be a constant or a
|
|
BlobReference. If `iters` is not provided, the loop will execute
|
|
until `ops.stop` or `ops.stop_if` is called.
|
|
Examples:
|
|
a = ops.Const(5)
|
|
with ops.loop():
|
|
ops.stop_if(ops.LE([a, ops.Const(0)]))
|
|
ops.Print(a, 0)
|
|
ops.Add([a, ops.Const(-1)], [a])
|
|
Above, 'a' will be printed 5 times, with values 5 to 1.
|
|
|
|
with ops.loop(10) as loop:
|
|
ops.LogInfo(loop.iter())
|
|
This will print the numbers from 0 to 9.
|
|
|
|
x = ops.Add([ops.Const(10), ops.Const(10)])
|
|
with ops.loop(x) as loop:
|
|
ops.LogInfo(loop.iter())
|
|
This will print the numbers from 0 to 19.
|
|
"""
|
|
return NetBuilder.current().add(_Loop(iters, name=name))
|
|
|
|
def stop_guard(self, has_stopped_blob=None, name=None):
|
|
"""
|
|
Creates a NetBuilder that will execute once as the next step of the
|
|
current NetBuilder. After execution, a bool tensor will indicate
|
|
whether the inner execution was halted with `stop` or `stop_if`.
|
|
Example:
|
|
a = ops.Const(True)
|
|
with ops.stop_guard() as sg1:
|
|
ops.stop_if(a)
|
|
ops.Print(ops.Const('did not stop'))
|
|
b = ops.Const(False)
|
|
with ops.stop_guard() as sg2:
|
|
ops.stop_if(b)
|
|
ops.Print(ops.Const('did not stop'))
|
|
ops.Print(sg1.has_stopped(), [])
|
|
ops.Print(sg2.has_stopped(), [])
|
|
In the example, 'did not stop' will be printed once,
|
|
followed by True and False.
|
|
"""
|
|
return NetBuilder.current().add(
|
|
_StopGuard(has_stopped_blob=has_stopped_blob, name=name))
|
|
|
|
def If(self, cond, name=None):
|
|
"""
|
|
Creates a NetBuilder that will execute once as the next step of the
|
|
current NetBuilder if the blob `cond` is True.
|
|
Example:
|
|
with ops.If(ops.Const(True)):
|
|
ops.Print(ops.Const('Will print'))
|
|
with ops.If(ops.Const(False)):
|
|
ops.Print(ops.Const('Wont print'))
|
|
The example will print 'Will print' once.
|
|
"""
|
|
return NetBuilder.current().add(_RunIf(cond, name=name))
|
|
|
|
def task_init(self):
|
|
"""
|
|
Defines operations that will be executed once at task startup.
|
|
Useful when implementing processors, that don't have access to the Task
|
|
top-level structure.
|
|
Example:
|
|
def my_processor(rec):
|
|
with ops.task_init():
|
|
one = ops.Const(1)
|
|
two = ops.Const(1)
|
|
return Tuple(
|
|
ops.Add(rec[0](), zero), ops.Add(rec[1](), two))
|
|
"""
|
|
setup = _SetupBuilder(_SetupBuilder.INIT)
|
|
self.net().add_attribute(Task.TASK_SETUP, setup)
|
|
return setup
|
|
|
|
def task_exit(self):
|
|
"""
|
|
Define operations to be executed at task shutdown.
|
|
Useful when implementing processors, that don't have access to the Task
|
|
top-level structure.
|
|
Example:
|
|
def read_queue(queue):
|
|
with ops.task_exit():
|
|
queue.close(ops.net())
|
|
return queue.read(ops.net())
|
|
"""
|
|
setup = _SetupBuilder(_SetupBuilder.EXIT)
|
|
self.net().add_attribute(Task.TASK_SETUP, setup)
|
|
return setup
|
|
|
|
def local_init(self):
|
|
"""
|
|
Similar to `task_init`, but executes at TaskGroup's startup instead,
|
|
before any task of the group starts executing.
|
|
"""
|
|
setup = _SetupBuilder(_SetupBuilder.INIT)
|
|
self.net().add_attribute(TaskGroup.LOCAL_SETUP, setup)
|
|
return setup
|
|
|
|
def local_exit(self):
|
|
"""
|
|
Similar to `task_init`, but executes at TaskGroup's exit instead,
|
|
after all tasks of the group finished execution.
|
|
"""
|
|
setup = _SetupBuilder(_SetupBuilder.EXIT)
|
|
self.net().add_attribute(TaskGroup.LOCAL_SETUP, setup)
|
|
return setup
|
|
|
|
def task_reporter(self, interval_ms=1000, name=None):
|
|
"""
|
|
Define operations to be executed at every time interval from
|
|
task start-up to finish. These operations are guaranteed to
|
|
execute at least once after all other operations of the task are
|
|
finished.
|
|
|
|
Example:
|
|
with ops.task_reporter(interval_ms=10000):
|
|
ops.LogInfo('10s elapsed')
|
|
"""
|
|
return _ReporterBuilder(interval_ms, net=self.net(), name=name)
|
|
|
|
def local_reporter(self, interval_ms=1000, name=None):
|
|
"""
|
|
Similar to task_report, but operations defined within this block
|
|
will run repeatedly for as long as any of the tasks in the current
|
|
TaskGroup have not finished.
|
|
"""
|
|
return _ReporterBuilder(interval_ms, name=name)
|
|
|
|
|
|
ops = Operations()
|
|
|
|
|
|
class _ReporterBuilder(NetBuilder):
|
|
def __init__(self, interval_ms, net=None, name=None):
|
|
NetBuilder.__init__(self, name)
|
|
self._net = net
|
|
self.interval_ms = interval_ms
|
|
|
|
def __exit__(self, etype, *args):
|
|
if etype is None:
|
|
step = core.to_execution_step(self)
|
|
step.RunEveryMillis(self.interval_ms)
|
|
if self._net:
|
|
self._net.add_attribute(Task.REPORT_STEP, step)
|
|
else:
|
|
TaskGroup.current().report_step(
|
|
step, interval_ms=self.interval_ms)
|
|
NetBuilder.__exit__(self, etype, *args)
|
|
|
|
|
|
class _SetupBuilder(NetBuilder):
|
|
INIT = 'init'
|
|
EXIT = 'exit'
|
|
|
|
def __init__(self, type, name=None):
|
|
NetBuilder.__init__(self, name)
|
|
self.type = type
|
|
|
|
def setup(self, net):
|
|
if self.type == _SetupBuilder.INIT:
|
|
return core.to_execution_step(self)
|
|
|
|
def exit(self, net):
|
|
if self.type == _SetupBuilder.EXIT:
|
|
return core.to_execution_step(self)
|
|
|
|
|
|
class _RunOnce(NetBuilder):
|
|
def __init__(self, name=None):
|
|
NetBuilder.__init__(self, name)
|
|
|
|
def __exit__(self, etype, *args):
|
|
if etype is None and self._stop_blob is not None:
|
|
ops.stop()
|
|
NetBuilder.__exit__(self, etype, *args)
|
|
|
|
|
|
class _StopGuard(_RunOnce):
|
|
def __init__(self, has_stopped_blob=None, name=None):
|
|
_RunOnce.__init__(self, name)
|
|
self._stopped = has_stopped_blob
|
|
self._ran = False
|
|
|
|
def __enter__(self):
|
|
r = _RunOnce.__enter__(self)
|
|
self._stopped = ops.Const(True, blob_out=self._stopped)
|
|
return r
|
|
|
|
def __exit__(self, etype, *args):
|
|
if etype is None:
|
|
self._ran = True
|
|
ops.Const(False, blob_out=self._stopped)
|
|
_RunOnce.__exit__(self, etype, *args)
|
|
|
|
def has_stopped(self):
|
|
"""
|
|
Return a blob that will be set to scalar bool `True` after
|
|
this net builder ran, iff it was halted early.
|
|
"""
|
|
assert self._ran, 'Context not used yet.'
|
|
return self._stopped
|
|
|
|
|
|
class _Loop(NetBuilder):
|
|
def __init__(self, iters=None, name=None):
|
|
NetBuilder.__init__(self, name, _stop_blob_required=True)
|
|
if iters is not None:
|
|
self._inc = ops.Const(1)
|
|
self._iter = ops.Const(0)
|
|
self._num_iters = (
|
|
iters if isinstance(iters, core.BlobReference)
|
|
else ops.Const(iters))
|
|
else:
|
|
self._num_iters = None
|
|
|
|
def iter(self):
|
|
assert self._num_iters is not None, (
|
|
'This loop does not have a number of iterations.')
|
|
assert self._iter is not None, (
|
|
'iter() must be called from inside the loop context')
|
|
return self._iter
|
|
|
|
def __enter__(self):
|
|
builder = NetBuilder.__enter__(self)
|
|
if self._num_iters is not None:
|
|
ops.stop_if(ops.GE([self._iter, self._num_iters]))
|
|
return builder
|
|
|
|
def __exit__(self, type, *args):
|
|
if type is None and self._num_iters is not None:
|
|
self.current_net().Add([self._iter, self._inc], [self._iter])
|
|
NetBuilder.__exit__(self, type, *args)
|
|
|
|
|
|
class _RunIf(_RunOnce):
|
|
def __init__(self, cond_blob=None, name=None, _already_ran=None):
|
|
_RunOnce.__init__(self, name)
|
|
assert cond_blob or _already_ran
|
|
self._is_else = cond_blob is None
|
|
if _already_ran is None:
|
|
self._else_blob = ops.Not(cond_blob)
|
|
self._already_ran = ops.Const(False)
|
|
else:
|
|
self._already_ran = _already_ran
|
|
self._else_blob = _already_ran if cond_blob is None else (
|
|
ops.Or([_already_ran, ops.Not(cond_blob)]))
|
|
|
|
def __enter__(self):
|
|
r = _RunOnce.__enter__(self)
|
|
ops.stop_if(self._else_blob)
|
|
ops.Const(True, blob_out=self._already_ran)
|
|
return r
|
|
|
|
def Elif(self, cond, name=None):
|
|
assert not self._is_else, 'Else not allowed for an Else.'
|
|
return NetBuilder.current().add(_RunIf(
|
|
cond, name=name or self.name, _already_ran=self._already_ran))
|
|
|
|
def Else(self, name=None):
|
|
assert not self._is_else, 'Elif not allowed for an Else.'
|
|
return NetBuilder.current().add(
|
|
_RunIf(name=name or self.name, _already_ran=self._already_ran))
|