Remove weak script (#22212)

Summary:
* Deletes all weak script decorators / associated data structures / methods
   * In order to keep supporting the standard library in script, this enables recursive script on any function defined in `torch.nn`
   * Most changes in `torch/nn` are the result of `ag -Q "weak" torch/nn/ -l | xargs sed -i '/weak/d'`, only `rnn.py` needed manual editing to use the `ignore` and `export` to continue supporting the overloaded `forward` methods
* `Sequential`/`ModuleList` no longer need to be added to constants since they are compiled on demand

This should also fix https://github.com/pytorch/pytorch/issues/22212
Pull Request resolved: https://github.com/pytorch/pytorch/pull/22212

Differential Revision: D15988346

Pulled By: driazati

fbshipit-source-id: af223e3ad0580be895377312949997a70e988e4f
This commit is contained in:
David Riazati 2019-07-03 17:22:22 -07:00 committed by Facebook Github Bot
parent b93f29ded3
commit 10c4b98ade
28 changed files with 109 additions and 564 deletions

View File

@ -6979,7 +6979,7 @@ a")
with self.assertRaisesRegex(RuntimeError, "'int' object is not iterable"): with self.assertRaisesRegex(RuntimeError, "'int' object is not iterable"):
M() M()
def test_script_module_list_sequential_error(self): def test_script_module_list_sequential(self):
class M(torch.jit.ScriptModule): class M(torch.jit.ScriptModule):
def __init__(self, mod_list): def __init__(self, mod_list):
super(M, self).__init__(False) super(M, self).__init__(False)
@ -6991,25 +6991,21 @@ a")
v = m(v) v = m(v)
return v return v
with self.assertRaisesRegex(RuntimeError, "Did you forget to add it to __constants"): m = M(nn.Sequential(nn.ReLU()))
a = M(nn.Sequential(nn.ReLU())) self.assertExportImportModule(m, (torch.randn(2, 2),))
with self.assertRaisesRegex(RuntimeError, "Did you forget to add it to __constants"):
a = M(nn.ModuleList([nn.ReLU()]))
def test_attr_module_constants_error(self): def test_attr_module_constants(self):
class M2(torch.jit.ScriptModule): class M2(torch.jit.ScriptModule):
def __init__(self, mod_list): def __init__(self, mod_list):
super(M2, self).__init__(False) super(M2, self).__init__(False)
self.mods = mod_list self.mods = mod_list
@torch.jit.script_method @torch.jit.script_method
def forward(self, v): def forward(self, x):
return self.mods.forward(x) return self.mods.forward(x)
with self.assertRaisesRegex(RuntimeError, "Did you forget to add it to __constants"): m = M2(nn.Sequential(nn.ReLU()))
M2(nn.Sequential(nn.ReLU())) self.assertExportImportModule(m, (torch.randn(2, 2),))
with self.assertRaisesRegex(RuntimeError, "Did you forget to add it to __constants"):
M2(nn.ModuleList([nn.ReLU()]))
def test_script_sequential_for(self): def test_script_sequential_for(self):
class Sub(torch.jit.ScriptModule): class Sub(torch.jit.ScriptModule):
@ -11007,6 +11003,7 @@ a")
with self.assertRaisesRegex(torch.jit.Error, "Exception"): with self.assertRaisesRegex(torch.jit.Error, "Exception"):
foo(torch.tensor(0)) foo(torch.tensor(0))
@unittest.skipIf(True, "Removing weak script")
def test_weak_script_function(self): def test_weak_script_function(self):
outer_var = 10 outer_var = 10
outer_var2 = 11 outer_var2 = 11
@ -11086,6 +11083,7 @@ a")
eg = torch.zeros(3, dtype=torch.uint8) eg = torch.zeros(3, dtype=torch.uint8)
self.assertEqual(foo_traced(eg), foo(eg)) self.assertEqual(foo_traced(eg), foo(eg))
@unittest.skipIf(True, "Removing weak script")
def test_weak_module(self): def test_weak_module(self):
@torch._jit_internal.weak_module @torch._jit_internal.weak_module
@ -11161,6 +11159,7 @@ a")
self.assertEqual(script_result, expected_result) self.assertEqual(script_result, expected_result)
self.assertEqual(script_result, script_result2) self.assertEqual(script_result, script_result2)
@unittest.skipIf(True, "Removing weak script")
def test_weak_module_parameters_and_buffers(self): def test_weak_module_parameters_and_buffers(self):
weights = torch.randn(10, 10) weights = torch.randn(10, 10)
bias = torch.randn(10) bias = torch.randn(10)
@ -11219,6 +11218,7 @@ a")
self.assertEqual(strong_mod(inp), expected_result) self.assertEqual(strong_mod(inp), expected_result)
self.assertExportImportModule(strong_mod, (inp,)) self.assertExportImportModule(strong_mod, (inp,))
@unittest.skipIf(True, "Removing weak script")
def test_weak_module_nested(self): def test_weak_module_nested(self):
@torch._jit_internal.weak_module @torch._jit_internal.weak_module
class OtherWeak(torch.nn.Module): class OtherWeak(torch.nn.Module):
@ -11280,6 +11280,7 @@ a")
+ F.linear(inp, 2 * torch.ones(10, 10), 2 * torch.ones(10)) + F.linear(inp, 2 * torch.ones(10, 10), 2 * torch.ones(10))
self.assertEqual(result, expected_result) self.assertEqual(result, expected_result)
@unittest.skipIf(True, "Removing weak script")
def test_weak_module_submodule(self): def test_weak_module_submodule(self):
@torch._jit_internal.weak_module @torch._jit_internal.weak_module
class Weak(torch.nn.Module): class Weak(torch.nn.Module):
@ -11319,6 +11320,7 @@ a")
with self.assertRaisesRegex(RuntimeError, "Cannot call a ScriptModule that is not a submodule of the caller"): with self.assertRaisesRegex(RuntimeError, "Cannot call a ScriptModule that is not a submodule of the caller"):
strong_mod = Strong() strong_mod = Strong()
@unittest.skipIf(True, "Removing weak script")
def test_weak_module_copying(self): def test_weak_module_copying(self):
class Submodule(torch.nn.Module): class Submodule(torch.nn.Module):
def __init__(self): def __init__(self):
@ -11385,6 +11387,7 @@ a")
m = M() m = M()
@unittest.skipIf(True, "Removing weak script")
def test_weak_module_attributes(self): def test_weak_module_attributes(self):
tester = self tester = self
@ -11948,6 +11951,7 @@ a")
FileCheck().check_not("prim::PythonOp").run(cu.test.graph) FileCheck().check_not("prim::PythonOp").run(cu.test.graph)
@unittest.skipIf(True, "Removing weak script")
def test_overloading(self): def test_overloading(self):
@torch._jit_internal.weak_module @torch._jit_internal.weak_module
class W(torch.nn.Module): class W(torch.nn.Module):
@ -13623,6 +13627,9 @@ EXCLUDE_SCRIPT_MODULES = {
'test_nn_AdaptiveAvgPool3d_tuple_none', 'test_nn_AdaptiveAvgPool3d_tuple_none',
'test_nn_AdaptiveMaxPool2d_tuple_none', 'test_nn_AdaptiveMaxPool2d_tuple_none',
'test_nn_AdaptiveMaxPool3d_tuple_none', 'test_nn_AdaptiveMaxPool3d_tuple_none',
# Uses Module._backend, so this is not supported
'test_nn_CrossMapLRN2d',
} }
@ -14552,10 +14559,6 @@ def add_nn_module_test(*args, **kwargs):
module_name = name.split("_")[0] module_name = name.split("_")[0]
module = getattr(torch.nn, module_name, None)
if module is None or torch._jit_internal.weak_types.get(module) is None:
return
if 'desc' in kwargs and 'eval' in kwargs['desc']: if 'desc' in kwargs and 'eval' in kwargs['desc']:
# eval() is not supported, so skip these tests # eval() is not supported, so skip these tests
return return

View File

@ -4,29 +4,14 @@ can be used in other places in torch/ (namely torch.nn) without running into
circular dependency problems circular dependency problems
""" """
import weakref
import inspect import inspect
import weakref
from torch._six import builtins from torch._six import builtins
# Tracks standalone weak script functions
compiled_weak_fns = weakref.WeakKeyDictionary() # noqa: T484
# Tracks which methods should be converted to strong methods
weak_script_methods = weakref.WeakKeyDictionary() # noqa: T484
# Converted modules and their corresponding WeakScriptModuleProxy objects
weak_modules = weakref.WeakKeyDictionary() # noqa: T484
# Types that have been declared as weak modules
weak_types = weakref.WeakKeyDictionary() # noqa: T484
# Wrapper functions that can call either of 2 functions depending on a boolean # Wrapper functions that can call either of 2 functions depending on a boolean
# argument # argument
boolean_dispatched = weakref.WeakKeyDictionary() # noqa: T484 boolean_dispatched = weakref.WeakKeyDictionary() # noqa: T484
COMPILATION_PENDING = object()
COMPILED = object()
def createResolutionCallback(frames_up=0): def createResolutionCallback(frames_up=0):
""" """
@ -71,51 +56,41 @@ def createResolutionCallback(frames_up=0):
return f_globals[key] return f_globals[key]
elif hasattr(builtins, key): elif hasattr(builtins, key):
return getattr(builtins, key) return getattr(builtins, key)
else:
return None
return env return env
def weak_script(fn, _frames_up=0): def createResolutionCallbackFromClosure(fn):
""" """
Marks a function as a weak script function. When used in a script function Create a resolutionCallback by introspecting the function instead of
or ScriptModule, the weak script function will be lazily compiled and looking up the stack for the enclosing scope
inlined in the graph. When not used in a script function, the weak script
annotation has no effect.
""" """
compiled_weak_fns[fn] = { var_names = fn.__code__.co_freevars
"status": COMPILATION_PENDING,
"compiled_fn": None,
"rcb": createResolutionCallback(_frames_up + 1)
}
return fn
# map of captured name -> value
free_vars = {}
def weak_module(cls): for index, name in enumerate(var_names):
weak_types[cls] = { free_vars[name] = fn.__closure__[index].cell_contents
"method_stubs": None f_globals = fn.__globals__
}
return cls
def env(key):
if key in free_vars:
return free_vars[key]
elif hasattr(builtins, key):
return getattr(builtins, key)
else:
return f_globals.get(key)
def weak_script_method(fn): return env
weak_script_methods[fn] = {
"rcb": createResolutionCallback(frames_up=2),
"original_method": fn
}
return fn
def boolean_dispatch(arg_name, arg_index, default, if_true, if_false, module_name, func_name): def boolean_dispatch(arg_name, arg_index, default, if_true, if_false, module_name, func_name):
""" """
Dispatches to either of 2 weak script functions based on a boolean argument. Dispatches to either of 2 script functions based on a boolean argument.
In TorchScript, the boolean argument must be constant so that the correct In TorchScript, the boolean argument must be constant so that the correct
function to use can be determined at compile time. function to use can be determined at compile time.
""" """
if compiled_weak_fns.get(if_true) is None or compiled_weak_fns.get(if_false) is None:
raise RuntimeError("both functions must be weak script")
def fn(*args, **kwargs): def fn(*args, **kwargs):
dispatch_flag = False dispatch_flag = False
if arg_name in kwargs: if arg_name in kwargs:

View File

@ -23,7 +23,8 @@ Decl mergeTypesFromTypeComment(
<< "Number of type annotations (" << "Number of type annotations ("
<< type_annotation_decl.params().size() << type_annotation_decl.params().size()
<< ") did not match the number of " << ") did not match the number of "
<< "function parameters (" << expected_num_annotations << ")"; << (is_method ? "method" : "function")
<< " parameters (" << expected_num_annotations << ")";
} }
auto old = decl.params(); auto old = decl.params();
auto _new = type_annotation_decl.params(); auto _new = type_annotation_decl.params();

View File

