mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/40807 We pack a lot of logic into `jit/__init__.py`, making it unclear to developers and users which parts of our API are public vs. internal. This is one in a series of PRs intended to pull implementation out into separate files, and leave `__init__.py` as a place to register the public API. This PR moves all the tracing-related stuff out, and fixes other spots up as necessary. Followups will move other core APIs out. The desired end-state is that we conform to the relevant rules in [PEP 8](https://www.python.org/dev/peps/pep-0008/#public-and-internal-interfaces). In particular: - Internal implementation goes in modules prefixed by `_`. - `__init__.py` exposes a public API from these private modules, and nothing more. - We set `__all__` appropriately to declare our public API. - All use of JIT-internal functionality outside the JIT are removed (in particular, ONNX is relying on a number internal APIs). Since they will need to be imported explicitly, it will be easier to catch new uses of internal APIs in review. Test Plan: Imported from OSS Reviewed By: eellison Differential Revision: D22320645 Pulled By: suo fbshipit-source-id: 0720ea9976240e09837d76695207e89afcc58270
170 lines
7.0 KiB
Python
170 lines
7.0 KiB
Python
"""TorchScript
|
|
|
|
This module contains functionality to support the JIT's scripting frontend, notably:
|
|
- torch.jit.script
|
|
|
|
This is not intended to be imported directly; please use the exposed
|
|
functionalities in `torch.jit`.
|
|
"""
|
|
import torch
|
|
|
|
import functools
|
|
import collections
|
|
|
|
from torch.jit._recursive import ScriptMethodStub
|
|
from torch.nn import Module
|
|
from torch.jit._state import _enabled
|
|
from torch._six import with_metaclass
|
|
|
|
if _enabled:
|
|
Attribute = collections.namedtuple("Attribute", ["value", "type"])
|
|
else:
|
|
|
|
def Attribute(value, type):
|
|
return value
|
|
|
|
|
|
# For each user-defined class that subclasses ScriptModule, this meta-class:
|
|
# (1) finds all the methods annotated with @script_method in a ScriptModule and
|
|
# removes them from the class attributes
|
|
# (2) puts a wrapper around the class's __init__ method to recusively compile
|
|
# all of the script_methods with the module after the original __init__ has
|
|
# run. This has to occur after the user-defined __init__ so that submodules and
|
|
# parameters are initialized _before_ the script compiler resolve references to
|
|
# `self.param` or `self.module`.
|
|
class ScriptMeta(type):
|
|
def __init__(cls, name, bases, attrs): # noqa: B902
|
|
# Aggregate all the ScriptMethods and constants from superclasses
|
|
cls._methods = {}
|
|
cls._constants_set = set(getattr(cls, "__constants__", ()))
|
|
for base in reversed(bases):
|
|
for k, v in getattr(base, "_methods", {}).items():
|
|
cls._methods[k] = v
|
|
base_constants = getattr(base, "_constants_set", set())
|
|
cls._constants_set = cls._constants_set.union(base_constants)
|
|
|
|
# find all the script methods of the current class
|
|
for k, v in sorted(attrs.items()):
|
|
if isinstance(v, ScriptMethodStub):
|
|
delattr(cls, k)
|
|
cls._methods[v.original_method.__name__] = v
|
|
|
|
if getattr(cls, "_disable_script_meta", False):
|
|
# We leave built-in ScriptModule types alone, since this metaclass
|
|
# is only for compiling user classes that inherit from
|
|
# ScriptModule.
|
|
return super(ScriptMeta, cls).__init__(name, bases, attrs)
|
|
|
|
original_init = getattr(cls, "__init__", lambda self: None)
|
|
|
|
@functools.wraps(original_init)
|
|
def init_then_script(self, *args, **kwargs):
|
|
original_init(self, *args, **kwargs)
|
|
if type(self) == cls:
|
|
|
|
def make_stubs(module):
|
|
cls = type(module)
|
|
return [v for k, v in sorted(cls._methods.items())]
|
|
|
|
self.__dict__[
|
|
"_actual_script_module"
|
|
] = torch.jit._recursive.create_script_module(self, make_stubs)
|
|
|
|
# Delete the Python attributes that now shadow the ScriptModule
|
|
# ones, so that __getattr__ and __setattr__ will properly find
|
|
# the scripted versions.
|
|
concrete_type = self._actual_script_module._concrete_type
|
|
for name in concrete_type.get_attributes():
|
|
delattr(self, name)
|
|
for name, _ in concrete_type.get_modules():
|
|
delattr(self, name)
|
|
for name in ("_parameters", "_buffers", "_modules"):
|
|
delattr(self, name)
|
|
|
|
cls.__init__ = init_then_script
|
|
return super(ScriptMeta, cls).__init__(name, bases, attrs)
|
|
|
|
|
|
class _CachedForward(object):
|
|
def __get__(self, obj, cls):
|
|
return self.__getattr__("forward")
|
|
|
|
|
|
if _enabled:
|
|
# this is a Python 'non-data descriptor' that causes the first access
|
|
# to ScriptModule's forward to lookup the forward method and stash
|
|
# it in the objects dict. Due to the standard rules for attribute lookup
|
|
# subsequent lookups will just directly return the previously looked up method.
|
|
# This is necessary because nn.Module defines forward as a method. If we
|
|
# did nothing __getattr__ would not be called. Instead we'd get nn.Module.forward
|
|
# which always throws an exception.
|
|
|
|
class ScriptModule(with_metaclass(ScriptMeta, Module)):
|
|
"""
|
|
``ScriptModule``s wrap a C++ ``torch::jit::Module``. ``ScriptModule``s
|
|
contain methods, attributes, parameters, and
|
|
constants. These can be accessed the same as on a normal ``nn.Module``.
|
|
"""
|
|
|
|
def __init__(self):
|
|
super(ScriptModule, self).__init__()
|
|
|
|
forward = _CachedForward()
|
|
|
|
def __getattr__(self, attr):
|
|
if "_actual_script_module" not in self.__dict__:
|
|
return super(ScriptModule, self).__getattr__(attr)
|
|
return getattr(self._actual_script_module, attr)
|
|
|
|
def __setattr__(self, attr, value):
|
|
if "_actual_script_module" not in self.__dict__:
|
|
# Unwrap torch.jit.Attribute into a regular setattr + recording
|
|
# the provided type in __annotations__.
|
|
#
|
|
# This ensures that if we use the attr again in `__init__`, it
|
|
# will look like the actual value, not an instance of Attribute.
|
|
if isinstance(value, Attribute):
|
|
# NB: Ensure that we set __annotations__ on the specific
|
|
# class in question, and not on a superclass (which would
|
|
# be wrong wrong wrong!).
|
|
# See also https://github.com/pytorch/pytorch/issues/39463
|
|
if "__annotations__" not in self.__class__.__dict__:
|
|
self.__class__.__annotations__ = {}
|
|
self.__annotations__[attr] = value.type
|
|
value = value.value
|
|
return super(ScriptModule, self).__setattr__(attr, value)
|
|
|
|
setattr(self._actual_script_module, attr, value)
|
|
|
|
def define(self, src):
|
|
if "_actual_script_module" in self.__dict__:
|
|
# If we have completed initialization, just defer to the
|
|
# backing RecursiveScriptModule to eagerly compile the provided
|
|
# source.
|
|
return self._actual_script_module.define(src)
|
|
|
|
# Otherwise, we are still in the object's __init__.
|
|
# In that case, add `src` as a stub to be compiled.
|
|
#
|
|
# We use frames_up=1 to get to the proper surrounding scope. The stack
|
|
# will look like:
|
|
# 0. createResolutionCallback
|
|
# 1. define()
|
|
# 2. surrounding scope.
|
|
#
|
|
# createResolutionCallback internally adds 1 to get us to our frame, then
|
|
# we add 1 to get to the proper surrounding scope.
|
|
rcb = torch._jit_internal.createResolutionCallbackFromFrame(frames_up=1)
|
|
ast = torch._C._parse_source_def(src)
|
|
self._methods[ast.name().name] = ScriptMethodStub(rcb, ast, None)
|
|
|
|
def _replicate_for_data_parallel(self):
|
|
return self._actual_script_module._replicate_for_data_parallel()
|
|
|
|
|
|
else:
|
|
# TODO MAKE SURE THAT DISABLING WORKS
|
|
class ScriptModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(ScriptModule, self).__init__()
|