from __future__ import absolute_import from __future__ import division from __future__ import print_function from __future__ import unicode_literals from caffe2.python import core, context from caffe2.python.task import Task, TaskGroup @context.define_context() class NetBuilder(object): """ Scope-driven mechanism for building nets, loops and conditional blocks. Example: from caffe2.python.net_builder import NetBuilder, ops with NetBuilder() as nb: c = ops.Const(5) d = ops.Const(0) with ops.loop(): ops.stop_if(ops.LE([c, ops.Const(0)])) ops.Add([c, ops.Const(-1)], [c]) with ops.If(ops.GE([c, ops.Const(3)])): ops.Add([d, ops.Const(10)]) ops.Print(c, []) ops.Print(d, []) step = core.to_execution_step(nb) """ def __init__(self, name=None, _stop_blob_required=False): self._name = name or '' self._prefix = name + '/' if name else '' self._frozen = False self._current_net = None self._children = [] self._stop_blob = None self._stop_blob_required = _stop_blob_required def stop_blob(self): """ Returns the BlobReference to the stop_blob of this NetBuilder. If one is not yet available, creates one. This function assumes that the stop_blob() will be used immediatelly in the current net, so it doesn't initialize it if the current net is the first of the builder. """ if self._stop_blob is None: net = self.current_net() self._stop_blob = core.BlobReference( net.NextName('stop_blob'), net=net) if self._current_net != self._children[0]: self._children.insert(0, core.Net( self._prefix + 'stop_blob_init')) self._children[0].Const(False, blob_out=self._stop_blob) return self._stop_blob def stop_if(self, blob): ops.Copy(blob, self.stop_blob()) self._current_net = None def _assert_mutable(self): assert not self._frozen, ( 'This NetBuilder (%s) has been built already.' % self._name) def add(self, child): self._assert_mutable() self._current_net = None self._children.append(child) # to-do : check it's not a dag net if isinstance(child, core.Net): self._current_net = child return child def current_net(self): self._assert_mutable() if self._current_net is None: self.add(core.Net(self._prefix + 'net')) return self._current_net def freeze(self): for child in self._children: if hasattr(child, 'freeze'): child.freeze() self._current_net = None self._frozen = True def get(self): self.freeze() return self._children def __exit__(self, etype, *args): self.freeze() if etype is not None: return assert (not self._stop_blob_required) or self._stop_blob is not None, ( 'This NetBuilder (%s) requires a stop condition ' % self._name + 'to be set with `stop` or `stop_if`') class Operations(object): """ Operations to be used in the context of a NetBuilder. """ def net(self, net=None): """ Retrieves the current net, or add a new net to the builder. """ if net is not None: NetBuilder.current().add(net) return net return NetBuilder.current().current_net() def __getattr__(self, op_type): """ Adds an operator call to the currently active Net. """ if op_type.startswith('__'): raise AttributeError() return getattr(self.net(), op_type) def task_group(self): """ Creates a local task group which will execute as the next step of the current NetBuilder. """ from caffe2.python import task group = NetBuilder.current() with task.Cluster(): with task.Node('local'): tg = task.TaskGroup() group.add(tg) return tg def stop(self): """ Stop execution of the current execution step. Example: ops.Print(a, 0) ops.stop() ops.Print(b, 0) In the example, 'b' will never be printed. """ return self.stop_if(ops.Const(True)) def stop_if(self, blob): """ Stop execution of the current execution step if the condition `blob` is met. Example: ops.Print(a, 0) ops.stop_if(ops.LE([x, ops.Const(0)])) ops.Print(b, 0) In the example, 'b' will only be printed if the value of scalar tensor 'x' lower or equal to 0. """ return NetBuilder.current().stop_if(blob) def loop(self, iters=None): """ Creates a NetBuilder that will execute in a loop as the next step of the current NetBuilder. If `iters` is provided, the loop will execute for `iters` iterations and then stop. `iters` can be a constant or a BlobReference. If `iters` is not provided, the loop will execute until `ops.stop` or `ops.stop_if` is called. Examples: a = ops.Const(5) with ops.loop(): ops.stop_if(ops.LE([a, ops.Const(0)])) ops.Print(a, 0) ops.Add([a, ops.Const(-1)], [a]) Above, 'a' will be printed 5 times, with values 5 to 1. with ops.loop(10) as loop: ops.LogInfo(loop.iter()) This will print the numbers from 0 to 9. x = ops.Add([ops.Const(10), ops.Const(10)]) with ops.loop(x) as loop: ops.LogInfo(loop.iter()) This will print the numbers from 0 to 19. """ return NetBuilder.current().add(_Loop(iters)) def stop_guard(self, has_stopped_blob=None): """ Creates a NetBuilder that will execute once as the next step of the current NetBuilder. After execution, a bool tensor will indicate whether the inner execution was halted with `stop` or `stop_if`. Example: a = ops.Const(True) with ops.stop_guard() as sg1: ops.stop_if(a) ops.Print(ops.Const('did not stop')) b = ops.Const(False) with ops.stop_guard() as sg2: ops.stop_if(b) ops.Print(ops.Const('did not stop')) ops.Print(sg1.has_stopped(), []) ops.Print(sg2.has_stopped(), []) In the example, 'did not stop' will be printed once, followed by True and False. """ return NetBuilder.current().add( _StopGuard(has_stopped_blob=has_stopped_blob)) def If(self, cond): """ Creates a NetBuilder that will execute once as the next step of the current NetBuilder if the blob `cond` is True. Example: with ops.If(ops.Const(True)): ops.Print(ops.Const('Will print')) with ops.If(ops.Const(False)): ops.Print(ops.Const('Wont print')) The example will print 'Will print' once. """ return NetBuilder.current().add(_RunIf(cond)) def task_init(self): """ Defines operations that will be executed once at task startup. Useful when implementing processors, that don't have access to the Task top-level structure. Example: def my_processor(rec): with ops.task_init(): one = ops.Const(1) two = ops.Const(1) return Tuple( ops.Add(rec[0](), zero), ops.Add(rec[1](), two)) """ setup = _SetupBuilder(_SetupBuilder.INIT) self.net().add_attribute(Task.TASK_SETUP, setup) return setup def task_exit(self): """ Define operations to be executed at task shutdown. Useful when implementing processors, that don't have access to the Task top-level structure. Example: def read_queue(queue): with ops.task_exit(): queue.close(ops.net()) return queue.read(ops.net()) """ setup = _SetupBuilder(_SetupBuilder.EXIT) self.net().add_attribute(Task.TASK_SETUP, setup) return setup def local_init(self): """ Similar to `task_init`, but executes at TaskGroup's startup instead, before any task of the group starts executing. """ setup = _SetupBuilder(_SetupBuilder.INIT) self.net().add_attribute(TaskGroup.LOCAL_SETUP, setup) return setup def local_exit(self): """ Similar to `task_init`, but executes at TaskGroup's exit instead, after all tasks of the group finished execution. """ setup = _SetupBuilder(_SetupBuilder.EXIT) self.net().add_attribute(TaskGroup.LOCAL_SETUP, setup) return setup ops = Operations() class _SetupBuilder(NetBuilder): INIT = 'init' EXIT = 'exit' def __init__(self, type, name=None): NetBuilder.__init__(self, name) self.type = type def setup(self, net): if self.type == _SetupBuilder.INIT: return core.to_execution_step(self) def exit(self, net): if self.type == _SetupBuilder.EXIT: return core.to_execution_step(self) class _RunOnce(NetBuilder): def __init__(self, name=None): NetBuilder.__init__(self, name) def __exit__(self, etype, *args): if etype is None and self._stop_blob is not None: ops.stop() NetBuilder.__exit__(self, etype, *args) class _StopGuard(_RunOnce): def __init__(self, name=None, has_stopped_blob=None): _RunOnce.__init__(self, name) self._stopped = has_stopped_blob self._ran = False def __enter__(self): r = _RunOnce.__enter__(self) self._stopped = ops.Const(True, blob_out=self._stopped) return r def __exit__(self, etype, *args): if etype is None: self._ran = True ops.Const(False, blob_out=self._stopped) _RunOnce.__exit__(self, etype, *args) def has_stopped(self): """ Return a blob that will be set to scalar bool `True` after this net builder ran, iff it was halted early. """ assert self._ran, 'Context not used yet.' return self._stopped class _Loop(NetBuilder): def __init__(self, iters=None, name=None): NetBuilder.__init__(self, name, _stop_blob_required=True) if iters is not None: self._inc = ops.Const(1) self._iter = ops.Const(0) self._num_iters = ( iters if isinstance(iters, core.BlobReference) else ops.Const(iters)) else: self._num_iters = None def iter(self): assert self._num_iters is not None, ( 'This loop does not have a number of iterations.') assert self._iter is not None, ( 'iter() must be called from inside the loop context') return self._iter def __enter__(self): builder = NetBuilder.__enter__(self) if self._num_iters is not None: ops.stop_if(ops.GE([self._iter, self._num_iters])) return builder def __exit__(self, type, *args): if type is None and self._num_iters is not None: self.current_net().Add([self._iter, self._inc], [self._iter]) NetBuilder.__exit__(self, type, *args) class _RunIf(_RunOnce): def __init__(self, cond_blob=None, name=None, _already_ran=None): _RunOnce.__init__(self, name) assert cond_blob or _already_ran self._is_else = cond_blob is None if _already_ran is None: self._else_blob = ops.Not(cond_blob) self._already_ran = ops.Const(False) else: self._already_ran = _already_ran self._else_blob = _already_ran if cond_blob is None else ( ops.Or([_already_ran, ops.Not(cond_blob)])) def __enter__(self): r = _RunOnce.__enter__(self) ops.stop_if(self._else_blob) ops.Const(True, blob_out=self._already_ran) return r def Elif(self, cond): assert not self._is_else, 'Else not allowed for an Else.' return NetBuilder.current().add( _RunIf(cond, _already_ran=self._already_ran)) def Else(self): assert not self._is_else, 'Elif not allowed for an Else.' return NetBuilder.current().add( _RunIf(_already_ran=self._already_ran))