@ -244,6 +244,11 @@ std::shared_ptr<SugaredValue> OverloadedMethodValue::call(
<< err.str(); << err.str();
} }
bool should_recurse(py::object obj) {
return py::cast<bool>(py::module::import("torch.jit")
.attr("_is_recursive_script_enabled")(obj));
}
std::shared_ptr<SugaredValue> ModuleValue::attr( std::shared_ptr<SugaredValue> ModuleValue::attr(
const SourceRange& loc, const SourceRange& loc,
Function& m, Function& m,
@ -307,7 +312,7 @@ std::shared_ptr<SugaredValue> ModuleValue::attr(
// If recursive script mode is on, create a ScriptModule and register it as // If recursive script mode is on, create a ScriptModule and register it as
// as submodule or register a python method as a script::Method // as submodule or register a python method as a script::Method
if (getRecursiveScriptMode()) { if (should_recurse(attr)) {
if (py::isinstance(attr, py::module::import("torch.nn").attr("Module"))) { if (py::isinstance(attr, py::module::import("torch.nn").attr("Module"))) {
// If the module is a submodule of the py_module, convert it to a // If the module is a submodule of the py_module, convert it to a
// ScriptModule and add it as a submodule to the script::Module. This // ScriptModule and add it as a submodule to the script::Module. This
@ -471,11 +476,6 @@ std::shared_ptr<SugaredValue> toSugaredValue(
} }
} }
auto weak_obj =
py::module::import("torch.jit").attr("_try_get_weak_module")(obj);
if (!weak_obj.is_none()) {
obj = weak_obj;
}
if (auto callee = as_function(obj)) { if (auto callee = as_function(obj)) {
return std::make_shared<FunctionValue>(callee); return std::make_shared<FunctionValue>(callee);
} else if (py::isinstance<py::module>(obj)) { } else if (py::isinstance<py::module>(obj)) {
@ -504,12 +504,6 @@ std::shared_ptr<SugaredValue> toSugaredValue(
<< "which is currently not supported in Torchscript." << "which is currently not supported in Torchscript."
<< "Please open a feature request to add it."; << "Please open a feature request to add it.";
} }
auto compiled_fn =
py::module::import("torch.jit").attr("_try_compile_weak_script")(obj);
if (auto callee = as_function(compiled_fn)) {
return std::make_shared<FunctionValue>(callee);
}
} }
py::object dispatched_fn = py::object dispatched_fn =
@ -528,7 +522,7 @@ std::shared_ptr<SugaredValue> toSugaredValue(
} }
} }
if (getRecursiveScriptMode() && py::isinstance<py::function>(obj)) { if (should_recurse(obj) && py::isinstance<py::function>(obj)) {
auto compiled_fn = auto compiled_fn =
py::module::import("torch.jit").attr("_try_compile_fn")(obj); py::module::import("torch.jit").attr("_try_compile_fn")(obj);
if (auto callee = as_function(compiled_fn)) { if (auto callee = as_function(compiled_fn)) {

View File

@ -7,7 +7,7 @@ import torch.backends.cudnn as cudnn
import torch.jit.annotations import torch.jit.annotations
import torch._jit_internal as _jit_internal import torch._jit_internal as _jit_internal
from torch._six import PY2, PY37, with_metaclass, get_function_from_type, \ from torch._six import PY2, PY37, with_metaclass, get_function_from_type, \
string_classes, builtins string_classes
from ..nn.modules.utils import _single, _pair, _triple, _quadruple, \ from ..nn.modules.utils import _single, _pair, _triple, _quadruple, \
_list_with_default _list_with_default
import torch.testing import torch.testing
@ -930,50 +930,10 @@ def _try_get_overloaded_fn(mod, field):
return mod._overloads.get(field, None) if isinstance(mod, ScriptModule) else None return mod._overloads.get(field, None) if isinstance(mod, ScriptModule) else None
def _try_compile_weak_script(fn):
entry = _jit_internal.compiled_weak_fns.get(fn)
if entry is None:
return None
if entry["status"] == _jit_internal.COMPILATION_PENDING:
compiled_fn = torch.jit.script(fn, True, 0, entry["rcb"])
del entry["rcb"]
_jit_internal.compiled_weak_fns[fn]["compiled_fn"] = compiled_fn
entry["status"] = _jit_internal.COMPILED
return compiled_fn
# TODO: use fn.__closure__
raise RuntimeError("Cannot make resolutionCallback in Python 2")
else:
return entry["compiled_fn"]
class ScriptWarning(Warning): class ScriptWarning(Warning):
pass pass
def createResolutionCallbackFromClosure(fn):
"""
Create a resolutionCallback by introspecting the function instead of
looking up the stack for the enclosing scope
"""
var_names = fn.__code__.co_freevars
# map of captured name -> value
free_vars = {}
for index, name in enumerate(var_names):
free_vars[name] = fn.__closure__[index].cell_contents
f_globals = fn.__globals__
def env(key):
if key in free_vars:
return free_vars[key]
elif hasattr(builtins, key):
return getattr(builtins, key)
else:
return f_globals.get(key)
return env
def _create_constant_iterable_module(module): def _create_constant_iterable_module(module):
modules = OrderedDict() modules = OrderedDict()
@ -1012,20 +972,20 @@ def _try_compile_fn(fn):
# Don't do anything for @ignore'd functions # Don't do anything for @ignore'd functions
return None return None
if not inspect.isfunction(fn) and not inspect.ismethod(fn):
raise RuntimeError("`{}` is not a function. Recursive scripting only supports "
"Python functions or methods currently.\n"
"Consider manually annotating `{}` with @torch.jit.script.".format(fn))
if isinstance(fn, torch.nn.Module): if isinstance(fn, torch.nn.Module):
# Since modules are callable pybind recognizes them as functions, but # Since modules are callable pybind recognizes them as functions, but
# don't do anything for them # don't do anything for them
return None return None
if not inspect.isfunction(fn) and not inspect.ismethod(fn):
raise RuntimeError("`{}` is not a function. Recursive scripting only supports "
"Python functions or methods currently.\n"
"Consider manually annotating `{}` with @torch.jit.script.".format(fn, fn))
# We don't have the actual scope where the function was defined, but we can # We don't have the actual scope where the function was defined, but we can
# extract the necessary info from the closed over variables on the function # extract the necessary info from the closed over variables on the function
# object # object
rcb = createResolutionCallbackFromClosure(fn) rcb = _jit_internal.createResolutionCallbackFromClosure(fn)
return torch.jit.script(fn, _rcb=rcb) return torch.jit.script(fn, _rcb=rcb)
@ -1040,7 +1000,9 @@ def _disable_emit_hooks():
def _create_method_from_fn(module, fn): def _create_method_from_fn(module, fn):
if _jit_internal.is_ignored_fn(fn): if _jit_internal.is_ignored_fn(fn):
return None return None
stub = script_method(fn, createResolutionCallbackFromClosure(fn)) if not inspect.ismethod(fn):
return None
stub = script_method(fn, _jit_internal.createResolutionCallbackFromClosure(fn))
with _disable_emit_hooks(): with _disable_emit_hooks():
# We don't want to call the hooks here since the graph that is calling # We don't want to call the hooks here since the graph that is calling
# this function is not yet complete # this function is not yet complete
@ -1101,6 +1063,15 @@ def _qualified_name(obj):
return module_name + "." + name return module_name + "." + name
def _is_recursive_script_enabled(value):
# TODO: [enable recursive script]
# when recursive script is made the default, remove this method
enabled = torch._C._jit_recursive_script()
module = inspect.getmodule(value)
if module is not None and 'torch.nn' in module.__name__:
enabled = True
return enabled
@contextlib.contextmanager @contextlib.contextmanager
def _enable_recursive_script(): def _enable_recursive_script():
torch._C._jit_recursive_script(True) torch._C._jit_recursive_script(True)
@ -1114,8 +1085,8 @@ def script(obj, optimize=True, _frames_up=0, _rcb=None):
if _rcb is None: if _rcb is None:
_rcb = _jit_internal.createResolutionCallback(_frames_up + 1) _rcb = _jit_internal.createResolutionCallback(_frames_up + 1)
if torch._C._jit_recursive_script():
if isinstance(obj, torch.nn.Module): if isinstance(obj, torch.nn.Module):
if _is_recursive_script_enabled(obj):
return _convert_to_script_module(obj) return _convert_to_script_module(obj)
if inspect.isclass(obj): if inspect.isclass(obj):
@ -1158,21 +1129,6 @@ def script_method(fn, _rcb=None):
return ScriptMethodStub(_rcb, ast, fn) return ScriptMethodStub(_rcb, ast, fn)
def _try_get_weak_module(mod):
"""
Get the WeakScriptModuleProxy corresponding to mod if it exists
"""
if not isinstance(mod, Module):
return None
return _jit_internal.weak_modules.get(mod)
def _is_weak_type(cls):
"""
Check if a type has been annotated with `weak_module`
"""
return cls in _jit_internal.weak_types
# These OrderedDictWrapper classes replace the actual OrderedDicts in # These OrderedDictWrapper classes replace the actual OrderedDicts in
# module with versions that get/set properties inside of script::Module. # module with versions that get/set properties inside of script::Module.
@ -1569,9 +1525,9 @@ if _enabled:
def __setattr__(self, attr, value): def __setattr__(self, attr, value):
if attr not in self._constants_set: if attr not in self._constants_set:
if isinstance(value, Module) and _is_weak_type(type(value)): if isinstance(value, Module) and _is_recursive_script_enabled(value):
# Compile weak script module # Compile weak script module
value = _make_strong(value) value = _convert_to_script_module(value)
if attr == 'training': if attr == 'training':
if self._c._has_attribute('training'): if self._c._has_attribute('training'):
self.__dict__['training'] = value self.__dict__['training'] = value
@ -1684,7 +1640,7 @@ if _enabled:
if isinstance(item, (ModuleList, Sequential)): if isinstance(item, (ModuleList, Sequential)):
# These are in __constants__, so ignore them here # These are in __constants__, so ignore them here
if not torch._C._jit_recursive_script(): if not _is_recursive_script_enabled(item):
# For recursive script, these are constantified after # For recursive script, these are constantified after
# they are used, so they don't need to be in constants. # they are used, so they don't need to be in constants.
# The `continue` here should be deleted along with # The `continue` here should be deleted along with
@ -1774,33 +1730,27 @@ else:
super(ScriptModule, self).__init__() super(ScriptModule, self).__init__()
def _get_weak_stubs(cls): def _convert_to_script_module(mod):
"""
Calls script_method for each method that has been annotated with @weak_script
on the type of the object passed in and returns the generated ScriptMethodStubs.
"""
stubs = []
for name in dir(cls):
func = get_function_from_type(cls, name)
if func in _jit_internal.weak_script_methods:
entry = _jit_internal.weak_script_methods[func]
stub = script_method(entry["original_method"], entry["rcb"])
stubs.append(stub)
return stubs
def _convert_to_script_module(mod, methods=None):
""" """
Makes a ScriptModule from an nn.Module. If `_methods` is provided, Makes a ScriptModule from an nn.Module. If `_methods` is provided,
these methods are treated as @script_methods. If not, it defaults to these methods are treated as @script_methods. If not, it defaults to
`('forward',)`. Methods accessed in forward are scripted on demand if `('forward',)`. Methods accessed in forward are scripted on demand if
`_enable_recursive_script()` is used. `_enable_recursive_script()` is used.
""" """
if isinstance(mod, ScriptModule):
return mod
if isinstance(mod, (ModuleList, Sequential)): if isinstance(mod, (ModuleList, Sequential)):
# Create constant versions for the iterable modules # Create constant versions for the iterable modules
return _create_constant_iterable_module(mod) return _create_constant_iterable_module(mod)
if methods is None: methods = ()
if hasattr(mod, 'forward'):
if mod.forward.__func__ == torch.nn.Module.forward:
# TODO: [enable recursive script]
# forward was not overrided
raise RuntimeError("No forward method was defined on {}".format(mod))
if not _jit_internal.is_ignored_fn(mod.forward):
methods = ('forward',) methods = ('forward',)
exported = [] exported = []
for name in dir(mod): for name in dir(mod):
@ -1812,36 +1762,12 @@ def _convert_to_script_module(mod, methods=None):
def make_stub(method): def make_stub(method):
func = get_function_from_type(type(mod), method) func = get_function_from_type(type(mod), method)
return script_method(func, createResolutionCallbackFromClosure(func)) return script_method(func, _jit_internal.createResolutionCallbackFromClosure(func))
stubs = list(map(make_stub, methods)) stubs = list(map(make_stub, methods))
return WeakScriptModuleProxy(mod, stubs) return WeakScriptModuleProxy(mod, stubs)
def _make_strong(mod):
"""
Converts a weak module into a subclass of ScriptModule. If `_methods` is
provided, only these methods are treated as @script_methods.
"""
if mod in _jit_internal.weak_modules:
return _jit_internal.weak_modules[mod]
cls = type(mod)
# Explicitly annotated weak script
stubs = _jit_internal.weak_types.get(cls)["method_stubs"]
if stubs is None:
# Generate stubs and and store on weak_types in case this type is
# used again
stubs = _get_weak_stubs(cls)
_jit_internal.weak_types[cls]["method_stubs"] = stubs
proxy = WeakScriptModuleProxy(mod, stubs)
_jit_internal.weak_modules[mod] = proxy
return proxy
def _get_methods(cls): def _get_methods(cls):
import inspect import inspect
# In Python 3 unbound methods are functions, but in Python 2 they are methods # In Python 3 unbound methods are functions, but in Python 2 they are methods
@ -1937,13 +1863,13 @@ class _ConstModuleList(ScriptModule):
if isinstance(modules, OrderedDict): if isinstance(modules, OrderedDict):
for key, module in modules.items(): for key, module in modules.items():
if _is_weak_type(type(module)): if isinstance(module, torch.nn.Module) and _is_recursive_script_enabled(module):
module = _make_strong(module) module = _convert_to_script_module(module)
self.add_module(key, module) self.add_module(key, module)
else: else:
for i, module in enumerate(modules): for i, module in enumerate(modules):
if _is_weak_type(type(module)): if isinstance(module, torch.nn.Module) and _is_recursive_script_enabled(module):
module = _make_strong(module) module = _convert_to_script_module(module)
self.add_module(str(i), module) self.add_module(str(i), module)
def __getitem__(self, idx): def __getitem__(self, idx):

View File

@ -1,9 +1,7 @@
import torch import torch
import torch.backends.cudnn as cudnn import torch.backends.cudnn as cudnn
from ..._jit_internal import weak_script
@weak_script
def affine_grid_generator(theta, size): def affine_grid_generator(theta, size):
# type: (Tensor, List[int]) -> Tensor # type: (Tensor, List[int]) -> Tensor
if theta.is_cuda and cudnn.enabled and cudnn.is_acceptable(theta) and len(size) == 4 and size[0] < 65536: if theta.is_cuda and cudnn.enabled and cudnn.is_acceptable(theta) and len(size) == 4 and size[0] < 65536:

View File

@ -1,10 +1,8 @@
import warnings import warnings
from .._jit_internal import weak_script
# NB: Keep this file in sync with enums in aten/src/ATen/core/Reduction.h # NB: Keep this file in sync with enums in aten/src/ATen/core/Reduction.h
@weak_script
def get_enum(reduction): def get_enum(reduction):
# type: (str) -> int # type: (str) -> int
if reduction == 'none': if reduction == 'none':
@ -26,7 +24,6 @@ def get_enum(reduction):
# We use these functions in torch/legacy as well, in which case we'll silence the warning # We use these functions in torch/legacy as well, in which case we'll silence the warning
@weak_script
def legacy_get_string(size_average, reduce, emit_warning=True): def legacy_get_string(size_average, reduce, emit_warning=True):
# type: (Optional[bool], Optional[bool], bool) -> str # type: (Optional[bool], Optional[bool], bool) -> str
warning = "size_average and reduce args will be deprecated, please use reduction='{}' instead." warning = "size_average and reduce args will be deprecated, please use reduction='{}' instead."
@ -47,7 +44,6 @@ def legacy_get_string(size_average, reduce, emit_warning=True):
return ret return ret
@weak_script
def legacy_get_enum(size_average, reduce, emit_warning=True): def legacy_get_enum(size_average, reduce, emit_warning=True):
# type: (Optional[bool], Optional[bool], bool) -> int # type: (Optional[bool], Optional[bool], bool) -> int
return get_enum(legacy_get_string(size_average, reduce, emit_warning)) return get_enum(legacy_get_string(size_average, reduce, emit_warning))

View File

@ -12,7 +12,7 @@ from ._functions import vision
from .modules.utils import _single, _pair, _triple, _list_with_default from .modules.utils import _single, _pair, _triple, _list_with_default
from . import grad # noqa: F401 from . import grad # noqa: F401
from . import _VF from . import _VF
from .._jit_internal import weak_script, List from .._jit_internal import boolean_dispatch, List
conv1d = _add_docstr(torch.conv1d, r""" conv1d = _add_docstr(torch.conv1d, r"""
@ -299,7 +299,6 @@ Args:
""") """)
@weak_script
def fractional_max_pool2d_with_indices(input, kernel_size, output_size=None, def fractional_max_pool2d_with_indices(input, kernel_size, output_size=None,
output_ratio=None, return_indices=False, output_ratio=None, return_indices=False,
_random_samples=None): _random_samples=None):
@ -346,7 +345,6 @@ def fractional_max_pool2d_with_indices(input, kernel_size, output_size=None,
return torch._C._nn.fractional_max_pool2d(input, kernel_size, output_size, _random_samples) return torch._C._nn.fractional_max_pool2d(input, kernel_size, output_size, _random_samples)
@weak_script
def _fractional_max_pool2d(input, kernel_size, output_size=None, def _fractional_max_pool2d(input, kernel_size, output_size=None,
output_ratio=None, return_indices=False, output_ratio=None, return_indices=False,
_random_samples=None): _random_samples=None):
@ -355,7 +353,7 @@ def _fractional_max_pool2d(input, kernel_size, output_size=None,
output_ratio, return_indices, output_ratio, return_indices,
_random_samples)[0] _random_samples)[0]
fractional_max_pool2d = torch._jit_internal.boolean_dispatch( fractional_max_pool2d = boolean_dispatch(
arg_name='return_indices', arg_name='return_indices',
arg_index=4, arg_index=4,
default=False, default=False,
@ -365,7 +363,6 @@ fractional_max_pool2d = torch._jit_internal.boolean_dispatch(
func_name='fractional_max_pool2d') func_name='fractional_max_pool2d')
@weak_script
def fractional_max_pool3d_with_indices(input, kernel_size, output_size=None, def fractional_max_pool3d_with_indices(input, kernel_size, output_size=None,
output_ratio=None, return_indices=False, output_ratio=None, return_indices=False,
_random_samples=None): _random_samples=None):
@ -414,7 +411,6 @@ def fractional_max_pool3d_with_indices(input, kernel_size, output_size=None,
return torch._C._nn.fractional_max_pool3d(input, kernel_size, output_size, _random_samples) return torch._C._nn.fractional_max_pool3d(input, kernel_size, output_size, _random_samples)
@weak_script
def _fractional_max_pool3d(input, kernel_size, output_size=None, def _fractional_max_pool3d(input, kernel_size, output_size=None,
output_ratio=None, return_indices=False, output_ratio=None, return_indices=False,
_random_samples=None): _random_samples=None):
@ -423,7 +419,7 @@ def _fractional_max_pool3d(input, kernel_size, output_size=None,
output_ratio, return_indices, output_ratio, return_indices,
_random_samples)[0] _random_samples)[0]
fractional_max_pool3d = torch._jit_internal.boolean_dispatch( fractional_max_pool3d = boolean_dispatch(
arg_name='return_indices', arg_name='return_indices',
arg_index=4, arg_index=4,
default=False, default=False,
@ -433,7 +429,6 @@ fractional_max_pool3d = torch._jit_internal.boolean_dispatch(
func_name='fractional_max_pool3d') func_name='fractional_max_pool3d')
@weak_script
def max_pool1d_with_indices(input, kernel_size, stride=None, padding=0, def max_pool1d_with_indices(input, kernel_size, stride=None, padding=0,
dilation=1, ceil_mode=False, return_indices=False): dilation=1, ceil_mode=False, return_indices=False):
# type: (Tensor, BroadcastingList1[int], Optional[BroadcastingList1[int]], BroadcastingList1[int], BroadcastingList1[int], bool, bool) -> Tuple[Tensor, Tensor] # noqa # type: (Tensor, BroadcastingList1[int], Optional[BroadcastingList1[int]], BroadcastingList1[int], BroadcastingList1[int], bool, bool) -> Tuple[Tensor, Tensor] # noqa
@ -448,7 +443,6 @@ def max_pool1d_with_indices(input, kernel_size, stride=None, padding=0,
input, kernel_size, stride, padding, dilation, ceil_mode) input, kernel_size, stride, padding, dilation, ceil_mode)
@weak_script
def _max_pool1d(input, kernel_size, stride=None, padding=0, dilation=1, def _max_pool1d(input, kernel_size, stride=None, padding=0, dilation=1,
ceil_mode=False, return_indices=False): ceil_mode=False, return_indices=False):
# type: (Tensor, BroadcastingList1[int], Optional[BroadcastingList1[int]], BroadcastingList1[int], BroadcastingList1[int], bool, bool) -> Tensor # noqa # type: (Tensor, BroadcastingList1[int], Optional[BroadcastingList1[int]], BroadcastingList1[int], BroadcastingList1[int], bool, bool) -> Tensor # noqa
@ -457,7 +451,7 @@ def _max_pool1d(input, kernel_size, stride=None, padding=0, dilation=1,
return torch.max_pool1d( return torch.max_pool1d(
input, kernel_size, stride, padding, dilation, ceil_mode) input, kernel_size, stride, padding, dilation, ceil_mode)
max_pool1d = torch._jit_internal.boolean_dispatch( max_pool1d = boolean_dispatch(
arg_name='return_indices', arg_name='return_indices',
arg_index=6, arg_index=6,
default=False, default=False,
@ -467,7 +461,6 @@ max_pool1d = torch._jit_internal.boolean_dispatch(
func_name='max_pool1d') func_name='max_pool1d')
@weak_script
def max_pool2d_with_indices(input, kernel_size, stride=None, padding=0, dilation=1, def max_pool2d_with_indices(input, kernel_size, stride=None, padding=0, dilation=1,
ceil_mode=False, return_indices=False): ceil_mode=False, return_indices=False):
# type: (Tensor, BroadcastingList2[int], Optional[BroadcastingList2[int]], BroadcastingList2[int], BroadcastingList2[int], bool, bool) -> Tuple[Tensor, Tensor] # noqa # type: (Tensor, BroadcastingList2[int], Optional[BroadcastingList2[int]], BroadcastingList2[int], BroadcastingList2[int], bool, bool) -> Tuple[Tensor, Tensor] # noqa
@ -481,7 +474,6 @@ def max_pool2d_with_indices(input, kernel_size, stride=None, padding=0, dilation
return torch._C._nn.max_pool2d_with_indices(input, kernel_size, stride, padding, dilation, ceil_mode) return torch._C._nn.max_pool2d_with_indices(input, kernel_size, stride, padding, dilation, ceil_mode)
@weak_script
def _max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1, def _max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1,
ceil_mode=False, return_indices=False): ceil_mode=False, return_indices=False):
# type: (Tensor, BroadcastingList2[int], Optional[BroadcastingList2[int]], BroadcastingList2[int], BroadcastingList2[int], bool, bool) -> Tensor # noqa # type: (Tensor, BroadcastingList2[int], Optional[BroadcastingList2[int]], BroadcastingList2[int], BroadcastingList2[int], bool, bool) -> Tensor # noqa
@ -490,7 +482,7 @@ def _max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1,
return torch.max_pool2d( return torch.max_pool2d(
input, kernel_size, stride, padding, dilation, ceil_mode) input, kernel_size, stride, padding, dilation, ceil_mode)
max_pool2d = torch._jit_internal.boolean_dispatch( max_pool2d = boolean_dispatch(
arg_name='return_indices', arg_name='return_indices',
arg_index=6, arg_index=6,
default=False, default=False,
@ -500,7 +492,6 @@ max_pool2d = torch._jit_internal.boolean_dispatch(
func_name='max_pool2d') func_name='max_pool2d')
@weak_script
def max_pool3d_with_indices(input, kernel_size, stride=None, padding=0, def max_pool3d_with_indices(input, kernel_size, stride=None, padding=0,
dilation=1, ceil_mode=False, return_indices=False): dilation=1, ceil_mode=False, return_indices=False):
# type: (Tensor, BroadcastingList3[int], Optional[BroadcastingList3[int]], BroadcastingList3[int], BroadcastingList3[int], bool, bool) -> Tuple[Tensor, Tensor] # noqa # type: (Tensor, BroadcastingList3[int], Optional[BroadcastingList3[int]], BroadcastingList3[int], BroadcastingList3[int], bool, bool) -> Tuple[Tensor, Tensor] # noqa
@ -515,7 +506,6 @@ def max_pool3d_with_indices(input, kernel_size, stride=None, padding=0,
input, kernel_size, stride, padding, dilation, ceil_mode) input, kernel_size, stride, padding, dilation, ceil_mode)
@weak_script
def _max_pool3d(input, kernel_size, stride=None, padding=0, dilation=1, def _max_pool3d(input, kernel_size, stride=None, padding=0, dilation=1,
ceil_mode=False, return_indices=False): ceil_mode=False, return_indices=False):
# type: (Tensor, BroadcastingList3[int], Optional[BroadcastingList3[int]], BroadcastingList3[int], BroadcastingList3[int], bool, bool) -> Tensor # noqa # type: (Tensor, BroadcastingList3[int], Optional[BroadcastingList3[int]], BroadcastingList3[int], BroadcastingList3[int], bool, bool) -> Tensor # noqa
@ -524,7 +514,7 @@ def _max_pool3d(input, kernel_size, stride=None, padding=0, dilation=1,
return torch.max_pool3d( return torch.max_pool3d(
input, kernel_size, stride, padding, dilation, ceil_mode) input, kernel_size, stride, padding, dilation, ceil_mode)
max_pool3d = torch._jit_internal.boolean_dispatch( max_pool3d = boolean_dispatch(
arg_name='return_indices', arg_name='return_indices',
arg_index=6, arg_index=6,
default=False, default=False,
@ -534,7 +524,6 @@ max_pool3d = torch._jit_internal.boolean_dispatch(
func_name='max_pool3d') func_name='max_pool3d')
@weak_script
def _unpool_output_size(input, kernel_size, stride, padding, output_size): def _unpool_output_size(input, kernel_size, stride, padding, output_size):
# type: (Tensor, List[int], List[int], List[int], Optional[List[int]]) -> List[int] # type: (Tensor, List[int], List[int], List[int], Optional[List[int]]) -> List[int]
input_size = input.size() input_size = input.size()
@ -564,7 +553,6 @@ def _unpool_output_size(input, kernel_size, stride, padding, output_size):
return ret return ret
@weak_script
def max_unpool1d(input, indices, kernel_size, stride=None, padding=0, def max_unpool1d(input, indices, kernel_size, stride=None, padding=0,
output_size=None): output_size=None):
# type: (Tensor, Tensor, BroadcastingList1[int], Optional[BroadcastingList1[int]], BroadcastingList1[int], Optional[BroadcastingList1[int]]) -> Tensor # noqa # type: (Tensor, Tensor, BroadcastingList1[int], Optional[BroadcastingList1[int]], BroadcastingList1[int], Optional[BroadcastingList1[int]]) -> Tensor # noqa
@ -588,7 +576,6 @@ def max_unpool1d(input, indices, kernel_size, stride=None, padding=0,
output_size).squeeze(3) output_size).squeeze(3)
@weak_script
def max_unpool2d(input, indices, kernel_size, stride=None, padding=0, def max_unpool2d(input, indices, kernel_size, stride=None, padding=0,
output_size=None): output_size=None):
# type: (Tensor, Tensor, BroadcastingList2[int], Optional[BroadcastingList2[int]], BroadcastingList2[int], Optional[BroadcastingList2[int]]) -> Tensor # noqa # type: (Tensor, Tensor, BroadcastingList2[int], Optional[BroadcastingList2[int]], BroadcastingList2[int], Optional[BroadcastingList2[int]]) -> Tensor # noqa
@ -607,7 +594,6 @@ def max_unpool2d(input, indices, kernel_size, stride=None, padding=0,
return torch._C._nn.max_unpool2d(input, indices, output_size) return torch._C._nn.max_unpool2d(input, indices, output_size)
@weak_script
def max_unpool3d(input, indices, kernel_size, stride=None, padding=0, def max_unpool3d(input, indices, kernel_size, stride=None, padding=0,
output_size=None): output_size=None):
# type: (Tensor, Tensor, BroadcastingList3[int], Optional[BroadcastingList3[int]], BroadcastingList3[int], Optional[BroadcastingList3[int]]) -> Tensor # noqa # type: (Tensor, Tensor, BroadcastingList3[int], Optional[BroadcastingList3[int]], BroadcastingList3[int], Optional[BroadcastingList3[int]]) -> Tensor # noqa
@ -627,7 +613,6 @@ def max_unpool3d(input, indices, kernel_size, stride=None, padding=0,
input, indices, output_size, _stride, padding) input, indices, output_size, _stride, padding)
@weak_script
def lp_pool2d(input, norm_type, kernel_size, stride=None, ceil_mode=False): def lp_pool2d(input, norm_type, kernel_size, stride=None, ceil_mode=False):
# type: (Tensor, float, int, Optional[BroadcastingList2[int]], bool) -> Tensor # type: (Tensor, float, int, Optional[BroadcastingList2[int]], bool) -> Tensor
r"""Applies a 2D power-average pooling over an input signal composed of r"""Applies a 2D power-average pooling over an input signal composed of
@ -645,7 +630,6 @@ def lp_pool2d(input, norm_type, kernel_size, stride=None, ceil_mode=False):
return (torch.sign(out) * relu(torch.abs(out))).mul(kw * kh).pow(1. / norm_type) return (torch.sign(out) * relu(torch.abs(out))).mul(kw * kh).pow(1. / norm_type)
@weak_script
def lp_pool1d(input, norm_type, kernel_size, stride=None, ceil_mode=False): def lp_pool1d(input, norm_type, kernel_size, stride=None, ceil_mode=False):
# type: (Tensor, float, int, Optional[BroadcastingList1[int]], bool) -> Tensor # type: (Tensor, float, int, Optional[BroadcastingList1[int]], bool) -> Tensor
r"""Applies a 1D power-average pooling over an input signal composed of r"""Applies a 1D power-average pooling over an input signal composed of
@ -662,7 +646,6 @@ def lp_pool1d(input, norm_type, kernel_size, stride=None, ceil_mode=False):
return (torch.sign(out) * relu(torch.abs(out))).mul(kernel_size).pow(1. / norm_type) return (torch.sign(out) * relu(torch.abs(out))).mul(kernel_size).pow(1. / norm_type)
@weak_script
def adaptive_max_pool1d_with_indices(input, output_size, return_indices=False): def adaptive_max_pool1d_with_indices(input, output_size, return_indices=False):
# type: (Tensor, BroadcastingList1[int], bool) -> Tuple[Tensor, Tensor] # type: (Tensor, BroadcastingList1[int], bool) -> Tuple[Tensor, Tensor]
r"""Applies a 1D adaptive max pooling over an input signal composed of r"""Applies a 1D adaptive max pooling over an input signal composed of
@ -677,12 +660,11 @@ def adaptive_max_pool1d_with_indices(input, output_size, return_indices=False):
return torch.adaptive_max_pool1d(input, output_size) return torch.adaptive_max_pool1d(input, output_size)
@weak_script
def _adaptive_max_pool1d(input, output_size, return_indices=False): def _adaptive_max_pool1d(input, output_size, return_indices=False):
# type: (Tensor, BroadcastingList1[int], bool) -> Tensor # type: (Tensor, BroadcastingList1[int], bool) -> Tensor
return adaptive_max_pool1d_with_indices(input, output_size)[0] return adaptive_max_pool1d_with_indices(input, output_size)[0]
adaptive_max_pool1d = torch._jit_internal.boolean_dispatch( adaptive_max_pool1d = boolean_dispatch(
arg_name='return_indices', arg_name='return_indices',
arg_index=2, arg_index=2,
default=False, default=False,
@ -692,7 +674,6 @@ adaptive_max_pool1d = torch._jit_internal.boolean_dispatch(
func_name='adaptive_max_pool1d') func_name='adaptive_max_pool1d')
@weak_script
def adaptive_max_pool2d_with_indices(input, output_size, return_indices=False): def adaptive_max_pool2d_with_indices(input, output_size, return_indices=False):
# type: (Tensor, BroadcastingList2[int], bool) -> Tuple[Tensor, Tensor] # type: (Tensor, BroadcastingList2[int], bool) -> Tuple[Tensor, Tensor]
r"""Applies a 2D adaptive max pooling over an input signal composed of r"""Applies a 2D adaptive max pooling over an input signal composed of
@ -709,12 +690,11 @@ def adaptive_max_pool2d_with_indices(input, output_size, return_indices=False):
return torch._C._nn.adaptive_max_pool2d(input, output_size) return torch._C._nn.adaptive_max_pool2d(input, output_size)
@weak_script
def _adaptive_max_pool2d(input, output_size, return_indices=False): def _adaptive_max_pool2d(input, output_size, return_indices=False):
# type: (Tensor, BroadcastingList2[int], bool) -> Tensor # type: (Tensor, BroadcastingList2[int], bool) -> Tensor
return adaptive_max_pool2d_with_indices(input, output_size)[0] return adaptive_max_pool2d_with_indices(input, output_size)[0]
adaptive_max_pool2d = torch._jit_internal.boolean_dispatch( adaptive_max_pool2d = boolean_dispatch(
arg_name='return_indices', arg_name='return_indices',
arg_index=2, arg_index=2,
default=False, default=False,
@ -724,7 +704,6 @@ adaptive_max_pool2d = torch._jit_internal.boolean_dispatch(
func_name='adaptive_max_pool2d') func_name='adaptive_max_pool2d')
@weak_script
def adaptive_max_pool3d_with_indices(input, output_size, return_indices=False): def adaptive_max_pool3d_with_indices(input, output_size, return_indices=False):
# type: (Tensor, BroadcastingList3[int], bool) -> Tuple[Tensor, Tensor] # type: (Tensor, BroadcastingList3[int], bool) -> Tuple[Tensor, Tensor]
r"""Applies a 3D adaptive max pooling over an input signal composed of r"""Applies a 3D adaptive max pooling over an input signal composed of
@ -741,12 +720,11 @@ def adaptive_max_pool3d_with_indices(input, output_size, return_indices=False):
return torch._C._nn.adaptive_max_pool3d(input, output_size) return torch._C._nn.adaptive_max_pool3d(input, output_size)
@weak_script
def _adaptive_max_pool3d(input, output_size, return_indices=False): def _adaptive_max_pool3d(input, output_size, return_indices=False):
# type: (Tensor, BroadcastingList3[int], bool) -> Tensor # type: (Tensor, BroadcastingList3[int], bool) -> Tensor
return adaptive_max_pool3d_with_indices(input, output_size)[0] return adaptive_max_pool3d_with_indices(input, output_size)[0]
adaptive_max_pool3d = torch._jit_internal.boolean_dispatch( adaptive_max_pool3d = boolean_dispatch(
arg_name='return_indices', arg_name='return_indices',
arg_index=2, arg_index=2,
default=False, default=False,
@ -769,7 +747,6 @@ Args:
""") """)
@weak_script
def adaptive_avg_pool2d(input, output_size): def adaptive_avg_pool2d(input, output_size):
# type: (Tensor, BroadcastingList2[int]) -> Tensor # type: (Tensor, BroadcastingList2[int]) -> Tensor
r""" r"""
@ -786,7 +763,6 @@ def adaptive_avg_pool2d(input, output_size):
return torch._C._nn.adaptive_avg_pool2d(input, _output_size) return torch._C._nn.adaptive_avg_pool2d(input, _output_size)
@weak_script
def adaptive_avg_pool3d(input, output_size): def adaptive_avg_pool3d(input, output_size):
# type: (Tensor, BroadcastingList3[int]) -> Tensor # type: (Tensor, BroadcastingList3[int]) -> Tensor
r""" r"""
@ -804,7 +780,6 @@ def adaptive_avg_pool3d(input, output_size):
# Activation functions # Activation functions
@weak_script
def dropout(input, p=0.5, training=True, inplace=False): def dropout(input, p=0.5, training=True, inplace=False):
# type: (Tensor, float, bool, bool) -> Tensor # type: (Tensor, float, bool, bool) -> Tensor
r""" r"""
@ -827,7 +802,6 @@ def dropout(input, p=0.5, training=True, inplace=False):
else _VF.dropout(input, p, training)) else _VF.dropout(input, p, training))
@weak_script
def alpha_dropout(input, p=0.5, training=False, inplace=False): def alpha_dropout(input, p=0.5, training=False, inplace=False):
# type: (Tensor, float, bool, bool) -> Tensor # type: (Tensor, float, bool, bool) -> Tensor
r"""Applies alpha dropout to the input. r"""Applies alpha dropout to the input.
@ -842,7 +816,6 @@ def alpha_dropout(input, p=0.5, training=False, inplace=False):
else _VF.alpha_dropout(input, p, training)) else _VF.alpha_dropout(input, p, training))
@weak_script
def dropout2d(input, p=0.5, training=True, inplace=False): def dropout2d(input, p=0.5, training=True, inplace=False):
# type: (Tensor, float, bool, bool) -> Tensor # type: (Tensor, float, bool, bool) -> Tensor
r""" r"""
@ -867,7 +840,6 @@ def dropout2d(input, p=0.5, training=True, inplace=False):
else _VF.feature_dropout(input, p, training)) else _VF.feature_dropout(input, p, training))
@weak_script
def dropout3d(input, p=0.5, training=True, inplace=False): def dropout3d(input, p=0.5, training=True, inplace=False):
# type: (Tensor, float, bool, bool) -> Tensor # type: (Tensor, float, bool, bool) -> Tensor
r""" r"""
@ -894,7 +866,6 @@ def dropout3d(input, p=0.5, training=True, inplace=False):
else _VF.feature_dropout(input, p, training)) else _VF.feature_dropout(input, p, training))
@weak_script
def feature_alpha_dropout(input, p=0.5, training=False, inplace=False): def feature_alpha_dropout(input, p=0.5, training=False, inplace=False):
# type: (Tensor, float, bool, bool) -> Tensor # type: (Tensor, float, bool, bool) -> Tensor
if p < 0. or p > 1.: if p < 0. or p > 1.:
@ -905,7 +876,6 @@ def feature_alpha_dropout(input, p=0.5, training=False, inplace=False):
else _VF.feature_alpha_dropout(input, p, training)) else _VF.feature_alpha_dropout(input, p, training))
@weak_script
def threshold(input, threshold, value, inplace=False): def threshold(input, threshold, value, inplace=False):
# type: (Tensor, float, float, bool) -> Tensor # type: (Tensor, float, float, bool) -> Tensor
r"""Thresholds each element of the input Tensor. r"""Thresholds each element of the input Tensor.
@ -926,7 +896,6 @@ In-place version of :func:`~threshold`.
""") """)
@weak_script
def relu(input, inplace=False): def relu(input, inplace=False):
# type: (Tensor, bool) -> Tensor # type: (Tensor, bool) -> Tensor
r"""relu(input, inplace=False) -> Tensor r"""relu(input, inplace=False) -> Tensor
@ -948,7 +917,6 @@ In-place version of :func:`~relu`.
""") """)
@weak_script
def glu(input, dim=-1): def glu(input, dim=-1):
# type: (Tensor, int) -> Tensor # type: (Tensor, int) -> Tensor
r""" r"""
@ -973,7 +941,6 @@ def glu(input, dim=-1):
return torch._C._nn.glu(input, dim) return torch._C._nn.glu(input, dim)
@weak_script
def hardtanh(input, min_val=-1., max_val=1., inplace=False): def hardtanh(input, min_val=-1., max_val=1., inplace=False):
# type: (Tensor, float, float, bool) -> Tensor # type: (Tensor, float, float, bool) -> Tensor
r""" r"""
@ -996,7 +963,6 @@ In-place version of :func:`~hardtanh`.
""") """)
@weak_script
def relu6(input, inplace=False): def relu6(input, inplace=False):
# type: (Tensor, bool) -> Tensor # type: (Tensor, bool) -> Tensor
r"""relu6(input, inplace=False) -> Tensor r"""relu6(input, inplace=False) -> Tensor
@ -1008,7 +974,6 @@ def relu6(input, inplace=False):
return hardtanh(input, 0., 6., inplace) return hardtanh(input, 0., 6., inplace)
@weak_script
def elu(input, alpha=1., inplace=False): def elu(input, alpha=1., inplace=False):
# type: (Tensor, float, bool) -> Tensor # type: (Tensor, float, bool) -> Tensor
r"""Applies element-wise, r"""Applies element-wise,
@ -1030,7 +995,6 @@ In-place version of :func:`~elu`.
""") """)
@weak_script
def selu(input, inplace=False): def selu(input, inplace=False):
# type: (Tensor, bool) -> Tensor # type: (Tensor, bool) -> Tensor
r"""selu(input, inplace=False) -> Tensor r"""selu(input, inplace=False) -> Tensor
@ -1056,7 +1020,6 @@ In-place version of :func:`~selu`.
""") """)
@weak_script
def celu(input, alpha=1., inplace=False): def celu(input, alpha=1., inplace=False):
# type: (Tensor, float, bool) -> Tensor # type: (Tensor, float, bool) -> Tensor
r"""celu(input, alpha=1., inplace=False) -> Tensor r"""celu(input, alpha=1., inplace=False) -> Tensor
@ -1079,7 +1042,6 @@ In-place version of :func:`~celu`.
""") """)
@weak_script
def leaky_relu(input, negative_slope=0.01, inplace=False): def leaky_relu(input, negative_slope=0.01, inplace=False):
# type: (Tensor, float, bool) -> Tensor # type: (Tensor, float, bool) -> Tensor
r""" r"""
@ -1104,7 +1066,6 @@ In-place version of :func:`~leaky_relu`.
""") """)
@weak_script
def prelu(input, weight): def prelu(input, weight):
# type: (Tensor, Tensor) -> Tensor # type: (Tensor, Tensor) -> Tensor
r"""prelu(input, weight) -> Tensor r"""prelu(input, weight) -> Tensor
@ -1118,7 +1079,6 @@ def prelu(input, weight):
return torch.prelu(input, weight) return torch.prelu(input, weight)
@weak_script
def rrelu(input, lower=1. / 8, upper=1. / 3, training=False, inplace=False): def rrelu(input, lower=1. / 8, upper=1. / 3, training=False, inplace=False):
# type: (Tensor, float, float, bool, bool) -> Tensor # type: (Tensor, float, float, bool, bool) -> Tensor
r"""rrelu(input, lower=1./8, upper=1./3, training=False, inplace=False) -> Tensor r"""rrelu(input, lower=1./8, upper=1./3, training=False, inplace=False) -> Tensor
@ -1148,7 +1108,6 @@ Applies element-wise :math:`\text{LogSigmoid}(x_i) = \log \left(\frac{1}{1 + \ex
See :class:`~torch.nn.LogSigmoid` for more details. See :class:`~torch.nn.LogSigmoid` for more details.
""") """)
@weak_script
def gelu(input): def gelu(input):
r"""gelu(input) -> Tensor r"""gelu(input) -> Tensor
@ -1162,7 +1121,6 @@ def gelu(input):
return torch._C._nn.gelu(input) return torch._C._nn.gelu(input)
@weak_script
def hardshrink(input, lambd=0.5): def hardshrink(input, lambd=0.5):
# type: (Tensor, float) -> Tensor # type: (Tensor, float) -> Tensor
r""" r"""
@ -1175,7 +1133,6 @@ def hardshrink(input, lambd=0.5):
return torch.hardshrink(input, lambd) return torch.hardshrink(input, lambd)
@weak_script
def tanhshrink(input): def tanhshrink(input):
r"""tanhshrink(input) -> Tensor r"""tanhshrink(input) -> Tensor
@ -1186,7 +1143,6 @@ def tanhshrink(input):
return input - input.tanh() return input - input.tanh()
@weak_script
def softsign(input): def softsign(input):
r"""softsign(input) -> Tensor r"""softsign(input) -> Tensor
@ -1202,7 +1158,6 @@ softplus(input, beta=1, threshold=20) -> Tensor
""") """)
@weak_script
def _get_softmax_dim(name, ndim, stacklevel): def _get_softmax_dim(name, ndim, stacklevel):
# type: (str, int, int) -> int # type: (str, int, int) -> int
warnings.warn("Implicit dimension choice for {} has been deprecated. " warnings.warn("Implicit dimension choice for {} has been deprecated. "
@ -1214,7 +1169,6 @@ def _get_softmax_dim(name, ndim, stacklevel):
return ret return ret
@weak_script
def softmin(input, dim=None, _stacklevel=3, dtype=None): def softmin(input, dim=None, _stacklevel=3, dtype=None):
# type: (Tensor, Optional[int], int, Optional[int]) -> Tensor # type: (Tensor, Optional[int], int, Optional[int]) -> Tensor
r"""Applies a softmin function. r"""Applies a softmin function.
@ -1240,7 +1194,6 @@ def softmin(input, dim=None, _stacklevel=3, dtype=None):
return ret return ret
@weak_script
def softmax(input, dim=None, _stacklevel=3, dtype=None): def softmax(input, dim=None, _stacklevel=3, dtype=None):
# type: (Tensor, Optional[int], int, Optional[int]) -> Tensor # type: (Tensor, Optional[int], int, Optional[int]) -> Tensor
r"""Applies a softmax function. r"""Applies a softmax function.
@ -1276,7 +1229,6 @@ def softmax(input, dim=None, _stacklevel=3, dtype=None):
return ret return ret
@weak_script
def gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1): def gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1):
# type: (Tensor, float, bool, float, int) -> Tensor # type: (Tensor, float, bool, float, int) -> Tensor
r""" r"""
@ -1337,7 +1289,6 @@ def gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1):
return ret return ret
@weak_script
def log_softmax(input, dim=None, _stacklevel=3, dtype=None): def log_softmax(input, dim=None, _stacklevel=3, dtype=None):
# type: (Tensor, Optional[int], int, Optional[int]) -> Tensor # type: (Tensor, Optional[int], int, Optional[int]) -> Tensor
r"""Applies a softmax followed by a logarithm. r"""Applies a softmax followed by a logarithm.
@ -1373,7 +1324,6 @@ See :class:`~torch.nn.Softshrink` for more details.
""") """)
@weak_script
def tanh(input): def tanh(input):
r"""tanh(input) -> Tensor r"""tanh(input) -> Tensor
@ -1386,7 +1336,6 @@ def tanh(input):
return input.tanh() return input.tanh()
@weak_script
def sigmoid(input): def sigmoid(input):
r"""sigmoid(input) -> Tensor r"""sigmoid(input) -> Tensor
@ -1398,7 +1347,6 @@ def sigmoid(input):
return input.sigmoid() return input.sigmoid()
@weak_script
def linear(input, weight, bias=None): def linear(input, weight, bias=None):
# type: (Tensor, Tensor, Optional[Tensor]) -> Tensor # type: (Tensor, Tensor, Optional[Tensor]) -> Tensor
r""" r"""
@ -1423,7 +1371,6 @@ def linear(input, weight, bias=None):
return ret return ret
@weak_script
def bilinear(input1, input2, weight, bias=None): def bilinear(input1, input2, weight, bias=None):
# type: (Tensor, Tensor, Tensor, Optional[Tensor]) -> Tensor # type: (Tensor, Tensor, Tensor, Optional[Tensor]) -> Tensor
return torch.bilinear(input1, input2, weight, bias) return torch.bilinear(input1, input2, weight, bias)
@ -1435,7 +1382,6 @@ def _no_grad_embedding_renorm_(weight, input, max_norm, norm_type):
torch.embedding_renorm_(weight, input, max_norm, norm_type) torch.embedding_renorm_(weight, input, max_norm, norm_type)
@weak_script
def embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2., def embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2.,
scale_grad_by_freq=False, sparse=False): scale_grad_by_freq=False, sparse=False):
# type: (Tensor, Tensor, Optional[int], Optional[float], float, bool, bool) -> Tensor # type: (Tensor, Tensor, Optional[int], Optional[float], float, bool, bool) -> Tensor
@ -1517,7 +1463,6 @@ def embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2.,
return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse) return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
@weak_script
def embedding_bag(input, weight, offsets=None, max_norm=None, norm_type=2, def embedding_bag(input, weight, offsets=None, max_norm=None, norm_type=2,
scale_grad_by_freq=False, mode='mean', sparse=False, scale_grad_by_freq=False, mode='mean', sparse=False,
per_sample_weights=None): per_sample_weights=None):
@ -1677,7 +1622,6 @@ def embedding_bag(input, weight, offsets=None, max_norm=None, norm_type=2,
return ret return ret
@weak_script
def batch_norm(input, running_mean, running_var, weight=None, bias=None, def batch_norm(input, running_mean, running_var, weight=None, bias=None,
training=False, momentum=0.1, eps=1e-5): training=False, momentum=0.1, eps=1e-5):
# type: (Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor], bool, float, float) -> Tensor # noqa # type: (Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor], bool, float, float) -> Tensor # noqa
@ -1709,7 +1653,6 @@ def batch_norm(input, running_mean, running_var, weight=None, bias=None,
) )
@weak_script
def instance_norm(input, running_mean=None, running_var=None, weight=None, def instance_norm(input, running_mean=None, running_var=None, weight=None,
bias=None, use_input_stats=True, momentum=0.1, eps=1e-5): bias=None, use_input_stats=True, momentum=0.1, eps=1e-5):
# type: (Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor], bool, float, float) -> Tensor # noqa # type: (Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor], bool, float, float) -> Tensor # noqa
@ -1725,7 +1668,6 @@ def instance_norm(input, running_mean=None, running_var=None, weight=None,
) )
@weak_script
def layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-5): def layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-5):
# type: (Tensor, List[int], Optional[Tensor], Optional[Tensor], float) -> Tensor # type: (Tensor, List[int], Optional[Tensor], Optional[Tensor], float) -> Tensor
r"""Applies Layer Normalization for last certain number of dimensions. r"""Applies Layer Normalization for last certain number of dimensions.
@ -1736,7 +1678,6 @@ def layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-5):
torch.backends.cudnn.enabled) torch.backends.cudnn.enabled)
@weak_script
def group_norm(input, num_groups, weight=None, bias=None, eps=1e-5): def group_norm(input, num_groups, weight=None, bias=None, eps=1e-5):
# type: (Tensor, int, Optional[Tensor], Optional[Tensor], float) -> Tensor # type: (Tensor, int, Optional[Tensor], Optional[Tensor], float) -> Tensor
r"""Applies Group Normalization for last certain number of dimensions. r"""Applies Group Normalization for last certain number of dimensions.
@ -1747,7 +1688,6 @@ def group_norm(input, num_groups, weight=None, bias=None, eps=1e-5):
torch.backends.cudnn.enabled) torch.backends.cudnn.enabled)
@weak_script
def local_response_norm(input, size, alpha=1e-4, beta=0.75, k=1.): def local_response_norm(input, size, alpha=1e-4, beta=0.75, k=1.):
# type: (Tensor, int, float, float, float) -> Tensor # type: (Tensor, int, float, float, float) -> Tensor
r"""Applies local response normalization over an input signal composed of r"""Applies local response normalization over an input signal composed of
@ -1776,7 +1716,6 @@ def local_response_norm(input, size, alpha=1e-4, beta=0.75, k=1.):
# loss # loss
@weak_script
def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=0, def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=0,
reduction='mean', zero_infinity=False): reduction='mean', zero_infinity=False):
# type: (Tensor, Tensor, Tensor, Tensor, int, str, bool) -> Tensor # type: (Tensor, Tensor, Tensor, Tensor, int, str, bool) -> Tensor
@ -1824,7 +1763,6 @@ def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=0,
zero_infinity) zero_infinity)
@weak_script
def nll_loss(input, target, weight=None, size_average=None, ignore_index=-100, def nll_loss(input, target, weight=None, size_average=None, ignore_index=-100,
reduce=None, reduction='mean'): reduce=None, reduction='mean'):
# type: (Tensor, Tensor, Optional[Tensor], Optional[bool], int, Optional[bool], str) -> Tensor # type: (Tensor, Tensor, Optional[Tensor], Optional[bool], int, Optional[bool], str) -> Tensor
@ -1903,7 +1841,6 @@ def nll_loss(input, target, weight=None, size_average=None, ignore_index=-100,
return ret return ret
@weak_script
def poisson_nll_loss(input, target, log_input=True, full=False, size_average=None, eps=1e-8, def poisson_nll_loss(input, target, log_input=True, full=False, size_average=None, eps=1e-8,
reduce=None, reduction='mean'): reduce=None, reduction='mean'):
# type: (Tensor, Tensor, bool, bool, Optional[bool], float, Optional[bool], str) -> Tensor # type: (Tensor, Tensor, bool, bool, Optional[bool], float, Optional[bool], str) -> Tensor
@ -1949,7 +1886,6 @@ def poisson_nll_loss(input, target, log_input=True, full=False, size_average=Non
return ret return ret
@weak_script
def kl_div(input, target, size_average=None, reduce=None, reduction='mean'): def kl_div(input, target, size_average=None, reduce=None, reduction='mean'):
# type: (Tensor, Tensor, Optional[bool], Optional[bool], str) -> Tensor # type: (Tensor, Tensor, Optional[bool], Optional[bool], str) -> Tensor
r"""The `Kullback-Leibler divergence`_ Loss. r"""The `Kullback-Leibler divergence`_ Loss.
@ -2007,7 +1943,6 @@ def kl_div(input, target, size_average=None, reduce=None, reduction='mean'):
return reduced return reduced
@weak_script
def cross_entropy(input, target, weight=None, size_average=None, ignore_index=-100, def cross_entropy(input, target, weight=None, size_average=None, ignore_index=-100,
reduce=None, reduction='mean'): reduce=None, reduction='mean'):
# type: (Tensor, Tensor, Optional[Tensor], Optional[bool], int, Optional[bool], str) -> Tensor # type: (Tensor, Tensor, Optional[Tensor], Optional[bool], int, Optional[bool], str) -> Tensor
@ -2056,7 +1991,6 @@ def cross_entropy(input, target, weight=None, size_average=None, ignore_index=-1
return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction) return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
@weak_script
def binary_cross_entropy(input, target, weight=None, size_average=None, def binary_cross_entropy(input, target, weight=None, size_average=None,
reduce=None, reduction='mean'): reduce=None, reduction='mean'):
# type: (Tensor, Tensor, Optional[Tensor], Optional[bool], Optional[bool], str) -> Tensor # type: (Tensor, Tensor, Optional[Tensor], Optional[bool], Optional[bool], str) -> Tensor
@ -2113,7 +2047,6 @@ def binary_cross_entropy(input, target, weight=None, size_average=None,
input, target, weight, reduction_enum) input, target, weight, reduction_enum)
@weak_script
def binary_cross_entropy_with_logits(input, target, weight=None, size_average=None, def binary_cross_entropy_with_logits(input, target, weight=None, size_average=None,
reduce=None, reduction='mean', pos_weight=None): reduce=None, reduction='mean', pos_weight=None):
# type: (Tensor, Tensor, Optional[Tensor], Optional[bool], Optional[bool], str, Optional[Tensor]) -> Tensor # type: (Tensor, Tensor, Optional[Tensor], Optional[bool], Optional[bool], str, Optional[Tensor]) -> Tensor
@ -2174,14 +2107,12 @@ def _pointwise_loss(lambd, lambd_optimized, input, target, reduction='mean'):
return lambd_optimized(expanded_input, expanded_target, _Reduction.get_enum(reduction)) return lambd_optimized(expanded_input, expanded_target, _Reduction.get_enum(reduction))
@weak_script
def _smooth_l1_loss(input, target): def _smooth_l1_loss(input, target):
# type: (Tensor, Tensor) -> Tensor # type: (Tensor, Tensor) -> Tensor
t = torch.abs(input - target) t = torch.abs(input - target)
return torch.where(t < 1, 0.5 * t ** 2, t - 0.5) return torch.where(t < 1, 0.5 * t ** 2, t - 0.5)
@weak_script
def smooth_l1_loss(input, target, size_average=None, reduce=None, reduction='mean'): def smooth_l1_loss(input, target, size_average=None, reduce=None, reduction='mean'):
# type: (Tensor, Tensor, Optional[bool], Optional[bool], str) -> Tensor # type: (Tensor, Tensor, Optional[bool], Optional[bool], str) -> Tensor
r"""Function that uses a squared term if the absolute r"""Function that uses a squared term if the absolute
@ -2206,7 +2137,6 @@ def smooth_l1_loss(input, target, size_average=None, reduce=None, reduction='mea
return ret return ret
@weak_script
def l1_loss(input, target, size_average=None, reduce=None, reduction='mean'): def l1_loss(input, target, size_average=None, reduce=None, reduction='mean'):
# type: (Tensor, Tensor, Optional[bool], Optional[bool], str) -> Tensor # type: (Tensor, Tensor, Optional[bool], Optional[bool], str) -> Tensor
r"""l1_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor r"""l1_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor
@ -2232,7 +2162,6 @@ def l1_loss(input, target, size_average=None, reduce=None, reduction='mean'):
return ret return ret
@weak_script
def mse_loss(input, target, size_average=None, reduce=None, reduction='mean'): def mse_loss(input, target, size_average=None, reduce=None, reduction='mean'):
# type: (Tensor, Tensor, Optional[bool], Optional[bool], str) -> Tensor # type: (Tensor, Tensor, Optional[bool], Optional[bool], str) -> Tensor
r"""mse_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor r"""mse_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor
@ -2258,7 +2187,6 @@ def mse_loss(input, target, size_average=None, reduce=None, reduction='mean'):
return ret return ret
@weak_script
def margin_ranking_loss(input1, input2, target, margin=0, size_average=None, def margin_ranking_loss(input1, input2, target, margin=0, size_average=None,
reduce=None, reduction='mean'): reduce=None, reduction='mean'):
# type: (Tensor, Tensor, Tensor, float, Optional[bool], Optional[bool], str) -> Tensor # type: (Tensor, Tensor, Tensor, float, Optional[bool], Optional[bool], str) -> Tensor
@ -2276,7 +2204,6 @@ def margin_ranking_loss(input1, input2, target, margin=0, size_average=None,
return torch.margin_ranking_loss(input1, input2, target, margin, reduction_enum) return torch.margin_ranking_loss(input1, input2, target, margin, reduction_enum)
@weak_script
def hinge_embedding_loss(input, target, margin=1.0, size_average=None, def hinge_embedding_loss(input, target, margin=1.0, size_average=None,
reduce=None, reduction='mean'): reduce=None, reduction='mean'):
# type: (Tensor, Tensor, float, Optional[bool], Optional[bool], str) -> Tensor # type: (Tensor, Tensor, float, Optional[bool], Optional[bool], str) -> Tensor
@ -2291,7 +2218,6 @@ def hinge_embedding_loss(input, target, margin=1.0, size_average=None,
return torch.hinge_embedding_loss(input, target, margin, reduction_enum) return torch.hinge_embedding_loss(input, target, margin, reduction_enum)
@weak_script
def multilabel_margin_loss(input, target, size_average=None, reduce=None, reduction='mean'): def multilabel_margin_loss(input, target, size_average=None, reduce=None, reduction='mean'):
# type: (Tensor, Tensor, Optional[bool], Optional[bool], str) -> Tensor # type: (Tensor, Tensor, Optional[bool], Optional[bool], str) -> Tensor
r"""multilabel_margin_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor r"""multilabel_margin_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor
@ -2305,7 +2231,6 @@ def multilabel_margin_loss(input, target, size_average=None, reduce=None, reduct
return torch._C._nn.multilabel_margin_loss(input, target, reduction_enum) return torch._C._nn.multilabel_margin_loss(input, target, reduction_enum)
@weak_script
def soft_margin_loss(input, target, size_average=None, reduce=None, reduction='mean'): def soft_margin_loss(input, target, size_average=None, reduce=None, reduction='mean'):
# type: (Tensor, Tensor, Optional[bool], Optional[bool], str) -> Tensor # type: (Tensor, Tensor, Optional[bool], Optional[bool], str) -> Tensor
r"""soft_margin_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor r"""soft_margin_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor
@ -2319,7 +2244,6 @@ def soft_margin_loss(input, target, size_average=None, reduce=None, reduction='m
return torch._C._nn.soft_margin_loss(input, target, reduction_enum) return torch._C._nn.soft_margin_loss(input, target, reduction_enum)
@weak_script
def multilabel_soft_margin_loss(input, target, weight=None, size_average=None, def multilabel_soft_margin_loss(input, target, weight=None, size_average=None,
reduce=None, reduction='mean'): reduce=None, reduction='mean'):
# type: (Tensor, Tensor, Optional[Tensor], Optional[bool], Optional[bool], str) -> Tensor # type: (Tensor, Tensor, Optional[Tensor], Optional[bool], Optional[bool], str) -> Tensor
@ -2349,7 +2273,6 @@ def multilabel_soft_margin_loss(input, target, weight=None, size_average=None,
return ret return ret
@weak_script
def cosine_embedding_loss(input1, input2, target, margin=0, size_average=None, def cosine_embedding_loss(input1, input2, target, margin=0, size_average=None,
reduce=None, reduction='mean'): reduce=None, reduction='mean'):
# type: (Tensor, Tensor, Tensor, float, Optional[bool], Optional[bool], str) -> Tensor # type: (Tensor, Tensor, Tensor, float, Optional[bool], Optional[bool], str) -> Tensor
@ -2364,7 +2287,6 @@ def cosine_embedding_loss(input1, input2, target, margin=0, size_average=None,
return torch.cosine_embedding_loss(input1, input2, target, margin, reduction_enum) return torch.cosine_embedding_loss(input1, input2, target, margin, reduction_enum)
@weak_script
def multi_margin_loss(input, target, p=1, margin=1., weight=None, size_average=None, def multi_margin_loss(input, target, p=1, margin=1., weight=None, size_average=None,
reduce=None, reduction='mean'): reduce=None, reduction='mean'):
# type: (Tensor, Tensor, int, float, Optional[Tensor], Optional[bool], Optional[bool], str) -> Tensor # type: (Tensor, Tensor, int, float, Optional[Tensor], Optional[bool], Optional[bool], str) -> Tensor
@ -2635,7 +2557,6 @@ GRID_SAMPLE_PADDING_MODES = {
} }
@weak_script
def grid_sample(input, grid, mode='bilinear', padding_mode='zeros'): def grid_sample(input, grid, mode='bilinear', padding_mode='zeros'):
# type: (Tensor, Tensor, str, str) -> Tensor # type: (Tensor, Tensor, str, str) -> Tensor
r"""Given an :attr:`input` and a flow-field :attr:`grid`, computes the r"""Given an :attr:`input` and a flow-field :attr:`grid`, computes the
@ -2717,7 +2638,6 @@ def grid_sample(input, grid, mode='bilinear', padding_mode='zeros'):
return torch.grid_sampler(input, grid, mode_enum, padding_mode_enum) return torch.grid_sampler(input, grid, mode_enum, padding_mode_enum)
@weak_script
def affine_grid(theta, size): def affine_grid(theta, size):
# type: (Tensor, List[int]) -> Tensor # type: (Tensor, List[int]) -> Tensor
r"""Generates a 2d flow field, given a batch of affine matrices :attr:`theta`. r"""Generates a 2d flow field, given a batch of affine matrices :attr:`theta`.
@ -2735,7 +2655,6 @@ def affine_grid(theta, size):
return vision.affine_grid_generator(theta, size) return vision.affine_grid_generator(theta, size)
@weak_script
def pad(input, pad, mode='constant', value=0): def pad(input, pad, mode='constant', value=0):
# type: (Tensor, List[int], str, float) -> Tensor # type: (Tensor, List[int], str, float) -> Tensor
r"""Pads tensor. r"""Pads tensor.
@ -2844,7 +2763,6 @@ def pad(input, pad, mode='constant', value=0):
# distance # distance
@weak_script
def pairwise_distance(x1, x2, p=2., eps=1e-6, keepdim=False): def pairwise_distance(x1, x2, p=2., eps=1e-6, keepdim=False):
# type: (Tensor, Tensor, float, float, bool) -> Tensor # type: (Tensor, Tensor, float, float, bool) -> Tensor
r""" r"""
@ -2952,7 +2870,6 @@ Examples:
""") """)
@weak_script
def triplet_margin_loss(anchor, positive, negative, margin=1.0, p=2, eps=1e-6, swap=False, size_average=None, def triplet_margin_loss(anchor, positive, negative, margin=1.0, p=2, eps=1e-6, swap=False, size_average=None,
reduce=None, reduction="mean"): reduce=None, reduction="mean"):
# type: (Tensor, Tensor, Tensor, float, float, float, bool, Optional[bool], Optional[bool], str) -> Tensor # type: (Tensor, Tensor, Tensor, float, float, float, bool, Optional[bool], Optional[bool], str) -> Tensor
@ -2967,7 +2884,6 @@ def triplet_margin_loss(anchor, positive, negative, margin=1.0, p=2, eps=1e-6, s
swap, reduction_enum) swap, reduction_enum)
@weak_script
def normalize(input, p=2, dim=1, eps=1e-12, out=None): def normalize(input, p=2, dim=1, eps=1e-12, out=None):
# type: (Tensor, float, int, float, Optional[Tensor]) -> Tensor # type: (Tensor, float, int, float, Optional[Tensor]) -> Tensor
r"""Performs :math:`L_p` normalization of inputs over specified dimension. r"""Performs :math:`L_p` normalization of inputs over specified dimension.
@ -3001,7 +2917,6 @@ def assert_int_or_pair(arg, arg_name, message):
assert isinstance(arg, int) or len(arg) == 2, message.format(arg_name) assert isinstance(arg, int) or len(arg) == 2, message.format(arg_name)
@weak_script
def unfold(input, kernel_size, dilation=1, padding=0, stride=1): def unfold(input, kernel_size, dilation=1, padding=0, stride=1):
# type: (Tensor, BroadcastingList2[int], BroadcastingList2[int], BroadcastingList2[int], BroadcastingList2[int]) -> Tensor # noqa # type: (Tensor, BroadcastingList2[int], BroadcastingList2[int], BroadcastingList2[int], BroadcastingList2[int]) -> Tensor # noqa
r"""Extracts sliding local blocks from an batched input tensor. r"""Extracts sliding local blocks from an batched input tensor.
@ -3036,7 +2951,6 @@ def unfold(input, kernel_size, dilation=1, padding=0, stride=1):
return ret return ret
@weak_script
def fold(input, output_size, kernel_size, dilation=1, padding=0, stride=1): def fold(input, output_size, kernel_size, dilation=1, padding=0, stride=1):
# type: (Tensor, BroadcastingList2[int], BroadcastingList2[int], BroadcastingList2[int], BroadcastingList2[int], BroadcastingList2[int]) -> Tensor # noqa # type: (Tensor, BroadcastingList2[int], BroadcastingList2[int], BroadcastingList2[int], BroadcastingList2[int], BroadcastingList2[int]) -> Tensor # noqa
r"""Combines an array of sliding local blocks into a large containing r"""Combines an array of sliding local blocks into a large containing
@ -3064,7 +2978,6 @@ def fold(input, output_size, kernel_size, dilation=1, padding=0, stride=1):
return ret return ret
@weak_script
def _pad_circular(input, padding): def _pad_circular(input, padding):
# type: (Tensor, List[int]) -> Tensor # type: (Tensor, List[int]) -> Tensor
""" """
@ -3090,7 +3003,6 @@ def _pad_circular(input, padding):
return input return input
@weak_script
def multi_head_attention_forward(query, # type: Tensor def multi_head_attention_forward(query, # type: Tensor
key, # type: Tensor key, # type: Tensor
value, # type: Tensor value, # type: Tensor

View File

@ -4,7 +4,6 @@ import math
import warnings import warnings
import torch import torch
from .._jit_internal import weak_script
# These no_grad_* functions are necessary as wrappers around the parts of these # These no_grad_* functions are necessary as wrappers around the parts of these
# functions that use `with torch.no_grad()`. The JIT doesn't support context # functions that use `with torch.no_grad()`. The JIT doesn't support context
@ -72,7 +71,6 @@ def calculate_gain(nonlinearity, param=None):
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity)) raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
@weak_script
def uniform_(tensor, a=0., b=1.): def uniform_(tensor, a=0., b=1.):
# type: (Tensor, float, float) -> Tensor # type: (Tensor, float, float) -> Tensor
r"""Fills the input Tensor with values drawn from the uniform r"""Fills the input Tensor with values drawn from the uniform
@ -90,7 +88,6 @@ def uniform_(tensor, a=0., b=1.):
return _no_grad_uniform_(tensor, a, b) return _no_grad_uniform_(tensor, a, b)
@weak_script
def normal_(tensor, mean=0., std=1.): def normal_(tensor, mean=0., std=1.):
# type: (Tensor, float, float) -> Tensor # type: (Tensor, float, float) -> Tensor
r"""Fills the input Tensor with values drawn from the normal r"""Fills the input Tensor with values drawn from the normal
@ -108,7 +105,6 @@ def normal_(tensor, mean=0., std=1.):
return _no_grad_normal_(tensor, mean, std) return _no_grad_normal_(tensor, mean, std)
@weak_script
def constant_(tensor, val): def constant_(tensor, val):
# type: (Tensor, float) -> Tensor # type: (Tensor, float) -> Tensor
r"""Fills the input Tensor with the value :math:`\text{val}`. r"""Fills the input Tensor with the value :math:`\text{val}`.
@ -124,7 +120,6 @@ def constant_(tensor, val):
return _no_grad_fill_(tensor, val) return _no_grad_fill_(tensor, val)
@weak_script
def ones_(tensor): def ones_(tensor):
# type: (Tensor) -> Tensor # type: (Tensor) -> Tensor
r"""Fills the input Tensor with ones`. r"""Fills the input Tensor with ones`.
@ -139,7 +134,6 @@ def ones_(tensor):
return _no_grad_fill_(tensor, 1.) return _no_grad_fill_(tensor, 1.)
@weak_script
def zeros_(tensor): def zeros_(tensor):
# type: (Tensor) -> Tensor # type: (Tensor) -> Tensor
r"""Fills the input Tensor with zeros`. r"""Fills the input Tensor with zeros`.
@ -205,7 +199,6 @@ def dirac_(tensor):
return tensor return tensor
@weak_script
def _calculate_fan_in_and_fan_out(tensor): def _calculate_fan_in_and_fan_out(tensor):
dimensions = tensor.dim() dimensions = tensor.dim()
if dimensions < 2: if dimensions < 2:
@ -226,7 +219,6 @@ def _calculate_fan_in_and_fan_out(tensor):
return fan_in, fan_out return fan_in, fan_out
@weak_script
def xavier_uniform_(tensor, gain=1.): def xavier_uniform_(tensor, gain=1.):
# type: (Tensor, float) -> Tensor # type: (Tensor, float) -> Tensor
r"""Fills the input `Tensor` with values according to the method r"""Fills the input `Tensor` with values according to the method
@ -255,7 +247,6 @@ def xavier_uniform_(tensor, gain=1.):
return _no_grad_uniform_(tensor, -a, a) return _no_grad_uniform_(tensor, -a, a)
@weak_script
def xavier_normal_(tensor, gain=1.): def xavier_normal_(tensor, gain=1.):
# type: (Tensor, float) -> Tensor # type: (Tensor, float) -> Tensor
r"""Fills the input `Tensor` with values according to the method r"""Fills the input `Tensor` with values according to the method

