## @package context # Module caffe2.python.context from __future__ import absolute_import from __future__ import division from __future__ import print_function from __future__ import unicode_literals import threading import six class _ContextInfo(object): def __init__(self, cls, allow_default, arg_name): self.cls = cls self.allow_default = allow_default self.arg_name = arg_name self._local_stack = threading.local() @property def _stack(self): if not hasattr(self._local_stack, 'obj'): self._local_stack.obj = [] return self._local_stack.obj def enter(self, value): self._stack.append(value) def exit(self, value): assert len(self._stack) > 0, 'Context %s is empty.' % self.cls assert self._stack.pop() == value def get_active(self, required=True): if len(self._stack) == 0: if not required: return None assert self.allow_default, ( 'Context %s is required but none is active.' % self.cls) self.enter(self.cls()) return self._stack[-1] class _ContextRegistry(object): def __init__(self): self._ctxs = {} def register(self, ctx_info): assert isinstance(ctx_info, _ContextInfo) assert (ctx_info.cls not in self._ctxs), ( 'Context %s already registered' % ctx_info.cls) self._ctxs[ctx_info.cls] = ctx_info def get(self, cls): assert cls in self._ctxs, 'Context %s not registered.' % cls return self._ctxs[cls] _CONTEXT_REGISTRY = _ContextRegistry() def _context_registry(): global _CONTEXT_REGISTRY return _CONTEXT_REGISTRY def __enter__(self): if self._prev_enter is not None: self._prev_enter() _context_registry().get(self._ctx_class).enter(self) return self def __exit__(self, *args): _context_registry().get(self._ctx_class).exit(self) if self._prev_exit is not None: self._prev_exit(*args) def __call__(self, func): @six.wraps(func) def wrapper(*args, **kwargs): with self: return func(*args, **kwargs) return wrapper @classmethod def _current(cls, value=None, required=True): return _get_active_context(cls, value, required) class define_context(object): def __init__(self, arg_name=None, allow_default=False): self.arg_name = arg_name self.allow_default = allow_default def __call__(self, cls): assert not hasattr(cls, '_ctx_class'), ( '%s parent class (%s) already defines context.' % ( cls, cls._ctx_class)) cls._ctx_class = cls _context_registry().register( _ContextInfo(cls, self.allow_default, self.arg_name) ) cls._prev_enter = cls.__enter__ if hasattr(cls, '__enter__') else None cls._prev_exit = cls.__exit__ if hasattr(cls, '__exit__') else None cls.__enter__ = __enter__ cls.__exit__ = __exit__ cls.__call__ = __call__ cls.current = _current return cls def _get_active_context(cls, val=None, required=True): ctx_info = _context_registry().get(cls) if val is not None: assert isinstance(val, cls), ( 'Wrong context type. Expected: %s, got %s.' % (cls, type(val))) return val return ctx_info.get_active(required=required)