mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Adds support for weak script modules created that get compiled to `ScriptModule`s once added as a submodule of a `ScriptModule`: ```python weak_module class Test(torch.nn.Module): ... weak_script_method def forward(self, x): ... ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/12682 Differential Revision: D10458626 Pulled By: driazati fbshipit-source-id: 10ae23cb83cdafc4646cee58f399e14b2e60acd4
102 lines
2.7 KiB
Python
102 lines
2.7 KiB
Python
"""
|
|
The weak_script annotation needs to be here instead of inside torch/jit/ so it
|
|
can be used in other places in torch/ (namely torch.nn) without running into
|
|
circular dependency problems
|
|
"""
|
|
|
|
import weakref
|
|
import inspect
|
|
try:
|
|
import builtins # PY3
|
|
except Exception:
|
|
import __builtin__ as builtins # PY2
|
|
|
|
# Tracks standalone weak script functions
|
|
_compiled_weak_fns = weakref.WeakKeyDictionary()
|
|
|
|
# Tracks which methods should be converted to strong methods
|
|
_weak_script_methods = weakref.WeakKeyDictionary()
|
|
|
|
# Converted modules and their corresponding WeakScriptModuleProxy objects
|
|
_weak_modules = weakref.WeakKeyDictionary()
|
|
|
|
# Types that have been declared as weak modules
|
|
_weak_types = weakref.WeakKeyDictionary()
|
|
|
|
COMPILATION_PENDING = object()
|
|
COMPILED = object()
|
|
|
|
|
|
def createResolutionCallback(frames_up=0):
|
|
"""
|
|
Creates a function which, given a string variable name,
|
|
returns the value of the variable in the scope of the caller of
|
|
the function which called createResolutionCallback (by default).
|
|
|
|
This is used to enable access in-scope Python variables inside
|
|
TorchScript fragments.
|
|
|
|
frames_up is number of additional frames to go up on the stack.
|
|
The default value is 0, which correspond to the frame of the caller
|
|
of createResolutionCallback. Also for example, if frames_up is set
|
|
to 1, then the frame of the caller's caller of createResolutionCallback
|
|
will be taken.
|
|
|
|
For example, the following program prints 2::
|
|
|
|
def bar():
|
|
cb = createResolutionCallback(1)
|
|
print(cb("foo"))
|
|
|
|
def baz():
|
|
foo = 2
|
|
bar()
|
|
|
|
baz()
|
|
"""
|
|
frame = inspect.stack()[1 + frames_up][0]
|
|
f_locals = frame.f_locals
|
|
f_globals = frame.f_globals
|
|
|
|
def env(key):
|
|
if key in f_locals:
|
|
return f_locals[key]
|
|
elif key in f_globals:
|
|
return f_globals[key]
|
|
elif hasattr(builtins, key):
|
|
return getattr(builtins, key)
|
|
else:
|
|
return None
|
|
|
|
return env
|
|
|
|
|
|
def weak_script(fn, _frames_up=0):
|
|
"""
|
|
Marks a function as a weak script function. When used in a script function
|
|
or ScriptModule, the weak script function will be lazily compiled and
|
|
inlined in the graph. When not used in a script function, the weak script
|
|
annotation has no effect.
|
|
"""
|
|
_compiled_weak_fns[fn] = {
|
|
"status": COMPILATION_PENDING,
|
|
"compiled_fn": None,
|
|
"rcb": createResolutionCallback(_frames_up + 1)
|
|
}
|
|
return fn
|
|
|
|
|
|
def weak_module(cls):
|
|
_weak_types[cls] = {
|
|
"method_stubs": None
|
|
}
|
|
return cls
|
|
|
|
|
|
def weak_script_method(fn):
|
|
_weak_script_methods[fn] = {
|
|
"rcb": createResolutionCallback(frames_up=2),
|
|
"original_method": fn
|
|
}
|
|
return fn
|