View File

@ -7,10 +7,8 @@ from torch.nn.init import xavier_normal_
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from .module import Module from .module import Module
from .. import functional as F from .. import functional as F
from ..._jit_internal import weak_module, weak_script_method
@weak_module
class Threshold(Module): class Threshold(Module):
r"""Thresholds each element of the input Tensor. r"""Thresholds each element of the input Tensor.
@ -48,7 +46,6 @@ class Threshold(Module):
self.inplace = inplace self.inplace = inplace
# TODO: check in THNN (if inplace == True, then assert value <= threshold) # TODO: check in THNN (if inplace == True, then assert value <= threshold)
@weak_script_method
def forward(self, input): def forward(self, input):
return F.threshold(input, self.threshold, self.value, self.inplace) return F.threshold(input, self.threshold, self.value, self.inplace)
@ -59,7 +56,6 @@ class Threshold(Module):
) )
@weak_module
class ReLU(Module): class ReLU(Module):
r"""Applies the rectified linear unit function element-wise: r"""Applies the rectified linear unit function element-wise:
@ -94,7 +90,6 @@ class ReLU(Module):
super(ReLU, self).__init__() super(ReLU, self).__init__()
self.inplace = inplace self.inplace = inplace
@weak_script_method
def forward(self, input): def forward(self, input):
return F.relu(input, inplace=self.inplace) return F.relu(input, inplace=self.inplace)
@ -103,7 +98,6 @@ class ReLU(Module):
return inplace_str return inplace_str
@weak_module
class RReLU(Module): class RReLU(Module):
r"""Applies the randomized leaky rectified liner unit function, element-wise, r"""Applies the randomized leaky rectified liner unit function, element-wise,
as described in the paper: as described in the paper:
@ -151,7 +145,6 @@ class RReLU(Module):
self.upper = upper self.upper = upper
self.inplace = inplace self.inplace = inplace
@weak_script_method
def forward(self, input): def forward(self, input):
return F.rrelu(input, self.lower, self.upper, self.training, self.inplace) return F.rrelu(input, self.lower, self.upper, self.training, self.inplace)
@ -160,7 +153,6 @@ class RReLU(Module):
return 'lower={}, upper={}{}'.format(self.lower, self.upper, inplace_str) return 'lower={}, upper={}{}'.format(self.lower, self.upper, inplace_str)
@weak_module
class Hardtanh(Module): class Hardtanh(Module):
r"""Applies the HardTanh function element-wise r"""Applies the HardTanh function element-wise
@ -213,7 +205,6 @@ class Hardtanh(Module):
self.inplace = inplace self.inplace = inplace
assert self.max_val > self.min_val assert self.max_val > self.min_val
@weak_script_method
def forward(self, input): def forward(self, input):
return F.hardtanh(input, self.min_val, self.max_val, self.inplace) return F.hardtanh(input, self.min_val, self.max_val, self.inplace)
@ -224,7 +215,6 @@ class Hardtanh(Module):
) )
@weak_module
class ReLU6(Hardtanh): class ReLU6(Hardtanh):
r"""Applies the element-wise function: r"""Applies the element-wise function:
@ -256,7 +246,6 @@ class ReLU6(Hardtanh):
return inplace_str return inplace_str
@weak_module
class Sigmoid(Module): class Sigmoid(Module):
r"""Applies the element-wise function: r"""Applies the element-wise function:
@ -278,12 +267,10 @@ class Sigmoid(Module):
>>> output = m(input) >>> output = m(input)
""" """
@weak_script_method
def forward(self, input): def forward(self, input):
return torch.sigmoid(input) return torch.sigmoid(input)
@weak_module
class Tanh(Module): class Tanh(Module):
r"""Applies the element-wise function: r"""Applies the element-wise function:
@ -304,12 +291,10 @@ class Tanh(Module):
>>> output = m(input) >>> output = m(input)
""" """
@weak_script_method
def forward(self, input): def forward(self, input):
return torch.tanh(input) return torch.tanh(input)
@weak_module
class ELU(Module): class ELU(Module):
r"""Applies the element-wise function: r"""Applies the element-wise function:
@ -340,7 +325,6 @@ class ELU(Module):
self.alpha = alpha self.alpha = alpha
self.inplace = inplace self.inplace = inplace
@weak_script_method
def forward(self, input): def forward(self, input):
return F.elu(input, self.alpha, self.inplace) return F.elu(input, self.alpha, self.inplace)
@ -349,7 +333,6 @@ class ELU(Module):
return 'alpha={}{}'.format(self.alpha, inplace_str) return 'alpha={}{}'.format(self.alpha, inplace_str)
@weak_module
class CELU(Module): class CELU(Module):
r"""Applies the element-wise function: r"""Applies the element-wise function:
@ -385,7 +368,6 @@ class CELU(Module):
self.alpha = alpha self.alpha = alpha
self.inplace = inplace self.inplace = inplace
@weak_script_method
def forward(self, input): def forward(self, input):
return F.celu(input, self.alpha, self.inplace) return F.celu(input, self.alpha, self.inplace)
@ -394,7 +376,6 @@ class CELU(Module):
return 'alpha={}{}'.format(self.alpha, inplace_str) return 'alpha={}{}'.format(self.alpha, inplace_str)
@weak_module
class SELU(Module): class SELU(Module):
r"""Applied element-wise, as: r"""Applied element-wise, as:
@ -430,7 +411,6 @@ class SELU(Module):
super(SELU, self).__init__() super(SELU, self).__init__()
self.inplace = inplace self.inplace = inplace
@weak_script_method
def forward(self, input): def forward(self, input):
return F.selu(input, self.inplace) return F.selu(input, self.inplace)
@ -439,7 +419,6 @@ class SELU(Module):
return inplace_str return inplace_str
@weak_module
class GLU(Module): class GLU(Module):
r"""Applies the gated linear unit function r"""Applies the gated linear unit function
:math:`{GLU}(a, b)= a \otimes \sigma(b)` where :math:`a` is the first half :math:`{GLU}(a, b)= a \otimes \sigma(b)` where :math:`a` is the first half
@ -465,7 +444,6 @@ class GLU(Module):
super(GLU, self).__init__() super(GLU, self).__init__()
self.dim = dim self.dim = dim
@weak_script_method
def forward(self, input): def forward(self, input):
return F.glu(input, self.dim) return F.glu(input, self.dim)
@ -473,7 +451,6 @@ class GLU(Module):
return 'dim={}'.format(self.dim) return 'dim={}'.format(self.dim)
@weak_module
class Hardshrink(Module): class Hardshrink(Module):
r"""Applies the hard shrinkage function element-wise: r"""Applies the hard shrinkage function element-wise:
@ -507,7 +484,6 @@ class Hardshrink(Module):
super(Hardshrink, self).__init__() super(Hardshrink, self).__init__()
self.lambd = lambd self.lambd = lambd
@weak_script_method
def forward(self, input): def forward(self, input):
return F.hardshrink(input, self.lambd) return F.hardshrink(input, self.lambd)
@ -515,7 +491,6 @@ class Hardshrink(Module):
return '{}'.format(self.lambd) return '{}'.format(self.lambd)
@weak_module
class LeakyReLU(Module): class LeakyReLU(Module):
r"""Applies the element-wise function: r"""Applies the element-wise function:
@ -556,7 +531,6 @@ class LeakyReLU(Module):
self.negative_slope = negative_slope self.negative_slope = negative_slope
self.inplace = inplace self.inplace = inplace
@weak_script_method
def forward(self, input): def forward(self, input):
return F.leaky_relu(input, self.negative_slope, self.inplace) return F.leaky_relu(input, self.negative_slope, self.inplace)
@ -565,7 +539,6 @@ class LeakyReLU(Module):
return 'negative_slope={}{}'.format(self.negative_slope, inplace_str) return 'negative_slope={}{}'.format(self.negative_slope, inplace_str)
@weak_module
class LogSigmoid(Module): class LogSigmoid(Module):
r"""Applies the element-wise function: r"""Applies the element-wise function:
@ -586,12 +559,10 @@ class LogSigmoid(Module):
>>> output = m(input) >>> output = m(input)
""" """
@weak_script_method
def forward(self, input): def forward(self, input):
return F.logsigmoid(input) return F.logsigmoid(input)
@weak_module
class Softplus(Module): class Softplus(Module):
r"""Applies the element-wise function: r"""Applies the element-wise function:
@ -628,7 +599,6 @@ class Softplus(Module):
self.beta = beta self.beta = beta
self.threshold = threshold self.threshold = threshold
@weak_script_method
def forward(self, input): def forward(self, input):
return F.softplus(input, self.beta, self.threshold) return F.softplus(input, self.beta, self.threshold)
@ -636,7 +606,6 @@ class Softplus(Module):
return 'beta={}, threshold={}'.format(self.beta, self.threshold) return 'beta={}, threshold={}'.format(self.beta, self.threshold)
@weak_module
class Softshrink(Module): class Softshrink(Module):
r"""Applies the soft shrinkage function elementwise: r"""Applies the soft shrinkage function elementwise:
@ -670,7 +639,6 @@ class Softshrink(Module):
super(Softshrink, self).__init__() super(Softshrink, self).__init__()
self.lambd = lambd self.lambd = lambd
@weak_script_method
def forward(self, input): def forward(self, input):
return F.softshrink(input, self.lambd) return F.softshrink(input, self.lambd)
@ -678,7 +646,6 @@ class Softshrink(Module):
return str(self.lambd) return str(self.lambd)
@weak_module
class MultiheadAttention(Module): class MultiheadAttention(Module):
r"""Allows the model to jointly attend to information r"""Allows the model to jointly attend to information
from different representation subspaces. from different representation subspaces.
@ -759,7 +726,6 @@ class MultiheadAttention(Module):
if self.bias_v is not None: if self.bias_v is not None:
xavier_normal_(self.bias_v) xavier_normal_(self.bias_v)
@weak_script_method
def forward(self, query, key, value, key_padding_mask=None, def forward(self, query, key, value, key_padding_mask=None,
need_weights=True, attn_mask=None): need_weights=True, attn_mask=None):
r""" r"""
@ -817,7 +783,6 @@ class MultiheadAttention(Module):
attn_mask=attn_mask) attn_mask=attn_mask)
@weak_module
class PReLU(Module): class PReLU(Module):
r"""Applies the element-wise function: r"""Applies the element-wise function:
@ -874,7 +839,6 @@ class PReLU(Module):
super(PReLU, self).__init__() super(PReLU, self).__init__()
self.weight = Parameter(torch.Tensor(num_parameters).fill_(init)) self.weight = Parameter(torch.Tensor(num_parameters).fill_(init))
@weak_script_method
def forward(self, input): def forward(self, input):
return F.prelu(input, self.weight) return F.prelu(input, self.weight)
@ -882,7 +846,6 @@ class PReLU(Module):
return 'num_parameters={}'.format(self.num_parameters) return 'num_parameters={}'.format(self.num_parameters)
@weak_module
class Softsign(Module): class Softsign(Module):
r"""Applies the element-wise function: r"""Applies the element-wise function:
@ -903,12 +866,10 @@ class Softsign(Module):
>>> output = m(input) >>> output = m(input)
""" """
@weak_script_method
def forward(self, input): def forward(self, input):
return F.softsign(input) return F.softsign(input)
@weak_module
class Tanhshrink(Module): class Tanhshrink(Module):
r"""Applies the element-wise function: r"""Applies the element-wise function:
@ -929,12 +890,10 @@ class Tanhshrink(Module):
>>> output = m(input) >>> output = m(input)
""" """
@weak_script_method
def forward(self, input): def forward(self, input):
return F.tanhshrink(input) return F.tanhshrink(input)
@weak_module
class Softmin(Module): class Softmin(Module):
r"""Applies the Softmin function to an n-dimensional input Tensor r"""Applies the Softmin function to an n-dimensional input Tensor
rescaling them so that the elements of the n-dimensional output Tensor rescaling them so that the elements of the n-dimensional output Tensor
@ -970,12 +929,10 @@ class Softmin(Module):
super(Softmin, self).__init__() super(Softmin, self).__init__()
self.dim = dim self.dim = dim
@weak_script_method
def forward(self, input): def forward(self, input):
return F.softmin(input, self.dim, _stacklevel=5) return F.softmin(input, self.dim, _stacklevel=5)
@weak_module
class Softmax(Module): class Softmax(Module):
r"""Applies the Softmax function to an n-dimensional input Tensor r"""Applies the Softmax function to an n-dimensional input Tensor
rescaling them so that the elements of the n-dimensional output Tensor rescaling them so that the elements of the n-dimensional output Tensor
@ -1021,7 +978,6 @@ class Softmax(Module):
if not hasattr(self, 'dim'): if not hasattr(self, 'dim'):
self.dim = None self.dim = None
@weak_script_method
def forward(self, input): def forward(self, input):
return F.softmax(input, self.dim, _stacklevel=5) return F.softmax(input, self.dim, _stacklevel=5)
@ -1029,7 +985,6 @@ class Softmax(Module):
return 'dim={dim}'.format(dim=self.dim) return 'dim={dim}'.format(dim=self.dim)
@weak_module
class Softmax2d(Module): class Softmax2d(Module):
r"""Applies SoftMax over features to each spatial location. r"""Applies SoftMax over features to each spatial location.
@ -1052,13 +1007,11 @@ class Softmax2d(Module):
>>> output = m(input) >>> output = m(input)
""" """
@weak_script_method
def forward(self, input): def forward(self, input):
assert input.dim() == 4, 'Softmax2d requires a 4D tensor as input' assert input.dim() == 4, 'Softmax2d requires a 4D tensor as input'
return F.softmax(input, 1, _stacklevel=5) return F.softmax(input, 1, _stacklevel=5)
@weak_module
class LogSoftmax(Module): class LogSoftmax(Module):
r"""Applies the :math:`\log(\text{Softmax}(x))` function to an n-dimensional r"""Applies the :math:`\log(\text{Softmax}(x))` function to an n-dimensional
input Tensor. The LogSoftmax formulation can be simplified as: input Tensor. The LogSoftmax formulation can be simplified as:
@ -1095,6 +1048,5 @@ class LogSoftmax(Module):
if not hasattr(self, 'dim'): if not hasattr(self, 'dim'):
self.dim = None self.dim = None
@weak_script_method
def forward(self, input): def forward(self, input):
return F.log_softmax(input, self.dim, _stacklevel=5) return F.log_softmax(input, self.dim, _stacklevel=5)

