mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Remove weak script (#22212)
Summary: * Deletes all weak script decorators / associated data structures / methods * In order to keep supporting the standard library in script, this enables recursive script on any function defined in `torch.nn` * Most changes in `torch/nn` are the result of `ag -Q "weak" torch/nn/ -l | xargs sed -i '/weak/d'`, only `rnn.py` needed manual editing to use the `ignore` and `export` to continue supporting the overloaded `forward` methods * `Sequential`/`ModuleList` no longer need to be added to constants since they are compiled on demand This should also fix https://github.com/pytorch/pytorch/issues/22212 Pull Request resolved: https://github.com/pytorch/pytorch/pull/22212 Differential Revision: D15988346 Pulled By: driazati fbshipit-source-id: af223e3ad0580be895377312949997a70e988e4f
This commit is contained in:
parent
b93f29ded3
commit
10c4b98ade
|
|
@ -6979,7 +6979,7 @@ a")
|
|||
with self.assertRaisesRegex(RuntimeError, "'int' object is not iterable"):
|
||||
M()
|
||||
|
||||
def test_script_module_list_sequential_error(self):
|
||||
def test_script_module_list_sequential(self):
|
||||
class M(torch.jit.ScriptModule):
|
||||
def __init__(self, mod_list):
|
||||
super(M, self).__init__(False)
|
||||
|
|
@ -6991,25 +6991,21 @@ a")
|
|||
v = m(v)
|
||||
return v
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Did you forget to add it to __constants"):
|
||||
a = M(nn.Sequential(nn.ReLU()))
|
||||
with self.assertRaisesRegex(RuntimeError, "Did you forget to add it to __constants"):
|
||||
a = M(nn.ModuleList([nn.ReLU()]))
|
||||
m = M(nn.Sequential(nn.ReLU()))
|
||||
self.assertExportImportModule(m, (torch.randn(2, 2),))
|
||||
|
||||
def test_attr_module_constants_error(self):
|
||||
def test_attr_module_constants(self):
|
||||
class M2(torch.jit.ScriptModule):
|
||||
def __init__(self, mod_list):
|
||||
super(M2, self).__init__(False)
|
||||
self.mods = mod_list
|
||||
|
||||
@torch.jit.script_method
|
||||
def forward(self, v):
|
||||
def forward(self, x):
|
||||
return self.mods.forward(x)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Did you forget to add it to __constants"):
|
||||
M2(nn.Sequential(nn.ReLU()))
|
||||
with self.assertRaisesRegex(RuntimeError, "Did you forget to add it to __constants"):
|
||||
M2(nn.ModuleList([nn.ReLU()]))
|
||||
m = M2(nn.Sequential(nn.ReLU()))
|
||||
self.assertExportImportModule(m, (torch.randn(2, 2),))
|
||||
|
||||
def test_script_sequential_for(self):
|
||||
class Sub(torch.jit.ScriptModule):
|
||||
|
|
@ -11007,6 +11003,7 @@ a")
|
|||
with self.assertRaisesRegex(torch.jit.Error, "Exception"):
|
||||
foo(torch.tensor(0))
|
||||
|
||||
@unittest.skipIf(True, "Removing weak script")
|
||||
def test_weak_script_function(self):
|
||||
outer_var = 10
|
||||
outer_var2 = 11
|
||||
|
|
@ -11086,6 +11083,7 @@ a")
|
|||
eg = torch.zeros(3, dtype=torch.uint8)
|
||||
self.assertEqual(foo_traced(eg), foo(eg))
|
||||
|
||||
@unittest.skipIf(True, "Removing weak script")
|
||||
def test_weak_module(self):
|
||||
|
||||
@torch._jit_internal.weak_module
|
||||
|
|
@ -11161,6 +11159,7 @@ a")
|
|||
self.assertEqual(script_result, expected_result)
|
||||
self.assertEqual(script_result, script_result2)
|
||||
|
||||
@unittest.skipIf(True, "Removing weak script")
|
||||
def test_weak_module_parameters_and_buffers(self):
|
||||
weights = torch.randn(10, 10)
|
||||
bias = torch.randn(10)
|
||||
|
|
@ -11219,6 +11218,7 @@ a")
|
|||
self.assertEqual(strong_mod(inp), expected_result)
|
||||
self.assertExportImportModule(strong_mod, (inp,))
|
||||
|
||||
@unittest.skipIf(True, "Removing weak script")
|
||||
def test_weak_module_nested(self):
|
||||
@torch._jit_internal.weak_module
|
||||
class OtherWeak(torch.nn.Module):
|
||||
|
|
@ -11280,6 +11280,7 @@ a")
|
|||
+ F.linear(inp, 2 * torch.ones(10, 10), 2 * torch.ones(10))
|
||||
self.assertEqual(result, expected_result)
|
||||
|
||||
@unittest.skipIf(True, "Removing weak script")
|
||||
def test_weak_module_submodule(self):
|
||||
@torch._jit_internal.weak_module
|
||||
class Weak(torch.nn.Module):
|
||||
|
|
@ -11319,6 +11320,7 @@ a")
|
|||
with self.assertRaisesRegex(RuntimeError, "Cannot call a ScriptModule that is not a submodule of the caller"):
|
||||
strong_mod = Strong()
|
||||
|
||||
@unittest.skipIf(True, "Removing weak script")
|
||||
def test_weak_module_copying(self):
|
||||
class Submodule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
|
@ -11385,6 +11387,7 @@ a")
|
|||
|
||||
m = M()
|
||||
|
||||
@unittest.skipIf(True, "Removing weak script")
|
||||
def test_weak_module_attributes(self):
|
||||
tester = self
|
||||
|
||||
|
|
@ -11948,6 +11951,7 @@ a")
|
|||
|
||||
FileCheck().check_not("prim::PythonOp").run(cu.test.graph)
|
||||
|
||||
@unittest.skipIf(True, "Removing weak script")
|
||||
def test_overloading(self):
|
||||
@torch._jit_internal.weak_module
|
||||
class W(torch.nn.Module):
|
||||
|
|
@ -13623,6 +13627,9 @@ EXCLUDE_SCRIPT_MODULES = {
|
|||
'test_nn_AdaptiveAvgPool3d_tuple_none',
|
||||
'test_nn_AdaptiveMaxPool2d_tuple_none',
|
||||
'test_nn_AdaptiveMaxPool3d_tuple_none',
|
||||
|
||||
# Uses Module._backend, so this is not supported
|
||||
'test_nn_CrossMapLRN2d',
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -14552,10 +14559,6 @@ def add_nn_module_test(*args, **kwargs):
|
|||
|
||||
module_name = name.split("_")[0]
|
||||
|
||||
module = getattr(torch.nn, module_name, None)
|
||||
if module is None or torch._jit_internal.weak_types.get(module) is None:
|
||||
return
|
||||
|
||||
if 'desc' in kwargs and 'eval' in kwargs['desc']:
|
||||
# eval() is not supported, so skip these tests
|
||||
return
|
||||
|
|
|
|||
|
|
@ -4,29 +4,14 @@ can be used in other places in torch/ (namely torch.nn) without running into
|
|||
circular dependency problems
|
||||
"""
|
||||
|
||||
import weakref
|
||||
import inspect
|
||||
import weakref
|
||||
from torch._six import builtins
|
||||
|
||||
# Tracks standalone weak script functions
|
||||
compiled_weak_fns = weakref.WeakKeyDictionary() # noqa: T484
|
||||
|
||||
# Tracks which methods should be converted to strong methods
|
||||
weak_script_methods = weakref.WeakKeyDictionary() # noqa: T484
|
||||
|
||||
# Converted modules and their corresponding WeakScriptModuleProxy objects
|
||||
weak_modules = weakref.WeakKeyDictionary() # noqa: T484
|
||||
|
||||
# Types that have been declared as weak modules
|
||||
weak_types = weakref.WeakKeyDictionary() # noqa: T484
|
||||
|
||||
# Wrapper functions that can call either of 2 functions depending on a boolean
|
||||
# argument
|
||||
boolean_dispatched = weakref.WeakKeyDictionary() # noqa: T484
|
||||
|
||||
COMPILATION_PENDING = object()
|
||||
COMPILED = object()
|
||||
|
||||
|
||||
def createResolutionCallback(frames_up=0):
|
||||
"""
|
||||
|
|
@ -71,51 +56,41 @@ def createResolutionCallback(frames_up=0):
|
|||
return f_globals[key]
|
||||
elif hasattr(builtins, key):
|
||||
return getattr(builtins, key)
|
||||
else:
|
||||
return None
|
||||
|
||||
return env
|
||||
|
||||
|
||||
def weak_script(fn, _frames_up=0):
|
||||
def createResolutionCallbackFromClosure(fn):
|
||||
"""
|
||||
Marks a function as a weak script function. When used in a script function
|
||||
or ScriptModule, the weak script function will be lazily compiled and
|
||||
inlined in the graph. When not used in a script function, the weak script
|
||||
annotation has no effect.
|
||||
Create a resolutionCallback by introspecting the function instead of
|
||||
looking up the stack for the enclosing scope
|
||||
"""
|
||||
compiled_weak_fns[fn] = {
|
||||
"status": COMPILATION_PENDING,
|
||||
"compiled_fn": None,
|
||||
"rcb": createResolutionCallback(_frames_up + 1)
|
||||
}
|
||||
return fn
|
||||
var_names = fn.__code__.co_freevars
|
||||
|
||||
# map of captured name -> value
|
||||
free_vars = {}
|
||||
|
||||
def weak_module(cls):
|
||||
weak_types[cls] = {
|
||||
"method_stubs": None
|
||||
}
|
||||
return cls
|
||||
for index, name in enumerate(var_names):
|
||||
free_vars[name] = fn.__closure__[index].cell_contents
|
||||
f_globals = fn.__globals__
|
||||
|
||||
def env(key):
|
||||
if key in free_vars:
|
||||
return free_vars[key]
|
||||
elif hasattr(builtins, key):
|
||||
return getattr(builtins, key)
|
||||
else:
|
||||
return f_globals.get(key)
|
||||
|
||||
def weak_script_method(fn):
|
||||
weak_script_methods[fn] = {
|
||||
"rcb": createResolutionCallback(frames_up=2),
|
||||
"original_method": fn
|
||||
}
|
||||
return fn
|
||||
return env
|
||||
|
||||
|
||||
def boolean_dispatch(arg_name, arg_index, default, if_true, if_false, module_name, func_name):
|
||||
"""
|
||||
Dispatches to either of 2 weak script functions based on a boolean argument.
|
||||
Dispatches to either of 2 script functions based on a boolean argument.
|
||||
In TorchScript, the boolean argument must be constant so that the correct
|
||||
function to use can be determined at compile time.
|
||||
"""
|
||||
if compiled_weak_fns.get(if_true) is None or compiled_weak_fns.get(if_false) is None:
|
||||
raise RuntimeError("both functions must be weak script")
|
||||
|
||||
def fn(*args, **kwargs):
|
||||
dispatch_flag = False
|
||||
if arg_name in kwargs:
|
||||
|
|
|
|||
|
|
@ -23,7 +23,8 @@ Decl mergeTypesFromTypeComment(
|
|||
<< "Number of type annotations ("
|
||||
<< type_annotation_decl.params().size()
|
||||
<< ") did not match the number of "
|
||||
<< "function parameters (" << expected_num_annotations << ")";
|
||||
<< (is_method ? "method" : "function")
|
||||
<< " parameters (" << expected_num_annotations << ")";
|
||||
}
|
||||
auto old = decl.params();
|
||||
auto _new = type_annotation_decl.params();
|
||||
|
|
|
|||
|
|
@ -244,6 +244,11 @@ std::shared_ptr<SugaredValue> OverloadedMethodValue::call(
|
|||
<< err.str();
|
||||
}
|
||||
|
||||
bool should_recurse(py::object obj) {
|
||||
return py::cast<bool>(py::module::import("torch.jit")
|
||||
.attr("_is_recursive_script_enabled")(obj));
|
||||
}
|
||||
|
||||
std::shared_ptr<SugaredValue> ModuleValue::attr(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
|
|
@ -307,7 +312,7 @@ std::shared_ptr<SugaredValue> ModuleValue::attr(
|
|||
|
||||
// If recursive script mode is on, create a ScriptModule and register it as
|
||||
// as submodule or register a python method as a script::Method
|
||||
if (getRecursiveScriptMode()) {
|
||||
if (should_recurse(attr)) {
|
||||
if (py::isinstance(attr, py::module::import("torch.nn").attr("Module"))) {
|
||||
// If the module is a submodule of the py_module, convert it to a
|
||||
// ScriptModule and add it as a submodule to the script::Module. This
|
||||
|
|
@ -471,11 +476,6 @@ std::shared_ptr<SugaredValue> toSugaredValue(
|
|||
}
|
||||
}
|
||||
|
||||
auto weak_obj =
|
||||
py::module::import("torch.jit").attr("_try_get_weak_module")(obj);
|
||||
if (!weak_obj.is_none()) {
|
||||
obj = weak_obj;
|
||||
}
|
||||
if (auto callee = as_function(obj)) {
|
||||
return std::make_shared<FunctionValue>(callee);
|
||||
} else if (py::isinstance<py::module>(obj)) {
|
||||
|
|
@ -504,12 +504,6 @@ std::shared_ptr<SugaredValue> toSugaredValue(
|
|||
<< "which is currently not supported in Torchscript."
|
||||
<< "Please open a feature request to add it.";
|
||||
}
|
||||
|
||||
auto compiled_fn =
|
||||
py::module::import("torch.jit").attr("_try_compile_weak_script")(obj);
|
||||
if (auto callee = as_function(compiled_fn)) {
|
||||
return std::make_shared<FunctionValue>(callee);
|
||||
}
|
||||
}
|
||||
|
||||
py::object dispatched_fn =
|
||||
|
|
@ -528,7 +522,7 @@ std::shared_ptr<SugaredValue> toSugaredValue(
|
|||
}
|
||||
}
|
||||
|
||||
if (getRecursiveScriptMode() && py::isinstance<py::function>(obj)) {
|
||||
if (should_recurse(obj) && py::isinstance<py::function>(obj)) {
|
||||
auto compiled_fn =
|
||||
py::module::import("torch.jit").attr("_try_compile_fn")(obj);
|
||||
if (auto callee = as_function(compiled_fn)) {
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import torch.backends.cudnn as cudnn
|
|||
import torch.jit.annotations
|
||||
import torch._jit_internal as _jit_internal
|
||||
from torch._six import PY2, PY37, with_metaclass, get_function_from_type, \
|
||||
string_classes, builtins
|
||||
string_classes
|
||||
from ..nn.modules.utils import _single, _pair, _triple, _quadruple, \
|
||||
_list_with_default
|
||||
import torch.testing
|
||||
|
|
@ -930,50 +930,10 @@ def _try_get_overloaded_fn(mod, field):
|
|||
return mod._overloads.get(field, None) if isinstance(mod, ScriptModule) else None
|
||||
|
||||
|
||||
def _try_compile_weak_script(fn):
|
||||
entry = _jit_internal.compiled_weak_fns.get(fn)
|
||||
if entry is None:
|
||||
return None
|
||||
if entry["status"] == _jit_internal.COMPILATION_PENDING:
|
||||
compiled_fn = torch.jit.script(fn, True, 0, entry["rcb"])
|
||||
del entry["rcb"]
|
||||
_jit_internal.compiled_weak_fns[fn]["compiled_fn"] = compiled_fn
|
||||
entry["status"] = _jit_internal.COMPILED
|
||||
return compiled_fn
|
||||
# TODO: use fn.__closure__
|
||||
raise RuntimeError("Cannot make resolutionCallback in Python 2")
|
||||
else:
|
||||
return entry["compiled_fn"]
|
||||
|
||||
|
||||
class ScriptWarning(Warning):
|
||||
pass
|
||||
|
||||
|
||||
def createResolutionCallbackFromClosure(fn):
|
||||
"""
|
||||
Create a resolutionCallback by introspecting the function instead of
|
||||
looking up the stack for the enclosing scope
|
||||
"""
|
||||
var_names = fn.__code__.co_freevars
|
||||
|
||||
# map of captured name -> value
|
||||
free_vars = {}
|
||||
|
||||
for index, name in enumerate(var_names):
|
||||
free_vars[name] = fn.__closure__[index].cell_contents
|
||||
f_globals = fn.__globals__
|
||||
|
||||
def env(key):
|
||||
if key in free_vars:
|
||||
return free_vars[key]
|
||||
elif hasattr(builtins, key):
|
||||
return getattr(builtins, key)
|
||||
else:
|
||||
return f_globals.get(key)
|
||||
|
||||
return env
|
||||
|
||||
def _create_constant_iterable_module(module):
|
||||
modules = OrderedDict()
|
||||
|
||||
|
|
@ -1012,20 +972,20 @@ def _try_compile_fn(fn):
|
|||
# Don't do anything for @ignore'd functions
|
||||
return None
|
||||
|
||||
if not inspect.isfunction(fn) and not inspect.ismethod(fn):
|
||||
raise RuntimeError("`{}` is not a function. Recursive scripting only supports "
|
||||
"Python functions or methods currently.\n"
|
||||
"Consider manually annotating `{}` with @torch.jit.script.".format(fn))
|
||||
|
||||
if isinstance(fn, torch.nn.Module):
|
||||
# Since modules are callable pybind recognizes them as functions, but
|
||||
# don't do anything for them
|
||||
return None
|
||||
|
||||
if not inspect.isfunction(fn) and not inspect.ismethod(fn):
|
||||
raise RuntimeError("`{}` is not a function. Recursive scripting only supports "
|
||||
"Python functions or methods currently.\n"
|
||||
"Consider manually annotating `{}` with @torch.jit.script.".format(fn, fn))
|
||||
|
||||
# We don't have the actual scope where the function was defined, but we can
|
||||
# extract the necessary info from the closed over variables on the function
|
||||
# object
|
||||
rcb = createResolutionCallbackFromClosure(fn)
|
||||
rcb = _jit_internal.createResolutionCallbackFromClosure(fn)
|
||||
return torch.jit.script(fn, _rcb=rcb)
|
||||
|
||||
|
||||
|
|
@ -1040,7 +1000,9 @@ def _disable_emit_hooks():
|
|||
def _create_method_from_fn(module, fn):
|
||||
if _jit_internal.is_ignored_fn(fn):
|
||||
return None
|
||||
stub = script_method(fn, createResolutionCallbackFromClosure(fn))
|
||||
if not inspect.ismethod(fn):
|
||||
return None
|
||||
stub = script_method(fn, _jit_internal.createResolutionCallbackFromClosure(fn))
|
||||
with _disable_emit_hooks():
|
||||
# We don't want to call the hooks here since the graph that is calling
|
||||
# this function is not yet complete
|
||||
|
|
@ -1101,6 +1063,15 @@ def _qualified_name(obj):
|
|||
return module_name + "." + name
|
||||
|
||||
|
||||
def _is_recursive_script_enabled(value):
|
||||
# TODO: [enable recursive script]
|
||||
# when recursive script is made the default, remove this method
|
||||
enabled = torch._C._jit_recursive_script()
|
||||
module = inspect.getmodule(value)
|
||||
if module is not None and 'torch.nn' in module.__name__:
|
||||
enabled = True
|
||||
return enabled
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _enable_recursive_script():
|
||||
torch._C._jit_recursive_script(True)
|
||||
|
|
@ -1114,8 +1085,8 @@ def script(obj, optimize=True, _frames_up=0, _rcb=None):
|
|||
if _rcb is None:
|
||||
_rcb = _jit_internal.createResolutionCallback(_frames_up + 1)
|
||||
|
||||
if torch._C._jit_recursive_script():
|
||||
if isinstance(obj, torch.nn.Module):
|
||||
if isinstance(obj, torch.nn.Module):
|
||||
if _is_recursive_script_enabled(obj):
|
||||
return _convert_to_script_module(obj)
|
||||
|
||||
if inspect.isclass(obj):
|
||||
|
|
@ -1158,21 +1129,6 @@ def script_method(fn, _rcb=None):
|
|||
return ScriptMethodStub(_rcb, ast, fn)
|
||||
|
||||
|
||||
def _try_get_weak_module(mod):
|
||||
"""
|
||||
Get the WeakScriptModuleProxy corresponding to mod if it exists
|
||||
"""
|
||||
if not isinstance(mod, Module):
|
||||
return None
|
||||
return _jit_internal.weak_modules.get(mod)
|
||||
|
||||
|
||||
def _is_weak_type(cls):
|
||||
"""
|
||||
Check if a type has been annotated with `weak_module`
|
||||
"""
|
||||
return cls in _jit_internal.weak_types
|
||||
|
||||
|
||||
# These OrderedDictWrapper classes replace the actual OrderedDicts in
|
||||
# module with versions that get/set properties inside of script::Module.
|
||||
|
|
@ -1569,9 +1525,9 @@ if _enabled:
|
|||
|
||||
def __setattr__(self, attr, value):
|
||||
if attr not in self._constants_set:
|
||||
if isinstance(value, Module) and _is_weak_type(type(value)):
|
||||
if isinstance(value, Module) and _is_recursive_script_enabled(value):
|
||||
# Compile weak script module
|
||||
value = _make_strong(value)
|
||||
value = _convert_to_script_module(value)
|
||||
if attr == 'training':
|
||||
if self._c._has_attribute('training'):
|
||||
self.__dict__['training'] = value
|
||||
|
|
@ -1684,7 +1640,7 @@ if _enabled:
|
|||
if isinstance(item, (ModuleList, Sequential)):
|
||||
# These are in __constants__, so ignore them here
|
||||
|
||||
if not torch._C._jit_recursive_script():
|
||||
if not _is_recursive_script_enabled(item):
|
||||
# For recursive script, these are constantified after
|
||||
# they are used, so they don't need to be in constants.
|
||||
# The `continue` here should be deleted along with
|
||||
|
|
@ -1774,34 +1730,28 @@ else:
|
|||
super(ScriptModule, self).__init__()
|
||||
|
||||
|
||||
def _get_weak_stubs(cls):
|
||||
"""
|
||||
Calls script_method for each method that has been annotated with @weak_script
|
||||
on the type of the object passed in and returns the generated ScriptMethodStubs.
|
||||
"""
|
||||
stubs = []
|
||||
for name in dir(cls):
|
||||
func = get_function_from_type(cls, name)
|
||||
if func in _jit_internal.weak_script_methods:
|
||||
entry = _jit_internal.weak_script_methods[func]
|
||||
stub = script_method(entry["original_method"], entry["rcb"])
|
||||
stubs.append(stub)
|
||||
return stubs
|
||||
|
||||
|
||||
def _convert_to_script_module(mod, methods=None):
|
||||
def _convert_to_script_module(mod):
|
||||
"""
|
||||
Makes a ScriptModule from an nn.Module. If `_methods` is provided,
|
||||
these methods are treated as @script_methods. If not, it defaults to
|
||||
`('forward',)`. Methods accessed in forward are scripted on demand if
|
||||
`_enable_recursive_script()` is used.
|
||||
"""
|
||||
if isinstance(mod, ScriptModule):
|
||||
return mod
|
||||
|
||||
if isinstance(mod, (ModuleList, Sequential)):
|
||||
# Create constant versions for the iterable modules
|
||||
return _create_constant_iterable_module(mod)
|
||||
|
||||
if methods is None:
|
||||
methods = ('forward',)
|
||||
methods = ()
|
||||
if hasattr(mod, 'forward'):
|
||||
if mod.forward.__func__ == torch.nn.Module.forward:
|
||||
# TODO: [enable recursive script]
|
||||
# forward was not overrided
|
||||
raise RuntimeError("No forward method was defined on {}".format(mod))
|
||||
if not _jit_internal.is_ignored_fn(mod.forward):
|
||||
methods = ('forward',)
|
||||
exported = []
|
||||
for name in dir(mod):
|
||||
item = getattr(mod, name)
|
||||
|
|
@ -1812,36 +1762,12 @@ def _convert_to_script_module(mod, methods=None):
|
|||
|
||||
def make_stub(method):
|
||||
func = get_function_from_type(type(mod), method)
|
||||
return script_method(func, createResolutionCallbackFromClosure(func))
|
||||
return script_method(func, _jit_internal.createResolutionCallbackFromClosure(func))
|
||||
|
||||
stubs = list(map(make_stub, methods))
|
||||
return WeakScriptModuleProxy(mod, stubs)
|
||||
|
||||
|
||||
def _make_strong(mod):
|
||||
"""
|
||||
Converts a weak module into a subclass of ScriptModule. If `_methods` is
|
||||
provided, only these methods are treated as @script_methods.
|
||||
"""
|
||||
if mod in _jit_internal.weak_modules:
|
||||
return _jit_internal.weak_modules[mod]
|
||||
|
||||
cls = type(mod)
|
||||
# Explicitly annotated weak script
|
||||
stubs = _jit_internal.weak_types.get(cls)["method_stubs"]
|
||||
if stubs is None:
|
||||
# Generate stubs and and store on weak_types in case this type is
|
||||
# used again
|
||||
stubs = _get_weak_stubs(cls)
|
||||
_jit_internal.weak_types[cls]["method_stubs"] = stubs
|
||||
|
||||
proxy = WeakScriptModuleProxy(mod, stubs)
|
||||
|
||||
_jit_internal.weak_modules[mod] = proxy
|
||||
|
||||
return proxy
|
||||
|
||||
|
||||
def _get_methods(cls):
|
||||
import inspect
|
||||
# In Python 3 unbound methods are functions, but in Python 2 they are methods
|
||||
|
|
@ -1937,13 +1863,13 @@ class _ConstModuleList(ScriptModule):
|
|||
|
||||
if isinstance(modules, OrderedDict):
|
||||
for key, module in modules.items():
|
||||
if _is_weak_type(type(module)):
|
||||
module = _make_strong(module)
|
||||
if isinstance(module, torch.nn.Module) and _is_recursive_script_enabled(module):
|
||||
module = _convert_to_script_module(module)
|
||||
self.add_module(key, module)
|
||||
else:
|
||||
for i, module in enumerate(modules):
|
||||
if _is_weak_type(type(module)):
|
||||
module = _make_strong(module)
|
||||
if isinstance(module, torch.nn.Module) and _is_recursive_script_enabled(module):
|
||||
module = _convert_to_script_module(module)
|
||||
self.add_module(str(i), module)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
|
|
|
|||
|
|
@ -1,9 +1,7 @@
|
|||
import torch
|
||||
import torch.backends.cudnn as cudnn
|
||||
from ..._jit_internal import weak_script
|
||||
|
||||
|
||||
@weak_script
|
||||
def affine_grid_generator(theta, size):
|
||||
# type: (Tensor, List[int]) -> Tensor
|
||||
if theta.is_cuda and cudnn.enabled and cudnn.is_acceptable(theta) and len(size) == 4 and size[0] < 65536:
|
||||
|
|
|
|||
|
|
@ -1,10 +1,8 @@
|
|||
import warnings
|
||||
from .._jit_internal import weak_script
|
||||
|
||||
# NB: Keep this file in sync with enums in aten/src/ATen/core/Reduction.h
|
||||
|
||||
|
||||
@weak_script
|
||||
def get_enum(reduction):
|
||||
# type: (str) -> int
|
||||
if reduction == 'none':
|
||||
|
|
@ -26,7 +24,6 @@ def get_enum(reduction):
|
|||
|
||||
|
||||
# We use these functions in torch/legacy as well, in which case we'll silence the warning
|
||||
@weak_script
|
||||
def legacy_get_string(size_average, reduce, emit_warning=True):
|
||||
# type: (Optional[bool], Optional[bool], bool) -> str
|
||||
warning = "size_average and reduce args will be deprecated, please use reduction='{}' instead."
|
||||
|
|
@ -47,7 +44,6 @@ def legacy_get_string(size_average, reduce, emit_warning=True):
|
|||
return ret
|
||||
|
||||
|
||||
@weak_script
|
||||
def legacy_get_enum(size_average, reduce, emit_warning=True):
|
||||
# type: (Optional[bool], Optional[bool], bool) -> int
|
||||
return get_enum(legacy_get_string(size_average, reduce, emit_warning))
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from ._functions import vision
|
|||
from .modules.utils import _single, _pair, _triple, _list_with_default
|
||||
from . import grad # noqa: F401
|
||||
from . import _VF
|
||||
from .._jit_internal import weak_script, List
|
||||
from .._jit_internal import boolean_dispatch, List
|
||||
|
||||
|
||||
conv1d = _add_docstr(torch.conv1d, r"""
|
||||
|
|
@ -299,7 +299,6 @@ Args:
|
|||
""")
|
||||
|
||||
|
||||
@weak_script
|
||||
def fractional_max_pool2d_with_indices(input, kernel_size, output_size=None,
|
||||
output_ratio=None, return_indices=False,
|
||||
_random_samples=None):
|
||||
|
|
@ -346,7 +345,6 @@ def fractional_max_pool2d_with_indices(input, kernel_size, output_size=None,
|
|||
return torch._C._nn.fractional_max_pool2d(input, kernel_size, output_size, _random_samples)
|
||||
|
||||
|
||||
@weak_script
|
||||
def _fractional_max_pool2d(input, kernel_size, output_size=None,
|
||||
output_ratio=None, return_indices=False,
|
||||
_random_samples=None):
|
||||
|
|
@ -355,7 +353,7 @@ def _fractional_max_pool2d(input, kernel_size, output_size=None,
|
|||
output_ratio, return_indices,
|
||||
_random_samples)[0]
|
||||
|
||||
fractional_max_pool2d = torch._jit_internal.boolean_dispatch(
|
||||
fractional_max_pool2d = boolean_dispatch(
|
||||
arg_name='return_indices',
|
||||
arg_index=4,
|
||||
default=False,
|
||||
|
|
@ -365,7 +363,6 @@ fractional_max_pool2d = torch._jit_internal.boolean_dispatch(
|
|||
func_name='fractional_max_pool2d')
|
||||
|
||||
|
||||
@weak_script
|
||||
def fractional_max_pool3d_with_indices(input, kernel_size, output_size=None,
|
||||
output_ratio=None, return_indices=False,
|
||||
_random_samples=None):
|
||||
|
|
@ -414,7 +411,6 @@ def fractional_max_pool3d_with_indices(input, kernel_size, output_size=None,
|
|||
return torch._C._nn.fractional_max_pool3d(input, kernel_size, output_size, _random_samples)
|
||||
|
||||
|
||||
@weak_script
|
||||
def _fractional_max_pool3d(input, kernel_size, output_size=None,
|
||||
output_ratio=None, return_indices=False,
|
||||
_random_samples=None):
|
||||
|
|
@ -423,7 +419,7 @@ def _fractional_max_pool3d(input, kernel_size, output_size=None,
|
|||
output_ratio, return_indices,
|
||||
_random_samples)[0]
|
||||
|
||||
fractional_max_pool3d = torch._jit_internal.boolean_dispatch(
|
||||
fractional_max_pool3d = boolean_dispatch(
|
||||
arg_name='return_indices',
|
||||
arg_index=4,
|
||||
default=False,
|
||||
|
|
@ -433,7 +429,6 @@ fractional_max_pool3d = torch._jit_internal.boolean_dispatch(
|
|||
func_name='fractional_max_pool3d')
|
||||
|
||||
|
||||
@weak_script
|
||||
def max_pool1d_with_indices(input, kernel_size, stride=None, padding=0,
|
||||
dilation=1, ceil_mode=False, return_indices=False):
|
||||
# type: (Tensor, BroadcastingList1[int], Optional[BroadcastingList1[int]], BroadcastingList1[int], BroadcastingList1[int], bool, bool) -> Tuple[Tensor, Tensor] # noqa
|
||||
|
|
@ -448,7 +443,6 @@ def max_pool1d_with_indices(input, kernel_size, stride=None, padding=0,
|
|||
input, kernel_size, stride, padding, dilation, ceil_mode)
|
||||
|
||||
|
||||
@weak_script
|
||||
def _max_pool1d(input, kernel_size, stride=None, padding=0, dilation=1,
|
||||
ceil_mode=False, return_indices=False):
|
||||
# type: (Tensor, BroadcastingList1[int], Optional[BroadcastingList1[int]], BroadcastingList1[int], BroadcastingList1[int], bool, bool) -> Tensor # noqa
|
||||
|
|
@ -457,7 +451,7 @@ def _max_pool1d(input, kernel_size, stride=None, padding=0, dilation=1,
|
|||
return torch.max_pool1d(
|
||||
input, kernel_size, stride, padding, dilation, ceil_mode)
|
||||
|
||||
max_pool1d = torch._jit_internal.boolean_dispatch(
|
||||
max_pool1d = boolean_dispatch(
|
||||
arg_name='return_indices',
|
||||
arg_index=6,
|
||||
default=False,
|
||||
|
|
@ -467,7 +461,6 @@ max_pool1d = torch._jit_internal.boolean_dispatch(
|
|||
func_name='max_pool1d')
|
||||
|
||||
|
||||
@weak_script
|
||||
def max_pool2d_with_indices(input, kernel_size, stride=None, padding=0, dilation=1,
|
||||
ceil_mode=False, return_indices=False):
|
||||
# type: (Tensor, BroadcastingList2[int], Optional[BroadcastingList2[int]], BroadcastingList2[int], BroadcastingList2[int], bool, bool) -> Tuple[Tensor, Tensor] # noqa
|
||||
|
|
@ -481,7 +474,6 @@ def max_pool2d_with_indices(input, kernel_size, stride=None, padding=0, dilation
|
|||
return torch._C._nn.max_pool2d_with_indices(input, kernel_size, stride, padding, dilation, ceil_mode)
|
||||
|
||||
|
||||
@weak_script
|
||||
def _max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1,
|
||||
ceil_mode=False, return_indices=False):
|
||||
# type: (Tensor, BroadcastingList2[int], Optional[BroadcastingList2[int]], BroadcastingList2[int], BroadcastingList2[int], bool, bool) -> Tensor # noqa
|
||||
|
|
@ -490,7 +482,7 @@ def _max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1,
|
|||
return torch.max_pool2d(
|
||||
input, kernel_size, stride, padding, dilation, ceil_mode)
|
||||
|
||||
max_pool2d = torch._jit_internal.boolean_dispatch(
|
||||
max_pool2d = boolean_dispatch(
|
||||
arg_name='return_indices',
|
||||
arg_index=6,
|
||||
default=False,
|
||||
|
|
@ -500,7 +492,6 @@ max_pool2d = torch._jit_internal.boolean_dispatch(
|
|||
func_name='max_pool2d')
|
||||
|
||||
|
||||
@weak_script
|
||||
def max_pool3d_with_indices(input, kernel_size, stride=None, padding=0,
|
||||
dilation=1, ceil_mode=False, return_indices=False):
|
||||
# type: (Tensor, BroadcastingList3[int], Optional[BroadcastingList3[int]], BroadcastingList3[int], BroadcastingList3[int], bool, bool) -> Tuple[Tensor, Tensor] # noqa
|
||||
|
|
@ -515,7 +506,6 @@ def max_pool3d_with_indices(input, kernel_size, stride=None, padding=0,
|
|||
input, kernel_size, stride, padding, dilation, ceil_mode)
|
||||
|
||||
|
||||
@weak_script
|
||||
def _max_pool3d(input, kernel_size, stride=None, padding=0, dilation=1,
|
||||
ceil_mode=False, return_indices=False):
|
||||
# type: (Tensor, BroadcastingList3[int], Optional[BroadcastingList3[int]], BroadcastingList3[int], BroadcastingList3[int], bool, bool) -> Tensor # noqa
|
||||
|
|
@ -524,7 +514,7 @@ def _max_pool3d(input, kernel_size, stride=None, padding=0, dilation=1,
|
|||
return torch.max_pool3d(
|
||||
input, kernel_size, stride, padding, dilation, ceil_mode)
|
||||
|
||||
max_pool3d = torch._jit_internal.boolean_dispatch(
|
||||
max_pool3d = boolean_dispatch(
|
||||
arg_name='return_indices',
|
||||
arg_index=6,
|
||||
default=False,
|
||||
|
|
@ -534,7 +524,6 @@ max_pool3d = torch._jit_internal.boolean_dispatch(
|
|||
func_name='max_pool3d')
|
||||
|
||||
|
||||
@weak_script
|
||||
def _unpool_output_size(input, kernel_size, stride, padding, output_size):
|
||||
# type: (Tensor, List[int], List[int], List[int], Optional[List[int]]) -> List[int]
|
||||
input_size = input.size()
|
||||
|
|
@ -564,7 +553,6 @@ def _unpool_output_size(input, kernel_size, stride, padding, output_size):
|
|||
return ret
|
||||
|
||||
|
||||
@weak_script
|
||||
def max_unpool1d(input, indices, kernel_size, stride=None, padding=0,
|
||||
output_size=None):
|
||||
# type: (Tensor, Tensor, BroadcastingList1[int], Optional[BroadcastingList1[int]], BroadcastingList1[int], Optional[BroadcastingList1[int]]) -> Tensor # noqa
|
||||
|
|
@ -588,7 +576,6 @@ def max_unpool1d(input, indices, kernel_size, stride=None, padding=0,
|
|||
output_size).squeeze(3)
|
||||
|
||||
|
||||
@weak_script
|
||||
def max_unpool2d(input, indices, kernel_size, stride=None, padding=0,
|
||||
output_size=None):
|
||||
# type: (Tensor, Tensor, BroadcastingList2[int], Optional[BroadcastingList2[int]], BroadcastingList2[int], Optional[BroadcastingList2[int]]) -> Tensor # noqa
|
||||
|
|
@ -607,7 +594,6 @@ def max_unpool2d(input, indices, kernel_size, stride=None, padding=0,
|
|||
return torch._C._nn.max_unpool2d(input, indices, output_size)
|
||||
|
||||
|
||||
@weak_script
|
||||
def max_unpool3d(input, indices, kernel_size, stride=None, padding=0,
|
||||
output_size=None):
|
||||
# type: (Tensor, Tensor, BroadcastingList3[int], Optional[BroadcastingList3[int]], BroadcastingList3[int], Optional[BroadcastingList3[int]]) -> Tensor # noqa
|
||||
|
|
@ -627,7 +613,6 @@ def max_unpool3d(input, indices, kernel_size, stride=None, padding=0,
|
|||
input, indices, output_size, _stride, padding)
|
||||
|
||||
|
||||
@weak_script
|
||||
def lp_pool2d(input, norm_type, kernel_size, stride=None, ceil_mode=False):
|
||||
# type: (Tensor, float, int, Optional[BroadcastingList2[int]], bool) -> Tensor
|
||||
r"""Applies a 2D power-average pooling over an input signal composed of
|
||||
|
|
@ -645,7 +630,6 @@ def lp_pool2d(input, norm_type, kernel_size, stride=None, ceil_mode=False):
|
|||
return (torch.sign(out) * relu(torch.abs(out))).mul(kw * kh).pow(1. / norm_type)
|
||||
|
||||
|
||||
@weak_script
|
||||
def lp_pool1d(input, norm_type, kernel_size, stride=None, ceil_mode=False):
|
||||
# type: (Tensor, float, int, Optional[BroadcastingList1[int]], bool) -> Tensor
|
||||
r"""Applies a 1D power-average pooling over an input signal composed of
|
||||
|
|
@ -662,7 +646,6 @@ def lp_pool1d(input, norm_type, kernel_size, stride=None, ceil_mode=False):
|
|||
return (torch.sign(out) * relu(torch.abs(out))).mul(kernel_size).pow(1. / norm_type)
|
||||
|
||||
|
||||
@weak_script
|
||||
def adaptive_max_pool1d_with_indices(input, output_size, return_indices=False):
|
||||
# type: (Tensor, BroadcastingList1[int], bool) -> Tuple[Tensor, Tensor]
|
||||
r"""Applies a 1D adaptive max pooling over an input signal composed of
|
||||
|
|
@ -677,12 +660,11 @@ def adaptive_max_pool1d_with_indices(input, output_size, return_indices=False):
|
|||
return torch.adaptive_max_pool1d(input, output_size)
|
||||
|
||||
|
||||
@weak_script
|
||||
def _adaptive_max_pool1d(input, output_size, return_indices=False):
|
||||
# type: (Tensor, BroadcastingList1[int], bool) -> Tensor
|
||||
return adaptive_max_pool1d_with_indices(input, output_size)[0]
|
||||
|
||||
adaptive_max_pool1d = torch._jit_internal.boolean_dispatch(
|
||||
adaptive_max_pool1d = boolean_dispatch(
|
||||
arg_name='return_indices',
|
||||
arg_index=2,
|
||||
default=False,
|
||||
|
|
@ -692,7 +674,6 @@ adaptive_max_pool1d = torch._jit_internal.boolean_dispatch(
|
|||
func_name='adaptive_max_pool1d')
|
||||
|
||||
|
||||
@weak_script
|
||||
def adaptive_max_pool2d_with_indices(input, output_size, return_indices=False):
|
||||
# type: (Tensor, BroadcastingList2[int], bool) -> Tuple[Tensor, Tensor]
|
||||
r"""Applies a 2D adaptive max pooling over an input signal composed of
|
||||
|
|
@ -709,12 +690,11 @@ def adaptive_max_pool2d_with_indices(input, output_size, return_indices=False):
|
|||
return torch._C._nn.adaptive_max_pool2d(input, output_size)
|
||||
|
||||
|
||||
@weak_script
|
||||
def _adaptive_max_pool2d(input, output_size, return_indices=False):
|
||||
# type: (Tensor, BroadcastingList2[int], bool) -> Tensor
|
||||
return adaptive_max_pool2d_with_indices(input, output_size)[0]
|
||||
|
||||
adaptive_max_pool2d = torch._jit_internal.boolean_dispatch(
|
||||
adaptive_max_pool2d = boolean_dispatch(
|
||||
arg_name='return_indices',
|
||||
arg_index=2,
|
||||
default=False,
|
||||
|
|
@ -724,7 +704,6 @@ adaptive_max_pool2d = torch._jit_internal.boolean_dispatch(
|
|||
func_name='adaptive_max_pool2d')
|
||||
|
||||
|
||||
@weak_script
|
||||
def adaptive_max_pool3d_with_indices(input, output_size, return_indices=False):
|
||||
# type: (Tensor, BroadcastingList3[int], bool) -> Tuple[Tensor, Tensor]
|
||||
r"""Applies a 3D adaptive max pooling over an input signal composed of
|
||||
|
|
@ -741,12 +720,11 @@ def adaptive_max_pool3d_with_indices(input, output_size, return_indices=False):
|
|||
return torch._C._nn.adaptive_max_pool3d(input, output_size)
|
||||
|
||||
|
||||
@weak_script
|
||||
def _adaptive_max_pool3d(input, output_size, return_indices=False):
|
||||
# type: (Tensor, BroadcastingList3[int], bool) -> Tensor
|
||||
return adaptive_max_pool3d_with_indices(input, output_size)[0]
|
||||
|
||||
adaptive_max_pool3d = torch._jit_internal.boolean_dispatch(
|
||||
adaptive_max_pool3d = boolean_dispatch(
|
||||
arg_name='return_indices',
|
||||
arg_index=2,
|
||||
default=False,
|
||||
|
|
@ -769,7 +747,6 @@ Args:
|
|||
""")
|
||||
|
||||
|
||||
@weak_script
|
||||
def adaptive_avg_pool2d(input, output_size):
|
||||
# type: (Tensor, BroadcastingList2[int]) -> Tensor
|
||||
r"""
|
||||
|
|
@ -786,7 +763,6 @@ def adaptive_avg_pool2d(input, output_size):
|
|||
return torch._C._nn.adaptive_avg_pool2d(input, _output_size)
|
||||
|
||||
|
||||
@weak_script
|
||||
def adaptive_avg_pool3d(input, output_size):
|
||||
# type: (Tensor, BroadcastingList3[int]) -> Tensor
|
||||
r"""
|
||||
|
|
@ -804,7 +780,6 @@ def adaptive_avg_pool3d(input, output_size):
|
|||
|
||||
|
||||
# Activation functions
|
||||
@weak_script
|
||||
def dropout(input, p=0.5, training=True, inplace=False):
|
||||
# type: (Tensor, float, bool, bool) -> Tensor
|
||||
r"""
|
||||
|
|
@ -827,7 +802,6 @@ def dropout(input, p=0.5, training=True, inplace=False):
|
|||
else _VF.dropout(input, p, training))
|
||||
|
||||
|
||||
@weak_script
|
||||
def alpha_dropout(input, p=0.5, training=False, inplace=False):
|
||||
# type: (Tensor, float, bool, bool) -> Tensor
|
||||
r"""Applies alpha dropout to the input.
|
||||
|
|
@ -842,7 +816,6 @@ def alpha_dropout(input, p=0.5, training=False, inplace=False):
|
|||
else _VF.alpha_dropout(input, p, training))
|
||||
|
||||
|
||||
@weak_script
|
||||
def dropout2d(input, p=0.5, training=True, inplace=False):
|
||||
# type: (Tensor, float, bool, bool) -> Tensor
|
||||
r"""
|
||||
|
|
@ -867,7 +840,6 @@ def dropout2d(input, p=0.5, training=True, inplace=False):
|
|||
else _VF.feature_dropout(input, p, training))
|
||||
|
||||
|
||||
@weak_script
|
||||
def dropout3d(input, p=0.5, training=True, inplace=False):
|
||||
# type: (Tensor, float, bool, bool) -> Tensor
|
||||
r"""
|
||||
|
|
@ -894,7 +866,6 @@ def dropout3d(input, p=0.5, training=True, inplace=False):
|
|||
else _VF.feature_dropout(input, p, training))
|
||||
|
||||
|
||||
@weak_script
|
||||
def feature_alpha_dropout(input, p=0.5, training=False, inplace=False):
|
||||
# type: (Tensor, float, bool, bool) -> Tensor
|
||||
if p < 0. or p > 1.:
|
||||
|
|
@ -905,7 +876,6 @@ def feature_alpha_dropout(input, p=0.5, training=False, inplace=False):
|
|||
else _VF.feature_alpha_dropout(input, p, training))
|
||||
|
||||
|
||||
@weak_script
|
||||
def threshold(input, threshold, value, inplace=False):
|
||||
# type: (Tensor, float, float, bool) -> Tensor
|
||||
r"""Thresholds each element of the input Tensor.
|
||||
|
|
@ -926,7 +896,6 @@ In-place version of :func:`~threshold`.
|
|||
""")
|
||||
|
||||
|
||||
@weak_script
|
||||
def relu(input, inplace=False):
|
||||
# type: (Tensor, bool) -> Tensor
|
||||
r"""relu(input, inplace=False) -> Tensor
|
||||
|
|
@ -948,7 +917,6 @@ In-place version of :func:`~relu`.
|
|||
""")
|
||||
|
||||
|
||||
@weak_script
|
||||
def glu(input, dim=-1):
|
||||
# type: (Tensor, int) -> Tensor
|
||||
r"""
|
||||
|
|
@ -973,7 +941,6 @@ def glu(input, dim=-1):
|
|||
return torch._C._nn.glu(input, dim)
|
||||
|
||||
|
||||
@weak_script
|
||||
def hardtanh(input, min_val=-1., max_val=1., inplace=False):
|
||||
# type: (Tensor, float, float, bool) -> Tensor
|
||||
r"""
|
||||
|
|
@ -996,7 +963,6 @@ In-place version of :func:`~hardtanh`.
|
|||
""")
|
||||
|
||||
|
||||
@weak_script
|
||||
def relu6(input, inplace=False):
|
||||
# type: (Tensor, bool) -> Tensor
|
||||
r"""relu6(input, inplace=False) -> Tensor
|
||||
|
|
@ -1008,7 +974,6 @@ def relu6(input, inplace=False):
|
|||
return hardtanh(input, 0., 6., inplace)
|
||||
|
||||
|
||||
@weak_script
|
||||
def elu(input, alpha=1., inplace=False):
|
||||
# type: (Tensor, float, bool) -> Tensor
|
||||
r"""Applies element-wise,
|
||||
|
|
@ -1030,7 +995,6 @@ In-place version of :func:`~elu`.
|
|||
""")
|
||||
|
||||
|
||||
@weak_script
|
||||
def selu(input, inplace=False):
|
||||
# type: (Tensor, bool) -> Tensor
|
||||
r"""selu(input, inplace=False) -> Tensor
|
||||
|
|
@ -1056,7 +1020,6 @@ In-place version of :func:`~selu`.
|
|||
""")
|
||||
|
||||
|
||||
@weak_script
|
||||
def celu(input, alpha=1., inplace=False):
|
||||
# type: (Tensor, float, bool) -> Tensor
|
||||
r"""celu(input, alpha=1., inplace=False) -> Tensor
|
||||
|
|
@ -1079,7 +1042,6 @@ In-place version of :func:`~celu`.
|
|||
""")
|
||||
|
||||
|
||||
@weak_script
|
||||
def leaky_relu(input, negative_slope=0.01, inplace=False):
|
||||
# type: (Tensor, float, bool) -> Tensor
|
||||
r"""
|
||||
|
|
@ -1104,7 +1066,6 @@ In-place version of :func:`~leaky_relu`.
|
|||
""")
|
||||
|
||||
|
||||
@weak_script
|
||||
def prelu(input, weight):
|
||||
# type: (Tensor, Tensor) -> Tensor
|
||||
r"""prelu(input, weight) -> Tensor
|
||||
|
|
@ -1118,7 +1079,6 @@ def prelu(input, weight):
|
|||
return torch.prelu(input, weight)
|
||||
|
||||
|
||||
@weak_script
|
||||
def rrelu(input, lower=1. / 8, upper=1. / 3, training=False, inplace=False):
|
||||
# type: (Tensor, float, float, bool, bool) -> Tensor
|
||||
r"""rrelu(input, lower=1./8, upper=1./3, training=False, inplace=False) -> Tensor
|
||||
|
|
@ -1148,7 +1108,6 @@ Applies element-wise :math:`\text{LogSigmoid}(x_i) = \log \left(\frac{1}{1 + \ex
|
|||
See :class:`~torch.nn.LogSigmoid` for more details.
|
||||
""")
|
||||
|
||||
@weak_script
|
||||
def gelu(input):
|
||||
r"""gelu(input) -> Tensor
|
||||
|
||||
|
|
@ -1162,7 +1121,6 @@ def gelu(input):
|
|||
return torch._C._nn.gelu(input)
|
||||
|
||||
|
||||
@weak_script
|
||||
def hardshrink(input, lambd=0.5):
|
||||
# type: (Tensor, float) -> Tensor
|
||||
r"""
|
||||
|
|
@ -1175,7 +1133,6 @@ def hardshrink(input, lambd=0.5):
|
|||
return torch.hardshrink(input, lambd)
|
||||
|
||||
|
||||
@weak_script
|
||||
def tanhshrink(input):
|
||||
r"""tanhshrink(input) -> Tensor
|
||||
|
||||
|
|
@ -1186,7 +1143,6 @@ def tanhshrink(input):
|
|||
return input - input.tanh()
|
||||
|
||||
|
||||
@weak_script
|
||||
def softsign(input):
|
||||
r"""softsign(input) -> Tensor
|
||||
|
||||
|
|
@ -1202,7 +1158,6 @@ softplus(input, beta=1, threshold=20) -> Tensor
|
|||
""")
|
||||
|
||||
|
||||
@weak_script
|
||||
def _get_softmax_dim(name, ndim, stacklevel):
|
||||
# type: (str, int, int) -> int
|
||||
warnings.warn("Implicit dimension choice for {} has been deprecated. "
|
||||
|
|
@ -1214,7 +1169,6 @@ def _get_softmax_dim(name, ndim, stacklevel):
|
|||
return ret
|
||||
|
||||
|
||||
@weak_script
|
||||
def softmin(input, dim=None, _stacklevel=3, dtype=None):
|
||||
# type: (Tensor, Optional[int], int, Optional[int]) -> Tensor
|
||||
r"""Applies a softmin function.
|
||||
|
|
@ -1240,7 +1194,6 @@ def softmin(input, dim=None, _stacklevel=3, dtype=None):
|
|||
return ret
|
||||
|
||||
|
||||
@weak_script
|
||||
def softmax(input, dim=None, _stacklevel=3, dtype=None):
|
||||
# type: (Tensor, Optional[int], int, Optional[int]) -> Tensor
|
||||
r"""Applies a softmax function.
|
||||
|
|
@ -1276,7 +1229,6 @@ def softmax(input, dim=None, _stacklevel=3, dtype=None):
|
|||
return ret
|
||||
|
||||
|
||||
@weak_script
|
||||
def gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1):
|
||||
# type: (Tensor, float, bool, float, int) -> Tensor
|
||||
r"""
|
||||
|
|
@ -1337,7 +1289,6 @@ def gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1):
|
|||
return ret
|
||||
|
||||
|
||||
@weak_script
|
||||
def log_softmax(input, dim=None, _stacklevel=3, dtype=None):
|
||||
# type: (Tensor, Optional[int], int, Optional[int]) -> Tensor
|
||||
r"""Applies a softmax followed by a logarithm.
|
||||
|
|
@ -1373,7 +1324,6 @@ See :class:`~torch.nn.Softshrink` for more details.
|
|||
""")
|
||||
|
||||
|
||||
@weak_script
|
||||
def tanh(input):
|
||||
r"""tanh(input) -> Tensor
|
||||
|
||||
|
|
@ -1386,7 +1336,6 @@ def tanh(input):
|
|||
return input.tanh()
|
||||
|
||||
|
||||
@weak_script
|
||||
def sigmoid(input):
|
||||
r"""sigmoid(input) -> Tensor
|
||||
|
||||
|
|
@ -1398,7 +1347,6 @@ def sigmoid(input):
|
|||
return input.sigmoid()
|
||||
|
||||
|
||||
@weak_script
|
||||
def linear(input, weight, bias=None):
|
||||
# type: (Tensor, Tensor, Optional[Tensor]) -> Tensor
|
||||
r"""
|
||||
|
|
@ -1423,7 +1371,6 @@ def linear(input, weight, bias=None):
|
|||
return ret
|
||||
|
||||
|
||||
@weak_script
|
||||
def bilinear(input1, input2, weight, bias=None):
|
||||
# type: (Tensor, Tensor, Tensor, Optional[Tensor]) -> Tensor
|
||||
return torch.bilinear(input1, input2, weight, bias)
|
||||
|
|
@ -1435,7 +1382,6 @@ def _no_grad_embedding_renorm_(weight, input, max_norm, norm_type):
|
|||
torch.embedding_renorm_(weight, input, max_norm, norm_type)
|
||||
|
||||
|
||||
@weak_script
|
||||
def embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2.,
|
||||
scale_grad_by_freq=False, sparse=False):
|
||||
# type: (Tensor, Tensor, Optional[int], Optional[float], float, bool, bool) -> Tensor
|
||||
|
|
@ -1517,7 +1463,6 @@ def embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2.,
|
|||
return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
|
||||
|
||||
|
||||
@weak_script
|
||||
def embedding_bag(input, weight, offsets=None, max_norm=None, norm_type=2,
|
||||
scale_grad_by_freq=False, mode='mean', sparse=False,
|
||||
per_sample_weights=None):
|
||||
|
|
@ -1677,7 +1622,6 @@ def embedding_bag(input, weight, offsets=None, max_norm=None, norm_type=2,
|
|||
return ret
|
||||
|
||||
|
||||
@weak_script
|
||||
def batch_norm(input, running_mean, running_var, weight=None, bias=None,
|
||||
training=False, momentum=0.1, eps=1e-5):
|
||||
# type: (Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor], bool, float, float) -> Tensor # noqa
|
||||
|
|
@ -1709,7 +1653,6 @@ def batch_norm(input, running_mean, running_var, weight=None, bias=None,
|
|||
)
|
||||
|
||||
|
||||
@weak_script
|
||||
def instance_norm(input, running_mean=None, running_var=None, weight=None,
|
||||
bias=None, use_input_stats=True, momentum=0.1, eps=1e-5):
|
||||
# type: (Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor], bool, float, float) -> Tensor # noqa
|
||||
|
|
@ -1725,7 +1668,6 @@ def instance_norm(input, running_mean=None, running_var=None, weight=None,
|
|||
)
|
||||
|
||||
|
||||
@weak_script
|
||||
def layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-5):
|
||||
# type: (Tensor, List[int], Optional[Tensor], Optional[Tensor], float) -> Tensor
|
||||
r"""Applies Layer Normalization for last certain number of dimensions.
|
||||
|
|
@ -1736,7 +1678,6 @@ def layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-5):
|
|||
torch.backends.cudnn.enabled)
|
||||
|
||||
|
||||
@weak_script
|
||||
def group_norm(input, num_groups, weight=None, bias=None, eps=1e-5):
|
||||
# type: (Tensor, int, Optional[Tensor], Optional[Tensor], float) -> Tensor
|
||||
r"""Applies Group Normalization for last certain number of dimensions.
|
||||
|
|
@ -1747,7 +1688,6 @@ def group_norm(input, num_groups, weight=None, bias=None, eps=1e-5):
|
|||
torch.backends.cudnn.enabled)
|
||||
|
||||
|
||||
@weak_script
|
||||
def local_response_norm(input, size, alpha=1e-4, beta=0.75, k=1.):
|
||||
# type: (Tensor, int, float, float, float) -> Tensor
|
||||
r"""Applies local response normalization over an input signal composed of
|
||||
|
|
@ -1776,7 +1716,6 @@ def local_response_norm(input, size, alpha=1e-4, beta=0.75, k=1.):
|
|||
|
||||
# loss
|
||||
|
||||
@weak_script
|
||||
def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=0,
|
||||
reduction='mean', zero_infinity=False):
|
||||
# type: (Tensor, Tensor, Tensor, Tensor, int, str, bool) -> Tensor
|
||||
|
|
@ -1824,7 +1763,6 @@ def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=0,
|
|||
zero_infinity)
|
||||
|
||||
|
||||
@weak_script
|
||||
def nll_loss(input, target, weight=None, size_average=None, ignore_index=-100,
|
||||
reduce=None, reduction='mean'):
|
||||
# type: (Tensor, Tensor, Optional[Tensor], Optional[bool], int, Optional[bool], str) -> Tensor
|
||||
|
|
@ -1903,7 +1841,6 @@ def nll_loss(input, target, weight=None, size_average=None, ignore_index=-100,
|
|||
return ret
|
||||
|
||||
|
||||
@weak_script
|
||||
def poisson_nll_loss(input, target, log_input=True, full=False, size_average=None, eps=1e-8,
|
||||
reduce=None, reduction='mean'):
|
||||
# type: (Tensor, Tensor, bool, bool, Optional[bool], float, Optional[bool], str) -> Tensor
|
||||
|
|
@ -1949,7 +1886,6 @@ def poisson_nll_loss(input, target, log_input=True, full=False, size_average=Non
|
|||
return ret
|
||||
|
||||
|
||||
@weak_script
|
||||
def kl_div(input, target, size_average=None, reduce=None, reduction='mean'):
|
||||
# type: (Tensor, Tensor, Optional[bool], Optional[bool], str) -> Tensor
|
||||
r"""The `Kullback-Leibler divergence`_ Loss.
|
||||
|
|
@ -2007,7 +1943,6 @@ def kl_div(input, target, size_average=None, reduce=None, reduction='mean'):
|
|||
return reduced
|
||||
|
||||
|
||||
@weak_script
|
||||
def cross_entropy(input, target, weight=None, size_average=None, ignore_index=-100,
|
||||
reduce=None, reduction='mean'):
|
||||
# type: (Tensor, Tensor, Optional[Tensor], Optional[bool], int, Optional[bool], str) -> Tensor
|
||||
|
|
@ -2056,7 +1991,6 @@ def cross_entropy(input, target, weight=None, size_average=None, ignore_index=-1
|
|||
return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
|
||||
|
||||
|
||||
@weak_script
|
||||
def binary_cross_entropy(input, target, weight=None, size_average=None,
|
||||
reduce=None, reduction='mean'):
|
||||
# type: (Tensor, Tensor, Optional[Tensor], Optional[bool], Optional[bool], str) -> Tensor
|
||||
|
|
@ -2113,7 +2047,6 @@ def binary_cross_entropy(input, target, weight=None, size_average=None,
|
|||
input, target, weight, reduction_enum)
|
||||
|
||||
|
||||
@weak_script
|
||||
def binary_cross_entropy_with_logits(input, target, weight=None, size_average=None,
|
||||
reduce=None, reduction='mean', pos_weight=None):
|
||||
# type: (Tensor, Tensor, Optional[Tensor], Optional[bool], Optional[bool], str, Optional[Tensor]) -> Tensor
|
||||
|
|
@ -2174,14 +2107,12 @@ def _pointwise_loss(lambd, lambd_optimized, input, target, reduction='mean'):
|
|||
return lambd_optimized(expanded_input, expanded_target, _Reduction.get_enum(reduction))
|
||||
|
||||
|
||||
@weak_script
|
||||
def _smooth_l1_loss(input, target):
|
||||
# type: (Tensor, Tensor) -> Tensor
|
||||
t = torch.abs(input - target)
|
||||
return torch.where(t < 1, 0.5 * t ** 2, t - 0.5)
|
||||
|
||||
|
||||
@weak_script
|
||||
def smooth_l1_loss(input, target, size_average=None, reduce=None, reduction='mean'):
|
||||
# type: (Tensor, Tensor, Optional[bool], Optional[bool], str) -> Tensor
|
||||
r"""Function that uses a squared term if the absolute
|
||||
|
|
@ -2206,7 +2137,6 @@ def smooth_l1_loss(input, target, size_average=None, reduce=None, reduction='mea
|
|||
return ret
|
||||
|
||||
|
||||
@weak_script
|
||||
def l1_loss(input, target, size_average=None, reduce=None, reduction='mean'):
|
||||
# type: (Tensor, Tensor, Optional[bool], Optional[bool], str) -> Tensor
|
||||
r"""l1_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor
|
||||
|
|
@ -2232,7 +2162,6 @@ def l1_loss(input, target, size_average=None, reduce=None, reduction='mean'):
|
|||
return ret
|
||||
|
||||
|
||||
@weak_script
|
||||
def mse_loss(input, target, size_average=None, reduce=None, reduction='mean'):
|
||||
# type: (Tensor, Tensor, Optional[bool], Optional[bool], str) -> Tensor
|
||||
r"""mse_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor
|
||||
|
|
@ -2258,7 +2187,6 @@ def mse_loss(input, target, size_average=None, reduce=None, reduction='mean'):
|
|||
return ret
|
||||
|
||||
|
||||
@weak_script
|
||||
def margin_ranking_loss(input1, input2, target, margin=0, size_average=None,
|
||||
reduce=None, reduction='mean'):
|
||||
# type: (Tensor, Tensor, Tensor, float, Optional[bool], Optional[bool], str) -> Tensor
|
||||
|
|
@ -2276,7 +2204,6 @@ def margin_ranking_loss(input1, input2, target, margin=0, size_average=None,
|
|||
return torch.margin_ranking_loss(input1, input2, target, margin, reduction_enum)
|
||||
|
||||
|
||||
@weak_script
|
||||
def hinge_embedding_loss(input, target, margin=1.0, size_average=None,
|
||||
reduce=None, reduction='mean'):
|
||||
# type: (Tensor, Tensor, float, Optional[bool], Optional[bool], str) -> Tensor
|
||||
|
|
@ -2291,7 +2218,6 @@ def hinge_embedding_loss(input, target, margin=1.0, size_average=None,
|
|||
return torch.hinge_embedding_loss(input, target, margin, reduction_enum)
|
||||
|
||||
|
||||
@weak_script
|
||||
def multilabel_margin_loss(input, target, size_average=None, reduce=None, reduction='mean'):
|
||||
# type: (Tensor, Tensor, Optional[bool], Optional[bool], str) -> Tensor
|
||||
r"""multilabel_margin_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor
|
||||
|
|
@ -2305,7 +2231,6 @@ def multilabel_margin_loss(input, target, size_average=None, reduce=None, reduct
|
|||
return torch._C._nn.multilabel_margin_loss(input, target, reduction_enum)
|
||||
|
||||
|
||||
@weak_script
|
||||
def soft_margin_loss(input, target, size_average=None, reduce=None, reduction='mean'):
|
||||
# type: (Tensor, Tensor, Optional[bool], Optional[bool], str) -> Tensor
|
||||
r"""soft_margin_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor
|
||||
|
|
@ -2319,7 +2244,6 @@ def soft_margin_loss(input, target, size_average=None, reduce=None, reduction='m
|
|||
return torch._C._nn.soft_margin_loss(input, target, reduction_enum)
|
||||
|
||||
|
||||
@weak_script
|
||||
def multilabel_soft_margin_loss(input, target, weight=None, size_average=None,
|
||||
reduce=None, reduction='mean'):
|
||||
# type: (Tensor, Tensor, Optional[Tensor], Optional[bool], Optional[bool], str) -> Tensor
|
||||
|
|
@ -2349,7 +2273,6 @@ def multilabel_soft_margin_loss(input, target, weight=None, size_average=None,
|
|||
return ret
|
||||
|
||||
|
||||
@weak_script
|
||||
def cosine_embedding_loss(input1, input2, target, margin=0, size_average=None,
|
||||
reduce=None, reduction='mean'):
|
||||
# type: (Tensor, Tensor, Tensor, float, Optional[bool], Optional[bool], str) -> Tensor
|
||||
|
|
@ -2364,7 +2287,6 @@ def cosine_embedding_loss(input1, input2, target, margin=0, size_average=None,
|
|||
return torch.cosine_embedding_loss(input1, input2, target, margin, reduction_enum)
|
||||
|
||||
|
||||
@weak_script
|
||||
def multi_margin_loss(input, target, p=1, margin=1., weight=None, size_average=None,
|
||||
reduce=None, reduction='mean'):
|
||||
# type: (Tensor, Tensor, int, float, Optional[Tensor], Optional[bool], Optional[bool], str) -> Tensor
|
||||
|
|
@ -2635,7 +2557,6 @@ GRID_SAMPLE_PADDING_MODES = {
|
|||
}
|
||||
|
||||
|
||||
@weak_script
|
||||
def grid_sample(input, grid, mode='bilinear', padding_mode='zeros'):
|
||||
# type: (Tensor, Tensor, str, str) -> Tensor
|
||||
r"""Given an :attr:`input` and a flow-field :attr:`grid`, computes the
|
||||
|
|
@ -2717,7 +2638,6 @@ def grid_sample(input, grid, mode='bilinear', padding_mode='zeros'):
|
|||
return torch.grid_sampler(input, grid, mode_enum, padding_mode_enum)
|
||||
|
||||
|
||||
@weak_script
|
||||
def affine_grid(theta, size):
|
||||
# type: (Tensor, List[int]) -> Tensor
|
||||
r"""Generates a 2d flow field, given a batch of affine matrices :attr:`theta`.
|
||||
|
|
@ -2735,7 +2655,6 @@ def affine_grid(theta, size):
|
|||
return vision.affine_grid_generator(theta, size)
|
||||
|
||||
|
||||
@weak_script
|
||||
def pad(input, pad, mode='constant', value=0):
|
||||
# type: (Tensor, List[int], str, float) -> Tensor
|
||||
r"""Pads tensor.
|
||||
|
|
@ -2844,7 +2763,6 @@ def pad(input, pad, mode='constant', value=0):
|
|||
# distance
|
||||
|
||||
|
||||
@weak_script
|
||||
def pairwise_distance(x1, x2, p=2., eps=1e-6, keepdim=False):
|
||||
# type: (Tensor, Tensor, float, float, bool) -> Tensor
|
||||
r"""
|
||||
|
|
@ -2952,7 +2870,6 @@ Examples:
|
|||
""")
|
||||
|
||||
|
||||
@weak_script
|
||||
def triplet_margin_loss(anchor, positive, negative, margin=1.0, p=2, eps=1e-6, swap=False, size_average=None,
|
||||
reduce=None, reduction="mean"):
|
||||
# type: (Tensor, Tensor, Tensor, float, float, float, bool, Optional[bool], Optional[bool], str) -> Tensor
|
||||
|
|
@ -2967,7 +2884,6 @@ def triplet_margin_loss(anchor, positive, negative, margin=1.0, p=2, eps=1e-6, s
|
|||
swap, reduction_enum)
|
||||
|
||||
|
||||
@weak_script
|
||||
def normalize(input, p=2, dim=1, eps=1e-12, out=None):
|
||||
# type: (Tensor, float, int, float, Optional[Tensor]) -> Tensor
|
||||
r"""Performs :math:`L_p` normalization of inputs over specified dimension.
|
||||
|
|
@ -3001,7 +2917,6 @@ def assert_int_or_pair(arg, arg_name, message):
|
|||
assert isinstance(arg, int) or len(arg) == 2, message.format(arg_name)
|
||||
|
||||
|
||||
@weak_script
|
||||
def unfold(input, kernel_size, dilation=1, padding=0, stride=1):
|
||||
# type: (Tensor, BroadcastingList2[int], BroadcastingList2[int], BroadcastingList2[int], BroadcastingList2[int]) -> Tensor # noqa
|
||||
r"""Extracts sliding local blocks from an batched input tensor.
|
||||
|
|
@ -3036,7 +2951,6 @@ def unfold(input, kernel_size, dilation=1, padding=0, stride=1):
|
|||
return ret
|
||||
|
||||
|
||||
@weak_script
|
||||
def fold(input, output_size, kernel_size, dilation=1, padding=0, stride=1):
|
||||
# type: (Tensor, BroadcastingList2[int], BroadcastingList2[int], BroadcastingList2[int], BroadcastingList2[int], BroadcastingList2[int]) -> Tensor # noqa
|
||||
r"""Combines an array of sliding local blocks into a large containing
|
||||
|
|
@ -3064,7 +2978,6 @@ def fold(input, output_size, kernel_size, dilation=1, padding=0, stride=1):
|
|||
return ret
|
||||
|
||||
|
||||
@weak_script
|
||||
def _pad_circular(input, padding):
|
||||
# type: (Tensor, List[int]) -> Tensor
|
||||
"""
|
||||
|
|
@ -3090,7 +3003,6 @@ def _pad_circular(input, padding):
|
|||
return input
|
||||
|
||||
|
||||
@weak_script
|
||||
def multi_head_attention_forward(query, # type: Tensor
|
||||
key, # type: Tensor
|
||||
value, # type: Tensor
|
||||
|
|
@ -3135,8 +3047,8 @@ def multi_head_attention_forward(query, # type: Tensor
|
|||
need_weights: output attn_output_weights.
|
||||
attn_mask: mask that prevents attention to certain positions. This is an additive mask
|
||||
(i.e. the values will be added to the attention layer).
|
||||
use_separate_proj_weight: the function accept the proj. weights for query, key,
|
||||
and value in differnt forms. If false, in_proj_weight will be used, which is
|
||||
use_separate_proj_weight: the function accept the proj. weights for query, key,
|
||||
and value in differnt forms. If false, in_proj_weight will be used, which is
|
||||
a combination of q_proj_weight, k_proj_weight, v_proj_weight.
|
||||
q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.
|
||||
static_k, static_v: static key and value used for attention operators.
|
||||
|
|
@ -3152,9 +3064,9 @@ def multi_head_attention_forward(query, # type: Tensor
|
|||
the embedding dimension.
|
||||
- key_padding_mask: :math:`(N, S)`, ByteTensor, where N is the batch size, S is the source sequence length.
|
||||
- attn_mask: :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
|
||||
- static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
|
||||
- static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
|
||||
N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
|
||||
- static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
|
||||
- static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
|
||||
N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
|
||||
|
||||
Outputs:
|
||||
|
|
@ -3285,12 +3197,12 @@ def multi_head_attention_forward(query, # type: Tensor
|
|||
v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
|
||||
|
||||
if static_k is not None:
|
||||
assert static_k.size(0) == bsz * num_heads
|
||||
assert static_k.size(0) == bsz * num_heads
|
||||
assert static_k.size(2) == head_dim
|
||||
k = static_k
|
||||
|
||||
if static_v is not None:
|
||||
assert static_v.size(0) == bsz * num_heads
|
||||
assert static_v.size(0) == bsz * num_heads
|
||||
assert static_v.size(2) == head_dim
|
||||
v = static_v
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ import math
|
|||
import warnings
|
||||
|
||||
import torch
|
||||
from .._jit_internal import weak_script
|
||||
|
||||
# These no_grad_* functions are necessary as wrappers around the parts of these
|
||||
# functions that use `with torch.no_grad()`. The JIT doesn't support context
|
||||
|
|
@ -72,7 +71,6 @@ def calculate_gain(nonlinearity, param=None):
|
|||
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
|
||||
|
||||
|
||||
@weak_script
|
||||
def uniform_(tensor, a=0., b=1.):
|
||||
# type: (Tensor, float, float) -> Tensor
|
||||
r"""Fills the input Tensor with values drawn from the uniform
|
||||
|
|
@ -90,7 +88,6 @@ def uniform_(tensor, a=0., b=1.):
|
|||
return _no_grad_uniform_(tensor, a, b)
|
||||
|
||||
|
||||
@weak_script
|
||||
def normal_(tensor, mean=0., std=1.):
|
||||
# type: (Tensor, float, float) -> Tensor
|
||||
r"""Fills the input Tensor with values drawn from the normal
|
||||
|
|
@ -108,7 +105,6 @@ def normal_(tensor, mean=0., std=1.):
|
|||
return _no_grad_normal_(tensor, mean, std)
|
||||
|
||||
|
||||
@weak_script
|
||||
def constant_(tensor, val):
|
||||
# type: (Tensor, float) -> Tensor
|
||||
r"""Fills the input Tensor with the value :math:`\text{val}`.
|
||||
|
|
@ -124,7 +120,6 @@ def constant_(tensor, val):
|
|||
return _no_grad_fill_(tensor, val)
|
||||
|
||||
|
||||
@weak_script
|
||||
def ones_(tensor):
|
||||
# type: (Tensor) -> Tensor
|
||||
r"""Fills the input Tensor with ones`.
|
||||
|
|
@ -139,7 +134,6 @@ def ones_(tensor):
|
|||
return _no_grad_fill_(tensor, 1.)
|
||||
|
||||
|
||||
@weak_script
|
||||
def zeros_(tensor):
|
||||
# type: (Tensor) -> Tensor
|
||||
r"""Fills the input Tensor with zeros`.
|
||||
|
|
@ -205,7 +199,6 @@ def dirac_(tensor):
|
|||
return tensor
|
||||
|
||||
|
||||
@weak_script
|
||||
def _calculate_fan_in_and_fan_out(tensor):
|
||||
dimensions = tensor.dim()
|
||||
if dimensions < 2:
|
||||
|
|
@ -226,7 +219,6 @@ def _calculate_fan_in_and_fan_out(tensor):
|
|||
return fan_in, fan_out
|
||||
|
||||
|
||||
@weak_script
|
||||
def xavier_uniform_(tensor, gain=1.):
|
||||
# type: (Tensor, float) -> Tensor
|
||||
r"""Fills the input `Tensor` with values according to the method
|
||||
|
|
@ -255,7 +247,6 @@ def xavier_uniform_(tensor, gain=1.):
|
|||
return _no_grad_uniform_(tensor, -a, a)
|
||||
|
||||
|
||||
@weak_script
|
||||
def xavier_normal_(tensor, gain=1.):
|
||||
# type: (Tensor, float) -> Tensor
|
||||
r"""Fills the input `Tensor` with values according to the method
|
||||
|
|
|
|||
|
|
@ -7,10 +7,8 @@ from torch.nn.init import xavier_normal_
|
|||
from torch.nn.parameter import Parameter
|
||||
from .module import Module
|
||||
from .. import functional as F
|
||||
from ..._jit_internal import weak_module, weak_script_method
|
||||
|
||||
|
||||
@weak_module
|
||||
class Threshold(Module):
|
||||
r"""Thresholds each element of the input Tensor.
|
||||
|
||||
|
|
@ -48,7 +46,6 @@ class Threshold(Module):
|
|||
self.inplace = inplace
|
||||
# TODO: check in THNN (if inplace == True, then assert value <= threshold)
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input):
|
||||
return F.threshold(input, self.threshold, self.value, self.inplace)
|
||||
|
||||
|
|
@ -59,7 +56,6 @@ class Threshold(Module):
|
|||
)
|
||||
|
||||
|
||||
@weak_module
|
||||
class ReLU(Module):
|
||||
r"""Applies the rectified linear unit function element-wise:
|
||||
|
||||
|
|
@ -94,7 +90,6 @@ class ReLU(Module):
|
|||
super(ReLU, self).__init__()
|
||||
self.inplace = inplace
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input):
|
||||
return F.relu(input, inplace=self.inplace)
|
||||
|
||||
|
|
@ -103,7 +98,6 @@ class ReLU(Module):
|
|||
return inplace_str
|
||||
|
||||
|
||||
@weak_module
|
||||
class RReLU(Module):
|
||||
r"""Applies the randomized leaky rectified liner unit function, element-wise,
|
||||
as described in the paper:
|
||||
|
|
@ -151,7 +145,6 @@ class RReLU(Module):
|
|||
self.upper = upper
|
||||
self.inplace = inplace
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input):
|
||||
return F.rrelu(input, self.lower, self.upper, self.training, self.inplace)
|
||||
|
||||
|
|
@ -160,7 +153,6 @@ class RReLU(Module):
|
|||
return 'lower={}, upper={}{}'.format(self.lower, self.upper, inplace_str)
|
||||
|
||||
|
||||
@weak_module
|
||||
class Hardtanh(Module):
|
||||
r"""Applies the HardTanh function element-wise
|
||||
|
||||
|
|
@ -213,7 +205,6 @@ class Hardtanh(Module):
|
|||
self.inplace = inplace
|
||||
assert self.max_val > self.min_val
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input):
|
||||
return F.hardtanh(input, self.min_val, self.max_val, self.inplace)
|
||||
|
||||
|
|
@ -224,7 +215,6 @@ class Hardtanh(Module):
|
|||
)
|
||||
|
||||
|
||||
@weak_module
|
||||
class ReLU6(Hardtanh):
|
||||
r"""Applies the element-wise function:
|
||||
|
||||
|
|
@ -256,7 +246,6 @@ class ReLU6(Hardtanh):
|
|||
return inplace_str
|
||||
|
||||
|
||||
@weak_module
|
||||
class Sigmoid(Module):
|
||||
r"""Applies the element-wise function:
|
||||
|
||||
|
|
@ -278,12 +267,10 @@ class Sigmoid(Module):
|
|||
>>> output = m(input)
|
||||
"""
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input):
|
||||
return torch.sigmoid(input)
|
||||
|
||||
|
||||
@weak_module
|
||||
class Tanh(Module):
|
||||
r"""Applies the element-wise function:
|
||||
|
||||
|
|
@ -304,12 +291,10 @@ class Tanh(Module):
|
|||
>>> output = m(input)
|
||||
"""
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input):
|
||||
return torch.tanh(input)
|
||||
|
||||
|
||||
@weak_module
|
||||
class ELU(Module):
|
||||
r"""Applies the element-wise function:
|
||||
|
||||
|
|
@ -340,7 +325,6 @@ class ELU(Module):
|
|||
self.alpha = alpha
|
||||
self.inplace = inplace
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input):
|
||||
return F.elu(input, self.alpha, self.inplace)
|
||||
|
||||
|
|
@ -349,7 +333,6 @@ class ELU(Module):
|
|||
return 'alpha={}{}'.format(self.alpha, inplace_str)
|
||||
|
||||
|
||||
@weak_module
|
||||
class CELU(Module):
|
||||
r"""Applies the element-wise function:
|
||||
|
||||
|
|
@ -385,7 +368,6 @@ class CELU(Module):
|
|||
self.alpha = alpha
|
||||
self.inplace = inplace
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input):
|
||||
return F.celu(input, self.alpha, self.inplace)
|
||||
|
||||
|
|
@ -394,7 +376,6 @@ class CELU(Module):
|
|||
return 'alpha={}{}'.format(self.alpha, inplace_str)
|
||||
|
||||
|
||||
@weak_module
|
||||
class SELU(Module):
|
||||
r"""Applied element-wise, as:
|
||||
|
||||
|
|
@ -430,7 +411,6 @@ class SELU(Module):
|
|||
super(SELU, self).__init__()
|
||||
self.inplace = inplace
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input):
|
||||
return F.selu(input, self.inplace)
|
||||
|
||||
|
|
@ -439,7 +419,6 @@ class SELU(Module):
|
|||
return inplace_str
|
||||
|
||||
|
||||
@weak_module
|
||||
class GLU(Module):
|
||||
r"""Applies the gated linear unit function
|
||||
:math:`{GLU}(a, b)= a \otimes \sigma(b)` where :math:`a` is the first half
|
||||
|
|
@ -465,7 +444,6 @@ class GLU(Module):
|
|||
super(GLU, self).__init__()
|
||||
self.dim = dim
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input):
|
||||
return F.glu(input, self.dim)
|
||||
|
||||
|
|
@ -473,7 +451,6 @@ class GLU(Module):
|
|||
return 'dim={}'.format(self.dim)
|
||||
|
||||
|
||||
@weak_module
|
||||
class Hardshrink(Module):
|
||||
r"""Applies the hard shrinkage function element-wise:
|
||||
|
||||
|
|
@ -507,7 +484,6 @@ class Hardshrink(Module):
|
|||
super(Hardshrink, self).__init__()
|
||||
self.lambd = lambd
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input):
|
||||
return F.hardshrink(input, self.lambd)
|
||||
|
||||
|
|
@ -515,7 +491,6 @@ class Hardshrink(Module):
|
|||
return '{}'.format(self.lambd)
|
||||
|
||||
|
||||
@weak_module
|
||||
class LeakyReLU(Module):
|
||||
r"""Applies the element-wise function:
|
||||
|
||||
|
|
@ -556,7 +531,6 @@ class LeakyReLU(Module):
|
|||
self.negative_slope = negative_slope
|
||||
self.inplace = inplace
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input):
|
||||
return F.leaky_relu(input, self.negative_slope, self.inplace)
|
||||
|
||||
|
|
@ -565,7 +539,6 @@ class LeakyReLU(Module):
|
|||
return 'negative_slope={}{}'.format(self.negative_slope, inplace_str)
|
||||
|
||||
|
||||
@weak_module
|
||||
class LogSigmoid(Module):
|
||||
r"""Applies the element-wise function:
|
||||
|
||||
|
|
@ -586,12 +559,10 @@ class LogSigmoid(Module):
|
|||
>>> output = m(input)
|
||||
"""
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input):
|
||||
return F.logsigmoid(input)
|
||||
|
||||
|
||||
@weak_module
|
||||
class Softplus(Module):
|
||||
r"""Applies the element-wise function:
|
||||
|
||||
|
|
@ -628,7 +599,6 @@ class Softplus(Module):
|
|||
self.beta = beta
|
||||
self.threshold = threshold
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input):
|
||||
return F.softplus(input, self.beta, self.threshold)
|
||||
|
||||
|
|
@ -636,7 +606,6 @@ class Softplus(Module):
|
|||
return 'beta={}, threshold={}'.format(self.beta, self.threshold)
|
||||
|
||||
|
||||
@weak_module
|
||||
class Softshrink(Module):
|
||||
r"""Applies the soft shrinkage function elementwise:
|
||||
|
||||
|
|
@ -670,7 +639,6 @@ class Softshrink(Module):
|
|||
super(Softshrink, self).__init__()
|
||||
self.lambd = lambd
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input):
|
||||
return F.softshrink(input, self.lambd)
|
||||
|
||||
|
|
@ -678,7 +646,6 @@ class Softshrink(Module):
|
|||
return str(self.lambd)
|
||||
|
||||
|
||||
@weak_module
|
||||
class MultiheadAttention(Module):
|
||||
r"""Allows the model to jointly attend to information
|
||||
from different representation subspaces.
|
||||
|
|
@ -759,7 +726,6 @@ class MultiheadAttention(Module):
|
|||
if self.bias_v is not None:
|
||||
xavier_normal_(self.bias_v)
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, query, key, value, key_padding_mask=None,
|
||||
need_weights=True, attn_mask=None):
|
||||
r"""
|
||||
|
|
@ -817,7 +783,6 @@ class MultiheadAttention(Module):
|
|||
attn_mask=attn_mask)
|
||||
|
||||
|
||||
@weak_module
|
||||
class PReLU(Module):
|
||||
r"""Applies the element-wise function:
|
||||
|
||||
|
|
@ -874,7 +839,6 @@ class PReLU(Module):
|
|||
super(PReLU, self).__init__()
|
||||
self.weight = Parameter(torch.Tensor(num_parameters).fill_(init))
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input):
|
||||
return F.prelu(input, self.weight)
|
||||
|
||||
|
|
@ -882,7 +846,6 @@ class PReLU(Module):
|
|||
return 'num_parameters={}'.format(self.num_parameters)
|
||||
|
||||
|
||||
@weak_module
|
||||
class Softsign(Module):
|
||||
r"""Applies the element-wise function:
|
||||
|
||||
|
|
@ -903,12 +866,10 @@ class Softsign(Module):
|
|||
>>> output = m(input)
|
||||
"""
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input):
|
||||
return F.softsign(input)
|
||||
|
||||
|
||||
@weak_module
|
||||
class Tanhshrink(Module):
|
||||
r"""Applies the element-wise function:
|
||||
|
||||
|
|
@ -929,12 +890,10 @@ class Tanhshrink(Module):
|
|||
>>> output = m(input)
|
||||
"""
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input):
|
||||
return F.tanhshrink(input)
|
||||
|
||||
|
||||
@weak_module
|
||||
class Softmin(Module):
|
||||
r"""Applies the Softmin function to an n-dimensional input Tensor
|
||||
rescaling them so that the elements of the n-dimensional output Tensor
|
||||
|
|
@ -970,12 +929,10 @@ class Softmin(Module):
|
|||
super(Softmin, self).__init__()
|
||||
self.dim = dim
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input):
|
||||
return F.softmin(input, self.dim, _stacklevel=5)
|
||||
|
||||
|
||||
@weak_module
|
||||
class Softmax(Module):
|
||||
r"""Applies the Softmax function to an n-dimensional input Tensor
|
||||
rescaling them so that the elements of the n-dimensional output Tensor
|
||||
|
|
@ -1021,7 +978,6 @@ class Softmax(Module):
|
|||
if not hasattr(self, 'dim'):
|
||||
self.dim = None
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input):
|
||||
return F.softmax(input, self.dim, _stacklevel=5)
|
||||
|
||||
|
|
@ -1029,7 +985,6 @@ class Softmax(Module):
|
|||
return 'dim={dim}'.format(dim=self.dim)
|
||||
|
||||
|
||||
@weak_module
|
||||
class Softmax2d(Module):
|
||||
r"""Applies SoftMax over features to each spatial location.
|
||||
|
||||
|
|
@ -1052,13 +1007,11 @@ class Softmax2d(Module):
|
|||
>>> output = m(input)
|
||||
"""
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input):
|
||||
assert input.dim() == 4, 'Softmax2d requires a 4D tensor as input'
|
||||
return F.softmax(input, 1, _stacklevel=5)
|
||||
|
||||
|
||||
@weak_module
|
||||
class LogSoftmax(Module):
|
||||
r"""Applies the :math:`\log(\text{Softmax}(x))` function to an n-dimensional
|
||||
input Tensor. The LogSoftmax formulation can be simplified as:
|
||||
|
|
@ -1095,6 +1048,5 @@ class LogSoftmax(Module):
|
|||
if not hasattr(self, 'dim'):
|
||||
self.dim = None
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input):
|
||||
return F.log_softmax(input, self.dim, _stacklevel=5)
|
||||
|
|
|
|||
|
|
@ -6,12 +6,10 @@ from .module import Module
|
|||
from torch.nn.parameter import Parameter
|
||||
from .. import functional as F
|
||||
from .. import init
|
||||
from ..._jit_internal import weak_module, weak_script_method
|
||||
|
||||
|
||||
# TODO: check contiguous in THNN
|
||||
# TODO: use separate backend functions?
|
||||
@weak_module
|
||||
class _BatchNorm(Module):
|
||||
_version = 2
|
||||
__constants__ = ['track_running_stats', 'momentum', 'eps', 'weight', 'bias',
|
||||
|
|
@ -57,7 +55,6 @@ class _BatchNorm(Module):
|
|||
def _check_input_dim(self, input):
|
||||
raise NotImplementedError
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input):
|
||||
self._check_input_dim(input)
|
||||
|
||||
|
|
@ -103,7 +100,6 @@ class _BatchNorm(Module):
|
|||
missing_keys, unexpected_keys, error_msgs)
|
||||
|
||||
|
||||
@weak_module
|
||||
class BatchNorm1d(_BatchNorm):
|
||||
r"""Applies Batch Normalization over a 2D or 3D input (a mini-batch of 1D
|
||||
inputs with optional additional channel dimension) as described in the paper
|
||||
|
|
@ -170,14 +166,12 @@ class BatchNorm1d(_BatchNorm):
|
|||
https://arxiv.org/abs/1502.03167
|
||||
"""
|
||||
|
||||
@weak_script_method
|
||||
def _check_input_dim(self, input):
|
||||
if input.dim() != 2 and input.dim() != 3:
|
||||
raise ValueError('expected 2D or 3D input (got {}D input)'
|
||||
.format(input.dim()))
|
||||
|
||||
|
||||
@weak_module
|
||||
class BatchNorm2d(_BatchNorm):
|
||||
r"""Applies Batch Normalization over a 4D input (a mini-batch of 2D inputs
|
||||
with additional channel dimension) as described in the paper
|
||||
|
|
@ -244,14 +238,12 @@ class BatchNorm2d(_BatchNorm):
|
|||
https://arxiv.org/abs/1502.03167
|
||||
"""
|
||||
|
||||
@weak_script_method
|
||||
def _check_input_dim(self, input):
|
||||
if input.dim() != 4:
|
||||
raise ValueError('expected 4D input (got {}D input)'
|
||||
.format(input.dim()))
|
||||
|
||||
|
||||
@weak_module
|
||||
class BatchNorm3d(_BatchNorm):
|
||||
r"""Applies Batch Normalization over a 5D input (a mini-batch of 3D inputs
|
||||
with additional channel dimension) as described in the paper
|
||||
|
|
@ -319,7 +311,6 @@ class BatchNorm3d(_BatchNorm):
|
|||
https://arxiv.org/abs/1502.03167
|
||||
"""
|
||||
|
||||
@weak_script_method
|
||||
def _check_input_dim(self, input):
|
||||
if input.dim() != 5:
|
||||
raise ValueError('expected 5D input (got {}D input)'
|
||||
|
|
|
|||
|
|
@ -6,10 +6,9 @@ from .. import functional as F
|
|||
from .. import init
|
||||
from .module import Module
|
||||
from .utils import _single, _pair, _triple
|
||||
from ..._jit_internal import weak_module, weak_script_method, List
|
||||
from ..._jit_internal import List
|
||||
|
||||
|
||||
@weak_module
|
||||
class _ConvNd(Module):
|
||||
|
||||
__constants__ = ['stride', 'padding', 'dilation', 'groups', 'bias',
|
||||
|
|
@ -74,7 +73,6 @@ class _ConvNd(Module):
|
|||
self.padding_mode = 'zeros'
|
||||
|
||||
|
||||
@weak_module
|
||||
class Conv1d(_ConvNd):
|
||||
r"""Applies a 1D convolution over an input signal composed of several input
|
||||
planes.
|
||||
|
|
@ -192,7 +190,6 @@ class Conv1d(_ConvNd):
|
|||
in_channels, out_channels, kernel_size, stride, padding, dilation,
|
||||
False, _single(0), groups, bias, padding_mode)
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input):
|
||||
if self.padding_mode == 'circular':
|
||||
expanded_padding = ((self.padding[0] + 1) // 2, self.padding[0] // 2)
|
||||
|
|
@ -203,7 +200,6 @@ class Conv1d(_ConvNd):
|
|||
self.padding, self.dilation, self.groups)
|
||||
|
||||
|
||||
@weak_module
|
||||
class Conv2d(_ConvNd):
|
||||
r"""Applies a 2D convolution over an input signal composed of several input
|
||||
planes.
|
||||
|
|
@ -333,7 +329,6 @@ class Conv2d(_ConvNd):
|
|||
in_channels, out_channels, kernel_size, stride, padding, dilation,
|
||||
False, _pair(0), groups, bias, padding_mode)
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input):
|
||||
if self.padding_mode == 'circular':
|
||||
expanded_padding = ((self.padding[1] + 1) // 2, self.padding[1] // 2,
|
||||
|
|
@ -345,7 +340,6 @@ class Conv2d(_ConvNd):
|
|||
self.padding, self.dilation, self.groups)
|
||||
|
||||
|
||||
@weak_module
|
||||
class Conv3d(_ConvNd):
|
||||
r"""Applies a 3D convolution over an input signal composed of several input
|
||||
planes.
|
||||
|
|
@ -470,7 +464,6 @@ class Conv3d(_ConvNd):
|
|||
in_channels, out_channels, kernel_size, stride, padding, dilation,
|
||||
False, _triple(0), groups, bias, padding_mode)
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input):
|
||||
if self.padding_mode == 'circular':
|
||||
expanded_padding = ((self.padding[2] + 1) // 2, self.padding[2] // 2,
|
||||
|
|
@ -483,9 +476,7 @@ class Conv3d(_ConvNd):
|
|||
self.padding, self.dilation, self.groups)
|
||||
|
||||
|
||||
@weak_module
|
||||
class _ConvTransposeMixin(object):
|
||||
@weak_script_method
|
||||
def forward(self, input, output_size=None):
|
||||
# type(Tensor, Optional[List[int]]) -> Tensor
|
||||
output_padding = self._output_padding(input, output_size, self.stride, self.padding, self.kernel_size)
|
||||
|
|
@ -497,7 +488,6 @@ class _ConvTransposeMixin(object):
|
|||
else:
|
||||
return func(input, self.weight, self.bias)
|
||||
|
||||
@weak_script_method
|
||||
def _output_padding(self, input, output_size, stride, padding, kernel_size):
|
||||
# type: (Tensor, Optional[List[int]], List[int], List[int], List[int]) -> List[int]
|
||||
if output_size is None:
|
||||
|
|
@ -537,7 +527,6 @@ class _ConvTransposeMixin(object):
|
|||
return ret
|
||||
|
||||
|
||||
@weak_module
|
||||
class ConvTranspose1d(_ConvTransposeMixin, _ConvNd):
|
||||
r"""Applies a 1D transposed convolution operator over an input image
|
||||
composed of several input planes.
|
||||
|
|
@ -638,7 +627,6 @@ class ConvTranspose1d(_ConvTransposeMixin, _ConvNd):
|
|||
in_channels, out_channels, kernel_size, stride, padding, dilation,
|
||||
True, output_padding, groups, bias, padding_mode)
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input, output_size=None):
|
||||
# type: (Tensor, Optional[List[int]]) -> Tensor
|
||||
if self.padding_mode != 'zeros':
|
||||
|
|
@ -650,7 +638,6 @@ class ConvTranspose1d(_ConvTransposeMixin, _ConvNd):
|
|||
output_padding, self.groups, self.dilation)
|
||||
|
||||
|
||||
@weak_module
|
||||
class ConvTranspose2d(_ConvTransposeMixin, _ConvNd):
|
||||
r"""Applies a 2D transposed convolution operator over an input image
|
||||
composed of several input planes.
|
||||
|
|
@ -786,7 +773,6 @@ class ConvTranspose2d(_ConvTransposeMixin, _ConvNd):
|
|||
in_channels, out_channels, kernel_size, stride, padding, dilation,
|
||||
True, output_padding, groups, bias, padding_mode)
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input, output_size=None):
|
||||
# type: (Tensor, Optional[List[int]]) -> Tensor
|
||||
if self.padding_mode != 'zeros':
|
||||
|
|
@ -799,7 +785,6 @@ class ConvTranspose2d(_ConvTransposeMixin, _ConvNd):
|
|||
output_padding, self.groups, self.dilation)
|
||||
|
||||
|
||||
@weak_module
|
||||
class ConvTranspose3d(_ConvTransposeMixin, _ConvNd):
|
||||
r"""Applies a 3D transposed convolution operator over an input image composed of several input
|
||||
planes.
|
||||
|
|
@ -931,7 +916,6 @@ class ConvTranspose3d(_ConvTransposeMixin, _ConvNd):
|
|||
in_channels, out_channels, kernel_size, stride, padding, dilation,
|
||||
True, output_padding, groups, bias, padding_mode)
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input, output_size=None):
|
||||
# type: (Tensor, Optional[List[int]]) -> Tensor
|
||||
if self.padding_mode != 'zeros':
|
||||
|
|
|
|||
|
|
@ -1,9 +1,7 @@
|
|||
from .module import Module
|
||||
from .. import functional as F
|
||||
from ..._jit_internal import weak_module, weak_script_method
|
||||
|
||||
|
||||
@weak_module
|
||||
class PairwiseDistance(Module):
|
||||
r"""
|
||||
Computes the batchwise pairwise distance between vectors :math:`v_1`, :math:`v_2` using the p-norm:
|
||||
|
|
@ -35,12 +33,10 @@ class PairwiseDistance(Module):
|
|||
self.eps = eps
|
||||
self.keepdim = keepdim
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, x1, x2):
|
||||
return F.pairwise_distance(x1, x2, self.norm, self.eps, self.keepdim)
|
||||
|
||||
|
||||
@weak_module
|
||||
class CosineSimilarity(Module):
|
||||
r"""Returns cosine similarity between :math:`x_1` and :math:`x_2`, computed along dim.
|
||||
|
||||
|
|
@ -68,6 +64,5 @@ class CosineSimilarity(Module):
|
|||
self.dim = dim
|
||||
self.eps = eps
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, x1, x2):
|
||||
return F.cosine_similarity(x1, x2, self.dim, self.eps)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
from .module import Module
|
||||
from .. import functional as F
|
||||
from ..._jit_internal import weak_module, weak_script_method
|
||||
|
||||
|
||||
class _DropoutNd(Module):
|
||||
|
|
@ -18,7 +17,6 @@ class _DropoutNd(Module):
|
|||
return 'p={}, inplace={}'.format(self.p, self.inplace)
|
||||
|
||||
|
||||
@weak_module
|
||||
class Dropout(_DropoutNd):
|
||||
r"""During training, randomly zeroes some of the elements of the input
|
||||
tensor with probability :attr:`p` using samples from a Bernoulli
|
||||
|
|
@ -52,12 +50,10 @@ class Dropout(_DropoutNd):
|
|||
detectors: https://arxiv.org/abs/1207.0580
|
||||
"""
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input):
|
||||
return F.dropout(input, self.p, self.training, self.inplace)
|
||||
|
||||
|
||||
@weak_module
|
||||
class Dropout2d(_DropoutNd):
|
||||
r"""Randomly zero out entire channels (a channel is a 2D feature map,
|
||||
e.g., the :math:`j`-th channel of the :math:`i`-th sample in the
|
||||
|
|
@ -96,12 +92,10 @@ class Dropout2d(_DropoutNd):
|
|||
http://arxiv.org/abs/1411.4280
|
||||
"""
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input):
|
||||
return F.dropout2d(input, self.p, self.training, self.inplace)
|
||||
|
||||
|
||||
@weak_module
|
||||
class Dropout3d(_DropoutNd):
|
||||
r"""Randomly zero out entire channels (a channel is a 3D feature map,
|
||||
e.g., the :math:`j`-th channel of the :math:`i`-th sample in the
|
||||
|
|
@ -140,12 +134,10 @@ class Dropout3d(_DropoutNd):
|
|||
http://arxiv.org/abs/1411.4280
|
||||
"""
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input):
|
||||
return F.dropout3d(input, self.p, self.training, self.inplace)
|
||||
|
||||
|
||||
@weak_module
|
||||
class AlphaDropout(_DropoutNd):
|
||||
r"""Applies Alpha Dropout over the input.
|
||||
|
||||
|
|
@ -184,14 +176,11 @@ class AlphaDropout(_DropoutNd):
|
|||
.. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515
|
||||
"""
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input):
|
||||
return F.alpha_dropout(input, self.p, self.training)
|
||||
|
||||
|
||||
@weak_module
|
||||
class FeatureAlphaDropout(_DropoutNd):
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input):
|
||||
return F.feature_alpha_dropout(input, self.p, self.training)
|
||||
|
|
|
|||
|
|
@ -1,10 +1,8 @@
|
|||
# coding=utf-8
|
||||
from .module import Module
|
||||
from .. import functional as F
|
||||
from ..._jit_internal import weak_module, weak_script_method
|
||||
|
||||
|
||||
@weak_module
|
||||
class Fold(Module):
|
||||
r"""Combines an array of sliding local blocks into a large containing
|
||||
tensor.
|
||||
|
|
@ -101,7 +99,6 @@ class Fold(Module):
|
|||
self.padding = padding
|
||||
self.stride = stride
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input):
|
||||
return F.fold(input, self.output_size, self.kernel_size, self.dilation,
|
||||
self.padding, self.stride)
|
||||
|
|
@ -113,7 +110,6 @@ class Fold(Module):
|
|||
)
|
||||
|
||||
|
||||
@weak_module
|
||||
class Unfold(Module):
|
||||
r"""Extracts sliding local blocks from a batched input tensor.
|
||||
|
||||
|
|
@ -217,7 +213,6 @@ class Unfold(Module):
|
|||
self.padding = padding
|
||||
self.stride = stride
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input):
|
||||
return F.unfold(input, self.kernel_size, self.dilation,
|
||||
self.padding, self.stride)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
from .batchnorm import _BatchNorm
|
||||
from .. import functional as F
|
||||
from ..._jit_internal import weak_module, weak_script_method
|
||||
|
||||
|
||||
class _InstanceNorm(_BatchNorm):
|
||||
|
|
@ -9,7 +8,6 @@ class _InstanceNorm(_BatchNorm):
|
|||
super(_InstanceNorm, self).__init__(
|
||||
num_features, eps, momentum, affine, track_running_stats)
|
||||
|
||||
@weak_script_method
|
||||
def _check_input_dim(self, input):
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
@ -43,7 +41,6 @@ class _InstanceNorm(_BatchNorm):
|
|||
state_dict, prefix, local_metadata, strict,
|
||||
missing_keys, unexpected_keys, error_msgs)
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input):
|
||||
self._check_input_dim(input)
|
||||
|
||||
|
|
@ -52,7 +49,6 @@ class _InstanceNorm(_BatchNorm):
|
|||
self.training or not self.track_running_stats, self.momentum, self.eps)
|
||||
|
||||
|
||||
@weak_module
|
||||
class InstanceNorm1d(_InstanceNorm):
|
||||
r"""Applies Instance Normalization over a 3D input (a mini-batch of 1D
|
||||
inputs with optional additional channel dimension) as described in the paper
|
||||
|
|
@ -121,7 +117,6 @@ class InstanceNorm1d(_InstanceNorm):
|
|||
https://arxiv.org/abs/1607.08022
|
||||
"""
|
||||
|
||||
@weak_script_method
|
||||
def _check_input_dim(self, input):
|
||||
if input.dim() == 2:
|
||||
raise ValueError(
|
||||
|
|
@ -135,7 +130,6 @@ class InstanceNorm1d(_InstanceNorm):
|
|||
.format(input.dim()))
|
||||
|
||||
|
||||
@weak_module
|
||||
class InstanceNorm2d(_InstanceNorm):
|
||||
r"""Applies Instance Normalization over a 4D input (a mini-batch of 2D inputs
|
||||
with additional channel dimension) as described in the paper
|
||||
|
|
@ -204,14 +198,12 @@ class InstanceNorm2d(_InstanceNorm):
|
|||
https://arxiv.org/abs/1607.08022
|
||||
"""
|
||||
|
||||
@weak_script_method
|
||||
def _check_input_dim(self, input):
|
||||
if input.dim() != 4:
|
||||
raise ValueError('expected 4D input (got {}D input)'
|
||||
.format(input.dim()))
|
||||
|
||||
|
||||
@weak_module
|
||||
class InstanceNorm3d(_InstanceNorm):
|
||||
r"""Applies Instance Normalization over a 5D input (a mini-batch of 3D inputs
|
||||
with additional channel dimension) as described in the paper
|
||||
|
|
@ -280,7 +272,6 @@ class InstanceNorm3d(_InstanceNorm):
|
|||
https://arxiv.org/abs/1607.08022
|
||||
"""
|
||||
|
||||
@weak_script_method
|
||||
def _check_input_dim(self, input):
|
||||
if input.dim() != 5:
|
||||
raise ValueError('expected 5D input (got {}D input)'
|
||||
|
|
|
|||
|
|
@ -5,10 +5,8 @@ from torch.nn.parameter import Parameter
|
|||
from .. import functional as F
|
||||
from .. import init
|
||||
from .module import Module
|
||||
from ..._jit_internal import weak_module, weak_script_method
|
||||
|
||||
|
||||
@weak_module
|
||||
class Identity(Module):
|
||||
r"""A placeholder identity operator that is argument-insensitive.
|
||||
|
||||
|
|
@ -28,12 +26,10 @@ class Identity(Module):
|
|||
def __init__(self, *args, **kwargs):
|
||||
super(Identity, self).__init__()
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input):
|
||||
return input
|
||||
|
||||
|
||||
@weak_module
|
||||
class Linear(Module):
|
||||
r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b`
|
||||
|
||||
|
|
@ -87,7 +83,6 @@ class Linear(Module):
|
|||
bound = 1 / math.sqrt(fan_in)
|
||||
init.uniform_(self.bias, -bound, bound)
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input):
|
||||
return F.linear(input, self.weight, self.bias)
|
||||
|
||||
|
|
@ -97,7 +92,6 @@ class Linear(Module):
|
|||
)
|
||||
|
||||
|
||||
@weak_module
|
||||
class Bilinear(Module):
|
||||
r"""Applies a bilinear transformation to the incoming data:
|
||||
:math:`y = x_1 A x_2 + b`
|
||||
|
|
@ -157,7 +151,6 @@ class Bilinear(Module):
|
|||
if self.bias is not None:
|
||||
init.uniform_(self.bias, -bound, bound)
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input1, input2):
|
||||
return F.bilinear(input1, input2, self.weight, self.bias)
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ import warnings
|
|||
from .module import Module
|
||||
from .. import functional as F
|
||||
from .. import _reduction as _Reduction
|
||||
from ..._jit_internal import weak_module, weak_script_method
|
||||
|
||||
|
||||
class _Loss(Module):
|
||||
|
|
@ -21,7 +20,6 @@ class _WeightedLoss(_Loss):
|
|||
self.register_buffer('weight', weight)
|
||||
|
||||
|
||||
@weak_module
|
||||
class L1Loss(_Loss):
|
||||
r"""Creates a criterion that measures the mean absolute error (MAE) between each element in
|
||||
the input :math:`x` and target :math:`y`.
|
||||
|
|
@ -86,12 +84,10 @@ class L1Loss(_Loss):
|
|||
def __init__(self, size_average=None, reduce=None, reduction='mean'):
|
||||
super(L1Loss, self).__init__(size_average, reduce, reduction)
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input, target):
|
||||
return F.l1_loss(input, target, reduction=self.reduction)
|
||||
|
||||
|
||||
@weak_module
|
||||
class NLLLoss(_WeightedLoss):
|
||||
r"""The negative log likelihood loss. It is useful to train a classification
|
||||
problem with `C` classes.
|
||||
|
|
@ -204,12 +200,10 @@ class NLLLoss(_WeightedLoss):
|
|||
super(NLLLoss, self).__init__(weight, size_average, reduce, reduction)
|
||||
self.ignore_index = ignore_index
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input, target):
|
||||
return F.nll_loss(input, target, weight=self.weight, ignore_index=self.ignore_index, reduction=self.reduction)
|
||||
|
||||
|
||||
@weak_module
|
||||
class NLLLoss2d(NLLLoss):
|
||||
def __init__(self, weight=None, size_average=None, ignore_index=-100,
|
||||
reduce=None, reduction='mean'):
|
||||
|
|
@ -219,7 +213,6 @@ class NLLLoss2d(NLLLoss):
|
|||
super(NLLLoss2d, self).__init__(weight, size_average, ignore_index, reduce, reduction)
|
||||
|
||||
|
||||
@weak_module
|
||||
class PoissonNLLLoss(_Loss):
|
||||
r"""Negative log likelihood loss with Poisson distribution of target.
|
||||
|
||||
|
|
@ -286,13 +279,11 @@ class PoissonNLLLoss(_Loss):
|
|||
self.full = full
|
||||
self.eps = eps
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, log_input, target):
|
||||
return F.poisson_nll_loss(log_input, target, log_input=self.log_input, full=self.full,
|
||||
eps=self.eps, reduction=self.reduction)
|
||||
|
||||
|
||||
@weak_module
|
||||
class KLDivLoss(_Loss):
|
||||
r"""The `Kullback-Leibler divergence`_ Loss
|
||||
|
||||
|
|
@ -370,12 +361,10 @@ class KLDivLoss(_Loss):
|
|||
def __init__(self, size_average=None, reduce=None, reduction='mean'):
|
||||
super(KLDivLoss, self).__init__(size_average, reduce, reduction)
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input, target):
|
||||
return F.kl_div(input, target, reduction=self.reduction)
|
||||
|
||||
|
||||
@weak_module
|
||||
class MSELoss(_Loss):
|
||||
r"""Creates a criterion that measures the mean squared error (squared L2 norm) between
|
||||
each element in the input :math:`x` and target :math:`y`.
|
||||
|
|
@ -438,12 +427,10 @@ class MSELoss(_Loss):
|
|||
def __init__(self, size_average=None, reduce=None, reduction='mean'):
|
||||
super(MSELoss, self).__init__(size_average, reduce, reduction)
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input, target):
|
||||
return F.mse_loss(input, target, reduction=self.reduction)
|
||||
|
||||
|
||||
@weak_module
|
||||
class BCELoss(_WeightedLoss):
|
||||
r"""Creates a criterion that measures the Binary Cross Entropy
|
||||
between the target and the output:
|
||||
|
|
@ -507,12 +494,10 @@ class BCELoss(_WeightedLoss):
|
|||
def __init__(self, weight=None, size_average=None, reduce=None, reduction='mean'):
|
||||
super(BCELoss, self).__init__(weight, size_average, reduce, reduction)
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input, target):
|
||||
return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)
|
||||
|
||||
|
||||
@weak_module
|
||||
class BCEWithLogitsLoss(_Loss):
|
||||
r"""This loss combines a `Sigmoid` layer and the `BCELoss` in one single
|
||||
class. This version is more numerically stable than using a plain `Sigmoid`
|
||||
|
|
@ -609,7 +594,6 @@ class BCEWithLogitsLoss(_Loss):
|
|||
self.register_buffer('weight', weight)
|
||||
self.register_buffer('pos_weight', pos_weight)
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input, target):
|
||||
return F.binary_cross_entropy_with_logits(input, target,
|
||||
self.weight,
|
||||
|
|
@ -617,7 +601,6 @@ class BCEWithLogitsLoss(_Loss):
|
|||
reduction=self.reduction)
|
||||
|
||||
|
||||
@weak_module
|
||||
class HingeEmbeddingLoss(_Loss):
|
||||
r"""Measures the loss given an input tensor :math:`x` and a labels tensor :math:`y`
|
||||
(containing 1 or -1).
|
||||
|
|
@ -673,12 +656,10 @@ class HingeEmbeddingLoss(_Loss):
|
|||
super(HingeEmbeddingLoss, self).__init__(size_average, reduce, reduction)
|
||||
self.margin = margin
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input, target):
|
||||
return F.hinge_embedding_loss(input, target, margin=self.margin, reduction=self.reduction)
|
||||
|
||||
|
||||
@weak_module
|
||||
class MultiLabelMarginLoss(_Loss):
|
||||
r"""Creates a criterion that optimizes a multi-class multi-classification
|
||||
hinge loss (margin-based loss) between input :math:`x` (a 2D mini-batch `Tensor`)
|
||||
|
|
@ -739,12 +720,10 @@ class MultiLabelMarginLoss(_Loss):
|
|||
def __init__(self, size_average=None, reduce=None, reduction='mean'):
|
||||
super(MultiLabelMarginLoss, self).__init__(size_average, reduce, reduction)
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input, target):
|
||||
return F.multilabel_margin_loss(input, target, reduction=self.reduction)
|
||||
|
||||
|
||||
@weak_module
|
||||
class SmoothL1Loss(_Loss):
|
||||
r"""Creates a criterion that uses a squared term if the absolute
|
||||
element-wise error falls below 1 and an L1 term otherwise.
|
||||
|
|
@ -799,12 +778,10 @@ class SmoothL1Loss(_Loss):
|
|||
def __init__(self, size_average=None, reduce=None, reduction='mean'):
|
||||
super(SmoothL1Loss, self).__init__(size_average, reduce, reduction)
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input, target):
|
||||
return F.smooth_l1_loss(input, target, reduction=self.reduction)
|
||||
|
||||
|
||||
@weak_module
|
||||
class SoftMarginLoss(_Loss):
|
||||
r"""Creates a criterion that optimizes a two-class classification
|
||||
logistic loss between input tensor :math:`x` and target tensor :math:`y`
|
||||
|
|
@ -842,12 +819,10 @@ class SoftMarginLoss(_Loss):
|
|||
def __init__(self, size_average=None, reduce=None, reduction='mean'):
|
||||
super(SoftMarginLoss, self).__init__(size_average, reduce, reduction)
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input, target):
|
||||
return F.soft_margin_loss(input, target, reduction=self.reduction)
|
||||
|
||||
|
||||
@weak_module
|
||||
class CrossEntropyLoss(_WeightedLoss):
|
||||
r"""This criterion combines :func:`nn.LogSoftmax` and :func:`nn.NLLLoss` in one single class.
|
||||
|
||||
|
|
@ -936,13 +911,11 @@ class CrossEntropyLoss(_WeightedLoss):
|
|||
super(CrossEntropyLoss, self).__init__(weight, size_average, reduce, reduction)
|
||||
self.ignore_index = ignore_index
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input, target):
|
||||
return F.cross_entropy(input, target, weight=self.weight,
|
||||
ignore_index=self.ignore_index, reduction=self.reduction)
|
||||
|
||||
|
||||
@weak_module
|
||||
class MultiLabelSoftMarginLoss(_WeightedLoss):
|
||||
r"""Creates a criterion that optimizes a multi-label one-versus-all
|
||||
loss based on max-entropy, between input :math:`x` and target :math:`y` of size
|
||||
|
|
@ -986,12 +959,10 @@ class MultiLabelSoftMarginLoss(_WeightedLoss):
|
|||
def __init__(self, weight=None, size_average=None, reduce=None, reduction='mean'):
|
||||
super(MultiLabelSoftMarginLoss, self).__init__(weight, size_average, reduce, reduction)
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input, target):
|
||||
return F.multilabel_soft_margin_loss(input, target, weight=self.weight, reduction=self.reduction)
|
||||
|
||||
|
||||
@weak_module
|
||||
class CosineEmbeddingLoss(_Loss):
|
||||
r"""Creates a criterion that measures the loss given input tensors
|
||||
:math:`x_1`, :math:`x_2` and a `Tensor` label :math:`y` with values 1 or -1.
|
||||
|
|
@ -1034,12 +1005,10 @@ class CosineEmbeddingLoss(_Loss):
|
|||
super(CosineEmbeddingLoss, self).__init__(size_average, reduce, reduction)
|
||||
self.margin = margin
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input1, input2, target):
|
||||
return F.cosine_embedding_loss(input1, input2, target, margin=self.margin, reduction=self.reduction)
|
||||
|
||||
|
||||
@weak_module
|
||||
class MarginRankingLoss(_Loss):
|
||||
r"""Creates a criterion that measures the loss given
|
||||
inputs :math:`x1`, :math:`x2`, two 1D mini-batch `Tensors`,
|
||||
|
|
@ -1082,12 +1051,10 @@ class MarginRankingLoss(_Loss):
|
|||
super(MarginRankingLoss, self).__init__(size_average, reduce, reduction)
|
||||
self.margin = margin
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input1, input2, target):
|
||||
return F.margin_ranking_loss(input1, input2, target, margin=self.margin, reduction=self.reduction)
|
||||
|
||||
|
||||
@weak_module
|
||||
class MultiMarginLoss(_WeightedLoss):
|
||||
r"""Creates a criterion that optimizes a multi-class classification hinge
|
||||
loss (margin-based loss) between input :math:`x` (a 2D mini-batch `Tensor`) and
|
||||
|
|
@ -1145,13 +1112,11 @@ class MultiMarginLoss(_WeightedLoss):
|
|||
self.p = p
|
||||
self.margin = margin
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input, target):
|
||||
return F.multi_margin_loss(input, target, p=self.p, margin=self.margin,
|
||||
weight=self.weight, reduction=self.reduction)
|
||||
|
||||
|
||||
@weak_module
|
||||
class TripletMarginLoss(_Loss):
|
||||
r"""Creates a criterion that measures the triplet loss given an input
|
||||
tensors :math:`x1`, :math:`x2`, :math:`x3` and a margin with a value greater than :math:`0`.
|
||||
|
|
@ -1221,13 +1186,11 @@ class TripletMarginLoss(_Loss):
|
|||
self.eps = eps
|
||||
self.swap = swap
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, anchor, positive, negative):
|
||||
return F.triplet_margin_loss(anchor, positive, negative, margin=self.margin, p=self.p,
|
||||
eps=self.eps, swap=self.swap, reduction=self.reduction)
|
||||
|
||||
|
||||
@weak_module
|
||||
class CTCLoss(_Loss):
|
||||
r"""The Connectionist Temporal Classification loss.
|
||||
|
||||
|
|
@ -1327,7 +1290,6 @@ class CTCLoss(_Loss):
|
|||
self.blank = blank
|
||||
self.zero_infinity = zero_infinity
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, log_probs, targets, input_lengths, target_lengths):
|
||||
return F.ctc_loss(log_probs, targets, input_lengths, target_lengths, self.blank, self.reduction,
|
||||
self.zero_infinity)
|
||||
|
|
|
|||
|
|
@ -4,10 +4,8 @@ from torch.nn.parameter import Parameter
|
|||
from .module import Module
|
||||
from .. import functional as F
|
||||
from .. import init
|
||||
from ..._jit_internal import weak_module, weak_script_method
|
||||
|
||||
|
||||
@weak_module
|
||||
class LocalResponseNorm(Module):
|
||||
r"""Applies local response normalization over an input signal composed
|
||||
of several input planes, where channels occupy the second dimension.
|
||||
|
|
@ -45,7 +43,6 @@ class LocalResponseNorm(Module):
|
|||
self.beta = beta
|
||||
self.k = k
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input):
|
||||
return F.local_response_norm(input, self.size, self.alpha, self.beta,
|
||||
self.k)
|
||||
|
|
@ -71,7 +68,6 @@ class CrossMapLRN2d(Module):
|
|||
return '{size}, alpha={alpha}, beta={beta}, k={k}'.format(**self.__dict__)
|
||||
|
||||
|
||||
@weak_module
|
||||
class LayerNorm(Module):
|
||||
r"""Applies Layer Normalization over a mini-batch of inputs as described in
|
||||
the paper `Layer Normalization`_ .
|
||||
|
|
@ -151,7 +147,6 @@ class LayerNorm(Module):
|
|||
init.ones_(self.weight)
|
||||
init.zeros_(self.bias)
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input):
|
||||
return F.layer_norm(
|
||||
input, self.normalized_shape, self.weight, self.bias, self.eps)
|
||||
|
|
@ -161,7 +156,6 @@ class LayerNorm(Module):
|
|||
'elementwise_affine={elementwise_affine}'.format(**self.__dict__)
|
||||
|
||||
|
||||
@weak_module
|
||||
class GroupNorm(Module):
|
||||
r"""Applies Group Normalization over a mini-batch of inputs as described in
|
||||
the paper `Group Normalization`_ .
|
||||
|
|
@ -226,7 +220,6 @@ class GroupNorm(Module):
|
|||
init.ones_(self.weight)
|
||||
init.zeros_(self.bias)
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input):
|
||||
return F.group_norm(
|
||||
input, self.num_groups, self.weight, self.bias, self.eps)
|
||||
|
|
|
|||
|
|
@ -1,13 +1,11 @@
|
|||
from .module import Module
|
||||
from .utils import _pair, _quadruple, _ntuple
|
||||
from .. import functional as F
|
||||
from ..._jit_internal import weak_module, weak_script_method
|
||||
|
||||
|
||||
# TODO: grad_output size asserts in THNN
|
||||
|
||||
|
||||
@weak_module
|
||||
class _ConstantPadNd(Module):
|
||||
__constants__ = ['padding', 'value']
|
||||
|
||||
|
|
@ -15,7 +13,6 @@ class _ConstantPadNd(Module):
|
|||
super(_ConstantPadNd, self).__init__()
|
||||
self.value = value
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input):
|
||||
return F.pad(input, self.padding, 'constant', self.value)
|
||||
|
||||
|
|
@ -23,7 +20,6 @@ class _ConstantPadNd(Module):
|
|||
return 'padding={}, value={}'.format(self.padding, self.value)
|
||||
|
||||
|
||||
@weak_module
|
||||
class ConstantPad1d(_ConstantPadNd):
|
||||
r"""Pads the input tensor boundaries with a constant value.
|
||||
|
||||
|
|
@ -73,7 +69,6 @@ class ConstantPad1d(_ConstantPadNd):
|
|||
self.padding = _pair(padding)
|
||||
|
||||
|
||||
@weak_module
|
||||
class ConstantPad2d(_ConstantPadNd):
|
||||
r"""Pads the input tensor boundaries with a constant value.
|
||||
|
||||
|
|
@ -123,7 +118,6 @@ class ConstantPad2d(_ConstantPadNd):
|
|||
self.padding = _quadruple(padding)
|
||||
|
||||
|
||||
@weak_module
|
||||
class ConstantPad3d(_ConstantPadNd):
|
||||
r"""Pads the input tensor boundaries with a constant value.
|
||||
|
||||
|
|
@ -162,11 +156,9 @@ class ConstantPad3d(_ConstantPadNd):
|
|||
self.padding = _ntuple(6)(padding)
|
||||
|
||||
|
||||
@weak_module
|
||||
class _ReflectionPadNd(Module):
|
||||
__constants__ = ['padding']
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input):
|
||||
return F.pad(input, self.padding, 'reflect')
|
||||
|
||||
|
|
@ -174,7 +166,6 @@ class _ReflectionPadNd(Module):
|
|||
return '{}'.format(self.padding)
|
||||
|
||||
|
||||
@weak_module
|
||||
class ReflectionPad1d(_ReflectionPadNd):
|
||||
r"""Pads the input tensor using the reflection of the input boundary.
|
||||
|
||||
|
|
@ -214,7 +205,6 @@ class ReflectionPad1d(_ReflectionPadNd):
|
|||
self.padding = _pair(padding)
|
||||
|
||||
|
||||
@weak_module
|
||||
class ReflectionPad2d(_ReflectionPadNd):
|
||||
r"""Pads the input tensor using the reflection of the input boundary.
|
||||
|
||||
|
|
@ -265,11 +255,9 @@ class ReflectionPad2d(_ReflectionPadNd):
|
|||
self.padding = _quadruple(padding)
|
||||
|
||||
|
||||
@weak_module
|
||||
class _ReplicationPadNd(Module):
|
||||
__constants__ = ['padding']
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input):
|
||||
return F.pad(input, self.padding, 'replicate')
|
||||
|
||||
|
|
@ -277,7 +265,6 @@ class _ReplicationPadNd(Module):
|
|||
return '{}'.format(self.padding)
|
||||
|
||||
|
||||
@weak_module
|
||||
class ReplicationPad1d(_ReplicationPadNd):
|
||||
r"""Pads the input tensor using replication of the input boundary.
|
||||
|
||||
|
|
@ -317,7 +304,6 @@ class ReplicationPad1d(_ReplicationPadNd):
|
|||
self.padding = _pair(padding)
|
||||
|
||||
|
||||
@weak_module
|
||||
class ReplicationPad2d(_ReplicationPadNd):
|
||||
r"""Pads the input tensor using replication of the input boundary.
|
||||
|
||||
|
|
@ -368,7 +354,6 @@ class ReplicationPad2d(_ReplicationPadNd):
|
|||
self.padding = _quadruple(padding)
|
||||
|
||||
|
||||
@weak_module
|
||||
class ReplicationPad3d(_ReplicationPadNd):
|
||||
r"""Pads the input tensor using replication of the input boundary.
|
||||
|
||||
|
|
@ -407,7 +392,6 @@ class ReplicationPad3d(_ReplicationPadNd):
|
|||
self.padding = _ntuple(6)(padding)
|
||||
|
||||
|
||||
@weak_module
|
||||
class ZeroPad2d(ConstantPad2d):
|
||||
r"""Pads the input tensor boundaries with zero.
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,7 @@
|
|||
from .module import Module
|
||||
from .. import functional as F
|
||||
from ..._jit_internal import weak_module, weak_script_method
|
||||
|
||||
|
||||
@weak_module
|
||||
class PixelShuffle(Module):
|
||||
r"""Rearranges elements in a tensor of shape :math:`(*, C \times r^2, H, W)`
|
||||
to a tensor of shape :math:`(*, C, H \times r, W \times r)`.
|
||||
|
|
@ -41,7 +39,6 @@ class PixelShuffle(Module):
|
|||
super(PixelShuffle, self).__init__()
|
||||
self.upscale_factor = upscale_factor
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input):
|
||||
return F.pixel_shuffle(input, self.upscale_factor)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,8 @@
|
|||
from .module import Module
|
||||
from .utils import _single, _pair, _triple
|
||||
from .. import functional as F
|
||||
from ..._jit_internal import weak_module, weak_script_method
|
||||
|
||||
|
||||
@weak_module
|
||||
class _MaxPoolNd(Module):
|
||||
__constants__ = ['kernel_size', 'stride', 'padding', 'dilation',
|
||||
'return_indices', 'ceil_mode']
|
||||
|
|
@ -24,7 +22,6 @@ class _MaxPoolNd(Module):
|
|||
', dilation={dilation}, ceil_mode={ceil_mode}'.format(**self.__dict__)
|
||||
|
||||
|
||||
@weak_module
|
||||
class MaxPool1d(_MaxPoolNd):
|
||||
r"""Applies a 1D max pooling over an input signal composed of several input
|
||||
planes.
|
||||
|
|
@ -68,7 +65,6 @@ class MaxPool1d(_MaxPoolNd):
|
|||
https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
|
||||
"""
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input):
|
||||
return F.max_pool1d(input, self.kernel_size, self.stride,
|
||||
self.padding, self.dilation, self.ceil_mode,
|
||||
|
|
@ -79,7 +75,6 @@ class MaxPool1d(_MaxPoolNd):
|
|||
', dilation={dilation}, ceil_mode={ceil_mode}'.format(**self.__dict__)
|
||||
|
||||
|
||||
@weak_module
|
||||
class MaxPool2d(_MaxPoolNd):
|
||||
r"""Applies a 2D max pooling over an input signal composed of several input
|
||||
planes.
|
||||
|
|
@ -139,14 +134,12 @@ class MaxPool2d(_MaxPoolNd):
|
|||
https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
|
||||
"""
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input):
|
||||
return F.max_pool2d(input, self.kernel_size, self.stride,
|
||||
self.padding, self.dilation, self.ceil_mode,
|
||||
self.return_indices)
|
||||
|
||||
|
||||
@weak_module
|
||||
class MaxPool3d(_MaxPoolNd):
|
||||
r"""Applies a 3D max pooling over an input signal composed of several input
|
||||
planes.
|
||||
|
|
@ -210,14 +203,12 @@ class MaxPool3d(_MaxPoolNd):
|
|||
https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
|
||||
""" # noqa: E501
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input):
|
||||
return F.max_pool3d(input, self.kernel_size, self.stride,
|
||||
self.padding, self.dilation, self.ceil_mode,
|
||||
self.return_indices)
|
||||
|
||||
|
||||
@weak_module
|
||||
class _MaxUnpoolNd(Module):
|
||||
|
||||
def extra_repr(self):
|
||||
|
|
@ -226,7 +217,6 @@ class _MaxUnpoolNd(Module):
|
|||
)
|
||||
|
||||
|
||||
@weak_module
|
||||
class MaxUnpool1d(_MaxUnpoolNd):
|
||||
r"""Computes a partial inverse of :class:`MaxPool1d`.
|
||||
|
||||
|
|
@ -292,7 +282,6 @@ class MaxUnpool1d(_MaxUnpoolNd):
|
|||
self.padding, output_size)
|
||||
|
||||
|
||||
@weak_module
|
||||
class MaxUnpool2d(_MaxUnpoolNd):
|
||||
r"""Computes a partial inverse of :class:`MaxPool2d`.
|
||||
|
||||
|
|
@ -366,7 +355,6 @@ class MaxUnpool2d(_MaxUnpoolNd):
|
|||
self.padding, output_size)
|
||||
|
||||
|
||||
@weak_module
|
||||
class MaxUnpool3d(_MaxUnpoolNd):
|
||||
r"""Computes a partial inverse of :class:`MaxPool3d`.
|
||||
|
||||
|
|
@ -429,7 +417,6 @@ class MaxUnpool3d(_MaxUnpoolNd):
|
|||
self.padding, output_size)
|
||||
|
||||
|
||||
@weak_module
|
||||
class _AvgPoolNd(Module):
|
||||
__constants__ = ['kernel_size', 'stride', 'padding', 'ceil_mode', 'count_include_pad']
|
||||
|
||||
|
|
@ -439,7 +426,6 @@ class _AvgPoolNd(Module):
|
|||
)
|
||||
|
||||
|
||||
@weak_module
|
||||
class AvgPool1d(_AvgPoolNd):
|
||||
r"""Applies a 1D average pooling over an input signal composed of several
|
||||
input planes.
|
||||
|
|
@ -490,14 +476,12 @@ class AvgPool1d(_AvgPoolNd):
|
|||
self.ceil_mode = ceil_mode
|
||||
self.count_include_pad = count_include_pad
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input):
|
||||
return F.avg_pool1d(
|
||||
input, self.kernel_size, self.stride, self.padding, self.ceil_mode,
|
||||
self.count_include_pad)
|
||||
|
||||
|
||||
@weak_module
|
||||
class AvgPool2d(_AvgPoolNd):
|
||||
r"""Applies a 2D average pooling over an input signal composed of several input
|
||||
planes.
|
||||
|
|
@ -557,13 +541,11 @@ class AvgPool2d(_AvgPoolNd):
|
|||
self.ceil_mode = ceil_mode
|
||||
self.count_include_pad = count_include_pad
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input):
|
||||
return F.avg_pool2d(input, self.kernel_size, self.stride,
|
||||
self.padding, self.ceil_mode, self.count_include_pad)
|
||||
|
||||
|
||||
@weak_module
|
||||
class AvgPool3d(_AvgPoolNd):
|
||||
r"""Applies a 3D average pooling over an input signal composed of several input
|
||||
planes.
|
||||
|
|
@ -630,7 +612,6 @@ class AvgPool3d(_AvgPoolNd):
|
|||
self.ceil_mode = ceil_mode
|
||||
self.count_include_pad = count_include_pad
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input):
|
||||
return F.avg_pool3d(input, self.kernel_size, self.stride,
|
||||
self.padding, self.ceil_mode, self.count_include_pad)
|
||||
|
|
@ -642,7 +623,6 @@ class AvgPool3d(_AvgPoolNd):
|
|||
self.__dict__.setdefault('count_include_pad', True)
|
||||
|
||||
|
||||
@weak_module
|
||||
class FractionalMaxPool2d(Module):
|
||||
r"""Applies a 2D fractional max pooling over an input signal composed of several input planes.
|
||||
|
||||
|
|
@ -694,7 +674,6 @@ class FractionalMaxPool2d(Module):
|
|||
raise ValueError("output_ratio must be between 0 and 1 (got {})"
|
||||
.format(output_ratio))
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input):
|
||||
return F.fractional_max_pool2d(
|
||||
input, self.kernel_size, self.output_size, self.output_ratio,
|
||||
|
|
@ -702,7 +681,6 @@ class FractionalMaxPool2d(Module):
|
|||
_random_samples=self._random_samples)
|
||||
|
||||
|
||||
@weak_module
|
||||
class FractionalMaxPool3d(Module):
|
||||
r"""Applies a 3D fractional max pooling over an input signal composed of several input planes.
|
||||
|
||||
|
|
@ -754,7 +732,6 @@ class FractionalMaxPool3d(Module):
|
|||
raise ValueError("output_ratio must be between 0 and 1 (got {})"
|
||||
.format(output_ratio))
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input):
|
||||
return F.fractional_max_pool3d(
|
||||
input, self.kernel_size, self.output_size, self.output_ratio,
|
||||
|
|
@ -762,7 +739,6 @@ class FractionalMaxPool3d(Module):
|
|||
_random_samples=self._random_samples)
|
||||
|
||||
|
||||
@weak_module
|
||||
class _LPPoolNd(Module):
|
||||
__constants__ = ['norm_type', 'kernel_size', 'stride', 'ceil_mode']
|
||||
|
||||
|
|
@ -778,7 +754,6 @@ class _LPPoolNd(Module):
|
|||
'ceil_mode={ceil_mode}'.format(**self.__dict__)
|
||||
|
||||
|
||||
@weak_module
|
||||
class LPPool1d(_LPPoolNd):
|
||||
r"""Applies a 1D power-average pooling over an input signal composed of several input
|
||||
planes.
|
||||
|
|
@ -814,14 +789,11 @@ class LPPool1d(_LPPoolNd):
|
|||
>>> output = m(input)
|
||||
"""
|
||||
|
||||
@weak_script_method
|
||||
@weak_script_method
|
||||
def forward(self, input):
|
||||
return F.lp_pool1d(input, float(self.norm_type), self.kernel_size,
|
||||
self.stride, self.ceil_mode)
|
||||
|
||||
|
||||
@weak_module
|
||||
class LPPool2d(_LPPoolNd):
|
||||
r"""Applies a 2D power-average pooling over an input signal composed of several input
|
||||
planes.
|
||||
|
|
@ -871,13 +843,11 @@ class LPPool2d(_LPPoolNd):
|
|||
|
||||
"""
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input):
|
||||
return F.lp_pool2d(input, float(self.norm_type), self.kernel_size,
|
||||
self.stride, self.ceil_mode)
|
||||
|
||||
|
||||
@weak_module
|
||||
class _AdaptiveMaxPoolNd(Module):
|
||||
__constants__ = ['output_size', 'return_indices']
|
||||
|
||||
|
|
@ -893,7 +863,6 @@ class _AdaptiveMaxPoolNd(Module):
|
|||
# output shapes are, and how the operation computes output.
|
||||
|
||||
|
||||
@weak_module
|
||||
class AdaptiveMaxPool1d(_AdaptiveMaxPoolNd):
|
||||
r"""Applies a 1D adaptive max pooling over an input signal composed of several input planes.
|
||||
|
||||
|
|
@ -913,12 +882,10 @@ class AdaptiveMaxPool1d(_AdaptiveMaxPoolNd):
|
|||
|
||||
"""
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input):
|
||||
return F.adaptive_max_pool1d(input, self.output_size, self.return_indices)
|
||||
|
||||
|
||||
@weak_module
|
||||
class AdaptiveMaxPool2d(_AdaptiveMaxPoolNd):
|
||||
r"""Applies a 2D adaptive max pooling over an input signal composed of several input planes.
|
||||
|
||||
|
|
@ -949,12 +916,10 @@ class AdaptiveMaxPool2d(_AdaptiveMaxPoolNd):
|
|||
|
||||
"""
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input):
|
||||
return F.adaptive_max_pool2d(input, self.output_size, self.return_indices)
|
||||
|
||||
|
||||
@weak_module
|
||||
class AdaptiveMaxPool3d(_AdaptiveMaxPoolNd):
|
||||
r"""Applies a 3D adaptive max pooling over an input signal composed of several input planes.
|
||||
|
||||
|
|
@ -986,12 +951,10 @@ class AdaptiveMaxPool3d(_AdaptiveMaxPoolNd):
|
|||
|
||||
"""
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input):
|
||||
return F.adaptive_max_pool3d(input, self.output_size, self.return_indices)
|
||||
|
||||
|
||||
@weak_module
|
||||
class _AdaptiveAvgPoolNd(Module):
|
||||
__constants__ = ['output_size']
|
||||
|
||||
|
|
@ -1003,7 +966,6 @@ class _AdaptiveAvgPoolNd(Module):
|
|||
return 'output_size={}'.format(self.output_size)
|
||||
|
||||
|
||||
@weak_module
|
||||
class AdaptiveAvgPool1d(_AdaptiveAvgPoolNd):
|
||||
r"""Applies a 1D adaptive average pooling over an input signal composed of several input planes.
|
||||
|
||||
|
|
@ -1021,12 +983,10 @@ class AdaptiveAvgPool1d(_AdaptiveAvgPoolNd):
|
|||
|
||||
"""
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input):
|
||||
return F.adaptive_avg_pool1d(input, self.output_size)
|
||||
|
||||
|
||||
@weak_module
|
||||
class AdaptiveAvgPool2d(_AdaptiveAvgPoolNd):
|
||||
r"""Applies a 2D adaptive average pooling over an input signal composed of several input planes.
|
||||
|
||||
|
|
@ -1055,12 +1015,10 @@ class AdaptiveAvgPool2d(_AdaptiveAvgPoolNd):
|
|||
|
||||
"""
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input):
|
||||
return F.adaptive_avg_pool2d(input, self.output_size)
|
||||
|
||||
|
||||
@weak_module
|
||||
class AdaptiveAvgPool3d(_AdaptiveAvgPoolNd):
|
||||
r"""Applies a 3D adaptive average pooling over an input signal composed of several input planes.
|
||||
|
||||
|
|
@ -1089,6 +1047,5 @@ class AdaptiveAvgPool3d(_AdaptiveAvgPoolNd):
|
|||
|
||||
"""
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input):
|
||||
return F.adaptive_avg_pool3d(input, self.output_size)
|
||||
|
|
|
|||
|
|
@ -8,8 +8,7 @@ from ..parameter import Parameter
|
|||
from ..utils.rnn import PackedSequence, get_packed_sequence
|
||||
from .. import init
|
||||
from .. import _VF
|
||||
from ..._jit_internal import weak_module, weak_script_method, weak_script, \
|
||||
_parameter_list
|
||||
from ..._jit_internal import _parameter_list
|
||||
|
||||
_rnn_impls = {
|
||||
'GRU': _VF.gru,
|
||||
|
|
@ -18,7 +17,6 @@ _rnn_impls = {
|
|||
}
|
||||
|
||||
|
||||
@weak_script
|
||||
def apply_permutation(tensor, permutation, dim=1):
|
||||
# type: (Tensor, Tensor, int) -> Tensor
|
||||
return tensor.index_select(dim, permutation)
|
||||
|
|
@ -139,7 +137,6 @@ class RNNBase(Module):
|
|||
def _get_flat_weights(self):
|
||||
return self._flat_weights
|
||||
|
||||
@weak_script_method
|
||||
def check_input(self, input, batch_sizes):
|
||||
# type: (Tensor, Optional[Tensor]) -> None
|
||||
expected_input_dim = 2 if batch_sizes is not None else 3
|
||||
|
|
@ -152,7 +149,6 @@ class RNNBase(Module):
|
|||
'input.size(-1) must be equal to input_size. Expected {}, got {}'.format(
|
||||
self.input_size, input.size(-1)))
|
||||
|
||||
@weak_script_method
|
||||
def get_expected_hidden_size(self, input, batch_sizes):
|
||||
# type: (Tensor, Optional[Tensor]) -> Tuple[int, int, int]
|
||||
if batch_sizes is not None:
|
||||
|
|
@ -165,7 +161,6 @@ class RNNBase(Module):
|
|||
mini_batch, self.hidden_size)
|
||||
return expected_hidden_size
|
||||
|
||||
@weak_script_method
|
||||
def check_hidden_size(self, hx, expected_hidden_size, msg='Expected hidden size {}, got {}'):
|
||||
# type: (Tensor, Tuple[int, int, int], str) -> None
|
||||
if hx.size() != expected_hidden_size:
|
||||
|
|
@ -374,7 +369,6 @@ class RNN(RNNBase):
|
|||
super(RNN, self).__init__(mode, *args, **kwargs)
|
||||
|
||||
|
||||
@weak_module
|
||||
class LSTM(RNNBase):
|
||||
r"""Applies a multi-layer long short-term memory (LSTM) RNN to an input
|
||||
sequence.
|
||||
|
|
@ -484,7 +478,6 @@ class LSTM(RNNBase):
|
|||
def __init__(self, *args, **kwargs):
|
||||
super(LSTM, self).__init__('LSTM', *args, **kwargs)
|
||||
|
||||
@weak_script_method
|
||||
def check_forward_args(self, input, hidden, batch_sizes):
|
||||
# type: (Tensor, Tuple[Tensor, Tensor], Optional[Tensor]) -> None
|
||||
self.check_input(input, batch_sizes)
|
||||
|
|
@ -495,14 +488,12 @@ class LSTM(RNNBase):
|
|||
self.check_hidden_size(hidden[1], expected_hidden_size,
|
||||
'Expected hidden[1] size {}, got {}')
|
||||
|
||||
@weak_script_method
|
||||
def permute_hidden(self, hx, permutation):
|
||||
# type: (Tuple[Tensor, Tensor], Optional[Tensor]) -> Tuple[Tensor, Tensor]
|
||||
if permutation is None:
|
||||
return hx
|
||||
return apply_permutation(hx[0], permutation), apply_permutation(hx[1], permutation)
|
||||
|
||||
@weak_script_method
|
||||
def forward_impl(self, input, hx, batch_sizes, max_batch_size, sorted_indices):
|
||||
# type: (Tensor, Optional[Tuple[Tensor, Tensor]], Optional[Tensor], int, Optional[Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa
|
||||
if hx is None:
|
||||
|
|
@ -528,7 +519,7 @@ class LSTM(RNNBase):
|
|||
|
||||
return output, hidden
|
||||
|
||||
@weak_script_method
|
||||
@torch._jit_internal.export
|
||||
def forward_tensor(self, input, hx=None):
|
||||
# type: (Tensor, Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
|
||||
batch_sizes = None
|
||||
|
|
@ -540,7 +531,7 @@ class LSTM(RNNBase):
|
|||
|
||||
return output, self.permute_hidden(hidden, unsorted_indices)
|
||||
|
||||
@weak_script_method
|
||||
@torch._jit_internal.export
|
||||
def forward_packed(self, input, hx=None):
|
||||
# type: (Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]], Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]], Tuple[Tensor, Tensor]] # noqa
|
||||
input, batch_sizes, sorted_indices, unsorted_indices = input
|
||||
|
|
@ -552,6 +543,7 @@ class LSTM(RNNBase):
|
|||
output = get_packed_sequence(output, batch_sizes, sorted_indices, unsorted_indices)
|
||||
return output, self.permute_hidden(hidden, unsorted_indices)
|
||||
|
||||
@torch._jit_internal.ignore
|
||||
def forward(self, input, hx=None):
|
||||
if isinstance(input, PackedSequence):
|
||||
return self.forward_packed(input, hx)
|
||||
|
|
@ -694,14 +686,12 @@ class RNNCellBase(Module):
|
|||
s += ', nonlinearity={nonlinearity}'
|
||||
return s.format(**self.__dict__)
|
||||
|
||||
@weak_script_method
|
||||
def check_forward_input(self, input):
|
||||
if input.size(1) != self.input_size:
|
||||
raise RuntimeError(
|
||||
"input has inconsistent input_size: got {}, expected {}".format(
|
||||
input.size(1), self.input_size))
|
||||
|
||||
@weak_script_method
|
||||
def check_forward_hidden(self, input, hx, hidden_label=''):
|
||||
# type: (Tensor, Tensor, str) -> None
|
||||
if input.size(0) != hx.size(0):
|
||||
|
|
@ -720,7 +710,6 @@ class RNNCellBase(Module):
|
|||
init.uniform_(weight, -stdv, stdv)
|
||||
|
||||
|
||||
@weak_module
|
||||
class RNNCell(RNNCellBase):
|
||||
r"""An Elman RNN cell with tanh or ReLU non-linearity.
|
||||
|
||||
|
|
@ -784,7 +773,6 @@ class RNNCell(RNNCellBase):
|
|||
super(RNNCell, self).__init__(input_size, hidden_size, bias, num_chunks=1)
|
||||
self.nonlinearity = nonlinearity
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input, hx=None):
|
||||
# type: (Tensor, Optional[Tensor]) -> Tensor
|
||||
self.check_forward_input(input)
|
||||
|
|
@ -810,7 +798,6 @@ class RNNCell(RNNCellBase):
|
|||
return ret
|
||||
|
||||
|
||||
@weak_module
|
||||
class LSTMCell(RNNCellBase):
|
||||
r"""A long short-term memory (LSTM) cell.
|
||||
|
||||
|
|
@ -875,7 +862,6 @@ class LSTMCell(RNNCellBase):
|
|||
def __init__(self, input_size, hidden_size, bias=True):
|
||||
super(LSTMCell, self).__init__(input_size, hidden_size, bias, num_chunks=4)
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input, hx=None):
|
||||
# type: (Tensor, Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tensor]
|
||||
self.check_forward_input(input)
|
||||
|
|
@ -891,7 +877,6 @@ class LSTMCell(RNNCellBase):
|
|||
)
|
||||
|
||||
|
||||
@weak_module
|
||||
class GRUCell(RNNCellBase):
|
||||
r"""A gated recurrent unit (GRU) cell
|
||||
|
||||
|
|
@ -957,7 +942,6 @@ class GRUCell(RNNCellBase):
|
|||
def __init__(self, input_size, hidden_size, bias=True):
|
||||
super(GRUCell, self).__init__(input_size, hidden_size, bias, num_chunks=3)
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input, hx=None):
|
||||
# type: (Tensor, Optional[Tensor]) -> Tensor
|
||||
self.check_forward_input(input)
|
||||
|
|
|
|||
|
|
@ -4,10 +4,8 @@ from torch.nn.parameter import Parameter
|
|||
from .module import Module
|
||||
from .. import functional as F
|
||||
from .. import init
|
||||
from torch._jit_internal import weak_module, weak_script_method
|
||||
|
||||
|
||||
@weak_module
|
||||
class Embedding(Module):
|
||||
r"""A simple lookup table that stores embeddings of a fixed dictionary and size.
|
||||
|
||||
|
|
@ -110,7 +108,6 @@ class Embedding(Module):
|
|||
with torch.no_grad():
|
||||
self.weight[self.padding_idx].fill_(0)
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input):
|
||||
return F.embedding(
|
||||
input, self.weight, self.padding_idx, self.max_norm,
|
||||
|
|
@ -173,7 +170,6 @@ class Embedding(Module):
|
|||
return embedding
|
||||
|
||||
|
||||
@weak_module
|
||||
class EmbeddingBag(Module):
|
||||
r"""Computes sums or means of 'bags' of embeddings, without instantiating the
|
||||
intermediate embeddings.
|
||||
|
|
@ -277,7 +273,6 @@ class EmbeddingBag(Module):
|
|||
def reset_parameters(self):
|
||||
init.normal_(self.weight)
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input, offsets=None, per_sample_weights=None):
|
||||
# type: (Tensor, Optional[Tensor], Optional[Tensor]) -> Tensor
|
||||
return F.embedding_bag(input, self.weight, offsets,
|
||||
|
|
|
|||
|
|
@ -1,9 +1,7 @@
|
|||
from .module import Module
|
||||
from .. import functional as F
|
||||
from ..._jit_internal import weak_module, weak_script_method
|
||||
|
||||
|
||||
@weak_module
|
||||
class Upsample(Module):
|
||||
r"""Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D (volumetric) data.
|
||||
|
||||
|
|
@ -129,7 +127,6 @@ class Upsample(Module):
|
|||
self.mode = mode
|
||||
self.align_corners = align_corners
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input):
|
||||
return F.interpolate(input, self.size, self.scale_factor, self.mode, self.align_corners)
|
||||
|
||||
|
|
@ -142,7 +139,6 @@ class Upsample(Module):
|
|||
return info
|
||||
|
||||
|
||||
@weak_module
|
||||
class UpsamplingNearest2d(Upsample):
|
||||
r"""Applies a 2D nearest neighbor upsampling to an input signal composed of several input
|
||||
channels.
|
||||
|
|
@ -188,7 +184,6 @@ class UpsamplingNearest2d(Upsample):
|
|||
super(UpsamplingNearest2d, self).__init__(size, scale_factor, mode='nearest')
|
||||
|
||||
|
||||
@weak_module
|
||||
class UpsamplingBilinear2d(Upsample):
|
||||
r"""Applies a 2D bilinear upsampling to an input signal composed of several input
|
||||
channels.
|
||||
|
|
|
|||
|
|
@ -23,10 +23,9 @@ def _is_jit_enabled():
|
|||
|
||||
|
||||
# Check if we can safely replicate the module.
|
||||
# there are three types of module:
|
||||
# there are two types of module:
|
||||
# 1. python modules
|
||||
# 2. weak python modules (nn.Module annotated by @weak_module)
|
||||
# 3. ScriptModule
|
||||
# 2. ScriptModule
|
||||
#
|
||||
# currently a module cannot be replicated properly if the descendants of
|
||||
# any ScriptModule contains python module (type 1 above)
|
||||
|
|
|
|||
|
|
@ -5,9 +5,7 @@ from __future__ import unicode_literals
|
|||
|
||||
from .. import functional as F
|
||||
from ...modules.module import Module
|
||||
from ...._jit_internal import weak_module, weak_script_method
|
||||
|
||||
@weak_module
|
||||
class ReLU(Module):
|
||||
r"""Applies quantized rectified linear unit function element-wise:
|
||||
|
||||
|
|
@ -37,7 +35,6 @@ class ReLU(Module):
|
|||
assert not inplace, 'torch.nn.quantized.ReLU does not support inplace'
|
||||
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, input):
|
||||
return F.relu(input)
|
||||
|
||||
|
|
|
|||
|
|
@ -2,9 +2,7 @@ from __future__ import absolute_import, division, print_function, unicode_litera
|
|||
import torch
|
||||
from ...modules.module import Module
|
||||
from ...modules.linear import Linear as NNLinear
|
||||
from ...._jit_internal import weak_module
|
||||
|
||||
@weak_module
|
||||
class Quantize(Module):
|
||||
r"""Quantizes an incoming tensor
|
||||
Args:
|
||||
|
|
@ -39,7 +37,6 @@ class Quantize(Module):
|
|||
def from_float(mod):
|
||||
return Quantize(mod.qparams[0].item(), mod.qparams[1].item(), torch.quint8)
|
||||
|
||||
@weak_module
|
||||
class DeQuantize(Module):
|
||||
r"""Dequantizes an incoming tensor
|
||||
|
||||
|
|
@ -65,7 +62,6 @@ class DeQuantize(Module):
|
|||
def from_float(mod):
|
||||
return DeQuantize()
|
||||
|
||||
@weak_module
|
||||
class Linear(NNLinear):
|
||||
r"""
|
||||
A quantized linear module with quantized tensor as inputs
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user