mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Noticed while trying to script one of the models which happened to have numpy values as constants. Lacking the numpy prefix in the error message was quite confusing. Pull Request resolved: https://github.com/pytorch/pytorch/pull/41024 Differential Revision: D22426399 Pulled By: dzhulgakov fbshipit-source-id: 06158b75355fac6871e4861f82fc637c2420e370
650 lines
28 KiB
Python
650 lines
28 KiB
Python
import inspect
|
|
import torch
|
|
import collections
|
|
import textwrap
|
|
import functools
|
|
import warnings
|
|
|
|
import torch._jit_internal as _jit_internal
|
|
from torch.jit.frontend import get_default_args, get_jit_def
|
|
from torch.jit._builtins import _find_builtin
|
|
from torch.nn import Module
|
|
from torch._six import get_function_from_type, bind_method
|
|
|
|
|
|
ScriptMethodStub = collections.namedtuple('ScriptMethodStub', ('resolution_callback', 'def_', 'original_method'))
|
|
|
|
# TODO: there should be a more principled way of doing this.
|
|
blacklist = [
|
|
"_version",
|
|
"_parameters",
|
|
"_buffers",
|
|
"_modules",
|
|
"_initializing",
|
|
"_backward_hooks",
|
|
"_forward_hooks",
|
|
"_forward_pre_hooks",
|
|
"_state_dict_hooks",
|
|
"_load_state_dict_pre_hooks",
|
|
"dump_patches",
|
|
]
|
|
|
|
def make_stub(func, name):
|
|
rcb = _jit_internal.createResolutionCallbackFromClosure(func)
|
|
ast = get_jit_def(func, name, self_name="RecursiveScriptModule")
|
|
return ScriptMethodStub(rcb, ast, func)
|
|
|
|
def make_stub_from_method(nn_module, method_name):
|
|
func = getattr(nn_module, method_name)
|
|
if isinstance(func, ScriptMethodStub):
|
|
return func
|
|
# Make sure the name present in the resulting AST will match the name
|
|
# requested here. The only time they don't match is if you do something
|
|
# like:
|
|
# def _forward(self):
|
|
# pass
|
|
# forward = _forward
|
|
# In this case, the actual function object will have the name `_forward`,
|
|
# even though we requested a stub for `forward`.
|
|
return make_stub(func, method_name)
|
|
|
|
# base types that can be constants
|
|
# in addition, tuples and lists of these base types are also considered constants
|
|
# If you edit this list, then you also need to edit the handlers in
|
|
# ConstantValue in jit/script/init.cpp
|
|
_constant_types = (bool, float, int, str, type(None), torch.device, torch.layout, torch.dtype)
|
|
|
|
def _get_valid_constant(attr, v):
|
|
if isinstance(v, _constant_types):
|
|
return v
|
|
elif isinstance(v, tuple) or isinstance(v, list):
|
|
return tuple(_get_valid_constant(attr, x) for x in v)
|
|
constants = ", ".join(torch.typename(typ) for typ in _constant_types)
|
|
raise TypeError(textwrap.dedent("""
|
|
'{}' object for attribute '{}' is not a valid constant.
|
|
Valid constants are:
|
|
1. a nn.ModuleList
|
|
2. a value of type {{{}}}
|
|
3. a list or tuple of (2)
|
|
""".format(torch.typename(type(v)), attr, constants)))
|
|
|
|
|
|
class SourceContext(torch._C._jit_tree_views.SourceRangeFactory):
|
|
def __init__(self, source, filename, file_lineno, leading_whitespace_len):
|
|
super(SourceContext, self).__init__(source, filename, file_lineno, leading_whitespace_len)
|
|
|
|
|
|
def infer_concrete_type_builder(nn_module):
|
|
"""
|
|
Build a ConcreteModuleTypeBuilder from an nn.Module. This
|
|
ConcreteModuleType doesn't have a JIT type associated with it yet, it
|
|
must be filled in by the caller.
|
|
"""
|
|
concrete_type_builder = torch._C.ConcreteModuleTypeBuilder(type(nn_module))
|
|
if isinstance(nn_module, (torch.nn.ModuleDict)):
|
|
concrete_type_builder.set_module_dict()
|
|
if isinstance(nn_module, (torch.nn.ModuleList, torch.nn.Sequential)):
|
|
concrete_type_builder.set_module_list()
|
|
|
|
class_annotations = getattr(nn_module, '__annotations__', {})
|
|
|
|
# try to infer the type from type annotation or from the object itself
|
|
def infer_type(name, item):
|
|
# The forward function from Module is special; never use this annotations; we
|
|
# need to infer type directly using JIT. I originally wanted to write
|
|
# this test as isinstance(class_annotations[name], Callable) but
|
|
# isinstance on typing things doesn't seem to work: isinstance(list, Callable)
|
|
# is also true!
|
|
if name in class_annotations and class_annotations[name] != torch.nn.Module.__annotations__["forward"]:
|
|
attr_type = torch.jit.annotations.ann_to_type(class_annotations[name], _jit_internal.fake_range())
|
|
elif isinstance(item, torch.jit.Attribute):
|
|
attr_type = torch.jit.annotations.ann_to_type(item.type, _jit_internal.fake_range())
|
|
else:
|
|
attr_type = torch._C._jit_try_infer_type(item)
|
|
return attr_type
|
|
|
|
added_names = set()
|
|
|
|
for name, item in nn_module._parameters.items():
|
|
assert item is None or isinstance(item, torch.Tensor)
|
|
attr_type = infer_type(name, item)
|
|
# We currently have the invariant in various places in our code
|
|
# that parameters must be Tensors. However, the nn.Module API also
|
|
# allows NoneType parameters. These parameters are not returned as
|
|
# part of `parameters()` and its variants, but are available
|
|
# through direct attribute access.
|
|
concrete_type_builder.add_attribute(name, attr_type, True, False)
|
|
added_names.add(name)
|
|
|
|
for name, item in nn_module._buffers.items():
|
|
assert item is None or isinstance(item, torch.Tensor)
|
|
attr_type = infer_type(name, item)
|
|
concrete_type_builder.add_attribute(name, attr_type, False, True)
|
|
added_names.add(name)
|
|
|
|
for name, item in nn_module._modules.items():
|
|
attr_type = infer_type(name, item)
|
|
if item is None:
|
|
# Modules can be None. We don't have direct support for optional
|
|
# Modules, so the register it as an NoneType attribute instead.
|
|
concrete_type_builder.add_attribute(name, attr_type, False, False)
|
|
continue
|
|
if attr_type is not None:
|
|
assert attr_type.is_interface_type()
|
|
# if the type can be inferred, it should be a module interface type
|
|
sub_concrete_type = torch._C.ConcreteModuleType.from_jit_type(attr_type)
|
|
else:
|
|
# otherwise we get the concrete module type for item and add it to concrete_type
|
|
sub_concrete_type = concrete_type_store.get_or_create_concrete_type(item)
|
|
concrete_type_builder.add_module(name, sub_concrete_type)
|
|
|
|
added_names.add(name)
|
|
|
|
# populate constants_set
|
|
constants_set = getattr(nn_module, "__constants__", set())
|
|
|
|
# Constants annotated via `Final[T]` rather than being added to `__constants__`
|
|
for name, ann in class_annotations.items():
|
|
if torch._jit_internal.is_final(ann):
|
|
constants_set.add(name)
|
|
|
|
for name in constants_set:
|
|
if name in added_names:
|
|
# TODO: We should really error in this case, but its bc-breaking so
|
|
# we need to warn for at least one release
|
|
if name in nn_module._modules:
|
|
hint = "submodule"
|
|
elif name in nn_module._buffers:
|
|
hint = "buffer"
|
|
elif name in nn_module._parameters:
|
|
hint = "parameter"
|
|
else:
|
|
raise AssertionError("added_names must be submodule, parameter, or buffer")
|
|
|
|
warnings.warn("'{}' was found in ScriptModule constants, "
|
|
" but it is a non-constant {}. Consider removing it.".format(name, hint))
|
|
continue
|
|
if not hasattr(nn_module, name):
|
|
# TODO: We should really error in this case, but its bc-breaking so
|
|
# we need to warn for at least one release
|
|
warnings.warn("'{}' was found in ScriptModule constants, "
|
|
"but was not actually set in __init__. "
|
|
"Consider removing it.".format(name))
|
|
continue
|
|
value = getattr(nn_module, name)
|
|
concrete_type_builder.add_constant(name, _get_valid_constant(name, value))
|
|
added_names.add(name)
|
|
|
|
# populate overloads
|
|
overloads = getattr(nn_module, "__overloads__", {})
|
|
# update with any annotated overloads
|
|
overloads.update(get_overload_name_mapping(get_overload_annotations(nn_module)))
|
|
for name, overloaded_names in overloads.items():
|
|
concrete_type_builder.add_overload(name, overloaded_names)
|
|
|
|
for name, value in nn_module.__dict__.items():
|
|
if name in blacklist or name.startswith("__"):
|
|
# Python objects have lots of random attributes attached to them;
|
|
# PyTorch adds a few more. Prevent these from getting compiled.
|
|
continue
|
|
|
|
if name in added_names:
|
|
# Don't re-add anything we already added
|
|
continue
|
|
|
|
# Handle Python function attributes
|
|
if inspect.isfunction(value):
|
|
try:
|
|
scripted_fn = torch.jit.script(value)
|
|
concrete_type_builder.add_function_attribute(
|
|
name,
|
|
torch._C._jit_try_infer_type(scripted_fn),
|
|
value)
|
|
except Exception as e:
|
|
# If we fail to script the function, it isn't a hard error.
|
|
# Instead, we will add it to the list of attributes we failed
|
|
# to convert, with the compilation error.
|
|
hint = ("(This function exists as an attribute on the Python module, "
|
|
"but we failed to compile it to a TorchScript function. "
|
|
"\nThe error stack is reproduced here:\n{}").format(e)
|
|
concrete_type_builder.add_failed_attribute(name, hint)
|
|
pass
|
|
|
|
continue
|
|
|
|
# Handle calls to builtin functions (either bespoke builtins from torch.jit._builtins or
|
|
# a call to an aten function like torch.add)
|
|
builtin_symbol_name = _find_builtin(value)
|
|
if builtin_symbol_name:
|
|
concrete_type_builder.add_builtin_function(name, builtin_symbol_name)
|
|
continue
|
|
|
|
# Handle Script function attributes
|
|
if isinstance(value, torch.jit.ScriptFunction):
|
|
concrete_type_builder.add_function_attribute(
|
|
name,
|
|
torch._C._jit_try_infer_type(value),
|
|
value)
|
|
continue
|
|
|
|
# If we got here, this is a regular "data" attribute, Add it to the concrete type
|
|
attr_type = infer_type(name, value)
|
|
if attr_type is not None:
|
|
concrete_type_builder.add_attribute(name, attr_type, False, False)
|
|
else:
|
|
# TODO: could add more detail here. For example, what the user should do
|
|
# when the pytype is `list` or `NoneType`
|
|
hint = ("(This attribute exists on the Python module, "
|
|
"but we failed to convert Python type: '{}' "
|
|
"to a TorchScript type.)").format(torch.typename(type(value)))
|
|
concrete_type_builder.add_failed_attribute(name, hint)
|
|
|
|
# Add @property methods as failed attributes, to give a better error message.
|
|
for name, value in type(nn_module).__dict__.items():
|
|
if isinstance(value, property):
|
|
hint = ("\n(This attribute exists on the Python module, but it's an @property "
|
|
"method. @property methods are not yet supported in TorchScript. "
|
|
"Please file a feature request on Github)")
|
|
concrete_type_builder.add_failed_attribute(name, hint)
|
|
|
|
return concrete_type_builder
|
|
|
|
class ConcreteTypeStore(object):
|
|
def __init__(self):
|
|
# Python module type => List[ConcreteModuleType)]
|
|
self.type_store = {}
|
|
# ConcreteTypes that have had their methods already compiled
|
|
self.methods_compiled = set()
|
|
|
|
def get_or_create_concrete_type(self, nn_module):
|
|
"""
|
|
Infer a ConcreteType from this `nn.Module` instance. Underlying JIT
|
|
types are re-used if possible.
|
|
"""
|
|
assert isinstance(nn_module, Module)
|
|
if isinstance(nn_module, torch.jit.ScriptModule) and \
|
|
hasattr(nn_module, "_concrete_type"):
|
|
return nn_module._concrete_type
|
|
|
|
concrete_type_builder = infer_concrete_type_builder(nn_module)
|
|
|
|
nn_module_type = type(nn_module)
|
|
if nn_module_type not in self.type_store:
|
|
self.type_store[nn_module_type] = []
|
|
|
|
# Search the type store for an already-available JIT type
|
|
known_types = self.type_store[nn_module_type]
|
|
for known_type in known_types:
|
|
if known_type.equals(concrete_type_builder):
|
|
return known_type
|
|
|
|
# We didn't find anything; generate a new JIT type from this concrete type
|
|
concrete_type = concrete_type_builder.build()
|
|
self.type_store[nn_module_type].append(concrete_type)
|
|
return concrete_type
|
|
|
|
concrete_type_store = ConcreteTypeStore()
|
|
|
|
def create_methods_from_stubs(concrete_type, stubs):
|
|
defs = [m.def_ for m in stubs]
|
|
rcbs = [m.resolution_callback for m in stubs]
|
|
defaults = [get_default_args(m.original_method) for m in stubs]
|
|
concrete_type._create_methods(defs, rcbs, defaults)
|
|
|
|
def create_script_module(nn_module, stubs_fn, share_types=True):
|
|
"""
|
|
Creates a new ScriptModule from an nn.Module
|
|
|
|
Arguments:
|
|
nn_module: The original Python nn.Module that we are creating a ScriptModule for.
|
|
stubs_fn: Lambda that takes an nn.Module and generates a list of ScriptMethodStubs to compile.
|
|
share_types: Whether to share underlying JIT types between modules (if possible).
|
|
NOTE: Only set to False this when we cannot guarantee type sharing will work
|
|
correctly. This only happens today for traced modules, where the same
|
|
module can produce different traced methods depending on the inputs.
|
|
"""
|
|
assert not isinstance(nn_module, torch.jit.RecursiveScriptModule)
|
|
check_module_initialized(nn_module)
|
|
if share_types:
|
|
# Look into the store of cached JIT types
|
|
concrete_type = concrete_type_store.get_or_create_concrete_type(nn_module)
|
|
else:
|
|
# Get a concrete type directly, without trying to re-use an existing JIT
|
|
# type from the type store.
|
|
concrete_type_builder = infer_concrete_type_builder(nn_module)
|
|
concrete_type_builder.set_poisoned()
|
|
concrete_type = concrete_type_builder.build()
|
|
return create_script_module_impl(nn_module, concrete_type, stubs_fn)
|
|
|
|
def create_script_module_impl(nn_module, concrete_type, stubs_fn):
|
|
"""
|
|
Convert an nn.Module to a RecursiveScriptModule.
|
|
|
|
Arguments:
|
|
nn_module: The original Python nn.Module that we are creating a ScriptModule for.
|
|
concrete_type: The fully initialized ConcreteType of the module.
|
|
stubs_fn: Lambda that takes an nn.Module and generates a list of ScriptMethodStubs to compile.
|
|
"""
|
|
cpp_module = torch._C._create_module_with_type(concrete_type.jit_type)
|
|
stubs = stubs_fn(nn_module)
|
|
|
|
def init_fn(script_module):
|
|
# Initialize the ScriptModule:
|
|
# 1. Copy the attributes/parameters/buffers from the original `nn_module` to the new ScriptModule.
|
|
for name, (attr_type, is_param) in concrete_type.get_attributes().items():
|
|
orig_value = getattr(nn_module, name)
|
|
orig_value = orig_value.value if isinstance(orig_value, torch.jit.Attribute) else orig_value
|
|
cpp_module.setattr(name, orig_value)
|
|
|
|
# 2. Copy the submodules from the original `nn_module` to the new ScriptModule,
|
|
# recursively scripting them.
|
|
for name, sub_concrete_type in concrete_type.get_modules():
|
|
orig_value = getattr(nn_module, name)
|
|
assert isinstance(orig_value, Module), "Expected Module but got {}".format(type(orig_value))
|
|
module_type = sub_concrete_type.jit_type
|
|
if isinstance(module_type, torch._C.InterfaceType):
|
|
# use the interface inference rule to compile the module
|
|
scripted = interface_script(module_type, orig_value)
|
|
elif isinstance(orig_value, torch.jit.ScriptModule):
|
|
scripted = orig_value
|
|
else:
|
|
# use the default recursive rule to compile the module
|
|
scripted = create_script_module_impl(orig_value, sub_concrete_type, infer_methods_to_compile)
|
|
|
|
cpp_module.setattr(name, scripted)
|
|
script_module._modules[name] = scripted
|
|
|
|
# 3. Copy @ignored/@unused methods from the original `nn_module` to the new ScriptModule.
|
|
# This ensures we can access these Python methods on the ScriptModule.
|
|
for name in dir(nn_module):
|
|
item = getattr(nn_module, name, None)
|
|
if not inspect.ismethod(item):
|
|
continue
|
|
if _jit_internal.is_ignored_fn(item):
|
|
unbound_function = getattr(type(nn_module), name)
|
|
bound_method = unbound_function.__get__(script_module)
|
|
setattr(script_module, name, bound_method)
|
|
|
|
# For convenience, attach the concrete type to the new ScriptModule
|
|
script_module._concrete_type = concrete_type
|
|
|
|
# Actually create the ScriptModule, initializing it with the function we just defined
|
|
script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
|
|
|
|
# Compile methods if necessary
|
|
if concrete_type not in concrete_type_store.methods_compiled:
|
|
create_methods_from_stubs(concrete_type, stubs)
|
|
torch._C._run_emit_module_hook(cpp_module)
|
|
concrete_type_store.methods_compiled.add(concrete_type)
|
|
|
|
# Special handling so methods like __len__ work in script methods on classes derived from containers
|
|
if isinstance(nn_module, (torch.nn.ModuleList, torch.nn.Sequential, torch.nn.ModuleDict)) and \
|
|
'__len__' not in cpp_module._method_names():
|
|
script_module.define("def __len__(self):\n return {}\n".format(len(nn_module)))
|
|
if isinstance(nn_module, torch.nn.ModuleDict) and \
|
|
'__contains__' not in cpp_module._method_names():
|
|
if len(nn_module.keys()):
|
|
keys = repr(list(nn_module.keys()))
|
|
script_module.define("def __contains__(self, key: str):\n return key in {}\n".format(keys))
|
|
else:
|
|
script_module.define("def __contains__(self, key: str):\n return False\n")
|
|
|
|
|
|
# Make the compiled methods available to the Python ScriptModule class.
|
|
for stub in stubs:
|
|
if stub.original_method is None:
|
|
# define()'d methods don't have an Python original_method, so we
|
|
# don't need to do any Python re-wrapping stuff
|
|
continue
|
|
|
|
name = stub.original_method.__name__
|
|
if name != stub.def_.name().name:
|
|
# TODO: Why skip this? Because @torch.jit._overload_method will
|
|
# mangle the name of the function.
|
|
continue
|
|
script_method = cpp_module._get_method(name)
|
|
|
|
# Wrap the original to propagate docstrings and such.
|
|
# TODO: we don't currently do this functions that are recursively
|
|
# compiled, we should.
|
|
script_method = functools.wraps(stub.original_method)(script_method)
|
|
|
|
# Add the methods to the script_module directly. This ensures they will
|
|
# be found first when `name` is looked up (as opposed to the stubs or
|
|
# nn.Module.forward)
|
|
script_module.__dict__[name] = script_method
|
|
|
|
|
|
# copy over python methods to script module if they aren't defined on the script module
|
|
# this is currently an internal api used only on module containers
|
|
for name in dir(nn_module):
|
|
item = getattr(nn_module, name, None)
|
|
if _jit_internal.get_torchscript_modifier(item) is _jit_internal.FunctionModifiers.COPY_TO_SCRIPT_WRAPPER:
|
|
add_python_attr_to_scripted_model(script_module, nn_module, name)
|
|
|
|
return script_module
|
|
|
|
|
|
# We define shims of certain attributes on the RecursiveScriptModule to support
|
|
# magic methods. To check if a script model defines an attribute we need
|
|
# to also check that the attribute is not the shim
|
|
def script_model_defines_attr(script_model, attr):
|
|
script_attr = getattr(script_model, attr, None)
|
|
if script_attr is None:
|
|
return False
|
|
default_attr = get_function_from_type(torch.jit.RecursiveScriptModule, attr)
|
|
if default_attr is None:
|
|
return False
|
|
return script_attr != default_attr
|
|
|
|
def add_python_attr_to_scripted_model(script_model, orig, attr):
|
|
if hasattr(orig, attr) and script_model_defines_attr(script_model, attr):
|
|
setattr(script_model, attr, getattr(orig, attr))
|
|
|
|
def get_overload_annotations(mod):
|
|
# original function => [(mangled overload name, overload function)]
|
|
overloads = {}
|
|
|
|
for name in dir(type(mod)):
|
|
item = getattr(mod, name, None)
|
|
if not callable(item):
|
|
continue
|
|
|
|
# builtin functions like repr() in python 2 do not have __module__ defined
|
|
if hasattr(item, "__module__") and item.__module__ is not None:
|
|
method_overloads = _jit_internal._get_overloaded_methods(item, mod.__class__)
|
|
if method_overloads is None:
|
|
continue
|
|
|
|
names = [name + "__" + str(i) for i in range(len(method_overloads))]
|
|
overloads[item] = list(zip(names, method_overloads))
|
|
|
|
return overloads
|
|
|
|
def get_overload_name_mapping(overload_info):
|
|
# Same format as __overloads__
|
|
# original function => [overload names]
|
|
overload_name_mappings = {}
|
|
for orig_fn, overloads in overload_info.items():
|
|
original_name = orig_fn.__name__
|
|
if original_name not in overload_name_mappings:
|
|
overload_name_mappings[original_name] = []
|
|
|
|
for overload_name, _ in overloads:
|
|
overload_name_mappings[original_name].append(overload_name)
|
|
return overload_name_mappings
|
|
|
|
def _check_no_signature(func):
|
|
signature = torch.jit.annotations.get_signature(func, None, _jit_internal.fake_range(), inspect.ismethod(func))
|
|
if signature is None:
|
|
qual_name = _jit_internal._qualified_name(func)
|
|
raise RuntimeError("Must explicitly add type annotations to overloaded functions: {}".format(qual_name))
|
|
|
|
def make_stubs_for_overloads(overload_info):
|
|
overload_stubs = []
|
|
for orig_fn, overloads in overload_info.items():
|
|
orig_ast = get_jit_def(orig_fn, orig_fn.__name__, self_name="RecursiveScriptModule")
|
|
for overload_name, overload_fn in overloads:
|
|
_check_no_signature(overload_fn)
|
|
over_ast = get_jit_def(overload_fn, overload_fn.__name__, self_name="RecursiveScriptModule")
|
|
new_ast = torch._C._replace_overloaded_method_decl(over_ast.decl(), orig_ast, overload_name)
|
|
_rcb = _jit_internal.createResolutionCallbackFromClosure(orig_fn)
|
|
overload_stubs.append(ScriptMethodStub(_rcb, new_ast, overload_fn))
|
|
return overload_stubs
|
|
|
|
def check_module_initialized(mod):
|
|
assert isinstance(mod, torch.nn.Module)
|
|
if not hasattr(mod, '_parameters'):
|
|
raise RuntimeError("'{}' has not been initialized, did you forget to call 'super()'?"
|
|
.format(torch.typename(type(mod))))
|
|
|
|
def infer_methods_to_compile(nn_module):
|
|
"""
|
|
Implements the default rules for which methods should act as starting
|
|
points for compilation (TODO add a link when the rules are published).
|
|
"""
|
|
check_module_initialized(nn_module)
|
|
|
|
methods = []
|
|
if hasattr(nn_module, 'forward') and not _jit_internal.is_ignored_fn(nn_module.forward):
|
|
forward_func = getattr(nn_module.forward, "__func__", None)
|
|
module_forward = get_function_from_type(torch.nn.Module, "forward")
|
|
if forward_func != module_forward:
|
|
methods = ['forward']
|
|
|
|
exported = []
|
|
for name in dir(nn_module):
|
|
item = getattr(nn_module, name, None)
|
|
if _jit_internal.get_torchscript_modifier(item) is _jit_internal.FunctionModifiers.EXPORT:
|
|
exported.append(name)
|
|
|
|
methods = methods + exported
|
|
|
|
overload_name_mappings = dict(getattr(nn_module, "__overloads__", {}))
|
|
overload_info = get_overload_annotations(nn_module)
|
|
overload_name_mappings.update(get_overload_name_mapping(overload_info))
|
|
overload_stubs = make_stubs_for_overloads(overload_info)
|
|
|
|
nn_module.__overloads__ = overload_name_mappings
|
|
|
|
# we shouldn't directly compile overloaded methods, just its overloads
|
|
def ignore_overloaded(method_name):
|
|
return method_name not in overload_name_mappings
|
|
|
|
filtered_methods = filter(ignore_overloaded, methods)
|
|
|
|
# Unique the methods. We don't want to use a set to store the methods because it
|
|
# introduces non-determinism to compile order.
|
|
uniquer = set()
|
|
uniqued_methods = []
|
|
for name in filtered_methods:
|
|
if name in uniquer:
|
|
continue
|
|
uniqued_methods.append(name)
|
|
uniquer.add(name)
|
|
|
|
stubs = []
|
|
for method in uniqued_methods:
|
|
stubs.append(make_stub_from_method(nn_module, method))
|
|
return overload_stubs + stubs
|
|
|
|
def interface_script(mod_interface, nn_module):
|
|
"""
|
|
Makes a ScriptModule from an nn.Module, using the interface methods rule for
|
|
determining which methods to compile.
|
|
|
|
Arguments:
|
|
mod_interface: the interface type that the module have
|
|
nn_module: The original Python nn.Module that we are creating a ScriptModule for.
|
|
"""
|
|
if isinstance(nn_module, torch.jit.ScriptModule):
|
|
return nn_module
|
|
|
|
check_module_initialized(nn_module)
|
|
|
|
def infer_interface_methods_to_compile(nn_module):
|
|
"""
|
|
Rule to infer the methods from the interface type to know which
|
|
methods need to act as starting points for compilation.
|
|
"""
|
|
stubs = []
|
|
for method in mod_interface.getMethodNames():
|
|
stubs.append(make_stub_from_method(nn_module, method))
|
|
return stubs
|
|
|
|
return create_script_module(nn_module, infer_interface_methods_to_compile)
|
|
|
|
def try_compile_fn(fn, loc):
|
|
if _jit_internal.is_ignored_fn(fn):
|
|
# Don't do anything for @ignore'd functions
|
|
return None
|
|
|
|
if isinstance(fn, torch.nn.Module):
|
|
# Since modules are callable pybind recognizes them as functions, but
|
|
# don't do anything for them
|
|
return None
|
|
|
|
if not inspect.isfunction(fn) and not inspect.ismethod(fn):
|
|
raise RuntimeError("`{}` is not a function. Recursive scripting only supports "
|
|
"Python functions or methods currently.\n"
|
|
"Consider manually annotating `{}` with @torch.jit.script.".format(fn, fn))
|
|
|
|
# We don't have the actual scope where the function was defined, but we can
|
|
# extract the necessary info from the closed over variables on the function
|
|
# object
|
|
rcb = _jit_internal.createResolutionCallbackFromClosure(fn)
|
|
return torch.jit.script(fn, _rcb=rcb)
|
|
|
|
def wrap_cpp_module(cpp_module):
|
|
"""
|
|
Wrap this torch._C.ScriptModule in a Python ScriptModule, recursively for all submodules
|
|
"""
|
|
def init_fn(script_module):
|
|
for name, cpp_module in torch._C.ModuleDict(script_module._c).items():
|
|
setattr(script_module, name, wrap_cpp_module(cpp_module))
|
|
script_module._concrete_type = torch._C.ConcreteModuleType.from_jit_type(script_module._c._type())
|
|
return torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
|
|
|
|
def compile_unbound_method(concrete_type, fn):
|
|
if _jit_internal.is_ignored_fn(fn):
|
|
return None
|
|
stub = make_stub(fn, fn.__name__)
|
|
with torch.jit._disable_emit_hooks():
|
|
# We don't want to call the hooks here since the graph that is calling
|
|
# this function is not yet complete
|
|
create_methods_from_stubs(concrete_type, (stub,))
|
|
return stub
|
|
|
|
def lazy_bind(concrete_type, unbound_method):
|
|
"""
|
|
Returns a function that lazily binds `unbound_method` to a provided
|
|
Module IValue, then invokes the method. We do this so that any Python
|
|
shenanigans that will poison type sharing are impossible at compile
|
|
time.
|
|
"""
|
|
def lazy_binding_method(cpp_module, *args):
|
|
def init_fn(script_module):
|
|
orig_class = concrete_type.py_class
|
|
|
|
# Copy @ignored/@unused methods from the original module to the new one.
|
|
# This ensures they are available during execution.
|
|
for name in dir(orig_class):
|
|
item = getattr(orig_class, name, None)
|
|
if _jit_internal.is_ignored_fn(item):
|
|
setattr(script_module, name, item)
|
|
|
|
# Copy constants over so they are available during execution.
|
|
for name, value in concrete_type.get_constants().items():
|
|
setattr(script_module, name, value)
|
|
|
|
script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
|
|
method = bind_method(unbound_method, script_module, torch.jit.RecursiveScriptModule)
|
|
return method(*args)
|
|
|
|
# make the lazy binding method "look like" the original method
|
|
lazy_binding_method.original_fn = unbound_method
|
|
lazy_binding_method.__name__ = unbound_method.__name__
|
|
torch._jit_internal.copy_torchscript_modifier(unbound_method, lazy_binding_method)
|
|
|
|
return lazy_binding_method
|