View File

@ -6,12 +6,10 @@ from .module import Module
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from .. import functional as F from .. import functional as F
from .. import init from .. import init
from ..._jit_internal import weak_module, weak_script_method
# TODO: check contiguous in THNN # TODO: check contiguous in THNN
# TODO: use separate backend functions? # TODO: use separate backend functions?
@weak_module
class _BatchNorm(Module): class _BatchNorm(Module):
_version = 2 _version = 2
__constants__ = ['track_running_stats', 'momentum', 'eps', 'weight', 'bias', __constants__ = ['track_running_stats', 'momentum', 'eps', 'weight', 'bias',
@ -57,7 +55,6 @@ class _BatchNorm(Module):
def _check_input_dim(self, input): def _check_input_dim(self, input):
raise NotImplementedError raise NotImplementedError
@weak_script_method
def forward(self, input): def forward(self, input):
self._check_input_dim(input) self._check_input_dim(input)
@ -103,7 +100,6 @@ class _BatchNorm(Module):
missing_keys, unexpected_keys, error_msgs) missing_keys, unexpected_keys, error_msgs)
@weak_module
class BatchNorm1d(_BatchNorm): class BatchNorm1d(_BatchNorm):
r"""Applies Batch Normalization over a 2D or 3D input (a mini-batch of 1D r"""Applies Batch Normalization over a 2D or 3D input (a mini-batch of 1D
inputs with optional additional channel dimension) as described in the paper inputs with optional additional channel dimension) as described in the paper
@ -170,14 +166,12 @@ class BatchNorm1d(_BatchNorm):
https://arxiv.org/abs/1502.03167 https://arxiv.org/abs/1502.03167
""" """
@weak_script_method
def _check_input_dim(self, input): def _check_input_dim(self, input):
if input.dim() != 2 and input.dim() != 3: if input.dim() != 2 and input.dim() != 3:
raise ValueError('expected 2D or 3D input (got {}D input)' raise ValueError('expected 2D or 3D input (got {}D input)'
.format(input.dim())) .format(input.dim()))
@weak_module
class BatchNorm2d(_BatchNorm): class BatchNorm2d(_BatchNorm):
r"""Applies Batch Normalization over a 4D input (a mini-batch of 2D inputs r"""Applies Batch Normalization over a 4D input (a mini-batch of 2D inputs
with additional channel dimension) as described in the paper with additional channel dimension) as described in the paper
@ -244,14 +238,12 @@ class BatchNorm2d(_BatchNorm):
https://arxiv.org/abs/1502.03167 https://arxiv.org/abs/1502.03167
""" """
@weak_script_method
def _check_input_dim(self, input): def _check_input_dim(self, input):
if input.dim() != 4: if input.dim() != 4:
raise ValueError('expected 4D input (got {}D input)' raise ValueError('expected 4D input (got {}D input)'
.format(input.dim())) .format(input.dim()))
@weak_module
class BatchNorm3d(_BatchNorm): class BatchNorm3d(_BatchNorm):
r"""Applies Batch Normalization over a 5D input (a mini-batch of 3D inputs r"""Applies Batch Normalization over a 5D input (a mini-batch of 3D inputs
with additional channel dimension) as described in the paper with additional channel dimension) as described in the paper
@ -319,7 +311,6 @@ class BatchNorm3d(_BatchNorm):
https://arxiv.org/abs/1502.03167 https://arxiv.org/abs/1502.03167
""" """
@weak_script_method
def _check_input_dim(self, input): def _check_input_dim(self, input):
if input.dim() != 5: if input.dim() != 5:
raise ValueError('expected 5D input (got {}D input)' raise ValueError('expected 5D input (got {}D input)'

View File

@ -6,10 +6,9 @@ from .. import functional as F
from .. import init from .. import init
from .module import Module from .module import Module
from .utils import _single, _pair, _triple from .utils import _single, _pair, _triple
from ..._jit_internal import weak_module, weak_script_method, List from ..._jit_internal import List
@weak_module
class _ConvNd(Module): class _ConvNd(Module):
__constants__ = ['stride', 'padding', 'dilation', 'groups', 'bias', __constants__ = ['stride', 'padding', 'dilation', 'groups', 'bias',
@ -74,7 +73,6 @@ class _ConvNd(Module):
self.padding_mode = 'zeros' self.padding_mode = 'zeros'
@weak_module
class Conv1d(_ConvNd): class Conv1d(_ConvNd):
r"""Applies a 1D convolution over an input signal composed of several input r"""Applies a 1D convolution over an input signal composed of several input
planes. planes.
@ -192,7 +190,6 @@ class Conv1d(_ConvNd):
in_channels, out_channels, kernel_size, stride, padding, dilation, in_channels, out_channels, kernel_size, stride, padding, dilation,
False, _single(0), groups, bias, padding_mode) False, _single(0), groups, bias, padding_mode)
@weak_script_method
def forward(self, input): def forward(self, input):
if self.padding_mode == 'circular': if self.padding_mode == 'circular':
expanded_padding = ((self.padding[0] + 1) // 2, self.padding[0] // 2) expanded_padding = ((self.padding[0] + 1) // 2, self.padding[0] // 2)
@ -203,7 +200,6 @@ class Conv1d(_ConvNd):
self.padding, self.dilation, self.groups) self.padding, self.dilation, self.groups)
@weak_module
class Conv2d(_ConvNd): class Conv2d(_ConvNd):
r"""Applies a 2D convolution over an input signal composed of several input r"""Applies a 2D convolution over an input signal composed of several input
planes. planes.
@ -333,7 +329,6 @@ class Conv2d(_ConvNd):
in_channels, out_channels, kernel_size, stride, padding, dilation, in_channels, out_channels, kernel_size, stride, padding, dilation,
False, _pair(0), groups, bias, padding_mode) False, _pair(0), groups, bias, padding_mode)
@weak_script_method
def forward(self, input): def forward(self, input):
if self.padding_mode == 'circular': if self.padding_mode == 'circular':
expanded_padding = ((self.padding[1] + 1) // 2, self.padding[1] // 2, expanded_padding = ((self.padding[1] + 1) // 2, self.padding[1] // 2,
@ -345,7 +340,6 @@ class Conv2d(_ConvNd):
self.padding, self.dilation, self.groups) self.padding, self.dilation, self.groups)
@weak_module
class Conv3d(_ConvNd): class Conv3d(_ConvNd):
r"""Applies a 3D convolution over an input signal composed of several input r"""Applies a 3D convolution over an input signal composed of several input
planes. planes.
@ -470,7 +464,6 @@ class Conv3d(_ConvNd):
in_channels, out_channels, kernel_size, stride, padding, dilation, in_channels, out_channels, kernel_size, stride, padding, dilation,
False, _triple(0), groups, bias, padding_mode) False, _triple(0), groups, bias, padding_mode)
@weak_script_method
def forward(self, input): def forward(self, input):
if self.padding_mode == 'circular': if self.padding_mode == 'circular':
expanded_padding = ((self.padding[2] + 1) // 2, self.padding[2] // 2, expanded_padding = ((self.padding[2] + 1) // 2, self.padding[2] // 2,
@ -483,9 +476,7 @@ class Conv3d(_ConvNd):
self.padding, self.dilation, self.groups) self.padding, self.dilation, self.groups)
@weak_module
class _ConvTransposeMixin(object): class _ConvTransposeMixin(object):
@weak_script_method
def forward(self, input, output_size=None): def forward(self, input, output_size=None):
# type(Tensor, Optional[List[int]]) -> Tensor # type(Tensor, Optional[List[int]]) -> Tensor
output_padding = self._output_padding(input, output_size, self.stride, self.padding, self.kernel_size) output_padding = self._output_padding(input, output_size, self.stride, self.padding, self.kernel_size)
@ -497,7 +488,6 @@ class _ConvTransposeMixin(object):
else: else:
return func(input, self.weight, self.bias) return func(input, self.weight, self.bias)
@weak_script_method
def _output_padding(self, input, output_size, stride, padding, kernel_size): def _output_padding(self, input, output_size, stride, padding, kernel_size):
# type: (Tensor, Optional[List[int]], List[int], List[int], List[int]) -> List[int] # type: (Tensor, Optional[List[int]], List[int], List[int], List[int]) -> List[int]
if output_size is None: if output_size is None:
@ -537,7 +527,6 @@ class _ConvTransposeMixin(object):
return ret return ret
@weak_module
class ConvTranspose1d(_ConvTransposeMixin, _ConvNd): class ConvTranspose1d(_ConvTransposeMixin, _ConvNd):
r"""Applies a 1D transposed convolution operator over an input image r"""Applies a 1D transposed convolution operator over an input image
composed of several input planes. composed of several input planes.
@ -638,7 +627,6 @@ class ConvTranspose1d(_ConvTransposeMixin, _ConvNd):
in_channels, out_channels, kernel_size, stride, padding, dilation, in_channels, out_channels, kernel_size, stride, padding, dilation,
True, output_padding, groups, bias, padding_mode) True, output_padding, groups, bias, padding_mode)
@weak_script_method
def forward(self, input, output_size=None): def forward(self, input, output_size=None):
# type: (Tensor, Optional[List[int]]) -> Tensor # type: (Tensor, Optional[List[int]]) -> Tensor
if self.padding_mode != 'zeros': if self.padding_mode != 'zeros':
@ -650,7 +638,6 @@ class ConvTranspose1d(_ConvTransposeMixin, _ConvNd):
output_padding, self.groups, self.dilation) output_padding, self.groups, self.dilation)
@weak_module
class ConvTranspose2d(_ConvTransposeMixin, _ConvNd): class ConvTranspose2d(_ConvTransposeMixin, _ConvNd):
r"""Applies a 2D transposed convolution operator over an input image r"""Applies a 2D transposed convolution operator over an input image
composed of several input planes. composed of several input planes.
@ -786,7 +773,6 @@ class ConvTranspose2d(_ConvTransposeMixin, _ConvNd):
in_channels, out_channels, kernel_size, stride, padding, dilation, in_channels, out_channels, kernel_size, stride, padding, dilation,
True, output_padding, groups, bias, padding_mode) True, output_padding, groups, bias, padding_mode)
@weak_script_method
def forward(self, input, output_size=None): def forward(self, input, output_size=None):
# type: (Tensor, Optional[List[int]]) -> Tensor # type: (Tensor, Optional[List[int]]) -> Tensor
if self.padding_mode != 'zeros': if self.padding_mode != 'zeros':
@ -799,7 +785,6 @@ class ConvTranspose2d(_ConvTransposeMixin, _ConvNd):
output_padding, self.groups, self.dilation) output_padding, self.groups, self.dilation)
@weak_module
class ConvTranspose3d(_ConvTransposeMixin, _ConvNd): class ConvTranspose3d(_ConvTransposeMixin, _ConvNd):
r"""Applies a 3D transposed convolution operator over an input image composed of several input r"""Applies a 3D transposed convolution operator over an input image composed of several input
planes. planes.
@ -931,7 +916,6 @@ class ConvTranspose3d(_ConvTransposeMixin, _ConvNd):
in_channels, out_channels, kernel_size, stride, padding, dilation, in_channels, out_channels, kernel_size, stride, padding, dilation,
True, output_padding, groups, bias, padding_mode) True, output_padding, groups, bias, padding_mode)
@weak_script_method
def forward(self, input, output_size=None): def forward(self, input, output_size=None):
# type: (Tensor, Optional[List[int]]) -> Tensor # type: (Tensor, Optional[List[int]]) -> Tensor
if self.padding_mode != 'zeros': if self.padding_mode != 'zeros':

