from caffe2.python import scope import contextlib import logging logger = logging.getLogger(__name__) class ParameterSharingContext(object): """ This class manages scope driven way of parameter sharing across different NameScopes. """ def __init__(self): self._scope_overrides = {} self._contexts = [] def _resolve_scope_overrides(self, candidate_scope): """ Recursively resolves all scope overrides, i.e multiple steps of override can be used. For example, if one provides following scope overrides: {'scope_b': 'scope_a'} and within 'scope_b' - {'shared_child': ''}, then name 'w' will get resolved to the following blobs depending on the namescope: a. 'scope_a' -> 'scope_a/w' b. 'scope_b' -> 'scope_a/w' c. 'scope_c' -> 'scope_c/w' d. 'scope_b/shared_child' -> 'scope_a/w' d. 'scope_b/unshared_child' -> 'scope_a/unshared_child/w' """ best_scope = candidate_scope best_scope_idx = 0 sub_scopes = candidate_scope.split(scope._NAMESCOPE_SEPARATOR) cur_scope = '' for idx, sub_scope in enumerate(sub_scopes): cur_scope = cur_scope + sub_scope + scope._NAMESCOPE_SEPARATOR if cur_scope in self._scope_overrides: best_scope = self._scope_overrides[cur_scope] best_scope_idx = idx if best_scope == candidate_scope: return candidate_scope else: return (self._resolve_scope_overrides(best_scope) + scope._NAMESCOPE_SEPARATOR.join( sub_scopes[best_scope_idx + 1:])) def get_parameter_name(self, name): candidate_scope = scope.CurrentNameScope() best_scope = self._resolve_scope_overrides(candidate_scope) if best_scope != candidate_scope: logger.info("Overwriting scope {0} with scope {1}".format( candidate_scope, best_scope)) return best_scope + name def add_scope_overrides(self, shared_scopes): self._contexts.append(shared_scopes) self._scope_overrides.update(shared_scopes) def pop(self): assert len(self._contexts) > 0 self._contexts.pop() self._scope_overrides = {} for x in self._contexts: self._scope_overrides.update(x) parameter_sharing_context = ParameterSharingContext() def _normalize_namescope(namescope): if namescope and namescope[-1] != scope._NAMESCOPE_SEPARATOR: return namescope + scope._NAMESCOPE_SEPARATOR else: return namescope @contextlib.contextmanager def ParameterSharing(shared_scopes): """ Helper function for sharing scopes. All the parameters within the shared_scopes, will be remapped with the respect of CurrentNamescope() I.e. if one calls ParameterSharing with {'scope_b': 'scope_'a'}, from the scope 'some_global_scope', it'll effectively mean, that all parameters from 'some_global_scope/scope_b' will shared with the parameters from 'some_global_scope/scope_a' """ assert isinstance(shared_scopes, dict) shared_scope_overrides = {} current_scope = scope.CurrentNameScope() for k, v in shared_scopes.items(): assert not v.startswith(k), ( "Illegal override for parameter sharing. {} is prefix of {}". format(k, v)) k = current_scope + k v = current_scope + v # Normalize all the scopes, so scope_a and scope_a/ are equivalent k = _normalize_namescope(k) v = _normalize_namescope(v) shared_scope_overrides[k] = v try: parameter_sharing_context.add_scope_overrides(shared_scope_overrides) yield finally: parameter_sharing_context.pop()