mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: There is a long lasting problem of scoping which was introduced in original python wrappers early in H1. Basically each RNNCell implemented has to manually scope outputs of each of the operators. If somebody forgets, then there could be weird bugs with layers etc. Approach is the following. User has to explicitly specify current scope when using apply_over_sequence function and others if the function is going to be called several times (like for stacking layers). This way we use Caffe2 native scoping approach instead of inventing one extra API people have to use (i.e. passing scope name as an argument to the RNNCell constructor). Closes https://github.com/caffe2/caffe2/pull/1681 Differential Revision: D6777536 Pulled By: salexspb fbshipit-source-id: 73d860b8d4857589e04bdea5a6fcd3080d68427c
115 lines
3.7 KiB
Python
115 lines
3.7 KiB
Python
# Copyright (c) 2016-present, Facebook, Inc.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
##############################################################################
|
|
|
|
## @package scope
|
|
# Module caffe2.python.scope
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
from __future__ import unicode_literals
|
|
|
|
import contextlib
|
|
import threading
|
|
from past.builtins import basestring
|
|
|
|
from caffe2.proto import caffe2_pb2
|
|
|
|
|
|
# The name scope and device scope when creating a new operator.
|
|
_NAMESCOPE_SEPARATOR = '/'
|
|
|
|
_threadlocal_scope = threading.local()
|
|
|
|
|
|
def CurrentNameScope():
|
|
global _threadlocal_scope
|
|
if not hasattr(_threadlocal_scope, "namescope"):
|
|
_threadlocal_scope.namescope = ''
|
|
return _threadlocal_scope.namescope
|
|
|
|
|
|
def CurrentDeviceScope():
|
|
global _threadlocal_scope
|
|
if not hasattr(_threadlocal_scope, "devicescope"):
|
|
_threadlocal_scope.devicescope = None
|
|
return _threadlocal_scope.devicescope
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def NameScope(prefix, reset=False):
|
|
global _threadlocal_scope
|
|
assert isinstance(prefix, basestring) or prefix is None, \
|
|
"NameScope takes in a string as its argument."
|
|
old_scope = CurrentNameScope()
|
|
prefix = prefix + _NAMESCOPE_SEPARATOR if prefix else ''
|
|
if reset:
|
|
_threadlocal_scope.namescope = prefix
|
|
else:
|
|
_threadlocal_scope.namescope = _threadlocal_scope.namescope + prefix
|
|
|
|
try:
|
|
yield
|
|
finally:
|
|
assert _threadlocal_scope.namescope.endswith(prefix), \
|
|
"The namescope variable is changed from outside NameScope() calls."
|
|
_threadlocal_scope.namescope = old_scope
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def DeviceScope(scope, node_name=None):
|
|
new_scope = caffe2_pb2.DeviceOption()
|
|
if scope:
|
|
assert isinstance(scope, caffe2_pb2.DeviceOption), \
|
|
"DeviceScope takes in a caffe2_pb2.DeviceOption as its argument."
|
|
new_scope.CopyFrom(scope)
|
|
else:
|
|
assert node_name, "At least one argument should be non-null in DeviceScope"
|
|
|
|
# rewrite node_name if it is explicitly given
|
|
if node_name:
|
|
new_scope.node_name = node_name
|
|
global _threadlocal_scope
|
|
old_scope = CurrentDeviceScope()
|
|
# nested scope should inherit the node_name if it is not explicitly set
|
|
if old_scope and old_scope.HasField('node_name') and \
|
|
not new_scope.HasField('node_name'):
|
|
new_scope.node_name = old_scope.node_name
|
|
_threadlocal_scope.devicescope = new_scope
|
|
try:
|
|
yield
|
|
finally:
|
|
assert _threadlocal_scope.devicescope == new_scope, \
|
|
"The device scope is changed from outside DeviceScope() calls."
|
|
_threadlocal_scope.devicescope = old_scope
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def EmptyDeviceScope():
|
|
"""
|
|
Allow users to 'disable' the device scope behaviour (so it can be
|
|
controlled at a NetDef::DeviceOption level, not overridden at
|
|
OperatorDef::DeviceOption level).
|
|
|
|
This sets the CurrentDeviceScope() to None, so that the field is
|
|
not set in CreateOperator(...), etc.
|
|
"""
|
|
old_scope = CurrentDeviceScope()
|
|
try:
|
|
_threadlocal_scope.devicescope = None
|
|
yield
|
|
finally:
|
|
_threadlocal_scope.devicescope = old_scope
|
|
return
|