mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Remove weak script (#22212)
Summary: * Deletes all weak script decorators / associated data structures / methods * In order to keep supporting the standard library in script, this enables recursive script on any function defined in `torch.nn` * Most changes in `torch/nn` are the result of `ag -Q "weak" torch/nn/ -l | xargs sed -i '/weak/d'`, only `rnn.py` needed manual editing to use the `ignore` and `export` to continue supporting the overloaded `forward` methods * `Sequential`/`ModuleList` no longer need to be added to constants since they are compiled on demand This should also fix https://github.com/pytorch/pytorch/issues/22212 Pull Request resolved: https://github.com/pytorch/pytorch/pull/22212 Differential Revision: D15988346 Pulled By: driazati fbshipit-source-id: af223e3ad0580be895377312949997a70e988e4f
This commit is contained in:
parent
b93f29ded3
commit
10c4b98ade
|
|
@ -6979,7 +6979,7 @@ a")
|
||||||
with self.assertRaisesRegex(RuntimeError, "'int' object is not iterable"):
|
with self.assertRaisesRegex(RuntimeError, "'int' object is not iterable"):
|
||||||
M()
|
M()
|
||||||
|
|
||||||
def test_script_module_list_sequential_error(self):
|
def test_script_module_list_sequential(self):
|
||||||
class M(torch.jit.ScriptModule):
|
class M(torch.jit.ScriptModule):
|
||||||
def __init__(self, mod_list):
|
def __init__(self, mod_list):
|
||||||
super(M, self).__init__(False)
|
super(M, self).__init__(False)
|
||||||
|
|
@ -6991,25 +6991,21 @@ a")
|
||||||
v = m(v)
|
v = m(v)
|
||||||
return v
|
return v
|
||||||
|
|
||||||
with self.assertRaisesRegex(RuntimeError, "Did you forget to add it to __constants"):
|
m = M(nn.Sequential(nn.ReLU()))
|
||||||
a = M(nn.Sequential(nn.ReLU()))
|
self.assertExportImportModule(m, (torch.randn(2, 2),))
|
||||||
with self.assertRaisesRegex(RuntimeError, "Did you forget to add it to __constants"):
|
|
||||||
a = M(nn.ModuleList([nn.ReLU()]))
|
|
||||||
|
|
||||||
def test_attr_module_constants_error(self):
|
def test_attr_module_constants(self):
|
||||||
class M2(torch.jit.ScriptModule):
|
class M2(torch.jit.ScriptModule):
|
||||||
def __init__(self, mod_list):
|
def __init__(self, mod_list):
|
||||||
super(M2, self).__init__(False)
|
super(M2, self).__init__(False)
|
||||||
self.mods = mod_list
|
self.mods = mod_list
|
||||||
|
|
||||||
@torch.jit.script_method
|
@torch.jit.script_method
|
||||||
def forward(self, v):
|
def forward(self, x):
|
||||||
return self.mods.forward(x)
|
return self.mods.forward(x)
|
||||||
|
|
||||||
with self.assertRaisesRegex(RuntimeError, "Did you forget to add it to __constants"):
|
m = M2(nn.Sequential(nn.ReLU()))
|
||||||
M2(nn.Sequential(nn.ReLU()))
|
self.assertExportImportModule(m, (torch.randn(2, 2),))
|
||||||
with self.assertRaisesRegex(RuntimeError, "Did you forget to add it to __constants"):
|
|
||||||
M2(nn.ModuleList([nn.ReLU()]))
|
|
||||||
|
|
||||||
def test_script_sequential_for(self):
|
def test_script_sequential_for(self):
|
||||||
class Sub(torch.jit.ScriptModule):
|
class Sub(torch.jit.ScriptModule):
|
||||||
|
|
@ -11007,6 +11003,7 @@ a")
|
||||||
with self.assertRaisesRegex(torch.jit.Error, "Exception"):
|
with self.assertRaisesRegex(torch.jit.Error, "Exception"):
|
||||||
foo(torch.tensor(0))
|
foo(torch.tensor(0))
|
||||||
|
|
||||||
|
@unittest.skipIf(True, "Removing weak script")
|
||||||
def test_weak_script_function(self):
|
def test_weak_script_function(self):
|
||||||
outer_var = 10
|
outer_var = 10
|
||||||
outer_var2 = 11
|
outer_var2 = 11
|
||||||
|
|
@ -11086,6 +11083,7 @@ a")
|
||||||
eg = torch.zeros(3, dtype=torch.uint8)
|
eg = torch.zeros(3, dtype=torch.uint8)
|
||||||
self.assertEqual(foo_traced(eg), foo(eg))
|
self.assertEqual(foo_traced(eg), foo(eg))
|
||||||
|
|
||||||
|
@unittest.skipIf(True, "Removing weak script")
|
||||||
def test_weak_module(self):
|
def test_weak_module(self):
|
||||||
|
|
||||||
@torch._jit_internal.weak_module
|
@torch._jit_internal.weak_module
|
||||||
|
|
@ -11161,6 +11159,7 @@ a")
|
||||||
self.assertEqual(script_result, expected_result)
|
self.assertEqual(script_result, expected_result)
|
||||||
self.assertEqual(script_result, script_result2)
|
self.assertEqual(script_result, script_result2)
|
||||||
|
|
||||||
|
@unittest.skipIf(True, "Removing weak script")
|
||||||
def test_weak_module_parameters_and_buffers(self):
|
def test_weak_module_parameters_and_buffers(self):
|
||||||
weights = torch.randn(10, 10)
|
weights = torch.randn(10, 10)
|
||||||
bias = torch.randn(10)
|
bias = torch.randn(10)
|
||||||
|
|
@ -11219,6 +11218,7 @@ a")
|
||||||
self.assertEqual(strong_mod(inp), expected_result)
|
self.assertEqual(strong_mod(inp), expected_result)
|
||||||
self.assertExportImportModule(strong_mod, (inp,))
|
self.assertExportImportModule(strong_mod, (inp,))
|
||||||
|
|
||||||
|
@unittest.skipIf(True, "Removing weak script")
|
||||||
def test_weak_module_nested(self):
|
def test_weak_module_nested(self):
|
||||||
@torch._jit_internal.weak_module
|
@torch._jit_internal.weak_module
|
||||||
class OtherWeak(torch.nn.Module):
|
class OtherWeak(torch.nn.Module):
|
||||||
|
|
@ -11280,6 +11280,7 @@ a")
|
||||||
+ F.linear(inp, 2 * torch.ones(10, 10), 2 * torch.ones(10))
|
+ F.linear(inp, 2 * torch.ones(10, 10), 2 * torch.ones(10))
|
||||||
self.assertEqual(result, expected_result)
|
self.assertEqual(result, expected_result)
|
||||||
|
|
||||||
|
@unittest.skipIf(True, "Removing weak script")
|
||||||
def test_weak_module_submodule(self):
|
def test_weak_module_submodule(self):
|
||||||
@torch._jit_internal.weak_module
|
@torch._jit_internal.weak_module
|
||||||
class Weak(torch.nn.Module):
|
class Weak(torch.nn.Module):
|
||||||
|
|
@ -11319,6 +11320,7 @@ a")
|
||||||
with self.assertRaisesRegex(RuntimeError, "Cannot call a ScriptModule that is not a submodule of the caller"):
|
with self.assertRaisesRegex(RuntimeError, "Cannot call a ScriptModule that is not a submodule of the caller"):
|
||||||
strong_mod = Strong()
|
strong_mod = Strong()
|
||||||
|
|
||||||
|
@unittest.skipIf(True, "Removing weak script")
|
||||||
def test_weak_module_copying(self):
|
def test_weak_module_copying(self):
|
||||||
class Submodule(torch.nn.Module):
|
class Submodule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
@ -11385,6 +11387,7 @@ a")
|
||||||
|
|
||||||
m = M()
|
m = M()
|
||||||
|
|
||||||
|
@unittest.skipIf(True, "Removing weak script")
|
||||||
def test_weak_module_attributes(self):
|
def test_weak_module_attributes(self):
|
||||||
tester = self
|
tester = self
|
||||||
|
|
||||||
|
|
@ -11948,6 +11951,7 @@ a")
|
||||||
|
|
||||||
FileCheck().check_not("prim::PythonOp").run(cu.test.graph)
|
FileCheck().check_not("prim::PythonOp").run(cu.test.graph)
|
||||||
|
|
||||||
|
@unittest.skipIf(True, "Removing weak script")
|
||||||
def test_overloading(self):
|
def test_overloading(self):
|
||||||
@torch._jit_internal.weak_module
|
@torch._jit_internal.weak_module
|
||||||
class W(torch.nn.Module):
|
class W(torch.nn.Module):
|
||||||
|
|
@ -13623,6 +13627,9 @@ EXCLUDE_SCRIPT_MODULES = {
|
||||||
'test_nn_AdaptiveAvgPool3d_tuple_none',
|
'test_nn_AdaptiveAvgPool3d_tuple_none',
|
||||||
'test_nn_AdaptiveMaxPool2d_tuple_none',
|
'test_nn_AdaptiveMaxPool2d_tuple_none',
|
||||||
'test_nn_AdaptiveMaxPool3d_tuple_none',
|
'test_nn_AdaptiveMaxPool3d_tuple_none',
|
||||||
|
|
||||||
|
# Uses Module._backend, so this is not supported
|
||||||
|
'test_nn_CrossMapLRN2d',
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -14552,10 +14559,6 @@ def add_nn_module_test(*args, **kwargs):
|
||||||
|
|
||||||
module_name = name.split("_")[0]
|
module_name = name.split("_")[0]
|
||||||
|
|
||||||
module = getattr(torch.nn, module_name, None)
|
|
||||||
if module is None or torch._jit_internal.weak_types.get(module) is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
if 'desc' in kwargs and 'eval' in kwargs['desc']:
|
if 'desc' in kwargs and 'eval' in kwargs['desc']:
|
||||||
# eval() is not supported, so skip these tests
|
# eval() is not supported, so skip these tests
|
||||||
return
|
return
|
||||||
|
|
|
||||||
|
|
@ -4,29 +4,14 @@ can be used in other places in torch/ (namely torch.nn) without running into
|
||||||
circular dependency problems
|
circular dependency problems
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import weakref
|
|
||||||
import inspect
|
import inspect
|
||||||
|
import weakref
|
||||||
from torch._six import builtins
|
from torch._six import builtins
|
||||||
|
|
||||||
# Tracks standalone weak script functions
|
|
||||||
compiled_weak_fns = weakref.WeakKeyDictionary() # noqa: T484
|
|
||||||
|
|
||||||
# Tracks which methods should be converted to strong methods
|
|
||||||
weak_script_methods = weakref.WeakKeyDictionary() # noqa: T484
|
|
||||||
|
|
||||||
# Converted modules and their corresponding WeakScriptModuleProxy objects
|
|
||||||
weak_modules = weakref.WeakKeyDictionary() # noqa: T484
|
|
||||||
|
|
||||||
# Types that have been declared as weak modules
|
|
||||||
weak_types = weakref.WeakKeyDictionary() # noqa: T484
|
|
||||||
|
|
||||||
# Wrapper functions that can call either of 2 functions depending on a boolean
|
# Wrapper functions that can call either of 2 functions depending on a boolean
|
||||||
# argument
|
# argument
|
||||||
boolean_dispatched = weakref.WeakKeyDictionary() # noqa: T484
|
boolean_dispatched = weakref.WeakKeyDictionary() # noqa: T484
|
||||||
|
|
||||||
COMPILATION_PENDING = object()
|
|
||||||
COMPILED = object()
|
|
||||||
|
|
||||||
|
|
||||||
def createResolutionCallback(frames_up=0):
|
def createResolutionCallback(frames_up=0):
|
||||||
"""
|
"""
|
||||||
|
|
@ -71,51 +56,41 @@ def createResolutionCallback(frames_up=0):
|
||||||
return f_globals[key]
|
return f_globals[key]
|
||||||
elif hasattr(builtins, key):
|
elif hasattr(builtins, key):
|
||||||
return getattr(builtins, key)
|
return getattr(builtins, key)
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return env
|
return env
|
||||||
|
|
||||||
|
|
||||||
def weak_script(fn, _frames_up=0):
|
def createResolutionCallbackFromClosure(fn):
|
||||||
"""
|
"""
|
||||||
Marks a function as a weak script function. When used in a script function
|
Create a resolutionCallback by introspecting the function instead of
|
||||||
or ScriptModule, the weak script function will be lazily compiled and
|
looking up the stack for the enclosing scope
|
||||||
inlined in the graph. When not used in a script function, the weak script
|
|
||||||
annotation has no effect.
|
|
||||||
"""
|
"""
|
||||||
compiled_weak_fns[fn] = {
|
var_names = fn.__code__.co_freevars
|
||||||
"status": COMPILATION_PENDING,
|
|
||||||
"compiled_fn": None,
|
|
||||||
"rcb": createResolutionCallback(_frames_up + 1)
|
|
||||||
}
|
|
||||||
return fn
|
|
||||||
|
|
||||||
|
# map of captured name -> value
|
||||||
|
free_vars = {}
|
||||||
|
|
||||||
def weak_module(cls):
|
for index, name in enumerate(var_names):
|
||||||
weak_types[cls] = {
|
free_vars[name] = fn.__closure__[index].cell_contents
|
||||||
"method_stubs": None
|
f_globals = fn.__globals__
|
||||||
}
|
|
||||||
return cls
|
|
||||||
|
|
||||||
|
def env(key):
|
||||||
|
if key in free_vars:
|
||||||
|
return free_vars[key]
|
||||||
|
elif hasattr(builtins, key):
|
||||||
|
return getattr(builtins, key)
|
||||||
|
else:
|
||||||
|
return f_globals.get(key)
|
||||||
|
|
||||||
def weak_script_method(fn):
|
return env
|
||||||
weak_script_methods[fn] = {
|
|
||||||
"rcb": createResolutionCallback(frames_up=2),
|
|
||||||
"original_method": fn
|
|
||||||
}
|
|
||||||
return fn
|
|
||||||
|
|
||||||
|
|
||||||
def boolean_dispatch(arg_name, arg_index, default, if_true, if_false, module_name, func_name):
|
def boolean_dispatch(arg_name, arg_index, default, if_true, if_false, module_name, func_name):
|
||||||
"""
|
"""
|
||||||
Dispatches to either of 2 weak script functions based on a boolean argument.
|
Dispatches to either of 2 script functions based on a boolean argument.
|
||||||
In TorchScript, the boolean argument must be constant so that the correct
|
In TorchScript, the boolean argument must be constant so that the correct
|
||||||
function to use can be determined at compile time.
|
function to use can be determined at compile time.
|
||||||
"""
|
"""
|
||||||
if compiled_weak_fns.get(if_true) is None or compiled_weak_fns.get(if_false) is None:
|
|
||||||
raise RuntimeError("both functions must be weak script")
|
|
||||||
|
|
||||||
def fn(*args, **kwargs):
|
def fn(*args, **kwargs):
|
||||||
dispatch_flag = False
|
dispatch_flag = False
|
||||||
if arg_name in kwargs:
|
if arg_name in kwargs:
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,8 @@ Decl mergeTypesFromTypeComment(
|
||||||
<< "Number of type annotations ("
|
<< "Number of type annotations ("
|
||||||
<< type_annotation_decl.params().size()
|
<< type_annotation_decl.params().size()
|
||||||
<< ") did not match the number of "
|
<< ") did not match the number of "
|
||||||
<< "function parameters (" << expected_num_annotations << ")";
|
<< (is_method ? "method" : "function")
|
||||||
|
<< " parameters (" << expected_num_annotations << ")";
|
||||||
}
|
}
|
||||||
auto old = decl.params();
|
auto old = decl.params();
|
||||||
auto _new = type_annotation_decl.params();
|
auto _new = type_annotation_decl.params();
|
||||||
|
|
|
||||||
|
|
@ -244,6 +244,11 @@ std::shared_ptr<SugaredValue> OverloadedMethodValue::call(
|
||||||
<< err.str();
|
<< err.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool should_recurse(py::object obj) {
|
||||||
|
return py::cast<bool>(py::module::import("torch.jit")
|
||||||
|
.attr("_is_recursive_script_enabled")(obj));
|
||||||
|
}
|
||||||
|
|
||||||
std::shared_ptr<SugaredValue> ModuleValue::attr(
|
std::shared_ptr<SugaredValue> ModuleValue::attr(
|
||||||
const SourceRange& loc,
|
const SourceRange& loc,
|
||||||
Function& m,
|
Function& m,
|
||||||
|
|
@ -307,7 +312,7 @@ std::shared_ptr<SugaredValue> ModuleValue::attr(
|
||||||
|
|
||||||
// If recursive script mode is on, create a ScriptModule and register it as
|
// If recursive script mode is on, create a ScriptModule and register it as
|
||||||
// as submodule or register a python method as a script::Method
|
// as submodule or register a python method as a script::Method
|
||||||
if (getRecursiveScriptMode()) {
|
if (should_recurse(attr)) {
|
||||||
if (py::isinstance(attr, py::module::import("torch.nn").attr("Module"))) {
|
if (py::isinstance(attr, py::module::import("torch.nn").attr("Module"))) {
|
||||||
// If the module is a submodule of the py_module, convert it to a
|
// If the module is a submodule of the py_module, convert it to a
|
||||||
// ScriptModule and add it as a submodule to the script::Module. This
|
// ScriptModule and add it as a submodule to the script::Module. This
|
||||||
|
|
@ -471,11 +476,6 @@ std::shared_ptr<SugaredValue> toSugaredValue(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auto weak_obj =
|
|
||||||
py::module::import("torch.jit").attr("_try_get_weak_module")(obj);
|
|
||||||
if (!weak_obj.is_none()) {
|
|
||||||
obj = weak_obj;
|
|
||||||
}
|
|
||||||
if (auto callee = as_function(obj)) {
|
if (auto callee = as_function(obj)) {
|
||||||
return std::make_shared<FunctionValue>(callee);
|
return std::make_shared<FunctionValue>(callee);
|
||||||
} else if (py::isinstance<py::module>(obj)) {
|
} else if (py::isinstance<py::module>(obj)) {
|
||||||
|
|
@ -504,12 +504,6 @@ std::shared_ptr<SugaredValue> toSugaredValue(
|
||||||
<< "which is currently not supported in Torchscript."
|
<< "which is currently not supported in Torchscript."
|
||||||
<< "Please open a feature request to add it.";
|
<< "Please open a feature request to add it.";
|
||||||
}
|
}
|
||||||
|
|
||||||
auto compiled_fn =
|
|
||||||
py::module::import("torch.jit").attr("_try_compile_weak_script")(obj);
|
|
||||||
if (auto callee = as_function(compiled_fn)) {
|
|
||||||
return std::make_shared<FunctionValue>(callee);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
py::object dispatched_fn =
|
py::object dispatched_fn =
|
||||||
|
|
@ -528,7 +522,7 @@ std::shared_ptr<SugaredValue> toSugaredValue(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (getRecursiveScriptMode() && py::isinstance<py::function>(obj)) {
|
if (should_recurse(obj) && py::isinstance<py::function>(obj)) {
|
||||||
auto compiled_fn =
|
auto compiled_fn =
|
||||||
py::module::import("torch.jit").attr("_try_compile_fn")(obj);
|
py::module::import("torch.jit").attr("_try_compile_fn")(obj);
|
||||||
if (auto callee = as_function(compiled_fn)) {
|
if (auto callee = as_function(compiled_fn)) {
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ import torch.backends.cudnn as cudnn
|
||||||
import torch.jit.annotations
|
import torch.jit.annotations
|
||||||
import torch._jit_internal as _jit_internal
|
import torch._jit_internal as _jit_internal
|
||||||
from torch._six import PY2, PY37, with_metaclass, get_function_from_type, \
|
from torch._six import PY2, PY37, with_metaclass, get_function_from_type, \
|
||||||
string_classes, builtins
|
string_classes
|
||||||
from ..nn.modules.utils import _single, _pair, _triple, _quadruple, \
|
from ..nn.modules.utils import _single, _pair, _triple, _quadruple, \
|
||||||
_list_with_default
|
_list_with_default
|
||||||
import torch.testing
|
import torch.testing
|
||||||
|
|
@ -930,50 +930,10 @@ def _try_get_overloaded_fn(mod, field):
|
||||||
return mod._overloads.get(field, None) if isinstance(mod, ScriptModule) else None
|
return mod._overloads.get(field, None) if isinstance(mod, ScriptModule) else None
|
||||||
|
|
||||||
|
|
||||||
def _try_compile_weak_script(fn):
|
|
||||||
entry = _jit_internal.compiled_weak_fns.get(fn)
|
|
||||||
if entry is None:
|
|
||||||
return None
|
|
||||||
if entry["status"] == _jit_internal.COMPILATION_PENDING:
|
|
||||||
compiled_fn = torch.jit.script(fn, True, 0, entry["rcb"])
|
|
||||||
del entry["rcb"]
|
|
||||||
_jit_internal.compiled_weak_fns[fn]["compiled_fn"] = compiled_fn
|
|
||||||
entry["status"] = _jit_internal.COMPILED
|
|
||||||
return compiled_fn
|
|
||||||
# TODO: use fn.__closure__
|
|
||||||
raise RuntimeError("Cannot make resolutionCallback in Python 2")
|
|
||||||
else:
|
|
||||||
return entry["compiled_fn"]
|
|
||||||
|
|
||||||
|
|
||||||
class ScriptWarning(Warning):
|
class ScriptWarning(Warning):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def createResolutionCallbackFromClosure(fn):
|
|
||||||
"""
|
|
||||||
Create a resolutionCallback by introspecting the function instead of
|
|
||||||
looking up the stack for the enclosing scope
|
|
||||||
"""
|
|
||||||
var_names = fn.__code__.co_freevars
|
|
||||||
|
|
||||||
# map of captured name -> value
|
|
||||||
free_vars = {}
|
|
||||||
|
|
||||||
for index, name in enumerate(var_names):
|
|
||||||
free_vars[name] = fn.__closure__[index].cell_contents
|
|
||||||
f_globals = fn.__globals__
|
|
||||||
|
|
||||||
def env(key):
|
|
||||||
if key in free_vars:
|
|
||||||
return free_vars[key]
|
|
||||||
elif hasattr(builtins, key):
|
|
||||||
return getattr(builtins, key)
|
|
||||||
else:
|
|
||||||
return f_globals.get(key)
|
|
||||||
|
|
||||||
return env
|
|
||||||
|
|
||||||
def _create_constant_iterable_module(module):
|
def _create_constant_iterable_module(module):
|
||||||
modules = OrderedDict()
|
modules = OrderedDict()
|
||||||
|
|
||||||
|
|
@ -1012,20 +972,20 @@ def _try_compile_fn(fn):
|
||||||
# Don't do anything for @ignore'd functions
|
# Don't do anything for @ignore'd functions
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if not inspect.isfunction(fn) and not inspect.ismethod(fn):
|
|
||||||
raise RuntimeError("`{}` is not a function. Recursive scripting only supports "
|
|
||||||
"Python functions or methods currently.\n"
|
|
||||||
"Consider manually annotating `{}` with @torch.jit.script.".format(fn))
|
|
||||||
|
|
||||||
if isinstance(fn, torch.nn.Module):
|
if isinstance(fn, torch.nn.Module):
|
||||||
# Since modules are callable pybind recognizes them as functions, but
|
# Since modules are callable pybind recognizes them as functions, but
|
||||||
# don't do anything for them
|
# don't do anything for them
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
if not inspect.isfunction(fn) and not inspect.ismethod(fn):
|
||||||
|
raise RuntimeError("`{}` is not a function. Recursive scripting only supports "
|
||||||
|
"Python functions or methods currently.\n"
|
||||||
|
"Consider manually annotating `{}` with @torch.jit.script.".format(fn, fn))
|
||||||
|
|
||||||
# We don't have the actual scope where the function was defined, but we can
|
# We don't have the actual scope where the function was defined, but we can
|
||||||
# extract the necessary info from the closed over variables on the function
|
# extract the necessary info from the closed over variables on the function
|
||||||
# object
|
# object
|
||||||
rcb = createResolutionCallbackFromClosure(fn)
|
rcb = _jit_internal.createResolutionCallbackFromClosure(fn)
|
||||||
return torch.jit.script(fn, _rcb=rcb)
|
return torch.jit.script(fn, _rcb=rcb)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1040,7 +1000,9 @@ def _disable_emit_hooks():
|
||||||
def _create_method_from_fn(module, fn):
|
def _create_method_from_fn(module, fn):
|
||||||
if _jit_internal.is_ignored_fn(fn):
|
if _jit_internal.is_ignored_fn(fn):
|
||||||
return None
|
return None
|
||||||
stub = script_method(fn, createResolutionCallbackFromClosure(fn))
|
if not inspect.ismethod(fn):
|
||||||
|
return None
|
||||||
|
stub = script_method(fn, _jit_internal.createResolutionCallbackFromClosure(fn))
|
||||||
with _disable_emit_hooks():
|
with _disable_emit_hooks():
|
||||||
# We don't want to call the hooks here since the graph that is calling
|
# We don't want to call the hooks here since the graph that is calling
|
||||||
# this function is not yet complete
|
# this function is not yet complete
|
||||||
|
|
@ -1101,6 +1063,15 @@ def _qualified_name(obj):
|
||||||
return module_name + "." + name
|
return module_name + "." + name
|
||||||
|
|
||||||
|
|
||||||
|
def _is_recursive_script_enabled(value):
|
||||||
|
# TODO: [enable recursive script]
|
||||||
|
# when recursive script is made the default, remove this method
|
||||||
|
enabled = torch._C._jit_recursive_script()
|
||||||
|
module = inspect.getmodule(value)
|
||||||
|
if module is not None and 'torch.nn' in module.__name__:
|
||||||
|
enabled = True
|
||||||
|
return enabled
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def _enable_recursive_script():
|
def _enable_recursive_script():
|
||||||
torch._C._jit_recursive_script(True)
|
torch._C._jit_recursive_script(True)
|
||||||
|
|
@ -1114,8 +1085,8 @@ def script(obj, optimize=True, _frames_up=0, _rcb=None):
|
||||||
if _rcb is None:
|
if _rcb is None:
|
||||||
_rcb = _jit_internal.createResolutionCallback(_frames_up + 1)
|
_rcb = _jit_internal.createResolutionCallback(_frames_up + 1)
|
||||||
|
|
||||||
if torch._C._jit_recursive_script():
|
|
||||||
if isinstance(obj, torch.nn.Module):
|
if isinstance(obj, torch.nn.Module):
|
||||||
|
if _is_recursive_script_enabled(obj):
|
||||||
return _convert_to_script_module(obj)
|
return _convert_to_script_module(obj)
|
||||||
|
|
||||||
if inspect.isclass(obj):
|
if inspect.isclass(obj):
|
||||||
|
|
@ -1158,21 +1129,6 @@ def script_method(fn, _rcb=None):
|
||||||
return ScriptMethodStub(_rcb, ast, fn)
|
return ScriptMethodStub(_rcb, ast, fn)
|
||||||
|
|
||||||
|
|
||||||
def _try_get_weak_module(mod):
|
|
||||||
"""
|
|
||||||
Get the WeakScriptModuleProxy corresponding to mod if it exists
|
|
||||||
"""
|
|
||||||
if not isinstance(mod, Module):
|
|
||||||
return None
|
|
||||||
return _jit_internal.weak_modules.get(mod)
|
|
||||||
|
|
||||||
|
|
||||||
def _is_weak_type(cls):
|
|
||||||
"""
|
|
||||||
Check if a type has been annotated with `weak_module`
|
|
||||||
"""
|
|
||||||
return cls in _jit_internal.weak_types
|
|
||||||
|
|
||||||
|
|
||||||
# These OrderedDictWrapper classes replace the actual OrderedDicts in
|
# These OrderedDictWrapper classes replace the actual OrderedDicts in
|
||||||
# module with versions that get/set properties inside of script::Module.
|
# module with versions that get/set properties inside of script::Module.
|
||||||
|
|
@ -1569,9 +1525,9 @@ if _enabled:
|
||||||
|
|
||||||
def __setattr__(self, attr, value):
|
def __setattr__(self, attr, value):
|
||||||
if attr not in self._constants_set:
|
if attr not in self._constants_set:
|
||||||
if isinstance(value, Module) and _is_weak_type(type(value)):
|
if isinstance(value, Module) and _is_recursive_script_enabled(value):
|
||||||
# Compile weak script module
|
# Compile weak script module
|
||||||
value = _make_strong(value)
|
value = _convert_to_script_module(value)
|
||||||
if attr == 'training':
|
if attr == 'training':
|
||||||
if self._c._has_attribute('training'):
|
if self._c._has_attribute('training'):
|
||||||
self.__dict__['training'] = value
|
self.__dict__['training'] = value
|
||||||
|
|
@ -1684,7 +1640,7 @@ if _enabled:
|
||||||
if isinstance(item, (ModuleList, Sequential)):
|
if isinstance(item, (ModuleList, Sequential)):
|
||||||
# These are in __constants__, so ignore them here
|
# These are in __constants__, so ignore them here
|
||||||
|
|
||||||
if not torch._C._jit_recursive_script():
|
if not _is_recursive_script_enabled(item):
|
||||||
# For recursive script, these are constantified after
|
# For recursive script, these are constantified after
|
||||||
# they are used, so they don't need to be in constants.
|
# they are used, so they don't need to be in constants.
|
||||||
# The `continue` here should be deleted along with
|
# The `continue` here should be deleted along with
|
||||||
|
|
@ -1774,33 +1730,27 @@ else:
|
||||||
super(ScriptModule, self).__init__()
|
super(ScriptModule, self).__init__()
|
||||||
|
|
||||||
|
|
||||||
def _get_weak_stubs(cls):
|
def _convert_to_script_module(mod):
|
||||||
"""
|
|
||||||
Calls script_method for each method that has been annotated with @weak_script
|
|
||||||
on the type of the object passed in and returns the generated ScriptMethodStubs.
|
|
||||||
"""
|
|
||||||
stubs = []
|
|
||||||
for name in dir(cls):
|
|
||||||
func = get_function_from_type(cls, name)
|
|
||||||
if func in _jit_internal.weak_script_methods:
|
|
||||||
entry = _jit_internal.weak_script_methods[func]
|
|
||||||
stub = script_method(entry["original_method"], entry["rcb"])
|
|
||||||
stubs.append(stub)
|
|
||||||
return stubs
|
|
||||||
|
|
||||||
|
|
||||||
def _convert_to_script_module(mod, methods=None):
|
|
||||||
"""
|
"""
|
||||||
Makes a ScriptModule from an nn.Module. If `_methods` is provided,
|
Makes a ScriptModule from an nn.Module. If `_methods` is provided,
|
||||||
these methods are treated as @script_methods. If not, it defaults to
|
these methods are treated as @script_methods. If not, it defaults to
|
||||||
`('forward',)`. Methods accessed in forward are scripted on demand if
|
`('forward',)`. Methods accessed in forward are scripted on demand if
|
||||||
`_enable_recursive_script()` is used.
|
`_enable_recursive_script()` is used.
|
||||||
"""
|
"""
|
||||||
|
if isinstance(mod, ScriptModule):
|
||||||
|
return mod
|
||||||
|
|
||||||
if isinstance(mod, (ModuleList, Sequential)):
|
if isinstance(mod, (ModuleList, Sequential)):
|
||||||
# Create constant versions for the iterable modules
|
# Create constant versions for the iterable modules
|
||||||
return _create_constant_iterable_module(mod)
|
return _create_constant_iterable_module(mod)
|
||||||
|
|
||||||
if methods is None:
|
methods = ()
|
||||||
|
if hasattr(mod, 'forward'):
|
||||||
|
if mod.forward.__func__ == torch.nn.Module.forward:
|
||||||
|
# TODO: [enable recursive script]
|
||||||
|
# forward was not overrided
|
||||||
|
raise RuntimeError("No forward method was defined on {}".format(mod))
|
||||||
|
if not _jit_internal.is_ignored_fn(mod.forward):
|
||||||
methods = ('forward',)
|
methods = ('forward',)
|
||||||
exported = []
|
exported = []
|
||||||
for name in dir(mod):
|
for name in dir(mod):
|
||||||
|
|
@ -1812,36 +1762,12 @@ def _convert_to_script_module(mod, methods=None):
|
||||||
|
|
||||||
def make_stub(method):
|
def make_stub(method):
|
||||||
func = get_function_from_type(type(mod), method)
|
func = get_function_from_type(type(mod), method)
|
||||||
return script_method(func, createResolutionCallbackFromClosure(func))
|
return script_method(func, _jit_internal.createResolutionCallbackFromClosure(func))
|
||||||
|
|
||||||
stubs = list(map(make_stub, methods))
|
stubs = list(map(make_stub, methods))
|
||||||
return WeakScriptModuleProxy(mod, stubs)
|
return WeakScriptModuleProxy(mod, stubs)
|
||||||
|
|
||||||
|
|
||||||
def _make_strong(mod):
|
|
||||||
"""
|
|
||||||
Converts a weak module into a subclass of ScriptModule. If `_methods` is
|
|
||||||
provided, only these methods are treated as @script_methods.
|
|
||||||
"""
|
|
||||||
if mod in _jit_internal.weak_modules:
|
|
||||||
return _jit_internal.weak_modules[mod]
|
|
||||||
|
|
||||||
cls = type(mod)
|
|
||||||
# Explicitly annotated weak script
|
|
||||||
stubs = _jit_internal.weak_types.get(cls)["method_stubs"]
|
|
||||||
if stubs is None:
|
|
||||||
# Generate stubs and and store on weak_types in case this type is
|
|
||||||
# used again
|
|
||||||
stubs = _get_weak_stubs(cls)
|
|
||||||
_jit_internal.weak_types[cls]["method_stubs"] = stubs
|
|
||||||
|
|
||||||
proxy = WeakScriptModuleProxy(mod, stubs)
|
|
||||||
|
|
||||||
_jit_internal.weak_modules[mod] = proxy
|
|
||||||
|
|
||||||
return proxy
|
|
||||||
|
|
||||||
|
|
||||||
def _get_methods(cls):
|
def _get_methods(cls):
|
||||||
import inspect
|
import inspect
|
||||||
# In Python 3 unbound methods are functions, but in Python 2 they are methods
|
# In Python 3 unbound methods are functions, but in Python 2 they are methods
|
||||||
|
|
@ -1937,13 +1863,13 @@ class _ConstModuleList(ScriptModule):
|
||||||
|
|
||||||
if isinstance(modules, OrderedDict):
|
if isinstance(modules, OrderedDict):
|
||||||
for key, module in modules.items():
|
for key, module in modules.items():
|
||||||
if _is_weak_type(type(module)):
|
if isinstance(module, torch.nn.Module) and _is_recursive_script_enabled(module):
|
||||||
module = _make_strong(module)
|
module = _convert_to_script_module(module)
|
||||||
self.add_module(key, module)
|
self.add_module(key, module)
|
||||||
else:
|
else:
|
||||||
for i, module in enumerate(modules):
|
for i, module in enumerate(modules):
|
||||||
if _is_weak_type(type(module)):
|
if isinstance(module, torch.nn.Module) and _is_recursive_script_enabled(module):
|
||||||
module = _make_strong(module)
|
module = _convert_to_script_module(module)
|
||||||
self.add_module(str(i), module)
|
self.add_module(str(i), module)
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.backends.cudnn as cudnn
|
import torch.backends.cudnn as cudnn
|
||||||
from ..._jit_internal import weak_script
|
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def affine_grid_generator(theta, size):
|
def affine_grid_generator(theta, size):
|
||||||
# type: (Tensor, List[int]) -> Tensor
|
# type: (Tensor, List[int]) -> Tensor
|
||||||
if theta.is_cuda and cudnn.enabled and cudnn.is_acceptable(theta) and len(size) == 4 and size[0] < 65536:
|
if theta.is_cuda and cudnn.enabled and cudnn.is_acceptable(theta) and len(size) == 4 and size[0] < 65536:
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,8 @@
|
||||||
import warnings
|
import warnings
|
||||||
from .._jit_internal import weak_script
|
|
||||||
|
|
||||||
# NB: Keep this file in sync with enums in aten/src/ATen/core/Reduction.h
|
# NB: Keep this file in sync with enums in aten/src/ATen/core/Reduction.h
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def get_enum(reduction):
|
def get_enum(reduction):
|
||||||
# type: (str) -> int
|
# type: (str) -> int
|
||||||
if reduction == 'none':
|
if reduction == 'none':
|
||||||
|
|
@ -26,7 +24,6 @@ def get_enum(reduction):
|
||||||
|
|
||||||
|
|
||||||
# We use these functions in torch/legacy as well, in which case we'll silence the warning
|
# We use these functions in torch/legacy as well, in which case we'll silence the warning
|
||||||
@weak_script
|
|
||||||
def legacy_get_string(size_average, reduce, emit_warning=True):
|
def legacy_get_string(size_average, reduce, emit_warning=True):
|
||||||
# type: (Optional[bool], Optional[bool], bool) -> str
|
# type: (Optional[bool], Optional[bool], bool) -> str
|
||||||
warning = "size_average and reduce args will be deprecated, please use reduction='{}' instead."
|
warning = "size_average and reduce args will be deprecated, please use reduction='{}' instead."
|
||||||
|
|
@ -47,7 +44,6 @@ def legacy_get_string(size_average, reduce, emit_warning=True):
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def legacy_get_enum(size_average, reduce, emit_warning=True):
|
def legacy_get_enum(size_average, reduce, emit_warning=True):
|
||||||
# type: (Optional[bool], Optional[bool], bool) -> int
|
# type: (Optional[bool], Optional[bool], bool) -> int
|
||||||
return get_enum(legacy_get_string(size_average, reduce, emit_warning))
|
return get_enum(legacy_get_string(size_average, reduce, emit_warning))
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ from ._functions import vision
|
||||||
from .modules.utils import _single, _pair, _triple, _list_with_default
|
from .modules.utils import _single, _pair, _triple, _list_with_default
|
||||||
from . import grad # noqa: F401
|
from . import grad # noqa: F401
|
||||||
from . import _VF
|
from . import _VF
|
||||||
from .._jit_internal import weak_script, List
|
from .._jit_internal import boolean_dispatch, List
|
||||||
|
|
||||||
|
|
||||||
conv1d = _add_docstr(torch.conv1d, r"""
|
conv1d = _add_docstr(torch.conv1d, r"""
|
||||||
|
|
@ -299,7 +299,6 @@ Args:
|
||||||
""")
|
""")
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def fractional_max_pool2d_with_indices(input, kernel_size, output_size=None,
|
def fractional_max_pool2d_with_indices(input, kernel_size, output_size=None,
|
||||||
output_ratio=None, return_indices=False,
|
output_ratio=None, return_indices=False,
|
||||||
_random_samples=None):
|
_random_samples=None):
|
||||||
|
|
@ -346,7 +345,6 @@ def fractional_max_pool2d_with_indices(input, kernel_size, output_size=None,
|
||||||
return torch._C._nn.fractional_max_pool2d(input, kernel_size, output_size, _random_samples)
|
return torch._C._nn.fractional_max_pool2d(input, kernel_size, output_size, _random_samples)
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def _fractional_max_pool2d(input, kernel_size, output_size=None,
|
def _fractional_max_pool2d(input, kernel_size, output_size=None,
|
||||||
output_ratio=None, return_indices=False,
|
output_ratio=None, return_indices=False,
|
||||||
_random_samples=None):
|
_random_samples=None):
|
||||||
|
|
@ -355,7 +353,7 @@ def _fractional_max_pool2d(input, kernel_size, output_size=None,
|
||||||
output_ratio, return_indices,
|
output_ratio, return_indices,
|
||||||
_random_samples)[0]
|
_random_samples)[0]
|
||||||
|
|
||||||
fractional_max_pool2d = torch._jit_internal.boolean_dispatch(
|
fractional_max_pool2d = boolean_dispatch(
|
||||||
arg_name='return_indices',
|
arg_name='return_indices',
|
||||||
arg_index=4,
|
arg_index=4,
|
||||||
default=False,
|
default=False,
|
||||||
|
|
@ -365,7 +363,6 @@ fractional_max_pool2d = torch._jit_internal.boolean_dispatch(
|
||||||
func_name='fractional_max_pool2d')
|
func_name='fractional_max_pool2d')
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def fractional_max_pool3d_with_indices(input, kernel_size, output_size=None,
|
def fractional_max_pool3d_with_indices(input, kernel_size, output_size=None,
|
||||||
output_ratio=None, return_indices=False,
|
output_ratio=None, return_indices=False,
|
||||||
_random_samples=None):
|
_random_samples=None):
|
||||||
|
|
@ -414,7 +411,6 @@ def fractional_max_pool3d_with_indices(input, kernel_size, output_size=None,
|
||||||
return torch._C._nn.fractional_max_pool3d(input, kernel_size, output_size, _random_samples)
|
return torch._C._nn.fractional_max_pool3d(input, kernel_size, output_size, _random_samples)
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def _fractional_max_pool3d(input, kernel_size, output_size=None,
|
def _fractional_max_pool3d(input, kernel_size, output_size=None,
|
||||||
output_ratio=None, return_indices=False,
|
output_ratio=None, return_indices=False,
|
||||||
_random_samples=None):
|
_random_samples=None):
|
||||||
|
|
@ -423,7 +419,7 @@ def _fractional_max_pool3d(input, kernel_size, output_size=None,
|
||||||
output_ratio, return_indices,
|
output_ratio, return_indices,
|
||||||
_random_samples)[0]
|
_random_samples)[0]
|
||||||
|
|
||||||
fractional_max_pool3d = torch._jit_internal.boolean_dispatch(
|
fractional_max_pool3d = boolean_dispatch(
|
||||||
arg_name='return_indices',
|
arg_name='return_indices',
|
||||||
arg_index=4,
|
arg_index=4,
|
||||||
default=False,
|
default=False,
|
||||||
|
|
@ -433,7 +429,6 @@ fractional_max_pool3d = torch._jit_internal.boolean_dispatch(
|
||||||
func_name='fractional_max_pool3d')
|
func_name='fractional_max_pool3d')
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def max_pool1d_with_indices(input, kernel_size, stride=None, padding=0,
|
def max_pool1d_with_indices(input, kernel_size, stride=None, padding=0,
|
||||||
dilation=1, ceil_mode=False, return_indices=False):
|
dilation=1, ceil_mode=False, return_indices=False):
|
||||||
# type: (Tensor, BroadcastingList1[int], Optional[BroadcastingList1[int]], BroadcastingList1[int], BroadcastingList1[int], bool, bool) -> Tuple[Tensor, Tensor] # noqa
|
# type: (Tensor, BroadcastingList1[int], Optional[BroadcastingList1[int]], BroadcastingList1[int], BroadcastingList1[int], bool, bool) -> Tuple[Tensor, Tensor] # noqa
|
||||||
|
|
@ -448,7 +443,6 @@ def max_pool1d_with_indices(input, kernel_size, stride=None, padding=0,
|
||||||
input, kernel_size, stride, padding, dilation, ceil_mode)
|
input, kernel_size, stride, padding, dilation, ceil_mode)
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def _max_pool1d(input, kernel_size, stride=None, padding=0, dilation=1,
|
def _max_pool1d(input, kernel_size, stride=None, padding=0, dilation=1,
|
||||||
ceil_mode=False, return_indices=False):
|
ceil_mode=False, return_indices=False):
|
||||||
# type: (Tensor, BroadcastingList1[int], Optional[BroadcastingList1[int]], BroadcastingList1[int], BroadcastingList1[int], bool, bool) -> Tensor # noqa
|
# type: (Tensor, BroadcastingList1[int], Optional[BroadcastingList1[int]], BroadcastingList1[int], BroadcastingList1[int], bool, bool) -> Tensor # noqa
|
||||||
|
|
@ -457,7 +451,7 @@ def _max_pool1d(input, kernel_size, stride=None, padding=0, dilation=1,
|
||||||
return torch.max_pool1d(
|
return torch.max_pool1d(
|
||||||
input, kernel_size, stride, padding, dilation, ceil_mode)
|
input, kernel_size, stride, padding, dilation, ceil_mode)
|
||||||
|
|
||||||
max_pool1d = torch._jit_internal.boolean_dispatch(
|
max_pool1d = boolean_dispatch(
|
||||||
arg_name='return_indices',
|
arg_name='return_indices',
|
||||||
arg_index=6,
|
arg_index=6,
|
||||||
default=False,
|
default=False,
|
||||||
|
|
@ -467,7 +461,6 @@ max_pool1d = torch._jit_internal.boolean_dispatch(
|
||||||
func_name='max_pool1d')
|
func_name='max_pool1d')
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def max_pool2d_with_indices(input, kernel_size, stride=None, padding=0, dilation=1,
|
def max_pool2d_with_indices(input, kernel_size, stride=None, padding=0, dilation=1,
|
||||||
ceil_mode=False, return_indices=False):
|
ceil_mode=False, return_indices=False):
|
||||||
# type: (Tensor, BroadcastingList2[int], Optional[BroadcastingList2[int]], BroadcastingList2[int], BroadcastingList2[int], bool, bool) -> Tuple[Tensor, Tensor] # noqa
|
# type: (Tensor, BroadcastingList2[int], Optional[BroadcastingList2[int]], BroadcastingList2[int], BroadcastingList2[int], bool, bool) -> Tuple[Tensor, Tensor] # noqa
|
||||||
|
|
@ -481,7 +474,6 @@ def max_pool2d_with_indices(input, kernel_size, stride=None, padding=0, dilation
|
||||||
return torch._C._nn.max_pool2d_with_indices(input, kernel_size, stride, padding, dilation, ceil_mode)
|
return torch._C._nn.max_pool2d_with_indices(input, kernel_size, stride, padding, dilation, ceil_mode)
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def _max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1,
|
def _max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1,
|
||||||
ceil_mode=False, return_indices=False):
|
ceil_mode=False, return_indices=False):
|
||||||
# type: (Tensor, BroadcastingList2[int], Optional[BroadcastingList2[int]], BroadcastingList2[int], BroadcastingList2[int], bool, bool) -> Tensor # noqa
|
# type: (Tensor, BroadcastingList2[int], Optional[BroadcastingList2[int]], BroadcastingList2[int], BroadcastingList2[int], bool, bool) -> Tensor # noqa
|
||||||
|
|
@ -490,7 +482,7 @@ def _max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1,
|
||||||
return torch.max_pool2d(
|
return torch.max_pool2d(
|
||||||
input, kernel_size, stride, padding, dilation, ceil_mode)
|
input, kernel_size, stride, padding, dilation, ceil_mode)
|
||||||
|
|
||||||
max_pool2d = torch._jit_internal.boolean_dispatch(
|
max_pool2d = boolean_dispatch(
|
||||||
arg_name='return_indices',
|
arg_name='return_indices',
|
||||||
arg_index=6,
|
arg_index=6,
|
||||||
default=False,
|
default=False,
|
||||||
|
|
@ -500,7 +492,6 @@ max_pool2d = torch._jit_internal.boolean_dispatch(
|
||||||
func_name='max_pool2d')
|
func_name='max_pool2d')
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def max_pool3d_with_indices(input, kernel_size, stride=None, padding=0,
|
def max_pool3d_with_indices(input, kernel_size, stride=None, padding=0,
|
||||||
dilation=1, ceil_mode=False, return_indices=False):
|
dilation=1, ceil_mode=False, return_indices=False):
|
||||||
# type: (Tensor, BroadcastingList3[int], Optional[BroadcastingList3[int]], BroadcastingList3[int], BroadcastingList3[int], bool, bool) -> Tuple[Tensor, Tensor] # noqa
|
# type: (Tensor, BroadcastingList3[int], Optional[BroadcastingList3[int]], BroadcastingList3[int], BroadcastingList3[int], bool, bool) -> Tuple[Tensor, Tensor] # noqa
|
||||||
|
|
@ -515,7 +506,6 @@ def max_pool3d_with_indices(input, kernel_size, stride=None, padding=0,
|
||||||
input, kernel_size, stride, padding, dilation, ceil_mode)
|
input, kernel_size, stride, padding, dilation, ceil_mode)
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def _max_pool3d(input, kernel_size, stride=None, padding=0, dilation=1,
|
def _max_pool3d(input, kernel_size, stride=None, padding=0, dilation=1,
|
||||||
ceil_mode=False, return_indices=False):
|
ceil_mode=False, return_indices=False):
|
||||||
# type: (Tensor, BroadcastingList3[int], Optional[BroadcastingList3[int]], BroadcastingList3[int], BroadcastingList3[int], bool, bool) -> Tensor # noqa
|
# type: (Tensor, BroadcastingList3[int], Optional[BroadcastingList3[int]], BroadcastingList3[int], BroadcastingList3[int], bool, bool) -> Tensor # noqa
|
||||||
|
|
@ -524,7 +514,7 @@ def _max_pool3d(input, kernel_size, stride=None, padding=0, dilation=1,
|
||||||
return torch.max_pool3d(
|
return torch.max_pool3d(
|
||||||
input, kernel_size, stride, padding, dilation, ceil_mode)
|
input, kernel_size, stride, padding, dilation, ceil_mode)
|
||||||
|
|
||||||
max_pool3d = torch._jit_internal.boolean_dispatch(
|
max_pool3d = boolean_dispatch(
|
||||||
arg_name='return_indices',
|
arg_name='return_indices',
|
||||||
arg_index=6,
|
arg_index=6,
|
||||||
default=False,
|
default=False,
|
||||||
|
|
@ -534,7 +524,6 @@ max_pool3d = torch._jit_internal.boolean_dispatch(
|
||||||
func_name='max_pool3d')
|
func_name='max_pool3d')
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def _unpool_output_size(input, kernel_size, stride, padding, output_size):
|
def _unpool_output_size(input, kernel_size, stride, padding, output_size):
|
||||||
# type: (Tensor, List[int], List[int], List[int], Optional[List[int]]) -> List[int]
|
# type: (Tensor, List[int], List[int], List[int], Optional[List[int]]) -> List[int]
|
||||||
input_size = input.size()
|
input_size = input.size()
|
||||||
|
|
@ -564,7 +553,6 @@ def _unpool_output_size(input, kernel_size, stride, padding, output_size):
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def max_unpool1d(input, indices, kernel_size, stride=None, padding=0,
|
def max_unpool1d(input, indices, kernel_size, stride=None, padding=0,
|
||||||
output_size=None):
|
output_size=None):
|
||||||
# type: (Tensor, Tensor, BroadcastingList1[int], Optional[BroadcastingList1[int]], BroadcastingList1[int], Optional[BroadcastingList1[int]]) -> Tensor # noqa
|
# type: (Tensor, Tensor, BroadcastingList1[int], Optional[BroadcastingList1[int]], BroadcastingList1[int], Optional[BroadcastingList1[int]]) -> Tensor # noqa
|
||||||
|
|
@ -588,7 +576,6 @@ def max_unpool1d(input, indices, kernel_size, stride=None, padding=0,
|
||||||
output_size).squeeze(3)
|
output_size).squeeze(3)
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def max_unpool2d(input, indices, kernel_size, stride=None, padding=0,
|
def max_unpool2d(input, indices, kernel_size, stride=None, padding=0,
|
||||||
output_size=None):
|
output_size=None):
|
||||||
# type: (Tensor, Tensor, BroadcastingList2[int], Optional[BroadcastingList2[int]], BroadcastingList2[int], Optional[BroadcastingList2[int]]) -> Tensor # noqa
|
# type: (Tensor, Tensor, BroadcastingList2[int], Optional[BroadcastingList2[int]], BroadcastingList2[int], Optional[BroadcastingList2[int]]) -> Tensor # noqa
|
||||||
|
|
@ -607,7 +594,6 @@ def max_unpool2d(input, indices, kernel_size, stride=None, padding=0,
|
||||||
return torch._C._nn.max_unpool2d(input, indices, output_size)
|
return torch._C._nn.max_unpool2d(input, indices, output_size)
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def max_unpool3d(input, indices, kernel_size, stride=None, padding=0,
|
def max_unpool3d(input, indices, kernel_size, stride=None, padding=0,
|
||||||
output_size=None):
|
output_size=None):
|
||||||
# type: (Tensor, Tensor, BroadcastingList3[int], Optional[BroadcastingList3[int]], BroadcastingList3[int], Optional[BroadcastingList3[int]]) -> Tensor # noqa
|
# type: (Tensor, Tensor, BroadcastingList3[int], Optional[BroadcastingList3[int]], BroadcastingList3[int], Optional[BroadcastingList3[int]]) -> Tensor # noqa
|
||||||
|
|
@ -627,7 +613,6 @@ def max_unpool3d(input, indices, kernel_size, stride=None, padding=0,
|
||||||
input, indices, output_size, _stride, padding)
|
input, indices, output_size, _stride, padding)
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def lp_pool2d(input, norm_type, kernel_size, stride=None, ceil_mode=False):
|
def lp_pool2d(input, norm_type, kernel_size, stride=None, ceil_mode=False):
|
||||||
# type: (Tensor, float, int, Optional[BroadcastingList2[int]], bool) -> Tensor
|
# type: (Tensor, float, int, Optional[BroadcastingList2[int]], bool) -> Tensor
|
||||||
r"""Applies a 2D power-average pooling over an input signal composed of
|
r"""Applies a 2D power-average pooling over an input signal composed of
|
||||||
|
|
@ -645,7 +630,6 @@ def lp_pool2d(input, norm_type, kernel_size, stride=None, ceil_mode=False):
|
||||||
return (torch.sign(out) * relu(torch.abs(out))).mul(kw * kh).pow(1. / norm_type)
|
return (torch.sign(out) * relu(torch.abs(out))).mul(kw * kh).pow(1. / norm_type)
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def lp_pool1d(input, norm_type, kernel_size, stride=None, ceil_mode=False):
|
def lp_pool1d(input, norm_type, kernel_size, stride=None, ceil_mode=False):
|
||||||
# type: (Tensor, float, int, Optional[BroadcastingList1[int]], bool) -> Tensor
|
# type: (Tensor, float, int, Optional[BroadcastingList1[int]], bool) -> Tensor
|
||||||
r"""Applies a 1D power-average pooling over an input signal composed of
|
r"""Applies a 1D power-average pooling over an input signal composed of
|
||||||
|
|
@ -662,7 +646,6 @@ def lp_pool1d(input, norm_type, kernel_size, stride=None, ceil_mode=False):
|
||||||
return (torch.sign(out) * relu(torch.abs(out))).mul(kernel_size).pow(1. / norm_type)
|
return (torch.sign(out) * relu(torch.abs(out))).mul(kernel_size).pow(1. / norm_type)
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def adaptive_max_pool1d_with_indices(input, output_size, return_indices=False):
|
def adaptive_max_pool1d_with_indices(input, output_size, return_indices=False):
|
||||||
# type: (Tensor, BroadcastingList1[int], bool) -> Tuple[Tensor, Tensor]
|
# type: (Tensor, BroadcastingList1[int], bool) -> Tuple[Tensor, Tensor]
|
||||||
r"""Applies a 1D adaptive max pooling over an input signal composed of
|
r"""Applies a 1D adaptive max pooling over an input signal composed of
|
||||||
|
|
@ -677,12 +660,11 @@ def adaptive_max_pool1d_with_indices(input, output_size, return_indices=False):
|
||||||
return torch.adaptive_max_pool1d(input, output_size)
|
return torch.adaptive_max_pool1d(input, output_size)
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def _adaptive_max_pool1d(input, output_size, return_indices=False):
|
def _adaptive_max_pool1d(input, output_size, return_indices=False):
|
||||||
# type: (Tensor, BroadcastingList1[int], bool) -> Tensor
|
# type: (Tensor, BroadcastingList1[int], bool) -> Tensor
|
||||||
return adaptive_max_pool1d_with_indices(input, output_size)[0]
|
return adaptive_max_pool1d_with_indices(input, output_size)[0]
|
||||||
|
|
||||||
adaptive_max_pool1d = torch._jit_internal.boolean_dispatch(
|
adaptive_max_pool1d = boolean_dispatch(
|
||||||
arg_name='return_indices',
|
arg_name='return_indices',
|
||||||
arg_index=2,
|
arg_index=2,
|
||||||
default=False,
|
default=False,
|
||||||
|
|
@ -692,7 +674,6 @@ adaptive_max_pool1d = torch._jit_internal.boolean_dispatch(
|
||||||
func_name='adaptive_max_pool1d')
|
func_name='adaptive_max_pool1d')
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def adaptive_max_pool2d_with_indices(input, output_size, return_indices=False):
|
def adaptive_max_pool2d_with_indices(input, output_size, return_indices=False):
|
||||||
# type: (Tensor, BroadcastingList2[int], bool) -> Tuple[Tensor, Tensor]
|
# type: (Tensor, BroadcastingList2[int], bool) -> Tuple[Tensor, Tensor]
|
||||||
r"""Applies a 2D adaptive max pooling over an input signal composed of
|
r"""Applies a 2D adaptive max pooling over an input signal composed of
|
||||||
|
|
@ -709,12 +690,11 @@ def adaptive_max_pool2d_with_indices(input, output_size, return_indices=False):
|
||||||
return torch._C._nn.adaptive_max_pool2d(input, output_size)
|
return torch._C._nn.adaptive_max_pool2d(input, output_size)
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def _adaptive_max_pool2d(input, output_size, return_indices=False):
|
def _adaptive_max_pool2d(input, output_size, return_indices=False):
|
||||||
# type: (Tensor, BroadcastingList2[int], bool) -> Tensor
|
# type: (Tensor, BroadcastingList2[int], bool) -> Tensor
|
||||||
return adaptive_max_pool2d_with_indices(input, output_size)[0]
|
return adaptive_max_pool2d_with_indices(input, output_size)[0]
|
||||||
|
|
||||||
adaptive_max_pool2d = torch._jit_internal.boolean_dispatch(
|
adaptive_max_pool2d = boolean_dispatch(
|
||||||
arg_name='return_indices',
|
arg_name='return_indices',
|
||||||
arg_index=2,
|
arg_index=2,
|
||||||
default=False,
|
default=False,
|
||||||
|
|
@ -724,7 +704,6 @@ adaptive_max_pool2d = torch._jit_internal.boolean_dispatch(
|
||||||
func_name='adaptive_max_pool2d')
|
func_name='adaptive_max_pool2d')
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def adaptive_max_pool3d_with_indices(input, output_size, return_indices=False):
|
def adaptive_max_pool3d_with_indices(input, output_size, return_indices=False):
|
||||||
# type: (Tensor, BroadcastingList3[int], bool) -> Tuple[Tensor, Tensor]
|
# type: (Tensor, BroadcastingList3[int], bool) -> Tuple[Tensor, Tensor]
|
||||||
r"""Applies a 3D adaptive max pooling over an input signal composed of
|
r"""Applies a 3D adaptive max pooling over an input signal composed of
|
||||||
|
|
@ -741,12 +720,11 @@ def adaptive_max_pool3d_with_indices(input, output_size, return_indices=False):
|
||||||
return torch._C._nn.adaptive_max_pool3d(input, output_size)
|
return torch._C._nn.adaptive_max_pool3d(input, output_size)
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def _adaptive_max_pool3d(input, output_size, return_indices=False):
|
def _adaptive_max_pool3d(input, output_size, return_indices=False):
|
||||||
# type: (Tensor, BroadcastingList3[int], bool) -> Tensor
|
# type: (Tensor, BroadcastingList3[int], bool) -> Tensor
|
||||||
return adaptive_max_pool3d_with_indices(input, output_size)[0]
|
return adaptive_max_pool3d_with_indices(input, output_size)[0]
|
||||||
|
|
||||||
adaptive_max_pool3d = torch._jit_internal.boolean_dispatch(
|
adaptive_max_pool3d = boolean_dispatch(
|
||||||
arg_name='return_indices',
|
arg_name='return_indices',
|
||||||
arg_index=2,
|
arg_index=2,
|
||||||
default=False,
|
default=False,
|
||||||
|
|
@ -769,7 +747,6 @@ Args:
|
||||||
""")
|
""")
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def adaptive_avg_pool2d(input, output_size):
|
def adaptive_avg_pool2d(input, output_size):
|
||||||
# type: (Tensor, BroadcastingList2[int]) -> Tensor
|
# type: (Tensor, BroadcastingList2[int]) -> Tensor
|
||||||
r"""
|
r"""
|
||||||
|
|
@ -786,7 +763,6 @@ def adaptive_avg_pool2d(input, output_size):
|
||||||
return torch._C._nn.adaptive_avg_pool2d(input, _output_size)
|
return torch._C._nn.adaptive_avg_pool2d(input, _output_size)
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def adaptive_avg_pool3d(input, output_size):
|
def adaptive_avg_pool3d(input, output_size):
|
||||||
# type: (Tensor, BroadcastingList3[int]) -> Tensor
|
# type: (Tensor, BroadcastingList3[int]) -> Tensor
|
||||||
r"""
|
r"""
|
||||||
|
|
@ -804,7 +780,6 @@ def adaptive_avg_pool3d(input, output_size):
|
||||||
|
|
||||||
|
|
||||||
# Activation functions
|
# Activation functions
|
||||||
@weak_script
|
|
||||||
def dropout(input, p=0.5, training=True, inplace=False):
|
def dropout(input, p=0.5, training=True, inplace=False):
|
||||||
# type: (Tensor, float, bool, bool) -> Tensor
|
# type: (Tensor, float, bool, bool) -> Tensor
|
||||||
r"""
|
r"""
|
||||||
|
|
@ -827,7 +802,6 @@ def dropout(input, p=0.5, training=True, inplace=False):
|
||||||
else _VF.dropout(input, p, training))
|
else _VF.dropout(input, p, training))
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def alpha_dropout(input, p=0.5, training=False, inplace=False):
|
def alpha_dropout(input, p=0.5, training=False, inplace=False):
|
||||||
# type: (Tensor, float, bool, bool) -> Tensor
|
# type: (Tensor, float, bool, bool) -> Tensor
|
||||||
r"""Applies alpha dropout to the input.
|
r"""Applies alpha dropout to the input.
|
||||||
|
|
@ -842,7 +816,6 @@ def alpha_dropout(input, p=0.5, training=False, inplace=False):
|
||||||
else _VF.alpha_dropout(input, p, training))
|
else _VF.alpha_dropout(input, p, training))
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def dropout2d(input, p=0.5, training=True, inplace=False):
|
def dropout2d(input, p=0.5, training=True, inplace=False):
|
||||||
# type: (Tensor, float, bool, bool) -> Tensor
|
# type: (Tensor, float, bool, bool) -> Tensor
|
||||||
r"""
|
r"""
|
||||||
|
|
@ -867,7 +840,6 @@ def dropout2d(input, p=0.5, training=True, inplace=False):
|
||||||
else _VF.feature_dropout(input, p, training))
|
else _VF.feature_dropout(input, p, training))
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def dropout3d(input, p=0.5, training=True, inplace=False):
|
def dropout3d(input, p=0.5, training=True, inplace=False):
|
||||||
# type: (Tensor, float, bool, bool) -> Tensor
|
# type: (Tensor, float, bool, bool) -> Tensor
|
||||||
r"""
|
r"""
|
||||||
|
|
@ -894,7 +866,6 @@ def dropout3d(input, p=0.5, training=True, inplace=False):
|
||||||
else _VF.feature_dropout(input, p, training))
|
else _VF.feature_dropout(input, p, training))
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def feature_alpha_dropout(input, p=0.5, training=False, inplace=False):
|
def feature_alpha_dropout(input, p=0.5, training=False, inplace=False):
|
||||||
# type: (Tensor, float, bool, bool) -> Tensor
|
# type: (Tensor, float, bool, bool) -> Tensor
|
||||||
if p < 0. or p > 1.:
|
if p < 0. or p > 1.:
|
||||||
|
|
@ -905,7 +876,6 @@ def feature_alpha_dropout(input, p=0.5, training=False, inplace=False):
|
||||||
else _VF.feature_alpha_dropout(input, p, training))
|
else _VF.feature_alpha_dropout(input, p, training))
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def threshold(input, threshold, value, inplace=False):
|
def threshold(input, threshold, value, inplace=False):
|
||||||
# type: (Tensor, float, float, bool) -> Tensor
|
# type: (Tensor, float, float, bool) -> Tensor
|
||||||
r"""Thresholds each element of the input Tensor.
|
r"""Thresholds each element of the input Tensor.
|
||||||
|
|
@ -926,7 +896,6 @@ In-place version of :func:`~threshold`.
|
||||||
""")
|
""")
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def relu(input, inplace=False):
|
def relu(input, inplace=False):
|
||||||
# type: (Tensor, bool) -> Tensor
|
# type: (Tensor, bool) -> Tensor
|
||||||
r"""relu(input, inplace=False) -> Tensor
|
r"""relu(input, inplace=False) -> Tensor
|
||||||
|
|
@ -948,7 +917,6 @@ In-place version of :func:`~relu`.
|
||||||
""")
|
""")
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def glu(input, dim=-1):
|
def glu(input, dim=-1):
|
||||||
# type: (Tensor, int) -> Tensor
|
# type: (Tensor, int) -> Tensor
|
||||||
r"""
|
r"""
|
||||||
|
|
@ -973,7 +941,6 @@ def glu(input, dim=-1):
|
||||||
return torch._C._nn.glu(input, dim)
|
return torch._C._nn.glu(input, dim)
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def hardtanh(input, min_val=-1., max_val=1., inplace=False):
|
def hardtanh(input, min_val=-1., max_val=1., inplace=False):
|
||||||
# type: (Tensor, float, float, bool) -> Tensor
|
# type: (Tensor, float, float, bool) -> Tensor
|
||||||
r"""
|
r"""
|
||||||
|
|
@ -996,7 +963,6 @@ In-place version of :func:`~hardtanh`.
|
||||||
""")
|
""")
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def relu6(input, inplace=False):
|
def relu6(input, inplace=False):
|
||||||
# type: (Tensor, bool) -> Tensor
|
# type: (Tensor, bool) -> Tensor
|
||||||
r"""relu6(input, inplace=False) -> Tensor
|
r"""relu6(input, inplace=False) -> Tensor
|
||||||
|
|
@ -1008,7 +974,6 @@ def relu6(input, inplace=False):
|
||||||
return hardtanh(input, 0., 6., inplace)
|
return hardtanh(input, 0., 6., inplace)
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def elu(input, alpha=1., inplace=False):
|
def elu(input, alpha=1., inplace=False):
|
||||||
# type: (Tensor, float, bool) -> Tensor
|
# type: (Tensor, float, bool) -> Tensor
|
||||||
r"""Applies element-wise,
|
r"""Applies element-wise,
|
||||||
|
|
@ -1030,7 +995,6 @@ In-place version of :func:`~elu`.
|
||||||
""")
|
""")
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def selu(input, inplace=False):
|
def selu(input, inplace=False):
|
||||||
# type: (Tensor, bool) -> Tensor
|
# type: (Tensor, bool) -> Tensor
|
||||||
r"""selu(input, inplace=False) -> Tensor
|
r"""selu(input, inplace=False) -> Tensor
|
||||||
|
|
@ -1056,7 +1020,6 @@ In-place version of :func:`~selu`.
|
||||||
""")
|
""")
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def celu(input, alpha=1., inplace=False):
|
def celu(input, alpha=1., inplace=False):
|
||||||
# type: (Tensor, float, bool) -> Tensor
|
# type: (Tensor, float, bool) -> Tensor
|
||||||
r"""celu(input, alpha=1., inplace=False) -> Tensor
|
r"""celu(input, alpha=1., inplace=False) -> Tensor
|
||||||
|
|
@ -1079,7 +1042,6 @@ In-place version of :func:`~celu`.
|
||||||
""")
|
""")
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def leaky_relu(input, negative_slope=0.01, inplace=False):
|
def leaky_relu(input, negative_slope=0.01, inplace=False):
|
||||||
# type: (Tensor, float, bool) -> Tensor
|
# type: (Tensor, float, bool) -> Tensor
|
||||||
r"""
|
r"""
|
||||||
|
|
@ -1104,7 +1066,6 @@ In-place version of :func:`~leaky_relu`.
|
||||||
""")
|
""")
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def prelu(input, weight):
|
def prelu(input, weight):
|
||||||
# type: (Tensor, Tensor) -> Tensor
|
# type: (Tensor, Tensor) -> Tensor
|
||||||
r"""prelu(input, weight) -> Tensor
|
r"""prelu(input, weight) -> Tensor
|
||||||
|
|
@ -1118,7 +1079,6 @@ def prelu(input, weight):
|
||||||
return torch.prelu(input, weight)
|
return torch.prelu(input, weight)
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def rrelu(input, lower=1. / 8, upper=1. / 3, training=False, inplace=False):
|
def rrelu(input, lower=1. / 8, upper=1. / 3, training=False, inplace=False):
|
||||||
# type: (Tensor, float, float, bool, bool) -> Tensor
|
# type: (Tensor, float, float, bool, bool) -> Tensor
|
||||||
r"""rrelu(input, lower=1./8, upper=1./3, training=False, inplace=False) -> Tensor
|
r"""rrelu(input, lower=1./8, upper=1./3, training=False, inplace=False) -> Tensor
|
||||||
|
|
@ -1148,7 +1108,6 @@ Applies element-wise :math:`\text{LogSigmoid}(x_i) = \log \left(\frac{1}{1 + \ex
|
||||||
See :class:`~torch.nn.LogSigmoid` for more details.
|
See :class:`~torch.nn.LogSigmoid` for more details.
|
||||||
""")
|
""")
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def gelu(input):
|
def gelu(input):
|
||||||
r"""gelu(input) -> Tensor
|
r"""gelu(input) -> Tensor
|
||||||
|
|
||||||
|
|
@ -1162,7 +1121,6 @@ def gelu(input):
|
||||||
return torch._C._nn.gelu(input)
|
return torch._C._nn.gelu(input)
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def hardshrink(input, lambd=0.5):
|
def hardshrink(input, lambd=0.5):
|
||||||
# type: (Tensor, float) -> Tensor
|
# type: (Tensor, float) -> Tensor
|
||||||
r"""
|
r"""
|
||||||
|
|
@ -1175,7 +1133,6 @@ def hardshrink(input, lambd=0.5):
|
||||||
return torch.hardshrink(input, lambd)
|
return torch.hardshrink(input, lambd)
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def tanhshrink(input):
|
def tanhshrink(input):
|
||||||
r"""tanhshrink(input) -> Tensor
|
r"""tanhshrink(input) -> Tensor
|
||||||
|
|
||||||
|
|
@ -1186,7 +1143,6 @@ def tanhshrink(input):
|
||||||
return input - input.tanh()
|
return input - input.tanh()
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def softsign(input):
|
def softsign(input):
|
||||||
r"""softsign(input) -> Tensor
|
r"""softsign(input) -> Tensor
|
||||||
|
|
||||||
|
|
@ -1202,7 +1158,6 @@ softplus(input, beta=1, threshold=20) -> Tensor
|
||||||
""")
|
""")
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def _get_softmax_dim(name, ndim, stacklevel):
|
def _get_softmax_dim(name, ndim, stacklevel):
|
||||||
# type: (str, int, int) -> int
|
# type: (str, int, int) -> int
|
||||||
warnings.warn("Implicit dimension choice for {} has been deprecated. "
|
warnings.warn("Implicit dimension choice for {} has been deprecated. "
|
||||||
|
|
@ -1214,7 +1169,6 @@ def _get_softmax_dim(name, ndim, stacklevel):
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def softmin(input, dim=None, _stacklevel=3, dtype=None):
|
def softmin(input, dim=None, _stacklevel=3, dtype=None):
|
||||||
# type: (Tensor, Optional[int], int, Optional[int]) -> Tensor
|
# type: (Tensor, Optional[int], int, Optional[int]) -> Tensor
|
||||||
r"""Applies a softmin function.
|
r"""Applies a softmin function.
|
||||||
|
|
@ -1240,7 +1194,6 @@ def softmin(input, dim=None, _stacklevel=3, dtype=None):
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def softmax(input, dim=None, _stacklevel=3, dtype=None):
|
def softmax(input, dim=None, _stacklevel=3, dtype=None):
|
||||||
# type: (Tensor, Optional[int], int, Optional[int]) -> Tensor
|
# type: (Tensor, Optional[int], int, Optional[int]) -> Tensor
|
||||||
r"""Applies a softmax function.
|
r"""Applies a softmax function.
|
||||||
|
|
@ -1276,7 +1229,6 @@ def softmax(input, dim=None, _stacklevel=3, dtype=None):
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1):
|
def gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1):
|
||||||
# type: (Tensor, float, bool, float, int) -> Tensor
|
# type: (Tensor, float, bool, float, int) -> Tensor
|
||||||
r"""
|
r"""
|
||||||
|
|
@ -1337,7 +1289,6 @@ def gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1):
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def log_softmax(input, dim=None, _stacklevel=3, dtype=None):
|
def log_softmax(input, dim=None, _stacklevel=3, dtype=None):
|
||||||
# type: (Tensor, Optional[int], int, Optional[int]) -> Tensor
|
# type: (Tensor, Optional[int], int, Optional[int]) -> Tensor
|
||||||
r"""Applies a softmax followed by a logarithm.
|
r"""Applies a softmax followed by a logarithm.
|
||||||
|
|
@ -1373,7 +1324,6 @@ See :class:`~torch.nn.Softshrink` for more details.
|
||||||
""")
|
""")
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def tanh(input):
|
def tanh(input):
|
||||||
r"""tanh(input) -> Tensor
|
r"""tanh(input) -> Tensor
|
||||||
|
|
||||||
|
|
@ -1386,7 +1336,6 @@ def tanh(input):
|
||||||
return input.tanh()
|
return input.tanh()
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def sigmoid(input):
|
def sigmoid(input):
|
||||||
r"""sigmoid(input) -> Tensor
|
r"""sigmoid(input) -> Tensor
|
||||||
|
|
||||||
|
|
@ -1398,7 +1347,6 @@ def sigmoid(input):
|
||||||
return input.sigmoid()
|
return input.sigmoid()
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def linear(input, weight, bias=None):
|
def linear(input, weight, bias=None):
|
||||||
# type: (Tensor, Tensor, Optional[Tensor]) -> Tensor
|
# type: (Tensor, Tensor, Optional[Tensor]) -> Tensor
|
||||||
r"""
|
r"""
|
||||||
|
|
@ -1423,7 +1371,6 @@ def linear(input, weight, bias=None):
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def bilinear(input1, input2, weight, bias=None):
|
def bilinear(input1, input2, weight, bias=None):
|
||||||
# type: (Tensor, Tensor, Tensor, Optional[Tensor]) -> Tensor
|
# type: (Tensor, Tensor, Tensor, Optional[Tensor]) -> Tensor
|
||||||
return torch.bilinear(input1, input2, weight, bias)
|
return torch.bilinear(input1, input2, weight, bias)
|
||||||
|
|
@ -1435,7 +1382,6 @@ def _no_grad_embedding_renorm_(weight, input, max_norm, norm_type):
|
||||||
torch.embedding_renorm_(weight, input, max_norm, norm_type)
|
torch.embedding_renorm_(weight, input, max_norm, norm_type)
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2.,
|
def embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2.,
|
||||||
scale_grad_by_freq=False, sparse=False):
|
scale_grad_by_freq=False, sparse=False):
|
||||||
# type: (Tensor, Tensor, Optional[int], Optional[float], float, bool, bool) -> Tensor
|
# type: (Tensor, Tensor, Optional[int], Optional[float], float, bool, bool) -> Tensor
|
||||||
|
|
@ -1517,7 +1463,6 @@ def embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2.,
|
||||||
return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
|
return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def embedding_bag(input, weight, offsets=None, max_norm=None, norm_type=2,
|
def embedding_bag(input, weight, offsets=None, max_norm=None, norm_type=2,
|
||||||
scale_grad_by_freq=False, mode='mean', sparse=False,
|
scale_grad_by_freq=False, mode='mean', sparse=False,
|
||||||
per_sample_weights=None):
|
per_sample_weights=None):
|
||||||
|
|
@ -1677,7 +1622,6 @@ def embedding_bag(input, weight, offsets=None, max_norm=None, norm_type=2,
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def batch_norm(input, running_mean, running_var, weight=None, bias=None,
|
def batch_norm(input, running_mean, running_var, weight=None, bias=None,
|
||||||
training=False, momentum=0.1, eps=1e-5):
|
training=False, momentum=0.1, eps=1e-5):
|
||||||
# type: (Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor], bool, float, float) -> Tensor # noqa
|
# type: (Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor], bool, float, float) -> Tensor # noqa
|
||||||
|
|
@ -1709,7 +1653,6 @@ def batch_norm(input, running_mean, running_var, weight=None, bias=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def instance_norm(input, running_mean=None, running_var=None, weight=None,
|
def instance_norm(input, running_mean=None, running_var=None, weight=None,
|
||||||
bias=None, use_input_stats=True, momentum=0.1, eps=1e-5):
|
bias=None, use_input_stats=True, momentum=0.1, eps=1e-5):
|
||||||
# type: (Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor], bool, float, float) -> Tensor # noqa
|
# type: (Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor], bool, float, float) -> Tensor # noqa
|
||||||
|
|
@ -1725,7 +1668,6 @@ def instance_norm(input, running_mean=None, running_var=None, weight=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-5):
|
def layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-5):
|
||||||
# type: (Tensor, List[int], Optional[Tensor], Optional[Tensor], float) -> Tensor
|
# type: (Tensor, List[int], Optional[Tensor], Optional[Tensor], float) -> Tensor
|
||||||
r"""Applies Layer Normalization for last certain number of dimensions.
|
r"""Applies Layer Normalization for last certain number of dimensions.
|
||||||
|
|
@ -1736,7 +1678,6 @@ def layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-5):
|
||||||
torch.backends.cudnn.enabled)
|
torch.backends.cudnn.enabled)
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def group_norm(input, num_groups, weight=None, bias=None, eps=1e-5):
|
def group_norm(input, num_groups, weight=None, bias=None, eps=1e-5):
|
||||||
# type: (Tensor, int, Optional[Tensor], Optional[Tensor], float) -> Tensor
|
# type: (Tensor, int, Optional[Tensor], Optional[Tensor], float) -> Tensor
|
||||||
r"""Applies Group Normalization for last certain number of dimensions.
|
r"""Applies Group Normalization for last certain number of dimensions.
|
||||||
|
|
@ -1747,7 +1688,6 @@ def group_norm(input, num_groups, weight=None, bias=None, eps=1e-5):
|
||||||
torch.backends.cudnn.enabled)
|
torch.backends.cudnn.enabled)
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def local_response_norm(input, size, alpha=1e-4, beta=0.75, k=1.):
|
def local_response_norm(input, size, alpha=1e-4, beta=0.75, k=1.):
|
||||||
# type: (Tensor, int, float, float, float) -> Tensor
|
# type: (Tensor, int, float, float, float) -> Tensor
|
||||||
r"""Applies local response normalization over an input signal composed of
|
r"""Applies local response normalization over an input signal composed of
|
||||||
|
|
@ -1776,7 +1716,6 @@ def local_response_norm(input, size, alpha=1e-4, beta=0.75, k=1.):
|
||||||
|
|
||||||
# loss
|
# loss
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=0,
|
def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=0,
|
||||||
reduction='mean', zero_infinity=False):
|
reduction='mean', zero_infinity=False):
|
||||||
# type: (Tensor, Tensor, Tensor, Tensor, int, str, bool) -> Tensor
|
# type: (Tensor, Tensor, Tensor, Tensor, int, str, bool) -> Tensor
|
||||||
|
|
@ -1824,7 +1763,6 @@ def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=0,
|
||||||
zero_infinity)
|
zero_infinity)
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def nll_loss(input, target, weight=None, size_average=None, ignore_index=-100,
|
def nll_loss(input, target, weight=None, size_average=None, ignore_index=-100,
|
||||||
reduce=None, reduction='mean'):
|
reduce=None, reduction='mean'):
|
||||||
# type: (Tensor, Tensor, Optional[Tensor], Optional[bool], int, Optional[bool], str) -> Tensor
|
# type: (Tensor, Tensor, Optional[Tensor], Optional[bool], int, Optional[bool], str) -> Tensor
|
||||||
|
|
@ -1903,7 +1841,6 @@ def nll_loss(input, target, weight=None, size_average=None, ignore_index=-100,
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def poisson_nll_loss(input, target, log_input=True, full=False, size_average=None, eps=1e-8,
|
def poisson_nll_loss(input, target, log_input=True, full=False, size_average=None, eps=1e-8,
|
||||||
reduce=None, reduction='mean'):
|
reduce=None, reduction='mean'):
|
||||||
# type: (Tensor, Tensor, bool, bool, Optional[bool], float, Optional[bool], str) -> Tensor
|
# type: (Tensor, Tensor, bool, bool, Optional[bool], float, Optional[bool], str) -> Tensor
|
||||||
|
|
@ -1949,7 +1886,6 @@ def poisson_nll_loss(input, target, log_input=True, full=False, size_average=Non
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def kl_div(input, target, size_average=None, reduce=None, reduction='mean'):
|
def kl_div(input, target, size_average=None, reduce=None, reduction='mean'):
|
||||||
# type: (Tensor, Tensor, Optional[bool], Optional[bool], str) -> Tensor
|
# type: (Tensor, Tensor, Optional[bool], Optional[bool], str) -> Tensor
|
||||||
r"""The `Kullback-Leibler divergence`_ Loss.
|
r"""The `Kullback-Leibler divergence`_ Loss.
|
||||||
|
|
@ -2007,7 +1943,6 @@ def kl_div(input, target, size_average=None, reduce=None, reduction='mean'):
|
||||||
return reduced
|
return reduced
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def cross_entropy(input, target, weight=None, size_average=None, ignore_index=-100,
|
def cross_entropy(input, target, weight=None, size_average=None, ignore_index=-100,
|
||||||
reduce=None, reduction='mean'):
|
reduce=None, reduction='mean'):
|
||||||
# type: (Tensor, Tensor, Optional[Tensor], Optional[bool], int, Optional[bool], str) -> Tensor
|
# type: (Tensor, Tensor, Optional[Tensor], Optional[bool], int, Optional[bool], str) -> Tensor
|
||||||
|
|
@ -2056,7 +1991,6 @@ def cross_entropy(input, target, weight=None, size_average=None, ignore_index=-1
|
||||||
return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
|
return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def binary_cross_entropy(input, target, weight=None, size_average=None,
|
def binary_cross_entropy(input, target, weight=None, size_average=None,
|
||||||
reduce=None, reduction='mean'):
|
reduce=None, reduction='mean'):
|
||||||
# type: (Tensor, Tensor, Optional[Tensor], Optional[bool], Optional[bool], str) -> Tensor
|
# type: (Tensor, Tensor, Optional[Tensor], Optional[bool], Optional[bool], str) -> Tensor
|
||||||
|
|
@ -2113,7 +2047,6 @@ def binary_cross_entropy(input, target, weight=None, size_average=None,
|
||||||
input, target, weight, reduction_enum)
|
input, target, weight, reduction_enum)
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def binary_cross_entropy_with_logits(input, target, weight=None, size_average=None,
|
def binary_cross_entropy_with_logits(input, target, weight=None, size_average=None,
|
||||||
reduce=None, reduction='mean', pos_weight=None):
|
reduce=None, reduction='mean', pos_weight=None):
|
||||||
# type: (Tensor, Tensor, Optional[Tensor], Optional[bool], Optional[bool], str, Optional[Tensor]) -> Tensor
|
# type: (Tensor, Tensor, Optional[Tensor], Optional[bool], Optional[bool], str, Optional[Tensor]) -> Tensor
|
||||||
|
|
@ -2174,14 +2107,12 @@ def _pointwise_loss(lambd, lambd_optimized, input, target, reduction='mean'):
|
||||||
return lambd_optimized(expanded_input, expanded_target, _Reduction.get_enum(reduction))
|
return lambd_optimized(expanded_input, expanded_target, _Reduction.get_enum(reduction))
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def _smooth_l1_loss(input, target):
|
def _smooth_l1_loss(input, target):
|
||||||
# type: (Tensor, Tensor) -> Tensor
|
# type: (Tensor, Tensor) -> Tensor
|
||||||
t = torch.abs(input - target)
|
t = torch.abs(input - target)
|
||||||
return torch.where(t < 1, 0.5 * t ** 2, t - 0.5)
|
return torch.where(t < 1, 0.5 * t ** 2, t - 0.5)
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def smooth_l1_loss(input, target, size_average=None, reduce=None, reduction='mean'):
|
def smooth_l1_loss(input, target, size_average=None, reduce=None, reduction='mean'):
|
||||||
# type: (Tensor, Tensor, Optional[bool], Optional[bool], str) -> Tensor
|
# type: (Tensor, Tensor, Optional[bool], Optional[bool], str) -> Tensor
|
||||||
r"""Function that uses a squared term if the absolute
|
r"""Function that uses a squared term if the absolute
|
||||||
|
|
@ -2206,7 +2137,6 @@ def smooth_l1_loss(input, target, size_average=None, reduce=None, reduction='mea
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def l1_loss(input, target, size_average=None, reduce=None, reduction='mean'):
|
def l1_loss(input, target, size_average=None, reduce=None, reduction='mean'):
|
||||||
# type: (Tensor, Tensor, Optional[bool], Optional[bool], str) -> Tensor
|
# type: (Tensor, Tensor, Optional[bool], Optional[bool], str) -> Tensor
|
||||||
r"""l1_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor
|
r"""l1_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor
|
||||||
|
|
@ -2232,7 +2162,6 @@ def l1_loss(input, target, size_average=None, reduce=None, reduction='mean'):
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def mse_loss(input, target, size_average=None, reduce=None, reduction='mean'):
|
def mse_loss(input, target, size_average=None, reduce=None, reduction='mean'):
|
||||||
# type: (Tensor, Tensor, Optional[bool], Optional[bool], str) -> Tensor
|
# type: (Tensor, Tensor, Optional[bool], Optional[bool], str) -> Tensor
|
||||||
r"""mse_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor
|
r"""mse_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor
|
||||||
|
|
@ -2258,7 +2187,6 @@ def mse_loss(input, target, size_average=None, reduce=None, reduction='mean'):
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def margin_ranking_loss(input1, input2, target, margin=0, size_average=None,
|
def margin_ranking_loss(input1, input2, target, margin=0, size_average=None,
|
||||||
reduce=None, reduction='mean'):
|
reduce=None, reduction='mean'):
|
||||||
# type: (Tensor, Tensor, Tensor, float, Optional[bool], Optional[bool], str) -> Tensor
|
# type: (Tensor, Tensor, Tensor, float, Optional[bool], Optional[bool], str) -> Tensor
|
||||||
|
|
@ -2276,7 +2204,6 @@ def margin_ranking_loss(input1, input2, target, margin=0, size_average=None,
|
||||||
return torch.margin_ranking_loss(input1, input2, target, margin, reduction_enum)
|
return torch.margin_ranking_loss(input1, input2, target, margin, reduction_enum)
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def hinge_embedding_loss(input, target, margin=1.0, size_average=None,
|
def hinge_embedding_loss(input, target, margin=1.0, size_average=None,
|
||||||
reduce=None, reduction='mean'):
|
reduce=None, reduction='mean'):
|
||||||
# type: (Tensor, Tensor, float, Optional[bool], Optional[bool], str) -> Tensor
|
# type: (Tensor, Tensor, float, Optional[bool], Optional[bool], str) -> Tensor
|
||||||
|
|
@ -2291,7 +2218,6 @@ def hinge_embedding_loss(input, target, margin=1.0, size_average=None,
|
||||||
return torch.hinge_embedding_loss(input, target, margin, reduction_enum)
|
return torch.hinge_embedding_loss(input, target, margin, reduction_enum)
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def multilabel_margin_loss(input, target, size_average=None, reduce=None, reduction='mean'):
|
def multilabel_margin_loss(input, target, size_average=None, reduce=None, reduction='mean'):
|
||||||
# type: (Tensor, Tensor, Optional[bool], Optional[bool], str) -> Tensor
|
# type: (Tensor, Tensor, Optional[bool], Optional[bool], str) -> Tensor
|
||||||
r"""multilabel_margin_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor
|
r"""multilabel_margin_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor
|
||||||
|
|
@ -2305,7 +2231,6 @@ def multilabel_margin_loss(input, target, size_average=None, reduce=None, reduct
|
||||||
return torch._C._nn.multilabel_margin_loss(input, target, reduction_enum)
|
return torch._C._nn.multilabel_margin_loss(input, target, reduction_enum)
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def soft_margin_loss(input, target, size_average=None, reduce=None, reduction='mean'):
|
def soft_margin_loss(input, target, size_average=None, reduce=None, reduction='mean'):
|
||||||
# type: (Tensor, Tensor, Optional[bool], Optional[bool], str) -> Tensor
|
# type: (Tensor, Tensor, Optional[bool], Optional[bool], str) -> Tensor
|
||||||
r"""soft_margin_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor
|
r"""soft_margin_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor
|
||||||
|
|
@ -2319,7 +2244,6 @@ def soft_margin_loss(input, target, size_average=None, reduce=None, reduction='m
|
||||||
return torch._C._nn.soft_margin_loss(input, target, reduction_enum)
|
return torch._C._nn.soft_margin_loss(input, target, reduction_enum)
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def multilabel_soft_margin_loss(input, target, weight=None, size_average=None,
|
def multilabel_soft_margin_loss(input, target, weight=None, size_average=None,
|
||||||
reduce=None, reduction='mean'):
|
reduce=None, reduction='mean'):
|
||||||
# type: (Tensor, Tensor, Optional[Tensor], Optional[bool], Optional[bool], str) -> Tensor
|
# type: (Tensor, Tensor, Optional[Tensor], Optional[bool], Optional[bool], str) -> Tensor
|
||||||
|
|
@ -2349,7 +2273,6 @@ def multilabel_soft_margin_loss(input, target, weight=None, size_average=None,
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def cosine_embedding_loss(input1, input2, target, margin=0, size_average=None,
|
def cosine_embedding_loss(input1, input2, target, margin=0, size_average=None,
|
||||||
reduce=None, reduction='mean'):
|
reduce=None, reduction='mean'):
|
||||||
# type: (Tensor, Tensor, Tensor, float, Optional[bool], Optional[bool], str) -> Tensor
|
# type: (Tensor, Tensor, Tensor, float, Optional[bool], Optional[bool], str) -> Tensor
|
||||||
|
|
@ -2364,7 +2287,6 @@ def cosine_embedding_loss(input1, input2, target, margin=0, size_average=None,
|
||||||
return torch.cosine_embedding_loss(input1, input2, target, margin, reduction_enum)
|
return torch.cosine_embedding_loss(input1, input2, target, margin, reduction_enum)
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def multi_margin_loss(input, target, p=1, margin=1., weight=None, size_average=None,
|
def multi_margin_loss(input, target, p=1, margin=1., weight=None, size_average=None,
|
||||||
reduce=None, reduction='mean'):
|
reduce=None, reduction='mean'):
|
||||||
# type: (Tensor, Tensor, int, float, Optional[Tensor], Optional[bool], Optional[bool], str) -> Tensor
|
# type: (Tensor, Tensor, int, float, Optional[Tensor], Optional[bool], Optional[bool], str) -> Tensor
|
||||||
|
|
@ -2635,7 +2557,6 @@ GRID_SAMPLE_PADDING_MODES = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def grid_sample(input, grid, mode='bilinear', padding_mode='zeros'):
|
def grid_sample(input, grid, mode='bilinear', padding_mode='zeros'):
|
||||||
# type: (Tensor, Tensor, str, str) -> Tensor
|
# type: (Tensor, Tensor, str, str) -> Tensor
|
||||||
r"""Given an :attr:`input` and a flow-field :attr:`grid`, computes the
|
r"""Given an :attr:`input` and a flow-field :attr:`grid`, computes the
|
||||||
|
|
@ -2717,7 +2638,6 @@ def grid_sample(input, grid, mode='bilinear', padding_mode='zeros'):
|
||||||
return torch.grid_sampler(input, grid, mode_enum, padding_mode_enum)
|
return torch.grid_sampler(input, grid, mode_enum, padding_mode_enum)
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def affine_grid(theta, size):
|
def affine_grid(theta, size):
|
||||||
# type: (Tensor, List[int]) -> Tensor
|
# type: (Tensor, List[int]) -> Tensor
|
||||||
r"""Generates a 2d flow field, given a batch of affine matrices :attr:`theta`.
|
r"""Generates a 2d flow field, given a batch of affine matrices :attr:`theta`.
|
||||||
|
|
@ -2735,7 +2655,6 @@ def affine_grid(theta, size):
|
||||||
return vision.affine_grid_generator(theta, size)
|
return vision.affine_grid_generator(theta, size)
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def pad(input, pad, mode='constant', value=0):
|
def pad(input, pad, mode='constant', value=0):
|
||||||
# type: (Tensor, List[int], str, float) -> Tensor
|
# type: (Tensor, List[int], str, float) -> Tensor
|
||||||
r"""Pads tensor.
|
r"""Pads tensor.
|
||||||
|
|
@ -2844,7 +2763,6 @@ def pad(input, pad, mode='constant', value=0):
|
||||||
# distance
|
# distance
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def pairwise_distance(x1, x2, p=2., eps=1e-6, keepdim=False):
|
def pairwise_distance(x1, x2, p=2., eps=1e-6, keepdim=False):
|
||||||
# type: (Tensor, Tensor, float, float, bool) -> Tensor
|
# type: (Tensor, Tensor, float, float, bool) -> Tensor
|
||||||
r"""
|
r"""
|
||||||
|
|
@ -2952,7 +2870,6 @@ Examples:
|
||||||
""")
|
""")
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def triplet_margin_loss(anchor, positive, negative, margin=1.0, p=2, eps=1e-6, swap=False, size_average=None,
|
def triplet_margin_loss(anchor, positive, negative, margin=1.0, p=2, eps=1e-6, swap=False, size_average=None,
|
||||||
reduce=None, reduction="mean"):
|
reduce=None, reduction="mean"):
|
||||||
# type: (Tensor, Tensor, Tensor, float, float, float, bool, Optional[bool], Optional[bool], str) -> Tensor
|
# type: (Tensor, Tensor, Tensor, float, float, float, bool, Optional[bool], Optional[bool], str) -> Tensor
|
||||||
|
|
@ -2967,7 +2884,6 @@ def triplet_margin_loss(anchor, positive, negative, margin=1.0, p=2, eps=1e-6, s
|
||||||
swap, reduction_enum)
|
swap, reduction_enum)
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def normalize(input, p=2, dim=1, eps=1e-12, out=None):
|
def normalize(input, p=2, dim=1, eps=1e-12, out=None):
|
||||||
# type: (Tensor, float, int, float, Optional[Tensor]) -> Tensor
|
# type: (Tensor, float, int, float, Optional[Tensor]) -> Tensor
|
||||||
r"""Performs :math:`L_p` normalization of inputs over specified dimension.
|
r"""Performs :math:`L_p` normalization of inputs over specified dimension.
|
||||||
|
|
@ -3001,7 +2917,6 @@ def assert_int_or_pair(arg, arg_name, message):
|
||||||
assert isinstance(arg, int) or len(arg) == 2, message.format(arg_name)
|
assert isinstance(arg, int) or len(arg) == 2, message.format(arg_name)
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def unfold(input, kernel_size, dilation=1, padding=0, stride=1):
|
def unfold(input, kernel_size, dilation=1, padding=0, stride=1):
|
||||||
# type: (Tensor, BroadcastingList2[int], BroadcastingList2[int], BroadcastingList2[int], BroadcastingList2[int]) -> Tensor # noqa
|
# type: (Tensor, BroadcastingList2[int], BroadcastingList2[int], BroadcastingList2[int], BroadcastingList2[int]) -> Tensor # noqa
|
||||||
r"""Extracts sliding local blocks from an batched input tensor.
|
r"""Extracts sliding local blocks from an batched input tensor.
|
||||||
|
|
@ -3036,7 +2951,6 @@ def unfold(input, kernel_size, dilation=1, padding=0, stride=1):
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def fold(input, output_size, kernel_size, dilation=1, padding=0, stride=1):
|
def fold(input, output_size, kernel_size, dilation=1, padding=0, stride=1):
|
||||||
# type: (Tensor, BroadcastingList2[int], BroadcastingList2[int], BroadcastingList2[int], BroadcastingList2[int], BroadcastingList2[int]) -> Tensor # noqa
|
# type: (Tensor, BroadcastingList2[int], BroadcastingList2[int], BroadcastingList2[int], BroadcastingList2[int], BroadcastingList2[int]) -> Tensor # noqa
|
||||||
r"""Combines an array of sliding local blocks into a large containing
|
r"""Combines an array of sliding local blocks into a large containing
|
||||||
|
|
@ -3064,7 +2978,6 @@ def fold(input, output_size, kernel_size, dilation=1, padding=0, stride=1):
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def _pad_circular(input, padding):
|
def _pad_circular(input, padding):
|
||||||
# type: (Tensor, List[int]) -> Tensor
|
# type: (Tensor, List[int]) -> Tensor
|
||||||
"""
|
"""
|
||||||
|
|
@ -3090,7 +3003,6 @@ def _pad_circular(input, padding):
|
||||||
return input
|
return input
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def multi_head_attention_forward(query, # type: Tensor
|
def multi_head_attention_forward(query, # type: Tensor
|
||||||
key, # type: Tensor
|
key, # type: Tensor
|
||||||
value, # type: Tensor
|
value, # type: Tensor
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,6 @@ import math
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from .._jit_internal import weak_script
|
|
||||||
|
|
||||||
# These no_grad_* functions are necessary as wrappers around the parts of these
|
# These no_grad_* functions are necessary as wrappers around the parts of these
|
||||||
# functions that use `with torch.no_grad()`. The JIT doesn't support context
|
# functions that use `with torch.no_grad()`. The JIT doesn't support context
|
||||||
|
|
@ -72,7 +71,6 @@ def calculate_gain(nonlinearity, param=None):
|
||||||
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
|
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def uniform_(tensor, a=0., b=1.):
|
def uniform_(tensor, a=0., b=1.):
|
||||||
# type: (Tensor, float, float) -> Tensor
|
# type: (Tensor, float, float) -> Tensor
|
||||||
r"""Fills the input Tensor with values drawn from the uniform
|
r"""Fills the input Tensor with values drawn from the uniform
|
||||||
|
|
@ -90,7 +88,6 @@ def uniform_(tensor, a=0., b=1.):
|
||||||
return _no_grad_uniform_(tensor, a, b)
|
return _no_grad_uniform_(tensor, a, b)
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def normal_(tensor, mean=0., std=1.):
|
def normal_(tensor, mean=0., std=1.):
|
||||||
# type: (Tensor, float, float) -> Tensor
|
# type: (Tensor, float, float) -> Tensor
|
||||||
r"""Fills the input Tensor with values drawn from the normal
|
r"""Fills the input Tensor with values drawn from the normal
|
||||||
|
|
@ -108,7 +105,6 @@ def normal_(tensor, mean=0., std=1.):
|
||||||
return _no_grad_normal_(tensor, mean, std)
|
return _no_grad_normal_(tensor, mean, std)
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def constant_(tensor, val):
|
def constant_(tensor, val):
|
||||||
# type: (Tensor, float) -> Tensor
|
# type: (Tensor, float) -> Tensor
|
||||||
r"""Fills the input Tensor with the value :math:`\text{val}`.
|
r"""Fills the input Tensor with the value :math:`\text{val}`.
|
||||||
|
|
@ -124,7 +120,6 @@ def constant_(tensor, val):
|
||||||
return _no_grad_fill_(tensor, val)
|
return _no_grad_fill_(tensor, val)
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def ones_(tensor):
|
def ones_(tensor):
|
||||||
# type: (Tensor) -> Tensor
|
# type: (Tensor) -> Tensor
|
||||||
r"""Fills the input Tensor with ones`.
|
r"""Fills the input Tensor with ones`.
|
||||||
|
|
@ -139,7 +134,6 @@ def ones_(tensor):
|
||||||
return _no_grad_fill_(tensor, 1.)
|
return _no_grad_fill_(tensor, 1.)
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def zeros_(tensor):
|
def zeros_(tensor):
|
||||||
# type: (Tensor) -> Tensor
|
# type: (Tensor) -> Tensor
|
||||||
r"""Fills the input Tensor with zeros`.
|
r"""Fills the input Tensor with zeros`.
|
||||||
|
|
@ -205,7 +199,6 @@ def dirac_(tensor):
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def _calculate_fan_in_and_fan_out(tensor):
|
def _calculate_fan_in_and_fan_out(tensor):
|
||||||
dimensions = tensor.dim()
|
dimensions = tensor.dim()
|
||||||
if dimensions < 2:
|
if dimensions < 2:
|
||||||
|
|
@ -226,7 +219,6 @@ def _calculate_fan_in_and_fan_out(tensor):
|
||||||
return fan_in, fan_out
|
return fan_in, fan_out
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def xavier_uniform_(tensor, gain=1.):
|
def xavier_uniform_(tensor, gain=1.):
|
||||||
# type: (Tensor, float) -> Tensor
|
# type: (Tensor, float) -> Tensor
|
||||||
r"""Fills the input `Tensor` with values according to the method
|
r"""Fills the input `Tensor` with values according to the method
|
||||||
|
|
@ -255,7 +247,6 @@ def xavier_uniform_(tensor, gain=1.):
|
||||||
return _no_grad_uniform_(tensor, -a, a)
|
return _no_grad_uniform_(tensor, -a, a)
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def xavier_normal_(tensor, gain=1.):
|
def xavier_normal_(tensor, gain=1.):
|
||||||
# type: (Tensor, float) -> Tensor
|
# type: (Tensor, float) -> Tensor
|
||||||
r"""Fills the input `Tensor` with values according to the method
|
r"""Fills the input `Tensor` with values according to the method
|
||||||
|
|
|
||||||
|
|
@ -7,10 +7,8 @@ from torch.nn.init import xavier_normal_
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
from .module import Module
|
from .module import Module
|
||||||
from .. import functional as F
|
from .. import functional as F
|
||||||
from ..._jit_internal import weak_module, weak_script_method
|
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class Threshold(Module):
|
class Threshold(Module):
|
||||||
r"""Thresholds each element of the input Tensor.
|
r"""Thresholds each element of the input Tensor.
|
||||||
|
|
||||||
|
|
@ -48,7 +46,6 @@ class Threshold(Module):
|
||||||
self.inplace = inplace
|
self.inplace = inplace
|
||||||
# TODO: check in THNN (if inplace == True, then assert value <= threshold)
|
# TODO: check in THNN (if inplace == True, then assert value <= threshold)
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
return F.threshold(input, self.threshold, self.value, self.inplace)
|
return F.threshold(input, self.threshold, self.value, self.inplace)
|
||||||
|
|
||||||
|
|
@ -59,7 +56,6 @@ class Threshold(Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class ReLU(Module):
|
class ReLU(Module):
|
||||||
r"""Applies the rectified linear unit function element-wise:
|
r"""Applies the rectified linear unit function element-wise:
|
||||||
|
|
||||||
|
|
@ -94,7 +90,6 @@ class ReLU(Module):
|
||||||
super(ReLU, self).__init__()
|
super(ReLU, self).__init__()
|
||||||
self.inplace = inplace
|
self.inplace = inplace
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
return F.relu(input, inplace=self.inplace)
|
return F.relu(input, inplace=self.inplace)
|
||||||
|
|
||||||
|
|
@ -103,7 +98,6 @@ class ReLU(Module):
|
||||||
return inplace_str
|
return inplace_str
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class RReLU(Module):
|
class RReLU(Module):
|
||||||
r"""Applies the randomized leaky rectified liner unit function, element-wise,
|
r"""Applies the randomized leaky rectified liner unit function, element-wise,
|
||||||
as described in the paper:
|
as described in the paper:
|
||||||
|
|
@ -151,7 +145,6 @@ class RReLU(Module):
|
||||||
self.upper = upper
|
self.upper = upper
|
||||||
self.inplace = inplace
|
self.inplace = inplace
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
return F.rrelu(input, self.lower, self.upper, self.training, self.inplace)
|
return F.rrelu(input, self.lower, self.upper, self.training, self.inplace)
|
||||||
|
|
||||||
|
|
@ -160,7 +153,6 @@ class RReLU(Module):
|
||||||
return 'lower={}, upper={}{}'.format(self.lower, self.upper, inplace_str)
|
return 'lower={}, upper={}{}'.format(self.lower, self.upper, inplace_str)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class Hardtanh(Module):
|
class Hardtanh(Module):
|
||||||
r"""Applies the HardTanh function element-wise
|
r"""Applies the HardTanh function element-wise
|
||||||
|
|
||||||
|
|
@ -213,7 +205,6 @@ class Hardtanh(Module):
|
||||||
self.inplace = inplace
|
self.inplace = inplace
|
||||||
assert self.max_val > self.min_val
|
assert self.max_val > self.min_val
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
return F.hardtanh(input, self.min_val, self.max_val, self.inplace)
|
return F.hardtanh(input, self.min_val, self.max_val, self.inplace)
|
||||||
|
|
||||||
|
|
@ -224,7 +215,6 @@ class Hardtanh(Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class ReLU6(Hardtanh):
|
class ReLU6(Hardtanh):
|
||||||
r"""Applies the element-wise function:
|
r"""Applies the element-wise function:
|
||||||
|
|
||||||
|
|
@ -256,7 +246,6 @@ class ReLU6(Hardtanh):
|
||||||
return inplace_str
|
return inplace_str
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class Sigmoid(Module):
|
class Sigmoid(Module):
|
||||||
r"""Applies the element-wise function:
|
r"""Applies the element-wise function:
|
||||||
|
|
||||||
|
|
@ -278,12 +267,10 @@ class Sigmoid(Module):
|
||||||
>>> output = m(input)
|
>>> output = m(input)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
return torch.sigmoid(input)
|
return torch.sigmoid(input)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class Tanh(Module):
|
class Tanh(Module):
|
||||||
r"""Applies the element-wise function:
|
r"""Applies the element-wise function:
|
||||||
|
|
||||||
|
|
@ -304,12 +291,10 @@ class Tanh(Module):
|
||||||
>>> output = m(input)
|
>>> output = m(input)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
return torch.tanh(input)
|
return torch.tanh(input)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class ELU(Module):
|
class ELU(Module):
|
||||||
r"""Applies the element-wise function:
|
r"""Applies the element-wise function:
|
||||||
|
|
||||||
|
|
@ -340,7 +325,6 @@ class ELU(Module):
|
||||||
self.alpha = alpha
|
self.alpha = alpha
|
||||||
self.inplace = inplace
|
self.inplace = inplace
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
return F.elu(input, self.alpha, self.inplace)
|
return F.elu(input, self.alpha, self.inplace)
|
||||||
|
|
||||||
|
|
@ -349,7 +333,6 @@ class ELU(Module):
|
||||||
return 'alpha={}{}'.format(self.alpha, inplace_str)
|
return 'alpha={}{}'.format(self.alpha, inplace_str)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class CELU(Module):
|
class CELU(Module):
|
||||||
r"""Applies the element-wise function:
|
r"""Applies the element-wise function:
|
||||||
|
|
||||||
|
|
@ -385,7 +368,6 @@ class CELU(Module):
|
||||||
self.alpha = alpha
|
self.alpha = alpha
|
||||||
self.inplace = inplace
|
self.inplace = inplace
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
return F.celu(input, self.alpha, self.inplace)
|
return F.celu(input, self.alpha, self.inplace)
|
||||||
|
|
||||||
|
|
@ -394,7 +376,6 @@ class CELU(Module):
|
||||||
return 'alpha={}{}'.format(self.alpha, inplace_str)
|
return 'alpha={}{}'.format(self.alpha, inplace_str)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class SELU(Module):
|
class SELU(Module):
|
||||||
r"""Applied element-wise, as:
|
r"""Applied element-wise, as:
|
||||||
|
|
||||||
|
|
@ -430,7 +411,6 @@ class SELU(Module):
|
||||||
super(SELU, self).__init__()
|
super(SELU, self).__init__()
|
||||||
self.inplace = inplace
|
self.inplace = inplace
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
return F.selu(input, self.inplace)
|
return F.selu(input, self.inplace)
|
||||||
|
|
||||||
|
|
@ -439,7 +419,6 @@ class SELU(Module):
|
||||||
return inplace_str
|
return inplace_str
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class GLU(Module):
|
class GLU(Module):
|
||||||
r"""Applies the gated linear unit function
|
r"""Applies the gated linear unit function
|
||||||
:math:`{GLU}(a, b)= a \otimes \sigma(b)` where :math:`a` is the first half
|
:math:`{GLU}(a, b)= a \otimes \sigma(b)` where :math:`a` is the first half
|
||||||
|
|
@ -465,7 +444,6 @@ class GLU(Module):
|
||||||
super(GLU, self).__init__()
|
super(GLU, self).__init__()
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
return F.glu(input, self.dim)
|
return F.glu(input, self.dim)
|
||||||
|
|
||||||
|
|
@ -473,7 +451,6 @@ class GLU(Module):
|
||||||
return 'dim={}'.format(self.dim)
|
return 'dim={}'.format(self.dim)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class Hardshrink(Module):
|
class Hardshrink(Module):
|
||||||
r"""Applies the hard shrinkage function element-wise:
|
r"""Applies the hard shrinkage function element-wise:
|
||||||
|
|
||||||
|
|
@ -507,7 +484,6 @@ class Hardshrink(Module):
|
||||||
super(Hardshrink, self).__init__()
|
super(Hardshrink, self).__init__()
|
||||||
self.lambd = lambd
|
self.lambd = lambd
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
return F.hardshrink(input, self.lambd)
|
return F.hardshrink(input, self.lambd)
|
||||||
|
|
||||||
|
|
@ -515,7 +491,6 @@ class Hardshrink(Module):
|
||||||
return '{}'.format(self.lambd)
|
return '{}'.format(self.lambd)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class LeakyReLU(Module):
|
class LeakyReLU(Module):
|
||||||
r"""Applies the element-wise function:
|
r"""Applies the element-wise function:
|
||||||
|
|
||||||
|
|
@ -556,7 +531,6 @@ class LeakyReLU(Module):
|
||||||
self.negative_slope = negative_slope
|
self.negative_slope = negative_slope
|
||||||
self.inplace = inplace
|
self.inplace = inplace
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
return F.leaky_relu(input, self.negative_slope, self.inplace)
|
return F.leaky_relu(input, self.negative_slope, self.inplace)
|
||||||
|
|
||||||
|
|
@ -565,7 +539,6 @@ class LeakyReLU(Module):
|
||||||
return 'negative_slope={}{}'.format(self.negative_slope, inplace_str)
|
return 'negative_slope={}{}'.format(self.negative_slope, inplace_str)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class LogSigmoid(Module):
|
class LogSigmoid(Module):
|
||||||
r"""Applies the element-wise function:
|
r"""Applies the element-wise function:
|
||||||
|
|
||||||
|
|
@ -586,12 +559,10 @@ class LogSigmoid(Module):
|
||||||
>>> output = m(input)
|
>>> output = m(input)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
return F.logsigmoid(input)
|
return F.logsigmoid(input)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class Softplus(Module):
|
class Softplus(Module):
|
||||||
r"""Applies the element-wise function:
|
r"""Applies the element-wise function:
|
||||||
|
|
||||||
|
|
@ -628,7 +599,6 @@ class Softplus(Module):
|
||||||
self.beta = beta
|
self.beta = beta
|
||||||
self.threshold = threshold
|
self.threshold = threshold
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
return F.softplus(input, self.beta, self.threshold)
|
return F.softplus(input, self.beta, self.threshold)
|
||||||
|
|
||||||
|
|
@ -636,7 +606,6 @@ class Softplus(Module):
|
||||||
return 'beta={}, threshold={}'.format(self.beta, self.threshold)
|
return 'beta={}, threshold={}'.format(self.beta, self.threshold)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class Softshrink(Module):
|
class Softshrink(Module):
|
||||||
r"""Applies the soft shrinkage function elementwise:
|
r"""Applies the soft shrinkage function elementwise:
|
||||||
|
|
||||||
|
|
@ -670,7 +639,6 @@ class Softshrink(Module):
|
||||||
super(Softshrink, self).__init__()
|
super(Softshrink, self).__init__()
|
||||||
self.lambd = lambd
|
self.lambd = lambd
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
return F.softshrink(input, self.lambd)
|
return F.softshrink(input, self.lambd)
|
||||||
|
|
||||||
|
|
@ -678,7 +646,6 @@ class Softshrink(Module):
|
||||||
return str(self.lambd)
|
return str(self.lambd)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class MultiheadAttention(Module):
|
class MultiheadAttention(Module):
|
||||||
r"""Allows the model to jointly attend to information
|
r"""Allows the model to jointly attend to information
|
||||||
from different representation subspaces.
|
from different representation subspaces.
|
||||||
|
|
@ -759,7 +726,6 @@ class MultiheadAttention(Module):
|
||||||
if self.bias_v is not None:
|
if self.bias_v is not None:
|
||||||
xavier_normal_(self.bias_v)
|
xavier_normal_(self.bias_v)
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, query, key, value, key_padding_mask=None,
|
def forward(self, query, key, value, key_padding_mask=None,
|
||||||
need_weights=True, attn_mask=None):
|
need_weights=True, attn_mask=None):
|
||||||
r"""
|
r"""
|
||||||
|
|
@ -817,7 +783,6 @@ class MultiheadAttention(Module):
|
||||||
attn_mask=attn_mask)
|
attn_mask=attn_mask)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class PReLU(Module):
|
class PReLU(Module):
|
||||||
r"""Applies the element-wise function:
|
r"""Applies the element-wise function:
|
||||||
|
|
||||||
|
|
@ -874,7 +839,6 @@ class PReLU(Module):
|
||||||
super(PReLU, self).__init__()
|
super(PReLU, self).__init__()
|
||||||
self.weight = Parameter(torch.Tensor(num_parameters).fill_(init))
|
self.weight = Parameter(torch.Tensor(num_parameters).fill_(init))
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
return F.prelu(input, self.weight)
|
return F.prelu(input, self.weight)
|
||||||
|
|
||||||
|
|
@ -882,7 +846,6 @@ class PReLU(Module):
|
||||||
return 'num_parameters={}'.format(self.num_parameters)
|
return 'num_parameters={}'.format(self.num_parameters)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class Softsign(Module):
|
class Softsign(Module):
|
||||||
r"""Applies the element-wise function:
|
r"""Applies the element-wise function:
|
||||||
|
|
||||||
|
|
@ -903,12 +866,10 @@ class Softsign(Module):
|
||||||
>>> output = m(input)
|
>>> output = m(input)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
return F.softsign(input)
|
return F.softsign(input)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class Tanhshrink(Module):
|
class Tanhshrink(Module):
|
||||||
r"""Applies the element-wise function:
|
r"""Applies the element-wise function:
|
||||||
|
|
||||||
|
|
@ -929,12 +890,10 @@ class Tanhshrink(Module):
|
||||||
>>> output = m(input)
|
>>> output = m(input)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
return F.tanhshrink(input)
|
return F.tanhshrink(input)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class Softmin(Module):
|
class Softmin(Module):
|
||||||
r"""Applies the Softmin function to an n-dimensional input Tensor
|
r"""Applies the Softmin function to an n-dimensional input Tensor
|
||||||
rescaling them so that the elements of the n-dimensional output Tensor
|
rescaling them so that the elements of the n-dimensional output Tensor
|
||||||
|
|
@ -970,12 +929,10 @@ class Softmin(Module):
|
||||||
super(Softmin, self).__init__()
|
super(Softmin, self).__init__()
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
return F.softmin(input, self.dim, _stacklevel=5)
|
return F.softmin(input, self.dim, _stacklevel=5)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class Softmax(Module):
|
class Softmax(Module):
|
||||||
r"""Applies the Softmax function to an n-dimensional input Tensor
|
r"""Applies the Softmax function to an n-dimensional input Tensor
|
||||||
rescaling them so that the elements of the n-dimensional output Tensor
|
rescaling them so that the elements of the n-dimensional output Tensor
|
||||||
|
|
@ -1021,7 +978,6 @@ class Softmax(Module):
|
||||||
if not hasattr(self, 'dim'):
|
if not hasattr(self, 'dim'):
|
||||||
self.dim = None
|
self.dim = None
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
return F.softmax(input, self.dim, _stacklevel=5)
|
return F.softmax(input, self.dim, _stacklevel=5)
|
||||||
|
|
||||||
|
|
@ -1029,7 +985,6 @@ class Softmax(Module):
|
||||||
return 'dim={dim}'.format(dim=self.dim)
|
return 'dim={dim}'.format(dim=self.dim)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class Softmax2d(Module):
|
class Softmax2d(Module):
|
||||||
r"""Applies SoftMax over features to each spatial location.
|
r"""Applies SoftMax over features to each spatial location.
|
||||||
|
|
||||||
|
|
@ -1052,13 +1007,11 @@ class Softmax2d(Module):
|
||||||
>>> output = m(input)
|
>>> output = m(input)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
assert input.dim() == 4, 'Softmax2d requires a 4D tensor as input'
|
assert input.dim() == 4, 'Softmax2d requires a 4D tensor as input'
|
||||||
return F.softmax(input, 1, _stacklevel=5)
|
return F.softmax(input, 1, _stacklevel=5)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class LogSoftmax(Module):
|
class LogSoftmax(Module):
|
||||||
r"""Applies the :math:`\log(\text{Softmax}(x))` function to an n-dimensional
|
r"""Applies the :math:`\log(\text{Softmax}(x))` function to an n-dimensional
|
||||||
input Tensor. The LogSoftmax formulation can be simplified as:
|
input Tensor. The LogSoftmax formulation can be simplified as:
|
||||||
|
|
@ -1095,6 +1048,5 @@ class LogSoftmax(Module):
|
||||||
if not hasattr(self, 'dim'):
|
if not hasattr(self, 'dim'):
|
||||||
self.dim = None
|
self.dim = None
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
return F.log_softmax(input, self.dim, _stacklevel=5)
|
return F.log_softmax(input, self.dim, _stacklevel=5)
|
||||||
|
|
|
||||||
|
|
@ -6,12 +6,10 @@ from .module import Module
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
from .. import functional as F
|
from .. import functional as F
|
||||||
from .. import init
|
from .. import init
|
||||||
from ..._jit_internal import weak_module, weak_script_method
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: check contiguous in THNN
|
# TODO: check contiguous in THNN
|
||||||
# TODO: use separate backend functions?
|
# TODO: use separate backend functions?
|
||||||
@weak_module
|
|
||||||
class _BatchNorm(Module):
|
class _BatchNorm(Module):
|
||||||
_version = 2
|
_version = 2
|
||||||
__constants__ = ['track_running_stats', 'momentum', 'eps', 'weight', 'bias',
|
__constants__ = ['track_running_stats', 'momentum', 'eps', 'weight', 'bias',
|
||||||
|
|
@ -57,7 +55,6 @@ class _BatchNorm(Module):
|
||||||
def _check_input_dim(self, input):
|
def _check_input_dim(self, input):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
self._check_input_dim(input)
|
self._check_input_dim(input)
|
||||||
|
|
||||||
|
|
@ -103,7 +100,6 @@ class _BatchNorm(Module):
|
||||||
missing_keys, unexpected_keys, error_msgs)
|
missing_keys, unexpected_keys, error_msgs)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class BatchNorm1d(_BatchNorm):
|
class BatchNorm1d(_BatchNorm):
|
||||||
r"""Applies Batch Normalization over a 2D or 3D input (a mini-batch of 1D
|
r"""Applies Batch Normalization over a 2D or 3D input (a mini-batch of 1D
|
||||||
inputs with optional additional channel dimension) as described in the paper
|
inputs with optional additional channel dimension) as described in the paper
|
||||||
|
|
@ -170,14 +166,12 @@ class BatchNorm1d(_BatchNorm):
|
||||||
https://arxiv.org/abs/1502.03167
|
https://arxiv.org/abs/1502.03167
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def _check_input_dim(self, input):
|
def _check_input_dim(self, input):
|
||||||
if input.dim() != 2 and input.dim() != 3:
|
if input.dim() != 2 and input.dim() != 3:
|
||||||
raise ValueError('expected 2D or 3D input (got {}D input)'
|
raise ValueError('expected 2D or 3D input (got {}D input)'
|
||||||
.format(input.dim()))
|
.format(input.dim()))
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class BatchNorm2d(_BatchNorm):
|
class BatchNorm2d(_BatchNorm):
|
||||||
r"""Applies Batch Normalization over a 4D input (a mini-batch of 2D inputs
|
r"""Applies Batch Normalization over a 4D input (a mini-batch of 2D inputs
|
||||||
with additional channel dimension) as described in the paper
|
with additional channel dimension) as described in the paper
|
||||||
|
|
@ -244,14 +238,12 @@ class BatchNorm2d(_BatchNorm):
|
||||||
https://arxiv.org/abs/1502.03167
|
https://arxiv.org/abs/1502.03167
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def _check_input_dim(self, input):
|
def _check_input_dim(self, input):
|
||||||
if input.dim() != 4:
|
if input.dim() != 4:
|
||||||
raise ValueError('expected 4D input (got {}D input)'
|
raise ValueError('expected 4D input (got {}D input)'
|
||||||
.format(input.dim()))
|
.format(input.dim()))
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class BatchNorm3d(_BatchNorm):
|
class BatchNorm3d(_BatchNorm):
|
||||||
r"""Applies Batch Normalization over a 5D input (a mini-batch of 3D inputs
|
r"""Applies Batch Normalization over a 5D input (a mini-batch of 3D inputs
|
||||||
with additional channel dimension) as described in the paper
|
with additional channel dimension) as described in the paper
|
||||||
|
|
@ -319,7 +311,6 @@ class BatchNorm3d(_BatchNorm):
|
||||||
https://arxiv.org/abs/1502.03167
|
https://arxiv.org/abs/1502.03167
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def _check_input_dim(self, input):
|
def _check_input_dim(self, input):
|
||||||
if input.dim() != 5:
|
if input.dim() != 5:
|
||||||
raise ValueError('expected 5D input (got {}D input)'
|
raise ValueError('expected 5D input (got {}D input)'
|
||||||
|
|
|
||||||
|
|
@ -6,10 +6,9 @@ from .. import functional as F
|
||||||
from .. import init
|
from .. import init
|
||||||
from .module import Module
|
from .module import Module
|
||||||
from .utils import _single, _pair, _triple
|
from .utils import _single, _pair, _triple
|
||||||
from ..._jit_internal import weak_module, weak_script_method, List
|
from ..._jit_internal import List
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class _ConvNd(Module):
|
class _ConvNd(Module):
|
||||||
|
|
||||||
__constants__ = ['stride', 'padding', 'dilation', 'groups', 'bias',
|
__constants__ = ['stride', 'padding', 'dilation', 'groups', 'bias',
|
||||||
|
|
@ -74,7 +73,6 @@ class _ConvNd(Module):
|
||||||
self.padding_mode = 'zeros'
|
self.padding_mode = 'zeros'
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class Conv1d(_ConvNd):
|
class Conv1d(_ConvNd):
|
||||||
r"""Applies a 1D convolution over an input signal composed of several input
|
r"""Applies a 1D convolution over an input signal composed of several input
|
||||||
planes.
|
planes.
|
||||||
|
|
@ -192,7 +190,6 @@ class Conv1d(_ConvNd):
|
||||||
in_channels, out_channels, kernel_size, stride, padding, dilation,
|
in_channels, out_channels, kernel_size, stride, padding, dilation,
|
||||||
False, _single(0), groups, bias, padding_mode)
|
False, _single(0), groups, bias, padding_mode)
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
if self.padding_mode == 'circular':
|
if self.padding_mode == 'circular':
|
||||||
expanded_padding = ((self.padding[0] + 1) // 2, self.padding[0] // 2)
|
expanded_padding = ((self.padding[0] + 1) // 2, self.padding[0] // 2)
|
||||||
|
|
@ -203,7 +200,6 @@ class Conv1d(_ConvNd):
|
||||||
self.padding, self.dilation, self.groups)
|
self.padding, self.dilation, self.groups)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class Conv2d(_ConvNd):
|
class Conv2d(_ConvNd):
|
||||||
r"""Applies a 2D convolution over an input signal composed of several input
|
r"""Applies a 2D convolution over an input signal composed of several input
|
||||||
planes.
|
planes.
|
||||||
|
|
@ -333,7 +329,6 @@ class Conv2d(_ConvNd):
|
||||||
in_channels, out_channels, kernel_size, stride, padding, dilation,
|
in_channels, out_channels, kernel_size, stride, padding, dilation,
|
||||||
False, _pair(0), groups, bias, padding_mode)
|
False, _pair(0), groups, bias, padding_mode)
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
if self.padding_mode == 'circular':
|
if self.padding_mode == 'circular':
|
||||||
expanded_padding = ((self.padding[1] + 1) // 2, self.padding[1] // 2,
|
expanded_padding = ((self.padding[1] + 1) // 2, self.padding[1] // 2,
|
||||||
|
|
@ -345,7 +340,6 @@ class Conv2d(_ConvNd):
|
||||||
self.padding, self.dilation, self.groups)
|
self.padding, self.dilation, self.groups)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class Conv3d(_ConvNd):
|
class Conv3d(_ConvNd):
|
||||||
r"""Applies a 3D convolution over an input signal composed of several input
|
r"""Applies a 3D convolution over an input signal composed of several input
|
||||||
planes.
|
planes.
|
||||||
|
|
@ -470,7 +464,6 @@ class Conv3d(_ConvNd):
|
||||||
in_channels, out_channels, kernel_size, stride, padding, dilation,
|
in_channels, out_channels, kernel_size, stride, padding, dilation,
|
||||||
False, _triple(0), groups, bias, padding_mode)
|
False, _triple(0), groups, bias, padding_mode)
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
if self.padding_mode == 'circular':
|
if self.padding_mode == 'circular':
|
||||||
expanded_padding = ((self.padding[2] + 1) // 2, self.padding[2] // 2,
|
expanded_padding = ((self.padding[2] + 1) // 2, self.padding[2] // 2,
|
||||||
|
|
@ -483,9 +476,7 @@ class Conv3d(_ConvNd):
|
||||||
self.padding, self.dilation, self.groups)
|
self.padding, self.dilation, self.groups)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class _ConvTransposeMixin(object):
|
class _ConvTransposeMixin(object):
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input, output_size=None):
|
def forward(self, input, output_size=None):
|
||||||
# type(Tensor, Optional[List[int]]) -> Tensor
|
# type(Tensor, Optional[List[int]]) -> Tensor
|
||||||
output_padding = self._output_padding(input, output_size, self.stride, self.padding, self.kernel_size)
|
output_padding = self._output_padding(input, output_size, self.stride, self.padding, self.kernel_size)
|
||||||
|
|
@ -497,7 +488,6 @@ class _ConvTransposeMixin(object):
|
||||||
else:
|
else:
|
||||||
return func(input, self.weight, self.bias)
|
return func(input, self.weight, self.bias)
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def _output_padding(self, input, output_size, stride, padding, kernel_size):
|
def _output_padding(self, input, output_size, stride, padding, kernel_size):
|
||||||
# type: (Tensor, Optional[List[int]], List[int], List[int], List[int]) -> List[int]
|
# type: (Tensor, Optional[List[int]], List[int], List[int], List[int]) -> List[int]
|
||||||
if output_size is None:
|
if output_size is None:
|
||||||
|
|
@ -537,7 +527,6 @@ class _ConvTransposeMixin(object):
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class ConvTranspose1d(_ConvTransposeMixin, _ConvNd):
|
class ConvTranspose1d(_ConvTransposeMixin, _ConvNd):
|
||||||
r"""Applies a 1D transposed convolution operator over an input image
|
r"""Applies a 1D transposed convolution operator over an input image
|
||||||
composed of several input planes.
|
composed of several input planes.
|
||||||
|
|
@ -638,7 +627,6 @@ class ConvTranspose1d(_ConvTransposeMixin, _ConvNd):
|
||||||
in_channels, out_channels, kernel_size, stride, padding, dilation,
|
in_channels, out_channels, kernel_size, stride, padding, dilation,
|
||||||
True, output_padding, groups, bias, padding_mode)
|
True, output_padding, groups, bias, padding_mode)
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input, output_size=None):
|
def forward(self, input, output_size=None):
|
||||||
# type: (Tensor, Optional[List[int]]) -> Tensor
|
# type: (Tensor, Optional[List[int]]) -> Tensor
|
||||||
if self.padding_mode != 'zeros':
|
if self.padding_mode != 'zeros':
|
||||||
|
|
@ -650,7 +638,6 @@ class ConvTranspose1d(_ConvTransposeMixin, _ConvNd):
|
||||||
output_padding, self.groups, self.dilation)
|
output_padding, self.groups, self.dilation)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class ConvTranspose2d(_ConvTransposeMixin, _ConvNd):
|
class ConvTranspose2d(_ConvTransposeMixin, _ConvNd):
|
||||||
r"""Applies a 2D transposed convolution operator over an input image
|
r"""Applies a 2D transposed convolution operator over an input image
|
||||||
composed of several input planes.
|
composed of several input planes.
|
||||||
|
|
@ -786,7 +773,6 @@ class ConvTranspose2d(_ConvTransposeMixin, _ConvNd):
|
||||||
in_channels, out_channels, kernel_size, stride, padding, dilation,
|
in_channels, out_channels, kernel_size, stride, padding, dilation,
|
||||||
True, output_padding, groups, bias, padding_mode)
|
True, output_padding, groups, bias, padding_mode)
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input, output_size=None):
|
def forward(self, input, output_size=None):
|
||||||
# type: (Tensor, Optional[List[int]]) -> Tensor
|
# type: (Tensor, Optional[List[int]]) -> Tensor
|
||||||
if self.padding_mode != 'zeros':
|
if self.padding_mode != 'zeros':
|
||||||
|
|
@ -799,7 +785,6 @@ class ConvTranspose2d(_ConvTransposeMixin, _ConvNd):
|
||||||
output_padding, self.groups, self.dilation)
|
output_padding, self.groups, self.dilation)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class ConvTranspose3d(_ConvTransposeMixin, _ConvNd):
|
class ConvTranspose3d(_ConvTransposeMixin, _ConvNd):
|
||||||
r"""Applies a 3D transposed convolution operator over an input image composed of several input
|
r"""Applies a 3D transposed convolution operator over an input image composed of several input
|
||||||
planes.
|
planes.
|
||||||
|
|
@ -931,7 +916,6 @@ class ConvTranspose3d(_ConvTransposeMixin, _ConvNd):
|
||||||
in_channels, out_channels, kernel_size, stride, padding, dilation,
|
in_channels, out_channels, kernel_size, stride, padding, dilation,
|
||||||
True, output_padding, groups, bias, padding_mode)
|
True, output_padding, groups, bias, padding_mode)
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input, output_size=None):
|
def forward(self, input, output_size=None):
|
||||||
# type: (Tensor, Optional[List[int]]) -> Tensor
|
# type: (Tensor, Optional[List[int]]) -> Tensor
|
||||||
if self.padding_mode != 'zeros':
|
if self.padding_mode != 'zeros':
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,7 @@
|
||||||
from .module import Module
|
from .module import Module
|
||||||
from .. import functional as F
|
from .. import functional as F
|
||||||
from ..._jit_internal import weak_module, weak_script_method
|
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class PairwiseDistance(Module):
|
class PairwiseDistance(Module):
|
||||||
r"""
|
r"""
|
||||||
Computes the batchwise pairwise distance between vectors :math:`v_1`, :math:`v_2` using the p-norm:
|
Computes the batchwise pairwise distance between vectors :math:`v_1`, :math:`v_2` using the p-norm:
|
||||||
|
|
@ -35,12 +33,10 @@ class PairwiseDistance(Module):
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
self.keepdim = keepdim
|
self.keepdim = keepdim
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, x1, x2):
|
def forward(self, x1, x2):
|
||||||
return F.pairwise_distance(x1, x2, self.norm, self.eps, self.keepdim)
|
return F.pairwise_distance(x1, x2, self.norm, self.eps, self.keepdim)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class CosineSimilarity(Module):
|
class CosineSimilarity(Module):
|
||||||
r"""Returns cosine similarity between :math:`x_1` and :math:`x_2`, computed along dim.
|
r"""Returns cosine similarity between :math:`x_1` and :math:`x_2`, computed along dim.
|
||||||
|
|
||||||
|
|
@ -68,6 +64,5 @@ class CosineSimilarity(Module):
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, x1, x2):
|
def forward(self, x1, x2):
|
||||||
return F.cosine_similarity(x1, x2, self.dim, self.eps)
|
return F.cosine_similarity(x1, x2, self.dim, self.eps)
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,5 @@
|
||||||
from .module import Module
|
from .module import Module
|
||||||
from .. import functional as F
|
from .. import functional as F
|
||||||
from ..._jit_internal import weak_module, weak_script_method
|
|
||||||
|
|
||||||
|
|
||||||
class _DropoutNd(Module):
|
class _DropoutNd(Module):
|
||||||
|
|
@ -18,7 +17,6 @@ class _DropoutNd(Module):
|
||||||
return 'p={}, inplace={}'.format(self.p, self.inplace)
|
return 'p={}, inplace={}'.format(self.p, self.inplace)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class Dropout(_DropoutNd):
|
class Dropout(_DropoutNd):
|
||||||
r"""During training, randomly zeroes some of the elements of the input
|
r"""During training, randomly zeroes some of the elements of the input
|
||||||
tensor with probability :attr:`p` using samples from a Bernoulli
|
tensor with probability :attr:`p` using samples from a Bernoulli
|
||||||
|
|
@ -52,12 +50,10 @@ class Dropout(_DropoutNd):
|
||||||
detectors: https://arxiv.org/abs/1207.0580
|
detectors: https://arxiv.org/abs/1207.0580
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
return F.dropout(input, self.p, self.training, self.inplace)
|
return F.dropout(input, self.p, self.training, self.inplace)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class Dropout2d(_DropoutNd):
|
class Dropout2d(_DropoutNd):
|
||||||
r"""Randomly zero out entire channels (a channel is a 2D feature map,
|
r"""Randomly zero out entire channels (a channel is a 2D feature map,
|
||||||
e.g., the :math:`j`-th channel of the :math:`i`-th sample in the
|
e.g., the :math:`j`-th channel of the :math:`i`-th sample in the
|
||||||
|
|
@ -96,12 +92,10 @@ class Dropout2d(_DropoutNd):
|
||||||
http://arxiv.org/abs/1411.4280
|
http://arxiv.org/abs/1411.4280
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
return F.dropout2d(input, self.p, self.training, self.inplace)
|
return F.dropout2d(input, self.p, self.training, self.inplace)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class Dropout3d(_DropoutNd):
|
class Dropout3d(_DropoutNd):
|
||||||
r"""Randomly zero out entire channels (a channel is a 3D feature map,
|
r"""Randomly zero out entire channels (a channel is a 3D feature map,
|
||||||
e.g., the :math:`j`-th channel of the :math:`i`-th sample in the
|
e.g., the :math:`j`-th channel of the :math:`i`-th sample in the
|
||||||
|
|
@ -140,12 +134,10 @@ class Dropout3d(_DropoutNd):
|
||||||
http://arxiv.org/abs/1411.4280
|
http://arxiv.org/abs/1411.4280
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
return F.dropout3d(input, self.p, self.training, self.inplace)
|
return F.dropout3d(input, self.p, self.training, self.inplace)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class AlphaDropout(_DropoutNd):
|
class AlphaDropout(_DropoutNd):
|
||||||
r"""Applies Alpha Dropout over the input.
|
r"""Applies Alpha Dropout over the input.
|
||||||
|
|
||||||
|
|
@ -184,14 +176,11 @@ class AlphaDropout(_DropoutNd):
|
||||||
.. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515
|
.. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
return F.alpha_dropout(input, self.p, self.training)
|
return F.alpha_dropout(input, self.p, self.training)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class FeatureAlphaDropout(_DropoutNd):
|
class FeatureAlphaDropout(_DropoutNd):
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
return F.feature_alpha_dropout(input, self.p, self.training)
|
return F.feature_alpha_dropout(input, self.p, self.training)
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,8 @@
|
||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
from .module import Module
|
from .module import Module
|
||||||
from .. import functional as F
|
from .. import functional as F
|
||||||
from ..._jit_internal import weak_module, weak_script_method
|
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class Fold(Module):
|
class Fold(Module):
|
||||||
r"""Combines an array of sliding local blocks into a large containing
|
r"""Combines an array of sliding local blocks into a large containing
|
||||||
tensor.
|
tensor.
|
||||||
|
|
@ -101,7 +99,6 @@ class Fold(Module):
|
||||||
self.padding = padding
|
self.padding = padding
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
return F.fold(input, self.output_size, self.kernel_size, self.dilation,
|
return F.fold(input, self.output_size, self.kernel_size, self.dilation,
|
||||||
self.padding, self.stride)
|
self.padding, self.stride)
|
||||||
|
|
@ -113,7 +110,6 @@ class Fold(Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class Unfold(Module):
|
class Unfold(Module):
|
||||||
r"""Extracts sliding local blocks from a batched input tensor.
|
r"""Extracts sliding local blocks from a batched input tensor.
|
||||||
|
|
||||||
|
|
@ -217,7 +213,6 @@ class Unfold(Module):
|
||||||
self.padding = padding
|
self.padding = padding
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
return F.unfold(input, self.kernel_size, self.dilation,
|
return F.unfold(input, self.kernel_size, self.dilation,
|
||||||
self.padding, self.stride)
|
self.padding, self.stride)
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,5 @@
|
||||||
from .batchnorm import _BatchNorm
|
from .batchnorm import _BatchNorm
|
||||||
from .. import functional as F
|
from .. import functional as F
|
||||||
from ..._jit_internal import weak_module, weak_script_method
|
|
||||||
|
|
||||||
|
|
||||||
class _InstanceNorm(_BatchNorm):
|
class _InstanceNorm(_BatchNorm):
|
||||||
|
|
@ -9,7 +8,6 @@ class _InstanceNorm(_BatchNorm):
|
||||||
super(_InstanceNorm, self).__init__(
|
super(_InstanceNorm, self).__init__(
|
||||||
num_features, eps, momentum, affine, track_running_stats)
|
num_features, eps, momentum, affine, track_running_stats)
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def _check_input_dim(self, input):
|
def _check_input_dim(self, input):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
@ -43,7 +41,6 @@ class _InstanceNorm(_BatchNorm):
|
||||||
state_dict, prefix, local_metadata, strict,
|
state_dict, prefix, local_metadata, strict,
|
||||||
missing_keys, unexpected_keys, error_msgs)
|
missing_keys, unexpected_keys, error_msgs)
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
self._check_input_dim(input)
|
self._check_input_dim(input)
|
||||||
|
|
||||||
|
|
@ -52,7 +49,6 @@ class _InstanceNorm(_BatchNorm):
|
||||||
self.training or not self.track_running_stats, self.momentum, self.eps)
|
self.training or not self.track_running_stats, self.momentum, self.eps)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class InstanceNorm1d(_InstanceNorm):
|
class InstanceNorm1d(_InstanceNorm):
|
||||||
r"""Applies Instance Normalization over a 3D input (a mini-batch of 1D
|
r"""Applies Instance Normalization over a 3D input (a mini-batch of 1D
|
||||||
inputs with optional additional channel dimension) as described in the paper
|
inputs with optional additional channel dimension) as described in the paper
|
||||||
|
|
@ -121,7 +117,6 @@ class InstanceNorm1d(_InstanceNorm):
|
||||||
https://arxiv.org/abs/1607.08022
|
https://arxiv.org/abs/1607.08022
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def _check_input_dim(self, input):
|
def _check_input_dim(self, input):
|
||||||
if input.dim() == 2:
|
if input.dim() == 2:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|
@ -135,7 +130,6 @@ class InstanceNorm1d(_InstanceNorm):
|
||||||
.format(input.dim()))
|
.format(input.dim()))
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class InstanceNorm2d(_InstanceNorm):
|
class InstanceNorm2d(_InstanceNorm):
|
||||||
r"""Applies Instance Normalization over a 4D input (a mini-batch of 2D inputs
|
r"""Applies Instance Normalization over a 4D input (a mini-batch of 2D inputs
|
||||||
with additional channel dimension) as described in the paper
|
with additional channel dimension) as described in the paper
|
||||||
|
|
@ -204,14 +198,12 @@ class InstanceNorm2d(_InstanceNorm):
|
||||||
https://arxiv.org/abs/1607.08022
|
https://arxiv.org/abs/1607.08022
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def _check_input_dim(self, input):
|
def _check_input_dim(self, input):
|
||||||
if input.dim() != 4:
|
if input.dim() != 4:
|
||||||
raise ValueError('expected 4D input (got {}D input)'
|
raise ValueError('expected 4D input (got {}D input)'
|
||||||
.format(input.dim()))
|
.format(input.dim()))
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class InstanceNorm3d(_InstanceNorm):
|
class InstanceNorm3d(_InstanceNorm):
|
||||||
r"""Applies Instance Normalization over a 5D input (a mini-batch of 3D inputs
|
r"""Applies Instance Normalization over a 5D input (a mini-batch of 3D inputs
|
||||||
with additional channel dimension) as described in the paper
|
with additional channel dimension) as described in the paper
|
||||||
|
|
@ -280,7 +272,6 @@ class InstanceNorm3d(_InstanceNorm):
|
||||||
https://arxiv.org/abs/1607.08022
|
https://arxiv.org/abs/1607.08022
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def _check_input_dim(self, input):
|
def _check_input_dim(self, input):
|
||||||
if input.dim() != 5:
|
if input.dim() != 5:
|
||||||
raise ValueError('expected 5D input (got {}D input)'
|
raise ValueError('expected 5D input (got {}D input)'
|
||||||
|
|
|
||||||
|
|
@ -5,10 +5,8 @@ from torch.nn.parameter import Parameter
|
||||||
from .. import functional as F
|
from .. import functional as F
|
||||||
from .. import init
|
from .. import init
|
||||||
from .module import Module
|
from .module import Module
|
||||||
from ..._jit_internal import weak_module, weak_script_method
|
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class Identity(Module):
|
class Identity(Module):
|
||||||
r"""A placeholder identity operator that is argument-insensitive.
|
r"""A placeholder identity operator that is argument-insensitive.
|
||||||
|
|
||||||
|
|
@ -28,12 +26,10 @@ class Identity(Module):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(Identity, self).__init__()
|
super(Identity, self).__init__()
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
return input
|
return input
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class Linear(Module):
|
class Linear(Module):
|
||||||
r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b`
|
r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b`
|
||||||
|
|
||||||
|
|
@ -87,7 +83,6 @@ class Linear(Module):
|
||||||
bound = 1 / math.sqrt(fan_in)
|
bound = 1 / math.sqrt(fan_in)
|
||||||
init.uniform_(self.bias, -bound, bound)
|
init.uniform_(self.bias, -bound, bound)
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
return F.linear(input, self.weight, self.bias)
|
return F.linear(input, self.weight, self.bias)
|
||||||
|
|
||||||
|
|
@ -97,7 +92,6 @@ class Linear(Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class Bilinear(Module):
|
class Bilinear(Module):
|
||||||
r"""Applies a bilinear transformation to the incoming data:
|
r"""Applies a bilinear transformation to the incoming data:
|
||||||
:math:`y = x_1 A x_2 + b`
|
:math:`y = x_1 A x_2 + b`
|
||||||
|
|
@ -157,7 +151,6 @@ class Bilinear(Module):
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
init.uniform_(self.bias, -bound, bound)
|
init.uniform_(self.bias, -bound, bound)
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input1, input2):
|
def forward(self, input1, input2):
|
||||||
return F.bilinear(input1, input2, self.weight, self.bias)
|
return F.bilinear(input1, input2, self.weight, self.bias)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,6 @@ import warnings
|
||||||
from .module import Module
|
from .module import Module
|
||||||
from .. import functional as F
|
from .. import functional as F
|
||||||
from .. import _reduction as _Reduction
|
from .. import _reduction as _Reduction
|
||||||
from ..._jit_internal import weak_module, weak_script_method
|
|
||||||
|
|
||||||
|
|
||||||
class _Loss(Module):
|
class _Loss(Module):
|
||||||
|
|
@ -21,7 +20,6 @@ class _WeightedLoss(_Loss):
|
||||||
self.register_buffer('weight', weight)
|
self.register_buffer('weight', weight)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class L1Loss(_Loss):
|
class L1Loss(_Loss):
|
||||||
r"""Creates a criterion that measures the mean absolute error (MAE) between each element in
|
r"""Creates a criterion that measures the mean absolute error (MAE) between each element in
|
||||||
the input :math:`x` and target :math:`y`.
|
the input :math:`x` and target :math:`y`.
|
||||||
|
|
@ -86,12 +84,10 @@ class L1Loss(_Loss):
|
||||||
def __init__(self, size_average=None, reduce=None, reduction='mean'):
|
def __init__(self, size_average=None, reduce=None, reduction='mean'):
|
||||||
super(L1Loss, self).__init__(size_average, reduce, reduction)
|
super(L1Loss, self).__init__(size_average, reduce, reduction)
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input, target):
|
def forward(self, input, target):
|
||||||
return F.l1_loss(input, target, reduction=self.reduction)
|
return F.l1_loss(input, target, reduction=self.reduction)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class NLLLoss(_WeightedLoss):
|
class NLLLoss(_WeightedLoss):
|
||||||
r"""The negative log likelihood loss. It is useful to train a classification
|
r"""The negative log likelihood loss. It is useful to train a classification
|
||||||
problem with `C` classes.
|
problem with `C` classes.
|
||||||
|
|
@ -204,12 +200,10 @@ class NLLLoss(_WeightedLoss):
|
||||||
super(NLLLoss, self).__init__(weight, size_average, reduce, reduction)
|
super(NLLLoss, self).__init__(weight, size_average, reduce, reduction)
|
||||||
self.ignore_index = ignore_index
|
self.ignore_index = ignore_index
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input, target):
|
def forward(self, input, target):
|
||||||
return F.nll_loss(input, target, weight=self.weight, ignore_index=self.ignore_index, reduction=self.reduction)
|
return F.nll_loss(input, target, weight=self.weight, ignore_index=self.ignore_index, reduction=self.reduction)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class NLLLoss2d(NLLLoss):
|
class NLLLoss2d(NLLLoss):
|
||||||
def __init__(self, weight=None, size_average=None, ignore_index=-100,
|
def __init__(self, weight=None, size_average=None, ignore_index=-100,
|
||||||
reduce=None, reduction='mean'):
|
reduce=None, reduction='mean'):
|
||||||
|
|
@ -219,7 +213,6 @@ class NLLLoss2d(NLLLoss):
|
||||||
super(NLLLoss2d, self).__init__(weight, size_average, ignore_index, reduce, reduction)
|
super(NLLLoss2d, self).__init__(weight, size_average, ignore_index, reduce, reduction)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class PoissonNLLLoss(_Loss):
|
class PoissonNLLLoss(_Loss):
|
||||||
r"""Negative log likelihood loss with Poisson distribution of target.
|
r"""Negative log likelihood loss with Poisson distribution of target.
|
||||||
|
|
||||||
|
|
@ -286,13 +279,11 @@ class PoissonNLLLoss(_Loss):
|
||||||
self.full = full
|
self.full = full
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, log_input, target):
|
def forward(self, log_input, target):
|
||||||
return F.poisson_nll_loss(log_input, target, log_input=self.log_input, full=self.full,
|
return F.poisson_nll_loss(log_input, target, log_input=self.log_input, full=self.full,
|
||||||
eps=self.eps, reduction=self.reduction)
|
eps=self.eps, reduction=self.reduction)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class KLDivLoss(_Loss):
|
class KLDivLoss(_Loss):
|
||||||
r"""The `Kullback-Leibler divergence`_ Loss
|
r"""The `Kullback-Leibler divergence`_ Loss
|
||||||
|
|
||||||
|
|
@ -370,12 +361,10 @@ class KLDivLoss(_Loss):
|
||||||
def __init__(self, size_average=None, reduce=None, reduction='mean'):
|
def __init__(self, size_average=None, reduce=None, reduction='mean'):
|
||||||
super(KLDivLoss, self).__init__(size_average, reduce, reduction)
|
super(KLDivLoss, self).__init__(size_average, reduce, reduction)
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input, target):
|
def forward(self, input, target):
|
||||||
return F.kl_div(input, target, reduction=self.reduction)
|
return F.kl_div(input, target, reduction=self.reduction)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class MSELoss(_Loss):
|
class MSELoss(_Loss):
|
||||||
r"""Creates a criterion that measures the mean squared error (squared L2 norm) between
|
r"""Creates a criterion that measures the mean squared error (squared L2 norm) between
|
||||||
each element in the input :math:`x` and target :math:`y`.
|
each element in the input :math:`x` and target :math:`y`.
|
||||||
|
|
@ -438,12 +427,10 @@ class MSELoss(_Loss):
|
||||||
def __init__(self, size_average=None, reduce=None, reduction='mean'):
|
def __init__(self, size_average=None, reduce=None, reduction='mean'):
|
||||||
super(MSELoss, self).__init__(size_average, reduce, reduction)
|
super(MSELoss, self).__init__(size_average, reduce, reduction)
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input, target):
|
def forward(self, input, target):
|
||||||
return F.mse_loss(input, target, reduction=self.reduction)
|
return F.mse_loss(input, target, reduction=self.reduction)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class BCELoss(_WeightedLoss):
|
class BCELoss(_WeightedLoss):
|
||||||
r"""Creates a criterion that measures the Binary Cross Entropy
|
r"""Creates a criterion that measures the Binary Cross Entropy
|
||||||
between the target and the output:
|
between the target and the output:
|
||||||
|
|
@ -507,12 +494,10 @@ class BCELoss(_WeightedLoss):
|
||||||
def __init__(self, weight=None, size_average=None, reduce=None, reduction='mean'):
|
def __init__(self, weight=None, size_average=None, reduce=None, reduction='mean'):
|
||||||
super(BCELoss, self).__init__(weight, size_average, reduce, reduction)
|
super(BCELoss, self).__init__(weight, size_average, reduce, reduction)
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input, target):
|
def forward(self, input, target):
|
||||||
return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)
|
return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class BCEWithLogitsLoss(_Loss):
|
class BCEWithLogitsLoss(_Loss):
|
||||||
r"""This loss combines a `Sigmoid` layer and the `BCELoss` in one single
|
r"""This loss combines a `Sigmoid` layer and the `BCELoss` in one single
|
||||||
class. This version is more numerically stable than using a plain `Sigmoid`
|
class. This version is more numerically stable than using a plain `Sigmoid`
|
||||||
|
|
@ -609,7 +594,6 @@ class BCEWithLogitsLoss(_Loss):
|
||||||
self.register_buffer('weight', weight)
|
self.register_buffer('weight', weight)
|
||||||
self.register_buffer('pos_weight', pos_weight)
|
self.register_buffer('pos_weight', pos_weight)
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input, target):
|
def forward(self, input, target):
|
||||||
return F.binary_cross_entropy_with_logits(input, target,
|
return F.binary_cross_entropy_with_logits(input, target,
|
||||||
self.weight,
|
self.weight,
|
||||||
|
|
@ -617,7 +601,6 @@ class BCEWithLogitsLoss(_Loss):
|
||||||
reduction=self.reduction)
|
reduction=self.reduction)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class HingeEmbeddingLoss(_Loss):
|
class HingeEmbeddingLoss(_Loss):
|
||||||
r"""Measures the loss given an input tensor :math:`x` and a labels tensor :math:`y`
|
r"""Measures the loss given an input tensor :math:`x` and a labels tensor :math:`y`
|
||||||
(containing 1 or -1).
|
(containing 1 or -1).
|
||||||
|
|
@ -673,12 +656,10 @@ class HingeEmbeddingLoss(_Loss):
|
||||||
super(HingeEmbeddingLoss, self).__init__(size_average, reduce, reduction)
|
super(HingeEmbeddingLoss, self).__init__(size_average, reduce, reduction)
|
||||||
self.margin = margin
|
self.margin = margin
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input, target):
|
def forward(self, input, target):
|
||||||
return F.hinge_embedding_loss(input, target, margin=self.margin, reduction=self.reduction)
|
return F.hinge_embedding_loss(input, target, margin=self.margin, reduction=self.reduction)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class MultiLabelMarginLoss(_Loss):
|
class MultiLabelMarginLoss(_Loss):
|
||||||
r"""Creates a criterion that optimizes a multi-class multi-classification
|
r"""Creates a criterion that optimizes a multi-class multi-classification
|
||||||
hinge loss (margin-based loss) between input :math:`x` (a 2D mini-batch `Tensor`)
|
hinge loss (margin-based loss) between input :math:`x` (a 2D mini-batch `Tensor`)
|
||||||
|
|
@ -739,12 +720,10 @@ class MultiLabelMarginLoss(_Loss):
|
||||||
def __init__(self, size_average=None, reduce=None, reduction='mean'):
|
def __init__(self, size_average=None, reduce=None, reduction='mean'):
|
||||||
super(MultiLabelMarginLoss, self).__init__(size_average, reduce, reduction)
|
super(MultiLabelMarginLoss, self).__init__(size_average, reduce, reduction)
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input, target):
|
def forward(self, input, target):
|
||||||
return F.multilabel_margin_loss(input, target, reduction=self.reduction)
|
return F.multilabel_margin_loss(input, target, reduction=self.reduction)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class SmoothL1Loss(_Loss):
|
class SmoothL1Loss(_Loss):
|
||||||
r"""Creates a criterion that uses a squared term if the absolute
|
r"""Creates a criterion that uses a squared term if the absolute
|
||||||
element-wise error falls below 1 and an L1 term otherwise.
|
element-wise error falls below 1 and an L1 term otherwise.
|
||||||
|
|
@ -799,12 +778,10 @@ class SmoothL1Loss(_Loss):
|
||||||
def __init__(self, size_average=None, reduce=None, reduction='mean'):
|
def __init__(self, size_average=None, reduce=None, reduction='mean'):
|
||||||
super(SmoothL1Loss, self).__init__(size_average, reduce, reduction)
|
super(SmoothL1Loss, self).__init__(size_average, reduce, reduction)
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input, target):
|
def forward(self, input, target):
|
||||||
return F.smooth_l1_loss(input, target, reduction=self.reduction)
|
return F.smooth_l1_loss(input, target, reduction=self.reduction)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class SoftMarginLoss(_Loss):
|
class SoftMarginLoss(_Loss):
|
||||||
r"""Creates a criterion that optimizes a two-class classification
|
r"""Creates a criterion that optimizes a two-class classification
|
||||||
logistic loss between input tensor :math:`x` and target tensor :math:`y`
|
logistic loss between input tensor :math:`x` and target tensor :math:`y`
|
||||||
|
|
@ -842,12 +819,10 @@ class SoftMarginLoss(_Loss):
|
||||||
def __init__(self, size_average=None, reduce=None, reduction='mean'):
|
def __init__(self, size_average=None, reduce=None, reduction='mean'):
|
||||||
super(SoftMarginLoss, self).__init__(size_average, reduce, reduction)
|
super(SoftMarginLoss, self).__init__(size_average, reduce, reduction)
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input, target):
|
def forward(self, input, target):
|
||||||
return F.soft_margin_loss(input, target, reduction=self.reduction)
|
return F.soft_margin_loss(input, target, reduction=self.reduction)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class CrossEntropyLoss(_WeightedLoss):
|
class CrossEntropyLoss(_WeightedLoss):
|
||||||
r"""This criterion combines :func:`nn.LogSoftmax` and :func:`nn.NLLLoss` in one single class.
|
r"""This criterion combines :func:`nn.LogSoftmax` and :func:`nn.NLLLoss` in one single class.
|
||||||
|
|
||||||
|
|
@ -936,13 +911,11 @@ class CrossEntropyLoss(_WeightedLoss):
|
||||||
super(CrossEntropyLoss, self).__init__(weight, size_average, reduce, reduction)
|
super(CrossEntropyLoss, self).__init__(weight, size_average, reduce, reduction)
|
||||||
self.ignore_index = ignore_index
|
self.ignore_index = ignore_index
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input, target):
|
def forward(self, input, target):
|
||||||
return F.cross_entropy(input, target, weight=self.weight,
|
return F.cross_entropy(input, target, weight=self.weight,
|
||||||
ignore_index=self.ignore_index, reduction=self.reduction)
|
ignore_index=self.ignore_index, reduction=self.reduction)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class MultiLabelSoftMarginLoss(_WeightedLoss):
|
class MultiLabelSoftMarginLoss(_WeightedLoss):
|
||||||
r"""Creates a criterion that optimizes a multi-label one-versus-all
|
r"""Creates a criterion that optimizes a multi-label one-versus-all
|
||||||
loss based on max-entropy, between input :math:`x` and target :math:`y` of size
|
loss based on max-entropy, between input :math:`x` and target :math:`y` of size
|
||||||
|
|
@ -986,12 +959,10 @@ class MultiLabelSoftMarginLoss(_WeightedLoss):
|
||||||
def __init__(self, weight=None, size_average=None, reduce=None, reduction='mean'):
|
def __init__(self, weight=None, size_average=None, reduce=None, reduction='mean'):
|
||||||
super(MultiLabelSoftMarginLoss, self).__init__(weight, size_average, reduce, reduction)
|
super(MultiLabelSoftMarginLoss, self).__init__(weight, size_average, reduce, reduction)
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input, target):
|
def forward(self, input, target):
|
||||||
return F.multilabel_soft_margin_loss(input, target, weight=self.weight, reduction=self.reduction)
|
return F.multilabel_soft_margin_loss(input, target, weight=self.weight, reduction=self.reduction)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class CosineEmbeddingLoss(_Loss):
|
class CosineEmbeddingLoss(_Loss):
|
||||||
r"""Creates a criterion that measures the loss given input tensors
|
r"""Creates a criterion that measures the loss given input tensors
|
||||||
:math:`x_1`, :math:`x_2` and a `Tensor` label :math:`y` with values 1 or -1.
|
:math:`x_1`, :math:`x_2` and a `Tensor` label :math:`y` with values 1 or -1.
|
||||||
|
|
@ -1034,12 +1005,10 @@ class CosineEmbeddingLoss(_Loss):
|
||||||
super(CosineEmbeddingLoss, self).__init__(size_average, reduce, reduction)
|
super(CosineEmbeddingLoss, self).__init__(size_average, reduce, reduction)
|
||||||
self.margin = margin
|
self.margin = margin
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input1, input2, target):
|
def forward(self, input1, input2, target):
|
||||||
return F.cosine_embedding_loss(input1, input2, target, margin=self.margin, reduction=self.reduction)
|
return F.cosine_embedding_loss(input1, input2, target, margin=self.margin, reduction=self.reduction)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class MarginRankingLoss(_Loss):
|
class MarginRankingLoss(_Loss):
|
||||||
r"""Creates a criterion that measures the loss given
|
r"""Creates a criterion that measures the loss given
|
||||||
inputs :math:`x1`, :math:`x2`, two 1D mini-batch `Tensors`,
|
inputs :math:`x1`, :math:`x2`, two 1D mini-batch `Tensors`,
|
||||||
|
|
@ -1082,12 +1051,10 @@ class MarginRankingLoss(_Loss):
|
||||||
super(MarginRankingLoss, self).__init__(size_average, reduce, reduction)
|
super(MarginRankingLoss, self).__init__(size_average, reduce, reduction)
|
||||||
self.margin = margin
|
self.margin = margin
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input1, input2, target):
|
def forward(self, input1, input2, target):
|
||||||
return F.margin_ranking_loss(input1, input2, target, margin=self.margin, reduction=self.reduction)
|
return F.margin_ranking_loss(input1, input2, target, margin=self.margin, reduction=self.reduction)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class MultiMarginLoss(_WeightedLoss):
|
class MultiMarginLoss(_WeightedLoss):
|
||||||
r"""Creates a criterion that optimizes a multi-class classification hinge
|
r"""Creates a criterion that optimizes a multi-class classification hinge
|
||||||
loss (margin-based loss) between input :math:`x` (a 2D mini-batch `Tensor`) and
|
loss (margin-based loss) between input :math:`x` (a 2D mini-batch `Tensor`) and
|
||||||
|
|
@ -1145,13 +1112,11 @@ class MultiMarginLoss(_WeightedLoss):
|
||||||
self.p = p
|
self.p = p
|
||||||
self.margin = margin
|
self.margin = margin
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input, target):
|
def forward(self, input, target):
|
||||||
return F.multi_margin_loss(input, target, p=self.p, margin=self.margin,
|
return F.multi_margin_loss(input, target, p=self.p, margin=self.margin,
|
||||||
weight=self.weight, reduction=self.reduction)
|
weight=self.weight, reduction=self.reduction)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class TripletMarginLoss(_Loss):
|
class TripletMarginLoss(_Loss):
|
||||||
r"""Creates a criterion that measures the triplet loss given an input
|
r"""Creates a criterion that measures the triplet loss given an input
|
||||||
tensors :math:`x1`, :math:`x2`, :math:`x3` and a margin with a value greater than :math:`0`.
|
tensors :math:`x1`, :math:`x2`, :math:`x3` and a margin with a value greater than :math:`0`.
|
||||||
|
|
@ -1221,13 +1186,11 @@ class TripletMarginLoss(_Loss):
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
self.swap = swap
|
self.swap = swap
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, anchor, positive, negative):
|
def forward(self, anchor, positive, negative):
|
||||||
return F.triplet_margin_loss(anchor, positive, negative, margin=self.margin, p=self.p,
|
return F.triplet_margin_loss(anchor, positive, negative, margin=self.margin, p=self.p,
|
||||||
eps=self.eps, swap=self.swap, reduction=self.reduction)
|
eps=self.eps, swap=self.swap, reduction=self.reduction)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class CTCLoss(_Loss):
|
class CTCLoss(_Loss):
|
||||||
r"""The Connectionist Temporal Classification loss.
|
r"""The Connectionist Temporal Classification loss.
|
||||||
|
|
||||||
|
|
@ -1327,7 +1290,6 @@ class CTCLoss(_Loss):
|
||||||
self.blank = blank
|
self.blank = blank
|
||||||
self.zero_infinity = zero_infinity
|
self.zero_infinity = zero_infinity
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, log_probs, targets, input_lengths, target_lengths):
|
def forward(self, log_probs, targets, input_lengths, target_lengths):
|
||||||
return F.ctc_loss(log_probs, targets, input_lengths, target_lengths, self.blank, self.reduction,
|
return F.ctc_loss(log_probs, targets, input_lengths, target_lengths, self.blank, self.reduction,
|
||||||
self.zero_infinity)
|
self.zero_infinity)
|
||||||
|
|
|
||||||
|
|
@ -4,10 +4,8 @@ from torch.nn.parameter import Parameter
|
||||||
from .module import Module
|
from .module import Module
|
||||||
from .. import functional as F
|
from .. import functional as F
|
||||||
from .. import init
|
from .. import init
|
||||||
from ..._jit_internal import weak_module, weak_script_method
|
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class LocalResponseNorm(Module):
|
class LocalResponseNorm(Module):
|
||||||
r"""Applies local response normalization over an input signal composed
|
r"""Applies local response normalization over an input signal composed
|
||||||
of several input planes, where channels occupy the second dimension.
|
of several input planes, where channels occupy the second dimension.
|
||||||
|
|
@ -45,7 +43,6 @@ class LocalResponseNorm(Module):
|
||||||
self.beta = beta
|
self.beta = beta
|
||||||
self.k = k
|
self.k = k
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
return F.local_response_norm(input, self.size, self.alpha, self.beta,
|
return F.local_response_norm(input, self.size, self.alpha, self.beta,
|
||||||
self.k)
|
self.k)
|
||||||
|
|
@ -71,7 +68,6 @@ class CrossMapLRN2d(Module):
|
||||||
return '{size}, alpha={alpha}, beta={beta}, k={k}'.format(**self.__dict__)
|
return '{size}, alpha={alpha}, beta={beta}, k={k}'.format(**self.__dict__)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class LayerNorm(Module):
|
class LayerNorm(Module):
|
||||||
r"""Applies Layer Normalization over a mini-batch of inputs as described in
|
r"""Applies Layer Normalization over a mini-batch of inputs as described in
|
||||||
the paper `Layer Normalization`_ .
|
the paper `Layer Normalization`_ .
|
||||||
|
|
@ -151,7 +147,6 @@ class LayerNorm(Module):
|
||||||
init.ones_(self.weight)
|
init.ones_(self.weight)
|
||||||
init.zeros_(self.bias)
|
init.zeros_(self.bias)
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
return F.layer_norm(
|
return F.layer_norm(
|
||||||
input, self.normalized_shape, self.weight, self.bias, self.eps)
|
input, self.normalized_shape, self.weight, self.bias, self.eps)
|
||||||
|
|
@ -161,7 +156,6 @@ class LayerNorm(Module):
|
||||||
'elementwise_affine={elementwise_affine}'.format(**self.__dict__)
|
'elementwise_affine={elementwise_affine}'.format(**self.__dict__)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class GroupNorm(Module):
|
class GroupNorm(Module):
|
||||||
r"""Applies Group Normalization over a mini-batch of inputs as described in
|
r"""Applies Group Normalization over a mini-batch of inputs as described in
|
||||||
the paper `Group Normalization`_ .
|
the paper `Group Normalization`_ .
|
||||||
|
|
@ -226,7 +220,6 @@ class GroupNorm(Module):
|
||||||
init.ones_(self.weight)
|
init.ones_(self.weight)
|
||||||
init.zeros_(self.bias)
|
init.zeros_(self.bias)
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
return F.group_norm(
|
return F.group_norm(
|
||||||
input, self.num_groups, self.weight, self.bias, self.eps)
|
input, self.num_groups, self.weight, self.bias, self.eps)
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,11 @@
|
||||||
from .module import Module
|
from .module import Module
|
||||||
from .utils import _pair, _quadruple, _ntuple
|
from .utils import _pair, _quadruple, _ntuple
|
||||||
from .. import functional as F
|
from .. import functional as F
|
||||||
from ..._jit_internal import weak_module, weak_script_method
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: grad_output size asserts in THNN
|
# TODO: grad_output size asserts in THNN
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class _ConstantPadNd(Module):
|
class _ConstantPadNd(Module):
|
||||||
__constants__ = ['padding', 'value']
|
__constants__ = ['padding', 'value']
|
||||||
|
|
||||||
|
|
@ -15,7 +13,6 @@ class _ConstantPadNd(Module):
|
||||||
super(_ConstantPadNd, self).__init__()
|
super(_ConstantPadNd, self).__init__()
|
||||||
self.value = value
|
self.value = value
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
return F.pad(input, self.padding, 'constant', self.value)
|
return F.pad(input, self.padding, 'constant', self.value)
|
||||||
|
|
||||||
|
|
@ -23,7 +20,6 @@ class _ConstantPadNd(Module):
|
||||||
return 'padding={}, value={}'.format(self.padding, self.value)
|
return 'padding={}, value={}'.format(self.padding, self.value)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class ConstantPad1d(_ConstantPadNd):
|
class ConstantPad1d(_ConstantPadNd):
|
||||||
r"""Pads the input tensor boundaries with a constant value.
|
r"""Pads the input tensor boundaries with a constant value.
|
||||||
|
|
||||||
|
|
@ -73,7 +69,6 @@ class ConstantPad1d(_ConstantPadNd):
|
||||||
self.padding = _pair(padding)
|
self.padding = _pair(padding)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class ConstantPad2d(_ConstantPadNd):
|
class ConstantPad2d(_ConstantPadNd):
|
||||||
r"""Pads the input tensor boundaries with a constant value.
|
r"""Pads the input tensor boundaries with a constant value.
|
||||||
|
|
||||||
|
|
@ -123,7 +118,6 @@ class ConstantPad2d(_ConstantPadNd):
|
||||||
self.padding = _quadruple(padding)
|
self.padding = _quadruple(padding)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class ConstantPad3d(_ConstantPadNd):
|
class ConstantPad3d(_ConstantPadNd):
|
||||||
r"""Pads the input tensor boundaries with a constant value.
|
r"""Pads the input tensor boundaries with a constant value.
|
||||||
|
|
||||||
|
|
@ -162,11 +156,9 @@ class ConstantPad3d(_ConstantPadNd):
|
||||||
self.padding = _ntuple(6)(padding)
|
self.padding = _ntuple(6)(padding)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class _ReflectionPadNd(Module):
|
class _ReflectionPadNd(Module):
|
||||||
__constants__ = ['padding']
|
__constants__ = ['padding']
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
return F.pad(input, self.padding, 'reflect')
|
return F.pad(input, self.padding, 'reflect')
|
||||||
|
|
||||||
|
|
@ -174,7 +166,6 @@ class _ReflectionPadNd(Module):
|
||||||
return '{}'.format(self.padding)
|
return '{}'.format(self.padding)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class ReflectionPad1d(_ReflectionPadNd):
|
class ReflectionPad1d(_ReflectionPadNd):
|
||||||
r"""Pads the input tensor using the reflection of the input boundary.
|
r"""Pads the input tensor using the reflection of the input boundary.
|
||||||
|
|
||||||
|
|
@ -214,7 +205,6 @@ class ReflectionPad1d(_ReflectionPadNd):
|
||||||
self.padding = _pair(padding)
|
self.padding = _pair(padding)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class ReflectionPad2d(_ReflectionPadNd):
|
class ReflectionPad2d(_ReflectionPadNd):
|
||||||
r"""Pads the input tensor using the reflection of the input boundary.
|
r"""Pads the input tensor using the reflection of the input boundary.
|
||||||
|
|
||||||
|
|
@ -265,11 +255,9 @@ class ReflectionPad2d(_ReflectionPadNd):
|
||||||
self.padding = _quadruple(padding)
|
self.padding = _quadruple(padding)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class _ReplicationPadNd(Module):
|
class _ReplicationPadNd(Module):
|
||||||
__constants__ = ['padding']
|
__constants__ = ['padding']
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
return F.pad(input, self.padding, 'replicate')
|
return F.pad(input, self.padding, 'replicate')
|
||||||
|
|
||||||
|
|
@ -277,7 +265,6 @@ class _ReplicationPadNd(Module):
|
||||||
return '{}'.format(self.padding)
|
return '{}'.format(self.padding)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class ReplicationPad1d(_ReplicationPadNd):
|
class ReplicationPad1d(_ReplicationPadNd):
|
||||||
r"""Pads the input tensor using replication of the input boundary.
|
r"""Pads the input tensor using replication of the input boundary.
|
||||||
|
|
||||||
|
|
@ -317,7 +304,6 @@ class ReplicationPad1d(_ReplicationPadNd):
|
||||||
self.padding = _pair(padding)
|
self.padding = _pair(padding)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class ReplicationPad2d(_ReplicationPadNd):
|
class ReplicationPad2d(_ReplicationPadNd):
|
||||||
r"""Pads the input tensor using replication of the input boundary.
|
r"""Pads the input tensor using replication of the input boundary.
|
||||||
|
|
||||||
|
|
@ -368,7 +354,6 @@ class ReplicationPad2d(_ReplicationPadNd):
|
||||||
self.padding = _quadruple(padding)
|
self.padding = _quadruple(padding)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class ReplicationPad3d(_ReplicationPadNd):
|
class ReplicationPad3d(_ReplicationPadNd):
|
||||||
r"""Pads the input tensor using replication of the input boundary.
|
r"""Pads the input tensor using replication of the input boundary.
|
||||||
|
|
||||||
|
|
@ -407,7 +392,6 @@ class ReplicationPad3d(_ReplicationPadNd):
|
||||||
self.padding = _ntuple(6)(padding)
|
self.padding = _ntuple(6)(padding)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class ZeroPad2d(ConstantPad2d):
|
class ZeroPad2d(ConstantPad2d):
|
||||||
r"""Pads the input tensor boundaries with zero.
|
r"""Pads the input tensor boundaries with zero.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,7 @@
|
||||||
from .module import Module
|
from .module import Module
|
||||||
from .. import functional as F
|
from .. import functional as F
|
||||||
from ..._jit_internal import weak_module, weak_script_method
|
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class PixelShuffle(Module):
|
class PixelShuffle(Module):
|
||||||
r"""Rearranges elements in a tensor of shape :math:`(*, C \times r^2, H, W)`
|
r"""Rearranges elements in a tensor of shape :math:`(*, C \times r^2, H, W)`
|
||||||
to a tensor of shape :math:`(*, C, H \times r, W \times r)`.
|
to a tensor of shape :math:`(*, C, H \times r, W \times r)`.
|
||||||
|
|
@ -41,7 +39,6 @@ class PixelShuffle(Module):
|
||||||
super(PixelShuffle, self).__init__()
|
super(PixelShuffle, self).__init__()
|
||||||
self.upscale_factor = upscale_factor
|
self.upscale_factor = upscale_factor
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
return F.pixel_shuffle(input, self.upscale_factor)
|
return F.pixel_shuffle(input, self.upscale_factor)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,8 @@
|
||||||
from .module import Module
|
from .module import Module
|
||||||
from .utils import _single, _pair, _triple
|
from .utils import _single, _pair, _triple
|
||||||
from .. import functional as F
|
from .. import functional as F
|
||||||
from ..._jit_internal import weak_module, weak_script_method
|
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class _MaxPoolNd(Module):
|
class _MaxPoolNd(Module):
|
||||||
__constants__ = ['kernel_size', 'stride', 'padding', 'dilation',
|
__constants__ = ['kernel_size', 'stride', 'padding', 'dilation',
|
||||||
'return_indices', 'ceil_mode']
|
'return_indices', 'ceil_mode']
|
||||||
|
|
@ -24,7 +22,6 @@ class _MaxPoolNd(Module):
|
||||||
', dilation={dilation}, ceil_mode={ceil_mode}'.format(**self.__dict__)
|
', dilation={dilation}, ceil_mode={ceil_mode}'.format(**self.__dict__)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class MaxPool1d(_MaxPoolNd):
|
class MaxPool1d(_MaxPoolNd):
|
||||||
r"""Applies a 1D max pooling over an input signal composed of several input
|
r"""Applies a 1D max pooling over an input signal composed of several input
|
||||||
planes.
|
planes.
|
||||||
|
|
@ -68,7 +65,6 @@ class MaxPool1d(_MaxPoolNd):
|
||||||
https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
|
https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
return F.max_pool1d(input, self.kernel_size, self.stride,
|
return F.max_pool1d(input, self.kernel_size, self.stride,
|
||||||
self.padding, self.dilation, self.ceil_mode,
|
self.padding, self.dilation, self.ceil_mode,
|
||||||
|
|
@ -79,7 +75,6 @@ class MaxPool1d(_MaxPoolNd):
|
||||||
', dilation={dilation}, ceil_mode={ceil_mode}'.format(**self.__dict__)
|
', dilation={dilation}, ceil_mode={ceil_mode}'.format(**self.__dict__)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class MaxPool2d(_MaxPoolNd):
|
class MaxPool2d(_MaxPoolNd):
|
||||||
r"""Applies a 2D max pooling over an input signal composed of several input
|
r"""Applies a 2D max pooling over an input signal composed of several input
|
||||||
planes.
|
planes.
|
||||||
|
|
@ -139,14 +134,12 @@ class MaxPool2d(_MaxPoolNd):
|
||||||
https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
|
https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
return F.max_pool2d(input, self.kernel_size, self.stride,
|
return F.max_pool2d(input, self.kernel_size, self.stride,
|
||||||
self.padding, self.dilation, self.ceil_mode,
|
self.padding, self.dilation, self.ceil_mode,
|
||||||
self.return_indices)
|
self.return_indices)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class MaxPool3d(_MaxPoolNd):
|
class MaxPool3d(_MaxPoolNd):
|
||||||
r"""Applies a 3D max pooling over an input signal composed of several input
|
r"""Applies a 3D max pooling over an input signal composed of several input
|
||||||
planes.
|
planes.
|
||||||
|
|
@ -210,14 +203,12 @@ class MaxPool3d(_MaxPoolNd):
|
||||||
https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
|
https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
|
||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
return F.max_pool3d(input, self.kernel_size, self.stride,
|
return F.max_pool3d(input, self.kernel_size, self.stride,
|
||||||
self.padding, self.dilation, self.ceil_mode,
|
self.padding, self.dilation, self.ceil_mode,
|
||||||
self.return_indices)
|
self.return_indices)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class _MaxUnpoolNd(Module):
|
class _MaxUnpoolNd(Module):
|
||||||
|
|
||||||
def extra_repr(self):
|
def extra_repr(self):
|
||||||
|
|
@ -226,7 +217,6 @@ class _MaxUnpoolNd(Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class MaxUnpool1d(_MaxUnpoolNd):
|
class MaxUnpool1d(_MaxUnpoolNd):
|
||||||
r"""Computes a partial inverse of :class:`MaxPool1d`.
|
r"""Computes a partial inverse of :class:`MaxPool1d`.
|
||||||
|
|
||||||
|
|
@ -292,7 +282,6 @@ class MaxUnpool1d(_MaxUnpoolNd):
|
||||||
self.padding, output_size)
|
self.padding, output_size)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class MaxUnpool2d(_MaxUnpoolNd):
|
class MaxUnpool2d(_MaxUnpoolNd):
|
||||||
r"""Computes a partial inverse of :class:`MaxPool2d`.
|
r"""Computes a partial inverse of :class:`MaxPool2d`.
|
||||||
|
|
||||||
|
|
@ -366,7 +355,6 @@ class MaxUnpool2d(_MaxUnpoolNd):
|
||||||
self.padding, output_size)
|
self.padding, output_size)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class MaxUnpool3d(_MaxUnpoolNd):
|
class MaxUnpool3d(_MaxUnpoolNd):
|
||||||
r"""Computes a partial inverse of :class:`MaxPool3d`.
|
r"""Computes a partial inverse of :class:`MaxPool3d`.
|
||||||
|
|
||||||
|
|
@ -429,7 +417,6 @@ class MaxUnpool3d(_MaxUnpoolNd):
|
||||||
self.padding, output_size)
|
self.padding, output_size)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class _AvgPoolNd(Module):
|
class _AvgPoolNd(Module):
|
||||||
__constants__ = ['kernel_size', 'stride', 'padding', 'ceil_mode', 'count_include_pad']
|
__constants__ = ['kernel_size', 'stride', 'padding', 'ceil_mode', 'count_include_pad']
|
||||||
|
|
||||||
|
|
@ -439,7 +426,6 @@ class _AvgPoolNd(Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class AvgPool1d(_AvgPoolNd):
|
class AvgPool1d(_AvgPoolNd):
|
||||||
r"""Applies a 1D average pooling over an input signal composed of several
|
r"""Applies a 1D average pooling over an input signal composed of several
|
||||||
input planes.
|
input planes.
|
||||||
|
|
@ -490,14 +476,12 @@ class AvgPool1d(_AvgPoolNd):
|
||||||
self.ceil_mode = ceil_mode
|
self.ceil_mode = ceil_mode
|
||||||
self.count_include_pad = count_include_pad
|
self.count_include_pad = count_include_pad
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
return F.avg_pool1d(
|
return F.avg_pool1d(
|
||||||
input, self.kernel_size, self.stride, self.padding, self.ceil_mode,
|
input, self.kernel_size, self.stride, self.padding, self.ceil_mode,
|
||||||
self.count_include_pad)
|
self.count_include_pad)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class AvgPool2d(_AvgPoolNd):
|
class AvgPool2d(_AvgPoolNd):
|
||||||
r"""Applies a 2D average pooling over an input signal composed of several input
|
r"""Applies a 2D average pooling over an input signal composed of several input
|
||||||
planes.
|
planes.
|
||||||
|
|
@ -557,13 +541,11 @@ class AvgPool2d(_AvgPoolNd):
|
||||||
self.ceil_mode = ceil_mode
|
self.ceil_mode = ceil_mode
|
||||||
self.count_include_pad = count_include_pad
|
self.count_include_pad = count_include_pad
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
return F.avg_pool2d(input, self.kernel_size, self.stride,
|
return F.avg_pool2d(input, self.kernel_size, self.stride,
|
||||||
self.padding, self.ceil_mode, self.count_include_pad)
|
self.padding, self.ceil_mode, self.count_include_pad)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class AvgPool3d(_AvgPoolNd):
|
class AvgPool3d(_AvgPoolNd):
|
||||||
r"""Applies a 3D average pooling over an input signal composed of several input
|
r"""Applies a 3D average pooling over an input signal composed of several input
|
||||||
planes.
|
planes.
|
||||||
|
|
@ -630,7 +612,6 @@ class AvgPool3d(_AvgPoolNd):
|
||||||
self.ceil_mode = ceil_mode
|
self.ceil_mode = ceil_mode
|
||||||
self.count_include_pad = count_include_pad
|
self.count_include_pad = count_include_pad
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
return F.avg_pool3d(input, self.kernel_size, self.stride,
|
return F.avg_pool3d(input, self.kernel_size, self.stride,
|
||||||
self.padding, self.ceil_mode, self.count_include_pad)
|
self.padding, self.ceil_mode, self.count_include_pad)
|
||||||
|
|
@ -642,7 +623,6 @@ class AvgPool3d(_AvgPoolNd):
|
||||||
self.__dict__.setdefault('count_include_pad', True)
|
self.__dict__.setdefault('count_include_pad', True)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class FractionalMaxPool2d(Module):
|
class FractionalMaxPool2d(Module):
|
||||||
r"""Applies a 2D fractional max pooling over an input signal composed of several input planes.
|
r"""Applies a 2D fractional max pooling over an input signal composed of several input planes.
|
||||||
|
|
||||||
|
|
@ -694,7 +674,6 @@ class FractionalMaxPool2d(Module):
|
||||||
raise ValueError("output_ratio must be between 0 and 1 (got {})"
|
raise ValueError("output_ratio must be between 0 and 1 (got {})"
|
||||||
.format(output_ratio))
|
.format(output_ratio))
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
return F.fractional_max_pool2d(
|
return F.fractional_max_pool2d(
|
||||||
input, self.kernel_size, self.output_size, self.output_ratio,
|
input, self.kernel_size, self.output_size, self.output_ratio,
|
||||||
|
|
@ -702,7 +681,6 @@ class FractionalMaxPool2d(Module):
|
||||||
_random_samples=self._random_samples)
|
_random_samples=self._random_samples)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class FractionalMaxPool3d(Module):
|
class FractionalMaxPool3d(Module):
|
||||||
r"""Applies a 3D fractional max pooling over an input signal composed of several input planes.
|
r"""Applies a 3D fractional max pooling over an input signal composed of several input planes.
|
||||||
|
|
||||||
|
|
@ -754,7 +732,6 @@ class FractionalMaxPool3d(Module):
|
||||||
raise ValueError("output_ratio must be between 0 and 1 (got {})"
|
raise ValueError("output_ratio must be between 0 and 1 (got {})"
|
||||||
.format(output_ratio))
|
.format(output_ratio))
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
return F.fractional_max_pool3d(
|
return F.fractional_max_pool3d(
|
||||||
input, self.kernel_size, self.output_size, self.output_ratio,
|
input, self.kernel_size, self.output_size, self.output_ratio,
|
||||||
|
|
@ -762,7 +739,6 @@ class FractionalMaxPool3d(Module):
|
||||||
_random_samples=self._random_samples)
|
_random_samples=self._random_samples)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class _LPPoolNd(Module):
|
class _LPPoolNd(Module):
|
||||||
__constants__ = ['norm_type', 'kernel_size', 'stride', 'ceil_mode']
|
__constants__ = ['norm_type', 'kernel_size', 'stride', 'ceil_mode']
|
||||||
|
|
||||||
|
|
@ -778,7 +754,6 @@ class _LPPoolNd(Module):
|
||||||
'ceil_mode={ceil_mode}'.format(**self.__dict__)
|
'ceil_mode={ceil_mode}'.format(**self.__dict__)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class LPPool1d(_LPPoolNd):
|
class LPPool1d(_LPPoolNd):
|
||||||
r"""Applies a 1D power-average pooling over an input signal composed of several input
|
r"""Applies a 1D power-average pooling over an input signal composed of several input
|
||||||
planes.
|
planes.
|
||||||
|
|
@ -814,14 +789,11 @@ class LPPool1d(_LPPoolNd):
|
||||||
>>> output = m(input)
|
>>> output = m(input)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
return F.lp_pool1d(input, float(self.norm_type), self.kernel_size,
|
return F.lp_pool1d(input, float(self.norm_type), self.kernel_size,
|
||||||
self.stride, self.ceil_mode)
|
self.stride, self.ceil_mode)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class LPPool2d(_LPPoolNd):
|
class LPPool2d(_LPPoolNd):
|
||||||
r"""Applies a 2D power-average pooling over an input signal composed of several input
|
r"""Applies a 2D power-average pooling over an input signal composed of several input
|
||||||
planes.
|
planes.
|
||||||
|
|
@ -871,13 +843,11 @@ class LPPool2d(_LPPoolNd):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
return F.lp_pool2d(input, float(self.norm_type), self.kernel_size,
|
return F.lp_pool2d(input, float(self.norm_type), self.kernel_size,
|
||||||
self.stride, self.ceil_mode)
|
self.stride, self.ceil_mode)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class _AdaptiveMaxPoolNd(Module):
|
class _AdaptiveMaxPoolNd(Module):
|
||||||
__constants__ = ['output_size', 'return_indices']
|
__constants__ = ['output_size', 'return_indices']
|
||||||
|
|
||||||
|
|
@ -893,7 +863,6 @@ class _AdaptiveMaxPoolNd(Module):
|
||||||
# output shapes are, and how the operation computes output.
|
# output shapes are, and how the operation computes output.
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class AdaptiveMaxPool1d(_AdaptiveMaxPoolNd):
|
class AdaptiveMaxPool1d(_AdaptiveMaxPoolNd):
|
||||||
r"""Applies a 1D adaptive max pooling over an input signal composed of several input planes.
|
r"""Applies a 1D adaptive max pooling over an input signal composed of several input planes.
|
||||||
|
|
||||||
|
|
@ -913,12 +882,10 @@ class AdaptiveMaxPool1d(_AdaptiveMaxPoolNd):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
return F.adaptive_max_pool1d(input, self.output_size, self.return_indices)
|
return F.adaptive_max_pool1d(input, self.output_size, self.return_indices)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class AdaptiveMaxPool2d(_AdaptiveMaxPoolNd):
|
class AdaptiveMaxPool2d(_AdaptiveMaxPoolNd):
|
||||||
r"""Applies a 2D adaptive max pooling over an input signal composed of several input planes.
|
r"""Applies a 2D adaptive max pooling over an input signal composed of several input planes.
|
||||||
|
|
||||||
|
|
@ -949,12 +916,10 @@ class AdaptiveMaxPool2d(_AdaptiveMaxPoolNd):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
return F.adaptive_max_pool2d(input, self.output_size, self.return_indices)
|
return F.adaptive_max_pool2d(input, self.output_size, self.return_indices)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class AdaptiveMaxPool3d(_AdaptiveMaxPoolNd):
|
class AdaptiveMaxPool3d(_AdaptiveMaxPoolNd):
|
||||||
r"""Applies a 3D adaptive max pooling over an input signal composed of several input planes.
|
r"""Applies a 3D adaptive max pooling over an input signal composed of several input planes.
|
||||||
|
|
||||||
|
|
@ -986,12 +951,10 @@ class AdaptiveMaxPool3d(_AdaptiveMaxPoolNd):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
return F.adaptive_max_pool3d(input, self.output_size, self.return_indices)
|
return F.adaptive_max_pool3d(input, self.output_size, self.return_indices)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class _AdaptiveAvgPoolNd(Module):
|
class _AdaptiveAvgPoolNd(Module):
|
||||||
__constants__ = ['output_size']
|
__constants__ = ['output_size']
|
||||||
|
|
||||||
|
|
@ -1003,7 +966,6 @@ class _AdaptiveAvgPoolNd(Module):
|
||||||
return 'output_size={}'.format(self.output_size)
|
return 'output_size={}'.format(self.output_size)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class AdaptiveAvgPool1d(_AdaptiveAvgPoolNd):
|
class AdaptiveAvgPool1d(_AdaptiveAvgPoolNd):
|
||||||
r"""Applies a 1D adaptive average pooling over an input signal composed of several input planes.
|
r"""Applies a 1D adaptive average pooling over an input signal composed of several input planes.
|
||||||
|
|
||||||
|
|
@ -1021,12 +983,10 @@ class AdaptiveAvgPool1d(_AdaptiveAvgPoolNd):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
return F.adaptive_avg_pool1d(input, self.output_size)
|
return F.adaptive_avg_pool1d(input, self.output_size)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class AdaptiveAvgPool2d(_AdaptiveAvgPoolNd):
|
class AdaptiveAvgPool2d(_AdaptiveAvgPoolNd):
|
||||||
r"""Applies a 2D adaptive average pooling over an input signal composed of several input planes.
|
r"""Applies a 2D adaptive average pooling over an input signal composed of several input planes.
|
||||||
|
|
||||||
|
|
@ -1055,12 +1015,10 @@ class AdaptiveAvgPool2d(_AdaptiveAvgPoolNd):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
return F.adaptive_avg_pool2d(input, self.output_size)
|
return F.adaptive_avg_pool2d(input, self.output_size)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class AdaptiveAvgPool3d(_AdaptiveAvgPoolNd):
|
class AdaptiveAvgPool3d(_AdaptiveAvgPoolNd):
|
||||||
r"""Applies a 3D adaptive average pooling over an input signal composed of several input planes.
|
r"""Applies a 3D adaptive average pooling over an input signal composed of several input planes.
|
||||||
|
|
||||||
|
|
@ -1089,6 +1047,5 @@ class AdaptiveAvgPool3d(_AdaptiveAvgPoolNd):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
return F.adaptive_avg_pool3d(input, self.output_size)
|
return F.adaptive_avg_pool3d(input, self.output_size)
|
||||||
|
|
|
||||||
|
|
@ -8,8 +8,7 @@ from ..parameter import Parameter
|
||||||
from ..utils.rnn import PackedSequence, get_packed_sequence
|
from ..utils.rnn import PackedSequence, get_packed_sequence
|
||||||
from .. import init
|
from .. import init
|
||||||
from .. import _VF
|
from .. import _VF
|
||||||
from ..._jit_internal import weak_module, weak_script_method, weak_script, \
|
from ..._jit_internal import _parameter_list
|
||||||
_parameter_list
|
|
||||||
|
|
||||||
_rnn_impls = {
|
_rnn_impls = {
|
||||||
'GRU': _VF.gru,
|
'GRU': _VF.gru,
|
||||||
|
|
@ -18,7 +17,6 @@ _rnn_impls = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@weak_script
|
|
||||||
def apply_permutation(tensor, permutation, dim=1):
|
def apply_permutation(tensor, permutation, dim=1):
|
||||||
# type: (Tensor, Tensor, int) -> Tensor
|
# type: (Tensor, Tensor, int) -> Tensor
|
||||||
return tensor.index_select(dim, permutation)
|
return tensor.index_select(dim, permutation)
|
||||||
|
|
@ -139,7 +137,6 @@ class RNNBase(Module):
|
||||||
def _get_flat_weights(self):
|
def _get_flat_weights(self):
|
||||||
return self._flat_weights
|
return self._flat_weights
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def check_input(self, input, batch_sizes):
|
def check_input(self, input, batch_sizes):
|
||||||
# type: (Tensor, Optional[Tensor]) -> None
|
# type: (Tensor, Optional[Tensor]) -> None
|
||||||
expected_input_dim = 2 if batch_sizes is not None else 3
|
expected_input_dim = 2 if batch_sizes is not None else 3
|
||||||
|
|
@ -152,7 +149,6 @@ class RNNBase(Module):
|
||||||
'input.size(-1) must be equal to input_size. Expected {}, got {}'.format(
|
'input.size(-1) must be equal to input_size. Expected {}, got {}'.format(
|
||||||
self.input_size, input.size(-1)))
|
self.input_size, input.size(-1)))
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def get_expected_hidden_size(self, input, batch_sizes):
|
def get_expected_hidden_size(self, input, batch_sizes):
|
||||||
# type: (Tensor, Optional[Tensor]) -> Tuple[int, int, int]
|
# type: (Tensor, Optional[Tensor]) -> Tuple[int, int, int]
|
||||||
if batch_sizes is not None:
|
if batch_sizes is not None:
|
||||||
|
|
@ -165,7 +161,6 @@ class RNNBase(Module):
|
||||||
mini_batch, self.hidden_size)
|
mini_batch, self.hidden_size)
|
||||||
return expected_hidden_size
|
return expected_hidden_size
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def check_hidden_size(self, hx, expected_hidden_size, msg='Expected hidden size {}, got {}'):
|
def check_hidden_size(self, hx, expected_hidden_size, msg='Expected hidden size {}, got {}'):
|
||||||
# type: (Tensor, Tuple[int, int, int], str) -> None
|
# type: (Tensor, Tuple[int, int, int], str) -> None
|
||||||
if hx.size() != expected_hidden_size:
|
if hx.size() != expected_hidden_size:
|
||||||
|
|
@ -374,7 +369,6 @@ class RNN(RNNBase):
|
||||||
super(RNN, self).__init__(mode, *args, **kwargs)
|
super(RNN, self).__init__(mode, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class LSTM(RNNBase):
|
class LSTM(RNNBase):
|
||||||
r"""Applies a multi-layer long short-term memory (LSTM) RNN to an input
|
r"""Applies a multi-layer long short-term memory (LSTM) RNN to an input
|
||||||
sequence.
|
sequence.
|
||||||
|
|
@ -484,7 +478,6 @@ class LSTM(RNNBase):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(LSTM, self).__init__('LSTM', *args, **kwargs)
|
super(LSTM, self).__init__('LSTM', *args, **kwargs)
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def check_forward_args(self, input, hidden, batch_sizes):
|
def check_forward_args(self, input, hidden, batch_sizes):
|
||||||
# type: (Tensor, Tuple[Tensor, Tensor], Optional[Tensor]) -> None
|
# type: (Tensor, Tuple[Tensor, Tensor], Optional[Tensor]) -> None
|
||||||
self.check_input(input, batch_sizes)
|
self.check_input(input, batch_sizes)
|
||||||
|
|
@ -495,14 +488,12 @@ class LSTM(RNNBase):
|
||||||
self.check_hidden_size(hidden[1], expected_hidden_size,
|
self.check_hidden_size(hidden[1], expected_hidden_size,
|
||||||
'Expected hidden[1] size {}, got {}')
|
'Expected hidden[1] size {}, got {}')
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def permute_hidden(self, hx, permutation):
|
def permute_hidden(self, hx, permutation):
|
||||||
# type: (Tuple[Tensor, Tensor], Optional[Tensor]) -> Tuple[Tensor, Tensor]
|
# type: (Tuple[Tensor, Tensor], Optional[Tensor]) -> Tuple[Tensor, Tensor]
|
||||||
if permutation is None:
|
if permutation is None:
|
||||||
return hx
|
return hx
|
||||||
return apply_permutation(hx[0], permutation), apply_permutation(hx[1], permutation)
|
return apply_permutation(hx[0], permutation), apply_permutation(hx[1], permutation)
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward_impl(self, input, hx, batch_sizes, max_batch_size, sorted_indices):
|
def forward_impl(self, input, hx, batch_sizes, max_batch_size, sorted_indices):
|
||||||
# type: (Tensor, Optional[Tuple[Tensor, Tensor]], Optional[Tensor], int, Optional[Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa
|
# type: (Tensor, Optional[Tuple[Tensor, Tensor]], Optional[Tensor], int, Optional[Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa
|
||||||
if hx is None:
|
if hx is None:
|
||||||
|
|
@ -528,7 +519,7 @@ class LSTM(RNNBase):
|
||||||
|
|
||||||
return output, hidden
|
return output, hidden
|
||||||
|
|
||||||
@weak_script_method
|
@torch._jit_internal.export
|
||||||
def forward_tensor(self, input, hx=None):
|
def forward_tensor(self, input, hx=None):
|
||||||
# type: (Tensor, Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
|
# type: (Tensor, Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
|
||||||
batch_sizes = None
|
batch_sizes = None
|
||||||
|
|
@ -540,7 +531,7 @@ class LSTM(RNNBase):
|
||||||
|
|
||||||
return output, self.permute_hidden(hidden, unsorted_indices)
|
return output, self.permute_hidden(hidden, unsorted_indices)
|
||||||
|
|
||||||
@weak_script_method
|
@torch._jit_internal.export
|
||||||
def forward_packed(self, input, hx=None):
|
def forward_packed(self, input, hx=None):
|
||||||
# type: (Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]], Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]], Tuple[Tensor, Tensor]] # noqa
|
# type: (Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]], Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]], Tuple[Tensor, Tensor]] # noqa
|
||||||
input, batch_sizes, sorted_indices, unsorted_indices = input
|
input, batch_sizes, sorted_indices, unsorted_indices = input
|
||||||
|
|
@ -552,6 +543,7 @@ class LSTM(RNNBase):
|
||||||
output = get_packed_sequence(output, batch_sizes, sorted_indices, unsorted_indices)
|
output = get_packed_sequence(output, batch_sizes, sorted_indices, unsorted_indices)
|
||||||
return output, self.permute_hidden(hidden, unsorted_indices)
|
return output, self.permute_hidden(hidden, unsorted_indices)
|
||||||
|
|
||||||
|
@torch._jit_internal.ignore
|
||||||
def forward(self, input, hx=None):
|
def forward(self, input, hx=None):
|
||||||
if isinstance(input, PackedSequence):
|
if isinstance(input, PackedSequence):
|
||||||
return self.forward_packed(input, hx)
|
return self.forward_packed(input, hx)
|
||||||
|
|
@ -694,14 +686,12 @@ class RNNCellBase(Module):
|
||||||
s += ', nonlinearity={nonlinearity}'
|
s += ', nonlinearity={nonlinearity}'
|
||||||
return s.format(**self.__dict__)
|
return s.format(**self.__dict__)
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def check_forward_input(self, input):
|
def check_forward_input(self, input):
|
||||||
if input.size(1) != self.input_size:
|
if input.size(1) != self.input_size:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"input has inconsistent input_size: got {}, expected {}".format(
|
"input has inconsistent input_size: got {}, expected {}".format(
|
||||||
input.size(1), self.input_size))
|
input.size(1), self.input_size))
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def check_forward_hidden(self, input, hx, hidden_label=''):
|
def check_forward_hidden(self, input, hx, hidden_label=''):
|
||||||
# type: (Tensor, Tensor, str) -> None
|
# type: (Tensor, Tensor, str) -> None
|
||||||
if input.size(0) != hx.size(0):
|
if input.size(0) != hx.size(0):
|
||||||
|
|
@ -720,7 +710,6 @@ class RNNCellBase(Module):
|
||||||
init.uniform_(weight, -stdv, stdv)
|
init.uniform_(weight, -stdv, stdv)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class RNNCell(RNNCellBase):
|
class RNNCell(RNNCellBase):
|
||||||
r"""An Elman RNN cell with tanh or ReLU non-linearity.
|
r"""An Elman RNN cell with tanh or ReLU non-linearity.
|
||||||
|
|
||||||
|
|
@ -784,7 +773,6 @@ class RNNCell(RNNCellBase):
|
||||||
super(RNNCell, self).__init__(input_size, hidden_size, bias, num_chunks=1)
|
super(RNNCell, self).__init__(input_size, hidden_size, bias, num_chunks=1)
|
||||||
self.nonlinearity = nonlinearity
|
self.nonlinearity = nonlinearity
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input, hx=None):
|
def forward(self, input, hx=None):
|
||||||
# type: (Tensor, Optional[Tensor]) -> Tensor
|
# type: (Tensor, Optional[Tensor]) -> Tensor
|
||||||
self.check_forward_input(input)
|
self.check_forward_input(input)
|
||||||
|
|
@ -810,7 +798,6 @@ class RNNCell(RNNCellBase):
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class LSTMCell(RNNCellBase):
|
class LSTMCell(RNNCellBase):
|
||||||
r"""A long short-term memory (LSTM) cell.
|
r"""A long short-term memory (LSTM) cell.
|
||||||
|
|
||||||
|
|
@ -875,7 +862,6 @@ class LSTMCell(RNNCellBase):
|
||||||
def __init__(self, input_size, hidden_size, bias=True):
|
def __init__(self, input_size, hidden_size, bias=True):
|
||||||
super(LSTMCell, self).__init__(input_size, hidden_size, bias, num_chunks=4)
|
super(LSTMCell, self).__init__(input_size, hidden_size, bias, num_chunks=4)
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input, hx=None):
|
def forward(self, input, hx=None):
|
||||||
# type: (Tensor, Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tensor]
|
# type: (Tensor, Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tensor]
|
||||||
self.check_forward_input(input)
|
self.check_forward_input(input)
|
||||||
|
|
@ -891,7 +877,6 @@ class LSTMCell(RNNCellBase):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class GRUCell(RNNCellBase):
|
class GRUCell(RNNCellBase):
|
||||||
r"""A gated recurrent unit (GRU) cell
|
r"""A gated recurrent unit (GRU) cell
|
||||||
|
|
||||||
|
|
@ -957,7 +942,6 @@ class GRUCell(RNNCellBase):
|
||||||
def __init__(self, input_size, hidden_size, bias=True):
|
def __init__(self, input_size, hidden_size, bias=True):
|
||||||
super(GRUCell, self).__init__(input_size, hidden_size, bias, num_chunks=3)
|
super(GRUCell, self).__init__(input_size, hidden_size, bias, num_chunks=3)
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input, hx=None):
|
def forward(self, input, hx=None):
|
||||||
# type: (Tensor, Optional[Tensor]) -> Tensor
|
# type: (Tensor, Optional[Tensor]) -> Tensor
|
||||||
self.check_forward_input(input)
|
self.check_forward_input(input)
|
||||||
|
|
|
||||||
|
|
@ -4,10 +4,8 @@ from torch.nn.parameter import Parameter
|
||||||
from .module import Module
|
from .module import Module
|
||||||
from .. import functional as F
|
from .. import functional as F
|
||||||
from .. import init
|
from .. import init
|
||||||
from torch._jit_internal import weak_module, weak_script_method
|
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class Embedding(Module):
|
class Embedding(Module):
|
||||||
r"""A simple lookup table that stores embeddings of a fixed dictionary and size.
|
r"""A simple lookup table that stores embeddings of a fixed dictionary and size.
|
||||||
|
|
||||||
|
|
@ -110,7 +108,6 @@ class Embedding(Module):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self.weight[self.padding_idx].fill_(0)
|
self.weight[self.padding_idx].fill_(0)
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
return F.embedding(
|
return F.embedding(
|
||||||
input, self.weight, self.padding_idx, self.max_norm,
|
input, self.weight, self.padding_idx, self.max_norm,
|
||||||
|
|
@ -173,7 +170,6 @@ class Embedding(Module):
|
||||||
return embedding
|
return embedding
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class EmbeddingBag(Module):
|
class EmbeddingBag(Module):
|
||||||
r"""Computes sums or means of 'bags' of embeddings, without instantiating the
|
r"""Computes sums or means of 'bags' of embeddings, without instantiating the
|
||||||
intermediate embeddings.
|
intermediate embeddings.
|
||||||
|
|
@ -277,7 +273,6 @@ class EmbeddingBag(Module):
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
init.normal_(self.weight)
|
init.normal_(self.weight)
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input, offsets=None, per_sample_weights=None):
|
def forward(self, input, offsets=None, per_sample_weights=None):
|
||||||
# type: (Tensor, Optional[Tensor], Optional[Tensor]) -> Tensor
|
# type: (Tensor, Optional[Tensor], Optional[Tensor]) -> Tensor
|
||||||
return F.embedding_bag(input, self.weight, offsets,
|
return F.embedding_bag(input, self.weight, offsets,
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,7 @@
|
||||||
from .module import Module
|
from .module import Module
|
||||||
from .. import functional as F
|
from .. import functional as F
|
||||||
from ..._jit_internal import weak_module, weak_script_method
|
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class Upsample(Module):
|
class Upsample(Module):
|
||||||
r"""Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D (volumetric) data.
|
r"""Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D (volumetric) data.
|
||||||
|
|
||||||
|
|
@ -129,7 +127,6 @@ class Upsample(Module):
|
||||||
self.mode = mode
|
self.mode = mode
|
||||||
self.align_corners = align_corners
|
self.align_corners = align_corners
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
return F.interpolate(input, self.size, self.scale_factor, self.mode, self.align_corners)
|
return F.interpolate(input, self.size, self.scale_factor, self.mode, self.align_corners)
|
||||||
|
|
||||||
|
|
@ -142,7 +139,6 @@ class Upsample(Module):
|
||||||
return info
|
return info
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class UpsamplingNearest2d(Upsample):
|
class UpsamplingNearest2d(Upsample):
|
||||||
r"""Applies a 2D nearest neighbor upsampling to an input signal composed of several input
|
r"""Applies a 2D nearest neighbor upsampling to an input signal composed of several input
|
||||||
channels.
|
channels.
|
||||||
|
|
@ -188,7 +184,6 @@ class UpsamplingNearest2d(Upsample):
|
||||||
super(UpsamplingNearest2d, self).__init__(size, scale_factor, mode='nearest')
|
super(UpsamplingNearest2d, self).__init__(size, scale_factor, mode='nearest')
|
||||||
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class UpsamplingBilinear2d(Upsample):
|
class UpsamplingBilinear2d(Upsample):
|
||||||
r"""Applies a 2D bilinear upsampling to an input signal composed of several input
|
r"""Applies a 2D bilinear upsampling to an input signal composed of several input
|
||||||
channels.
|
channels.
|
||||||
|
|
|
||||||
|
|
@ -23,10 +23,9 @@ def _is_jit_enabled():
|
||||||
|
|
||||||
|
|
||||||
# Check if we can safely replicate the module.
|
# Check if we can safely replicate the module.
|
||||||
# there are three types of module:
|
# there are two types of module:
|
||||||
# 1. python modules
|
# 1. python modules
|
||||||
# 2. weak python modules (nn.Module annotated by @weak_module)
|
# 2. ScriptModule
|
||||||
# 3. ScriptModule
|
|
||||||
#
|
#
|
||||||
# currently a module cannot be replicated properly if the descendants of
|
# currently a module cannot be replicated properly if the descendants of
|
||||||
# any ScriptModule contains python module (type 1 above)
|
# any ScriptModule contains python module (type 1 above)
|
||||||
|
|
|
||||||
|
|
@ -5,9 +5,7 @@ from __future__ import unicode_literals
|
||||||
|
|
||||||
from .. import functional as F
|
from .. import functional as F
|
||||||
from ...modules.module import Module
|
from ...modules.module import Module
|
||||||
from ...._jit_internal import weak_module, weak_script_method
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class ReLU(Module):
|
class ReLU(Module):
|
||||||
r"""Applies quantized rectified linear unit function element-wise:
|
r"""Applies quantized rectified linear unit function element-wise:
|
||||||
|
|
||||||
|
|
@ -37,7 +35,6 @@ class ReLU(Module):
|
||||||
assert not inplace, 'torch.nn.quantized.ReLU does not support inplace'
|
assert not inplace, 'torch.nn.quantized.ReLU does not support inplace'
|
||||||
|
|
||||||
|
|
||||||
@weak_script_method
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
return F.relu(input)
|
return F.relu(input)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,9 +2,7 @@ from __future__ import absolute_import, division, print_function, unicode_litera
|
||||||
import torch
|
import torch
|
||||||
from ...modules.module import Module
|
from ...modules.module import Module
|
||||||
from ...modules.linear import Linear as NNLinear
|
from ...modules.linear import Linear as NNLinear
|
||||||
from ...._jit_internal import weak_module
|
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class Quantize(Module):
|
class Quantize(Module):
|
||||||
r"""Quantizes an incoming tensor
|
r"""Quantizes an incoming tensor
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -39,7 +37,6 @@ class Quantize(Module):
|
||||||
def from_float(mod):
|
def from_float(mod):
|
||||||
return Quantize(mod.qparams[0].item(), mod.qparams[1].item(), torch.quint8)
|
return Quantize(mod.qparams[0].item(), mod.qparams[1].item(), torch.quint8)
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class DeQuantize(Module):
|
class DeQuantize(Module):
|
||||||
r"""Dequantizes an incoming tensor
|
r"""Dequantizes an incoming tensor
|
||||||
|
|
||||||
|
|
@ -65,7 +62,6 @@ class DeQuantize(Module):
|
||||||
def from_float(mod):
|
def from_float(mod):
|
||||||
return DeQuantize()
|
return DeQuantize()
|
||||||
|
|
||||||
@weak_module
|
|
||||||
class Linear(NNLinear):
|
class Linear(NNLinear):
|
||||||
r"""
|
r"""
|
||||||
A quantized linear module with quantized tensor as inputs
|
A quantized linear module with quantized tensor as inputs
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user