View File

@ -1,9 +1,7 @@
from .module import Module from .module import Module
from .. import functional as F from .. import functional as F
from ..._jit_internal import weak_module, weak_script_method
@weak_module
class PairwiseDistance(Module): class PairwiseDistance(Module):
r""" r"""
Computes the batchwise pairwise distance between vectors :math:`v_1`, :math:`v_2` using the p-norm: Computes the batchwise pairwise distance between vectors :math:`v_1`, :math:`v_2` using the p-norm:
@ -35,12 +33,10 @@ class PairwiseDistance(Module):
self.eps = eps self.eps = eps
self.keepdim = keepdim self.keepdim = keepdim
@weak_script_method
def forward(self, x1, x2): def forward(self, x1, x2):
return F.pairwise_distance(x1, x2, self.norm, self.eps, self.keepdim) return F.pairwise_distance(x1, x2, self.norm, self.eps, self.keepdim)
@weak_module
class CosineSimilarity(Module): class CosineSimilarity(Module):
r"""Returns cosine similarity between :math:`x_1` and :math:`x_2`, computed along dim. r"""Returns cosine similarity between :math:`x_1` and :math:`x_2`, computed along dim.
@ -68,6 +64,5 @@ class CosineSimilarity(Module):
self.dim = dim self.dim = dim
self.eps = eps self.eps = eps
@weak_script_method
def forward(self, x1, x2): def forward(self, x1, x2):
return F.cosine_similarity(x1, x2, self.dim, self.eps) return F.cosine_similarity(x1, x2, self.dim, self.eps)

View File

