mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: # Description I'm new to this project just wanted to start with small bug fixes. I found some unused local variables and I've removed them in this pr. Pull Request resolved: https://github.com/pytorch/pytorch/pull/29181 Differential Revision: D18319893 Pulled By: suo fbshipit-source-id: e4f9f13b6db2ca213015569deb12d3fd9beb74a8
618 lines
26 KiB
Python
618 lines
26 KiB
Python
import inspect
|
|
import torch
|
|
import collections
|
|
import types
|
|
import textwrap
|
|
import functools
|
|
import warnings
|
|
|
|
import torch._jit_internal as _jit_internal
|
|
from torch.jit.frontend import get_default_args
|
|
from torch.nn import Module, ModuleList, Sequential, ModuleDict
|
|
from torch._six import get_function_from_type, bind_method
|
|
|
|
|
|
ScriptMethodStub = collections.namedtuple('ScriptMethodStub', ('resolution_callback', 'def_', 'original_method'))
|
|
|
|
# TODO: there should be a more principled way of doing this.
|
|
blacklist = [
|
|
"_version",
|
|
"_parameters",
|
|
"_buffers",
|
|
"_modules",
|
|
"_initializing",
|
|
"_backward_hooks",
|
|
"_forward_hooks",
|
|
"_forward_pre_hooks",
|
|
"_state_dict_hooks",
|
|
"_load_state_dict_pre_hooks",
|
|
"dump_patches",
|
|
]
|
|
|
|
def make_stub(func):
|
|
rcb = _jit_internal.createResolutionCallbackFromClosure(func)
|
|
ast = torch.jit.get_jit_def(func, self_name="RecursiveScriptModule")
|
|
return ScriptMethodStub(rcb, ast, func)
|
|
|
|
def make_stub_from_method(nn_module, method):
|
|
func = get_function_from_type(type(nn_module), method)
|
|
if isinstance(func, ScriptMethodStub):
|
|
return func
|
|
return make_stub(func)
|
|
|
|
# base types that can be constants
|
|
# in addition, tuples and lists of these base types are also considered constants
|
|
# If you edit this list, then you also need to edit the handlers in
|
|
# ConstantValue in jit/script/init.cpp
|
|
_constant_types = (bool, float, int, str, type(None), types.FunctionType, torch.device, torch.layout, torch.dtype)
|
|
|
|
def _get_valid_constant(attr, v):
|
|
if isinstance(v, _constant_types):
|
|
return v
|
|
elif isinstance(v, tuple) or isinstance(v, list):
|
|
return tuple(_get_valid_constant(attr, x) for x in v)
|
|
constants = ", ".join(typ.__name__ for typ in _constant_types)
|
|
raise TypeError(textwrap.dedent("""
|
|
'{}' object for attribute '{}' is not a valid constant.
|
|
Valid constants are:
|
|
1. a nn.ModuleList
|
|
2. a value of type {{{}}}
|
|
3. a list or tuple of (2)
|
|
""".format(type(v).__name__, attr, constants)))
|
|
|
|
def infer_raw_concrete_type(nn_module):
|
|
"""
|
|
Build a ConcreteModuleType from an nn.Module. This ConcreteModuleType
|
|
doesn't have a JIT type associated with it yet, it must be filled in
|
|
by the caller.
|
|
"""
|
|
concrete_type = torch._C.ConcreteModuleType()
|
|
concrete_type.add_pyclass(type(nn_module))
|
|
if isinstance(nn_module, (torch.nn.ModuleDict, torch.jit._ConstModuleDict)):
|
|
concrete_type.set_module_dict()
|
|
if isinstance(nn_module, (torch.nn.ModuleList, torch.nn.Sequential, torch.jit._ConstModuleList)):
|
|
concrete_type.set_module_list()
|
|
|
|
added_names = set()
|
|
for name, item in nn_module._parameters.items():
|
|
if item is None:
|
|
# TODO special case: parameters can be None. The JIT assumes
|
|
# parameters are Tensor types, so in this case just add it as a
|
|
# attribute.
|
|
# The "correct" fix here is to add the parameter as a NoneType
|
|
# attribute, but NoneType refinemenet is currently wonky
|
|
continue
|
|
assert isinstance(item, torch.Tensor)
|
|
attr_type = torch._C._jit_try_infer_type(item)
|
|
concrete_type.add_attribute(name, attr_type, True)
|
|
added_names.add(name)
|
|
|
|
for name, item in nn_module._modules.items():
|
|
sub_concrete_type = concrete_type_store.get_or_create_concrete_type(item)
|
|
concrete_type.add_module(name, sub_concrete_type)
|
|
added_names.add(name)
|
|
|
|
for name, item in nn_module._buffers.items():
|
|
if item is None:
|
|
# TODO special case: parameters can be None. The JIT assumes
|
|
# parameters are Tensor types, so in this case just add it as a
|
|
# attribute
|
|
# The "correct" fix here is to add the parameter as a NoneType
|
|
# attribute, but NoneType refinemenet is currently wonky
|
|
continue
|
|
assert isinstance(item, torch.Tensor)
|
|
attr_type = torch._C._jit_try_infer_type(item)
|
|
concrete_type.add_attribute(name, attr_type, False)
|
|
added_names.add(name)
|
|
|
|
# populate constants_set
|
|
constants_set = getattr(nn_module, "__constants__", set())
|
|
|
|
# Constants annotated via `Final[T]` rather than being added to `__constants__`
|
|
for name, ann in getattr(nn_module, '__annotations__', {}).items():
|
|
if torch._jit_internal.is_final(ann):
|
|
constants_set.add(name)
|
|
|
|
for name in constants_set:
|
|
if name in added_names:
|
|
# XXX: It is possible for something to be in the constants set but
|
|
# also in the parameters/buffers. This happens in BatchNorm as a
|
|
# hack to support optional parameters.
|
|
continue
|
|
if not hasattr(nn_module, name):
|
|
# TODO: We should really error in this case, but there are a couple
|
|
# extant examples of this so leave it for a future PR.
|
|
warnings.warn("'{}' was found in ScriptModule constants, "
|
|
"but was not actually set in __init__. "
|
|
"Consider removing it.".format(name))
|
|
continue
|
|
value = getattr(nn_module, name)
|
|
concrete_type.add_constant(name, _get_valid_constant(name, value))
|
|
added_names.add(name)
|
|
|
|
# populate overloads
|
|
overloads = getattr(nn_module, "__overloads__", {})
|
|
# update with any annotated overloads
|
|
overloads.update(get_overload_name_mapping(get_overload_annotations(nn_module)))
|
|
for name, overloaded_names in overloads.items():
|
|
concrete_type.add_overload(name, overloaded_names)
|
|
|
|
class_annotations = getattr(nn_module, '__annotations__', {})
|
|
|
|
# TODO: [switch to __dict__]
|
|
# we should use __dict__ here because we only want to pick up attributes on
|
|
# this module instance, not the class itself. We can't do it right now
|
|
# because there is code that relies on properties being turned into attributes.
|
|
# This is wrong (the property function is only evaluated once then "saved"
|
|
# as an attribute), so we should fix that and then switch this to using __dict__
|
|
for name in dir(nn_module):
|
|
if name in blacklist or name.startswith("__"):
|
|
# Python objects have lots of random attributes attached to them;
|
|
# PyTorch adds a few more. Prevent these from getting compiled.
|
|
continue
|
|
|
|
if name in added_names:
|
|
# Don't re-add anything we already added
|
|
continue
|
|
|
|
if not hasattr(nn_module, name):
|
|
# TODO: delete this when [switch to __dict__]
|
|
continue
|
|
|
|
item = getattr(nn_module, name)
|
|
if name not in nn_module.__dict__ and not isinstance(getattr(type(nn_module), name, None), property):
|
|
# Skip class attributes that aren't properties
|
|
# TODO: delete this when [switch to __dict__]
|
|
continue
|
|
|
|
# Handle Python function attributes
|
|
if inspect.isfunction(item) and not inspect.ismethod(item):
|
|
cls_attr = getattr(type(nn_module), name, None)
|
|
if inspect.isfunction(cls_attr):
|
|
# Skip function attributes that exist on the nn_module class.
|
|
# TODO: delete this when [switch to __dict__]
|
|
continue
|
|
|
|
try:
|
|
scripted_fn = torch.jit.script(item)
|
|
concrete_type.add_function_attribute(
|
|
name,
|
|
torch._C._jit_try_infer_type(scripted_fn),
|
|
item)
|
|
except Exception as e:
|
|
# If we fail to script the function, it isn't a hard error.
|
|
# Instead, we will add it to the list of attributes we failed
|
|
# to convert, with the compilation error.
|
|
hint = ("(This function exists as an attribute on the Python module, "
|
|
"but we failed to compile it to a TorchScript function. "
|
|
"\nThe error stack is reproduced here:\n{}").format(e)
|
|
concrete_type.add_failed_attribute(name, hint)
|
|
pass
|
|
|
|
continue
|
|
|
|
# Handle Script function attributes
|
|
if isinstance(item, torch.jit.ScriptFunction):
|
|
concrete_type.add_function_attribute(
|
|
name,
|
|
torch._C._jit_try_infer_type(item),
|
|
item)
|
|
continue
|
|
|
|
# If we got here, this is a regular "data" attribute. Try to infer to
|
|
# the type and add it to the concrete type
|
|
if name in class_annotations:
|
|
attr_type = torch.jit.annotations.ann_to_type(class_annotations[name])
|
|
elif isinstance(item, torch.jit.Attribute):
|
|
attr_type = torch.jit.annotations.ann_to_type(item.type)
|
|
else:
|
|
attr_type = torch._C._jit_try_infer_type(item)
|
|
|
|
if attr_type is not None:
|
|
concrete_type.add_attribute(name, attr_type, False)
|
|
else:
|
|
# TODO: could add more detail here. For example, what the user should do
|
|
# when the pytype is `list` or `NoneType`
|
|
hint = ("(This attribute exists on the Python module, "
|
|
"but we failed to convert Python type: '{}' "
|
|
"to a TorchScript type.)").format(type(item).__name__)
|
|
concrete_type.add_failed_attribute(name, hint)
|
|
|
|
return concrete_type
|
|
|
|
class ConcreteTypeStore(object):
|
|
def __init__(self):
|
|
# Python module type => List[ConcreteModuleType)]
|
|
self.type_store = {}
|
|
# ConcreteTypes that have had their methods already compiled
|
|
self.methods_compiled = set()
|
|
|
|
def get_or_create_concrete_type(self, nn_module):
|
|
"""
|
|
Infer a ConcreteType from this `nn.Module` instance. Underlying JIT
|
|
types are re-used if possible.
|
|
"""
|
|
assert isinstance(nn_module, Module)
|
|
if isinstance(nn_module, torch.jit.ScriptModule) and \
|
|
hasattr(nn_module, "_concrete_type"):
|
|
return nn_module._concrete_type
|
|
|
|
if isinstance(nn_module, (torch.nn.ModuleList, torch.nn.Sequential, torch.nn.ModuleDict)):
|
|
# TODO: This is here because the compilation path for constant iterable
|
|
# modules is different from everything else. Instead of calling
|
|
# create_script_module, we directly create a
|
|
# _ConstSequential/ModuleList/ModuleDict instance.
|
|
#
|
|
# The path used to create ConcreteTypes involves going in and analyzing
|
|
# all the nn.Modules ahead of time.
|
|
#
|
|
# That leads to skew where the result of generating a ConcreteType
|
|
# (which involves looking at torch.nn.Sequential) is different from the
|
|
# actual compilation path (which directly builds _ConstSequential).
|
|
#
|
|
# The right solution is to make these modules not special in the
|
|
# compilation path. But for now, just mimic what compilation does when
|
|
# generating a ConcreteType
|
|
scripted = create_constant_iterable_module(nn_module)
|
|
return scripted._concrete_type
|
|
|
|
raw_concrete_type = infer_raw_concrete_type(nn_module)
|
|
|
|
nn_module_type = type(nn_module)
|
|
if nn_module_type not in self.type_store:
|
|
self.type_store[nn_module_type] = []
|
|
|
|
# Search the type store for an already-available JIT type
|
|
known_types = self.type_store[nn_module_type]
|
|
for known_type in known_types:
|
|
if raw_concrete_type.equals(known_type):
|
|
return known_type
|
|
|
|
# We didn't find anything; generate a new JIT type from this concrete type
|
|
raw_concrete_type.create_new_type_from_this()
|
|
self.type_store[nn_module_type].append(raw_concrete_type)
|
|
return raw_concrete_type
|
|
|
|
concrete_type_store = ConcreteTypeStore()
|
|
|
|
def create_methods_from_stubs(concrete_type, stubs):
|
|
defs = [m.def_ for m in stubs]
|
|
rcbs = [m.resolution_callback for m in stubs]
|
|
defaults = [get_default_args(m.original_method) for m in stubs]
|
|
concrete_type._create_methods(defs, rcbs, defaults)
|
|
|
|
def create_script_module_for_tracing(nn_module, stubs):
|
|
"""
|
|
Creates a new ScriptModule from an nn.Module, but always uses a fresh type.
|
|
|
|
NOTE: Only use this when we cannot guarantee type sharing will work
|
|
correctly. This only happens today for traced modules, where the same
|
|
module can produce different traced methods depending on the inputs.
|
|
|
|
Arguments:
|
|
nn_module: The original Python nn.Module that we are creating a ScriptModule for.
|
|
stubs: ScriptMethodStubs to compile as part of the conversion process.
|
|
"""
|
|
check_module_initialized(nn_module)
|
|
# Get a ConcreteType without a JIT type. We will generate one ourselves
|
|
# and fill it in.
|
|
concrete_type = infer_raw_concrete_type(nn_module)
|
|
cpp_module = torch._C.ScriptModule(torch._jit_internal._qualified_name(type(nn_module)),
|
|
torch.jit._python_cu,
|
|
True)
|
|
# Poison this concrete type to ensure that it never gets re-used
|
|
concrete_type.set_poisoned()
|
|
concrete_type.add_jit_type(cpp_module._type())
|
|
|
|
return create_script_module_impl(nn_module, concrete_type, cpp_module, stubs)
|
|
|
|
|
|
def create_script_module(nn_module, stubs):
|
|
"""
|
|
Creates a new ScriptModule from an nn.Module, sharing underlying JIT types if possible
|
|
|
|
Arguments:
|
|
nn_module: The original Python nn.Module that we are creating a ScriptModule for.
|
|
stubs: ScriptMethodStubs to compile as part of the conversion process.
|
|
"""
|
|
check_module_initialized(nn_module)
|
|
concrete_type = concrete_type_store.get_or_create_concrete_type(nn_module)
|
|
cpp_module = torch._C._create_module_with_type(concrete_type.jit_type)
|
|
|
|
return create_script_module_impl(nn_module, concrete_type, cpp_module, stubs)
|
|
|
|
def create_script_module_impl(nn_module, concrete_type, cpp_module, stubs):
|
|
"""
|
|
Convert an nn.Module to a RecursiveScriptModule.
|
|
|
|
Arguments:
|
|
nn_module: The original Python nn.Module that we are creating a ScriptModule for.
|
|
concrete_type: The fully initialized ConcreteType of the module.
|
|
cpp_module: A newly-constructed C++ script::Module to copy stuff into.
|
|
stubs: ScriptMethodStubs to compile as part of the conversion process.
|
|
"""
|
|
assert concrete_type.jit_type and concrete_type.jit_type == cpp_module._type()
|
|
|
|
def init_fn(script_module):
|
|
# Initialize the ScriptModule:
|
|
# 1. Copy the attributes/parameters/buffers from the original `nn_module` to the new ScriptModule.
|
|
for name, (attr_type, is_param) in concrete_type.get_attributes().items():
|
|
orig_value = getattr(nn_module, name)
|
|
|
|
if is_param:
|
|
cpp_module._register_parameter(name, orig_value, False)
|
|
elif isinstance(orig_value, torch.jit.Attribute):
|
|
cpp_module._register_attribute(name, attr_type, orig_value.value)
|
|
else:
|
|
cpp_module._register_attribute(name, attr_type, orig_value)
|
|
|
|
# 2. Copy the submodules from the original `nn_module` to the new ScriptModule,
|
|
# recursively scripting them.
|
|
for name in concrete_type.get_module_names():
|
|
orig_value = getattr(nn_module, name)
|
|
assert isinstance(orig_value, Module)
|
|
scripted = recursive_script(orig_value)
|
|
cpp_module._register_module(name, scripted._c)
|
|
|
|
script_module._modules[name] = scripted
|
|
|
|
# 3. Copy @ignored/@unused methods from the original `nn_module` to the new ScriptModule.
|
|
# This ensures we can access these Python methods on the ScriptModule.
|
|
for name in dir(nn_module):
|
|
item = getattr(nn_module, name, None)
|
|
if not inspect.ismethod(item):
|
|
continue
|
|
if _jit_internal.is_ignored_fn(item):
|
|
setattr(script_module, name, item)
|
|
|
|
# For convenience, attach the concrete type to the new ScriptModule
|
|
script_module._concrete_type = concrete_type
|
|
|
|
# Actually create the ScriptModule, initializing it with the function we just defined
|
|
script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
|
|
|
|
# Compile methods if necessary
|
|
if concrete_type not in concrete_type_store.methods_compiled:
|
|
create_methods_from_stubs(concrete_type, stubs)
|
|
torch._C._run_emit_module_hook(cpp_module)
|
|
concrete_type_store.methods_compiled.add(concrete_type)
|
|
|
|
# Make the compiled methods available to the Python ScriptModule class.
|
|
for stub in stubs:
|
|
if stub.original_method is None:
|
|
# define()'d methods don't have an Python original_method, so we
|
|
# don't need to do any Python re-wrapping stuff
|
|
continue
|
|
|
|
name = stub.original_method.__name__
|
|
if name != stub.def_.name().name:
|
|
# TODO: Why skip this? Because @torch.jit._overload_method will
|
|
# mangle the name of the function.
|
|
continue
|
|
script_method = cpp_module._get_method(name)
|
|
|
|
# Wrap the original to propagate docstrings and such.
|
|
# TODO: we don't currently do this functions that are recursively
|
|
# compiled, we should.
|
|
script_method = functools.wraps(stub.original_method)(script_method)
|
|
|
|
# Add the methods to the script_module directly. This ensures they will
|
|
# be found first when `name` is looked up (as opposed to the stubs or
|
|
# nn.Module.forward)
|
|
script_module.__dict__[name] = script_method
|
|
|
|
return script_module
|
|
|
|
def get_overload_annotations(mod):
|
|
# original function => [(mangled overload name, overload function)]
|
|
overloads = {}
|
|
for name in dir(mod):
|
|
item = getattr(mod, name, None)
|
|
if not callable(item):
|
|
continue
|
|
|
|
# builtin functions like repr() in python 2 do not have __module__ defined
|
|
if hasattr(item, "__module__") and item.__module__ is not None:
|
|
method_overloads = _jit_internal._get_overloaded_methods(item, mod.__class__)
|
|
if method_overloads is None:
|
|
continue
|
|
|
|
names = [name + "__" + str(i) for i in range(len(method_overloads))]
|
|
overloads[item] = list(zip(names, method_overloads))
|
|
|
|
return overloads
|
|
|
|
def get_overload_name_mapping(overload_info):
|
|
# Same format as __overloads__
|
|
# original function => [overload names]
|
|
overload_name_mappings = {}
|
|
for orig_fn, overloads in overload_info.items():
|
|
original_name = orig_fn.__name__
|
|
if original_name not in overload_name_mappings:
|
|
overload_name_mappings[original_name] = []
|
|
|
|
for overload_name, _ in overloads:
|
|
overload_name_mappings[original_name].append(overload_name)
|
|
return overload_name_mappings
|
|
|
|
def make_stubs_for_overloads(overload_info):
|
|
overload_stubs = []
|
|
for orig_fn, overloads in overload_info.items():
|
|
orig_ast = torch.jit.get_jit_def(orig_fn, self_name="RecursiveScriptModule")
|
|
for overload_name, overload_fn in overloads:
|
|
torch.jit._check_no_signature(overload_fn)
|
|
over_ast = torch.jit.get_jit_def(overload_fn, self_name="RecursiveScriptModule")
|
|
new_ast = torch._C._replace_overloaded_method_decl(over_ast.decl(), orig_ast, overload_name)
|
|
_rcb = _jit_internal.createResolutionCallbackFromClosure(orig_fn)
|
|
overload_stubs.append(ScriptMethodStub(_rcb, new_ast, overload_fn))
|
|
return overload_stubs
|
|
|
|
def check_module_initialized(mod):
|
|
assert isinstance(mod, torch.nn.Module)
|
|
if not hasattr(mod, '_parameters'):
|
|
raise RuntimeError("'{}' has not been initialized, did you forget to call 'super()'?"
|
|
.format(type(mod).__name__))
|
|
|
|
def infer_methods_to_compile(nn_module):
|
|
"""
|
|
Implements the default rules for which methods should act as starting
|
|
points for compilation (TODO add a link when the rules are published).
|
|
"""
|
|
check_module_initialized(nn_module)
|
|
|
|
methods = []
|
|
if hasattr(nn_module, 'forward'):
|
|
if getattr(nn_module.forward, "__func__", None) == torch.nn.Module.forward:
|
|
# TODO, we deleted a check that forward is actually defined, instead skipping it
|
|
pass
|
|
elif not _jit_internal.is_ignored_fn(nn_module.forward):
|
|
methods = ['forward']
|
|
|
|
exported = []
|
|
for name in dir(nn_module):
|
|
item = getattr(nn_module, name, None)
|
|
if _jit_internal.get_torchscript_modifier(item) is _jit_internal.FunctionModifiers.EXPORT:
|
|
exported.append(name)
|
|
|
|
methods = methods + exported
|
|
|
|
overload_name_mappings = dict(getattr(nn_module, "__overloads__", {}))
|
|
overload_info = get_overload_annotations(nn_module)
|
|
overload_name_mappings.update(get_overload_name_mapping(overload_info))
|
|
overload_stubs = make_stubs_for_overloads(overload_info)
|
|
|
|
nn_module.__overloads__ = overload_name_mappings
|
|
|
|
# we shouldn't directly compile overloaded methods, just its overloads
|
|
def ignore_overloaded(method_name):
|
|
return method_name not in overload_name_mappings
|
|
|
|
filtered_methods = filter(ignore_overloaded, methods)
|
|
|
|
# Unique the methods. We don't want to use a set to store the methods because it
|
|
# introduces non-determinism to compile order.
|
|
uniquer = set()
|
|
uniqued_methods = []
|
|
for name in filtered_methods:
|
|
if name in uniquer:
|
|
continue
|
|
uniqued_methods.append(name)
|
|
uniquer.add(name)
|
|
|
|
stubs = []
|
|
for method in uniqued_methods:
|
|
stubs.append(make_stub_from_method(nn_module, method))
|
|
return overload_stubs + stubs
|
|
|
|
def recursive_script(nn_module):
|
|
"""
|
|
Makes a ScriptModule from an nn.Module, using the default rules for
|
|
determining which methods to compile.
|
|
|
|
Arguments:
|
|
nn_module: The original Python nn.Module that we are creating a ScriptModule for.
|
|
"""
|
|
if isinstance(nn_module, torch.jit.ScriptModule):
|
|
return nn_module
|
|
|
|
check_module_initialized(nn_module)
|
|
|
|
if isinstance(nn_module, (torch.nn.ModuleList, torch.nn.Sequential, torch.nn.ModuleDict)):
|
|
# Create constant versions for the iterable modules
|
|
return create_constant_iterable_module(nn_module)
|
|
|
|
return create_script_module(nn_module, infer_methods_to_compile(nn_module))
|
|
|
|
def try_compile_fn(fn, loc):
|
|
if _jit_internal.is_ignored_fn(fn):
|
|
# Don't do anything for @ignore'd functions
|
|
return None
|
|
|
|
if isinstance(fn, torch.nn.Module):
|
|
# Since modules are callable pybind recognizes them as functions, but
|
|
# don't do anything for them
|
|
return None
|
|
|
|
if not inspect.isfunction(fn) and not inspect.ismethod(fn):
|
|
raise RuntimeError("`{}` is not a function. Recursive scripting only supports "
|
|
"Python functions or methods currently.\n"
|
|
"Consider manually annotating `{}` with @torch.jit.script.".format(fn, fn))
|
|
|
|
# We don't have the actual scope where the function was defined, but we can
|
|
# extract the necessary info from the closed over variables on the function
|
|
# object
|
|
rcb = _jit_internal.createResolutionCallbackFromClosure(fn)
|
|
return torch.jit.script(fn, _rcb=rcb)
|
|
|
|
def create_constant_iterable_module(module):
|
|
modules = collections.OrderedDict()
|
|
|
|
for key, submodule in module._modules.items():
|
|
if isinstance(submodule, (ModuleList, Sequential, ModuleDict)):
|
|
# Make each item in the module a constant
|
|
modules[key] = create_constant_iterable_module(submodule)
|
|
else:
|
|
modules[key] = recursive_script(submodule)
|
|
|
|
if isinstance(module, Sequential):
|
|
return torch.jit._ConstSequential(Sequential(modules))
|
|
elif isinstance(module, ModuleList):
|
|
return torch.jit._ConstModuleList(modules)
|
|
elif isinstance(module, ModuleDict):
|
|
return torch.jit._ConstModuleDict(modules)
|
|
else:
|
|
raise RuntimeError("Only nn.ModuleList, nn.Sequential, and nn.ModuleDict can be made "
|
|
"into constant modules, found {}".format(module))
|
|
|
|
def wrap_cpp_module(cpp_module):
|
|
"""
|
|
Wrap this torch._C.ScriptModule in a Python ScriptModule, recursively for all submodules
|
|
"""
|
|
def init_fn(script_module):
|
|
for name, cpp_module in script_module._c._get_modules():
|
|
setattr(script_module, name, wrap_cpp_module(cpp_module))
|
|
return torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
|
|
|
|
def compile_unbound_method(concrete_type, fn):
|
|
if _jit_internal.is_ignored_fn(fn):
|
|
return None
|
|
stub = make_stub(fn)
|
|
with torch.jit._disable_emit_hooks():
|
|
# We don't want to call the hooks here since the graph that is calling
|
|
# this function is not yet complete
|
|
create_methods_from_stubs(concrete_type, (stub,))
|
|
return stub
|
|
|
|
def lazy_bind(concrete_type, unbound_method):
|
|
"""
|
|
Returns a function that lazily binds `unbound_method` to a provided
|
|
Module IValue, then invokes the method. We do this so that any Python
|
|
shenanigans that will poison type sharing are impossible at compile
|
|
time.
|
|
"""
|
|
def lazy_binding_method(cpp_module, *args):
|
|
def init_fn(script_module):
|
|
orig_class = concrete_type.py_class
|
|
|
|
# Copy @ignored/@unused methods from the original module to the new one.
|
|
# This ensures they are available during execution.
|
|
for name in dir(orig_class):
|
|
item = getattr(orig_class, name, None)
|
|
if _jit_internal.is_ignored_fn(item):
|
|
setattr(script_module, name, item)
|
|
|
|
# Copy constants over so they are available during execution.
|
|
for name, value in concrete_type.get_constants().items():
|
|
setattr(script_module, name, value)
|
|
|
|
script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
|
|
method = bind_method(unbound_method, script_module, torch.jit.RecursiveScriptModule)
|
|
return method(*args)
|
|
|
|
# make the lazy binding method "look like" the original method
|
|
lazy_binding_method.original_fn = unbound_method
|
|
lazy_binding_method.__name__ = unbound_method.__name__
|
|
torch._jit_internal.copy_torchscript_modifier(unbound_method, lazy_binding_method)
|
|
|
|
return lazy_binding_method
|