mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary:
arg_scope module for model_helpers.
Some coding example with it:
with model_helpers.arg_scope([model_helpers.FC], kwargs):
model_helpers.FC(model, "x", "out_1", n, n)
with model_helpers.arg_scope([myhelper], n=-3):
with model_helpers.arg_scope([myhelper], n=-2):
with model_helpers.arg_scope([myhelper], n=n):
res = model_helpers.myhelper(None)
with model_helpers.arg_scope([myhelper], n=-3), \
model_helpers.arg_scope([myhelper], n=-2), \
model_helpers.arg_scope([myhelper], n=n):
res = model_helpers.myhelper(None)
Reviewed By: salexspb
Differential Revision: D4837180
fbshipit-source-id: 2cbd81681779d6cd1e61ee189edcc1cf3bb07d15
36 lines
1.2 KiB
Python
36 lines
1.2 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
import contextlib
|
|
import copy
|
|
import threading
|
|
|
|
_threadlocal_scope = threading.local()
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def arg_scope(single_helper_or_list, **kwargs):
|
|
global _threadlocal_scope
|
|
if not isinstance(single_helper_or_list, list):
|
|
assert callable(single_helper_or_list), \
|
|
"arg_scope is only supporting single or a list of helper functions."
|
|
single_helper_or_list = [single_helper_or_list]
|
|
old_scope = copy.deepcopy(get_current_scope())
|
|
for helper in single_helper_or_list:
|
|
assert callable(helper), \
|
|
"arg_scope is only supporting a list of callable helper functions."
|
|
helper_key = helper.__name__
|
|
if helper_key not in old_scope:
|
|
_threadlocal_scope.current_scope[helper_key] = {}
|
|
_threadlocal_scope.current_scope[helper_key].update(kwargs)
|
|
|
|
yield
|
|
_threadlocal_scope.current_scope = old_scope
|
|
|
|
|
|
def get_current_scope():
|
|
global _threadlocal_scope
|
|
if not hasattr(_threadlocal_scope, "current_scope"):
|
|
_threadlocal_scope.current_scope = {}
|
|
return _threadlocal_scope.current_scope
|