@ -1,6 +1,5 @@
from .module import Module from .module import Module
from .. import functional as F from .. import functional as F
from ..._jit_internal import weak_module, weak_script_method
class _DropoutNd(Module): class _DropoutNd(Module):
@ -18,7 +17,6 @@ class _DropoutNd(Module):
return 'p={}, inplace={}'.format(self.p, self.inplace) return 'p={}, inplace={}'.format(self.p, self.inplace)
@weak_module
class Dropout(_DropoutNd): class Dropout(_DropoutNd):
r"""During training, randomly zeroes some of the elements of the input r"""During training, randomly zeroes some of the elements of the input
tensor with probability :attr:`p` using samples from a Bernoulli tensor with probability :attr:`p` using samples from a Bernoulli
@ -52,12 +50,10 @@ class Dropout(_DropoutNd):
detectors: https://arxiv.org/abs/1207.0580 detectors: https://arxiv.org/abs/1207.0580
""" """
@weak_script_method
def forward(self, input): def forward(self, input):
return F.dropout(input, self.p, self.training, self.inplace) return F.dropout(input, self.p, self.training, self.inplace)
@weak_module
class Dropout2d(_DropoutNd): class Dropout2d(_DropoutNd):
r"""Randomly zero out entire channels (a channel is a 2D feature map, r"""Randomly zero out entire channels (a channel is a 2D feature map,
e.g., the :math:`j`-th channel of the :math:`i`-th sample in the e.g., the :math:`j`-th channel of the :math:`i`-th sample in the
@ -96,12 +92,10 @@ class Dropout2d(_DropoutNd):
http://arxiv.org/abs/1411.4280 http://arxiv.org/abs/1411.4280
""" """
@weak_script_method
def forward(self, input): def forward(self, input):
return F.dropout2d(input, self.p, self.training, self.inplace) return F.dropout2d(input, self.p, self.training, self.inplace)
@weak_module
class Dropout3d(_DropoutNd): class Dropout3d(_DropoutNd):
r"""Randomly zero out entire channels (a channel is a 3D feature map, r"""Randomly zero out entire channels (a channel is a 3D feature map,
e.g., the :math:`j`-th channel of the :math:`i`-th sample in the e.g., the :math:`j`-th channel of the :math:`i`-th sample in the
@ -140,12 +134,10 @@ class Dropout3d(_DropoutNd):
http://arxiv.org/abs/1411.4280 http://arxiv.org/abs/1411.4280
""" """
@weak_script_method
def forward(self, input): def forward(self, input):
return F.dropout3d(input, self.p, self.training, self.inplace) return F.dropout3d(input, self.p, self.training, self.inplace)
@weak_module
class AlphaDropout(_DropoutNd): class AlphaDropout(_DropoutNd):
r"""Applies Alpha Dropout over the input. r"""Applies Alpha Dropout over the input.
@ -184,14 +176,11 @@ class AlphaDropout(_DropoutNd):
.. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515 .. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515
""" """
@weak_script_method
def forward(self, input): def forward(self, input):
return F.alpha_dropout(input, self.p, self.training) return F.alpha_dropout(input, self.p, self.training)
@weak_module
class FeatureAlphaDropout(_DropoutNd): class FeatureAlphaDropout(_DropoutNd):
@weak_script_method
def forward(self, input): def forward(self, input):
return F.feature_alpha_dropout(input, self.p, self.training) return F.feature_alpha_dropout(input, self.p, self.training)

View File

@ -1,10 +1,8 @@
# coding=utf-8 # coding=utf-8
from .module import Module from .module import Module
from .. import functional as F from .. import functional as F
from ..._jit_internal import weak_module, weak_script_method
@weak_module
class Fold(Module): class Fold(Module):
r"""Combines an array of sliding local blocks into a large containing r"""Combines an array of sliding local blocks into a large containing
tensor. tensor.
@ -101,7 +99,6 @@ class Fold(Module):
self.padding = padding self.padding = padding
self.stride = stride self.stride = stride
@weak_script_method
def forward(self, input): def forward(self, input):
return F.fold(input, self.output_size, self.kernel_size, self.dilation, return F.fold(input, self.output_size, self.kernel_size, self.dilation,
self.padding, self.stride) self.padding, self.stride)
@ -113,7 +110,6 @@ class Fold(Module):
) )
@weak_module
class Unfold(Module): class Unfold(Module):
r"""Extracts sliding local blocks from a batched input tensor. r"""Extracts sliding local blocks from a batched input tensor.
@ -217,7 +213,6 @@ class Unfold(Module):
self.padding = padding self.padding = padding
self.stride = stride self.stride = stride
@weak_script_method
def forward(self, input): def forward(self, input):
return F.unfold(input, self.kernel_size, self.dilation, return F.unfold(input, self.kernel_size, self.dilation,
self.padding, self.stride) self.padding, self.stride)

View File

@ -1,6 +1,5 @@
from .batchnorm import _BatchNorm from .batchnorm import _BatchNorm
from .. import functional as F from .. import functional as F
from ..._jit_internal import weak_module, weak_script_method
class _InstanceNorm(_BatchNorm): class _InstanceNorm(_BatchNorm):
@ -9,7 +8,6 @@ class _InstanceNorm(_BatchNorm):
super(_InstanceNorm, self).__init__( super(_InstanceNorm, self).__init__(
num_features, eps, momentum, affine, track_running_stats) num_features, eps, momentum, affine, track_running_stats)
@weak_script_method
def _check_input_dim(self, input): def _check_input_dim(self, input):
raise NotImplementedError raise NotImplementedError
@ -43,7 +41,6 @@ class _InstanceNorm(_BatchNorm):
state_dict, prefix, local_metadata, strict, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs) missing_keys, unexpected_keys, error_msgs)
@weak_script_method
def forward(self, input): def forward(self, input):
self._check_input_dim(input) self._check_input_dim(input)
@ -52,7 +49,6 @@ class _InstanceNorm(_BatchNorm):
self.training or not self.track_running_stats, self.momentum, self.eps) self.training or not self.track_running_stats, self.momentum, self.eps)
@weak_module
class InstanceNorm1d(_InstanceNorm): class InstanceNorm1d(_InstanceNorm):
r"""Applies Instance Normalization over a 3D input (a mini-batch of 1D r"""Applies Instance Normalization over a 3D input (a mini-batch of 1D
inputs with optional additional channel dimension) as described in the paper inputs with optional additional channel dimension) as described in the paper
@ -121,7 +117,6 @@ class InstanceNorm1d(_InstanceNorm):
https://arxiv.org/abs/1607.08022 https://arxiv.org/abs/1607.08022
""" """
@weak_script_method
def _check_input_dim(self, input): def _check_input_dim(self, input):
if input.dim() == 2: if input.dim() == 2:
raise ValueError( raise ValueError(
@ -135,7 +130,6 @@ class InstanceNorm1d(_InstanceNorm):
.format(input.dim())) .format(input.dim()))
@weak_module
class InstanceNorm2d(_InstanceNorm): class InstanceNorm2d(_InstanceNorm):
r"""Applies Instance Normalization over a 4D input (a mini-batch of 2D inputs r"""Applies Instance Normalization over a 4D input (a mini-batch of 2D inputs
with additional channel dimension) as described in the paper with additional channel dimension) as described in the paper
@ -204,14 +198,12 @@ class InstanceNorm2d(_InstanceNorm):
https://arxiv.org/abs/1607.08022 https://arxiv.org/abs/1607.08022
""" """
@weak_script_method
def _check_input_dim(self, input): def _check_input_dim(self, input):
if input.dim() != 4: if input.dim() != 4:
raise ValueError('expected 4D input (got {}D input)' raise ValueError('expected 4D input (got {}D input)'
.format(input.dim())) .format(input.dim()))
@weak_module
class InstanceNorm3d(_InstanceNorm): class InstanceNorm3d(_InstanceNorm):
r"""Applies Instance Normalization over a 5D input (a mini-batch of 3D inputs r"""Applies Instance Normalization over a 5D input (a mini-batch of 3D inputs
with additional channel dimension) as described in the paper with additional channel dimension) as described in the paper
@ -280,7 +272,6 @@ class InstanceNorm3d(_InstanceNorm):
https://arxiv.org/abs/1607.08022 https://arxiv.org/abs/1607.08022
""" """
@weak_script_method
def _check_input_dim(self, input): def _check_input_dim(self, input):
if input.dim() != 5: if input.dim() != 5:
raise ValueError('expected 5D input (got {}D input)' raise ValueError('expected 5D input (got {}D input)'

View File

@ -5,10 +5,8 @@ from torch.nn.parameter import Parameter
from .. import functional as F from .. import functional as F
from .. import init from .. import init
from .module import Module from .module import Module
from ..._jit_internal import weak_module, weak_script_method
@weak_module
class Identity(Module): class Identity(Module):
r"""A placeholder identity operator that is argument-insensitive. r"""A placeholder identity operator that is argument-insensitive.
@ -28,12 +26,10 @@ class Identity(Module):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(Identity, self).__init__() super(Identity, self).__init__()
@weak_script_method
def forward(self, input): def forward(self, input):
return input return input
@weak_module
class Linear(Module): class Linear(Module):
r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b` r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b`
@ -87,7 +83,6 @@ class Linear(Module):
bound = 1 / math.sqrt(fan_in) bound = 1 / math.sqrt(fan_in)
init.uniform_(self.bias, -bound, bound) init.uniform_(self.bias, -bound, bound)
@weak_script_method
def forward(self, input): def forward(self, input):
return F.linear(input, self.weight, self.bias) return F.linear(input, self.weight, self.bias)
@ -97,7 +92,6 @@ class Linear(Module):
) )
@weak_module
class Bilinear(Module): class Bilinear(Module):
r"""Applies a bilinear transformation to the incoming data: r"""Applies a bilinear transformation to the incoming data:
:math:`y = x_1 A x_2 + b` :math:`y = x_1 A x_2 + b`
@ -157,7 +151,6 @@ class Bilinear(Module):
if self.bias is not None: if self.bias is not None:
init.uniform_(self.bias, -bound, bound) init.uniform_(self.bias, -bound, bound)
@weak_script_method
def forward(self, input1, input2): def forward(self, input1, input2):
return F.bilinear(input1, input2, self.weight, self.bias) return F.bilinear(input1, input2, self.weight, self.bias)

View File

@ -3,7 +3,6 @@ import warnings
from .module import Module from .module import Module
from .. import functional as F from .. import functional as F
from .. import _reduction as _Reduction from .. import _reduction as _Reduction
from ..._jit_internal import weak_module, weak_script_method
class _Loss(Module): class _Loss(Module):
@ -21,7 +20,6 @@ class _WeightedLoss(_Loss):
self.register_buffer('weight', weight) self.register_buffer('weight', weight)
@weak_module
class L1Loss(_Loss): class L1Loss(_Loss):
r"""Creates a criterion that measures the mean absolute error (MAE) between each element in r"""Creates a criterion that measures the mean absolute error (MAE) between each element in
the input :math:`x` and target :math:`y`. the input :math:`x` and target :math:`y`.
@ -86,12 +84,10 @@ class L1Loss(_Loss):
def __init__(self, size_average=None, reduce=None, reduction='mean'): def __init__(self, size_average=None, reduce=None, reduction='mean'):
super(L1Loss, self).__init__(size_average, reduce, reduction) super(L1Loss, self).__init__(size_average, reduce, reduction)
@weak_script_method
def forward(self, input, target): def forward(self, input, target):
return F.l1_loss(input, target, reduction=self.reduction) return F.l1_loss(input, target, reduction=self.reduction)
@weak_module
class NLLLoss(_WeightedLoss): class NLLLoss(_WeightedLoss):
r"""The negative log likelihood loss. It is useful to train a classification r"""The negative log likelihood loss. It is useful to train a classification
problem with `C` classes. problem with `C` classes.
@ -204,12 +200,10 @@ class NLLLoss(_WeightedLoss):
super(NLLLoss, self).__init__(weight, size_average, reduce, reduction) super(NLLLoss, self).__init__(weight, size_average, reduce, reduction)
self.ignore_index = ignore_index self.ignore_index = ignore_index
@weak_script_method
def forward(self, input, target): def forward(self, input, target):
return F.nll_loss(input, target, weight=self.weight, ignore_index=self.ignore_index, reduction=self.reduction) return F.nll_loss(input, target, weight=self.weight, ignore_index=self.ignore_index, reduction=self.reduction)
@weak_module
class NLLLoss2d(NLLLoss): class NLLLoss2d(NLLLoss):
def __init__(self, weight=None, size_average=None, ignore_index=-100, def __init__(self, weight=None, size_average=None, ignore_index=-100,
reduce=None, reduction='mean'): reduce=None, reduction='mean'):
@ -219,7 +213,6 @@ class NLLLoss2d(NLLLoss):
super(NLLLoss2d, self).__init__(weight, size_average, ignore_index, reduce, reduction) super(NLLLoss2d, self).__init__(weight, size_average, ignore_index, reduce, reduction)
@weak_module
class PoissonNLLLoss(_Loss): class PoissonNLLLoss(_Loss):
r"""Negative log likelihood loss with Poisson distribution of target. r"""Negative log likelihood loss with Poisson distribution of target.
@ -286,13 +279,11 @@ class PoissonNLLLoss(_Loss):
self.full = full self.full = full
self.eps = eps self.eps = eps
@weak_script_method
def forward(self, log_input, target): def forward(self, log_input, target):
return F.poisson_nll_loss(log_input, target, log_input=self.log_input, full=self.full, return F.poisson_nll_loss(log_input, target, log_input=self.log_input, full=self.full,
eps=self.eps, reduction=self.reduction) eps=self.eps, reduction=self.reduction)
@weak_module
class KLDivLoss(_Loss): class KLDivLoss(_Loss):
r"""The `Kullback-Leibler divergence`_ Loss r"""The `Kullback-Leibler divergence`_ Loss
@ -370,12 +361,10 @@ class KLDivLoss(_Loss):
def __init__(self, size_average=None, reduce=None, reduction='mean'): def __init__(self, size_average=None, reduce=None, reduction='mean'):
super(KLDivLoss, self).__init__(size_average, reduce, reduction) super(KLDivLoss, self).__init__(size_average, reduce, reduction)
@weak_script_method
def forward(self, input, target): def forward(self, input, target):
return F.kl_div(input, target, reduction=self.reduction) return F.kl_div(input, target, reduction=self.reduction)
@weak_module
class MSELoss(_Loss): class MSELoss(_Loss):
r"""Creates a criterion that measures the mean squared error (squared L2 norm) between r"""Creates a criterion that measures the mean squared error (squared L2 norm) between
each element in the input :math:`x` and target :math:`y`. each element in the input :math:`x` and target :math:`y`.
@ -438,12 +427,10 @@ class MSELoss(_Loss):
def __init__(self, size_average=None, reduce=None, reduction='mean'): def __init__(self, size_average=None, reduce=None, reduction='mean'):
super(MSELoss, self).__init__(size_average, reduce, reduction) super(MSELoss, self).__init__(size_average, reduce, reduction)
@weak_script_method
def forward(self, input, target): def forward(self, input, target):
return F.mse_loss(input, target, reduction=self.reduction) return F.mse_loss(input, target, reduction=self.reduction)
@weak_module
class BCELoss(_WeightedLoss): class BCELoss(_WeightedLoss):
r"""Creates a criterion that measures the Binary Cross Entropy r"""Creates a criterion that measures the Binary Cross Entropy
between the target and the output: between the target and the output:
@ -507,12 +494,10 @@ class BCELoss(_WeightedLoss):
def __init__(self, weight=None, size_average=None, reduce=None, reduction='mean'): def __init__(self, weight=None, size_average=None, reduce=None, reduction='mean'):
super(BCELoss, self).__init__(weight, size_average, reduce, reduction) super(BCELoss, self).__init__(weight, size_average, reduce, reduction)
@weak_script_method
def forward(self, input, target): def forward(self, input, target):
return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction) return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)
@weak_module
class BCEWithLogitsLoss(_Loss): class BCEWithLogitsLoss(_Loss):
r"""This loss combines a `Sigmoid` layer and the `BCELoss` in one single r"""This loss combines a `Sigmoid` layer and the `BCELoss` in one single
class. This version is more numerically stable than using a plain `Sigmoid` class. This version is more numerically stable than using a plain `Sigmoid`
@ -609,7 +594,6 @@ class BCEWithLogitsLoss(_Loss):
self.register_buffer('weight', weight) self.register_buffer('weight', weight)
self.register_buffer('pos_weight', pos_weight) self.register_buffer('pos_weight', pos_weight)
@weak_script_method
def forward(self, input, target): def forward(self, input, target):
return F.binary_cross_entropy_with_logits(input, target, return F.binary_cross_entropy_with_logits(input, target,
self.weight, self.weight,
@ -617,7 +601,6 @@ class BCEWithLogitsLoss(_Loss):
reduction=self.reduction) reduction=self.reduction)
@weak_module
class HingeEmbeddingLoss(_Loss): class HingeEmbeddingLoss(_Loss):
r"""Measures the loss given an input tensor :math:`x` and a labels tensor :math:`y` r"""Measures the loss given an input tensor :math:`x` and a labels tensor :math:`y`
(containing 1 or -1). (containing 1 or -1).
@ -673,12 +656,10 @@ class HingeEmbeddingLoss(_Loss):
super(HingeEmbeddingLoss, self).__init__(size_average, reduce, reduction) super(HingeEmbeddingLoss, self).__init__(size_average, reduce, reduction)
self.margin = margin self.margin = margin
@weak_script_method
def forward(self, input, target): def forward(self, input, target):
return F.hinge_embedding_loss(input, target, margin=self.margin, reduction=self.reduction) return F.hinge_embedding_loss(input, target, margin=self.margin, reduction=self.reduction)
@weak_module
class MultiLabelMarginLoss(_Loss): class MultiLabelMarginLoss(_Loss):
r"""Creates a criterion that optimizes a multi-class multi-classification r"""Creates a criterion that optimizes a multi-class multi-classification
hinge loss (margin-based loss) between input :math:`x` (a 2D mini-batch `Tensor`) hinge loss (margin-based loss) between input :math:`x` (a 2D mini-batch `Tensor`)
@ -739,12 +720,10 @@ class MultiLabelMarginLoss(_Loss):
def __init__(self, size_average=None, reduce=None, reduction='mean'): def __init__(self, size_average=None, reduce=None, reduction='mean'):
super(MultiLabelMarginLoss, self).__init__(size_average, reduce, reduction) super(MultiLabelMarginLoss, self).__init__(size_average, reduce, reduction)
@weak_script_method
def forward(self, input, target): def forward(self, input, target):
return F.multilabel_margin_loss(input, target, reduction=self.reduction) return F.multilabel_margin_loss(input, target, reduction=self.reduction)
@weak_module
class SmoothL1Loss(_Loss): class SmoothL1Loss(_Loss):
r"""Creates a criterion that uses a squared term if the absolute r"""Creates a criterion that uses a squared term if the absolute
element-wise error falls below 1 and an L1 term otherwise. element-wise error falls below 1 and an L1 term otherwise.
@ -799,12 +778,10 @@ class SmoothL1Loss(_Loss):
def __init__(self, size_average=None, reduce=None, reduction='mean'): def __init__(self, size_average=None, reduce=None, reduction='mean'):
super(SmoothL1Loss, self).__init__(size_average, reduce, reduction) super(SmoothL1Loss, self).__init__(size_average, reduce, reduction)
@weak_script_method
def forward(self, input, target): def forward(self, input, target):
return F.smooth_l1_loss(input, target, reduction=self.reduction) return F.smooth_l1_loss(input, target, reduction=self.reduction)
@weak_module
class SoftMarginLoss(_Loss): class SoftMarginLoss(_Loss):
r"""Creates a criterion that optimizes a two-class classification r"""Creates a criterion that optimizes a two-class classification
logistic loss between input tensor :math:`x` and target tensor :math:`y` logistic loss between input tensor :math:`x` and target tensor :math:`y`
@ -842,12 +819,10 @@ class SoftMarginLoss(_Loss):
def __init__(self, size_average=None, reduce=None, reduction='mean'): def __init__(self, size_average=None, reduce=None, reduction='mean'):
super(SoftMarginLoss, self).__init__(size_average, reduce, reduction) super(SoftMarginLoss, self).__init__(size_average, reduce, reduction)
@weak_script_method
def forward(self, input, target): def forward(self, input, target):
return F.soft_margin_loss(input, target, reduction=self.reduction) return F.soft_margin_loss(input, target, reduction=self.reduction)
@weak_module
class CrossEntropyLoss(_WeightedLoss): class CrossEntropyLoss(_WeightedLoss):
r"""This criterion combines :func:`nn.LogSoftmax` and :func:`nn.NLLLoss` in one single class. r"""This criterion combines :func:`nn.LogSoftmax` and :func:`nn.NLLLoss` in one single class.
@ -936,13 +911,11 @@ class CrossEntropyLoss(_WeightedLoss):
super(CrossEntropyLoss, self).__init__(weight, size_average, reduce, reduction) super(CrossEntropyLoss, self).__init__(weight, size_average, reduce, reduction)
self.ignore_index = ignore_index self.ignore_index = ignore_index
@weak_script_method
def forward(self, input, target): def forward(self, input, target):
return F.cross_entropy(input, target, weight=self.weight, return F.cross_entropy(input, target, weight=self.weight,
ignore_index=self.ignore_index, reduction=self.reduction) ignore_index=self.ignore_index, reduction=self.reduction)
@weak_module
class MultiLabelSoftMarginLoss(_WeightedLoss): class MultiLabelSoftMarginLoss(_WeightedLoss):
r"""Creates a criterion that optimizes a multi-label one-versus-all r"""Creates a criterion that optimizes a multi-label one-versus-all
loss based on max-entropy, between input :math:`x` and target :math:`y` of size loss based on max-entropy, between input :math:`x` and target :math:`y` of size
@ -986,12 +959,10 @@ class MultiLabelSoftMarginLoss(_WeightedLoss):
def __init__(self, weight=None, size_average=None, reduce=None, reduction='mean'): def __init__(self, weight=None, size_average=None, reduce=None, reduction='mean'):
super(MultiLabelSoftMarginLoss, self).__init__(weight, size_average, reduce, reduction) super(MultiLabelSoftMarginLoss, self).__init__(weight, size_average, reduce, reduction)
@weak_script_method
def forward(self, input, target): def forward(self, input, target):
return F.multilabel_soft_margin_loss(input, target, weight=self.weight, reduction=self.reduction) return F.multilabel_soft_margin_loss(input, target, weight=self.weight, reduction=self.reduction)
@weak_module
class CosineEmbeddingLoss(_Loss): class CosineEmbeddingLoss(_Loss):
r"""Creates a criterion that measures the loss given input tensors r"""Creates a criterion that measures the loss given input tensors
:math:`x_1`, :math:`x_2` and a `Tensor` label :math:`y` with values 1 or -1. :math:`x_1`, :math:`x_2` and a `Tensor` label :math:`y` with values 1 or -1.
@ -1034,12 +1005,10 @@ class CosineEmbeddingLoss(_Loss):
super(CosineEmbeddingLoss, self).__init__(size_average, reduce, reduction) super(CosineEmbeddingLoss, self).__init__(size_average, reduce, reduction)
self.margin = margin self.margin = margin
@weak_script_method
def forward(self, input1, input2, target): def forward(self, input1, input2, target):
return F.cosine_embedding_loss(input1, input2, target, margin=self.margin, reduction=self.reduction) return F.cosine_embedding_loss(input1, input2, target, margin=self.margin, reduction=self.reduction)
@weak_module
class MarginRankingLoss(_Loss): class MarginRankingLoss(_Loss):
r"""Creates a criterion that measures the loss given r"""Creates a criterion that measures the loss given
inputs :math:`x1`, :math:`x2`, two 1D mini-batch `Tensors`, inputs :math:`x1`, :math:`x2`, two 1D mini-batch `Tensors`,
@ -1082,12 +1051,10 @@ class MarginRankingLoss(_Loss):
super(MarginRankingLoss, self).__init__(size_average, reduce, reduction) super(MarginRankingLoss, self).__init__(size_average, reduce, reduction)
self.margin = margin self.margin = margin
@weak_script_method
def forward(self, input1, input2, target): def forward(self, input1, input2, target):
return F.margin_ranking_loss(input1, input2, target, margin=self.margin, reduction=self.reduction) return F.margin_ranking_loss(input1, input2, target, margin=self.margin, reduction=self.reduction)
@weak_module
class MultiMarginLoss(_WeightedLoss): class MultiMarginLoss(_WeightedLoss):
r"""Creates a criterion that optimizes a multi-class classification hinge r"""Creates a criterion that optimizes a multi-class classification hinge
loss (margin-based loss) between input :math:`x` (a 2D mini-batch `Tensor`) and loss (margin-based loss) between input :math:`x` (a 2D mini-batch `Tensor`) and
@ -1145,13 +1112,11 @@ class MultiMarginLoss(_WeightedLoss):
self.p = p self.p = p
self.margin = margin self.margin = margin
@weak_script_method
def forward(self, input, target): def forward(self, input, target):
return F.multi_margin_loss(input, target, p=self.p, margin=self.margin, return F.multi_margin_loss(input, target, p=self.p, margin=self.margin,
weight=self.weight, reduction=self.reduction) weight=self.weight, reduction=self.reduction)
@weak_module
class TripletMarginLoss(_Loss): class TripletMarginLoss(_Loss):
r"""Creates a criterion that measures the triplet loss given an input r"""Creates a criterion that measures the triplet loss given an input
tensors :math:`x1`, :math:`x2`, :math:`x3` and a margin with a value greater than :math:`0`. tensors :math:`x1`, :math:`x2`, :math:`x3` and a margin with a value greater than :math:`0`.
@ -1221,13 +1186,11 @@ class TripletMarginLoss(_Loss):
self.eps = eps self.eps = eps
self.swap = swap self.swap = swap
@weak_script_method
def forward(self, anchor, positive, negative): def forward(self, anchor, positive, negative):
return F.triplet_margin_loss(anchor, positive, negative, margin=self.margin, p=self.p, return F.triplet_margin_loss(anchor, positive, negative, margin=self.margin, p=self.p,
eps=self.eps, swap=self.swap, reduction=self.reduction) eps=self.eps, swap=self.swap, reduction=self.reduction)
@weak_module
class CTCLoss(_Loss): class CTCLoss(_Loss):
r"""The Connectionist Temporal Classification loss. r"""The Connectionist Temporal Classification loss.
@ -1327,7 +1290,6 @@ class CTCLoss(_Loss):
self.blank = blank self.blank = blank
self.zero_infinity = zero_infinity self.zero_infinity = zero_infinity
@weak_script_method
def forward(self, log_probs, targets, input_lengths, target_lengths): def forward(self, log_probs, targets, input_lengths, target_lengths):
return F.ctc_loss(log_probs, targets, input_lengths, target_lengths, self.blank, self.reduction, return F.ctc_loss(log_probs, targets, input_lengths, target_lengths, self.blank, self.reduction,
self.zero_infinity) self.zero_infinity)

