diff --git a/test/test_jit.py b/test/test_jit.py index 4eb7fdf3250..618e258ff98 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -6979,7 +6979,7 @@ a") with self.assertRaisesRegex(RuntimeError, "'int' object is not iterable"): M() - def test_script_module_list_sequential_error(self): + def test_script_module_list_sequential(self): class M(torch.jit.ScriptModule): def __init__(self, mod_list): super(M, self).__init__(False) @@ -6991,25 +6991,21 @@ a") v = m(v) return v - with self.assertRaisesRegex(RuntimeError, "Did you forget to add it to __constants"): - a = M(nn.Sequential(nn.ReLU())) - with self.assertRaisesRegex(RuntimeError, "Did you forget to add it to __constants"): - a = M(nn.ModuleList([nn.ReLU()])) + m = M(nn.Sequential(nn.ReLU())) + self.assertExportImportModule(m, (torch.randn(2, 2),)) - def test_attr_module_constants_error(self): + def test_attr_module_constants(self): class M2(torch.jit.ScriptModule): def __init__(self, mod_list): super(M2, self).__init__(False) self.mods = mod_list @torch.jit.script_method - def forward(self, v): + def forward(self, x): return self.mods.forward(x) - with self.assertRaisesRegex(RuntimeError, "Did you forget to add it to __constants"): - M2(nn.Sequential(nn.ReLU())) - with self.assertRaisesRegex(RuntimeError, "Did you forget to add it to __constants"): - M2(nn.ModuleList([nn.ReLU()])) + m = M2(nn.Sequential(nn.ReLU())) + self.assertExportImportModule(m, (torch.randn(2, 2),)) def test_script_sequential_for(self): class Sub(torch.jit.ScriptModule): @@ -11007,6 +11003,7 @@ a") with self.assertRaisesRegex(torch.jit.Error, "Exception"): foo(torch.tensor(0)) + @unittest.skipIf(True, "Removing weak script") def test_weak_script_function(self): outer_var = 10 outer_var2 = 11 @@ -11086,6 +11083,7 @@ a") eg = torch.zeros(3, dtype=torch.uint8) self.assertEqual(foo_traced(eg), foo(eg)) + @unittest.skipIf(True, "Removing weak script") def test_weak_module(self): @torch._jit_internal.weak_module @@ -11161,6 +11159,7 @@ a") self.assertEqual(script_result, expected_result) self.assertEqual(script_result, script_result2) + @unittest.skipIf(True, "Removing weak script") def test_weak_module_parameters_and_buffers(self): weights = torch.randn(10, 10) bias = torch.randn(10) @@ -11219,6 +11218,7 @@ a") self.assertEqual(strong_mod(inp), expected_result) self.assertExportImportModule(strong_mod, (inp,)) + @unittest.skipIf(True, "Removing weak script") def test_weak_module_nested(self): @torch._jit_internal.weak_module class OtherWeak(torch.nn.Module): @@ -11280,6 +11280,7 @@ a") + F.linear(inp, 2 * torch.ones(10, 10), 2 * torch.ones(10)) self.assertEqual(result, expected_result) + @unittest.skipIf(True, "Removing weak script") def test_weak_module_submodule(self): @torch._jit_internal.weak_module class Weak(torch.nn.Module): @@ -11319,6 +11320,7 @@ a") with self.assertRaisesRegex(RuntimeError, "Cannot call a ScriptModule that is not a submodule of the caller"): strong_mod = Strong() + @unittest.skipIf(True, "Removing weak script") def test_weak_module_copying(self): class Submodule(torch.nn.Module): def __init__(self): @@ -11385,6 +11387,7 @@ a") m = M() + @unittest.skipIf(True, "Removing weak script") def test_weak_module_attributes(self): tester = self @@ -11948,6 +11951,7 @@ a") FileCheck().check_not("prim::PythonOp").run(cu.test.graph) + @unittest.skipIf(True, "Removing weak script") def test_overloading(self): @torch._jit_internal.weak_module class W(torch.nn.Module): @@ -13623,6 +13627,9 @@ EXCLUDE_SCRIPT_MODULES = { 'test_nn_AdaptiveAvgPool3d_tuple_none', 'test_nn_AdaptiveMaxPool2d_tuple_none', 'test_nn_AdaptiveMaxPool3d_tuple_none', + + # Uses Module._backend, so this is not supported + 'test_nn_CrossMapLRN2d', } @@ -14552,10 +14559,6 @@ def add_nn_module_test(*args, **kwargs): module_name = name.split("_")[0] - module = getattr(torch.nn, module_name, None) - if module is None or torch._jit_internal.weak_types.get(module) is None: - return - if 'desc' in kwargs and 'eval' in kwargs['desc']: # eval() is not supported, so skip these tests return diff --git a/torch/_jit_internal.py b/torch/_jit_internal.py index 8c484135362..8cb77f2eebd 100644 --- a/torch/_jit_internal.py +++ b/torch/_jit_internal.py @@ -4,29 +4,14 @@ can be used in other places in torch/ (namely torch.nn) without running into circular dependency problems """ -import weakref import inspect +import weakref from torch._six import builtins -# Tracks standalone weak script functions -compiled_weak_fns = weakref.WeakKeyDictionary() # noqa: T484 - -# Tracks which methods should be converted to strong methods -weak_script_methods = weakref.WeakKeyDictionary() # noqa: T484 - -# Converted modules and their corresponding WeakScriptModuleProxy objects -weak_modules = weakref.WeakKeyDictionary() # noqa: T484 - -# Types that have been declared as weak modules -weak_types = weakref.WeakKeyDictionary() # noqa: T484 - # Wrapper functions that can call either of 2 functions depending on a boolean # argument boolean_dispatched = weakref.WeakKeyDictionary() # noqa: T484 -COMPILATION_PENDING = object() -COMPILED = object() - def createResolutionCallback(frames_up=0): """ @@ -71,51 +56,41 @@ def createResolutionCallback(frames_up=0): return f_globals[key] elif hasattr(builtins, key): return getattr(builtins, key) - else: - return None return env -def weak_script(fn, _frames_up=0): +def createResolutionCallbackFromClosure(fn): """ - Marks a function as a weak script function. When used in a script function - or ScriptModule, the weak script function will be lazily compiled and - inlined in the graph. When not used in a script function, the weak script - annotation has no effect. + Create a resolutionCallback by introspecting the function instead of + looking up the stack for the enclosing scope """ - compiled_weak_fns[fn] = { - "status": COMPILATION_PENDING, - "compiled_fn": None, - "rcb": createResolutionCallback(_frames_up + 1) - } - return fn + var_names = fn.__code__.co_freevars + # map of captured name -> value + free_vars = {} -def weak_module(cls): - weak_types[cls] = { - "method_stubs": None - } - return cls + for index, name in enumerate(var_names): + free_vars[name] = fn.__closure__[index].cell_contents + f_globals = fn.__globals__ + def env(key): + if key in free_vars: + return free_vars[key] + elif hasattr(builtins, key): + return getattr(builtins, key) + else: + return f_globals.get(key) -def weak_script_method(fn): - weak_script_methods[fn] = { - "rcb": createResolutionCallback(frames_up=2), - "original_method": fn - } - return fn + return env def boolean_dispatch(arg_name, arg_index, default, if_true, if_false, module_name, func_name): """ - Dispatches to either of 2 weak script functions based on a boolean argument. + Dispatches to either of 2 script functions based on a boolean argument. In TorchScript, the boolean argument must be constant so that the correct function to use can be determined at compile time. """ - if compiled_weak_fns.get(if_true) is None or compiled_weak_fns.get(if_false) is None: - raise RuntimeError("both functions must be weak script") - def fn(*args, **kwargs): dispatch_flag = False if arg_name in kwargs: diff --git a/torch/csrc/jit/script/parser.cpp b/torch/csrc/jit/script/parser.cpp index e5948f7e958..4eab39c9f8a 100644 --- a/torch/csrc/jit/script/parser.cpp +++ b/torch/csrc/jit/script/parser.cpp @@ -23,7 +23,8 @@ Decl mergeTypesFromTypeComment( << "Number of type annotations (" << type_annotation_decl.params().size() << ") did not match the number of " - << "function parameters (" << expected_num_annotations << ")"; + << (is_method ? "method" : "function") + << " parameters (" << expected_num_annotations << ")"; } auto old = decl.params(); auto _new = type_annotation_decl.params(); diff --git a/torch/csrc/jit/script/python_sugared_value.cpp b/torch/csrc/jit/script/python_sugared_value.cpp index 20b2c09a1a8..aedfb1185f8 100644 --- a/torch/csrc/jit/script/python_sugared_value.cpp +++ b/torch/csrc/jit/script/python_sugared_value.cpp @@ -244,6 +244,11 @@ std::shared_ptr OverloadedMethodValue::call( << err.str(); } +bool should_recurse(py::object obj) { + return py::cast(py::module::import("torch.jit") + .attr("_is_recursive_script_enabled")(obj)); +} + std::shared_ptr ModuleValue::attr( const SourceRange& loc, Function& m, @@ -307,7 +312,7 @@ std::shared_ptr ModuleValue::attr( // If recursive script mode is on, create a ScriptModule and register it as // as submodule or register a python method as a script::Method - if (getRecursiveScriptMode()) { + if (should_recurse(attr)) { if (py::isinstance(attr, py::module::import("torch.nn").attr("Module"))) { // If the module is a submodule of the py_module, convert it to a // ScriptModule and add it as a submodule to the script::Module. This @@ -471,11 +476,6 @@ std::shared_ptr toSugaredValue( } } - auto weak_obj = - py::module::import("torch.jit").attr("_try_get_weak_module")(obj); - if (!weak_obj.is_none()) { - obj = weak_obj; - } if (auto callee = as_function(obj)) { return std::make_shared(callee); } else if (py::isinstance(obj)) { @@ -504,12 +504,6 @@ std::shared_ptr toSugaredValue( << "which is currently not supported in Torchscript." << "Please open a feature request to add it."; } - - auto compiled_fn = - py::module::import("torch.jit").attr("_try_compile_weak_script")(obj); - if (auto callee = as_function(compiled_fn)) { - return std::make_shared(callee); - } } py::object dispatched_fn = @@ -528,7 +522,7 @@ std::shared_ptr toSugaredValue( } } - if (getRecursiveScriptMode() && py::isinstance(obj)) { + if (should_recurse(obj) && py::isinstance(obj)) { auto compiled_fn = py::module::import("torch.jit").attr("_try_compile_fn")(obj); if (auto callee = as_function(compiled_fn)) { diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py index 09a490c2197..4691acbb558 100644 --- a/torch/jit/__init__.py +++ b/torch/jit/__init__.py @@ -7,7 +7,7 @@ import torch.backends.cudnn as cudnn import torch.jit.annotations import torch._jit_internal as _jit_internal from torch._six import PY2, PY37, with_metaclass, get_function_from_type, \ - string_classes, builtins + string_classes from ..nn.modules.utils import _single, _pair, _triple, _quadruple, \ _list_with_default import torch.testing @@ -930,50 +930,10 @@ def _try_get_overloaded_fn(mod, field): return mod._overloads.get(field, None) if isinstance(mod, ScriptModule) else None -def _try_compile_weak_script(fn): - entry = _jit_internal.compiled_weak_fns.get(fn) - if entry is None: - return None - if entry["status"] == _jit_internal.COMPILATION_PENDING: - compiled_fn = torch.jit.script(fn, True, 0, entry["rcb"]) - del entry["rcb"] - _jit_internal.compiled_weak_fns[fn]["compiled_fn"] = compiled_fn - entry["status"] = _jit_internal.COMPILED - return compiled_fn - # TODO: use fn.__closure__ - raise RuntimeError("Cannot make resolutionCallback in Python 2") - else: - return entry["compiled_fn"] - - class ScriptWarning(Warning): pass -def createResolutionCallbackFromClosure(fn): - """ - Create a resolutionCallback by introspecting the function instead of - looking up the stack for the enclosing scope - """ - var_names = fn.__code__.co_freevars - - # map of captured name -> value - free_vars = {} - - for index, name in enumerate(var_names): - free_vars[name] = fn.__closure__[index].cell_contents - f_globals = fn.__globals__ - - def env(key): - if key in free_vars: - return free_vars[key] - elif hasattr(builtins, key): - return getattr(builtins, key) - else: - return f_globals.get(key) - - return env - def _create_constant_iterable_module(module): modules = OrderedDict() @@ -1012,20 +972,20 @@ def _try_compile_fn(fn): # Don't do anything for @ignore'd functions return None - if not inspect.isfunction(fn) and not inspect.ismethod(fn): - raise RuntimeError("`{}` is not a function. Recursive scripting only supports " - "Python functions or methods currently.\n" - "Consider manually annotating `{}` with @torch.jit.script.".format(fn)) - if isinstance(fn, torch.nn.Module): # Since modules are callable pybind recognizes them as functions, but # don't do anything for them return None + if not inspect.isfunction(fn) and not inspect.ismethod(fn): + raise RuntimeError("`{}` is not a function. Recursive scripting only supports " + "Python functions or methods currently.\n" + "Consider manually annotating `{}` with @torch.jit.script.".format(fn, fn)) + # We don't have the actual scope where the function was defined, but we can # extract the necessary info from the closed over variables on the function # object - rcb = createResolutionCallbackFromClosure(fn) + rcb = _jit_internal.createResolutionCallbackFromClosure(fn) return torch.jit.script(fn, _rcb=rcb) @@ -1040,7 +1000,9 @@ def _disable_emit_hooks(): def _create_method_from_fn(module, fn): if _jit_internal.is_ignored_fn(fn): return None - stub = script_method(fn, createResolutionCallbackFromClosure(fn)) + if not inspect.ismethod(fn): + return None + stub = script_method(fn, _jit_internal.createResolutionCallbackFromClosure(fn)) with _disable_emit_hooks(): # We don't want to call the hooks here since the graph that is calling # this function is not yet complete @@ -1101,6 +1063,15 @@ def _qualified_name(obj): return module_name + "." + name +def _is_recursive_script_enabled(value): + # TODO: [enable recursive script] + # when recursive script is made the default, remove this method + enabled = torch._C._jit_recursive_script() + module = inspect.getmodule(value) + if module is not None and 'torch.nn' in module.__name__: + enabled = True + return enabled + @contextlib.contextmanager def _enable_recursive_script(): torch._C._jit_recursive_script(True) @@ -1114,8 +1085,8 @@ def script(obj, optimize=True, _frames_up=0, _rcb=None): if _rcb is None: _rcb = _jit_internal.createResolutionCallback(_frames_up + 1) - if torch._C._jit_recursive_script(): - if isinstance(obj, torch.nn.Module): + if isinstance(obj, torch.nn.Module): + if _is_recursive_script_enabled(obj): return _convert_to_script_module(obj) if inspect.isclass(obj): @@ -1158,21 +1129,6 @@ def script_method(fn, _rcb=None): return ScriptMethodStub(_rcb, ast, fn) -def _try_get_weak_module(mod): - """ - Get the WeakScriptModuleProxy corresponding to mod if it exists - """ - if not isinstance(mod, Module): - return None - return _jit_internal.weak_modules.get(mod) - - -def _is_weak_type(cls): - """ - Check if a type has been annotated with `weak_module` - """ - return cls in _jit_internal.weak_types - # These OrderedDictWrapper classes replace the actual OrderedDicts in # module with versions that get/set properties inside of script::Module. @@ -1569,9 +1525,9 @@ if _enabled: def __setattr__(self, attr, value): if attr not in self._constants_set: - if isinstance(value, Module) and _is_weak_type(type(value)): + if isinstance(value, Module) and _is_recursive_script_enabled(value): # Compile weak script module - value = _make_strong(value) + value = _convert_to_script_module(value) if attr == 'training': if self._c._has_attribute('training'): self.__dict__['training'] = value @@ -1684,7 +1640,7 @@ if _enabled: if isinstance(item, (ModuleList, Sequential)): # These are in __constants__, so ignore them here - if not torch._C._jit_recursive_script(): + if not _is_recursive_script_enabled(item): # For recursive script, these are constantified after # they are used, so they don't need to be in constants. # The `continue` here should be deleted along with @@ -1774,34 +1730,28 @@ else: super(ScriptModule, self).__init__() -def _get_weak_stubs(cls): - """ - Calls script_method for each method that has been annotated with @weak_script - on the type of the object passed in and returns the generated ScriptMethodStubs. - """ - stubs = [] - for name in dir(cls): - func = get_function_from_type(cls, name) - if func in _jit_internal.weak_script_methods: - entry = _jit_internal.weak_script_methods[func] - stub = script_method(entry["original_method"], entry["rcb"]) - stubs.append(stub) - return stubs - - -def _convert_to_script_module(mod, methods=None): +def _convert_to_script_module(mod): """ Makes a ScriptModule from an nn.Module. If `_methods` is provided, these methods are treated as @script_methods. If not, it defaults to `('forward',)`. Methods accessed in forward are scripted on demand if `_enable_recursive_script()` is used. """ + if isinstance(mod, ScriptModule): + return mod + if isinstance(mod, (ModuleList, Sequential)): # Create constant versions for the iterable modules return _create_constant_iterable_module(mod) - if methods is None: - methods = ('forward',) + methods = () + if hasattr(mod, 'forward'): + if mod.forward.__func__ == torch.nn.Module.forward: + # TODO: [enable recursive script] + # forward was not overrided + raise RuntimeError("No forward method was defined on {}".format(mod)) + if not _jit_internal.is_ignored_fn(mod.forward): + methods = ('forward',) exported = [] for name in dir(mod): item = getattr(mod, name) @@ -1812,36 +1762,12 @@ def _convert_to_script_module(mod, methods=None): def make_stub(method): func = get_function_from_type(type(mod), method) - return script_method(func, createResolutionCallbackFromClosure(func)) + return script_method(func, _jit_internal.createResolutionCallbackFromClosure(func)) stubs = list(map(make_stub, methods)) return WeakScriptModuleProxy(mod, stubs) -def _make_strong(mod): - """ - Converts a weak module into a subclass of ScriptModule. If `_methods` is - provided, only these methods are treated as @script_methods. - """ - if mod in _jit_internal.weak_modules: - return _jit_internal.weak_modules[mod] - - cls = type(mod) - # Explicitly annotated weak script - stubs = _jit_internal.weak_types.get(cls)["method_stubs"] - if stubs is None: - # Generate stubs and and store on weak_types in case this type is - # used again - stubs = _get_weak_stubs(cls) - _jit_internal.weak_types[cls]["method_stubs"] = stubs - - proxy = WeakScriptModuleProxy(mod, stubs) - - _jit_internal.weak_modules[mod] = proxy - - return proxy - - def _get_methods(cls): import inspect # In Python 3 unbound methods are functions, but in Python 2 they are methods @@ -1937,13 +1863,13 @@ class _ConstModuleList(ScriptModule): if isinstance(modules, OrderedDict): for key, module in modules.items(): - if _is_weak_type(type(module)): - module = _make_strong(module) + if isinstance(module, torch.nn.Module) and _is_recursive_script_enabled(module): + module = _convert_to_script_module(module) self.add_module(key, module) else: for i, module in enumerate(modules): - if _is_weak_type(type(module)): - module = _make_strong(module) + if isinstance(module, torch.nn.Module) and _is_recursive_script_enabled(module): + module = _convert_to_script_module(module) self.add_module(str(i), module) def __getitem__(self, idx): diff --git a/torch/nn/_functions/vision.py b/torch/nn/_functions/vision.py index 018e93bf0ba..f179f3e43c7 100644 --- a/torch/nn/_functions/vision.py +++ b/torch/nn/_functions/vision.py @@ -1,9 +1,7 @@ import torch import torch.backends.cudnn as cudnn -from ..._jit_internal import weak_script -@weak_script def affine_grid_generator(theta, size): # type: (Tensor, List[int]) -> Tensor if theta.is_cuda and cudnn.enabled and cudnn.is_acceptable(theta) and len(size) == 4 and size[0] < 65536: diff --git a/torch/nn/_reduction.py b/torch/nn/_reduction.py index cc84a9bca95..35cdc694460 100644 --- a/torch/nn/_reduction.py +++ b/torch/nn/_reduction.py @@ -1,10 +1,8 @@ import warnings -from .._jit_internal import weak_script # NB: Keep this file in sync with enums in aten/src/ATen/core/Reduction.h -@weak_script def get_enum(reduction): # type: (str) -> int if reduction == 'none': @@ -26,7 +24,6 @@ def get_enum(reduction): # We use these functions in torch/legacy as well, in which case we'll silence the warning -@weak_script def legacy_get_string(size_average, reduce, emit_warning=True): # type: (Optional[bool], Optional[bool], bool) -> str warning = "size_average and reduce args will be deprecated, please use reduction='{}' instead." @@ -47,7 +44,6 @@ def legacy_get_string(size_average, reduce, emit_warning=True): return ret -@weak_script def legacy_get_enum(size_average, reduce, emit_warning=True): # type: (Optional[bool], Optional[bool], bool) -> int return get_enum(legacy_get_string(size_average, reduce, emit_warning)) diff --git a/torch/nn/functional.py b/torch/nn/functional.py index fd1359dfcc3..c7dc4ec9ca7 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -12,7 +12,7 @@ from ._functions import vision from .modules.utils import _single, _pair, _triple, _list_with_default from . import grad # noqa: F401 from . import _VF -from .._jit_internal import weak_script, List +from .._jit_internal import boolean_dispatch, List conv1d = _add_docstr(torch.conv1d, r""" @@ -299,7 +299,6 @@ Args: """) -@weak_script def fractional_max_pool2d_with_indices(input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None): @@ -346,7 +345,6 @@ def fractional_max_pool2d_with_indices(input, kernel_size, output_size=None, return torch._C._nn.fractional_max_pool2d(input, kernel_size, output_size, _random_samples) -@weak_script def _fractional_max_pool2d(input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None): @@ -355,7 +353,7 @@ def _fractional_max_pool2d(input, kernel_size, output_size=None, output_ratio, return_indices, _random_samples)[0] -fractional_max_pool2d = torch._jit_internal.boolean_dispatch( +fractional_max_pool2d = boolean_dispatch( arg_name='return_indices', arg_index=4, default=False, @@ -365,7 +363,6 @@ fractional_max_pool2d = torch._jit_internal.boolean_dispatch( func_name='fractional_max_pool2d') -@weak_script def fractional_max_pool3d_with_indices(input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None): @@ -414,7 +411,6 @@ def fractional_max_pool3d_with_indices(input, kernel_size, output_size=None, return torch._C._nn.fractional_max_pool3d(input, kernel_size, output_size, _random_samples) -@weak_script def _fractional_max_pool3d(input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None): @@ -423,7 +419,7 @@ def _fractional_max_pool3d(input, kernel_size, output_size=None, output_ratio, return_indices, _random_samples)[0] -fractional_max_pool3d = torch._jit_internal.boolean_dispatch( +fractional_max_pool3d = boolean_dispatch( arg_name='return_indices', arg_index=4, default=False, @@ -433,7 +429,6 @@ fractional_max_pool3d = torch._jit_internal.boolean_dispatch( func_name='fractional_max_pool3d') -@weak_script def max_pool1d_with_indices(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False): # type: (Tensor, BroadcastingList1[int], Optional[BroadcastingList1[int]], BroadcastingList1[int], BroadcastingList1[int], bool, bool) -> Tuple[Tensor, Tensor] # noqa @@ -448,7 +443,6 @@ def max_pool1d_with_indices(input, kernel_size, stride=None, padding=0, input, kernel_size, stride, padding, dilation, ceil_mode) -@weak_script def _max_pool1d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False): # type: (Tensor, BroadcastingList1[int], Optional[BroadcastingList1[int]], BroadcastingList1[int], BroadcastingList1[int], bool, bool) -> Tensor # noqa @@ -457,7 +451,7 @@ def _max_pool1d(input, kernel_size, stride=None, padding=0, dilation=1, return torch.max_pool1d( input, kernel_size, stride, padding, dilation, ceil_mode) -max_pool1d = torch._jit_internal.boolean_dispatch( +max_pool1d = boolean_dispatch( arg_name='return_indices', arg_index=6, default=False, @@ -467,7 +461,6 @@ max_pool1d = torch._jit_internal.boolean_dispatch( func_name='max_pool1d') -@weak_script def max_pool2d_with_indices(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False): # type: (Tensor, BroadcastingList2[int], Optional[BroadcastingList2[int]], BroadcastingList2[int], BroadcastingList2[int], bool, bool) -> Tuple[Tensor, Tensor] # noqa @@ -481,7 +474,6 @@ def max_pool2d_with_indices(input, kernel_size, stride=None, padding=0, dilation return torch._C._nn.max_pool2d_with_indices(input, kernel_size, stride, padding, dilation, ceil_mode) -@weak_script def _max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False): # type: (Tensor, BroadcastingList2[int], Optional[BroadcastingList2[int]], BroadcastingList2[int], BroadcastingList2[int], bool, bool) -> Tensor # noqa @@ -490,7 +482,7 @@ def _max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1, return torch.max_pool2d( input, kernel_size, stride, padding, dilation, ceil_mode) -max_pool2d = torch._jit_internal.boolean_dispatch( +max_pool2d = boolean_dispatch( arg_name='return_indices', arg_index=6, default=False, @@ -500,7 +492,6 @@ max_pool2d = torch._jit_internal.boolean_dispatch( func_name='max_pool2d') -@weak_script def max_pool3d_with_indices(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False): # type: (Tensor, BroadcastingList3[int], Optional[BroadcastingList3[int]], BroadcastingList3[int], BroadcastingList3[int], bool, bool) -> Tuple[Tensor, Tensor] # noqa @@ -515,7 +506,6 @@ def max_pool3d_with_indices(input, kernel_size, stride=None, padding=0, input, kernel_size, stride, padding, dilation, ceil_mode) -@weak_script def _max_pool3d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False): # type: (Tensor, BroadcastingList3[int], Optional[BroadcastingList3[int]], BroadcastingList3[int], BroadcastingList3[int], bool, bool) -> Tensor # noqa @@ -524,7 +514,7 @@ def _max_pool3d(input, kernel_size, stride=None, padding=0, dilation=1, return torch.max_pool3d( input, kernel_size, stride, padding, dilation, ceil_mode) -max_pool3d = torch._jit_internal.boolean_dispatch( +max_pool3d = boolean_dispatch( arg_name='return_indices', arg_index=6, default=False, @@ -534,7 +524,6 @@ max_pool3d = torch._jit_internal.boolean_dispatch( func_name='max_pool3d') -@weak_script def _unpool_output_size(input, kernel_size, stride, padding, output_size): # type: (Tensor, List[int], List[int], List[int], Optional[List[int]]) -> List[int] input_size = input.size() @@ -564,7 +553,6 @@ def _unpool_output_size(input, kernel_size, stride, padding, output_size): return ret -@weak_script def max_unpool1d(input, indices, kernel_size, stride=None, padding=0, output_size=None): # type: (Tensor, Tensor, BroadcastingList1[int], Optional[BroadcastingList1[int]], BroadcastingList1[int], Optional[BroadcastingList1[int]]) -> Tensor # noqa @@ -588,7 +576,6 @@ def max_unpool1d(input, indices, kernel_size, stride=None, padding=0, output_size).squeeze(3) -@weak_script def max_unpool2d(input, indices, kernel_size, stride=None, padding=0, output_size=None): # type: (Tensor, Tensor, BroadcastingList2[int], Optional[BroadcastingList2[int]], BroadcastingList2[int], Optional[BroadcastingList2[int]]) -> Tensor # noqa @@ -607,7 +594,6 @@ def max_unpool2d(input, indices, kernel_size, stride=None, padding=0, return torch._C._nn.max_unpool2d(input, indices, output_size) -@weak_script def max_unpool3d(input, indices, kernel_size, stride=None, padding=0, output_size=None): # type: (Tensor, Tensor, BroadcastingList3[int], Optional[BroadcastingList3[int]], BroadcastingList3[int], Optional[BroadcastingList3[int]]) -> Tensor # noqa @@ -627,7 +613,6 @@ def max_unpool3d(input, indices, kernel_size, stride=None, padding=0, input, indices, output_size, _stride, padding) -@weak_script def lp_pool2d(input, norm_type, kernel_size, stride=None, ceil_mode=False): # type: (Tensor, float, int, Optional[BroadcastingList2[int]], bool) -> Tensor r"""Applies a 2D power-average pooling over an input signal composed of @@ -645,7 +630,6 @@ def lp_pool2d(input, norm_type, kernel_size, stride=None, ceil_mode=False): return (torch.sign(out) * relu(torch.abs(out))).mul(kw * kh).pow(1. / norm_type) -@weak_script def lp_pool1d(input, norm_type, kernel_size, stride=None, ceil_mode=False): # type: (Tensor, float, int, Optional[BroadcastingList1[int]], bool) -> Tensor r"""Applies a 1D power-average pooling over an input signal composed of @@ -662,7 +646,6 @@ def lp_pool1d(input, norm_type, kernel_size, stride=None, ceil_mode=False): return (torch.sign(out) * relu(torch.abs(out))).mul(kernel_size).pow(1. / norm_type) -@weak_script def adaptive_max_pool1d_with_indices(input, output_size, return_indices=False): # type: (Tensor, BroadcastingList1[int], bool) -> Tuple[Tensor, Tensor] r"""Applies a 1D adaptive max pooling over an input signal composed of @@ -677,12 +660,11 @@ def adaptive_max_pool1d_with_indices(input, output_size, return_indices=False): return torch.adaptive_max_pool1d(input, output_size) -@weak_script def _adaptive_max_pool1d(input, output_size, return_indices=False): # type: (Tensor, BroadcastingList1[int], bool) -> Tensor return adaptive_max_pool1d_with_indices(input, output_size)[0] -adaptive_max_pool1d = torch._jit_internal.boolean_dispatch( +adaptive_max_pool1d = boolean_dispatch( arg_name='return_indices', arg_index=2, default=False, @@ -692,7 +674,6 @@ adaptive_max_pool1d = torch._jit_internal.boolean_dispatch( func_name='adaptive_max_pool1d') -@weak_script def adaptive_max_pool2d_with_indices(input, output_size, return_indices=False): # type: (Tensor, BroadcastingList2[int], bool) -> Tuple[Tensor, Tensor] r"""Applies a 2D adaptive max pooling over an input signal composed of @@ -709,12 +690,11 @@ def adaptive_max_pool2d_with_indices(input, output_size, return_indices=False): return torch._C._nn.adaptive_max_pool2d(input, output_size) -@weak_script def _adaptive_max_pool2d(input, output_size, return_indices=False): # type: (Tensor, BroadcastingList2[int], bool) -> Tensor return adaptive_max_pool2d_with_indices(input, output_size)[0] -adaptive_max_pool2d = torch._jit_internal.boolean_dispatch( +adaptive_max_pool2d = boolean_dispatch( arg_name='return_indices', arg_index=2, default=False, @@ -724,7 +704,6 @@ adaptive_max_pool2d = torch._jit_internal.boolean_dispatch( func_name='adaptive_max_pool2d') -@weak_script def adaptive_max_pool3d_with_indices(input, output_size, return_indices=False): # type: (Tensor, BroadcastingList3[int], bool) -> Tuple[Tensor, Tensor] r"""Applies a 3D adaptive max pooling over an input signal composed of @@ -741,12 +720,11 @@ def adaptive_max_pool3d_with_indices(input, output_size, return_indices=False): return torch._C._nn.adaptive_max_pool3d(input, output_size) -@weak_script def _adaptive_max_pool3d(input, output_size, return_indices=False): # type: (Tensor, BroadcastingList3[int], bool) -> Tensor return adaptive_max_pool3d_with_indices(input, output_size)[0] -adaptive_max_pool3d = torch._jit_internal.boolean_dispatch( +adaptive_max_pool3d = boolean_dispatch( arg_name='return_indices', arg_index=2, default=False, @@ -769,7 +747,6 @@ Args: """) -@weak_script def adaptive_avg_pool2d(input, output_size): # type: (Tensor, BroadcastingList2[int]) -> Tensor r""" @@ -786,7 +763,6 @@ def adaptive_avg_pool2d(input, output_size): return torch._C._nn.adaptive_avg_pool2d(input, _output_size) -@weak_script def adaptive_avg_pool3d(input, output_size): # type: (Tensor, BroadcastingList3[int]) -> Tensor r""" @@ -804,7 +780,6 @@ def adaptive_avg_pool3d(input, output_size): # Activation functions -@weak_script def dropout(input, p=0.5, training=True, inplace=False): # type: (Tensor, float, bool, bool) -> Tensor r""" @@ -827,7 +802,6 @@ def dropout(input, p=0.5, training=True, inplace=False): else _VF.dropout(input, p, training)) -@weak_script def alpha_dropout(input, p=0.5, training=False, inplace=False): # type: (Tensor, float, bool, bool) -> Tensor r"""Applies alpha dropout to the input. @@ -842,7 +816,6 @@ def alpha_dropout(input, p=0.5, training=False, inplace=False): else _VF.alpha_dropout(input, p, training)) -@weak_script def dropout2d(input, p=0.5, training=True, inplace=False): # type: (Tensor, float, bool, bool) -> Tensor r""" @@ -867,7 +840,6 @@ def dropout2d(input, p=0.5, training=True, inplace=False): else _VF.feature_dropout(input, p, training)) -@weak_script def dropout3d(input, p=0.5, training=True, inplace=False): # type: (Tensor, float, bool, bool) -> Tensor r""" @@ -894,7 +866,6 @@ def dropout3d(input, p=0.5, training=True, inplace=False): else _VF.feature_dropout(input, p, training)) -@weak_script def feature_alpha_dropout(input, p=0.5, training=False, inplace=False): # type: (Tensor, float, bool, bool) -> Tensor if p < 0. or p > 1.: @@ -905,7 +876,6 @@ def feature_alpha_dropout(input, p=0.5, training=False, inplace=False): else _VF.feature_alpha_dropout(input, p, training)) -@weak_script def threshold(input, threshold, value, inplace=False): # type: (Tensor, float, float, bool) -> Tensor r"""Thresholds each element of the input Tensor. @@ -926,7 +896,6 @@ In-place version of :func:`~threshold`. """) -@weak_script def relu(input, inplace=False): # type: (Tensor, bool) -> Tensor r"""relu(input, inplace=False) -> Tensor @@ -948,7 +917,6 @@ In-place version of :func:`~relu`. """) -@weak_script def glu(input, dim=-1): # type: (Tensor, int) -> Tensor r""" @@ -973,7 +941,6 @@ def glu(input, dim=-1): return torch._C._nn.glu(input, dim) -@weak_script def hardtanh(input, min_val=-1., max_val=1., inplace=False): # type: (Tensor, float, float, bool) -> Tensor r""" @@ -996,7 +963,6 @@ In-place version of :func:`~hardtanh`. """) -@weak_script def relu6(input, inplace=False): # type: (Tensor, bool) -> Tensor r"""relu6(input, inplace=False) -> Tensor @@ -1008,7 +974,6 @@ def relu6(input, inplace=False): return hardtanh(input, 0., 6., inplace) -@weak_script def elu(input, alpha=1., inplace=False): # type: (Tensor, float, bool) -> Tensor r"""Applies element-wise, @@ -1030,7 +995,6 @@ In-place version of :func:`~elu`. """) -@weak_script def selu(input, inplace=False): # type: (Tensor, bool) -> Tensor r"""selu(input, inplace=False) -> Tensor @@ -1056,7 +1020,6 @@ In-place version of :func:`~selu`. """) -@weak_script def celu(input, alpha=1., inplace=False): # type: (Tensor, float, bool) -> Tensor r"""celu(input, alpha=1., inplace=False) -> Tensor @@ -1079,7 +1042,6 @@ In-place version of :func:`~celu`. """) -@weak_script def leaky_relu(input, negative_slope=0.01, inplace=False): # type: (Tensor, float, bool) -> Tensor r""" @@ -1104,7 +1066,6 @@ In-place version of :func:`~leaky_relu`. """) -@weak_script def prelu(input, weight): # type: (Tensor, Tensor) -> Tensor r"""prelu(input, weight) -> Tensor @@ -1118,7 +1079,6 @@ def prelu(input, weight): return torch.prelu(input, weight) -@weak_script def rrelu(input, lower=1. / 8, upper=1. / 3, training=False, inplace=False): # type: (Tensor, float, float, bool, bool) -> Tensor r"""rrelu(input, lower=1./8, upper=1./3, training=False, inplace=False) -> Tensor @@ -1148,7 +1108,6 @@ Applies element-wise :math:`\text{LogSigmoid}(x_i) = \log \left(\frac{1}{1 + \ex See :class:`~torch.nn.LogSigmoid` for more details. """) -@weak_script def gelu(input): r"""gelu(input) -> Tensor @@ -1162,7 +1121,6 @@ def gelu(input): return torch._C._nn.gelu(input) -@weak_script def hardshrink(input, lambd=0.5): # type: (Tensor, float) -> Tensor r""" @@ -1175,7 +1133,6 @@ def hardshrink(input, lambd=0.5): return torch.hardshrink(input, lambd) -@weak_script def tanhshrink(input): r"""tanhshrink(input) -> Tensor @@ -1186,7 +1143,6 @@ def tanhshrink(input): return input - input.tanh() -@weak_script def softsign(input): r"""softsign(input) -> Tensor @@ -1202,7 +1158,6 @@ softplus(input, beta=1, threshold=20) -> Tensor """) -@weak_script def _get_softmax_dim(name, ndim, stacklevel): # type: (str, int, int) -> int warnings.warn("Implicit dimension choice for {} has been deprecated. " @@ -1214,7 +1169,6 @@ def _get_softmax_dim(name, ndim, stacklevel): return ret -@weak_script def softmin(input, dim=None, _stacklevel=3, dtype=None): # type: (Tensor, Optional[int], int, Optional[int]) -> Tensor r"""Applies a softmin function. @@ -1240,7 +1194,6 @@ def softmin(input, dim=None, _stacklevel=3, dtype=None): return ret -@weak_script def softmax(input, dim=None, _stacklevel=3, dtype=None): # type: (Tensor, Optional[int], int, Optional[int]) -> Tensor r"""Applies a softmax function. @@ -1276,7 +1229,6 @@ def softmax(input, dim=None, _stacklevel=3, dtype=None): return ret -@weak_script def gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1): # type: (Tensor, float, bool, float, int) -> Tensor r""" @@ -1337,7 +1289,6 @@ def gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1): return ret -@weak_script def log_softmax(input, dim=None, _stacklevel=3, dtype=None): # type: (Tensor, Optional[int], int, Optional[int]) -> Tensor r"""Applies a softmax followed by a logarithm. @@ -1373,7 +1324,6 @@ See :class:`~torch.nn.Softshrink` for more details. """) -@weak_script def tanh(input): r"""tanh(input) -> Tensor @@ -1386,7 +1336,6 @@ def tanh(input): return input.tanh() -@weak_script def sigmoid(input): r"""sigmoid(input) -> Tensor @@ -1398,7 +1347,6 @@ def sigmoid(input): return input.sigmoid() -@weak_script def linear(input, weight, bias=None): # type: (Tensor, Tensor, Optional[Tensor]) -> Tensor r""" @@ -1423,7 +1371,6 @@ def linear(input, weight, bias=None): return ret -@weak_script def bilinear(input1, input2, weight, bias=None): # type: (Tensor, Tensor, Tensor, Optional[Tensor]) -> Tensor return torch.bilinear(input1, input2, weight, bias) @@ -1435,7 +1382,6 @@ def _no_grad_embedding_renorm_(weight, input, max_norm, norm_type): torch.embedding_renorm_(weight, input, max_norm, norm_type) -@weak_script def embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2., scale_grad_by_freq=False, sparse=False): # type: (Tensor, Tensor, Optional[int], Optional[float], float, bool, bool) -> Tensor @@ -1517,7 +1463,6 @@ def embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2., return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse) -@weak_script def embedding_bag(input, weight, offsets=None, max_norm=None, norm_type=2, scale_grad_by_freq=False, mode='mean', sparse=False, per_sample_weights=None): @@ -1677,7 +1622,6 @@ def embedding_bag(input, weight, offsets=None, max_norm=None, norm_type=2, return ret -@weak_script def batch_norm(input, running_mean, running_var, weight=None, bias=None, training=False, momentum=0.1, eps=1e-5): # type: (Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor], bool, float, float) -> Tensor # noqa @@ -1709,7 +1653,6 @@ def batch_norm(input, running_mean, running_var, weight=None, bias=None, ) -@weak_script def instance_norm(input, running_mean=None, running_var=None, weight=None, bias=None, use_input_stats=True, momentum=0.1, eps=1e-5): # type: (Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor], bool, float, float) -> Tensor # noqa @@ -1725,7 +1668,6 @@ def instance_norm(input, running_mean=None, running_var=None, weight=None, ) -@weak_script def layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-5): # type: (Tensor, List[int], Optional[Tensor], Optional[Tensor], float) -> Tensor r"""Applies Layer Normalization for last certain number of dimensions. @@ -1736,7 +1678,6 @@ def layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-5): torch.backends.cudnn.enabled) -@weak_script def group_norm(input, num_groups, weight=None, bias=None, eps=1e-5): # type: (Tensor, int, Optional[Tensor], Optional[Tensor], float) -> Tensor r"""Applies Group Normalization for last certain number of dimensions. @@ -1747,7 +1688,6 @@ def group_norm(input, num_groups, weight=None, bias=None, eps=1e-5): torch.backends.cudnn.enabled) -@weak_script def local_response_norm(input, size, alpha=1e-4, beta=0.75, k=1.): # type: (Tensor, int, float, float, float) -> Tensor r"""Applies local response normalization over an input signal composed of @@ -1776,7 +1716,6 @@ def local_response_norm(input, size, alpha=1e-4, beta=0.75, k=1.): # loss -@weak_script def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=0, reduction='mean', zero_infinity=False): # type: (Tensor, Tensor, Tensor, Tensor, int, str, bool) -> Tensor @@ -1824,7 +1763,6 @@ def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=0, zero_infinity) -@weak_script def nll_loss(input, target, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean'): # type: (Tensor, Tensor, Optional[Tensor], Optional[bool], int, Optional[bool], str) -> Tensor @@ -1903,7 +1841,6 @@ def nll_loss(input, target, weight=None, size_average=None, ignore_index=-100, return ret -@weak_script def poisson_nll_loss(input, target, log_input=True, full=False, size_average=None, eps=1e-8, reduce=None, reduction='mean'): # type: (Tensor, Tensor, bool, bool, Optional[bool], float, Optional[bool], str) -> Tensor @@ -1949,7 +1886,6 @@ def poisson_nll_loss(input, target, log_input=True, full=False, size_average=Non return ret -@weak_script def kl_div(input, target, size_average=None, reduce=None, reduction='mean'): # type: (Tensor, Tensor, Optional[bool], Optional[bool], str) -> Tensor r"""The `Kullback-Leibler divergence`_ Loss. @@ -2007,7 +1943,6 @@ def kl_div(input, target, size_average=None, reduce=None, reduction='mean'): return reduced -@weak_script def cross_entropy(input, target, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean'): # type: (Tensor, Tensor, Optional[Tensor], Optional[bool], int, Optional[bool], str) -> Tensor @@ -2056,7 +1991,6 @@ def cross_entropy(input, target, weight=None, size_average=None, ignore_index=-1 return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction) -@weak_script def binary_cross_entropy(input, target, weight=None, size_average=None, reduce=None, reduction='mean'): # type: (Tensor, Tensor, Optional[Tensor], Optional[bool], Optional[bool], str) -> Tensor @@ -2113,7 +2047,6 @@ def binary_cross_entropy(input, target, weight=None, size_average=None, input, target, weight, reduction_enum) -@weak_script def binary_cross_entropy_with_logits(input, target, weight=None, size_average=None, reduce=None, reduction='mean', pos_weight=None): # type: (Tensor, Tensor, Optional[Tensor], Optional[bool], Optional[bool], str, Optional[Tensor]) -> Tensor @@ -2174,14 +2107,12 @@ def _pointwise_loss(lambd, lambd_optimized, input, target, reduction='mean'): return lambd_optimized(expanded_input, expanded_target, _Reduction.get_enum(reduction)) -@weak_script def _smooth_l1_loss(input, target): # type: (Tensor, Tensor) -> Tensor t = torch.abs(input - target) return torch.where(t < 1, 0.5 * t ** 2, t - 0.5) -@weak_script def smooth_l1_loss(input, target, size_average=None, reduce=None, reduction='mean'): # type: (Tensor, Tensor, Optional[bool], Optional[bool], str) -> Tensor r"""Function that uses a squared term if the absolute @@ -2206,7 +2137,6 @@ def smooth_l1_loss(input, target, size_average=None, reduce=None, reduction='mea return ret -@weak_script def l1_loss(input, target, size_average=None, reduce=None, reduction='mean'): # type: (Tensor, Tensor, Optional[bool], Optional[bool], str) -> Tensor r"""l1_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor @@ -2232,7 +2162,6 @@ def l1_loss(input, target, size_average=None, reduce=None, reduction='mean'): return ret -@weak_script def mse_loss(input, target, size_average=None, reduce=None, reduction='mean'): # type: (Tensor, Tensor, Optional[bool], Optional[bool], str) -> Tensor r"""mse_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor @@ -2258,7 +2187,6 @@ def mse_loss(input, target, size_average=None, reduce=None, reduction='mean'): return ret -@weak_script def margin_ranking_loss(input1, input2, target, margin=0, size_average=None, reduce=None, reduction='mean'): # type: (Tensor, Tensor, Tensor, float, Optional[bool], Optional[bool], str) -> Tensor @@ -2276,7 +2204,6 @@ def margin_ranking_loss(input1, input2, target, margin=0, size_average=None, return torch.margin_ranking_loss(input1, input2, target, margin, reduction_enum) -@weak_script def hinge_embedding_loss(input, target, margin=1.0, size_average=None, reduce=None, reduction='mean'): # type: (Tensor, Tensor, float, Optional[bool], Optional[bool], str) -> Tensor @@ -2291,7 +2218,6 @@ def hinge_embedding_loss(input, target, margin=1.0, size_average=None, return torch.hinge_embedding_loss(input, target, margin, reduction_enum) -@weak_script def multilabel_margin_loss(input, target, size_average=None, reduce=None, reduction='mean'): # type: (Tensor, Tensor, Optional[bool], Optional[bool], str) -> Tensor r"""multilabel_margin_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor @@ -2305,7 +2231,6 @@ def multilabel_margin_loss(input, target, size_average=None, reduce=None, reduct return torch._C._nn.multilabel_margin_loss(input, target, reduction_enum) -@weak_script def soft_margin_loss(input, target, size_average=None, reduce=None, reduction='mean'): # type: (Tensor, Tensor, Optional[bool], Optional[bool], str) -> Tensor r"""soft_margin_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor @@ -2319,7 +2244,6 @@ def soft_margin_loss(input, target, size_average=None, reduce=None, reduction='m return torch._C._nn.soft_margin_loss(input, target, reduction_enum) -@weak_script def multilabel_soft_margin_loss(input, target, weight=None, size_average=None, reduce=None, reduction='mean'): # type: (Tensor, Tensor, Optional[Tensor], Optional[bool], Optional[bool], str) -> Tensor @@ -2349,7 +2273,6 @@ def multilabel_soft_margin_loss(input, target, weight=None, size_average=None, return ret -@weak_script def cosine_embedding_loss(input1, input2, target, margin=0, size_average=None, reduce=None, reduction='mean'): # type: (Tensor, Tensor, Tensor, float, Optional[bool], Optional[bool], str) -> Tensor @@ -2364,7 +2287,6 @@ def cosine_embedding_loss(input1, input2, target, margin=0, size_average=None, return torch.cosine_embedding_loss(input1, input2, target, margin, reduction_enum) -@weak_script def multi_margin_loss(input, target, p=1, margin=1., weight=None, size_average=None, reduce=None, reduction='mean'): # type: (Tensor, Tensor, int, float, Optional[Tensor], Optional[bool], Optional[bool], str) -> Tensor @@ -2635,7 +2557,6 @@ GRID_SAMPLE_PADDING_MODES = { } -@weak_script def grid_sample(input, grid, mode='bilinear', padding_mode='zeros'): # type: (Tensor, Tensor, str, str) -> Tensor r"""Given an :attr:`input` and a flow-field :attr:`grid`, computes the @@ -2717,7 +2638,6 @@ def grid_sample(input, grid, mode='bilinear', padding_mode='zeros'): return torch.grid_sampler(input, grid, mode_enum, padding_mode_enum) -@weak_script def affine_grid(theta, size): # type: (Tensor, List[int]) -> Tensor r"""Generates a 2d flow field, given a batch of affine matrices :attr:`theta`. @@ -2735,7 +2655,6 @@ def affine_grid(theta, size): return vision.affine_grid_generator(theta, size) -@weak_script def pad(input, pad, mode='constant', value=0): # type: (Tensor, List[int], str, float) -> Tensor r"""Pads tensor. @@ -2844,7 +2763,6 @@ def pad(input, pad, mode='constant', value=0): # distance -@weak_script def pairwise_distance(x1, x2, p=2., eps=1e-6, keepdim=False): # type: (Tensor, Tensor, float, float, bool) -> Tensor r""" @@ -2952,7 +2870,6 @@ Examples: """) -@weak_script def triplet_margin_loss(anchor, positive, negative, margin=1.0, p=2, eps=1e-6, swap=False, size_average=None, reduce=None, reduction="mean"): # type: (Tensor, Tensor, Tensor, float, float, float, bool, Optional[bool], Optional[bool], str) -> Tensor @@ -2967,7 +2884,6 @@ def triplet_margin_loss(anchor, positive, negative, margin=1.0, p=2, eps=1e-6, s swap, reduction_enum) -@weak_script def normalize(input, p=2, dim=1, eps=1e-12, out=None): # type: (Tensor, float, int, float, Optional[Tensor]) -> Tensor r"""Performs :math:`L_p` normalization of inputs over specified dimension. @@ -3001,7 +2917,6 @@ def assert_int_or_pair(arg, arg_name, message): assert isinstance(arg, int) or len(arg) == 2, message.format(arg_name) -@weak_script def unfold(input, kernel_size, dilation=1, padding=0, stride=1): # type: (Tensor, BroadcastingList2[int], BroadcastingList2[int], BroadcastingList2[int], BroadcastingList2[int]) -> Tensor # noqa r"""Extracts sliding local blocks from an batched input tensor. @@ -3036,7 +2951,6 @@ def unfold(input, kernel_size, dilation=1, padding=0, stride=1): return ret -@weak_script def fold(input, output_size, kernel_size, dilation=1, padding=0, stride=1): # type: (Tensor, BroadcastingList2[int], BroadcastingList2[int], BroadcastingList2[int], BroadcastingList2[int], BroadcastingList2[int]) -> Tensor # noqa r"""Combines an array of sliding local blocks into a large containing @@ -3064,7 +2978,6 @@ def fold(input, output_size, kernel_size, dilation=1, padding=0, stride=1): return ret -@weak_script def _pad_circular(input, padding): # type: (Tensor, List[int]) -> Tensor """ @@ -3090,7 +3003,6 @@ def _pad_circular(input, padding): return input -@weak_script def multi_head_attention_forward(query, # type: Tensor key, # type: Tensor value, # type: Tensor @@ -3135,8 +3047,8 @@ def multi_head_attention_forward(query, # type: Tensor need_weights: output attn_output_weights. attn_mask: mask that prevents attention to certain positions. This is an additive mask (i.e. the values will be added to the attention layer). - use_separate_proj_weight: the function accept the proj. weights for query, key, - and value in differnt forms. If false, in_proj_weight will be used, which is + use_separate_proj_weight: the function accept the proj. weights for query, key, + and value in differnt forms. If false, in_proj_weight will be used, which is a combination of q_proj_weight, k_proj_weight, v_proj_weight. q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias. static_k, static_v: static key and value used for attention operators. @@ -3152,9 +3064,9 @@ def multi_head_attention_forward(query, # type: Tensor the embedding dimension. - key_padding_mask: :math:`(N, S)`, ByteTensor, where N is the batch size, S is the source sequence length. - attn_mask: :math:`(L, S)` where L is the target sequence length, S is the source sequence length. - - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length, + - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length, N is the batch size, E is the embedding dimension. E/num_heads is the head dimension. - - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length, + - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length, N is the batch size, E is the embedding dimension. E/num_heads is the head dimension. Outputs: @@ -3285,12 +3197,12 @@ def multi_head_attention_forward(query, # type: Tensor v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) if static_k is not None: - assert static_k.size(0) == bsz * num_heads + assert static_k.size(0) == bsz * num_heads assert static_k.size(2) == head_dim k = static_k if static_v is not None: - assert static_v.size(0) == bsz * num_heads + assert static_v.size(0) == bsz * num_heads assert static_v.size(2) == head_dim v = static_v diff --git a/torch/nn/init.py b/torch/nn/init.py index 5c8863b9d00..2a7a9e8e47e 100644 --- a/torch/nn/init.py +++ b/torch/nn/init.py @@ -4,7 +4,6 @@ import math import warnings import torch -from .._jit_internal import weak_script # These no_grad_* functions are necessary as wrappers around the parts of these # functions that use `with torch.no_grad()`. The JIT doesn't support context @@ -72,7 +71,6 @@ def calculate_gain(nonlinearity, param=None): raise ValueError("Unsupported nonlinearity {}".format(nonlinearity)) -@weak_script def uniform_(tensor, a=0., b=1.): # type: (Tensor, float, float) -> Tensor r"""Fills the input Tensor with values drawn from the uniform @@ -90,7 +88,6 @@ def uniform_(tensor, a=0., b=1.): return _no_grad_uniform_(tensor, a, b) -@weak_script def normal_(tensor, mean=0., std=1.): # type: (Tensor, float, float) -> Tensor r"""Fills the input Tensor with values drawn from the normal @@ -108,7 +105,6 @@ def normal_(tensor, mean=0., std=1.): return _no_grad_normal_(tensor, mean, std) -@weak_script def constant_(tensor, val): # type: (Tensor, float) -> Tensor r"""Fills the input Tensor with the value :math:`\text{val}`. @@ -124,7 +120,6 @@ def constant_(tensor, val): return _no_grad_fill_(tensor, val) -@weak_script def ones_(tensor): # type: (Tensor) -> Tensor r"""Fills the input Tensor with ones`. @@ -139,7 +134,6 @@ def ones_(tensor): return _no_grad_fill_(tensor, 1.) -@weak_script def zeros_(tensor): # type: (Tensor) -> Tensor r"""Fills the input Tensor with zeros`. @@ -205,7 +199,6 @@ def dirac_(tensor): return tensor -@weak_script def _calculate_fan_in_and_fan_out(tensor): dimensions = tensor.dim() if dimensions < 2: @@ -226,7 +219,6 @@ def _calculate_fan_in_and_fan_out(tensor): return fan_in, fan_out -@weak_script def xavier_uniform_(tensor, gain=1.): # type: (Tensor, float) -> Tensor r"""Fills the input `Tensor` with values according to the method @@ -255,7 +247,6 @@ def xavier_uniform_(tensor, gain=1.): return _no_grad_uniform_(tensor, -a, a) -@weak_script def xavier_normal_(tensor, gain=1.): # type: (Tensor, float) -> Tensor r"""Fills the input `Tensor` with values according to the method diff --git a/torch/nn/modules/activation.py b/torch/nn/modules/activation.py index c6b165728ba..b337775f705 100644 --- a/torch/nn/modules/activation.py +++ b/torch/nn/modules/activation.py @@ -7,10 +7,8 @@ from torch.nn.init import xavier_normal_ from torch.nn.parameter import Parameter from .module import Module from .. import functional as F -from ..._jit_internal import weak_module, weak_script_method -@weak_module class Threshold(Module): r"""Thresholds each element of the input Tensor. @@ -48,7 +46,6 @@ class Threshold(Module): self.inplace = inplace # TODO: check in THNN (if inplace == True, then assert value <= threshold) - @weak_script_method def forward(self, input): return F.threshold(input, self.threshold, self.value, self.inplace) @@ -59,7 +56,6 @@ class Threshold(Module): ) -@weak_module class ReLU(Module): r"""Applies the rectified linear unit function element-wise: @@ -94,7 +90,6 @@ class ReLU(Module): super(ReLU, self).__init__() self.inplace = inplace - @weak_script_method def forward(self, input): return F.relu(input, inplace=self.inplace) @@ -103,7 +98,6 @@ class ReLU(Module): return inplace_str -@weak_module class RReLU(Module): r"""Applies the randomized leaky rectified liner unit function, element-wise, as described in the paper: @@ -151,7 +145,6 @@ class RReLU(Module): self.upper = upper self.inplace = inplace - @weak_script_method def forward(self, input): return F.rrelu(input, self.lower, self.upper, self.training, self.inplace) @@ -160,7 +153,6 @@ class RReLU(Module): return 'lower={}, upper={}{}'.format(self.lower, self.upper, inplace_str) -@weak_module class Hardtanh(Module): r"""Applies the HardTanh function element-wise @@ -213,7 +205,6 @@ class Hardtanh(Module): self.inplace = inplace assert self.max_val > self.min_val - @weak_script_method def forward(self, input): return F.hardtanh(input, self.min_val, self.max_val, self.inplace) @@ -224,7 +215,6 @@ class Hardtanh(Module): ) -@weak_module class ReLU6(Hardtanh): r"""Applies the element-wise function: @@ -256,7 +246,6 @@ class ReLU6(Hardtanh): return inplace_str -@weak_module class Sigmoid(Module): r"""Applies the element-wise function: @@ -278,12 +267,10 @@ class Sigmoid(Module): >>> output = m(input) """ - @weak_script_method def forward(self, input): return torch.sigmoid(input) -@weak_module class Tanh(Module): r"""Applies the element-wise function: @@ -304,12 +291,10 @@ class Tanh(Module): >>> output = m(input) """ - @weak_script_method def forward(self, input): return torch.tanh(input) -@weak_module class ELU(Module): r"""Applies the element-wise function: @@ -340,7 +325,6 @@ class ELU(Module): self.alpha = alpha self.inplace = inplace - @weak_script_method def forward(self, input): return F.elu(input, self.alpha, self.inplace) @@ -349,7 +333,6 @@ class ELU(Module): return 'alpha={}{}'.format(self.alpha, inplace_str) -@weak_module class CELU(Module): r"""Applies the element-wise function: @@ -385,7 +368,6 @@ class CELU(Module): self.alpha = alpha self.inplace = inplace - @weak_script_method def forward(self, input): return F.celu(input, self.alpha, self.inplace) @@ -394,7 +376,6 @@ class CELU(Module): return 'alpha={}{}'.format(self.alpha, inplace_str) -@weak_module class SELU(Module): r"""Applied element-wise, as: @@ -430,7 +411,6 @@ class SELU(Module): super(SELU, self).__init__() self.inplace = inplace - @weak_script_method def forward(self, input): return F.selu(input, self.inplace) @@ -439,7 +419,6 @@ class SELU(Module): return inplace_str -@weak_module class GLU(Module): r"""Applies the gated linear unit function :math:`{GLU}(a, b)= a \otimes \sigma(b)` where :math:`a` is the first half @@ -465,7 +444,6 @@ class GLU(Module): super(GLU, self).__init__() self.dim = dim - @weak_script_method def forward(self, input): return F.glu(input, self.dim) @@ -473,7 +451,6 @@ class GLU(Module): return 'dim={}'.format(self.dim) -@weak_module class Hardshrink(Module): r"""Applies the hard shrinkage function element-wise: @@ -507,7 +484,6 @@ class Hardshrink(Module): super(Hardshrink, self).__init__() self.lambd = lambd - @weak_script_method def forward(self, input): return F.hardshrink(input, self.lambd) @@ -515,7 +491,6 @@ class Hardshrink(Module): return '{}'.format(self.lambd) -@weak_module class LeakyReLU(Module): r"""Applies the element-wise function: @@ -556,7 +531,6 @@ class LeakyReLU(Module): self.negative_slope = negative_slope self.inplace = inplace - @weak_script_method def forward(self, input): return F.leaky_relu(input, self.negative_slope, self.inplace) @@ -565,7 +539,6 @@ class LeakyReLU(Module): return 'negative_slope={}{}'.format(self.negative_slope, inplace_str) -@weak_module class LogSigmoid(Module): r"""Applies the element-wise function: @@ -586,12 +559,10 @@ class LogSigmoid(Module): >>> output = m(input) """ - @weak_script_method def forward(self, input): return F.logsigmoid(input) -@weak_module class Softplus(Module): r"""Applies the element-wise function: @@ -628,7 +599,6 @@ class Softplus(Module): self.beta = beta self.threshold = threshold - @weak_script_method def forward(self, input): return F.softplus(input, self.beta, self.threshold) @@ -636,7 +606,6 @@ class Softplus(Module): return 'beta={}, threshold={}'.format(self.beta, self.threshold) -@weak_module class Softshrink(Module): r"""Applies the soft shrinkage function elementwise: @@ -670,7 +639,6 @@ class Softshrink(Module): super(Softshrink, self).__init__() self.lambd = lambd - @weak_script_method def forward(self, input): return F.softshrink(input, self.lambd) @@ -678,7 +646,6 @@ class Softshrink(Module): return str(self.lambd) -@weak_module class MultiheadAttention(Module): r"""Allows the model to jointly attend to information from different representation subspaces. @@ -759,7 +726,6 @@ class MultiheadAttention(Module): if self.bias_v is not None: xavier_normal_(self.bias_v) - @weak_script_method def forward(self, query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None): r""" @@ -817,7 +783,6 @@ class MultiheadAttention(Module): attn_mask=attn_mask) -@weak_module class PReLU(Module): r"""Applies the element-wise function: @@ -874,7 +839,6 @@ class PReLU(Module): super(PReLU, self).__init__() self.weight = Parameter(torch.Tensor(num_parameters).fill_(init)) - @weak_script_method def forward(self, input): return F.prelu(input, self.weight) @@ -882,7 +846,6 @@ class PReLU(Module): return 'num_parameters={}'.format(self.num_parameters) -@weak_module class Softsign(Module): r"""Applies the element-wise function: @@ -903,12 +866,10 @@ class Softsign(Module): >>> output = m(input) """ - @weak_script_method def forward(self, input): return F.softsign(input) -@weak_module class Tanhshrink(Module): r"""Applies the element-wise function: @@ -929,12 +890,10 @@ class Tanhshrink(Module): >>> output = m(input) """ - @weak_script_method def forward(self, input): return F.tanhshrink(input) -@weak_module class Softmin(Module): r"""Applies the Softmin function to an n-dimensional input Tensor rescaling them so that the elements of the n-dimensional output Tensor @@ -970,12 +929,10 @@ class Softmin(Module): super(Softmin, self).__init__() self.dim = dim - @weak_script_method def forward(self, input): return F.softmin(input, self.dim, _stacklevel=5) -@weak_module class Softmax(Module): r"""Applies the Softmax function to an n-dimensional input Tensor rescaling them so that the elements of the n-dimensional output Tensor @@ -1021,7 +978,6 @@ class Softmax(Module): if not hasattr(self, 'dim'): self.dim = None - @weak_script_method def forward(self, input): return F.softmax(input, self.dim, _stacklevel=5) @@ -1029,7 +985,6 @@ class Softmax(Module): return 'dim={dim}'.format(dim=self.dim) -@weak_module class Softmax2d(Module): r"""Applies SoftMax over features to each spatial location. @@ -1052,13 +1007,11 @@ class Softmax2d(Module): >>> output = m(input) """ - @weak_script_method def forward(self, input): assert input.dim() == 4, 'Softmax2d requires a 4D tensor as input' return F.softmax(input, 1, _stacklevel=5) -@weak_module class LogSoftmax(Module): r"""Applies the :math:`\log(\text{Softmax}(x))` function to an n-dimensional input Tensor. The LogSoftmax formulation can be simplified as: @@ -1095,6 +1048,5 @@ class LogSoftmax(Module): if not hasattr(self, 'dim'): self.dim = None - @weak_script_method def forward(self, input): return F.log_softmax(input, self.dim, _stacklevel=5) diff --git a/torch/nn/modules/batchnorm.py b/torch/nn/modules/batchnorm.py index 8728dd4dd62..163adb9ee65 100644 --- a/torch/nn/modules/batchnorm.py +++ b/torch/nn/modules/batchnorm.py @@ -6,12 +6,10 @@ from .module import Module from torch.nn.parameter import Parameter from .. import functional as F from .. import init -from ..._jit_internal import weak_module, weak_script_method # TODO: check contiguous in THNN # TODO: use separate backend functions? -@weak_module class _BatchNorm(Module): _version = 2 __constants__ = ['track_running_stats', 'momentum', 'eps', 'weight', 'bias', @@ -57,7 +55,6 @@ class _BatchNorm(Module): def _check_input_dim(self, input): raise NotImplementedError - @weak_script_method def forward(self, input): self._check_input_dim(input) @@ -103,7 +100,6 @@ class _BatchNorm(Module): missing_keys, unexpected_keys, error_msgs) -@weak_module class BatchNorm1d(_BatchNorm): r"""Applies Batch Normalization over a 2D or 3D input (a mini-batch of 1D inputs with optional additional channel dimension) as described in the paper @@ -170,14 +166,12 @@ class BatchNorm1d(_BatchNorm): https://arxiv.org/abs/1502.03167 """ - @weak_script_method def _check_input_dim(self, input): if input.dim() != 2 and input.dim() != 3: raise ValueError('expected 2D or 3D input (got {}D input)' .format(input.dim())) -@weak_module class BatchNorm2d(_BatchNorm): r"""Applies Batch Normalization over a 4D input (a mini-batch of 2D inputs with additional channel dimension) as described in the paper @@ -244,14 +238,12 @@ class BatchNorm2d(_BatchNorm): https://arxiv.org/abs/1502.03167 """ - @weak_script_method def _check_input_dim(self, input): if input.dim() != 4: raise ValueError('expected 4D input (got {}D input)' .format(input.dim())) -@weak_module class BatchNorm3d(_BatchNorm): r"""Applies Batch Normalization over a 5D input (a mini-batch of 3D inputs with additional channel dimension) as described in the paper @@ -319,7 +311,6 @@ class BatchNorm3d(_BatchNorm): https://arxiv.org/abs/1502.03167 """ - @weak_script_method def _check_input_dim(self, input): if input.dim() != 5: raise ValueError('expected 5D input (got {}D input)' diff --git a/torch/nn/modules/conv.py b/torch/nn/modules/conv.py index 7bb10c7b22b..c5be9c16d6c 100644 --- a/torch/nn/modules/conv.py +++ b/torch/nn/modules/conv.py @@ -6,10 +6,9 @@ from .. import functional as F from .. import init from .module import Module from .utils import _single, _pair, _triple -from ..._jit_internal import weak_module, weak_script_method, List +from ..._jit_internal import List -@weak_module class _ConvNd(Module): __constants__ = ['stride', 'padding', 'dilation', 'groups', 'bias', @@ -74,7 +73,6 @@ class _ConvNd(Module): self.padding_mode = 'zeros' -@weak_module class Conv1d(_ConvNd): r"""Applies a 1D convolution over an input signal composed of several input planes. @@ -192,7 +190,6 @@ class Conv1d(_ConvNd): in_channels, out_channels, kernel_size, stride, padding, dilation, False, _single(0), groups, bias, padding_mode) - @weak_script_method def forward(self, input): if self.padding_mode == 'circular': expanded_padding = ((self.padding[0] + 1) // 2, self.padding[0] // 2) @@ -203,7 +200,6 @@ class Conv1d(_ConvNd): self.padding, self.dilation, self.groups) -@weak_module class Conv2d(_ConvNd): r"""Applies a 2D convolution over an input signal composed of several input planes. @@ -333,7 +329,6 @@ class Conv2d(_ConvNd): in_channels, out_channels, kernel_size, stride, padding, dilation, False, _pair(0), groups, bias, padding_mode) - @weak_script_method def forward(self, input): if self.padding_mode == 'circular': expanded_padding = ((self.padding[1] + 1) // 2, self.padding[1] // 2, @@ -345,7 +340,6 @@ class Conv2d(_ConvNd): self.padding, self.dilation, self.groups) -@weak_module class Conv3d(_ConvNd): r"""Applies a 3D convolution over an input signal composed of several input planes. @@ -470,7 +464,6 @@ class Conv3d(_ConvNd): in_channels, out_channels, kernel_size, stride, padding, dilation, False, _triple(0), groups, bias, padding_mode) - @weak_script_method def forward(self, input): if self.padding_mode == 'circular': expanded_padding = ((self.padding[2] + 1) // 2, self.padding[2] // 2, @@ -483,9 +476,7 @@ class Conv3d(_ConvNd): self.padding, self.dilation, self.groups) -@weak_module class _ConvTransposeMixin(object): - @weak_script_method def forward(self, input, output_size=None): # type(Tensor, Optional[List[int]]) -> Tensor output_padding = self._output_padding(input, output_size, self.stride, self.padding, self.kernel_size) @@ -497,7 +488,6 @@ class _ConvTransposeMixin(object): else: return func(input, self.weight, self.bias) - @weak_script_method def _output_padding(self, input, output_size, stride, padding, kernel_size): # type: (Tensor, Optional[List[int]], List[int], List[int], List[int]) -> List[int] if output_size is None: @@ -537,7 +527,6 @@ class _ConvTransposeMixin(object): return ret -@weak_module class ConvTranspose1d(_ConvTransposeMixin, _ConvNd): r"""Applies a 1D transposed convolution operator over an input image composed of several input planes. @@ -638,7 +627,6 @@ class ConvTranspose1d(_ConvTransposeMixin, _ConvNd): in_channels, out_channels, kernel_size, stride, padding, dilation, True, output_padding, groups, bias, padding_mode) - @weak_script_method def forward(self, input, output_size=None): # type: (Tensor, Optional[List[int]]) -> Tensor if self.padding_mode != 'zeros': @@ -650,7 +638,6 @@ class ConvTranspose1d(_ConvTransposeMixin, _ConvNd): output_padding, self.groups, self.dilation) -@weak_module class ConvTranspose2d(_ConvTransposeMixin, _ConvNd): r"""Applies a 2D transposed convolution operator over an input image composed of several input planes. @@ -786,7 +773,6 @@ class ConvTranspose2d(_ConvTransposeMixin, _ConvNd): in_channels, out_channels, kernel_size, stride, padding, dilation, True, output_padding, groups, bias, padding_mode) - @weak_script_method def forward(self, input, output_size=None): # type: (Tensor, Optional[List[int]]) -> Tensor if self.padding_mode != 'zeros': @@ -799,7 +785,6 @@ class ConvTranspose2d(_ConvTransposeMixin, _ConvNd): output_padding, self.groups, self.dilation) -@weak_module class ConvTranspose3d(_ConvTransposeMixin, _ConvNd): r"""Applies a 3D transposed convolution operator over an input image composed of several input planes. @@ -931,7 +916,6 @@ class ConvTranspose3d(_ConvTransposeMixin, _ConvNd): in_channels, out_channels, kernel_size, stride, padding, dilation, True, output_padding, groups, bias, padding_mode) - @weak_script_method def forward(self, input, output_size=None): # type: (Tensor, Optional[List[int]]) -> Tensor if self.padding_mode != 'zeros': diff --git a/torch/nn/modules/distance.py b/torch/nn/modules/distance.py index 29edfdd3843..c9cf0580ec6 100644 --- a/torch/nn/modules/distance.py +++ b/torch/nn/modules/distance.py @@ -1,9 +1,7 @@ from .module import Module from .. import functional as F -from ..._jit_internal import weak_module, weak_script_method -@weak_module class PairwiseDistance(Module): r""" Computes the batchwise pairwise distance between vectors :math:`v_1`, :math:`v_2` using the p-norm: @@ -35,12 +33,10 @@ class PairwiseDistance(Module): self.eps = eps self.keepdim = keepdim - @weak_script_method def forward(self, x1, x2): return F.pairwise_distance(x1, x2, self.norm, self.eps, self.keepdim) -@weak_module class CosineSimilarity(Module): r"""Returns cosine similarity between :math:`x_1` and :math:`x_2`, computed along dim. @@ -68,6 +64,5 @@ class CosineSimilarity(Module): self.dim = dim self.eps = eps - @weak_script_method def forward(self, x1, x2): return F.cosine_similarity(x1, x2, self.dim, self.eps) diff --git a/torch/nn/modules/dropout.py b/torch/nn/modules/dropout.py index eaf27217ca8..1f499b01a7a 100644 --- a/torch/nn/modules/dropout.py +++ b/torch/nn/modules/dropout.py @@ -1,6 +1,5 @@ from .module import Module from .. import functional as F -from ..._jit_internal import weak_module, weak_script_method class _DropoutNd(Module): @@ -18,7 +17,6 @@ class _DropoutNd(Module): return 'p={}, inplace={}'.format(self.p, self.inplace) -@weak_module class Dropout(_DropoutNd): r"""During training, randomly zeroes some of the elements of the input tensor with probability :attr:`p` using samples from a Bernoulli @@ -52,12 +50,10 @@ class Dropout(_DropoutNd): detectors: https://arxiv.org/abs/1207.0580 """ - @weak_script_method def forward(self, input): return F.dropout(input, self.p, self.training, self.inplace) -@weak_module class Dropout2d(_DropoutNd): r"""Randomly zero out entire channels (a channel is a 2D feature map, e.g., the :math:`j`-th channel of the :math:`i`-th sample in the @@ -96,12 +92,10 @@ class Dropout2d(_DropoutNd): http://arxiv.org/abs/1411.4280 """ - @weak_script_method def forward(self, input): return F.dropout2d(input, self.p, self.training, self.inplace) -@weak_module class Dropout3d(_DropoutNd): r"""Randomly zero out entire channels (a channel is a 3D feature map, e.g., the :math:`j`-th channel of the :math:`i`-th sample in the @@ -140,12 +134,10 @@ class Dropout3d(_DropoutNd): http://arxiv.org/abs/1411.4280 """ - @weak_script_method def forward(self, input): return F.dropout3d(input, self.p, self.training, self.inplace) -@weak_module class AlphaDropout(_DropoutNd): r"""Applies Alpha Dropout over the input. @@ -184,14 +176,11 @@ class AlphaDropout(_DropoutNd): .. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515 """ - @weak_script_method def forward(self, input): return F.alpha_dropout(input, self.p, self.training) -@weak_module class FeatureAlphaDropout(_DropoutNd): - @weak_script_method def forward(self, input): return F.feature_alpha_dropout(input, self.p, self.training) diff --git a/torch/nn/modules/fold.py b/torch/nn/modules/fold.py index 8c3d95f254b..ecbdec2ca82 100644 --- a/torch/nn/modules/fold.py +++ b/torch/nn/modules/fold.py @@ -1,10 +1,8 @@ # coding=utf-8 from .module import Module from .. import functional as F -from ..._jit_internal import weak_module, weak_script_method -@weak_module class Fold(Module): r"""Combines an array of sliding local blocks into a large containing tensor. @@ -101,7 +99,6 @@ class Fold(Module): self.padding = padding self.stride = stride - @weak_script_method def forward(self, input): return F.fold(input, self.output_size, self.kernel_size, self.dilation, self.padding, self.stride) @@ -113,7 +110,6 @@ class Fold(Module): ) -@weak_module class Unfold(Module): r"""Extracts sliding local blocks from a batched input tensor. @@ -217,7 +213,6 @@ class Unfold(Module): self.padding = padding self.stride = stride - @weak_script_method def forward(self, input): return F.unfold(input, self.kernel_size, self.dilation, self.padding, self.stride) diff --git a/torch/nn/modules/instancenorm.py b/torch/nn/modules/instancenorm.py index b767ceaec3f..0c58b44aa9f 100644 --- a/torch/nn/modules/instancenorm.py +++ b/torch/nn/modules/instancenorm.py @@ -1,6 +1,5 @@ from .batchnorm import _BatchNorm from .. import functional as F -from ..._jit_internal import weak_module, weak_script_method class _InstanceNorm(_BatchNorm): @@ -9,7 +8,6 @@ class _InstanceNorm(_BatchNorm): super(_InstanceNorm, self).__init__( num_features, eps, momentum, affine, track_running_stats) - @weak_script_method def _check_input_dim(self, input): raise NotImplementedError @@ -43,7 +41,6 @@ class _InstanceNorm(_BatchNorm): state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) - @weak_script_method def forward(self, input): self._check_input_dim(input) @@ -52,7 +49,6 @@ class _InstanceNorm(_BatchNorm): self.training or not self.track_running_stats, self.momentum, self.eps) -@weak_module class InstanceNorm1d(_InstanceNorm): r"""Applies Instance Normalization over a 3D input (a mini-batch of 1D inputs with optional additional channel dimension) as described in the paper @@ -121,7 +117,6 @@ class InstanceNorm1d(_InstanceNorm): https://arxiv.org/abs/1607.08022 """ - @weak_script_method def _check_input_dim(self, input): if input.dim() == 2: raise ValueError( @@ -135,7 +130,6 @@ class InstanceNorm1d(_InstanceNorm): .format(input.dim())) -@weak_module class InstanceNorm2d(_InstanceNorm): r"""Applies Instance Normalization over a 4D input (a mini-batch of 2D inputs with additional channel dimension) as described in the paper @@ -204,14 +198,12 @@ class InstanceNorm2d(_InstanceNorm): https://arxiv.org/abs/1607.08022 """ - @weak_script_method def _check_input_dim(self, input): if input.dim() != 4: raise ValueError('expected 4D input (got {}D input)' .format(input.dim())) -@weak_module class InstanceNorm3d(_InstanceNorm): r"""Applies Instance Normalization over a 5D input (a mini-batch of 3D inputs with additional channel dimension) as described in the paper @@ -280,7 +272,6 @@ class InstanceNorm3d(_InstanceNorm): https://arxiv.org/abs/1607.08022 """ - @weak_script_method def _check_input_dim(self, input): if input.dim() != 5: raise ValueError('expected 5D input (got {}D input)' diff --git a/torch/nn/modules/linear.py b/torch/nn/modules/linear.py index f5d11e11bfb..70de79dcaa5 100644 --- a/torch/nn/modules/linear.py +++ b/torch/nn/modules/linear.py @@ -5,10 +5,8 @@ from torch.nn.parameter import Parameter from .. import functional as F from .. import init from .module import Module -from ..._jit_internal import weak_module, weak_script_method -@weak_module class Identity(Module): r"""A placeholder identity operator that is argument-insensitive. @@ -28,12 +26,10 @@ class Identity(Module): def __init__(self, *args, **kwargs): super(Identity, self).__init__() - @weak_script_method def forward(self, input): return input -@weak_module class Linear(Module): r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b` @@ -87,7 +83,6 @@ class Linear(Module): bound = 1 / math.sqrt(fan_in) init.uniform_(self.bias, -bound, bound) - @weak_script_method def forward(self, input): return F.linear(input, self.weight, self.bias) @@ -97,7 +92,6 @@ class Linear(Module): ) -@weak_module class Bilinear(Module): r"""Applies a bilinear transformation to the incoming data: :math:`y = x_1 A x_2 + b` @@ -157,7 +151,6 @@ class Bilinear(Module): if self.bias is not None: init.uniform_(self.bias, -bound, bound) - @weak_script_method def forward(self, input1, input2): return F.bilinear(input1, input2, self.weight, self.bias) diff --git a/torch/nn/modules/loss.py b/torch/nn/modules/loss.py index ce05a0eabd4..905190eb985 100644 --- a/torch/nn/modules/loss.py +++ b/torch/nn/modules/loss.py @@ -3,7 +3,6 @@ import warnings from .module import Module from .. import functional as F from .. import _reduction as _Reduction -from ..._jit_internal import weak_module, weak_script_method class _Loss(Module): @@ -21,7 +20,6 @@ class _WeightedLoss(_Loss): self.register_buffer('weight', weight) -@weak_module class L1Loss(_Loss): r"""Creates a criterion that measures the mean absolute error (MAE) between each element in the input :math:`x` and target :math:`y`. @@ -86,12 +84,10 @@ class L1Loss(_Loss): def __init__(self, size_average=None, reduce=None, reduction='mean'): super(L1Loss, self).__init__(size_average, reduce, reduction) - @weak_script_method def forward(self, input, target): return F.l1_loss(input, target, reduction=self.reduction) -@weak_module class NLLLoss(_WeightedLoss): r"""The negative log likelihood loss. It is useful to train a classification problem with `C` classes. @@ -204,12 +200,10 @@ class NLLLoss(_WeightedLoss): super(NLLLoss, self).__init__(weight, size_average, reduce, reduction) self.ignore_index = ignore_index - @weak_script_method def forward(self, input, target): return F.nll_loss(input, target, weight=self.weight, ignore_index=self.ignore_index, reduction=self.reduction) -@weak_module class NLLLoss2d(NLLLoss): def __init__(self, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean'): @@ -219,7 +213,6 @@ class NLLLoss2d(NLLLoss): super(NLLLoss2d, self).__init__(weight, size_average, ignore_index, reduce, reduction) -@weak_module class PoissonNLLLoss(_Loss): r"""Negative log likelihood loss with Poisson distribution of target. @@ -286,13 +279,11 @@ class PoissonNLLLoss(_Loss): self.full = full self.eps = eps - @weak_script_method def forward(self, log_input, target): return F.poisson_nll_loss(log_input, target, log_input=self.log_input, full=self.full, eps=self.eps, reduction=self.reduction) -@weak_module class KLDivLoss(_Loss): r"""The `Kullback-Leibler divergence`_ Loss @@ -370,12 +361,10 @@ class KLDivLoss(_Loss): def __init__(self, size_average=None, reduce=None, reduction='mean'): super(KLDivLoss, self).__init__(size_average, reduce, reduction) - @weak_script_method def forward(self, input, target): return F.kl_div(input, target, reduction=self.reduction) -@weak_module class MSELoss(_Loss): r"""Creates a criterion that measures the mean squared error (squared L2 norm) between each element in the input :math:`x` and target :math:`y`. @@ -438,12 +427,10 @@ class MSELoss(_Loss): def __init__(self, size_average=None, reduce=None, reduction='mean'): super(MSELoss, self).__init__(size_average, reduce, reduction) - @weak_script_method def forward(self, input, target): return F.mse_loss(input, target, reduction=self.reduction) -@weak_module class BCELoss(_WeightedLoss): r"""Creates a criterion that measures the Binary Cross Entropy between the target and the output: @@ -507,12 +494,10 @@ class BCELoss(_WeightedLoss): def __init__(self, weight=None, size_average=None, reduce=None, reduction='mean'): super(BCELoss, self).__init__(weight, size_average, reduce, reduction) - @weak_script_method def forward(self, input, target): return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction) -@weak_module class BCEWithLogitsLoss(_Loss): r"""This loss combines a `Sigmoid` layer and the `BCELoss` in one single class. This version is more numerically stable than using a plain `Sigmoid` @@ -609,7 +594,6 @@ class BCEWithLogitsLoss(_Loss): self.register_buffer('weight', weight) self.register_buffer('pos_weight', pos_weight) - @weak_script_method def forward(self, input, target): return F.binary_cross_entropy_with_logits(input, target, self.weight, @@ -617,7 +601,6 @@ class BCEWithLogitsLoss(_Loss): reduction=self.reduction) -@weak_module class HingeEmbeddingLoss(_Loss): r"""Measures the loss given an input tensor :math:`x` and a labels tensor :math:`y` (containing 1 or -1). @@ -673,12 +656,10 @@ class HingeEmbeddingLoss(_Loss): super(HingeEmbeddingLoss, self).__init__(size_average, reduce, reduction) self.margin = margin - @weak_script_method def forward(self, input, target): return F.hinge_embedding_loss(input, target, margin=self.margin, reduction=self.reduction) -@weak_module class MultiLabelMarginLoss(_Loss): r"""Creates a criterion that optimizes a multi-class multi-classification hinge loss (margin-based loss) between input :math:`x` (a 2D mini-batch `Tensor`) @@ -739,12 +720,10 @@ class MultiLabelMarginLoss(_Loss): def __init__(self, size_average=None, reduce=None, reduction='mean'): super(MultiLabelMarginLoss, self).__init__(size_average, reduce, reduction) - @weak_script_method def forward(self, input, target): return F.multilabel_margin_loss(input, target, reduction=self.reduction) -@weak_module class SmoothL1Loss(_Loss): r"""Creates a criterion that uses a squared term if the absolute element-wise error falls below 1 and an L1 term otherwise. @@ -799,12 +778,10 @@ class SmoothL1Loss(_Loss): def __init__(self, size_average=None, reduce=None, reduction='mean'): super(SmoothL1Loss, self).__init__(size_average, reduce, reduction) - @weak_script_method def forward(self, input, target): return F.smooth_l1_loss(input, target, reduction=self.reduction) -@weak_module class SoftMarginLoss(_Loss): r"""Creates a criterion that optimizes a two-class classification logistic loss between input tensor :math:`x` and target tensor :math:`y` @@ -842,12 +819,10 @@ class SoftMarginLoss(_Loss): def __init__(self, size_average=None, reduce=None, reduction='mean'): super(SoftMarginLoss, self).__init__(size_average, reduce, reduction) - @weak_script_method def forward(self, input, target): return F.soft_margin_loss(input, target, reduction=self.reduction) -@weak_module class CrossEntropyLoss(_WeightedLoss): r"""This criterion combines :func:`nn.LogSoftmax` and :func:`nn.NLLLoss` in one single class. @@ -936,13 +911,11 @@ class CrossEntropyLoss(_WeightedLoss): super(CrossEntropyLoss, self).__init__(weight, size_average, reduce, reduction) self.ignore_index = ignore_index - @weak_script_method def forward(self, input, target): return F.cross_entropy(input, target, weight=self.weight, ignore_index=self.ignore_index, reduction=self.reduction) -@weak_module class MultiLabelSoftMarginLoss(_WeightedLoss): r"""Creates a criterion that optimizes a multi-label one-versus-all loss based on max-entropy, between input :math:`x` and target :math:`y` of size @@ -986,12 +959,10 @@ class MultiLabelSoftMarginLoss(_WeightedLoss): def __init__(self, weight=None, size_average=None, reduce=None, reduction='mean'): super(MultiLabelSoftMarginLoss, self).__init__(weight, size_average, reduce, reduction) - @weak_script_method def forward(self, input, target): return F.multilabel_soft_margin_loss(input, target, weight=self.weight, reduction=self.reduction) -@weak_module class CosineEmbeddingLoss(_Loss): r"""Creates a criterion that measures the loss given input tensors :math:`x_1`, :math:`x_2` and a `Tensor` label :math:`y` with values 1 or -1. @@ -1034,12 +1005,10 @@ class CosineEmbeddingLoss(_Loss): super(CosineEmbeddingLoss, self).__init__(size_average, reduce, reduction) self.margin = margin - @weak_script_method def forward(self, input1, input2, target): return F.cosine_embedding_loss(input1, input2, target, margin=self.margin, reduction=self.reduction) -@weak_module class MarginRankingLoss(_Loss): r"""Creates a criterion that measures the loss given inputs :math:`x1`, :math:`x2`, two 1D mini-batch `Tensors`, @@ -1082,12 +1051,10 @@ class MarginRankingLoss(_Loss): super(MarginRankingLoss, self).__init__(size_average, reduce, reduction) self.margin = margin - @weak_script_method def forward(self, input1, input2, target): return F.margin_ranking_loss(input1, input2, target, margin=self.margin, reduction=self.reduction) -@weak_module class MultiMarginLoss(_WeightedLoss): r"""Creates a criterion that optimizes a multi-class classification hinge loss (margin-based loss) between input :math:`x` (a 2D mini-batch `Tensor`) and @@ -1145,13 +1112,11 @@ class MultiMarginLoss(_WeightedLoss): self.p = p self.margin = margin - @weak_script_method def forward(self, input, target): return F.multi_margin_loss(input, target, p=self.p, margin=self.margin, weight=self.weight, reduction=self.reduction) -@weak_module class TripletMarginLoss(_Loss): r"""Creates a criterion that measures the triplet loss given an input tensors :math:`x1`, :math:`x2`, :math:`x3` and a margin with a value greater than :math:`0`. @@ -1221,13 +1186,11 @@ class TripletMarginLoss(_Loss): self.eps = eps self.swap = swap - @weak_script_method def forward(self, anchor, positive, negative): return F.triplet_margin_loss(anchor, positive, negative, margin=self.margin, p=self.p, eps=self.eps, swap=self.swap, reduction=self.reduction) -@weak_module class CTCLoss(_Loss): r"""The Connectionist Temporal Classification loss. @@ -1327,7 +1290,6 @@ class CTCLoss(_Loss): self.blank = blank self.zero_infinity = zero_infinity - @weak_script_method def forward(self, log_probs, targets, input_lengths, target_lengths): return F.ctc_loss(log_probs, targets, input_lengths, target_lengths, self.blank, self.reduction, self.zero_infinity) diff --git a/torch/nn/modules/normalization.py b/torch/nn/modules/normalization.py index 60a4ab21e7e..61ca8a1b2f5 100644 --- a/torch/nn/modules/normalization.py +++ b/torch/nn/modules/normalization.py @@ -4,10 +4,8 @@ from torch.nn.parameter import Parameter from .module import Module from .. import functional as F from .. import init -from ..._jit_internal import weak_module, weak_script_method -@weak_module class LocalResponseNorm(Module): r"""Applies local response normalization over an input signal composed of several input planes, where channels occupy the second dimension. @@ -45,7 +43,6 @@ class LocalResponseNorm(Module): self.beta = beta self.k = k - @weak_script_method def forward(self, input): return F.local_response_norm(input, self.size, self.alpha, self.beta, self.k) @@ -71,7 +68,6 @@ class CrossMapLRN2d(Module): return '{size}, alpha={alpha}, beta={beta}, k={k}'.format(**self.__dict__) -@weak_module class LayerNorm(Module): r"""Applies Layer Normalization over a mini-batch of inputs as described in the paper `Layer Normalization`_ . @@ -151,7 +147,6 @@ class LayerNorm(Module): init.ones_(self.weight) init.zeros_(self.bias) - @weak_script_method def forward(self, input): return F.layer_norm( input, self.normalized_shape, self.weight, self.bias, self.eps) @@ -161,7 +156,6 @@ class LayerNorm(Module): 'elementwise_affine={elementwise_affine}'.format(**self.__dict__) -@weak_module class GroupNorm(Module): r"""Applies Group Normalization over a mini-batch of inputs as described in the paper `Group Normalization`_ . @@ -226,7 +220,6 @@ class GroupNorm(Module): init.ones_(self.weight) init.zeros_(self.bias) - @weak_script_method def forward(self, input): return F.group_norm( input, self.num_groups, self.weight, self.bias, self.eps) diff --git a/torch/nn/modules/padding.py b/torch/nn/modules/padding.py index 9af29b88492..58c3d441fc2 100644 --- a/torch/nn/modules/padding.py +++ b/torch/nn/modules/padding.py @@ -1,13 +1,11 @@ from .module import Module from .utils import _pair, _quadruple, _ntuple from .. import functional as F -from ..._jit_internal import weak_module, weak_script_method # TODO: grad_output size asserts in THNN -@weak_module class _ConstantPadNd(Module): __constants__ = ['padding', 'value'] @@ -15,7 +13,6 @@ class _ConstantPadNd(Module): super(_ConstantPadNd, self).__init__() self.value = value - @weak_script_method def forward(self, input): return F.pad(input, self.padding, 'constant', self.value) @@ -23,7 +20,6 @@ class _ConstantPadNd(Module): return 'padding={}, value={}'.format(self.padding, self.value) -@weak_module class ConstantPad1d(_ConstantPadNd): r"""Pads the input tensor boundaries with a constant value. @@ -73,7 +69,6 @@ class ConstantPad1d(_ConstantPadNd): self.padding = _pair(padding) -@weak_module class ConstantPad2d(_ConstantPadNd): r"""Pads the input tensor boundaries with a constant value. @@ -123,7 +118,6 @@ class ConstantPad2d(_ConstantPadNd): self.padding = _quadruple(padding) -@weak_module class ConstantPad3d(_ConstantPadNd): r"""Pads the input tensor boundaries with a constant value. @@ -162,11 +156,9 @@ class ConstantPad3d(_ConstantPadNd): self.padding = _ntuple(6)(padding) -@weak_module class _ReflectionPadNd(Module): __constants__ = ['padding'] - @weak_script_method def forward(self, input): return F.pad(input, self.padding, 'reflect') @@ -174,7 +166,6 @@ class _ReflectionPadNd(Module): return '{}'.format(self.padding) -@weak_module class ReflectionPad1d(_ReflectionPadNd): r"""Pads the input tensor using the reflection of the input boundary. @@ -214,7 +205,6 @@ class ReflectionPad1d(_ReflectionPadNd): self.padding = _pair(padding) -@weak_module class ReflectionPad2d(_ReflectionPadNd): r"""Pads the input tensor using the reflection of the input boundary. @@ -265,11 +255,9 @@ class ReflectionPad2d(_ReflectionPadNd): self.padding = _quadruple(padding) -@weak_module class _ReplicationPadNd(Module): __constants__ = ['padding'] - @weak_script_method def forward(self, input): return F.pad(input, self.padding, 'replicate') @@ -277,7 +265,6 @@ class _ReplicationPadNd(Module): return '{}'.format(self.padding) -@weak_module class ReplicationPad1d(_ReplicationPadNd): r"""Pads the input tensor using replication of the input boundary. @@ -317,7 +304,6 @@ class ReplicationPad1d(_ReplicationPadNd): self.padding = _pair(padding) -@weak_module class ReplicationPad2d(_ReplicationPadNd): r"""Pads the input tensor using replication of the input boundary. @@ -368,7 +354,6 @@ class ReplicationPad2d(_ReplicationPadNd): self.padding = _quadruple(padding) -@weak_module class ReplicationPad3d(_ReplicationPadNd): r"""Pads the input tensor using replication of the input boundary. @@ -407,7 +392,6 @@ class ReplicationPad3d(_ReplicationPadNd): self.padding = _ntuple(6)(padding) -@weak_module class ZeroPad2d(ConstantPad2d): r"""Pads the input tensor boundaries with zero. diff --git a/torch/nn/modules/pixelshuffle.py b/torch/nn/modules/pixelshuffle.py index 23213731872..69d1b677c91 100644 --- a/torch/nn/modules/pixelshuffle.py +++ b/torch/nn/modules/pixelshuffle.py @@ -1,9 +1,7 @@ from .module import Module from .. import functional as F -from ..._jit_internal import weak_module, weak_script_method -@weak_module class PixelShuffle(Module): r"""Rearranges elements in a tensor of shape :math:`(*, C \times r^2, H, W)` to a tensor of shape :math:`(*, C, H \times r, W \times r)`. @@ -41,7 +39,6 @@ class PixelShuffle(Module): super(PixelShuffle, self).__init__() self.upscale_factor = upscale_factor - @weak_script_method def forward(self, input): return F.pixel_shuffle(input, self.upscale_factor) diff --git a/torch/nn/modules/pooling.py b/torch/nn/modules/pooling.py index e4ed16dcb1a..984cdb56d11 100644 --- a/torch/nn/modules/pooling.py +++ b/torch/nn/modules/pooling.py @@ -1,10 +1,8 @@ from .module import Module from .utils import _single, _pair, _triple from .. import functional as F -from ..._jit_internal import weak_module, weak_script_method -@weak_module class _MaxPoolNd(Module): __constants__ = ['kernel_size', 'stride', 'padding', 'dilation', 'return_indices', 'ceil_mode'] @@ -24,7 +22,6 @@ class _MaxPoolNd(Module): ', dilation={dilation}, ceil_mode={ceil_mode}'.format(**self.__dict__) -@weak_module class MaxPool1d(_MaxPoolNd): r"""Applies a 1D max pooling over an input signal composed of several input planes. @@ -68,7 +65,6 @@ class MaxPool1d(_MaxPoolNd): https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md """ - @weak_script_method def forward(self, input): return F.max_pool1d(input, self.kernel_size, self.stride, self.padding, self.dilation, self.ceil_mode, @@ -79,7 +75,6 @@ class MaxPool1d(_MaxPoolNd): ', dilation={dilation}, ceil_mode={ceil_mode}'.format(**self.__dict__) -@weak_module class MaxPool2d(_MaxPoolNd): r"""Applies a 2D max pooling over an input signal composed of several input planes. @@ -139,14 +134,12 @@ class MaxPool2d(_MaxPoolNd): https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md """ - @weak_script_method def forward(self, input): return F.max_pool2d(input, self.kernel_size, self.stride, self.padding, self.dilation, self.ceil_mode, self.return_indices) -@weak_module class MaxPool3d(_MaxPoolNd): r"""Applies a 3D max pooling over an input signal composed of several input planes. @@ -210,14 +203,12 @@ class MaxPool3d(_MaxPoolNd): https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md """ # noqa: E501 - @weak_script_method def forward(self, input): return F.max_pool3d(input, self.kernel_size, self.stride, self.padding, self.dilation, self.ceil_mode, self.return_indices) -@weak_module class _MaxUnpoolNd(Module): def extra_repr(self): @@ -226,7 +217,6 @@ class _MaxUnpoolNd(Module): ) -@weak_module class MaxUnpool1d(_MaxUnpoolNd): r"""Computes a partial inverse of :class:`MaxPool1d`. @@ -292,7 +282,6 @@ class MaxUnpool1d(_MaxUnpoolNd): self.padding, output_size) -@weak_module class MaxUnpool2d(_MaxUnpoolNd): r"""Computes a partial inverse of :class:`MaxPool2d`. @@ -366,7 +355,6 @@ class MaxUnpool2d(_MaxUnpoolNd): self.padding, output_size) -@weak_module class MaxUnpool3d(_MaxUnpoolNd): r"""Computes a partial inverse of :class:`MaxPool3d`. @@ -429,7 +417,6 @@ class MaxUnpool3d(_MaxUnpoolNd): self.padding, output_size) -@weak_module class _AvgPoolNd(Module): __constants__ = ['kernel_size', 'stride', 'padding', 'ceil_mode', 'count_include_pad'] @@ -439,7 +426,6 @@ class _AvgPoolNd(Module): ) -@weak_module class AvgPool1d(_AvgPoolNd): r"""Applies a 1D average pooling over an input signal composed of several input planes. @@ -490,14 +476,12 @@ class AvgPool1d(_AvgPoolNd): self.ceil_mode = ceil_mode self.count_include_pad = count_include_pad - @weak_script_method def forward(self, input): return F.avg_pool1d( input, self.kernel_size, self.stride, self.padding, self.ceil_mode, self.count_include_pad) -@weak_module class AvgPool2d(_AvgPoolNd): r"""Applies a 2D average pooling over an input signal composed of several input planes. @@ -557,13 +541,11 @@ class AvgPool2d(_AvgPoolNd): self.ceil_mode = ceil_mode self.count_include_pad = count_include_pad - @weak_script_method def forward(self, input): return F.avg_pool2d(input, self.kernel_size, self.stride, self.padding, self.ceil_mode, self.count_include_pad) -@weak_module class AvgPool3d(_AvgPoolNd): r"""Applies a 3D average pooling over an input signal composed of several input planes. @@ -630,7 +612,6 @@ class AvgPool3d(_AvgPoolNd): self.ceil_mode = ceil_mode self.count_include_pad = count_include_pad - @weak_script_method def forward(self, input): return F.avg_pool3d(input, self.kernel_size, self.stride, self.padding, self.ceil_mode, self.count_include_pad) @@ -642,7 +623,6 @@ class AvgPool3d(_AvgPoolNd): self.__dict__.setdefault('count_include_pad', True) -@weak_module class FractionalMaxPool2d(Module): r"""Applies a 2D fractional max pooling over an input signal composed of several input planes. @@ -694,7 +674,6 @@ class FractionalMaxPool2d(Module): raise ValueError("output_ratio must be between 0 and 1 (got {})" .format(output_ratio)) - @weak_script_method def forward(self, input): return F.fractional_max_pool2d( input, self.kernel_size, self.output_size, self.output_ratio, @@ -702,7 +681,6 @@ class FractionalMaxPool2d(Module): _random_samples=self._random_samples) -@weak_module class FractionalMaxPool3d(Module): r"""Applies a 3D fractional max pooling over an input signal composed of several input planes. @@ -754,7 +732,6 @@ class FractionalMaxPool3d(Module): raise ValueError("output_ratio must be between 0 and 1 (got {})" .format(output_ratio)) - @weak_script_method def forward(self, input): return F.fractional_max_pool3d( input, self.kernel_size, self.output_size, self.output_ratio, @@ -762,7 +739,6 @@ class FractionalMaxPool3d(Module): _random_samples=self._random_samples) -@weak_module class _LPPoolNd(Module): __constants__ = ['norm_type', 'kernel_size', 'stride', 'ceil_mode'] @@ -778,7 +754,6 @@ class _LPPoolNd(Module): 'ceil_mode={ceil_mode}'.format(**self.__dict__) -@weak_module class LPPool1d(_LPPoolNd): r"""Applies a 1D power-average pooling over an input signal composed of several input planes. @@ -814,14 +789,11 @@ class LPPool1d(_LPPoolNd): >>> output = m(input) """ - @weak_script_method - @weak_script_method def forward(self, input): return F.lp_pool1d(input, float(self.norm_type), self.kernel_size, self.stride, self.ceil_mode) -@weak_module class LPPool2d(_LPPoolNd): r"""Applies a 2D power-average pooling over an input signal composed of several input planes. @@ -871,13 +843,11 @@ class LPPool2d(_LPPoolNd): """ - @weak_script_method def forward(self, input): return F.lp_pool2d(input, float(self.norm_type), self.kernel_size, self.stride, self.ceil_mode) -@weak_module class _AdaptiveMaxPoolNd(Module): __constants__ = ['output_size', 'return_indices'] @@ -893,7 +863,6 @@ class _AdaptiveMaxPoolNd(Module): # output shapes are, and how the operation computes output. -@weak_module class AdaptiveMaxPool1d(_AdaptiveMaxPoolNd): r"""Applies a 1D adaptive max pooling over an input signal composed of several input planes. @@ -913,12 +882,10 @@ class AdaptiveMaxPool1d(_AdaptiveMaxPoolNd): """ - @weak_script_method def forward(self, input): return F.adaptive_max_pool1d(input, self.output_size, self.return_indices) -@weak_module class AdaptiveMaxPool2d(_AdaptiveMaxPoolNd): r"""Applies a 2D adaptive max pooling over an input signal composed of several input planes. @@ -949,12 +916,10 @@ class AdaptiveMaxPool2d(_AdaptiveMaxPoolNd): """ - @weak_script_method def forward(self, input): return F.adaptive_max_pool2d(input, self.output_size, self.return_indices) -@weak_module class AdaptiveMaxPool3d(_AdaptiveMaxPoolNd): r"""Applies a 3D adaptive max pooling over an input signal composed of several input planes. @@ -986,12 +951,10 @@ class AdaptiveMaxPool3d(_AdaptiveMaxPoolNd): """ - @weak_script_method def forward(self, input): return F.adaptive_max_pool3d(input, self.output_size, self.return_indices) -@weak_module class _AdaptiveAvgPoolNd(Module): __constants__ = ['output_size'] @@ -1003,7 +966,6 @@ class _AdaptiveAvgPoolNd(Module): return 'output_size={}'.format(self.output_size) -@weak_module class AdaptiveAvgPool1d(_AdaptiveAvgPoolNd): r"""Applies a 1D adaptive average pooling over an input signal composed of several input planes. @@ -1021,12 +983,10 @@ class AdaptiveAvgPool1d(_AdaptiveAvgPoolNd): """ - @weak_script_method def forward(self, input): return F.adaptive_avg_pool1d(input, self.output_size) -@weak_module class AdaptiveAvgPool2d(_AdaptiveAvgPoolNd): r"""Applies a 2D adaptive average pooling over an input signal composed of several input planes. @@ -1055,12 +1015,10 @@ class AdaptiveAvgPool2d(_AdaptiveAvgPoolNd): """ - @weak_script_method def forward(self, input): return F.adaptive_avg_pool2d(input, self.output_size) -@weak_module class AdaptiveAvgPool3d(_AdaptiveAvgPoolNd): r"""Applies a 3D adaptive average pooling over an input signal composed of several input planes. @@ -1089,6 +1047,5 @@ class AdaptiveAvgPool3d(_AdaptiveAvgPoolNd): """ - @weak_script_method def forward(self, input): return F.adaptive_avg_pool3d(input, self.output_size) diff --git a/torch/nn/modules/rnn.py b/torch/nn/modules/rnn.py index f2c6f70c241..ee23762095c 100644 --- a/torch/nn/modules/rnn.py +++ b/torch/nn/modules/rnn.py @@ -8,8 +8,7 @@ from ..parameter import Parameter from ..utils.rnn import PackedSequence, get_packed_sequence from .. import init from .. import _VF -from ..._jit_internal import weak_module, weak_script_method, weak_script, \ - _parameter_list +from ..._jit_internal import _parameter_list _rnn_impls = { 'GRU': _VF.gru, @@ -18,7 +17,6 @@ _rnn_impls = { } -@weak_script def apply_permutation(tensor, permutation, dim=1): # type: (Tensor, Tensor, int) -> Tensor return tensor.index_select(dim, permutation) @@ -139,7 +137,6 @@ class RNNBase(Module): def _get_flat_weights(self): return self._flat_weights - @weak_script_method def check_input(self, input, batch_sizes): # type: (Tensor, Optional[Tensor]) -> None expected_input_dim = 2 if batch_sizes is not None else 3 @@ -152,7 +149,6 @@ class RNNBase(Module): 'input.size(-1) must be equal to input_size. Expected {}, got {}'.format( self.input_size, input.size(-1))) - @weak_script_method def get_expected_hidden_size(self, input, batch_sizes): # type: (Tensor, Optional[Tensor]) -> Tuple[int, int, int] if batch_sizes is not None: @@ -165,7 +161,6 @@ class RNNBase(Module): mini_batch, self.hidden_size) return expected_hidden_size - @weak_script_method def check_hidden_size(self, hx, expected_hidden_size, msg='Expected hidden size {}, got {}'): # type: (Tensor, Tuple[int, int, int], str) -> None if hx.size() != expected_hidden_size: @@ -374,7 +369,6 @@ class RNN(RNNBase): super(RNN, self).__init__(mode, *args, **kwargs) -@weak_module class LSTM(RNNBase): r"""Applies a multi-layer long short-term memory (LSTM) RNN to an input sequence. @@ -484,7 +478,6 @@ class LSTM(RNNBase): def __init__(self, *args, **kwargs): super(LSTM, self).__init__('LSTM', *args, **kwargs) - @weak_script_method def check_forward_args(self, input, hidden, batch_sizes): # type: (Tensor, Tuple[Tensor, Tensor], Optional[Tensor]) -> None self.check_input(input, batch_sizes) @@ -495,14 +488,12 @@ class LSTM(RNNBase): self.check_hidden_size(hidden[1], expected_hidden_size, 'Expected hidden[1] size {}, got {}') - @weak_script_method def permute_hidden(self, hx, permutation): # type: (Tuple[Tensor, Tensor], Optional[Tensor]) -> Tuple[Tensor, Tensor] if permutation is None: return hx return apply_permutation(hx[0], permutation), apply_permutation(hx[1], permutation) - @weak_script_method def forward_impl(self, input, hx, batch_sizes, max_batch_size, sorted_indices): # type: (Tensor, Optional[Tuple[Tensor, Tensor]], Optional[Tensor], int, Optional[Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa if hx is None: @@ -528,7 +519,7 @@ class LSTM(RNNBase): return output, hidden - @weak_script_method + @torch._jit_internal.export def forward_tensor(self, input, hx=None): # type: (Tensor, Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tuple[Tensor, Tensor]] batch_sizes = None @@ -540,7 +531,7 @@ class LSTM(RNNBase): return output, self.permute_hidden(hidden, unsorted_indices) - @weak_script_method + @torch._jit_internal.export def forward_packed(self, input, hx=None): # type: (Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]], Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]], Tuple[Tensor, Tensor]] # noqa input, batch_sizes, sorted_indices, unsorted_indices = input @@ -552,6 +543,7 @@ class LSTM(RNNBase): output = get_packed_sequence(output, batch_sizes, sorted_indices, unsorted_indices) return output, self.permute_hidden(hidden, unsorted_indices) + @torch._jit_internal.ignore def forward(self, input, hx=None): if isinstance(input, PackedSequence): return self.forward_packed(input, hx) @@ -694,14 +686,12 @@ class RNNCellBase(Module): s += ', nonlinearity={nonlinearity}' return s.format(**self.__dict__) - @weak_script_method def check_forward_input(self, input): if input.size(1) != self.input_size: raise RuntimeError( "input has inconsistent input_size: got {}, expected {}".format( input.size(1), self.input_size)) - @weak_script_method def check_forward_hidden(self, input, hx, hidden_label=''): # type: (Tensor, Tensor, str) -> None if input.size(0) != hx.size(0): @@ -720,7 +710,6 @@ class RNNCellBase(Module): init.uniform_(weight, -stdv, stdv) -@weak_module class RNNCell(RNNCellBase): r"""An Elman RNN cell with tanh or ReLU non-linearity. @@ -784,7 +773,6 @@ class RNNCell(RNNCellBase): super(RNNCell, self).__init__(input_size, hidden_size, bias, num_chunks=1) self.nonlinearity = nonlinearity - @weak_script_method def forward(self, input, hx=None): # type: (Tensor, Optional[Tensor]) -> Tensor self.check_forward_input(input) @@ -810,7 +798,6 @@ class RNNCell(RNNCellBase): return ret -@weak_module class LSTMCell(RNNCellBase): r"""A long short-term memory (LSTM) cell. @@ -875,7 +862,6 @@ class LSTMCell(RNNCellBase): def __init__(self, input_size, hidden_size, bias=True): super(LSTMCell, self).__init__(input_size, hidden_size, bias, num_chunks=4) - @weak_script_method def forward(self, input, hx=None): # type: (Tensor, Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tensor] self.check_forward_input(input) @@ -891,7 +877,6 @@ class LSTMCell(RNNCellBase): ) -@weak_module class GRUCell(RNNCellBase): r"""A gated recurrent unit (GRU) cell @@ -957,7 +942,6 @@ class GRUCell(RNNCellBase): def __init__(self, input_size, hidden_size, bias=True): super(GRUCell, self).__init__(input_size, hidden_size, bias, num_chunks=3) - @weak_script_method def forward(self, input, hx=None): # type: (Tensor, Optional[Tensor]) -> Tensor self.check_forward_input(input) diff --git a/torch/nn/modules/sparse.py b/torch/nn/modules/sparse.py index f96348de077..dd063b4a429 100644 --- a/torch/nn/modules/sparse.py +++ b/torch/nn/modules/sparse.py @@ -4,10 +4,8 @@ from torch.nn.parameter import Parameter from .module import Module from .. import functional as F from .. import init -from torch._jit_internal import weak_module, weak_script_method -@weak_module class Embedding(Module): r"""A simple lookup table that stores embeddings of a fixed dictionary and size. @@ -110,7 +108,6 @@ class Embedding(Module): with torch.no_grad(): self.weight[self.padding_idx].fill_(0) - @weak_script_method def forward(self, input): return F.embedding( input, self.weight, self.padding_idx, self.max_norm, @@ -173,7 +170,6 @@ class Embedding(Module): return embedding -@weak_module class EmbeddingBag(Module): r"""Computes sums or means of 'bags' of embeddings, without instantiating the intermediate embeddings. @@ -277,7 +273,6 @@ class EmbeddingBag(Module): def reset_parameters(self): init.normal_(self.weight) - @weak_script_method def forward(self, input, offsets=None, per_sample_weights=None): # type: (Tensor, Optional[Tensor], Optional[Tensor]) -> Tensor return F.embedding_bag(input, self.weight, offsets, diff --git a/torch/nn/modules/upsampling.py b/torch/nn/modules/upsampling.py index 55b6f93c084..43140b734a5 100644 --- a/torch/nn/modules/upsampling.py +++ b/torch/nn/modules/upsampling.py @@ -1,9 +1,7 @@ from .module import Module from .. import functional as F -from ..._jit_internal import weak_module, weak_script_method -@weak_module class Upsample(Module): r"""Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D (volumetric) data. @@ -129,7 +127,6 @@ class Upsample(Module): self.mode = mode self.align_corners = align_corners - @weak_script_method def forward(self, input): return F.interpolate(input, self.size, self.scale_factor, self.mode, self.align_corners) @@ -142,7 +139,6 @@ class Upsample(Module): return info -@weak_module class UpsamplingNearest2d(Upsample): r"""Applies a 2D nearest neighbor upsampling to an input signal composed of several input channels. @@ -188,7 +184,6 @@ class UpsamplingNearest2d(Upsample): super(UpsamplingNearest2d, self).__init__(size, scale_factor, mode='nearest') -@weak_module class UpsamplingBilinear2d(Upsample): r"""Applies a 2D bilinear upsampling to an input signal composed of several input channels. diff --git a/torch/nn/parallel/replicate.py b/torch/nn/parallel/replicate.py index d6ecac659cf..ba6a44fa7db 100644 --- a/torch/nn/parallel/replicate.py +++ b/torch/nn/parallel/replicate.py @@ -23,10 +23,9 @@ def _is_jit_enabled(): # Check if we can safely replicate the module. -# there are three types of module: +# there are two types of module: # 1. python modules -# 2. weak python modules (nn.Module annotated by @weak_module) -# 3. ScriptModule +# 2. ScriptModule # # currently a module cannot be replicated properly if the descendants of # any ScriptModule contains python module (type 1 above) diff --git a/torch/nn/quantized/modules/activation.py b/torch/nn/quantized/modules/activation.py index 11a16e8a9ff..bf0aa8ca3ed 100644 --- a/torch/nn/quantized/modules/activation.py +++ b/torch/nn/quantized/modules/activation.py @@ -5,9 +5,7 @@ from __future__ import unicode_literals from .. import functional as F from ...modules.module import Module -from ...._jit_internal import weak_module, weak_script_method -@weak_module class ReLU(Module): r"""Applies quantized rectified linear unit function element-wise: @@ -37,7 +35,6 @@ class ReLU(Module): assert not inplace, 'torch.nn.quantized.ReLU does not support inplace' - @weak_script_method def forward(self, input): return F.relu(input) diff --git a/torch/nn/quantized/modules/linear.py b/torch/nn/quantized/modules/linear.py index e3bdb2c4ca3..33e1edfff0f 100644 --- a/torch/nn/quantized/modules/linear.py +++ b/torch/nn/quantized/modules/linear.py @@ -2,9 +2,7 @@ from __future__ import absolute_import, division, print_function, unicode_litera import torch from ...modules.module import Module from ...modules.linear import Linear as NNLinear -from ...._jit_internal import weak_module -@weak_module class Quantize(Module): r"""Quantizes an incoming tensor Args: @@ -39,7 +37,6 @@ class Quantize(Module): def from_float(mod): return Quantize(mod.qparams[0].item(), mod.qparams[1].item(), torch.quint8) -@weak_module class DeQuantize(Module): r"""Dequantizes an incoming tensor @@ -65,7 +62,6 @@ class DeQuantize(Module): def from_float(mod): return DeQuantize() -@weak_module class Linear(NNLinear): r""" A quantized linear module with quantized tensor as inputs