Remove weak script (#22212)

Summary:
* Deletes all weak script decorators / associated data structures / methods
   * In order to keep supporting the standard library in script, this enables recursive script on any function defined in `torch.nn`
   * Most changes in `torch/nn` are the result of `ag -Q "weak" torch/nn/ -l | xargs sed -i '/weak/d'`, only `rnn.py` needed manual editing to use the `ignore` and `export` to continue supporting the overloaded `forward` methods
* `Sequential`/`ModuleList` no longer need to be added to constants since they are compiled on demand

This should also fix https://github.com/pytorch/pytorch/issues/22212
Pull Request resolved: https://github.com/pytorch/pytorch/pull/22212

Differential Revision: D15988346

Pulled By: driazati

fbshipit-source-id: af223e3ad0580be895377312949997a70e988e4f
This commit is contained in:
David Riazati 2019-07-03 17:22:22 -07:00 committed by Facebook Github Bot
parent b93f29ded3
commit 10c4b98ade
28 changed files with 109 additions and 564 deletions

View File

@ -6979,7 +6979,7 @@ a")
with self.assertRaisesRegex(RuntimeError, "'int' object is not iterable"):
M()
def test_script_module_list_sequential_error(self):
def test_script_module_list_sequential(self):
class M(torch.jit.ScriptModule):
def __init__(self, mod_list):
super(M, self).__init__(False)
@ -6991,25 +6991,21 @@ a")
v = m(v)
return v
with self.assertRaisesRegex(RuntimeError, "Did you forget to add it to __constants"):
a = M(nn.Sequential(nn.ReLU()))
with self.assertRaisesRegex(RuntimeError, "Did you forget to add it to __constants"):
a = M(nn.ModuleList([nn.ReLU()]))
m = M(nn.Sequential(nn.ReLU()))
self.assertExportImportModule(m, (torch.randn(2, 2),))
def test_attr_module_constants_error(self):
def test_attr_module_constants(self):
class M2(torch.jit.ScriptModule):
def __init__(self, mod_list):
super(M2, self).__init__(False)
self.mods = mod_list
@torch.jit.script_method
def forward(self, v):
def forward(self, x):
return self.mods.forward(x)
with self.assertRaisesRegex(RuntimeError, "Did you forget to add it to __constants"):
M2(nn.Sequential(nn.ReLU()))
with self.assertRaisesRegex(RuntimeError, "Did you forget to add it to __constants"):
M2(nn.ModuleList([nn.ReLU()]))
m = M2(nn.Sequential(nn.ReLU()))
self.assertExportImportModule(m, (torch.randn(2, 2),))
def test_script_sequential_for(self):
class Sub(torch.jit.ScriptModule):
@ -11007,6 +11003,7 @@ a")
with self.assertRaisesRegex(torch.jit.Error, "Exception"):
foo(torch.tensor(0))
@unittest.skipIf(True, "Removing weak script")
def test_weak_script_function(self):
outer_var = 10
outer_var2 = 11
@ -11086,6 +11083,7 @@ a")
eg = torch.zeros(3, dtype=torch.uint8)
self.assertEqual(foo_traced(eg), foo(eg))
@unittest.skipIf(True, "Removing weak script")
def test_weak_module(self):
@torch._jit_internal.weak_module
@ -11161,6 +11159,7 @@ a")
self.assertEqual(script_result, expected_result)
self.assertEqual(script_result, script_result2)
@unittest.skipIf(True, "Removing weak script")
def test_weak_module_parameters_and_buffers(self):
weights = torch.randn(10, 10)
bias = torch.randn(10)
@ -11219,6 +11218,7 @@ a")
self.assertEqual(strong_mod(inp), expected_result)
self.assertExportImportModule(strong_mod, (inp,))
@unittest.skipIf(True, "Removing weak script")
def test_weak_module_nested(self):
@torch._jit_internal.weak_module
class OtherWeak(torch.nn.Module):
@ -11280,6 +11280,7 @@ a")
+ F.linear(inp, 2 * torch.ones(10, 10), 2 * torch.ones(10))
self.assertEqual(result, expected_result)
@unittest.skipIf(True, "Removing weak script")
def test_weak_module_submodule(self):
@torch._jit_internal.weak_module
class Weak(torch.nn.Module):
@ -11319,6 +11320,7 @@ a")
with self.assertRaisesRegex(RuntimeError, "Cannot call a ScriptModule that is not a submodule of the caller"):
strong_mod = Strong()
@unittest.skipIf(True, "Removing weak script")
def test_weak_module_copying(self):
class Submodule(torch.nn.Module):
def __init__(self):
@ -11385,6 +11387,7 @@ a")
m = M()
@unittest.skipIf(True, "Removing weak script")
def test_weak_module_attributes(self):
tester = self
@ -11948,6 +11951,7 @@ a")
FileCheck().check_not("prim::PythonOp").run(cu.test.graph)
@unittest.skipIf(True, "Removing weak script")
def test_overloading(self):
@torch._jit_internal.weak_module
class W(torch.nn.Module):
@ -13623,6 +13627,9 @@ EXCLUDE_SCRIPT_MODULES = {
'test_nn_AdaptiveAvgPool3d_tuple_none',
'test_nn_AdaptiveMaxPool2d_tuple_none',
'test_nn_AdaptiveMaxPool3d_tuple_none',
# Uses Module._backend, so this is not supported
'test_nn_CrossMapLRN2d',
}
@ -14552,10 +14559,6 @@ def add_nn_module_test(*args, **kwargs):
module_name = name.split("_")[0]
module = getattr(torch.nn, module_name, None)
if module is None or torch._jit_internal.weak_types.get(module) is None:
return
if 'desc' in kwargs and 'eval' in kwargs['desc']:
# eval() is not supported, so skip these tests
return

View File

@ -4,29 +4,14 @@ can be used in other places in torch/ (namely torch.nn) without running into
circular dependency problems
"""
import weakref
import inspect
import weakref
from torch._six import builtins
# Tracks standalone weak script functions
compiled_weak_fns = weakref.WeakKeyDictionary() # noqa: T484
# Tracks which methods should be converted to strong methods
weak_script_methods = weakref.WeakKeyDictionary() # noqa: T484
# Converted modules and their corresponding WeakScriptModuleProxy objects
weak_modules = weakref.WeakKeyDictionary() # noqa: T484
# Types that have been declared as weak modules
weak_types = weakref.WeakKeyDictionary() # noqa: T484
# Wrapper functions that can call either of 2 functions depending on a boolean
# argument
boolean_dispatched = weakref.WeakKeyDictionary() # noqa: T484
COMPILATION_PENDING = object()
COMPILED = object()
def createResolutionCallback(frames_up=0):
"""
@ -71,51 +56,41 @@ def createResolutionCallback(frames_up=0):
return f_globals[key]
elif hasattr(builtins, key):
return getattr(builtins, key)
else:
return None
return env
def weak_script(fn, _frames_up=0):
def createResolutionCallbackFromClosure(fn):
"""
Marks a function as a weak script function. When used in a script function
or ScriptModule, the weak script function will be lazily compiled and
inlined in the graph. When not used in a script function, the weak script
annotation has no effect.
Create a resolutionCallback by introspecting the function instead of
looking up the stack for the enclosing scope
"""
compiled_weak_fns[fn] = {
"status": COMPILATION_PENDING,
"compiled_fn": None,
"rcb": createResolutionCallback(_frames_up + 1)
}
return fn
var_names = fn.__code__.co_freevars
# map of captured name -> value
free_vars = {}
def weak_module(cls):
weak_types[cls] = {
"method_stubs": None
}
return cls
for index, name in enumerate(var_names):
free_vars[name] = fn.__closure__[index].cell_contents
f_globals = fn.__globals__
def env(key):
if key in free_vars:
return free_vars[key]
elif hasattr(builtins, key):
return getattr(builtins, key)
else:
return f_globals.get(key)
def weak_script_method(fn):
weak_script_methods[fn] = {
"rcb": createResolutionCallback(frames_up=2),
"original_method": fn
}
return fn
return env
def boolean_dispatch(arg_name, arg_index, default, if_true, if_false, module_name, func_name):
"""
Dispatches to either of 2 weak script functions based on a boolean argument.
Dispatches to either of 2 script functions based on a boolean argument.
In TorchScript, the boolean argument must be constant so that the correct
function to use can be determined at compile time.
"""
if compiled_weak_fns.get(if_true) is None or compiled_weak_fns.get(if_false) is None:
raise RuntimeError("both functions must be weak script")
def fn(*args, **kwargs):
dispatch_flag = False
if arg_name in kwargs:

View File

@ -23,7 +23,8 @@ Decl mergeTypesFromTypeComment(
<< "Number of type annotations ("
<< type_annotation_decl.params().size()
<< ") did not match the number of "
<< "function parameters (" << expected_num_annotations << ")";
<< (is_method ? "method" : "function")
<< " parameters (" << expected_num_annotations << ")";
}
auto old = decl.params();
auto _new = type_annotation_decl.params();

View File

@ -244,6 +244,11 @@ std::shared_ptr<SugaredValue> OverloadedMethodValue::call(
<< err.str();
}
bool should_recurse(py::object obj) {
return py::cast<bool>(py::module::import("torch.jit")
.attr("_is_recursive_script_enabled")(obj));
}
std::shared_ptr<SugaredValue> ModuleValue::attr(
const SourceRange& loc,
Function& m,
@ -307,7 +312,7 @@ std::shared_ptr<SugaredValue> ModuleValue::attr(
// If recursive script mode is on, create a ScriptModule and register it as
// as submodule or register a python method as a script::Method
if (getRecursiveScriptMode()) {
if (should_recurse(attr)) {
if (py::isinstance(attr, py::module::import("torch.nn").attr("Module"))) {
// If the module is a submodule of the py_module, convert it to a
// ScriptModule and add it as a submodule to the script::Module. This
@ -471,11 +476,6 @@ std::shared_ptr<SugaredValue> toSugaredValue(
}
}
auto weak_obj =
py::module::import("torch.jit").attr("_try_get_weak_module")(obj);
if (!weak_obj.is_none()) {
obj = weak_obj;
}
if (auto callee = as_function(obj)) {
return std::make_shared<FunctionValue>(callee);
} else if (py::isinstance<py::module>(obj)) {
@ -504,12 +504,6 @@ std::shared_ptr<SugaredValue> toSugaredValue(
<< "which is currently not supported in Torchscript."
<< "Please open a feature request to add it.";
}
auto compiled_fn =
py::module::import("torch.jit").attr("_try_compile_weak_script")(obj);
if (auto callee = as_function(compiled_fn)) {
return std::make_shared<FunctionValue>(callee);
}
}
py::object dispatched_fn =
@ -528,7 +522,7 @@ std::shared_ptr<SugaredValue> toSugaredValue(
}
}
if (getRecursiveScriptMode() && py::isinstance<py::function>(obj)) {
if (should_recurse(obj) && py::isinstance<py::function>(obj)) {
auto compiled_fn =
py::module::import("torch.jit").attr("_try_compile_fn")(obj);
if (auto callee = as_function(compiled_fn)) {

View File

@ -7,7 +7,7 @@ import torch.backends.cudnn as cudnn
import torch.jit.annotations
import torch._jit_internal as _jit_internal
from torch._six import PY2, PY37, with_metaclass, get_function_from_type, \
string_classes, builtins
string_classes
from ..nn.modules.utils import _single, _pair, _triple, _quadruple, \
_list_with_default
import torch.testing
@ -930,50 +930,10 @@ def _try_get_overloaded_fn(mod, field):
return mod._overloads.get(field, None) if isinstance(mod, ScriptModule) else None
def _try_compile_weak_script(fn):
entry = _jit_internal.compiled_weak_fns.get(fn)
if entry is None:
return None
if entry["status"] == _jit_internal.COMPILATION_PENDING:
compiled_fn = torch.jit.script(fn, True, 0, entry["rcb"])
del entry["rcb"]
_jit_internal.compiled_weak_fns[fn]["compiled_fn"] = compiled_fn
entry["status"] = _jit_internal.COMPILED
return compiled_fn
# TODO: use fn.__closure__
raise RuntimeError("Cannot make resolutionCallback in Python 2")
else:
return entry["compiled_fn"]
class ScriptWarning(Warning):
pass
def createResolutionCallbackFromClosure(fn):
"""
Create a resolutionCallback by introspecting the function instead of
looking up the stack for the enclosing scope
"""
var_names = fn.__code__.co_freevars
# map of captured name -> value
free_vars = {}
for index, name in enumerate(var_names):
free_vars[name] = fn.__closure__[index].cell_contents
f_globals = fn.__globals__
def env(key):
if key in free_vars:
return free_vars[key]
elif hasattr(builtins, key):
return getattr(builtins, key)
else:
return f_globals.get(key)
return env
def _create_constant_iterable_module(module):
modules = OrderedDict()
@ -1012,20 +972,20 @@ def _try_compile_fn(fn):
# Don't do anything for @ignore'd functions
return None
if not inspect.isfunction(fn) and not inspect.ismethod(fn):
raise RuntimeError("`{}` is not a function. Recursive scripting only supports "
"Python functions or methods currently.\n"
"Consider manually annotating `{}` with @torch.jit.script.".format(fn))
if isinstance(fn, torch.nn.Module):
# Since modules are callable pybind recognizes them as functions, but
# don't do anything for them
return None
if not inspect.isfunction(fn) and not inspect.ismethod(fn):
raise RuntimeError("`{}` is not a function. Recursive scripting only supports "
"Python functions or methods currently.\n"
"Consider manually annotating `{}` with @torch.jit.script.".format(fn, fn))
# We don't have the actual scope where the function was defined, but we can
# extract the necessary info from the closed over variables on the function
# object
rcb = createResolutionCallbackFromClosure(fn)
rcb = _jit_internal.createResolutionCallbackFromClosure(fn)
return torch.jit.script(fn, _rcb=rcb)
@ -1040,7 +1000,9 @@ def _disable_emit_hooks():
def _create_method_from_fn(module, fn):
if _jit_internal.is_ignored_fn(fn):
return None
stub = script_method(fn, createResolutionCallbackFromClosure(fn))
if not inspect.ismethod(fn):
return None
stub = script_method(fn, _jit_internal.createResolutionCallbackFromClosure(fn))
with _disable_emit_hooks():
# We don't want to call the hooks here since the graph that is calling
# this function is not yet complete
@ -1101,6 +1063,15 @@ def _qualified_name(obj):
return module_name + "." + name
def _is_recursive_script_enabled(value):
# TODO: [enable recursive script]
# when recursive script is made the default, remove this method
enabled = torch._C._jit_recursive_script()
module = inspect.getmodule(value)
if module is not None and 'torch.nn' in module.__name__:
enabled = True
return enabled
@contextlib.contextmanager
def _enable_recursive_script():
torch._C._jit_recursive_script(True)
@ -1114,8 +1085,8 @@ def script(obj, optimize=True, _frames_up=0, _rcb=None):
if _rcb is None:
_rcb = _jit_internal.createResolutionCallback(_frames_up + 1)
if torch._C._jit_recursive_script():
if isinstance(obj, torch.nn.Module):
if _is_recursive_script_enabled(obj):
return _convert_to_script_module(obj)
if inspect.isclass(obj):
@ -1158,21 +1129,6 @@ def script_method(fn, _rcb=None):
return ScriptMethodStub(_rcb, ast, fn)
def _try_get_weak_module(mod):
"""
Get the WeakScriptModuleProxy corresponding to mod if it exists
"""
if not isinstance(mod, Module):
return None
return _jit_internal.weak_modules.get(mod)
def _is_weak_type(cls):
"""
Check if a type has been annotated with `weak_module`
"""
return cls in _jit_internal.weak_types
# These OrderedDictWrapper classes replace the actual OrderedDicts in
# module with versions that get/set properties inside of script::Module.
@ -1569,9 +1525,9 @@ if _enabled:
def __setattr__(self, attr, value):
if attr not in self._constants_set:
if isinstance(value, Module) and _is_weak_type(type(value)):
if isinstance(value, Module) and _is_recursive_script_enabled(value):
# Compile weak script module
value = _make_strong(value)
value = _convert_to_script_module(value)
if attr == 'training':
if self._c._has_attribute('training'):
self.__dict__['training'] = value
@ -1684,7 +1640,7 @@ if _enabled:
if isinstance(item, (ModuleList, Sequential)):
# These are in __constants__, so ignore them here
if not torch._C._jit_recursive_script():
if not _is_recursive_script_enabled(item):
# For recursive script, these are constantified after
# they are used, so they don't need to be in constants.
# The `continue` here should be deleted along with
@ -1774,33 +1730,27 @@ else:
super(ScriptModule, self).__init__()
def _get_weak_stubs(cls):
"""
Calls script_method for each method that has been annotated with @weak_script
on the type of the object passed in and returns the generated ScriptMethodStubs.
"""
stubs = []
for name in dir(cls):
func = get_function_from_type(cls, name)
if func in _jit_internal.weak_script_methods:
entry = _jit_internal.weak_script_methods[func]
stub = script_method(entry["original_method"], entry["rcb"])
stubs.append(stub)
return stubs
def _convert_to_script_module(mod, methods=None):
def _convert_to_script_module(mod):
"""
Makes a ScriptModule from an nn.Module. If `_methods` is provided,
these methods are treated as @script_methods. If not, it defaults to
`('forward',)`. Methods accessed in forward are scripted on demand if
`_enable_recursive_script()` is used.
"""
if isinstance(mod, ScriptModule):
return mod
if isinstance(mod, (ModuleList, Sequential)):
# Create constant versions for the iterable modules
return _create_constant_iterable_module(mod)
if methods is None:
methods = ()
if hasattr(mod, 'forward'):
if mod.forward.__func__ == torch.nn.Module.forward:
# TODO: [enable recursive script]
# forward was not overrided
raise RuntimeError("No forward method was defined on {}".format(mod))
if not _jit_internal.is_ignored_fn(mod.forward):
methods = ('forward',)
exported = []
for name in dir(mod):
@ -1812,36 +1762,12 @@ def _convert_to_script_module(mod, methods=None):
def make_stub(method):
func = get_function_from_type(type(mod), method)
return script_method(func, createResolutionCallbackFromClosure(func))
return script_method(func, _jit_internal.createResolutionCallbackFromClosure(func))
stubs = list(map(make_stub, methods))
return WeakScriptModuleProxy(mod, stubs)
def _make_strong(mod):
"""
Converts a weak module into a subclass of ScriptModule. If `_methods` is
provided, only these methods are treated as @script_methods.
"""
if mod in _jit_internal.weak_modules:
return _jit_internal.weak_modules[mod]
cls = type(mod)
# Explicitly annotated weak script
stubs = _jit_internal.weak_types.get(cls)["method_stubs"]
if stubs is None:
# Generate stubs and and store on weak_types in case this type is
# used again
stubs = _get_weak_stubs(cls)
_jit_internal.weak_types[cls]["method_stubs"] = stubs
proxy = WeakScriptModuleProxy(mod, stubs)
_jit_internal.weak_modules[mod] = proxy
return proxy
def _get_methods(cls):
import inspect
# In Python 3 unbound methods are functions, but in Python 2 they are methods
@ -1937,13 +1863,13 @@ class _ConstModuleList(ScriptModule):
if isinstance(modules, OrderedDict):
for key, module in modules.items():
if _is_weak_type(type(module)):
module = _make_strong(module)
if isinstance(module, torch.nn.Module) and _is_recursive_script_enabled(module):
module = _convert_to_script_module(module)
self.add_module(key, module)
else:
for i, module in enumerate(modules):
if _is_weak_type(type(module)):
module = _make_strong(module)
if isinstance(module, torch.nn.Module) and _is_recursive_script_enabled(module):
module = _convert_to_script_module(module)
self.add_module(str(i), module)
def __getitem__(self, idx):

View File

@ -1,9 +1,7 @@
import torch
import torch.backends.cudnn as cudnn
from ..._jit_internal import weak_script
@weak_script
def affine_grid_generator(theta, size):
# type: (Tensor, List[int]) -> Tensor
if theta.is_cuda and cudnn.enabled and cudnn.is_acceptable(theta) and len(size) == 4 and size[0] < 65536:

View File

@ -1,10 +1,8 @@
import warnings
from .._jit_internal import weak_script
# NB: Keep this file in sync with enums in aten/src/ATen/core/Reduction.h
@weak_script
def get_enum(reduction):
# type: (str) -> int
if reduction == 'none':
@ -26,7 +24,6 @@ def get_enum(reduction):
# We use these functions in torch/legacy as well, in which case we'll silence the warning
@weak_script
def legacy_get_string(size_average, reduce, emit_warning=True):
# type: (Optional[bool], Optional[bool], bool) -> str
warning = "size_average and reduce args will be deprecated, please use reduction='{}' instead."
@ -47,7 +44,6 @@ def legacy_get_string(size_average, reduce, emit_warning=True):
return ret
@weak_script
def legacy_get_enum(size_average, reduce, emit_warning=True):
# type: (Optional[bool], Optional[bool], bool) -> int
return get_enum(legacy_get_string(size_average, reduce, emit_warning))

View File

@ -12,7 +12,7 @@ from ._functions import vision
from .modules.utils import _single, _pair, _triple, _list_with_default
from . import grad # noqa: F401
from . import _VF
from .._jit_internal import weak_script, List
from .._jit_internal import boolean_dispatch, List
conv1d = _add_docstr(torch.conv1d, r"""
@ -299,7 +299,6 @@ Args:
""")
@weak_script
def fractional_max_pool2d_with_indices(input, kernel_size, output_size=None,
output_ratio=None, return_indices=False,
_random_samples=None):
@ -346,7 +345,6 @@ def fractional_max_pool2d_with_indices(input, kernel_size, output_size=None,
return torch._C._nn.fractional_max_pool2d(input, kernel_size, output_size, _random_samples)
@weak_script
def _fractional_max_pool2d(input, kernel_size, output_size=None,
output_ratio=None, return_indices=False,
_random_samples=None):
@ -355,7 +353,7 @@ def _fractional_max_pool2d(input, kernel_size, output_size=None,
output_ratio, return_indices,
_random_samples)[0]
fractional_max_pool2d = torch._jit_internal.boolean_dispatch(
fractional_max_pool2d = boolean_dispatch(
arg_name='return_indices',
arg_index=4,
default=False,
@ -365,7 +363,6 @@ fractional_max_pool2d = torch._jit_internal.boolean_dispatch(
func_name='fractional_max_pool2d')
@weak_script
def fractional_max_pool3d_with_indices(input, kernel_size, output_size=None,
output_ratio=None, return_indices=False,
_random_samples=None):
@ -414,7 +411,6 @@ def fractional_max_pool3d_with_indices(input, kernel_size, output_size=None,
return torch._C._nn.fractional_max_pool3d(input, kernel_size, output_size, _random_samples)
@weak_script
def _fractional_max_pool3d(input, kernel_size, output_size=None,
output_ratio=None, return_indices=False,
_random_samples=None):
@ -423,7 +419,7 @@ def _fractional_max_pool3d(input, kernel_size, output_size=None,
output_ratio, return_indices,
_random_samples)[0]
fractional_max_pool3d = torch._jit_internal.boolean_dispatch(
fractional_max_pool3d = boolean_dispatch(
arg_name='return_indices',
arg_index=4,
default=False,
@ -433,7 +429,6 @@ fractional_max_pool3d = torch._jit_internal.boolean_dispatch(
func_name='fractional_max_pool3d')
@weak_script
def max_pool1d_with_indices(input, kernel_size, stride=None, padding=0,
dilation=1, ceil_mode=False, return_indices=False):
# type: (Tensor, BroadcastingList1[int], Optional[BroadcastingList1[int]], BroadcastingList1[int], BroadcastingList1[int], bool, bool) -> Tuple[Tensor, Tensor] # noqa
@ -448,7 +443,6 @@ def max_pool1d_with_indices(input, kernel_size, stride=None, padding=0,
input, kernel_size, stride, padding, dilation, ceil_mode)
@weak_script
def _max_pool1d(input, kernel_size, stride=None, padding=0, dilation=1,
ceil_mode=False, return_indices=False):
# type: (Tensor, BroadcastingList1[int], Optional[BroadcastingList1[int]], BroadcastingList1[int], BroadcastingList1[int], bool, bool) -> Tensor # noqa
@ -457,7 +451,7 @@ def _max_pool1d(input, kernel_size, stride=None, padding=0, dilation=1,
return torch.max_pool1d(
input, kernel_size, stride, padding, dilation, ceil_mode)
max_pool1d = torch._jit_internal.boolean_dispatch(
max_pool1d = boolean_dispatch(
arg_name='return_indices',
arg_index=6,
default=False,
@ -467,7 +461,6 @@ max_pool1d = torch._jit_internal.boolean_dispatch(
func_name='max_pool1d')
@weak_script
def max_pool2d_with_indices(input, kernel_size, stride=None, padding=0, dilation=1,
ceil_mode=False, return_indices=False):
# type: (Tensor, BroadcastingList2[int], Optional[BroadcastingList2[int]], BroadcastingList2[int], BroadcastingList2[int], bool, bool) -> Tuple[Tensor, Tensor] # noqa
@ -481,7 +474,6 @@ def max_pool2d_with_indices(input, kernel_size, stride=None, padding=0, dilation
return torch._C._nn.max_pool2d_with_indices(input, kernel_size, stride, padding, dilation, ceil_mode)
@weak_script
def _max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1,
ceil_mode=False, return_indices=False):
# type: (Tensor, BroadcastingList2[int], Optional[BroadcastingList2[int]], BroadcastingList2[int], BroadcastingList2[int], bool, bool) -> Tensor # noqa
@ -490,7 +482,7 @@ def _max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1,
return torch.max_pool2d(
input, kernel_size, stride, padding, dilation, ceil_mode)
max_pool2d = torch._jit_internal.boolean_dispatch(
max_pool2d = boolean_dispatch(
arg_name='return_indices',
arg_index=6,
default=False,
@ -500,7 +492,6 @@ max_pool2d = torch._jit_internal.boolean_dispatch(
func_name='max_pool2d')
@weak_script
def max_pool3d_with_indices(input, kernel_size, stride=None, padding=0,
dilation=1, ceil_mode=False, return_indices=False):
# type: (Tensor, BroadcastingList3[int], Optional[BroadcastingList3[int]], BroadcastingList3[int], BroadcastingList3[int], bool, bool) -> Tuple[Tensor, Tensor] # noqa
@ -515,7 +506,6 @@ def max_pool3d_with_indices(input, kernel_size, stride=None, padding=0,
input, kernel_size, stride, padding, dilation, ceil_mode)
@weak_script
def _max_pool3d(input, kernel_size, stride=None, padding=0, dilation=1,
ceil_mode=False, return_indices=False):
# type: (Tensor, BroadcastingList3[int], Optional[BroadcastingList3[int]], BroadcastingList3[int], BroadcastingList3[int], bool, bool) -> Tensor # noqa
@ -524,7 +514,7 @@ def _max_pool3d(input, kernel_size, stride=None, padding=0, dilation=1,
return torch.max_pool3d(
input, kernel_size, stride, padding, dilation, ceil_mode)
max_pool3d = torch._jit_internal.boolean_dispatch(
max_pool3d = boolean_dispatch(
arg_name='return_indices',
arg_index=6,
default=False,
@ -534,7 +524,6 @@ max_pool3d = torch._jit_internal.boolean_dispatch(
func_name='max_pool3d')
@weak_script
def _unpool_output_size(input, kernel_size, stride, padding, output_size):
# type: (Tensor, List[int], List[int], List[int], Optional[List[int]]) -> List[int]
input_size = input.size()
@ -564,7 +553,6 @@ def _unpool_output_size(input, kernel_size, stride, padding, output_size):
return ret
@weak_script
def max_unpool1d(input, indices, kernel_size, stride=None, padding=0,
output_size=None):
# type: (Tensor, Tensor, BroadcastingList1[int], Optional[BroadcastingList1[int]], BroadcastingList1[int], Optional[BroadcastingList1[int]]) -> Tensor # noqa
@ -588,7 +576,6 @@ def max_unpool1d(input, indices, kernel_size, stride=None, padding=0,
output_size).squeeze(3)
@weak_script
def max_unpool2d(input, indices, kernel_size, stride=None, padding=0,
output_size=None):
# type: (Tensor, Tensor, BroadcastingList2[int], Optional[BroadcastingList2[int]], BroadcastingList2[int], Optional[BroadcastingList2[int]]) -> Tensor # noqa
@ -607,7 +594,6 @@ def max_unpool2d(input, indices, kernel_size, stride=None, padding=0,
return torch._C._nn.max_unpool2d(input, indices, output_size)
@weak_script
def max_unpool3d(input, indices, kernel_size, stride=None, padding=0,
output_size=None):
# type: (Tensor, Tensor, BroadcastingList3[int], Optional[BroadcastingList3[int]], BroadcastingList3[int], Optional[BroadcastingList3[int]]) -> Tensor # noqa
@ -627,7 +613,6 @@ def max_unpool3d(input, indices, kernel_size, stride=None, padding=0,
input, indices, output_size, _stride, padding)
@weak_script
def lp_pool2d(input, norm_type, kernel_size, stride=None, ceil_mode=False):
# type: (Tensor, float, int, Optional[BroadcastingList2[int]], bool) -> Tensor
r"""Applies a 2D power-average pooling over an input signal composed of
@ -645,7 +630,6 @@ def lp_pool2d(input, norm_type, kernel_size, stride=None, ceil_mode=False):
return (torch.sign(out) * relu(torch.abs(out))).mul(kw * kh).pow(1. / norm_type)
@weak_script
def lp_pool1d(input, norm_type, kernel_size, stride=None, ceil_mode=False):
# type: (Tensor, float, int, Optional[BroadcastingList1[int]], bool) -> Tensor
r"""Applies a 1D power-average pooling over an input signal composed of
@ -662,7 +646,6 @@ def lp_pool1d(input, norm_type, kernel_size, stride=None, ceil_mode=False):
return (torch.sign(out) * relu(torch.abs(out))).mul(kernel_size).pow(1. / norm_type)
@weak_script
def adaptive_max_pool1d_with_indices(input, output_size, return_indices=False):
# type: (Tensor, BroadcastingList1[int], bool) -> Tuple[Tensor, Tensor]
r"""Applies a 1D adaptive max pooling over an input signal composed of
@ -677,12 +660,11 @@ def adaptive_max_pool1d_with_indices(input, output_size, return_indices=False):
return torch.adaptive_max_pool1d(input, output_size)
@weak_script
def _adaptive_max_pool1d(input, output_size, return_indices=False):
# type: (Tensor, BroadcastingList1[int], bool) -> Tensor
return adaptive_max_pool1d_with_indices(input, output_size)[0]
adaptive_max_pool1d = torch._jit_internal.boolean_dispatch(
adaptive_max_pool1d = boolean_dispatch(
arg_name='return_indices',
arg_index=2,
default=False,
@ -692,7 +674,6 @@ adaptive_max_pool1d = torch._jit_internal.boolean_dispatch(
func_name='adaptive_max_pool1d')
@weak_script
def adaptive_max_pool2d_with_indices(input, output_size, return_indices=False):
# type: (Tensor, BroadcastingList2[int], bool) -> Tuple[Tensor, Tensor]
r"""Applies a 2D adaptive max pooling over an input signal composed of
@ -709,12 +690,11 @@ def adaptive_max_pool2d_with_indices(input, output_size, return_indices=False):
return torch._C._nn.adaptive_max_pool2d(input, output_size)
@weak_script
def _adaptive_max_pool2d(input, output_size, return_indices=False):
# type: (Tensor, BroadcastingList2[int], bool) -> Tensor
return adaptive_max_pool2d_with_indices(input, output_size)[0]
adaptive_max_pool2d = torch._jit_internal.boolean_dispatch(
adaptive_max_pool2d = boolean_dispatch(
arg_name='return_indices',
arg_index=2,
default=False,
@ -724,7 +704,6 @@ adaptive_max_pool2d = torch._jit_internal.boolean_dispatch(
func_name='adaptive_max_pool2d')
@weak_script
def adaptive_max_pool3d_with_indices(input, output_size, return_indices=False):
# type: (Tensor, BroadcastingList3[int], bool) -> Tuple[Tensor, Tensor]
r"""Applies a 3D adaptive max pooling over an input signal composed of
@ -741,12 +720,11 @@ def adaptive_max_pool3d_with_indices(input, output_size, return_indices=False):
return torch._C._nn.adaptive_max_pool3d(input, output_size)
@weak_script
def _adaptive_max_pool3d(input, output_size, return_indices=False):
# type: (Tensor, BroadcastingList3[int], bool) -> Tensor
return adaptive_max_pool3d_with_indices(input, output_size)[0]
adaptive_max_pool3d = torch._jit_internal.boolean_dispatch(
adaptive_max_pool3d = boolean_dispatch(
arg_name='return_indices',
arg_index=2,
default=False,
@ -769,7 +747,6 @@ Args:
""")
@weak_script
def adaptive_avg_pool2d(input, output_size):
# type: (Tensor, BroadcastingList2[int]) -> Tensor
r"""
@ -786,7 +763,6 @@ def adaptive_avg_pool2d(input, output_size):
return torch._C._nn.adaptive_avg_pool2d(input, _output_size)
@weak_script
def adaptive_avg_pool3d(input, output_size):
# type: (Tensor, BroadcastingList3[int]) -> Tensor
r"""
@ -804,7 +780,6 @@ def adaptive_avg_pool3d(input, output_size):
# Activation functions
@weak_script
def dropout(input, p=0.5, training=True, inplace=False):
# type: (Tensor, float, bool, bool) -> Tensor
r"""
@ -827,7 +802,6 @@ def dropout(input, p=0.5, training=True, inplace=False):
else _VF.dropout(input, p, training))
@weak_script
def alpha_dropout(input, p=0.5, training=False, inplace=False):
# type: (Tensor, float, bool, bool) -> Tensor
r"""Applies alpha dropout to the input.
@ -842,7 +816,6 @@ def alpha_dropout(input, p=0.5, training=False, inplace=False):
else _VF.alpha_dropout(input, p, training))
@weak_script
def dropout2d(input, p=0.5, training=True, inplace=False):
# type: (Tensor, float, bool, bool) -> Tensor
r"""
@ -867,7 +840,6 @@ def dropout2d(input, p=0.5, training=True, inplace=False):
else _VF.feature_dropout(input, p, training))
@weak_script
def dropout3d(input, p=0.5, training=True, inplace=False):
# type: (Tensor, float, bool, bool) -> Tensor
r"""
@ -894,7 +866,6 @@ def dropout3d(input, p=0.5, training=True, inplace=False):
else _VF.feature_dropout(input, p, training))
@weak_script
def feature_alpha_dropout(input, p=0.5, training=False, inplace=False):
# type: (Tensor, float, bool, bool) -> Tensor
if p < 0. or p > 1.:
@ -905,7 +876,6 @@ def feature_alpha_dropout(input, p=0.5, training=False, inplace=False):
else _VF.feature_alpha_dropout(input, p, training))
@weak_script
def threshold(input, threshold, value, inplace=False):
# type: (Tensor, float, float, bool) -> Tensor
r"""Thresholds each element of the input Tensor.
@ -926,7 +896,6 @@ In-place version of :func:`~threshold`.
""")
@weak_script
def relu(input, inplace=False):
# type: (Tensor, bool) -> Tensor
r"""relu(input, inplace=False) -> Tensor
@ -948,7 +917,6 @@ In-place version of :func:`~relu`.
""")
@weak_script
def glu(input, dim=-1):
# type: (Tensor, int) -> Tensor
r"""
@ -973,7 +941,6 @@ def glu(input, dim=-1):
return torch._C._nn.glu(input, dim)
@weak_script
def hardtanh(input, min_val=-1., max_val=1., inplace=False):
# type: (Tensor, float, float, bool) -> Tensor
r"""
@ -996,7 +963,6 @@ In-place version of :func:`~hardtanh`.
""")
@weak_script
def relu6(input, inplace=False):
# type: (Tensor, bool) -> Tensor
r"""relu6(input, inplace=False) -> Tensor
@ -1008,7 +974,6 @@ def relu6(input, inplace=False):
return hardtanh(input, 0., 6., inplace)
@weak_script
def elu(input, alpha=1., inplace=False):
# type: (Tensor, float, bool) -> Tensor
r"""Applies element-wise,
@ -1030,7 +995,6 @@ In-place version of :func:`~elu`.
""")
@weak_script
def selu(input, inplace=False):
# type: (Tensor, bool) -> Tensor
r"""selu(input, inplace=False) -> Tensor
@ -1056,7 +1020,6 @@ In-place version of :func:`~selu`.
""")
@weak_script
def celu(input, alpha=1., inplace=False):
# type: (Tensor, float, bool) -> Tensor
r"""celu(input, alpha=1., inplace=False) -> Tensor
@ -1079,7 +1042,6 @@ In-place version of :func:`~celu`.
""")
@weak_script
def leaky_relu(input, negative_slope=0.01, inplace=False):
# type: (Tensor, float, bool) -> Tensor
r"""
@ -1104,7 +1066,6 @@ In-place version of :func:`~leaky_relu`.
""")
@weak_script
def prelu(input, weight):
# type: (Tensor, Tensor) -> Tensor
r"""prelu(input, weight) -> Tensor
@ -1118,7 +1079,6 @@ def prelu(input, weight):
return torch.prelu(input, weight)
@weak_script
def rrelu(input, lower=1. / 8, upper=1. / 3, training=False, inplace=False):
# type: (Tensor, float, float, bool, bool) -> Tensor
r"""rrelu(input, lower=1./8, upper=1./3, training=False, inplace=False) -> Tensor
@ -1148,7 +1108,6 @@ Applies element-wise :math:`\text{LogSigmoid}(x_i) = \log \left(\frac{1}{1 + \ex
See :class:`~torch.nn.LogSigmoid` for more details.
""")
@weak_script
def gelu(input):
r"""gelu(input) -> Tensor
@ -1162,7 +1121,6 @@ def gelu(input):
return torch._C._nn.gelu(input)
@weak_script
def hardshrink(input, lambd=0.5):
# type: (Tensor, float) -> Tensor
r"""
@ -1175,7 +1133,6 @@ def hardshrink(input, lambd=0.5):
return torch.hardshrink(input, lambd)
@weak_script
def tanhshrink(input):
r"""tanhshrink(input) -> Tensor
@ -1186,7 +1143,6 @@ def tanhshrink(input):
return input - input.tanh()
@weak_script
def softsign(input):
r"""softsign(input) -> Tensor
@ -1202,7 +1158,6 @@ softplus(input, beta=1, threshold=20) -> Tensor
""")
@weak_script
def _get_softmax_dim(name, ndim, stacklevel):
# type: (str, int, int) -> int
warnings.warn("Implicit dimension choice for {} has been deprecated. "
@ -1214,7 +1169,6 @@ def _get_softmax_dim(name, ndim, stacklevel):
return ret
@weak_script
def softmin(input, dim=None, _stacklevel=3, dtype=None):
# type: (Tensor, Optional[int], int, Optional[int]) -> Tensor
r"""Applies a softmin function.
@ -1240,7 +1194,6 @@ def softmin(input, dim=None, _stacklevel=3, dtype=None):
return ret
@weak_script
def softmax(input, dim=None, _stacklevel=3, dtype=None):
# type: (Tensor, Optional[int], int, Optional[int]) -> Tensor
r"""Applies a softmax function.
@ -1276,7 +1229,6 @@ def softmax(input, dim=None, _stacklevel=3, dtype=None):
return ret
@weak_script
def gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1):
# type: (Tensor, float, bool, float, int) -> Tensor
r"""
@ -1337,7 +1289,6 @@ def gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1):
return ret
@weak_script
def log_softmax(input, dim=None, _stacklevel=3, dtype=None):
# type: (Tensor, Optional[int], int, Optional[int]) -> Tensor
r"""Applies a softmax followed by a logarithm.
@ -1373,7 +1324,6 @@ See :class:`~torch.nn.Softshrink` for more details.
""")
@weak_script
def tanh(input):
r"""tanh(input) -> Tensor
@ -1386,7 +1336,6 @@ def tanh(input):
return input.tanh()
@weak_script
def sigmoid(input):
r"""sigmoid(input) -> Tensor
@ -1398,7 +1347,6 @@ def sigmoid(input):
return input.sigmoid()
@weak_script
def linear(input, weight, bias=None):
# type: (Tensor, Tensor, Optional[Tensor]) -> Tensor
r"""
@ -1423,7 +1371,6 @@ def linear(input, weight, bias=None):
return ret
@weak_script
def bilinear(input1, input2, weight, bias=None):
# type: (Tensor, Tensor, Tensor, Optional[Tensor]) -> Tensor
return torch.bilinear(input1, input2, weight, bias)
@ -1435,7 +1382,6 @@ def _no_grad_embedding_renorm_(weight, input, max_norm, norm_type):
torch.embedding_renorm_(weight, input, max_norm, norm_type)
@weak_script
def embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2.,
scale_grad_by_freq=False, sparse=False):
# type: (Tensor, Tensor, Optional[int], Optional[float], float, bool, bool) -> Tensor
@ -1517,7 +1463,6 @@ def embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2.,
return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
@weak_script
def embedding_bag(input, weight, offsets=None, max_norm=None, norm_type=2,
scale_grad_by_freq=False, mode='mean', sparse=False,
per_sample_weights=None):
@ -1677,7 +1622,6 @@ def embedding_bag(input, weight, offsets=None, max_norm=None, norm_type=2,
return ret
@weak_script
def batch_norm(input, running_mean, running_var, weight=None, bias=None,
training=False, momentum=0.1, eps=1e-5):
# type: (Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor], bool, float, float) -> Tensor # noqa
@ -1709,7 +1653,6 @@ def batch_norm(input, running_mean, running_var, weight=None, bias=None,
)
@weak_script
def instance_norm(input, running_mean=None, running_var=None, weight=None,
bias=None, use_input_stats=True, momentum=0.1, eps=1e-5):
# type: (Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor], bool, float, float) -> Tensor # noqa
@ -1725,7 +1668,6 @@ def instance_norm(input, running_mean=None, running_var=None, weight=None,
)
@weak_script
def layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-5):
# type: (Tensor, List[int], Optional[Tensor], Optional[Tensor], float) -> Tensor
r"""Applies Layer Normalization for last certain number of dimensions.
@ -1736,7 +1678,6 @@ def layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-5):
torch.backends.cudnn.enabled)
@weak_script
def group_norm(input, num_groups, weight=None, bias=None, eps=1e-5):
# type: (Tensor, int, Optional[Tensor], Optional[Tensor], float) -> Tensor
r"""Applies Group Normalization for last certain number of dimensions.
@ -1747,7 +1688,6 @@ def group_norm(input, num_groups, weight=None, bias=None, eps=1e-5):
torch.backends.cudnn.enabled)
@weak_script
def local_response_norm(input, size, alpha=1e-4, beta=0.75, k=1.):
# type: (Tensor, int, float, float, float) -> Tensor
r"""Applies local response normalization over an input signal composed of
@ -1776,7 +1716,6 @@ def local_response_norm(input, size, alpha=1e-4, beta=0.75, k=1.):
# loss
@weak_script
def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=0,
reduction='mean', zero_infinity=False):
# type: (Tensor, Tensor, Tensor, Tensor, int, str, bool) -> Tensor
@ -1824,7 +1763,6 @@ def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=0,
zero_infinity)
@weak_script
def nll_loss(input, target, weight=None, size_average=None, ignore_index=-100,
reduce=None, reduction='mean'):
# type: (Tensor, Tensor, Optional[Tensor], Optional[bool], int, Optional[bool], str) -> Tensor
@ -1903,7 +1841,6 @@ def nll_loss(input, target, weight=None, size_average=None, ignore_index=-100,
return ret
@weak_script
def poisson_nll_loss(input, target, log_input=True, full=False, size_average=None, eps=1e-8,
reduce=None, reduction='mean'):
# type: (Tensor, Tensor, bool, bool, Optional[bool], float, Optional[bool], str) -> Tensor
@ -1949,7 +1886,6 @@ def poisson_nll_loss(input, target, log_input=True, full=False, size_average=Non
return ret
@weak_script
def kl_div(input, target, size_average=None, reduce=None, reduction='mean'):
# type: (Tensor, Tensor, Optional[bool], Optional[bool], str) -> Tensor
r"""The `Kullback-Leibler divergence`_ Loss.
@ -2007,7 +1943,6 @@ def kl_div(input, target, size_average=None, reduce=None, reduction='mean'):
return reduced
@weak_script
def cross_entropy(input, target, weight=None, size_average=None, ignore_index=-100,
reduce=None, reduction='mean'):
# type: (Tensor, Tensor, Optional[Tensor], Optional[bool], int, Optional[bool], str) -> Tensor
@ -2056,7 +1991,6 @@ def cross_entropy(input, target, weight=None, size_average=None, ignore_index=-1
return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
@weak_script
def binary_cross_entropy(input, target, weight=None, size_average=None,
reduce=None, reduction='mean'):
# type: (Tensor, Tensor, Optional[Tensor], Optional[bool], Optional[bool], str) -> Tensor
@ -2113,7 +2047,6 @@ def binary_cross_entropy(input, target, weight=None, size_average=None,
input, target, weight, reduction_enum)
@weak_script
def binary_cross_entropy_with_logits(input, target, weight=None, size_average=None,
reduce=None, reduction='mean', pos_weight=None):
# type: (Tensor, Tensor, Optional[Tensor], Optional[bool], Optional[bool], str, Optional[Tensor]) -> Tensor
@ -2174,14 +2107,12 @@ def _pointwise_loss(lambd, lambd_optimized, input, target, reduction='mean'):
return lambd_optimized(expanded_input, expanded_target, _Reduction.get_enum(reduction))
@weak_script
def _smooth_l1_loss(input, target):
# type: (Tensor, Tensor) -> Tensor
t = torch.abs(input - target)
return torch.where(t < 1, 0.5 * t ** 2, t - 0.5)
@weak_script
def smooth_l1_loss(input, target, size_average=None, reduce=None, reduction='mean'):
# type: (Tensor, Tensor, Optional[bool], Optional[bool], str) -> Tensor
r"""Function that uses a squared term if the absolute
@ -2206,7 +2137,6 @@ def smooth_l1_loss(input, target, size_average=None, reduce=None, reduction='mea
return ret
@weak_script
def l1_loss(input, target, size_average=None, reduce=None, reduction='mean'):
# type: (Tensor, Tensor, Optional[bool], Optional[bool], str) -> Tensor
r"""l1_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor
@ -2232,7 +2162,6 @@ def l1_loss(input, target, size_average=None, reduce=None, reduction='mean'):
return ret
@weak_script
def mse_loss(input, target, size_average=None, reduce=None, reduction='mean'):
# type: (Tensor, Tensor, Optional[bool], Optional[bool], str) -> Tensor
r"""mse_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor
@ -2258,7 +2187,6 @@ def mse_loss(input, target, size_average=None, reduce=None, reduction='mean'):
return ret
@weak_script
def margin_ranking_loss(input1, input2, target, margin=0, size_average=None,
reduce=None, reduction='mean'):
# type: (Tensor, Tensor, Tensor, float, Optional[bool], Optional[bool], str) -> Tensor
@ -2276,7 +2204,6 @@ def margin_ranking_loss(input1, input2, target, margin=0, size_average=None,
return torch.margin_ranking_loss(input1, input2, target, margin, reduction_enum)
@weak_script
def hinge_embedding_loss(input, target, margin=1.0, size_average=None,
reduce=None, reduction='mean'):
# type: (Tensor, Tensor, float, Optional[bool], Optional[bool], str) -> Tensor
@ -2291,7 +2218,6 @@ def hinge_embedding_loss(input, target, margin=1.0, size_average=None,
return torch.hinge_embedding_loss(input, target, margin, reduction_enum)
@weak_script
def multilabel_margin_loss(input, target, size_average=None, reduce=None, reduction='mean'):
# type: (Tensor, Tensor, Optional[bool], Optional[bool], str) -> Tensor
r"""multilabel_margin_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor
@ -2305,7 +2231,6 @@ def multilabel_margin_loss(input, target, size_average=None, reduce=None, reduct
return torch._C._nn.multilabel_margin_loss(input, target, reduction_enum)
@weak_script
def soft_margin_loss(input, target, size_average=None, reduce=None, reduction='mean'):
# type: (Tensor, Tensor, Optional[bool], Optional[bool], str) -> Tensor
r"""soft_margin_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor
@ -2319,7 +2244,6 @@ def soft_margin_loss(input, target, size_average=None, reduce=None, reduction='m
return torch._C._nn.soft_margin_loss(input, target, reduction_enum)
@weak_script
def multilabel_soft_margin_loss(input, target, weight=None, size_average=None,
reduce=None, reduction='mean'):
# type: (Tensor, Tensor, Optional[Tensor], Optional[bool], Optional[bool], str) -> Tensor
@ -2349,7 +2273,6 @@ def multilabel_soft_margin_loss(input, target, weight=None, size_average=None,
return ret
@weak_script
def cosine_embedding_loss(input1, input2, target, margin=0, size_average=None,
reduce=None, reduction='mean'):
# type: (Tensor, Tensor, Tensor, float, Optional[bool], Optional[bool], str) -> Tensor
@ -2364,7 +2287,6 @@ def cosine_embedding_loss(input1, input2, target, margin=0, size_average=None,
return torch.cosine_embedding_loss(input1, input2, target, margin, reduction_enum)
@weak_script
def multi_margin_loss(input, target, p=1, margin=1., weight=None, size_average=None,
reduce=None, reduction='mean'):
# type: (Tensor, Tensor, int, float, Optional[Tensor], Optional[bool], Optional[bool], str) -> Tensor
@ -2635,7 +2557,6 @@ GRID_SAMPLE_PADDING_MODES = {
}
@weak_script
def grid_sample(input, grid, mode='bilinear', padding_mode='zeros'):
# type: (Tensor, Tensor, str, str) -> Tensor
r"""Given an :attr:`input` and a flow-field :attr:`grid`, computes the
@ -2717,7 +2638,6 @@ def grid_sample(input, grid, mode='bilinear', padding_mode='zeros'):
return torch.grid_sampler(input, grid, mode_enum, padding_mode_enum)
@weak_script
def affine_grid(theta, size):
# type: (Tensor, List[int]) -> Tensor
r"""Generates a 2d flow field, given a batch of affine matrices :attr:`theta`.
@ -2735,7 +2655,6 @@ def affine_grid(theta, size):
return vision.affine_grid_generator(theta, size)
@weak_script
def pad(input, pad, mode='constant', value=0):
# type: (Tensor, List[int], str, float) -> Tensor
r"""Pads tensor.
@ -2844,7 +2763,6 @@ def pad(input, pad, mode='constant', value=0):
# distance
@weak_script
def pairwise_distance(x1, x2, p=2., eps=1e-6, keepdim=False):
# type: (Tensor, Tensor, float, float, bool) -> Tensor
r"""
@ -2952,7 +2870,6 @@ Examples:
""")
@weak_script
def triplet_margin_loss(anchor, positive, negative, margin=1.0, p=2, eps=1e-6, swap=False, size_average=None,
reduce=None, reduction="mean"):
# type: (Tensor, Tensor, Tensor, float, float, float, bool, Optional[bool], Optional[bool], str) -> Tensor
@ -2967,7 +2884,6 @@ def triplet_margin_loss(anchor, positive, negative, margin=1.0, p=2, eps=1e-6, s
swap, reduction_enum)
@weak_script
def normalize(input, p=2, dim=1, eps=1e-12, out=None):
# type: (Tensor, float, int, float, Optional[Tensor]) -> Tensor
r"""Performs :math:`L_p` normalization of inputs over specified dimension.
@ -3001,7 +2917,6 @@ def assert_int_or_pair(arg, arg_name, message):
assert isinstance(arg, int) or len(arg) == 2, message.format(arg_name)
@weak_script
def unfold(input, kernel_size, dilation=1, padding=0, stride=1):
# type: (Tensor, BroadcastingList2[int], BroadcastingList2[int], BroadcastingList2[int], BroadcastingList2[int]) -> Tensor # noqa
r"""Extracts sliding local blocks from an batched input tensor.
@ -3036,7 +2951,6 @@ def unfold(input, kernel_size, dilation=1, padding=0, stride=1):
return ret
@weak_script
def fold(input, output_size, kernel_size, dilation=1, padding=0, stride=1):
# type: (Tensor, BroadcastingList2[int], BroadcastingList2[int], BroadcastingList2[int], BroadcastingList2[int], BroadcastingList2[int]) -> Tensor # noqa
r"""Combines an array of sliding local blocks into a large containing
@ -3064,7 +2978,6 @@ def fold(input, output_size, kernel_size, dilation=1, padding=0, stride=1):
return ret
@weak_script
def _pad_circular(input, padding):
# type: (Tensor, List[int]) -> Tensor
"""
@ -3090,7 +3003,6 @@ def _pad_circular(input, padding):
return input
@weak_script
def multi_head_attention_forward(query, # type: Tensor
key, # type: Tensor
value, # type: Tensor

View File

@ -4,7 +4,6 @@ import math
import warnings
import torch
from .._jit_internal import weak_script
# These no_grad_* functions are necessary as wrappers around the parts of these
# functions that use `with torch.no_grad()`. The JIT doesn't support context
@ -72,7 +71,6 @@ def calculate_gain(nonlinearity, param=None):
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
@weak_script
def uniform_(tensor, a=0., b=1.):
# type: (Tensor, float, float) -> Tensor
r"""Fills the input Tensor with values drawn from the uniform
@ -90,7 +88,6 @@ def uniform_(tensor, a=0., b=1.):
return _no_grad_uniform_(tensor, a, b)
@weak_script
def normal_(tensor, mean=0., std=1.):
# type: (Tensor, float, float) -> Tensor
r"""Fills the input Tensor with values drawn from the normal
@ -108,7 +105,6 @@ def normal_(tensor, mean=0., std=1.):
return _no_grad_normal_(tensor, mean, std)
@weak_script
def constant_(tensor, val):
# type: (Tensor, float) -> Tensor
r"""Fills the input Tensor with the value :math:`\text{val}`.
@ -124,7 +120,6 @@ def constant_(tensor, val):
return _no_grad_fill_(tensor, val)
@weak_script
def ones_(tensor):
# type: (Tensor) -> Tensor
r"""Fills the input Tensor with ones`.
@ -139,7 +134,6 @@ def ones_(tensor):
return _no_grad_fill_(tensor, 1.)
@weak_script
def zeros_(tensor):
# type: (Tensor) -> Tensor
r"""Fills the input Tensor with zeros`.
@ -205,7 +199,6 @@ def dirac_(tensor):
return tensor
@weak_script
def _calculate_fan_in_and_fan_out(tensor):
dimensions = tensor.dim()
if dimensions < 2:
@ -226,7 +219,6 @@ def _calculate_fan_in_and_fan_out(tensor):
return fan_in, fan_out
@weak_script
def xavier_uniform_(tensor, gain=1.):
# type: (Tensor, float) -> Tensor
r"""Fills the input `Tensor` with values according to the method
@ -255,7 +247,6 @@ def xavier_uniform_(tensor, gain=1.):
return _no_grad_uniform_(tensor, -a, a)
@weak_script
def xavier_normal_(tensor, gain=1.):
# type: (Tensor, float) -> Tensor
r"""Fills the input `Tensor` with values according to the method

View File

@ -7,10 +7,8 @@ from torch.nn.init import xavier_normal_
from torch.nn.parameter import Parameter
from .module import Module
from .. import functional as F
from ..._jit_internal import weak_module, weak_script_method
@weak_module
class Threshold(Module):
r"""Thresholds each element of the input Tensor.
@ -48,7 +46,6 @@ class Threshold(Module):
self.inplace = inplace
# TODO: check in THNN (if inplace == True, then assert value <= threshold)
@weak_script_method
def forward(self, input):
return F.threshold(input, self.threshold, self.value, self.inplace)
@ -59,7 +56,6 @@ class Threshold(Module):
)
@weak_module
class ReLU(Module):
r"""Applies the rectified linear unit function element-wise:
@ -94,7 +90,6 @@ class ReLU(Module):
super(ReLU, self).__init__()
self.inplace = inplace
@weak_script_method
def forward(self, input):
return F.relu(input, inplace=self.inplace)
@ -103,7 +98,6 @@ class ReLU(Module):
return inplace_str
@weak_module
class RReLU(Module):
r"""Applies the randomized leaky rectified liner unit function, element-wise,
as described in the paper:
@ -151,7 +145,6 @@ class RReLU(Module):
self.upper = upper
self.inplace = inplace
@weak_script_method
def forward(self, input):
return F.rrelu(input, self.lower, self.upper, self.training, self.inplace)
@ -160,7 +153,6 @@ class RReLU(Module):
return 'lower={}, upper={}{}'.format(self.lower, self.upper, inplace_str)
@weak_module
class Hardtanh(Module):
r"""Applies the HardTanh function element-wise
@ -213,7 +205,6 @@ class Hardtanh(Module):
self.inplace = inplace
assert self.max_val > self.min_val
@weak_script_method
def forward(self, input):
return F.hardtanh(input, self.min_val, self.max_val, self.inplace)
@ -224,7 +215,6 @@ class Hardtanh(Module):
)
@weak_module
class ReLU6(Hardtanh):
r"""Applies the element-wise function:
@ -256,7 +246,6 @@ class ReLU6(Hardtanh):
return inplace_str
@weak_module
class Sigmoid(Module):
r"""Applies the element-wise function:
@ -278,12 +267,10 @@ class Sigmoid(Module):
>>> output = m(input)
"""
@weak_script_method
def forward(self, input):
return torch.sigmoid(input)
@weak_module
class Tanh(Module):
r"""Applies the element-wise function:
@ -304,12 +291,10 @@ class Tanh(Module):
>>> output = m(input)
"""
@weak_script_method
def forward(self, input):
return torch.tanh(input)
@weak_module
class ELU(Module):
r"""Applies the element-wise function:
@ -340,7 +325,6 @@ class ELU(Module):
self.alpha = alpha
self.inplace = inplace
@weak_script_method
def forward(self, input):
return F.elu(input, self.alpha, self.inplace)
@ -349,7 +333,6 @@ class ELU(Module):
return 'alpha={}{}'.format(self.alpha, inplace_str)
@weak_module
class CELU(Module):
r"""Applies the element-wise function:
@ -385,7 +368,6 @@ class CELU(Module):
self.alpha = alpha
self.inplace = inplace
@weak_script_method
def forward(self, input):
return F.celu(input, self.alpha, self.inplace)
@ -394,7 +376,6 @@ class CELU(Module):
return 'alpha={}{}'.format(self.alpha, inplace_str)
@weak_module
class SELU(Module):
r"""Applied element-wise, as:
@ -430,7 +411,6 @@ class SELU(Module):
super(SELU, self).__init__()
self.inplace = inplace
@weak_script_method
def forward(self, input):
return F.selu(input, self.inplace)
@ -439,7 +419,6 @@ class SELU(Module):
return inplace_str
@weak_module
class GLU(Module):
r"""Applies the gated linear unit function
:math:`{GLU}(a, b)= a \otimes \sigma(b)` where :math:`a` is the first half
@ -465,7 +444,6 @@ class GLU(Module):
super(GLU, self).__init__()
self.dim = dim
@weak_script_method
def forward(self, input):
return F.glu(input, self.dim)
@ -473,7 +451,6 @@ class GLU(Module):
return 'dim={}'.format(self.dim)
@weak_module
class Hardshrink(Module):
r"""Applies the hard shrinkage function element-wise:
@ -507,7 +484,6 @@ class Hardshrink(Module):
super(Hardshrink, self).__init__()
self.lambd = lambd
@weak_script_method
def forward(self, input):
return F.hardshrink(input, self.lambd)
@ -515,7 +491,6 @@ class Hardshrink(Module):
return '{}'.format(self.lambd)
@weak_module
class LeakyReLU(Module):
r"""Applies the element-wise function:
@ -556,7 +531,6 @@ class LeakyReLU(Module):
self.negative_slope = negative_slope
self.inplace = inplace
@weak_script_method
def forward(self, input):
return F.leaky_relu(input, self.negative_slope, self.inplace)
@ -565,7 +539,6 @@ class LeakyReLU(Module):
return 'negative_slope={}{}'.format(self.negative_slope, inplace_str)
@weak_module
class LogSigmoid(Module):
r"""Applies the element-wise function:
@ -586,12 +559,10 @@ class LogSigmoid(Module):
>>> output = m(input)
"""
@weak_script_method
def forward(self, input):
return F.logsigmoid(input)
@weak_module
class Softplus(Module):
r"""Applies the element-wise function:
@ -628,7 +599,6 @@ class Softplus(Module):
self.beta = beta
self.threshold = threshold
@weak_script_method
def forward(self, input):
return F.softplus(input, self.beta, self.threshold)
@ -636,7 +606,6 @@ class Softplus(Module):
return 'beta={}, threshold={}'.format(self.beta, self.threshold)
@weak_module
class Softshrink(Module):
r"""Applies the soft shrinkage function elementwise:
@ -670,7 +639,6 @@ class Softshrink(Module):
super(Softshrink, self).__init__()
self.lambd = lambd
@weak_script_method
def forward(self, input):
return F.softshrink(input, self.lambd)
@ -678,7 +646,6 @@ class Softshrink(Module):
return str(self.lambd)
@weak_module
class MultiheadAttention(Module):
r"""Allows the model to jointly attend to information
from different representation subspaces.
@ -759,7 +726,6 @@ class MultiheadAttention(Module):
if self.bias_v is not None:
xavier_normal_(self.bias_v)
@weak_script_method
def forward(self, query, key, value, key_padding_mask=None,
need_weights=True, attn_mask=None):
r"""
@ -817,7 +783,6 @@ class MultiheadAttention(Module):
attn_mask=attn_mask)
@weak_module
class PReLU(Module):
r"""Applies the element-wise function:
@ -874,7 +839,6 @@ class PReLU(Module):
super(PReLU, self).__init__()
self.weight = Parameter(torch.Tensor(num_parameters).fill_(init))
@weak_script_method
def forward(self, input):
return F.prelu(input, self.weight)
@ -882,7 +846,6 @@ class PReLU(Module):
return 'num_parameters={}'.format(self.num_parameters)
@weak_module
class Softsign(Module):
r"""Applies the element-wise function:
@ -903,12 +866,10 @@ class Softsign(Module):
>>> output = m(input)
"""
@weak_script_method
def forward(self, input):
return F.softsign(input)
@weak_module
class Tanhshrink(Module):
r"""Applies the element-wise function:
@ -929,12 +890,10 @@ class Tanhshrink(Module):
>>> output = m(input)
"""
@weak_script_method
def forward(self, input):
return F.tanhshrink(input)
@weak_module
class Softmin(Module):
r"""Applies the Softmin function to an n-dimensional input Tensor
rescaling them so that the elements of the n-dimensional output Tensor
@ -970,12 +929,10 @@ class Softmin(Module):
super(Softmin, self).__init__()
self.dim = dim
@weak_script_method
def forward(self, input):
return F.softmin(input, self.dim, _stacklevel=5)
@weak_module
class Softmax(Module):
r"""Applies the Softmax function to an n-dimensional input Tensor
rescaling them so that the elements of the n-dimensional output Tensor
@ -1021,7 +978,6 @@ class Softmax(Module):
if not hasattr(self, 'dim'):
self.dim = None
@weak_script_method
def forward(self, input):
return F.softmax(input, self.dim, _stacklevel=5)
@ -1029,7 +985,6 @@ class Softmax(Module):
return 'dim={dim}'.format(dim=self.dim)
@weak_module
class Softmax2d(Module):
r"""Applies SoftMax over features to each spatial location.
@ -1052,13 +1007,11 @@ class Softmax2d(Module):
>>> output = m(input)
"""
@weak_script_method
def forward(self, input):
assert input.dim() == 4, 'Softmax2d requires a 4D tensor as input'
return F.softmax(input, 1, _stacklevel=5)
@weak_module
class LogSoftmax(Module):
r"""Applies the :math:`\log(\text{Softmax}(x))` function to an n-dimensional
input Tensor. The LogSoftmax formulation can be simplified as:
@ -1095,6 +1048,5 @@ class LogSoftmax(Module):
if not hasattr(self, 'dim'):
self.dim = None
@weak_script_method
def forward(self, input):
return F.log_softmax(input, self.dim, _stacklevel=5)

View File

@ -6,12 +6,10 @@ from .module import Module
from torch.nn.parameter import Parameter
from .. import functional as F
from .. import init
from ..._jit_internal import weak_module, weak_script_method
# TODO: check contiguous in THNN
# TODO: use separate backend functions?
@weak_module
class _BatchNorm(Module):
_version = 2
__constants__ = ['track_running_stats', 'momentum', 'eps', 'weight', 'bias',
@ -57,7 +55,6 @@ class _BatchNorm(Module):
def _check_input_dim(self, input):
raise NotImplementedError
@weak_script_method
def forward(self, input):
self._check_input_dim(input)
@ -103,7 +100,6 @@ class _BatchNorm(Module):
missing_keys, unexpected_keys, error_msgs)
@weak_module
class BatchNorm1d(_BatchNorm):
r"""Applies Batch Normalization over a 2D or 3D input (a mini-batch of 1D
inputs with optional additional channel dimension) as described in the paper
@ -170,14 +166,12 @@ class BatchNorm1d(_BatchNorm):
https://arxiv.org/abs/1502.03167
"""
@weak_script_method
def _check_input_dim(self, input):
if input.dim() != 2 and input.dim() != 3:
raise ValueError('expected 2D or 3D input (got {}D input)'
.format(input.dim()))
@weak_module
class BatchNorm2d(_BatchNorm):
r"""Applies Batch Normalization over a 4D input (a mini-batch of 2D inputs
with additional channel dimension) as described in the paper
@ -244,14 +238,12 @@ class BatchNorm2d(_BatchNorm):
https://arxiv.org/abs/1502.03167
"""
@weak_script_method
def _check_input_dim(self, input):
if input.dim() != 4:
raise ValueError('expected 4D input (got {}D input)'
.format(input.dim()))
@weak_module
class BatchNorm3d(_BatchNorm):
r"""Applies Batch Normalization over a 5D input (a mini-batch of 3D inputs
with additional channel dimension) as described in the paper
@ -319,7 +311,6 @@ class BatchNorm3d(_BatchNorm):
https://arxiv.org/abs/1502.03167
"""
@weak_script_method
def _check_input_dim(self, input):
if input.dim() != 5:
raise ValueError('expected 5D input (got {}D input)'

View File

@ -6,10 +6,9 @@ from .. import functional as F
from .. import init
from .module import Module
from .utils import _single, _pair, _triple
from ..._jit_internal import weak_module, weak_script_method, List
from ..._jit_internal import List
@weak_module
class _ConvNd(Module):
__constants__ = ['stride', 'padding', 'dilation', 'groups', 'bias',
@ -74,7 +73,6 @@ class _ConvNd(Module):
self.padding_mode = 'zeros'
@weak_module
class Conv1d(_ConvNd):
r"""Applies a 1D convolution over an input signal composed of several input
planes.
@ -192,7 +190,6 @@ class Conv1d(_ConvNd):
in_channels, out_channels, kernel_size, stride, padding, dilation,
False, _single(0), groups, bias, padding_mode)
@weak_script_method
def forward(self, input):
if self.padding_mode == 'circular':
expanded_padding = ((self.padding[0] + 1) // 2, self.padding[0] // 2)
@ -203,7 +200,6 @@ class Conv1d(_ConvNd):
self.padding, self.dilation, self.groups)
@weak_module
class Conv2d(_ConvNd):
r"""Applies a 2D convolution over an input signal composed of several input
planes.
@ -333,7 +329,6 @@ class Conv2d(_ConvNd):
in_channels, out_channels, kernel_size, stride, padding, dilation,
False, _pair(0), groups, bias, padding_mode)
@weak_script_method
def forward(self, input):
if self.padding_mode == 'circular':
expanded_padding = ((self.padding[1] + 1) // 2, self.padding[1] // 2,
@ -345,7 +340,6 @@ class Conv2d(_ConvNd):
self.padding, self.dilation, self.groups)
@weak_module
class Conv3d(_ConvNd):
r"""Applies a 3D convolution over an input signal composed of several input
planes.
@ -470,7 +464,6 @@ class Conv3d(_ConvNd):
in_channels, out_channels, kernel_size, stride, padding, dilation,
False, _triple(0), groups, bias, padding_mode)
@weak_script_method
def forward(self, input):
if self.padding_mode == 'circular':
expanded_padding = ((self.padding[2] + 1) // 2, self.padding[2] // 2,
@ -483,9 +476,7 @@ class Conv3d(_ConvNd):
self.padding, self.dilation, self.groups)
@weak_module
class _ConvTransposeMixin(object):
@weak_script_method
def forward(self, input, output_size=None):
# type(Tensor, Optional[List[int]]) -> Tensor
output_padding = self._output_padding(input, output_size, self.stride, self.padding, self.kernel_size)
@ -497,7 +488,6 @@ class _ConvTransposeMixin(object):
else:
return func(input, self.weight, self.bias)
@weak_script_method
def _output_padding(self, input, output_size, stride, padding, kernel_size):
# type: (Tensor, Optional[List[int]], List[int], List[int], List[int]) -> List[int]
if output_size is None:
@ -537,7 +527,6 @@ class _ConvTransposeMixin(object):
return ret
@weak_module
class ConvTranspose1d(_ConvTransposeMixin, _ConvNd):
r"""Applies a 1D transposed convolution operator over an input image
composed of several input planes.
@ -638,7 +627,6 @@ class ConvTranspose1d(_ConvTransposeMixin, _ConvNd):
in_channels, out_channels, kernel_size, stride, padding, dilation,
True, output_padding, groups, bias, padding_mode)
@weak_script_method
def forward(self, input, output_size=None):
# type: (Tensor, Optional[List[int]]) -> Tensor
if self.padding_mode != 'zeros':
@ -650,7 +638,6 @@ class ConvTranspose1d(_ConvTransposeMixin, _ConvNd):
output_padding, self.groups, self.dilation)
@weak_module
class ConvTranspose2d(_ConvTransposeMixin, _ConvNd):
r"""Applies a 2D transposed convolution operator over an input image
composed of several input planes.
@ -786,7 +773,6 @@ class ConvTranspose2d(_ConvTransposeMixin, _ConvNd):
in_channels, out_channels, kernel_size, stride, padding, dilation,
True, output_padding, groups, bias, padding_mode)
@weak_script_method
def forward(self, input, output_size=None):
# type: (Tensor, Optional[List[int]]) -> Tensor
if self.padding_mode != 'zeros':
@ -799,7 +785,6 @@ class ConvTranspose2d(_ConvTransposeMixin, _ConvNd):
output_padding, self.groups, self.dilation)
@weak_module
class ConvTranspose3d(_ConvTransposeMixin, _ConvNd):
r"""Applies a 3D transposed convolution operator over an input image composed of several input
planes.
@ -931,7 +916,6 @@ class ConvTranspose3d(_ConvTransposeMixin, _ConvNd):
in_channels, out_channels, kernel_size, stride, padding, dilation,
True, output_padding, groups, bias, padding_mode)
@weak_script_method
def forward(self, input, output_size=None):
# type: (Tensor, Optional[List[int]]) -> Tensor
if self.padding_mode != 'zeros':

View File

@ -1,9 +1,7 @@
from .module import Module
from .. import functional as F
from ..._jit_internal import weak_module, weak_script_method
@weak_module
class PairwiseDistance(Module):
r"""
Computes the batchwise pairwise distance between vectors :math:`v_1`, :math:`v_2` using the p-norm:
@ -35,12 +33,10 @@ class PairwiseDistance(Module):
self.eps = eps
self.keepdim = keepdim
@weak_script_method
def forward(self, x1, x2):
return F.pairwise_distance(x1, x2, self.norm, self.eps, self.keepdim)
@weak_module
class CosineSimilarity(Module):
r"""Returns cosine similarity between :math:`x_1` and :math:`x_2`, computed along dim.
@ -68,6 +64,5 @@ class CosineSimilarity(Module):
self.dim = dim
self.eps = eps
@weak_script_method
def forward(self, x1, x2):
return F.cosine_similarity(x1, x2, self.dim, self.eps)

View File

@ -1,6 +1,5 @@
from .module import Module
from .. import functional as F
from ..._jit_internal import weak_module, weak_script_method
class _DropoutNd(Module):
@ -18,7 +17,6 @@ class _DropoutNd(Module):
return 'p={}, inplace={}'.format(self.p, self.inplace)
@weak_module
class Dropout(_DropoutNd):
r"""During training, randomly zeroes some of the elements of the input
tensor with probability :attr:`p` using samples from a Bernoulli
@ -52,12 +50,10 @@ class Dropout(_DropoutNd):
detectors: https://arxiv.org/abs/1207.0580
"""
@weak_script_method
def forward(self, input):
return F.dropout(input, self.p, self.training, self.inplace)
@weak_module
class Dropout2d(_DropoutNd):
r"""Randomly zero out entire channels (a channel is a 2D feature map,
e.g., the :math:`j`-th channel of the :math:`i`-th sample in the
@ -96,12 +92,10 @@ class Dropout2d(_DropoutNd):
http://arxiv.org/abs/1411.4280
"""
@weak_script_method
def forward(self, input):
return F.dropout2d(input, self.p, self.training, self.inplace)
@weak_module
class Dropout3d(_DropoutNd):
r"""Randomly zero out entire channels (a channel is a 3D feature map,
e.g., the :math:`j`-th channel of the :math:`i`-th sample in the
@ -140,12 +134,10 @@ class Dropout3d(_DropoutNd):
http://arxiv.org/abs/1411.4280
"""
@weak_script_method
def forward(self, input):
return F.dropout3d(input, self.p, self.training, self.inplace)
@weak_module
class AlphaDropout(_DropoutNd):
r"""Applies Alpha Dropout over the input.
@ -184,14 +176,11 @@ class AlphaDropout(_DropoutNd):
.. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515
"""
@weak_script_method
def forward(self, input):
return F.alpha_dropout(input, self.p, self.training)
@weak_module
class FeatureAlphaDropout(_DropoutNd):
@weak_script_method
def forward(self, input):
return F.feature_alpha_dropout(input, self.p, self.training)

View File

@ -1,10 +1,8 @@
# coding=utf-8
from .module import Module
from .. import functional as F
from ..._jit_internal import weak_module, weak_script_method
@weak_module
class Fold(Module):
r"""Combines an array of sliding local blocks into a large containing
tensor.
@ -101,7 +99,6 @@ class Fold(Module):
self.padding = padding
self.stride = stride
@weak_script_method
def forward(self, input):
return F.fold(input, self.output_size, self.kernel_size, self.dilation,
self.padding, self.stride)
@ -113,7 +110,6 @@ class Fold(Module):
)
@weak_module
class Unfold(Module):
r"""Extracts sliding local blocks from a batched input tensor.
@ -217,7 +213,6 @@ class Unfold(Module):
self.padding = padding
self.stride = stride
@weak_script_method
def forward(self, input):
return F.unfold(input, self.kernel_size, self.dilation,
self.padding, self.stride)

View File

@ -1,6 +1,5 @@
from .batchnorm import _BatchNorm
from .. import functional as F
from ..._jit_internal import weak_module, weak_script_method
class _InstanceNorm(_BatchNorm):
@ -9,7 +8,6 @@ class _InstanceNorm(_BatchNorm):
super(_InstanceNorm, self).__init__(
num_features, eps, momentum, affine, track_running_stats)
@weak_script_method
def _check_input_dim(self, input):
raise NotImplementedError
@ -43,7 +41,6 @@ class _InstanceNorm(_BatchNorm):
state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)
@weak_script_method
def forward(self, input):
self._check_input_dim(input)
@ -52,7 +49,6 @@ class _InstanceNorm(_BatchNorm):
self.training or not self.track_running_stats, self.momentum, self.eps)
@weak_module
class InstanceNorm1d(_InstanceNorm):
r"""Applies Instance Normalization over a 3D input (a mini-batch of 1D
inputs with optional additional channel dimension) as described in the paper
@ -121,7 +117,6 @@ class InstanceNorm1d(_InstanceNorm):
https://arxiv.org/abs/1607.08022
"""
@weak_script_method
def _check_input_dim(self, input):
if input.dim() == 2:
raise ValueError(
@ -135,7 +130,6 @@ class InstanceNorm1d(_InstanceNorm):
.format(input.dim()))
@weak_module
class InstanceNorm2d(_InstanceNorm):
r"""Applies Instance Normalization over a 4D input (a mini-batch of 2D inputs
with additional channel dimension) as described in the paper
@ -204,14 +198,12 @@ class InstanceNorm2d(_InstanceNorm):
https://arxiv.org/abs/1607.08022
"""
@weak_script_method
def _check_input_dim(self, input):
if input.dim() != 4:
raise ValueError('expected 4D input (got {}D input)'
.format(input.dim()))
@weak_module
class InstanceNorm3d(_InstanceNorm):
r"""Applies Instance Normalization over a 5D input (a mini-batch of 3D inputs
with additional channel dimension) as described in the paper
@ -280,7 +272,6 @@ class InstanceNorm3d(_InstanceNorm):
https://arxiv.org/abs/1607.08022
"""
@weak_script_method
def _check_input_dim(self, input):
if input.dim() != 5:
raise ValueError('expected 5D input (got {}D input)'

View File

@ -5,10 +5,8 @@ from torch.nn.parameter import Parameter
from .. import functional as F
from .. import init
from .module import Module
from ..._jit_internal import weak_module, weak_script_method
@weak_module
class Identity(Module):
r"""A placeholder identity operator that is argument-insensitive.
@ -28,12 +26,10 @@ class Identity(Module):
def __init__(self, *args, **kwargs):
super(Identity, self).__init__()
@weak_script_method
def forward(self, input):
return input
@weak_module
class Linear(Module):
r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b`
@ -87,7 +83,6 @@ class Linear(Module):
bound = 1 / math.sqrt(fan_in)
init.uniform_(self.bias, -bound, bound)
@weak_script_method
def forward(self, input):
return F.linear(input, self.weight, self.bias)
@ -97,7 +92,6 @@ class Linear(Module):
)
@weak_module
class Bilinear(Module):
r"""Applies a bilinear transformation to the incoming data:
:math:`y = x_1 A x_2 + b`
@ -157,7 +151,6 @@ class Bilinear(Module):
if self.bias is not None:
init.uniform_(self.bias, -bound, bound)
@weak_script_method
def forward(self, input1, input2):
return F.bilinear(input1, input2, self.weight, self.bias)

View File

@ -3,7 +3,6 @@ import warnings
from .module import Module
from .. import functional as F
from .. import _reduction as _Reduction
from ..._jit_internal import weak_module, weak_script_method
class _Loss(Module):
@ -21,7 +20,6 @@ class _WeightedLoss(_Loss):
self.register_buffer('weight', weight)
@weak_module
class L1Loss(_Loss):
r"""Creates a criterion that measures the mean absolute error (MAE) between each element in
the input :math:`x` and target :math:`y`.
@ -86,12 +84,10 @@ class L1Loss(_Loss):
def __init__(self, size_average=None, reduce=None, reduction='mean'):
super(L1Loss, self).__init__(size_average, reduce, reduction)
@weak_script_method
def forward(self, input, target):
return F.l1_loss(input, target, reduction=self.reduction)
@weak_module
class NLLLoss(_WeightedLoss):
r"""The negative log likelihood loss. It is useful to train a classification
problem with `C` classes.
@ -204,12 +200,10 @@ class NLLLoss(_WeightedLoss):
super(NLLLoss, self).__init__(weight, size_average, reduce, reduction)
self.ignore_index = ignore_index
@weak_script_method
def forward(self, input, target):
return F.nll_loss(input, target, weight=self.weight, ignore_index=self.ignore_index, reduction=self.reduction)
@weak_module
class NLLLoss2d(NLLLoss):
def __init__(self, weight=None, size_average=None, ignore_index=-100,
reduce=None, reduction='mean'):
@ -219,7 +213,6 @@ class NLLLoss2d(NLLLoss):
super(NLLLoss2d, self).__init__(weight, size_average, ignore_index, reduce, reduction)
@weak_module
class PoissonNLLLoss(_Loss):
r"""Negative log likelihood loss with Poisson distribution of target.
@ -286,13 +279,11 @@ class PoissonNLLLoss(_Loss):
self.full = full
self.eps = eps
@weak_script_method
def forward(self, log_input, target):
return F.poisson_nll_loss(log_input, target, log_input=self.log_input, full=self.full,
eps=self.eps, reduction=self.reduction)
@weak_module
class KLDivLoss(_Loss):
r"""The `Kullback-Leibler divergence`_ Loss
@ -370,12 +361,10 @@ class KLDivLoss(_Loss):
def __init__(self, size_average=None, reduce=None, reduction='mean'):
super(KLDivLoss, self).__init__(size_average, reduce, reduction)
@weak_script_method
def forward(self, input, target):
return F.kl_div(input, target, reduction=self.reduction)
@weak_module
class MSELoss(_Loss):
r"""Creates a criterion that measures the mean squared error (squared L2 norm) between
each element in the input :math:`x` and target :math:`y`.
@ -438,12 +427,10 @@ class MSELoss(_Loss):
def __init__(self, size_average=None, reduce=None, reduction='mean'):
super(MSELoss, self).__init__(size_average, reduce, reduction)
@weak_script_method
def forward(self, input, target):
return F.mse_loss(input, target, reduction=self.reduction)
@weak_module
class BCELoss(_WeightedLoss):
r"""Creates a criterion that measures the Binary Cross Entropy
between the target and the output:
@ -507,12 +494,10 @@ class BCELoss(_WeightedLoss):
def __init__(self, weight=None, size_average=None, reduce=None, reduction='mean'):
super(BCELoss, self).__init__(weight, size_average, reduce, reduction)
@weak_script_method
def forward(self, input, target):
return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)
@weak_module
class BCEWithLogitsLoss(_Loss):
r"""This loss combines a `Sigmoid` layer and the `BCELoss` in one single
class. This version is more numerically stable than using a plain `Sigmoid`
@ -609,7 +594,6 @@ class BCEWithLogitsLoss(_Loss):
self.register_buffer('weight', weight)
self.register_buffer('pos_weight', pos_weight)
@weak_script_method
def forward(self, input, target):
return F.binary_cross_entropy_with_logits(input, target,
self.weight,
@ -617,7 +601,6 @@ class BCEWithLogitsLoss(_Loss):
reduction=self.reduction)
@weak_module
class HingeEmbeddingLoss(_Loss):
r"""Measures the loss given an input tensor :math:`x` and a labels tensor :math:`y`
(containing 1 or -1).
@ -673,12 +656,10 @@ class HingeEmbeddingLoss(_Loss):
super(HingeEmbeddingLoss, self).__init__(size_average, reduce, reduction)
self.margin = margin
@weak_script_method
def forward(self, input, target):
return F.hinge_embedding_loss(input, target, margin=self.margin, reduction=self.reduction)
@weak_module
class MultiLabelMarginLoss(_Loss):
r"""Creates a criterion that optimizes a multi-class multi-classification
hinge loss (margin-based loss) between input :math:`x` (a 2D mini-batch `Tensor`)
@ -739,12 +720,10 @@ class MultiLabelMarginLoss(_Loss):
def __init__(self, size_average=None, reduce=None, reduction='mean'):
super(MultiLabelMarginLoss, self).__init__(size_average, reduce, reduction)
@weak_script_method
def forward(self, input, target):
return F.multilabel_margin_loss(input, target, reduction=self.reduction)
@weak_module
class SmoothL1Loss(_Loss):
r"""Creates a criterion that uses a squared term if the absolute
element-wise error falls below 1 and an L1 term otherwise.
@ -799,12 +778,10 @@ class SmoothL1Loss(_Loss):
def __init__(self, size_average=None, reduce=None, reduction='mean'):
super(SmoothL1Loss, self).__init__(size_average, reduce, reduction)
@weak_script_method
def forward(self, input, target):
return F.smooth_l1_loss(input, target, reduction=self.reduction)
@weak_module
class SoftMarginLoss(_Loss):
r"""Creates a criterion that optimizes a two-class classification
logistic loss between input tensor :math:`x` and target tensor :math:`y`
@ -842,12 +819,10 @@ class SoftMarginLoss(_Loss):
def __init__(self, size_average=None, reduce=None, reduction='mean'):
super(SoftMarginLoss, self).__init__(size_average, reduce, reduction)
@weak_script_method
def forward(self, input, target):
return F.soft_margin_loss(input, target, reduction=self.reduction)
@weak_module
class CrossEntropyLoss(_WeightedLoss):
r"""This criterion combines :func:`nn.LogSoftmax` and :func:`nn.NLLLoss` in one single class.
@ -936,13 +911,11 @@ class CrossEntropyLoss(_WeightedLoss):
super(CrossEntropyLoss, self).__init__(weight, size_average, reduce, reduction)
self.ignore_index = ignore_index
@weak_script_method
def forward(self, input, target):
return F.cross_entropy(input, target, weight=self.weight,
ignore_index=self.ignore_index, reduction=self.reduction)
@weak_module
class MultiLabelSoftMarginLoss(_WeightedLoss):
r"""Creates a criterion that optimizes a multi-label one-versus-all
loss based on max-entropy, between input :math:`x` and target :math:`y` of size
@ -986,12 +959,10 @@ class MultiLabelSoftMarginLoss(_WeightedLoss):
def __init__(self, weight=None, size_average=None, reduce=None, reduction='mean'):
super(MultiLabelSoftMarginLoss, self).__init__(weight, size_average, reduce, reduction)
@weak_script_method
def forward(self, input, target):
return F.multilabel_soft_margin_loss(input, target, weight=self.weight, reduction=self.reduction)
@weak_module
class CosineEmbeddingLoss(_Loss):
r"""Creates a criterion that measures the loss given input tensors
:math:`x_1`, :math:`x_2` and a `Tensor` label :math:`y` with values 1 or -1.
@ -1034,12 +1005,10 @@ class CosineEmbeddingLoss(_Loss):
super(CosineEmbeddingLoss, self).__init__(size_average, reduce, reduction)
self.margin = margin
@weak_script_method
def forward(self, input1, input2, target):
return F.cosine_embedding_loss(input1, input2, target, margin=self.margin, reduction=self.reduction)
@weak_module
class MarginRankingLoss(_Loss):
r"""Creates a criterion that measures the loss given
inputs :math:`x1`, :math:`x2`, two 1D mini-batch `Tensors`,
@ -1082,12 +1051,10 @@ class MarginRankingLoss(_Loss):
super(MarginRankingLoss, self).__init__(size_average, reduce, reduction)
self.margin = margin
@weak_script_method
def forward(self, input1, input2, target):
return F.margin_ranking_loss(input1, input2, target, margin=self.margin, reduction=self.reduction)
@weak_module
class MultiMarginLoss(_WeightedLoss):
r"""Creates a criterion that optimizes a multi-class classification hinge
loss (margin-based loss) between input :math:`x` (a 2D mini-batch `Tensor`) and
@ -1145,13 +1112,11 @@ class MultiMarginLoss(_WeightedLoss):
self.p = p
self.margin = margin
@weak_script_method
def forward(self, input, target):
return F.multi_margin_loss(input, target, p=self.p, margin=self.margin,
weight=self.weight, reduction=self.reduction)
@weak_module
class TripletMarginLoss(_Loss):
r"""Creates a criterion that measures the triplet loss given an input
tensors :math:`x1`, :math:`x2`, :math:`x3` and a margin with a value greater than :math:`0`.
@ -1221,13 +1186,11 @@ class TripletMarginLoss(_Loss):
self.eps = eps
self.swap = swap
@weak_script_method
def forward(self, anchor, positive, negative):
return F.triplet_margin_loss(anchor, positive, negative, margin=self.margin, p=self.p,
eps=self.eps, swap=self.swap, reduction=self.reduction)
@weak_module
class CTCLoss(_Loss):
r"""The Connectionist Temporal Classification loss.
@ -1327,7 +1290,6 @@ class CTCLoss(_Loss):
self.blank = blank
self.zero_infinity = zero_infinity
@weak_script_method
def forward(self, log_probs, targets, input_lengths, target_lengths):
return F.ctc_loss(log_probs, targets, input_lengths, target_lengths, self.blank, self.reduction,
self.zero_infinity)

View File

@ -4,10 +4,8 @@ from torch.nn.parameter import Parameter
from .module import Module
from .. import functional as F
from .. import init
from ..._jit_internal import weak_module, weak_script_method
@weak_module
class LocalResponseNorm(Module):
r"""Applies local response normalization over an input signal composed
of several input planes, where channels occupy the second dimension.
@ -45,7 +43,6 @@ class LocalResponseNorm(Module):
self.beta = beta
self.k = k
@weak_script_method
def forward(self, input):
return F.local_response_norm(input, self.size, self.alpha, self.beta,
self.k)
@ -71,7 +68,6 @@ class CrossMapLRN2d(Module):
return '{size}, alpha={alpha}, beta={beta}, k={k}'.format(**self.__dict__)
@weak_module
class LayerNorm(Module):
r"""Applies Layer Normalization over a mini-batch of inputs as described in
the paper `Layer Normalization`_ .
@ -151,7 +147,6 @@ class LayerNorm(Module):
init.ones_(self.weight)
init.zeros_(self.bias)
@weak_script_method
def forward(self, input):
return F.layer_norm(
input, self.normalized_shape, self.weight, self.bias, self.eps)
@ -161,7 +156,6 @@ class LayerNorm(Module):
'elementwise_affine={elementwise_affine}'.format(**self.__dict__)
@weak_module
class GroupNorm(Module):
r"""Applies Group Normalization over a mini-batch of inputs as described in
the paper `Group Normalization`_ .
@ -226,7 +220,6 @@ class GroupNorm(Module):
init.ones_(self.weight)
init.zeros_(self.bias)
@weak_script_method
def forward(self, input):
return F.group_norm(
input, self.num_groups, self.weight, self.bias, self.eps)

View File

@ -1,13 +1,11 @@
from .module import Module
from .utils import _pair, _quadruple, _ntuple
from .. import functional as F
from ..._jit_internal import weak_module, weak_script_method
# TODO: grad_output size asserts in THNN
@weak_module
class _ConstantPadNd(Module):
__constants__ = ['padding', 'value']
@ -15,7 +13,6 @@ class _ConstantPadNd(Module):
super(_ConstantPadNd, self).__init__()
self.value = value
@weak_script_method
def forward(self, input):
return F.pad(input, self.padding, 'constant', self.value)
@ -23,7 +20,6 @@ class _ConstantPadNd(Module):
return 'padding={}, value={}'.format(self.padding, self.value)
@weak_module
class ConstantPad1d(_ConstantPadNd):
r"""Pads the input tensor boundaries with a constant value.
@ -73,7 +69,6 @@ class ConstantPad1d(_ConstantPadNd):
self.padding = _pair(padding)
@weak_module
class ConstantPad2d(_ConstantPadNd):
r"""Pads the input tensor boundaries with a constant value.
@ -123,7 +118,6 @@ class ConstantPad2d(_ConstantPadNd):
self.padding = _quadruple(padding)
@weak_module
class ConstantPad3d(_ConstantPadNd):
r"""Pads the input tensor boundaries with a constant value.
@ -162,11 +156,9 @@ class ConstantPad3d(_ConstantPadNd):
self.padding = _ntuple(6)(padding)
@weak_module
class _ReflectionPadNd(Module):
__constants__ = ['padding']
@weak_script_method
def forward(self, input):
return F.pad(input, self.padding, 'reflect')
@ -174,7 +166,6 @@ class _ReflectionPadNd(Module):
return '{}'.format(self.padding)
@weak_module
class ReflectionPad1d(_ReflectionPadNd):
r"""Pads the input tensor using the reflection of the input boundary.
@ -214,7 +205,6 @@ class ReflectionPad1d(_ReflectionPadNd):
self.padding = _pair(padding)
@weak_module
class ReflectionPad2d(_ReflectionPadNd):
r"""Pads the input tensor using the reflection of the input boundary.
@ -265,11 +255,9 @@ class ReflectionPad2d(_ReflectionPadNd):
self.padding = _quadruple(padding)
@weak_module
class _ReplicationPadNd(Module):
__constants__ = ['padding']
@weak_script_method
def forward(self, input):
return F.pad(input, self.padding, 'replicate')
@ -277,7 +265,6 @@ class _ReplicationPadNd(Module):
return '{}'.format(self.padding)
@weak_module
class ReplicationPad1d(_ReplicationPadNd):
r"""Pads the input tensor using replication of the input boundary.
@ -317,7 +304,6 @@ class ReplicationPad1d(_ReplicationPadNd):
self.padding = _pair(padding)
@weak_module
class ReplicationPad2d(_ReplicationPadNd):
r"""Pads the input tensor using replication of the input boundary.
@ -368,7 +354,6 @@ class ReplicationPad2d(_ReplicationPadNd):
self.padding = _quadruple(padding)
@weak_module
class ReplicationPad3d(_ReplicationPadNd):
r"""Pads the input tensor using replication of the input boundary.
@ -407,7 +392,6 @@ class ReplicationPad3d(_ReplicationPadNd):
self.padding = _ntuple(6)(padding)
@weak_module
class ZeroPad2d(ConstantPad2d):
r"""Pads the input tensor boundaries with zero.

View File

@ -1,9 +1,7 @@
from .module import Module
from .. import functional as F
from ..._jit_internal import weak_module, weak_script_method
@weak_module
class PixelShuffle(Module):
r"""Rearranges elements in a tensor of shape :math:`(*, C \times r^2, H, W)`
to a tensor of shape :math:`(*, C, H \times r, W \times r)`.
@ -41,7 +39,6 @@ class PixelShuffle(Module):
super(PixelShuffle, self).__init__()
self.upscale_factor = upscale_factor
@weak_script_method
def forward(self, input):
return F.pixel_shuffle(input, self.upscale_factor)

View File

@ -1,10 +1,8 @@
from .module import Module
from .utils import _single, _pair, _triple
from .. import functional as F
from ..._jit_internal import weak_module, weak_script_method
@weak_module
class _MaxPoolNd(Module):
__constants__ = ['kernel_size', 'stride', 'padding', 'dilation',
'return_indices', 'ceil_mode']
@ -24,7 +22,6 @@ class _MaxPoolNd(Module):
', dilation={dilation}, ceil_mode={ceil_mode}'.format(**self.__dict__)
@weak_module
class MaxPool1d(_MaxPoolNd):
r"""Applies a 1D max pooling over an input signal composed of several input
planes.
@ -68,7 +65,6 @@ class MaxPool1d(_MaxPoolNd):
https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
"""
@weak_script_method
def forward(self, input):
return F.max_pool1d(input, self.kernel_size, self.stride,
self.padding, self.dilation, self.ceil_mode,
@ -79,7 +75,6 @@ class MaxPool1d(_MaxPoolNd):
', dilation={dilation}, ceil_mode={ceil_mode}'.format(**self.__dict__)
@weak_module
class MaxPool2d(_MaxPoolNd):
r"""Applies a 2D max pooling over an input signal composed of several input
planes.
@ -139,14 +134,12 @@ class MaxPool2d(_MaxPoolNd):
https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
"""
@weak_script_method
def forward(self, input):
return F.max_pool2d(input, self.kernel_size, self.stride,
self.padding, self.dilation, self.ceil_mode,
self.return_indices)
@weak_module
class MaxPool3d(_MaxPoolNd):
r"""Applies a 3D max pooling over an input signal composed of several input
planes.
@ -210,14 +203,12 @@ class MaxPool3d(_MaxPoolNd):
https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
""" # noqa: E501
@weak_script_method
def forward(self, input):
return F.max_pool3d(input, self.kernel_size, self.stride,
self.padding, self.dilation, self.ceil_mode,
self.return_indices)
@weak_module
class _MaxUnpoolNd(Module):
def extra_repr(self):
@ -226,7 +217,6 @@ class _MaxUnpoolNd(Module):
)
@weak_module
class MaxUnpool1d(_MaxUnpoolNd):
r"""Computes a partial inverse of :class:`MaxPool1d`.
@ -292,7 +282,6 @@ class MaxUnpool1d(_MaxUnpoolNd):
self.padding, output_size)
@weak_module
class MaxUnpool2d(_MaxUnpoolNd):
r"""Computes a partial inverse of :class:`MaxPool2d`.
@ -366,7 +355,6 @@ class MaxUnpool2d(_MaxUnpoolNd):
self.padding, output_size)
@weak_module
class MaxUnpool3d(_MaxUnpoolNd):
r"""Computes a partial inverse of :class:`MaxPool3d`.
@ -429,7 +417,6 @@ class MaxUnpool3d(_MaxUnpoolNd):
self.padding, output_size)
@weak_module
class _AvgPoolNd(Module):
__constants__ = ['kernel_size', 'stride', 'padding', 'ceil_mode', 'count_include_pad']
@ -439,7 +426,6 @@ class _AvgPoolNd(Module):
)
@weak_module
class AvgPool1d(_AvgPoolNd):
r"""Applies a 1D average pooling over an input signal composed of several
input planes.
@ -490,14 +476,12 @@ class AvgPool1d(_AvgPoolNd):
self.ceil_mode = ceil_mode
self.count_include_pad = count_include_pad
@weak_script_method
def forward(self, input):
return F.avg_pool1d(
input, self.kernel_size, self.stride, self.padding, self.ceil_mode,
self.count_include_pad)
@weak_module
class AvgPool2d(_AvgPoolNd):
r"""Applies a 2D average pooling over an input signal composed of several input
planes.
@ -557,13 +541,11 @@ class AvgPool2d(_AvgPoolNd):
self.ceil_mode = ceil_mode
self.count_include_pad = count_include_pad
@weak_script_method
def forward(self, input):
return F.avg_pool2d(input, self.kernel_size, self.stride,
self.padding, self.ceil_mode, self.count_include_pad)
@weak_module
class AvgPool3d(_AvgPoolNd):
r"""Applies a 3D average pooling over an input signal composed of several input
planes.
@ -630,7 +612,6 @@ class AvgPool3d(_AvgPoolNd):
self.ceil_mode = ceil_mode
self.count_include_pad = count_include_pad
@weak_script_method
def forward(self, input):
return F.avg_pool3d(input, self.kernel_size, self.stride,
self.padding, self.ceil_mode, self.count_include_pad)
@ -642,7 +623,6 @@ class AvgPool3d(_AvgPoolNd):
self.__dict__.setdefault('count_include_pad', True)
@weak_module
class FractionalMaxPool2d(Module):
r"""Applies a 2D fractional max pooling over an input signal composed of several input planes.
@ -694,7 +674,6 @@ class FractionalMaxPool2d(Module):
raise ValueError("output_ratio must be between 0 and 1 (got {})"
.format(output_ratio))
@weak_script_method
def forward(self, input):
return F.fractional_max_pool2d(
input, self.kernel_size, self.output_size, self.output_ratio,
@ -702,7 +681,6 @@ class FractionalMaxPool2d(Module):
_random_samples=self._random_samples)
@weak_module
class FractionalMaxPool3d(Module):
r"""Applies a 3D fractional max pooling over an input signal composed of several input planes.
@ -754,7 +732,6 @@ class FractionalMaxPool3d(Module):
raise ValueError("output_ratio must be between 0 and 1 (got {})"
.format(output_ratio))
@weak_script_method
def forward(self, input):
return F.fractional_max_pool3d(
input, self.kernel_size, self.output_size, self.output_ratio,
@ -762,7 +739,6 @@ class FractionalMaxPool3d(Module):
_random_samples=self._random_samples)
@weak_module
class _LPPoolNd(Module):
__constants__ = ['norm_type', 'kernel_size', 'stride', 'ceil_mode']
@ -778,7 +754,6 @@ class _LPPoolNd(Module):
'ceil_mode={ceil_mode}'.format(**self.__dict__)
@weak_module
class LPPool1d(_LPPoolNd):
r"""Applies a 1D power-average pooling over an input signal composed of several input
planes.
@ -814,14 +789,11 @@ class LPPool1d(_LPPoolNd):
>>> output = m(input)
"""
@weak_script_method
@weak_script_method
def forward(self, input):
return F.lp_pool1d(input, float(self.norm_type), self.kernel_size,
self.stride, self.ceil_mode)
@weak_module
class LPPool2d(_LPPoolNd):
r"""Applies a 2D power-average pooling over an input signal composed of several input
planes.
@ -871,13 +843,11 @@ class LPPool2d(_LPPoolNd):
"""
@weak_script_method
def forward(self, input):
return F.lp_pool2d(input, float(self.norm_type), self.kernel_size,
self.stride, self.ceil_mode)
@weak_module
class _AdaptiveMaxPoolNd(Module):
__constants__ = ['output_size', 'return_indices']
@ -893,7 +863,6 @@ class _AdaptiveMaxPoolNd(Module):
# output shapes are, and how the operation computes output.
@weak_module
class AdaptiveMaxPool1d(_AdaptiveMaxPoolNd):
r"""Applies a 1D adaptive max pooling over an input signal composed of several input planes.
@ -913,12 +882,10 @@ class AdaptiveMaxPool1d(_AdaptiveMaxPoolNd):
"""
@weak_script_method
def forward(self, input):
return F.adaptive_max_pool1d(input, self.output_size, self.return_indices)
@weak_module
class AdaptiveMaxPool2d(_AdaptiveMaxPoolNd):
r"""Applies a 2D adaptive max pooling over an input signal composed of several input planes.
@ -949,12 +916,10 @@ class AdaptiveMaxPool2d(_AdaptiveMaxPoolNd):
"""
@weak_script_method
def forward(self, input):
return F.adaptive_max_pool2d(input, self.output_size, self.return_indices)
@weak_module
class AdaptiveMaxPool3d(_AdaptiveMaxPoolNd):
r"""Applies a 3D adaptive max pooling over an input signal composed of several input planes.
@ -986,12 +951,10 @@ class AdaptiveMaxPool3d(_AdaptiveMaxPoolNd):
"""
@weak_script_method
def forward(self, input):
return F.adaptive_max_pool3d(input, self.output_size, self.return_indices)
@weak_module
class _AdaptiveAvgPoolNd(Module):
__constants__ = ['output_size']
@ -1003,7 +966,6 @@ class _AdaptiveAvgPoolNd(Module):
return 'output_size={}'.format(self.output_size)
@weak_module
class AdaptiveAvgPool1d(_AdaptiveAvgPoolNd):
r"""Applies a 1D adaptive average pooling over an input signal composed of several input planes.
@ -1021,12 +983,10 @@ class AdaptiveAvgPool1d(_AdaptiveAvgPoolNd):
"""
@weak_script_method
def forward(self, input):
return F.adaptive_avg_pool1d(input, self.output_size)
@weak_module
class AdaptiveAvgPool2d(_AdaptiveAvgPoolNd):
r"""Applies a 2D adaptive average pooling over an input signal composed of several input planes.
@ -1055,12 +1015,10 @@ class AdaptiveAvgPool2d(_AdaptiveAvgPoolNd):
"""
@weak_script_method
def forward(self, input):
return F.adaptive_avg_pool2d(input, self.output_size)
@weak_module
class AdaptiveAvgPool3d(_AdaptiveAvgPoolNd):
r"""Applies a 3D adaptive average pooling over an input signal composed of several input planes.
@ -1089,6 +1047,5 @@ class AdaptiveAvgPool3d(_AdaptiveAvgPoolNd):
"""
@weak_script_method
def forward(self, input):
return F.adaptive_avg_pool3d(input, self.output_size)

View File

@ -8,8 +8,7 @@ from ..parameter import Parameter
from ..utils.rnn import PackedSequence, get_packed_sequence
from .. import init
from .. import _VF
from ..._jit_internal import weak_module, weak_script_method, weak_script, \
_parameter_list
from ..._jit_internal import _parameter_list
_rnn_impls = {
'GRU': _VF.gru,
@ -18,7 +17,6 @@ _rnn_impls = {
}
@weak_script
def apply_permutation(tensor, permutation, dim=1):
# type: (Tensor, Tensor, int) -> Tensor
return tensor.index_select(dim, permutation)
@ -139,7 +137,6 @@ class RNNBase(Module):
def _get_flat_weights(self):
return self._flat_weights
@weak_script_method
def check_input(self, input, batch_sizes):
# type: (Tensor, Optional[Tensor]) -> None
expected_input_dim = 2 if batch_sizes is not None else 3
@ -152,7 +149,6 @@ class RNNBase(Module):
'input.size(-1) must be equal to input_size. Expected {}, got {}'.format(
self.input_size, input.size(-1)))
@weak_script_method
def get_expected_hidden_size(self, input, batch_sizes):
# type: (Tensor, Optional[Tensor]) -> Tuple[int, int, int]
if batch_sizes is not None:
@ -165,7 +161,6 @@ class RNNBase(Module):
mini_batch, self.hidden_size)
return expected_hidden_size
@weak_script_method
def check_hidden_size(self, hx, expected_hidden_size, msg='Expected hidden size {}, got {}'):
# type: (Tensor, Tuple[int, int, int], str) -> None
if hx.size() != expected_hidden_size:
@ -374,7 +369,6 @@ class RNN(RNNBase):
super(RNN, self).__init__(mode, *args, **kwargs)
@weak_module
class LSTM(RNNBase):
r"""Applies a multi-layer long short-term memory (LSTM) RNN to an input
sequence.
@ -484,7 +478,6 @@ class LSTM(RNNBase):
def __init__(self, *args, **kwargs):
super(LSTM, self).__init__('LSTM', *args, **kwargs)
@weak_script_method
def check_forward_args(self, input, hidden, batch_sizes):
# type: (Tensor, Tuple[Tensor, Tensor], Optional[Tensor]) -> None
self.check_input(input, batch_sizes)
@ -495,14 +488,12 @@ class LSTM(RNNBase):
self.check_hidden_size(hidden[1], expected_hidden_size,
'Expected hidden[1] size {}, got {}')
@weak_script_method
def permute_hidden(self, hx, permutation):
# type: (Tuple[Tensor, Tensor], Optional[Tensor]) -> Tuple[Tensor, Tensor]
if permutation is None:
return hx
return apply_permutation(hx[0], permutation), apply_permutation(hx[1], permutation)
@weak_script_method
def forward_impl(self, input, hx, batch_sizes, max_batch_size, sorted_indices):
# type: (Tensor, Optional[Tuple[Tensor, Tensor]], Optional[Tensor], int, Optional[Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa
if hx is None:
@ -528,7 +519,7 @@ class LSTM(RNNBase):
return output, hidden
@weak_script_method
@torch._jit_internal.export
def forward_tensor(self, input, hx=None):
# type: (Tensor, Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
batch_sizes = None
@ -540,7 +531,7 @@ class LSTM(RNNBase):
return output, self.permute_hidden(hidden, unsorted_indices)
@weak_script_method
@torch._jit_internal.export
def forward_packed(self, input, hx=None):
# type: (Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]], Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]], Tuple[Tensor, Tensor]] # noqa
input, batch_sizes, sorted_indices, unsorted_indices = input
@ -552,6 +543,7 @@ class LSTM(RNNBase):
output = get_packed_sequence(output, batch_sizes, sorted_indices, unsorted_indices)
return output, self.permute_hidden(hidden, unsorted_indices)
@torch._jit_internal.ignore
def forward(self, input, hx=None):
if isinstance(input, PackedSequence):
return self.forward_packed(input, hx)
@ -694,14 +686,12 @@ class RNNCellBase(Module):
s += ', nonlinearity={nonlinearity}'
return s.format(**self.__dict__)
@weak_script_method
def check_forward_input(self, input):
if input.size(1) != self.input_size:
raise RuntimeError(
"input has inconsistent input_size: got {}, expected {}".format(
input.size(1), self.input_size))
@weak_script_method
def check_forward_hidden(self, input, hx, hidden_label=''):
# type: (Tensor, Tensor, str) -> None
if input.size(0) != hx.size(0):
@ -720,7 +710,6 @@ class RNNCellBase(Module):
init.uniform_(weight, -stdv, stdv)
@weak_module
class RNNCell(RNNCellBase):
r"""An Elman RNN cell with tanh or ReLU non-linearity.
@ -784,7 +773,6 @@ class RNNCell(RNNCellBase):
super(RNNCell, self).__init__(input_size, hidden_size, bias, num_chunks=1)
self.nonlinearity = nonlinearity
@weak_script_method
def forward(self, input, hx=None):
# type: (Tensor, Optional[Tensor]) -> Tensor
self.check_forward_input(input)
@ -810,7 +798,6 @@ class RNNCell(RNNCellBase):
return ret
@weak_module
class LSTMCell(RNNCellBase):
r"""A long short-term memory (LSTM) cell.
@ -875,7 +862,6 @@ class LSTMCell(RNNCellBase):
def __init__(self, input_size, hidden_size, bias=True):
super(LSTMCell, self).__init__(input_size, hidden_size, bias, num_chunks=4)
@weak_script_method
def forward(self, input, hx=None):
# type: (Tensor, Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tensor]
self.check_forward_input(input)
@ -891,7 +877,6 @@ class LSTMCell(RNNCellBase):
)
@weak_module
class GRUCell(RNNCellBase):
r"""A gated recurrent unit (GRU) cell
@ -957,7 +942,6 @@ class GRUCell(RNNCellBase):
def __init__(self, input_size, hidden_size, bias=True):
super(GRUCell, self).__init__(input_size, hidden_size, bias, num_chunks=3)
@weak_script_method
def forward(self, input, hx=None):
# type: (Tensor, Optional[Tensor]) -> Tensor
self.check_forward_input(input)

View File

@ -4,10 +4,8 @@ from torch.nn.parameter import Parameter
from .module import Module
from .. import functional as F
from .. import init
from torch._jit_internal import weak_module, weak_script_method
@weak_module
class Embedding(Module):
r"""A simple lookup table that stores embeddings of a fixed dictionary and size.
@ -110,7 +108,6 @@ class Embedding(Module):
with torch.no_grad():
self.weight[self.padding_idx].fill_(0)
@weak_script_method
def forward(self, input):
return F.embedding(
input, self.weight, self.padding_idx, self.max_norm,
@ -173,7 +170,6 @@ class Embedding(Module):
return embedding
@weak_module
class EmbeddingBag(Module):
r"""Computes sums or means of 'bags' of embeddings, without instantiating the
intermediate embeddings.
@ -277,7 +273,6 @@ class EmbeddingBag(Module):
def reset_parameters(self):
init.normal_(self.weight)
@weak_script_method
def forward(self, input, offsets=None, per_sample_weights=None):
# type: (Tensor, Optional[Tensor], Optional[Tensor]) -> Tensor
return F.embedding_bag(input, self.weight, offsets,

View File

@ -1,9 +1,7 @@
from .module import Module
from .. import functional as F
from ..._jit_internal import weak_module, weak_script_method
@weak_module
class Upsample(Module):
r"""Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D (volumetric) data.
@ -129,7 +127,6 @@ class Upsample(Module):
self.mode = mode
self.align_corners = align_corners
@weak_script_method
def forward(self, input):
return F.interpolate(input, self.size, self.scale_factor, self.mode, self.align_corners)
@ -142,7 +139,6 @@ class Upsample(Module):
return info
@weak_module
class UpsamplingNearest2d(Upsample):
r"""Applies a 2D nearest neighbor upsampling to an input signal composed of several input
channels.
@ -188,7 +184,6 @@ class UpsamplingNearest2d(Upsample):
super(UpsamplingNearest2d, self).__init__(size, scale_factor, mode='nearest')
@weak_module
class UpsamplingBilinear2d(Upsample):
r"""Applies a 2D bilinear upsampling to an input signal composed of several input
channels.

View File

@ -23,10 +23,9 @@ def _is_jit_enabled():
# Check if we can safely replicate the module.
# there are three types of module:
# there are two types of module:
# 1. python modules
# 2. weak python modules (nn.Module annotated by @weak_module)
# 3. ScriptModule
# 2. ScriptModule
#
# currently a module cannot be replicated properly if the descendants of
# any ScriptModule contains python module (type 1 above)

View File

@ -5,9 +5,7 @@ from __future__ import unicode_literals
from .. import functional as F
from ...modules.module import Module
from ...._jit_internal import weak_module, weak_script_method
@weak_module
class ReLU(Module):
r"""Applies quantized rectified linear unit function element-wise:
@ -37,7 +35,6 @@ class ReLU(Module):
assert not inplace, 'torch.nn.quantized.ReLU does not support inplace'
@weak_script_method
def forward(self, input):
return F.relu(input)

View File

@ -2,9 +2,7 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import torch
from ...modules.module import Module
from ...modules.linear import Linear as NNLinear
from ...._jit_internal import weak_module
@weak_module
class Quantize(Module):
r"""Quantizes an incoming tensor
Args:
@ -39,7 +37,6 @@ class Quantize(Module):
def from_float(mod):
return Quantize(mod.qparams[0].item(), mod.qparams[1].item(), torch.quint8)
@weak_module
class DeQuantize(Module):
r"""Dequantizes an incoming tensor
@ -65,7 +62,6 @@ class DeQuantize(Module):
def from_float(mod):
return DeQuantize()
@weak_module
class Linear(NNLinear):
r"""
A quantized linear module with quantized tensor as inputs