View File

@ -4,10 +4,8 @@ from torch.nn.parameter import Parameter
from .module import Module from .module import Module
from .. import functional as F from .. import functional as F
from .. import init from .. import init
from ..._jit_internal import weak_module, weak_script_method
@weak_module
class LocalResponseNorm(Module): class LocalResponseNorm(Module):
r"""Applies local response normalization over an input signal composed r"""Applies local response normalization over an input signal composed
of several input planes, where channels occupy the second dimension. of several input planes, where channels occupy the second dimension.
@ -45,7 +43,6 @@ class LocalResponseNorm(Module):
self.beta = beta self.beta = beta
self.k = k self.k = k
@weak_script_method
def forward(self, input): def forward(self, input):
return F.local_response_norm(input, self.size, self.alpha, self.beta, return F.local_response_norm(input, self.size, self.alpha, self.beta,
self.k) self.k)
@ -71,7 +68,6 @@ class CrossMapLRN2d(Module):
return '{size}, alpha={alpha}, beta={beta}, k={k}'.format(**self.__dict__) return '{size}, alpha={alpha}, beta={beta}, k={k}'.format(**self.__dict__)
@weak_module
class LayerNorm(Module): class LayerNorm(Module):
r"""Applies Layer Normalization over a mini-batch of inputs as described in r"""Applies Layer Normalization over a mini-batch of inputs as described in
the paper `Layer Normalization`_ . the paper `Layer Normalization`_ .
@ -151,7 +147,6 @@ class LayerNorm(Module):
init.ones_(self.weight) init.ones_(self.weight)
init.zeros_(self.bias) init.zeros_(self.bias)
@weak_script_method
def forward(self, input): def forward(self, input):
return F.layer_norm( return F.layer_norm(
input, self.normalized_shape, self.weight, self.bias, self.eps) input, self.normalized_shape, self.weight, self.bias, self.eps)
@ -161,7 +156,6 @@ class LayerNorm(Module):
'elementwise_affine={elementwise_affine}'.format(**self.__dict__) 'elementwise_affine={elementwise_affine}'.format(**self.__dict__)
@weak_module
class GroupNorm(Module): class GroupNorm(Module):
r"""Applies Group Normalization over a mini-batch of inputs as described in r"""Applies Group Normalization over a mini-batch of inputs as described in
the paper `Group Normalization`_ . the paper `Group Normalization`_ .
@ -226,7 +220,6 @@ class GroupNorm(Module):
init.ones_(self.weight) init.ones_(self.weight)
init.zeros_(self.bias) init.zeros_(self.bias)
@weak_script_method
def forward(self, input): def forward(self, input):
return F.group_norm( return F.group_norm(
input, self.num_groups, self.weight, self.bias, self.eps) input, self.num_groups, self.weight, self.bias, self.eps)

View File

@ -1,13 +1,11 @@
from .module import Module from .module import Module
from .utils import _pair, _quadruple, _ntuple from .utils import _pair, _quadruple, _ntuple
from .. import functional as F from .. import functional as F
from ..._jit_internal import weak_module, weak_script_method
# TODO: grad_output size asserts in THNN # TODO: grad_output size asserts in THNN
@weak_module
class _ConstantPadNd(Module): class _ConstantPadNd(Module):
__constants__ = ['padding', 'value'] __constants__ = ['padding', 'value']
@ -15,7 +13,6 @@ class _ConstantPadNd(Module):
super(_ConstantPadNd, self).__init__() super(_ConstantPadNd, self).__init__()
self.value = value self.value = value
@weak_script_method
def forward(self, input): def forward(self, input):
return F.pad(input, self.padding, 'constant', self.value) return F.pad(input, self.padding, 'constant', self.value)
@ -23,7 +20,6 @@ class _ConstantPadNd(Module):
return 'padding={}, value={}'.format(self.padding, self.value) return 'padding={}, value={}'.format(self.padding, self.value)
@weak_module
class ConstantPad1d(_ConstantPadNd): class ConstantPad1d(_ConstantPadNd):
r"""Pads the input tensor boundaries with a constant value. r"""Pads the input tensor boundaries with a constant value.
@ -73,7 +69,6 @@ class ConstantPad1d(_ConstantPadNd):
self.padding = _pair(padding) self.padding = _pair(padding)
@weak_module
class ConstantPad2d(_ConstantPadNd): class ConstantPad2d(_ConstantPadNd):
r"""Pads the input tensor boundaries with a constant value. r"""Pads the input tensor boundaries with a constant value.
@ -123,7 +118,6 @@ class ConstantPad2d(_ConstantPadNd):
self.padding = _quadruple(padding) self.padding = _quadruple(padding)
@weak_module
class ConstantPad3d(_ConstantPadNd): class ConstantPad3d(_ConstantPadNd):
r"""Pads the input tensor boundaries with a constant value. r"""Pads the input tensor boundaries with a constant value.
@ -162,11 +156,9 @@ class ConstantPad3d(_ConstantPadNd):
self.padding = _ntuple(6)(padding) self.padding = _ntuple(6)(padding)
@weak_module
class _ReflectionPadNd(Module): class _ReflectionPadNd(Module):
__constants__ = ['padding'] __constants__ = ['padding']
@weak_script_method
def forward(self, input): def forward(self, input):
return F.pad(input, self.padding, 'reflect') return F.pad(input, self.padding, 'reflect')
@ -174,7 +166,6 @@ class _ReflectionPadNd(Module):
return '{}'.format(self.padding) return '{}'.format(self.padding)
@weak_module
class ReflectionPad1d(_ReflectionPadNd): class ReflectionPad1d(_ReflectionPadNd):
r"""Pads the input tensor using the reflection of the input boundary. r"""Pads the input tensor using the reflection of the input boundary.
@ -214,7 +205,6 @@ class ReflectionPad1d(_ReflectionPadNd):
self.padding = _pair(padding) self.padding = _pair(padding)
@weak_module
class ReflectionPad2d(_ReflectionPadNd): class ReflectionPad2d(_ReflectionPadNd):
r"""Pads the input tensor using the reflection of the input boundary. r"""Pads the input tensor using the reflection of the input boundary.
@ -265,11 +255,9 @@ class ReflectionPad2d(_ReflectionPadNd):
self.padding = _quadruple(padding) self.padding = _quadruple(padding)
@weak_module
class _ReplicationPadNd(Module): class _ReplicationPadNd(Module):
__constants__ = ['padding'] __constants__ = ['padding']
@weak_script_method
def forward(self, input): def forward(self, input):
return F.pad(input, self.padding, 'replicate') return F.pad(input, self.padding, 'replicate')
@ -277,7 +265,6 @@ class _ReplicationPadNd(Module):
return '{}'.format(self.padding) return '{}'.format(self.padding)
@weak_module
class ReplicationPad1d(_ReplicationPadNd): class ReplicationPad1d(_ReplicationPadNd):
r"""Pads the input tensor using replication of the input boundary. r"""Pads the input tensor using replication of the input boundary.
@ -317,7 +304,6 @@ class ReplicationPad1d(_ReplicationPadNd):
self.padding = _pair(padding) self.padding = _pair(padding)
@weak_module
class ReplicationPad2d(_ReplicationPadNd): class ReplicationPad2d(_ReplicationPadNd):
r"""Pads the input tensor using replication of the input boundary. r"""Pads the input tensor using replication of the input boundary.
@ -368,7 +354,6 @@ class ReplicationPad2d(_ReplicationPadNd):
self.padding = _quadruple(padding) self.padding = _quadruple(padding)
@weak_module
class ReplicationPad3d(_ReplicationPadNd): class ReplicationPad3d(_ReplicationPadNd):
r"""Pads the input tensor using replication of the input boundary. r"""Pads the input tensor using replication of the input boundary.
@ -407,7 +392,6 @@ class ReplicationPad3d(_ReplicationPadNd):
self.padding = _ntuple(6)(padding) self.padding = _ntuple(6)(padding)
@weak_module
class ZeroPad2d(ConstantPad2d): class ZeroPad2d(ConstantPad2d):
r"""Pads the input tensor boundaries with zero. r"""Pads the input tensor boundaries with zero.

View File

@ -1,9 +1,7 @@
from .module import Module from .module import Module
from .. import functional as F from .. import functional as F
from ..._jit_internal import weak_module, weak_script_method
@weak_module
class PixelShuffle(Module): class PixelShuffle(Module):
r"""Rearranges elements in a tensor of shape :math:`(*, C \times r^2, H, W)` r"""Rearranges elements in a tensor of shape :math:`(*, C \times r^2, H, W)`
to a tensor of shape :math:`(*, C, H \times r, W \times r)`. to a tensor of shape :math:`(*, C, H \times r, W \times r)`.
@ -41,7 +39,6 @@ class PixelShuffle(Module):
super(PixelShuffle, self).__init__() super(PixelShuffle, self).__init__()
self.upscale_factor = upscale_factor self.upscale_factor = upscale_factor
@weak_script_method
def forward(self, input): def forward(self, input):
return F.pixel_shuffle(input, self.upscale_factor) return F.pixel_shuffle(input, self.upscale_factor)

View File

@ -1,10 +1,8 @@
from .module import Module from .module import Module
from .utils import _single, _pair, _triple from .utils import _single, _pair, _triple
from .. import functional as F from .. import functional as F
from ..._jit_internal import weak_module, weak_script_method
@weak_module
class _MaxPoolNd(Module): class _MaxPoolNd(Module):
__constants__ = ['kernel_size', 'stride', 'padding', 'dilation', __constants__ = ['kernel_size', 'stride', 'padding', 'dilation',
'return_indices', 'ceil_mode'] 'return_indices', 'ceil_mode']
@ -24,7 +22,6 @@ class _MaxPoolNd(Module):
', dilation={dilation}, ceil_mode={ceil_mode}'.format(**self.__dict__) ', dilation={dilation}, ceil_mode={ceil_mode}'.format(**self.__dict__)
@weak_module
class MaxPool1d(_MaxPoolNd): class MaxPool1d(_MaxPoolNd):
r"""Applies a 1D max pooling over an input signal composed of several input r"""Applies a 1D max pooling over an input signal composed of several input
planes. planes.
@ -68,7 +65,6 @@ class MaxPool1d(_MaxPoolNd):
https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
""" """
@weak_script_method
def forward(self, input): def forward(self, input):
return F.max_pool1d(input, self.kernel_size, self.stride, return F.max_pool1d(input, self.kernel_size, self.stride,
self.padding, self.dilation, self.ceil_mode, self.padding, self.dilation, self.ceil_mode,
@ -79,7 +75,6 @@ class MaxPool1d(_MaxPoolNd):
', dilation={dilation}, ceil_mode={ceil_mode}'.format(**self.__dict__) ', dilation={dilation}, ceil_mode={ceil_mode}'.format(**self.__dict__)
@weak_module
class MaxPool2d(_MaxPoolNd): class MaxPool2d(_MaxPoolNd):
r"""Applies a 2D max pooling over an input signal composed of several input r"""Applies a 2D max pooling over an input signal composed of several input
planes. planes.
@ -139,14 +134,12 @@ class MaxPool2d(_MaxPoolNd):
https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
""" """
@weak_script_method
def forward(self, input): def forward(self, input):
return F.max_pool2d(input, self.kernel_size, self.stride, return F.max_pool2d(input, self.kernel_size, self.stride,
self.padding, self.dilation, self.ceil_mode, self.padding, self.dilation, self.ceil_mode,
self.return_indices) self.return_indices)
@weak_module
class MaxPool3d(_MaxPoolNd): class MaxPool3d(_MaxPoolNd):
r"""Applies a 3D max pooling over an input signal composed of several input r"""Applies a 3D max pooling over an input signal composed of several input
planes. planes.
@ -210,14 +203,12 @@ class MaxPool3d(_MaxPoolNd):
https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
""" # noqa: E501 """ # noqa: E501
@weak_script_method
def forward(self, input): def forward(self, input):
return F.max_pool3d(input, self.kernel_size, self.stride, return F.max_pool3d(input, self.kernel_size, self.stride,
self.padding, self.dilation, self.ceil_mode, self.padding, self.dilation, self.ceil_mode,
self.return_indices) self.return_indices)
@weak_module
class _MaxUnpoolNd(Module): class _MaxUnpoolNd(Module):
def extra_repr(self): def extra_repr(self):
@ -226,7 +217,6 @@ class _MaxUnpoolNd(Module):
) )
@weak_module
class MaxUnpool1d(_MaxUnpoolNd): class MaxUnpool1d(_MaxUnpoolNd):
r"""Computes a partial inverse of :class:`MaxPool1d`. r"""Computes a partial inverse of :class:`MaxPool1d`.
@ -292,7 +282,6 @@ class MaxUnpool1d(_MaxUnpoolNd):
self.padding, output_size) self.padding, output_size)
@weak_module
class MaxUnpool2d(_MaxUnpoolNd): class MaxUnpool2d(_MaxUnpoolNd):
r"""Computes a partial inverse of :class:`MaxPool2d`. r"""Computes a partial inverse of :class:`MaxPool2d`.
@ -366,7 +355,6 @@ class MaxUnpool2d(_MaxUnpoolNd):
self.padding, output_size) self.padding, output_size)
@weak_module
class MaxUnpool3d(_MaxUnpoolNd): class MaxUnpool3d(_MaxUnpoolNd):
r"""Computes a partial inverse of :class:`MaxPool3d`. r"""Computes a partial inverse of :class:`MaxPool3d`.
@ -429,7 +417,6 @@ class MaxUnpool3d(_MaxUnpoolNd):
self.padding, output_size) self.padding, output_size)
@weak_module
class _AvgPoolNd(Module): class _AvgPoolNd(Module):
__constants__ = ['kernel_size', 'stride', 'padding', 'ceil_mode', 'count_include_pad'] __constants__ = ['kernel_size', 'stride', 'padding', 'ceil_mode', 'count_include_pad']
@ -439,7 +426,6 @@ class _AvgPoolNd(Module):
) )
@weak_module
class AvgPool1d(_AvgPoolNd): class AvgPool1d(_AvgPoolNd):
r"""Applies a 1D average pooling over an input signal composed of several r"""Applies a 1D average pooling over an input signal composed of several
input planes. input planes.
@ -490,14 +476,12 @@ class AvgPool1d(_AvgPoolNd):
self.ceil_mode = ceil_mode self.ceil_mode = ceil_mode
self.count_include_pad = count_include_pad self.count_include_pad = count_include_pad
@weak_script_method
def forward(self, input): def forward(self, input):
return F.avg_pool1d( return F.avg_pool1d(
input, self.kernel_size, self.stride, self.padding, self.ceil_mode, input, self.kernel_size, self.stride, self.padding, self.ceil_mode,
self.count_include_pad) self.count_include_pad)
@weak_module
class AvgPool2d(_AvgPoolNd): class AvgPool2d(_AvgPoolNd):
r"""Applies a 2D average pooling over an input signal composed of several input r"""Applies a 2D average pooling over an input signal composed of several input
planes. planes.
@ -557,13 +541,11 @@ class AvgPool2d(_AvgPoolNd):
self.ceil_mode = ceil_mode self.ceil_mode = ceil_mode
self.count_include_pad = count_include_pad self.count_include_pad = count_include_pad
@weak_script_method
def forward(self, input): def forward(self, input):
return F.avg_pool2d(input, self.kernel_size, self.stride, return F.avg_pool2d(input, self.kernel_size, self.stride,
self.padding, self.ceil_mode, self.count_include_pad) self.padding, self.ceil_mode, self.count_include_pad)
@weak_module
class AvgPool3d(_AvgPoolNd): class AvgPool3d(_AvgPoolNd):
r"""Applies a 3D average pooling over an input signal composed of several input r"""Applies a 3D average pooling over an input signal composed of several input
planes. planes.
@ -630,7 +612,6 @@ class AvgPool3d(_AvgPoolNd):
self.ceil_mode = ceil_mode self.ceil_mode = ceil_mode
self.count_include_pad = count_include_pad self.count_include_pad = count_include_pad
@weak_script_method
def forward(self, input): def forward(self, input):
return F.avg_pool3d(input, self.kernel_size, self.stride, return F.avg_pool3d(input, self.kernel_size, self.stride,
self.padding, self.ceil_mode, self.count_include_pad) self.padding, self.ceil_mode, self.count_include_pad)
@ -642,7 +623,6 @@ class AvgPool3d(_AvgPoolNd):
self.__dict__.setdefault('count_include_pad', True) self.__dict__.setdefault('count_include_pad', True)
@weak_module
class FractionalMaxPool2d(Module): class FractionalMaxPool2d(Module):
r"""Applies a 2D fractional max pooling over an input signal composed of several input planes. r"""Applies a 2D fractional max pooling over an input signal composed of several input planes.
@ -694,7 +674,6 @@ class FractionalMaxPool2d(Module):
raise ValueError("output_ratio must be between 0 and 1 (got {})" raise ValueError("output_ratio must be between 0 and 1 (got {})"
.format(output_ratio)) .format(output_ratio))
@weak_script_method
def forward(self, input): def forward(self, input):
return F.fractional_max_pool2d( return F.fractional_max_pool2d(
input, self.kernel_size, self.output_size, self.output_ratio, input, self.kernel_size, self.output_size, self.output_ratio,
@ -702,7 +681,6 @@ class FractionalMaxPool2d(Module):
_random_samples=self._random_samples) _random_samples=self._random_samples)
@weak_module
class FractionalMaxPool3d(Module): class FractionalMaxPool3d(Module):
r"""Applies a 3D fractional max pooling over an input signal composed of several input planes. r"""Applies a 3D fractional max pooling over an input signal composed of several input planes.
@ -754,7 +732,6 @@ class FractionalMaxPool3d(Module):
raise ValueError("output_ratio must be between 0 and 1 (got {})" raise ValueError("output_ratio must be between 0 and 1 (got {})"
.format(output_ratio)) .format(output_ratio))
@weak_script_method
def forward(self, input): def forward(self, input):
return F.fractional_max_pool3d( return F.fractional_max_pool3d(
input, self.kernel_size, self.output_size, self.output_ratio, input, self.kernel_size, self.output_size, self.output_ratio,
@ -762,7 +739,6 @@ class FractionalMaxPool3d(Module):
_random_samples=self._random_samples) _random_samples=self._random_samples)
@weak_module
class _LPPoolNd(Module): class _LPPoolNd(Module):
__constants__ = ['norm_type', 'kernel_size', 'stride', 'ceil_mode'] __constants__ = ['norm_type', 'kernel_size', 'stride', 'ceil_mode']
@ -778,7 +754,6 @@ class _LPPoolNd(Module):
'ceil_mode={ceil_mode}'.format(**self.__dict__) 'ceil_mode={ceil_mode}'.format(**self.__dict__)
@weak_module
class LPPool1d(_LPPoolNd): class LPPool1d(_LPPoolNd):
r"""Applies a 1D power-average pooling over an input signal composed of several input r"""Applies a 1D power-average pooling over an input signal composed of several input
planes. planes.
@ -814,14 +789,11 @@ class LPPool1d(_LPPoolNd):
>>> output = m(input) >>> output = m(input)
""" """
@weak_script_method
@weak_script_method
def forward(self, input): def forward(self, input):
return F.lp_pool1d(input, float(self.norm_type), self.kernel_size, return F.lp_pool1d(input, float(self.norm_type), self.kernel_size,
self.stride, self.ceil_mode) self.stride, self.ceil_mode)
@weak_module
class LPPool2d(_LPPoolNd): class LPPool2d(_LPPoolNd):
r"""Applies a 2D power-average pooling over an input signal composed of several input r"""Applies a 2D power-average pooling over an input signal composed of several input
planes. planes.
@ -871,13 +843,11 @@ class LPPool2d(_LPPoolNd):
""" """
@weak_script_method
def forward(self, input): def forward(self, input):
return F.lp_pool2d(input, float(self.norm_type), self.kernel_size, return F.lp_pool2d(input, float(self.norm_type), self.kernel_size,
self.stride, self.ceil_mode) self.stride, self.ceil_mode)
@weak_module
class _AdaptiveMaxPoolNd(Module): class _AdaptiveMaxPoolNd(Module):
__constants__ = ['output_size', 'return_indices'] __constants__ = ['output_size', 'return_indices']
@ -893,7 +863,6 @@ class _AdaptiveMaxPoolNd(Module):
# output shapes are, and how the operation computes output. # output shapes are, and how the operation computes output.
@weak_module
class AdaptiveMaxPool1d(_AdaptiveMaxPoolNd): class AdaptiveMaxPool1d(_AdaptiveMaxPoolNd):
r"""Applies a 1D adaptive max pooling over an input signal composed of several input planes. r"""Applies a 1D adaptive max pooling over an input signal composed of several input planes.
@ -913,12 +882,10 @@ class AdaptiveMaxPool1d(_AdaptiveMaxPoolNd):
""" """
@weak_script_method
def forward(self, input): def forward(self, input):
return F.adaptive_max_pool1d(input, self.output_size, self.return_indices) return F.adaptive_max_pool1d(input, self.output_size, self.return_indices)
@weak_module
class AdaptiveMaxPool2d(_AdaptiveMaxPoolNd): class AdaptiveMaxPool2d(_AdaptiveMaxPoolNd):
r"""Applies a 2D adaptive max pooling over an input signal composed of several input planes. r"""Applies a 2D adaptive max pooling over an input signal composed of several input planes.
@ -949,12 +916,10 @@ class AdaptiveMaxPool2d(_AdaptiveMaxPoolNd):
""" """
@weak_script_method
def forward(self, input): def forward(self, input):
return F.adaptive_max_pool2d(input, self.output_size, self.return_indices) return F.adaptive_max_pool2d(input, self.output_size, self.return_indices)
@weak_module
class AdaptiveMaxPool3d(_AdaptiveMaxPoolNd): class AdaptiveMaxPool3d(_AdaptiveMaxPoolNd):
r"""Applies a 3D adaptive max pooling over an input signal composed of several input planes. r"""Applies a 3D adaptive max pooling over an input signal composed of several input planes.
@ -986,12 +951,10 @@ class AdaptiveMaxPool3d(_AdaptiveMaxPoolNd):
""" """
@weak_script_method
def forward(self, input): def forward(self, input):
return F.adaptive_max_pool3d(input, self.output_size, self.return_indices) return F.adaptive_max_pool3d(input, self.output_size, self.return_indices)
@weak_module
class _AdaptiveAvgPoolNd(Module): class _AdaptiveAvgPoolNd(Module):
__constants__ = ['output_size'] __constants__ = ['output_size']
@ -1003,7 +966,6 @@ class _AdaptiveAvgPoolNd(Module):
return 'output_size={}'.format(self.output_size) return 'output_size={}'.format(self.output_size)
@weak_module
class AdaptiveAvgPool1d(_AdaptiveAvgPoolNd): class AdaptiveAvgPool1d(_AdaptiveAvgPoolNd):
r"""Applies a 1D adaptive average pooling over an input signal composed of several input planes. r"""Applies a 1D adaptive average pooling over an input signal composed of several input planes.
@ -1021,12 +983,10 @@ class AdaptiveAvgPool1d(_AdaptiveAvgPoolNd):
""" """
@weak_script_method
def forward(self, input): def forward(self, input):
return F.adaptive_avg_pool1d(input, self.output_size) return F.adaptive_avg_pool1d(input, self.output_size)
@weak_module
class AdaptiveAvgPool2d(_AdaptiveAvgPoolNd): class AdaptiveAvgPool2d(_AdaptiveAvgPoolNd):
r"""Applies a 2D adaptive average pooling over an input signal composed of several input planes. r"""Applies a 2D adaptive average pooling over an input signal composed of several input planes.
@ -1055,12 +1015,10 @@ class AdaptiveAvgPool2d(_AdaptiveAvgPoolNd):
""" """
@weak_script_method
def forward(self, input): def forward(self, input):
return F.adaptive_avg_pool2d(input, self.output_size) return F.adaptive_avg_pool2d(input, self.output_size)
@weak_module
class AdaptiveAvgPool3d(_AdaptiveAvgPoolNd): class AdaptiveAvgPool3d(_AdaptiveAvgPoolNd):
r"""Applies a 3D adaptive average pooling over an input signal composed of several input planes. r"""Applies a 3D adaptive average pooling over an input signal composed of several input planes.
@ -1089,6 +1047,5 @@ class AdaptiveAvgPool3d(_AdaptiveAvgPoolNd):
""" """
@weak_script_method
def forward(self, input): def forward(self, input):
return F.adaptive_avg_pool3d(input, self.output_size) return F.adaptive_avg_pool3d(input, self.output_size)

