pytorch/torch/jit/_recursive.py
davidriazati cb4a6327a3 Delete WeakScriptModuleProxy (#23398)
Summary:
This PR deletes `WeakScriptModuleProxy` and uses `ScriptModule` directly and moves the recursive script stuff into `torch/jit/_recursive.py`. The first commit is just moving code, the latter 2 contain the actual changes
](https://our.intern.facebook.com/intern/diff/16712340/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/23398

Pulled By: driazati

Reviewed By: eellison

Differential Revision: D16712340

fbshipit-source-id: f907efcec59bb2694c079ab655304324c125e9bb
2019-08-12 13:36:47 -07:00

201 lines
7.9 KiB
Python

import inspect
import torch
import collections
import torch._jit_internal as _jit_internal
from torch.nn import Module, ModuleList, Parameter, Sequential
from torch._six import get_function_from_type
def copy_to_script_module(original, stubs):
"""
Copies the parameters, buffers, constants, attributes, and submodules
of an nn.Module into itself.
"""
if not hasattr(original, '_parameters'):
raise RuntimeError("'{}' has not been initialized, did you forget to call 'super()'?"
.format(type(original).__name__))
qualified_name = torch.jit._qualified_name(type(original))
script_module = torch.jit.ScriptModule(_qualified_name=qualified_name)
constants_set = set(getattr(original, "__constants__", []))
script_module.__dict__["_constants_set"] = {}
# Copy Parameters and Modules
for name in dir(original):
item = getattr(original, name)
if item is None and name in original._parameters:
# XXX: treat None value simply as module attributes instead of adding them to the parameter list
# TODO: need to handle this more generally when non-tensor attributes added to module
object.__setattr__(script_module, name, item)
elif item is script_module:
continue
elif isinstance(item, (Parameter, Module, torch.jit.Attribute)):
setattr(script_module, name, item)
# Copy buffers
for name in original._buffers:
if original._buffers[name] is None:
object.__setattr__(script_module, name, None)
else:
script_module.register_buffer(name, original._buffers[name])
# Constants annotated via `Final[T]` rather than being added to `__constants__`
for name, ann in getattr(original, '__annotations__', {}).items():
if torch._jit_internal.is_final(ann):
constants_set.add(name)
# Copy constants
script_module.__dict__["_constants_set"] = constants_set
for name in script_module.__dict__["_constants_set"]:
if hasattr(original, name):
if (name in original._parameters or name in original._buffers) and item is not None:
# for 'None' parameters/buffers, don't actually add their values if it exists
continue
setattr(script_module, name, getattr(original, name))
# Copy annotations, pull types from `__annotations__` or try to infer
# the type if possible
class_annotations = getattr(original, '__annotations__', {})
for name in dir(original):
if name in ("training", "__dict__"):
# TODO: removing this skip should let us remove the code to add training as an
# attribute in python_sugared_value.cpp
continue
if hasattr(script_module, name):
# Don't re-copy properties
continue
item = getattr(original, name)
if name in class_annotations:
the_type = torch.jit.annotations.ann_to_type(class_annotations[name])
else:
the_type = torch._C._jit_try_infer_type(item)
if the_type is not None:
script_module._c._register_attribute(name, the_type, item)
# Copy overloads
script_module.__dict__["_overloads"] = dict(getattr(original, "__overloads__", {}))
# Copy links to Python methods so they can be resolved when compiling
for name in dir(original):
item = getattr(original, name)
if hasattr(script_module, name):
# Skip Python builtins and all the module methods that are already
# attached to this since it inherits from nn.Module
continue
if inspect.ismethod(item):
setattr(script_module, name, item)
torch.jit._create_methods_from_stubs(script_module, stubs)
# Now that methods have been compiled, take methods that have been compiled
# and have them shadow their corresponding Python functions
for method_name in script_module._c._method_names():
setattr(script_module, method_name, script_module._c._get_method(method_name))
return script_module
def recursive_script(mod):
"""
Makes a ScriptModule from an nn.Module. If `_methods` is provided,
these methods are treated as @script_methods. If not, it defaults to
`('forward',)`. Methods accessed in forward are scripted on demand.
"""
if isinstance(mod, torch.jit.ScriptModule):
return mod
if isinstance(mod, (torch.nn.ModuleList, torch.nn.Sequential)):
# Create constant versions for the iterable modules
return create_constant_iterable_module(mod)
methods = ()
if hasattr(mod, 'forward'):
if mod.forward.__func__ == torch.nn.Module.forward:
raise RuntimeError("No forward method was defined on {}".format(mod))
if not _jit_internal.is_ignored_fn(mod.forward):
methods = ('forward',)
exported = []
for name in dir(mod):
item = getattr(mod, name)
if callable(item):
if _jit_internal.get_torchscript_modifier(item) is _jit_internal.FunctionModifiers.EXPORT:
exported.append(name)
methods = methods + tuple(exported)
def make_stub(method):
func = get_function_from_type(type(mod), method)
return torch.jit.script_method(func, _jit_internal.createResolutionCallbackFromClosure(func))
stubs = list(map(make_stub, methods))
return copy_to_script_module(mod, stubs)
def create_method_from_fn(module, fn):
if _jit_internal.is_ignored_fn(fn):
return None
if not inspect.ismethod(fn):
return None
stub = torch.jit.script_method(fn, _jit_internal.createResolutionCallbackFromClosure(fn))
with torch.jit._disable_emit_hooks():
# We don't want to call the hooks here since the graph that is calling
# this function is not yet complete
torch.jit._create_methods_from_stubs(module, (stub,))
return stub
def make_strong_submodule(field, module, parent):
if field not in parent._modules:
# It's not a submodule, don't do anything
return None
# Convert the module to a ScriptModule
new_strong_submodule = recursive_script(module)
# Install the ScriptModule on the python side
parent._modules._python_modules[field] = new_strong_submodule
return new_strong_submodule
def try_compile_fn(fn, loc):
if _jit_internal.is_ignored_fn(fn):
# Don't do anything for @ignore'd functions
return None
if isinstance(fn, torch.nn.Module):
# Since modules are callable pybind recognizes them as functions, but
# don't do anything for them
return None
if not inspect.isfunction(fn) and not inspect.ismethod(fn):
raise RuntimeError("`{}` is not a function. Recursive scripting only supports "
"Python functions or methods currently.\n"
"Consider manually annotating `{}` with @torch.jit.script.".format(fn, fn))
# We don't have the actual scope where the function was defined, but we can
# extract the necessary info from the closed over variables on the function
# object
rcb = _jit_internal.createResolutionCallbackFromClosure(fn)
return torch.jit.script(fn, _rcb=rcb)
def create_constant_iterable_module(module):
modules = collections.OrderedDict()
for key, submodule in module._modules.items():
if isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
# Make each item in the module a constant
modules[key] = create_constant_iterable_module(submodule)
else:
modules[key] = recursive_script(submodule)
if isinstance(module, Sequential):
return torch.jit._ConstSequential(Sequential(modules))
elif isinstance(module, ModuleList):
return torch.jit._ConstModuleList(modules)
else:
raise RuntimeError("Only nn.ModuleList and nn.Sequential can be made "
"into constant modules, found {}".format(module))