pytorch/caffe2/python/scope.py
Alexander Sidorov e0e124e617 Fix RNN scoping situation
Summary:
There is a long lasting problem of scoping which was introduced in original python wrappers early in H1. Basically each RNNCell implemented has to manually scope outputs of each of the operators. If somebody forgets, then there could be weird bugs with layers etc.

Approach is the following. User has to explicitly specify current scope when using  apply_over_sequence function and others if the function is going to be called several times (like for stacking layers). This way we use Caffe2 native scoping approach instead of inventing one extra API people have to use (i.e. passing scope name as an argument to the RNNCell constructor).
Closes https://github.com/caffe2/caffe2/pull/1681

Differential Revision: D6777536

Pulled By: salexspb

fbshipit-source-id: 73d860b8d4857589e04bdea5a6fcd3080d68427c
2018-02-07 17:35:29 -08:00

115 lines
3.7 KiB
Python

# Copyright (c) 2016-present, Facebook, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##############################################################################
## @package scope
# Module caffe2.python.scope
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import contextlib
import threading
from past.builtins import basestring
from caffe2.proto import caffe2_pb2
# The name scope and device scope when creating a new operator.
_NAMESCOPE_SEPARATOR = '/'
_threadlocal_scope = threading.local()
def CurrentNameScope():
global _threadlocal_scope
if not hasattr(_threadlocal_scope, "namescope"):
_threadlocal_scope.namescope = ''
return _threadlocal_scope.namescope
def CurrentDeviceScope():
global _threadlocal_scope
if not hasattr(_threadlocal_scope, "devicescope"):
_threadlocal_scope.devicescope = None
return _threadlocal_scope.devicescope
@contextlib.contextmanager
def NameScope(prefix, reset=False):
global _threadlocal_scope
assert isinstance(prefix, basestring) or prefix is None, \
"NameScope takes in a string as its argument."
old_scope = CurrentNameScope()
prefix = prefix + _NAMESCOPE_SEPARATOR if prefix else ''
if reset:
_threadlocal_scope.namescope = prefix
else:
_threadlocal_scope.namescope = _threadlocal_scope.namescope + prefix
try:
yield
finally:
assert _threadlocal_scope.namescope.endswith(prefix), \
"The namescope variable is changed from outside NameScope() calls."
_threadlocal_scope.namescope = old_scope
@contextlib.contextmanager
def DeviceScope(scope, node_name=None):
new_scope = caffe2_pb2.DeviceOption()
if scope:
assert isinstance(scope, caffe2_pb2.DeviceOption), \
"DeviceScope takes in a caffe2_pb2.DeviceOption as its argument."
new_scope.CopyFrom(scope)
else:
assert node_name, "At least one argument should be non-null in DeviceScope"
# rewrite node_name if it is explicitly given
if node_name:
new_scope.node_name = node_name
global _threadlocal_scope
old_scope = CurrentDeviceScope()
# nested scope should inherit the node_name if it is not explicitly set
if old_scope and old_scope.HasField('node_name') and \
not new_scope.HasField('node_name'):
new_scope.node_name = old_scope.node_name
_threadlocal_scope.devicescope = new_scope
try:
yield
finally:
assert _threadlocal_scope.devicescope == new_scope, \
"The device scope is changed from outside DeviceScope() calls."
_threadlocal_scope.devicescope = old_scope
@contextlib.contextmanager
def EmptyDeviceScope():
"""
Allow users to 'disable' the device scope behaviour (so it can be
controlled at a NetDef::DeviceOption level, not overridden at
OperatorDef::DeviceOption level).
This sets the CurrentDeviceScope() to None, so that the field is
not set in CreateOperator(...), etc.
"""
old_scope = CurrentDeviceScope()
try:
_threadlocal_scope.devicescope = None
yield
finally:
_threadlocal_scope.devicescope = old_scope
return