View File

@ -8,8 +8,7 @@ from ..parameter import Parameter
from ..utils.rnn import PackedSequence, get_packed_sequence from ..utils.rnn import PackedSequence, get_packed_sequence
from .. import init from .. import init
from .. import _VF from .. import _VF
from ..._jit_internal import weak_module, weak_script_method, weak_script, \ from ..._jit_internal import _parameter_list
_parameter_list
_rnn_impls = { _rnn_impls = {
'GRU': _VF.gru, 'GRU': _VF.gru,
@ -18,7 +17,6 @@ _rnn_impls = {
} }
@weak_script
def apply_permutation(tensor, permutation, dim=1): def apply_permutation(tensor, permutation, dim=1):
# type: (Tensor, Tensor, int) -> Tensor # type: (Tensor, Tensor, int) -> Tensor
return tensor.index_select(dim, permutation) return tensor.index_select(dim, permutation)
@ -139,7 +137,6 @@ class RNNBase(Module):
def _get_flat_weights(self): def _get_flat_weights(self):
return self._flat_weights return self._flat_weights
@weak_script_method
def check_input(self, input, batch_sizes): def check_input(self, input, batch_sizes):
# type: (Tensor, Optional[Tensor]) -> None # type: (Tensor, Optional[Tensor]) -> None
expected_input_dim = 2 if batch_sizes is not None else 3 expected_input_dim = 2 if batch_sizes is not None else 3
@ -152,7 +149,6 @@ class RNNBase(Module):
'input.size(-1) must be equal to input_size. Expected {}, got {}'.format( 'input.size(-1) must be equal to input_size. Expected {}, got {}'.format(
self.input_size, input.size(-1))) self.input_size, input.size(-1)))
@weak_script_method
def get_expected_hidden_size(self, input, batch_sizes): def get_expected_hidden_size(self, input, batch_sizes):
# type: (Tensor, Optional[Tensor]) -> Tuple[int, int, int] # type: (Tensor, Optional[Tensor]) -> Tuple[int, int, int]
if batch_sizes is not None: if batch_sizes is not None:
@ -165,7 +161,6 @@ class RNNBase(Module):
mini_batch, self.hidden_size) mini_batch, self.hidden_size)
return expected_hidden_size return expected_hidden_size
@weak_script_method
def check_hidden_size(self, hx, expected_hidden_size, msg='Expected hidden size {}, got {}'): def check_hidden_size(self, hx, expected_hidden_size, msg='Expected hidden size {}, got {}'):
# type: (Tensor, Tuple[int, int, int], str) -> None # type: (Tensor, Tuple[int, int, int], str) -> None
if hx.size() != expected_hidden_size: if hx.size() != expected_hidden_size:
@ -374,7 +369,6 @@ class RNN(RNNBase):
super(RNN, self).__init__(mode, *args, **kwargs) super(RNN, self).__init__(mode, *args, **kwargs)
@weak_module
class LSTM(RNNBase): class LSTM(RNNBase):
r"""Applies a multi-layer long short-term memory (LSTM) RNN to an input r"""Applies a multi-layer long short-term memory (LSTM) RNN to an input
sequence. sequence.
@ -484,7 +478,6 @@ class LSTM(RNNBase):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(LSTM, self).__init__('LSTM', *args, **kwargs) super(LSTM, self).__init__('LSTM', *args, **kwargs)
@weak_script_method
def check_forward_args(self, input, hidden, batch_sizes): def check_forward_args(self, input, hidden, batch_sizes):
# type: (Tensor, Tuple[Tensor, Tensor], Optional[Tensor]) -> None # type: (Tensor, Tuple[Tensor, Tensor], Optional[Tensor]) -> None
self.check_input(input, batch_sizes) self.check_input(input, batch_sizes)
@ -495,14 +488,12 @@ class LSTM(RNNBase):
self.check_hidden_size(hidden[1], expected_hidden_size, self.check_hidden_size(hidden[1], expected_hidden_size,
'Expected hidden[1] size {}, got {}') 'Expected hidden[1] size {}, got {}')
@weak_script_method
def permute_hidden(self, hx, permutation): def permute_hidden(self, hx, permutation):
# type: (Tuple[Tensor, Tensor], Optional[Tensor]) -> Tuple[Tensor, Tensor] # type: (Tuple[Tensor, Tensor], Optional[Tensor]) -> Tuple[Tensor, Tensor]
if permutation is None: if permutation is None:
return hx return hx
return apply_permutation(hx[0], permutation), apply_permutation(hx[1], permutation) return apply_permutation(hx[0], permutation), apply_permutation(hx[1], permutation)
@weak_script_method
def forward_impl(self, input, hx, batch_sizes, max_batch_size, sorted_indices): def forward_impl(self, input, hx, batch_sizes, max_batch_size, sorted_indices):
# type: (Tensor, Optional[Tuple[Tensor, Tensor]], Optional[Tensor], int, Optional[Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa # type: (Tensor, Optional[Tuple[Tensor, Tensor]], Optional[Tensor], int, Optional[Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa
if hx is None: if hx is None:
@ -528,7 +519,7 @@ class LSTM(RNNBase):
return output, hidden return output, hidden
@weak_script_method @torch._jit_internal.export
def forward_tensor(self, input, hx=None): def forward_tensor(self, input, hx=None):
# type: (Tensor, Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # type: (Tensor, Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
batch_sizes = None batch_sizes = None
@ -540,7 +531,7 @@ class LSTM(RNNBase):
return output, self.permute_hidden(hidden, unsorted_indices) return output, self.permute_hidden(hidden, unsorted_indices)
@weak_script_method @torch._jit_internal.export
def forward_packed(self, input, hx=None): def forward_packed(self, input, hx=None):
# type: (Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]], Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]], Tuple[Tensor, Tensor]] # noqa # type: (Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]], Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]], Tuple[Tensor, Tensor]] # noqa
input, batch_sizes, sorted_indices, unsorted_indices = input input, batch_sizes, sorted_indices, unsorted_indices = input
@ -552,6 +543,7 @@ class LSTM(RNNBase):
output = get_packed_sequence(output, batch_sizes, sorted_indices, unsorted_indices) output = get_packed_sequence(output, batch_sizes, sorted_indices, unsorted_indices)
return output, self.permute_hidden(hidden, unsorted_indices) return output, self.permute_hidden(hidden, unsorted_indices)
@torch._jit_internal.ignore
def forward(self, input, hx=None): def forward(self, input, hx=None):
if isinstance(input, PackedSequence): if isinstance(input, PackedSequence):
return self.forward_packed(input, hx) return self.forward_packed(input, hx)
@ -694,14 +686,12 @@ class RNNCellBase(Module):
s += ', nonlinearity={nonlinearity}' s += ', nonlinearity={nonlinearity}'
return s.format(**self.__dict__) return s.format(**self.__dict__)
@weak_script_method
def check_forward_input(self, input): def check_forward_input(self, input):
if input.size(1) != self.input_size: if input.size(1) != self.input_size:
raise RuntimeError( raise RuntimeError(
"input has inconsistent input_size: got {}, expected {}".format( "input has inconsistent input_size: got {}, expected {}".format(
input.size(1), self.input_size)) input.size(1), self.input_size))
@weak_script_method
def check_forward_hidden(self, input, hx, hidden_label=''): def check_forward_hidden(self, input, hx, hidden_label=''):
# type: (Tensor, Tensor, str) -> None # type: (Tensor, Tensor, str) -> None
if input.size(0) != hx.size(0): if input.size(0) != hx.size(0):
@ -720,7 +710,6 @@ class RNNCellBase(Module):
init.uniform_(weight, -stdv, stdv) init.uniform_(weight, -stdv, stdv)
@weak_module
class RNNCell(RNNCellBase): class RNNCell(RNNCellBase):
r"""An Elman RNN cell with tanh or ReLU non-linearity. r"""An Elman RNN cell with tanh or ReLU non-linearity.
@ -784,7 +773,6 @@ class RNNCell(RNNCellBase):
super(RNNCell, self).__init__(input_size, hidden_size, bias, num_chunks=1) super(RNNCell, self).__init__(input_size, hidden_size, bias, num_chunks=1)
self.nonlinearity = nonlinearity self.nonlinearity = nonlinearity
@weak_script_method
def forward(self, input, hx=None): def forward(self, input, hx=None):
# type: (Tensor, Optional[Tensor]) -> Tensor # type: (Tensor, Optional[Tensor]) -> Tensor
self.check_forward_input(input) self.check_forward_input(input)
@ -810,7 +798,6 @@ class RNNCell(RNNCellBase):
return ret return ret
@weak_module
class LSTMCell(RNNCellBase): class LSTMCell(RNNCellBase):
r"""A long short-term memory (LSTM) cell. r"""A long short-term memory (LSTM) cell.
@ -875,7 +862,6 @@ class LSTMCell(RNNCellBase):
def __init__(self, input_size, hidden_size, bias=True): def __init__(self, input_size, hidden_size, bias=True):
super(LSTMCell, self).__init__(input_size, hidden_size, bias, num_chunks=4) super(LSTMCell, self).__init__(input_size, hidden_size, bias, num_chunks=4)
@weak_script_method
def forward(self, input, hx=None): def forward(self, input, hx=None):
# type: (Tensor, Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tensor] # type: (Tensor, Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tensor]
self.check_forward_input(input) self.check_forward_input(input)
@ -891,7 +877,6 @@ class LSTMCell(RNNCellBase):
) )
@weak_module
class GRUCell(RNNCellBase): class GRUCell(RNNCellBase):
r"""A gated recurrent unit (GRU) cell r"""A gated recurrent unit (GRU) cell
@ -957,7 +942,6 @@ class GRUCell(RNNCellBase):
def __init__(self, input_size, hidden_size, bias=True): def __init__(self, input_size, hidden_size, bias=True):
super(GRUCell, self).__init__(input_size, hidden_size, bias, num_chunks=3) super(GRUCell, self).__init__(input_size, hidden_size, bias, num_chunks=3)
@weak_script_method
def forward(self, input, hx=None): def forward(self, input, hx=None):
# type: (Tensor, Optional[Tensor]) -> Tensor # type: (Tensor, Optional[Tensor]) -> Tensor
self.check_forward_input(input) self.check_forward_input(input)

View File

@ -4,10 +4,8 @@ from torch.nn.parameter import Parameter
from .module import Module from .module import Module
from .. import functional as F from .. import functional as F
from .. import init from .. import init
from torch._jit_internal import weak_module, weak_script_method
@weak_module
class Embedding(Module): class Embedding(Module):
r"""A simple lookup table that stores embeddings of a fixed dictionary and size. r"""A simple lookup table that stores embeddings of a fixed dictionary and size.
@ -110,7 +108,6 @@ class Embedding(Module):
with torch.no_grad(): with torch.no_grad():
self.weight[self.padding_idx].fill_(0) self.weight[self.padding_idx].fill_(0)
@weak_script_method
def forward(self, input): def forward(self, input):
return F.embedding( return F.embedding(
input, self.weight, self.padding_idx, self.max_norm, input, self.weight, self.padding_idx, self.max_norm,
@ -173,7 +170,6 @@ class Embedding(Module):
return embedding return embedding
@weak_module
class EmbeddingBag(Module): class EmbeddingBag(Module):
r"""Computes sums or means of 'bags' of embeddings, without instantiating the r"""Computes sums or means of 'bags' of embeddings, without instantiating the
intermediate embeddings. intermediate embeddings.
@ -277,7 +273,6 @@ class EmbeddingBag(Module):
def reset_parameters(self): def reset_parameters(self):
init.normal_(self.weight) init.normal_(self.weight)
@weak_script_method
def forward(self, input, offsets=None, per_sample_weights=None): def forward(self, input, offsets=None, per_sample_weights=None):
# type: (Tensor, Optional[Tensor], Optional[Tensor]) -> Tensor # type: (Tensor, Optional[Tensor], Optional[Tensor]) -> Tensor
return F.embedding_bag(input, self.weight, offsets, return F.embedding_bag(input, self.weight, offsets,

View File

@ -1,9 +1,7 @@
from .module import Module from .module import Module
from .. import functional as F from .. import functional as F
from ..._jit_internal import weak_module, weak_script_method
@weak_module
class Upsample(Module): class Upsample(Module):
r"""Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D (volumetric) data. r"""Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D (volumetric) data.
@ -129,7 +127,6 @@ class Upsample(Module):
self.mode = mode self.mode = mode
self.align_corners = align_corners self.align_corners = align_corners
@weak_script_method
def forward(self, input): def forward(self, input):
return F.interpolate(input, self.size, self.scale_factor, self.mode, self.align_corners) return F.interpolate(input, self.size, self.scale_factor, self.mode, self.align_corners)
@ -142,7 +139,6 @@ class Upsample(Module):
return info return info
@weak_module
class UpsamplingNearest2d(Upsample): class UpsamplingNearest2d(Upsample):
r"""Applies a 2D nearest neighbor upsampling to an input signal composed of several input r"""Applies a 2D nearest neighbor upsampling to an input signal composed of several input
channels. channels.
@ -188,7 +184,6 @@ class UpsamplingNearest2d(Upsample):
super(UpsamplingNearest2d, self).__init__(size, scale_factor, mode='nearest') super(UpsamplingNearest2d, self).__init__(size, scale_factor, mode='nearest')
@weak_module
class UpsamplingBilinear2d(Upsample): class UpsamplingBilinear2d(Upsample):
r"""Applies a 2D bilinear upsampling to an input signal composed of several input r"""Applies a 2D bilinear upsampling to an input signal composed of several input
channels. channels.

View File

@ -23,10 +23,9 @@ def _is_jit_enabled():
# Check if we can safely replicate the module. # Check if we can safely replicate the module.
# there are three types of module: # there are two types of module:
# 1. python modules # 1. python modules
# 2. weak python modules (nn.Module annotated by @weak_module) # 2. ScriptModule
# 3. ScriptModule
# #
# currently a module cannot be replicated properly if the descendants of # currently a module cannot be replicated properly if the descendants of
# any ScriptModule contains python module (type 1 above) # any ScriptModule contains python module (type 1 above)

View File

@ -5,9 +5,7 @@ from __future__ import unicode_literals
from .. import functional as F from .. import functional as F
from ...modules.module import Module from ...modules.module import Module
from ...._jit_internal import weak_module, weak_script_method
@weak_module
class ReLU(Module): class ReLU(Module):
r"""Applies quantized rectified linear unit function element-wise: r"""Applies quantized rectified linear unit function element-wise:
@ -37,7 +35,6 @@ class ReLU(Module):
assert not inplace, 'torch.nn.quantized.ReLU does not support inplace' assert not inplace, 'torch.nn.quantized.ReLU does not support inplace'
@weak_script_method
def forward(self, input): def forward(self, input):
return F.relu(input) return F.relu(input)

View File

@ -2,9 +2,7 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import torch import torch
from ...modules.module import Module from ...modules.module import Module
from ...modules.linear import Linear as NNLinear from ...modules.linear import Linear as NNLinear
from ...._jit_internal import weak_module
@weak_module
class Quantize(Module): class Quantize(Module):
r"""Quantizes an incoming tensor r"""Quantizes an incoming tensor
Args: Args:
@ -39,7 +37,6 @@ class Quantize(Module):
def from_float(mod): def from_float(mod):
return Quantize(mod.qparams[0].item(), mod.qparams[1].item(), torch.quint8) return Quantize(mod.qparams[0].item(), mod.qparams[1].item(), torch.quint8)
@weak_module
class DeQuantize(Module): class DeQuantize(Module):
r"""Dequantizes an incoming tensor r"""Dequantizes an incoming tensor
@ -65,7 +62,6 @@ class DeQuantize(Module):
def from_float(mod): def from_float(mod):
return DeQuantize() return DeQuantize()
@weak_module
class Linear(NNLinear): class Linear(NNLinear):
r""" r"""
A quantized linear module with quantized tensor as inputs A quantized linear module with quantized tensor as inputs