mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Use nn module tests in test_jit (#14238)
Summary: This PR adds weak modules for all activation modules and uses `test_nn` module tests to test weak modules that have been annotated with `weak_module` and therefore are in `torch._jit_internal._weak_types` Also depends on #14379 Pull Request resolved: https://github.com/pytorch/pytorch/pull/14238 Differential Revision: D13192230 Pulled By: driazati fbshipit-source-id: 36488960b6c91448b38c0fa65422539a93af8c5e
This commit is contained in:
parent
a0def0b57e
commit
4cdcbbf410
|
|
@ -51,14 +51,14 @@ module_tests = [
|
|||
),
|
||||
dict(
|
||||
module_name='Threshold',
|
||||
constructor_args=(2, 1),
|
||||
constructor_args=(2., 1.),
|
||||
input_size=(2, 3, 4, 5),
|
||||
check_inplace=True,
|
||||
desc='threshold_value'
|
||||
),
|
||||
dict(
|
||||
module_name='Threshold',
|
||||
constructor_args=(2, 10),
|
||||
constructor_args=(2., 10.),
|
||||
input_size=(2, 3, 4, 5),
|
||||
desc='large_value'
|
||||
),
|
||||
|
|
|
|||
|
|
@ -26,36 +26,37 @@ graph(%0 : Double(1, 3, 224, 224)
|
|||
%25 : int[] = prim::ListConstruct(%24, %24), scope: AlexNet/Sequential[features]/Conv2d[0]
|
||||
%26 : bool = prim::Constant[value=1](), scope: AlexNet/Sequential[features]/Conv2d[0]
|
||||
%input.1 : Double(1, 64, 55, 55) = aten::_convolution(%0, %1, %2, %18, %20, %22, %23, %25, %21, %23, %23, %26), scope: AlexNet/Sequential[features]/Conv2d[0]
|
||||
%input.2 : Double(1, 64, 55, 55) = aten::threshold_(%input.1, %24, %24), scope: AlexNet/Sequential[features]/ReLU[1]
|
||||
%29 : int = prim::Constant[value=3](), scope: AlexNet/Sequential[features]/MaxPool2d[2]
|
||||
%30 : int[] = prim::ListConstruct(%29, %29), scope: AlexNet/Sequential[features]/MaxPool2d[2]
|
||||
%31 : Double(1, 64, 27, 27), %32 : Long(1, 64, 27, 27) = aten::max_pool2d_with_indices(%input.2, %30, %20, %25, %22, %23), scope: AlexNet/Sequential[features]/MaxPool2d[2]
|
||||
%input.3 : Double(1, 192, 27, 27) = aten::_convolution(%31, %3, %4, %22, %20, %22, %23, %25, %21, %23, %23, %26), scope: AlexNet/Sequential[features]/Conv2d[3]
|
||||
%input.4 : Double(1, 192, 27, 27) = aten::threshold_(%input.3, %24, %24), scope: AlexNet/Sequential[features]/ReLU[4]
|
||||
%35 : Double(1, 192, 13, 13), %36 : Long(1, 192, 13, 13) = aten::max_pool2d_with_indices(%input.4, %30, %20, %25, %22, %23), scope: AlexNet/Sequential[features]/MaxPool2d[5]
|
||||
%input.5 : Double(1, 384, 13, 13) = aten::_convolution(%35, %5, %6, %22, %22, %22, %23, %25, %21, %23, %23, %26), scope: AlexNet/Sequential[features]/Conv2d[6]
|
||||
%38 : Double(1, 384, 13, 13) = aten::threshold_(%input.5, %24, %24), scope: AlexNet/Sequential[features]/ReLU[7]
|
||||
%input.6 : Double(1, 256, 13, 13) = aten::_convolution(%38, %7, %8, %22, %22, %22, %23, %25, %21, %23, %23, %26), scope: AlexNet/Sequential[features]/Conv2d[8]
|
||||
%40 : Double(1, 256, 13, 13) = aten::threshold_(%input.6, %24, %24), scope: AlexNet/Sequential[features]/ReLU[9]
|
||||
%input.7 : Double(1, 256, 13, 13) = aten::_convolution(%40, %9, %10, %22, %22, %22, %23, %25, %21, %23, %23, %26), scope: AlexNet/Sequential[features]/Conv2d[10]
|
||||
%input.8 : Double(1, 256, 13, 13) = aten::threshold_(%input.7, %24, %24), scope: AlexNet/Sequential[features]/ReLU[11]
|
||||
%43 : Double(1, 256, 6, 6), %44 : Long(1, 256, 6, 6) = aten::max_pool2d_with_indices(%input.8, %30, %20, %25, %22, %23), scope: AlexNet/Sequential[features]/MaxPool2d[12]
|
||||
%45 : int = aten::size(%43, %24), scope: AlexNet
|
||||
%46 : Long() = prim::NumToTensor(%45), scope: AlexNet
|
||||
%47 : int = prim::TensorToNum(%46), scope: AlexNet
|
||||
%48 : int = prim::Constant[value=9216](), scope: AlexNet
|
||||
%49 : int[] = prim::ListConstruct(%47, %48), scope: AlexNet
|
||||
%input.9 : Double(1, 9216) = aten::view(%43, %49), scope: AlexNet
|
||||
%51 : float = prim::Constant[value=0.5](), scope: AlexNet/Sequential[classifier]/Dropout[0]
|
||||
%input.10 : Double(1, 9216) = aten::dropout(%input.9, %51, %26), scope: AlexNet/Sequential[classifier]/Dropout[0]
|
||||
%53 : Double(9216!, 4096!) = aten::t(%11), scope: AlexNet/Sequential[classifier]/Linear[1]
|
||||
%input.11 : Double(1, 4096) = aten::addmm(%12, %input.10, %53, %21, %21), scope: AlexNet/Sequential[classifier]/Linear[1]
|
||||
%input.12 : Double(1, 4096) = aten::threshold_(%input.11, %24, %24), scope: AlexNet/Sequential[classifier]/ReLU[2]
|
||||
%input.13 : Double(1, 4096) = aten::dropout(%input.12, %51, %26), scope: AlexNet/Sequential[classifier]/Dropout[3]
|
||||
%57 : Double(4096!, 4096!) = aten::t(%13), scope: AlexNet/Sequential[classifier]/Linear[4]
|
||||
%input.14 : Double(1, 4096) = aten::addmm(%14, %input.13, %57, %21, %21), scope: AlexNet/Sequential[classifier]/Linear[4]
|
||||
%input : Double(1, 4096) = aten::threshold_(%input.14, %24, %24), scope: AlexNet/Sequential[classifier]/ReLU[5]
|
||||
%60 : Double(4096!, 1000!) = aten::t(%15), scope: AlexNet/Sequential[classifier]/Linear[6]
|
||||
%61 : Double(1, 1000) = aten::addmm(%16, %input, %60, %21, %21), scope: AlexNet/Sequential[classifier]/Linear[6]
|
||||
return (%61);
|
||||
%28 : float = prim::Constant[value=0](), scope: AlexNet/Sequential[features]/ReLU[1]
|
||||
%input.2 : Double(1, 64, 55, 55) = aten::threshold_(%input.1, %28, %28), scope: AlexNet/Sequential[features]/ReLU[1]
|
||||
%30 : int = prim::Constant[value=3](), scope: AlexNet/Sequential[features]/MaxPool2d[2]
|
||||
%31 : int[] = prim::ListConstruct(%30, %30), scope: AlexNet/Sequential[features]/MaxPool2d[2]
|
||||
%32 : Double(1, 64, 27, 27), %33 : Long(1, 64, 27, 27) = aten::max_pool2d_with_indices(%input.2, %31, %20, %25, %22, %23), scope: AlexNet/Sequential[features]/MaxPool2d[2]
|
||||
%input.3 : Double(1, 192, 27, 27) = aten::_convolution(%32, %3, %4, %22, %20, %22, %23, %25, %21, %23, %23, %26), scope: AlexNet/Sequential[features]/Conv2d[3]
|
||||
%input.4 : Double(1, 192, 27, 27) = aten::threshold_(%input.3, %28, %28), scope: AlexNet/Sequential[features]/ReLU[4]
|
||||
%36 : Double(1, 192, 13, 13), %37 : Long(1, 192, 13, 13) = aten::max_pool2d_with_indices(%input.4, %31, %20, %25, %22, %23), scope: AlexNet/Sequential[features]/MaxPool2d[5]
|
||||
%input.5 : Double(1, 384, 13, 13) = aten::_convolution(%36, %5, %6, %22, %22, %22, %23, %25, %21, %23, %23, %26), scope: AlexNet/Sequential[features]/Conv2d[6]
|
||||
%39 : Double(1, 384, 13, 13) = aten::threshold_(%input.5, %28, %28), scope: AlexNet/Sequential[features]/ReLU[7]
|
||||
%input.6 : Double(1, 256, 13, 13) = aten::_convolution(%39, %7, %8, %22, %22, %22, %23, %25, %21, %23, %23, %26), scope: AlexNet/Sequential[features]/Conv2d[8]
|
||||
%41 : Double(1, 256, 13, 13) = aten::threshold_(%input.6, %28, %28), scope: AlexNet/Sequential[features]/ReLU[9]
|
||||
%input.7 : Double(1, 256, 13, 13) = aten::_convolution(%41, %9, %10, %22, %22, %22, %23, %25, %21, %23, %23, %26), scope: AlexNet/Sequential[features]/Conv2d[10]
|
||||
%input.8 : Double(1, 256, 13, 13) = aten::threshold_(%input.7, %28, %28), scope: AlexNet/Sequential[features]/ReLU[11]
|
||||
%44 : Double(1, 256, 6, 6), %45 : Long(1, 256, 6, 6) = aten::max_pool2d_with_indices(%input.8, %31, %20, %25, %22, %23), scope: AlexNet/Sequential[features]/MaxPool2d[12]
|
||||
%46 : int = aten::size(%44, %24), scope: AlexNet
|
||||
%47 : Long() = prim::NumToTensor(%46), scope: AlexNet
|
||||
%48 : int = prim::TensorToNum(%47), scope: AlexNet
|
||||
%49 : int = prim::Constant[value=9216](), scope: AlexNet
|
||||
%50 : int[] = prim::ListConstruct(%48, %49), scope: AlexNet
|
||||
%input.9 : Double(1, 9216) = aten::view(%44, %50), scope: AlexNet
|
||||
%52 : float = prim::Constant[value=0.5](), scope: AlexNet/Sequential[classifier]/Dropout[0]
|
||||
%input.10 : Double(1, 9216) = aten::dropout(%input.9, %52, %26), scope: AlexNet/Sequential[classifier]/Dropout[0]
|
||||
%54 : Double(9216!, 4096!) = aten::t(%11), scope: AlexNet/Sequential[classifier]/Linear[1]
|
||||
%input.11 : Double(1, 4096) = aten::addmm(%12, %input.10, %54, %21, %21), scope: AlexNet/Sequential[classifier]/Linear[1]
|
||||
%input.12 : Double(1, 4096) = aten::threshold_(%input.11, %28, %28), scope: AlexNet/Sequential[classifier]/ReLU[2]
|
||||
%input.13 : Double(1, 4096) = aten::dropout(%input.12, %52, %26), scope: AlexNet/Sequential[classifier]/Dropout[3]
|
||||
%58 : Double(4096!, 4096!) = aten::t(%13), scope: AlexNet/Sequential[classifier]/Linear[4]
|
||||
%input.14 : Double(1, 4096) = aten::addmm(%14, %input.13, %58, %21, %21), scope: AlexNet/Sequential[classifier]/Linear[4]
|
||||
%input : Double(1, 4096) = aten::threshold_(%input.14, %28, %28), scope: AlexNet/Sequential[classifier]/ReLU[5]
|
||||
%61 : Double(4096!, 1000!) = aten::t(%15), scope: AlexNet/Sequential[classifier]/Linear[6]
|
||||
%62 : Double(1, 1000) = aten::addmm(%16, %input, %61, %21, %21), scope: AlexNet/Sequential[classifier]/Linear[6]
|
||||
return (%62);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,8 +11,10 @@ from torch.autograd.function import traceable
|
|||
from torch.testing import assert_allclose
|
||||
from torch.onnx import OperatorExportTypes
|
||||
from torch._six import inf, PY2
|
||||
from common_utils import (TestCase, run_tests, IS_WINDOWS, TEST_WITH_UBSAN,
|
||||
skipIfRocm, skipIfNoLapack, suppress_warnings, load_tests, IS_SANDCASTLE)
|
||||
from common_utils import TestCase, run_tests, IS_WINDOWS, TEST_WITH_UBSAN, \
|
||||
skipIfRocm, skipIfNoLapack, suppress_warnings, load_tests, IS_SANDCASTLE, \
|
||||
freeze_rng_state
|
||||
from test_nn import module_tests, new_module_tests
|
||||
from textwrap import dedent
|
||||
import os
|
||||
import io
|
||||
|
|
@ -41,6 +43,7 @@ from torch.jit import BatchTensor
|
|||
from test_module.future_div import div_int_future, div_float_future
|
||||
from test_module.no_future_div import div_int_nofuture, div_float_nofuture
|
||||
|
||||
|
||||
# load_tests from common_utils is used to automatically filter tests for
|
||||
# sharding on sandcastle. This line silences flake warnings
|
||||
load_tests = load_tests
|
||||
|
|
@ -446,9 +449,8 @@ class JitTestCase(TestCase):
|
|||
|
||||
def runAndSaveRNG(self, func, inputs, kwargs=None):
|
||||
kwargs = kwargs if kwargs else {}
|
||||
initial_rng_state = torch.get_rng_state()
|
||||
results = func(*inputs, **kwargs)
|
||||
torch.set_rng_state(initial_rng_state)
|
||||
with freeze_rng_state():
|
||||
results = func(*inputs, **kwargs)
|
||||
return results
|
||||
|
||||
|
||||
|
|
@ -9396,6 +9398,11 @@ EXCLUDE_SCRIPT = {
|
|||
'test_nn_max_unpool1d',
|
||||
}
|
||||
|
||||
EXCLUDE_SCRIPT_MODULES = {
|
||||
'test_nn_LPPool2d_norm',
|
||||
'test_nn_LPPool1d_norm',
|
||||
}
|
||||
|
||||
DISABLE_AUTODIFF_SUBGRAPH_INLINING = {
|
||||
'test_nn_avg_pool2d',
|
||||
'test_nn_log_softmax',
|
||||
|
|
@ -10199,10 +10206,37 @@ def add_nn_functional_test(name, self_size, args, variant_name='', skipTestIf=()
|
|||
post_add_test(test_name, skipTestIf, do_test)
|
||||
|
||||
|
||||
def add_nn_module_test(module_name, constructor_args, call_args,
|
||||
use_as_constant=False, skipTestIf=()):
|
||||
def add_nn_module_test(*args, **kwargs):
|
||||
if 'module_name' in kwargs:
|
||||
name = kwargs['module_name']
|
||||
elif 'fullname' in kwargs:
|
||||
name = kwargs['fullname']
|
||||
elif 'constructor' in kwargs:
|
||||
name = kwargs['constructor'].__name__
|
||||
|
||||
class_name = name.split("_")[0]
|
||||
|
||||
module = getattr(torch.nn, class_name, None)
|
||||
if module is None or torch._jit_internal._weak_types.get(module) is None:
|
||||
return
|
||||
|
||||
if 'desc' in kwargs and 'eval' in kwargs['desc']:
|
||||
# eval() is not supported, so skip these tests
|
||||
return
|
||||
|
||||
test_name = name
|
||||
if 'desc' in kwargs:
|
||||
test_name = "{}_{}".format(test_name, kwargs['desc'])
|
||||
test_name = 'test_nn_{}'.format(test_name)
|
||||
|
||||
def do_test(self):
|
||||
nn_module = getattr(torch.nn, module_name)
|
||||
if test_name in EXCLUDE_SCRIPT_MODULES:
|
||||
return
|
||||
if 'constructor' in kwargs:
|
||||
nn_module = kwargs['constructor']
|
||||
else:
|
||||
nn_module = getattr(torch.nn, name)
|
||||
constructor_args = kwargs.get('constructor_args', ())
|
||||
|
||||
# Construct a script module that passes arguments through
|
||||
# to self.submodule
|
||||
|
|
@ -10215,7 +10249,7 @@ def add_nn_module_test(module_name, constructor_args, call_args,
|
|||
script = script_method_template.format(method_args, call)
|
||||
|
||||
submodule_constants = []
|
||||
if use_as_constant:
|
||||
if kwargs.get('is_constant'):
|
||||
submodule_constants = ['submodule']
|
||||
|
||||
# Create module to use the script method
|
||||
|
|
@ -10241,13 +10275,16 @@ def add_nn_module_test(module_name, constructor_args, call_args,
|
|||
return module(*args)
|
||||
|
||||
# Check against Python module as reference
|
||||
args_variable, kwargs_variable = create_input(call_args)
|
||||
if 'input_fn' in kwargs:
|
||||
input_size = tuple(kwargs['input_fn']().size())
|
||||
else:
|
||||
input_size = kwargs['input_size']
|
||||
args_variable, kwargs_variable = create_input((input_size,))
|
||||
f_args_variable = deepcopy(unpack_variables(args_variable))
|
||||
|
||||
check_against_reference(self, create_script_module, create_nn_module, f_args_variable)
|
||||
|
||||
test_name = 'test_nn_{}'.format(module_name)
|
||||
post_add_test(test_name, skipTestIf, do_test)
|
||||
post_add_test(test_name, (), do_test)
|
||||
|
||||
|
||||
def post_add_test(test_name, skipTestIf, do_test):
|
||||
|
|
@ -10411,8 +10448,8 @@ for test in autograd_method_tests:
|
|||
for test in nn_functional_tests:
|
||||
add_nn_functional_test(*test)
|
||||
|
||||
for test in nn_module_tests:
|
||||
add_nn_module_test(*test)
|
||||
for test in module_tests + new_module_tests:
|
||||
add_nn_module_test(**test)
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -8321,7 +8321,7 @@ new_module_tests = [
|
|||
),
|
||||
dict(
|
||||
module_name='LPPool2d',
|
||||
constructor_args=(2, (2, 2), 2),
|
||||
constructor_args=(2, 2, 2),
|
||||
input_size=(1, 3, 7, 7),
|
||||
),
|
||||
dict(
|
||||
|
|
@ -9005,7 +9005,7 @@ new_module_tests = [
|
|||
),
|
||||
dict(
|
||||
module_name='Threshold',
|
||||
constructor_args=(2, 1),
|
||||
constructor_args=(2., 1.),
|
||||
input_size=(),
|
||||
check_inplace=True,
|
||||
desc='threshold_value_scalar'
|
||||
|
|
|
|||
|
|
@ -519,7 +519,7 @@ def max_unpool3d(input, indices, kernel_size, stride=None, padding=0,
|
|||
|
||||
@torch._jit_internal.weak_script
|
||||
def lp_pool2d(input, norm_type, kernel_size, stride=None, ceil_mode=False):
|
||||
# type: (Tensor, float, int, Optional[BroadcastingList1[int]], bool) -> Tensor
|
||||
# type: (Tensor, float, int, Optional[BroadcastingList2[int]], bool) -> Tensor
|
||||
r"""Applies a 2D power-average pooling over an input signal composed of
|
||||
several input planes. If the sum of all inputs to the power of `p` is
|
||||
zero, the gradient is set to zero as well.
|
||||
|
|
|
|||
|
|
@ -9,8 +9,6 @@ from ..._jit_internal import weak_module, weak_script_method
|
|||
|
||||
@torch._jit_internal.weak_module
|
||||
class Threshold(Module):
|
||||
__constants__ = ['threshold', 'value', 'inplace']
|
||||
|
||||
r"""Thresholds each element of the input Tensor
|
||||
|
||||
Threshold is defined as:
|
||||
|
|
@ -38,6 +36,7 @@ class Threshold(Module):
|
|||
>>> input = torch.randn(2)
|
||||
>>> output = m(input)
|
||||
"""
|
||||
__constants__ = ['threshold', 'value', 'inplace']
|
||||
|
||||
def __init__(self, threshold, value, inplace=False):
|
||||
super(Threshold, self).__init__()
|
||||
|
|
@ -57,6 +56,7 @@ class Threshold(Module):
|
|||
)
|
||||
|
||||
|
||||
@torch._jit_internal.weak_module
|
||||
class ReLU(Threshold):
|
||||
r"""Applies the rectified linear unit function element-wise
|
||||
:math:`\text{ReLU}(x)= \max(0, x)`
|
||||
|
|
@ -79,7 +79,7 @@ class ReLU(Threshold):
|
|||
"""
|
||||
|
||||
def __init__(self, inplace=False):
|
||||
super(ReLU, self).__init__(0, 0, inplace)
|
||||
super(ReLU, self).__init__(0., 0., inplace)
|
||||
|
||||
def extra_repr(self):
|
||||
inplace_str = 'inplace' if self.inplace else ''
|
||||
|
|
@ -143,6 +143,7 @@ class RReLU(Module):
|
|||
return 'lower={}, upper={}{}'.format(self.lower, self.upper, inplace_str)
|
||||
|
||||
|
||||
@torch._jit_internal.weak_module
|
||||
class Hardtanh(Module):
|
||||
r"""Applies the HardTanh function element-wise
|
||||
|
||||
|
|
@ -179,8 +180,9 @@ class Hardtanh(Module):
|
|||
>>> input = torch.randn(2)
|
||||
>>> output = m(input)
|
||||
"""
|
||||
__constants__ = ['min_val', 'max_val', 'inplace']
|
||||
|
||||
def __init__(self, min_val=-1, max_val=1, inplace=False, min_value=None, max_value=None):
|
||||
def __init__(self, min_val=-1., max_val=1., inplace=False, min_value=None, max_value=None):
|
||||
super(Hardtanh, self).__init__()
|
||||
if min_value is not None:
|
||||
warnings.warn("keyword argument min_value is deprecated and renamed to min_val")
|
||||
|
|
@ -194,6 +196,7 @@ class Hardtanh(Module):
|
|||
self.inplace = inplace
|
||||
assert self.max_val > self.min_val
|
||||
|
||||
@torch._jit_internal.weak_script_method
|
||||
def forward(self, input):
|
||||
return F.hardtanh(input, self.min_val, self.max_val, self.inplace)
|
||||
|
||||
|
|
@ -204,6 +207,7 @@ class Hardtanh(Module):
|
|||
)
|
||||
|
||||
|
||||
@torch._jit_internal.weak_module
|
||||
class ReLU6(Hardtanh):
|
||||
r"""Applies the element-wise function:
|
||||
|
||||
|
|
@ -228,7 +232,7 @@ class ReLU6(Hardtanh):
|
|||
"""
|
||||
|
||||
def __init__(self, inplace=False):
|
||||
super(ReLU6, self).__init__(0, 6, inplace)
|
||||
super(ReLU6, self).__init__(0., 6., inplace)
|
||||
|
||||
def extra_repr(self):
|
||||
inplace_str = 'inplace' if self.inplace else ''
|
||||
|
|
@ -288,6 +292,7 @@ class Tanh(Module):
|
|||
return torch.tanh(input)
|
||||
|
||||
|
||||
@torch._jit_internal.weak_module
|
||||
class ELU(Module):
|
||||
r"""Applies the element-wise function:
|
||||
|
||||
|
|
@ -311,12 +316,14 @@ class ELU(Module):
|
|||
>>> input = torch.randn(2)
|
||||
>>> output = m(input)
|
||||
"""
|
||||
__constants__ = ['alpha', 'inplace']
|
||||
|
||||
def __init__(self, alpha=1., inplace=False):
|
||||
super(ELU, self).__init__()
|
||||
self.alpha = alpha
|
||||
self.inplace = inplace
|
||||
|
||||
@torch._jit_internal.weak_script_method
|
||||
def forward(self, input):
|
||||
return F.elu(input, self.alpha, self.inplace)
|
||||
|
||||
|
|
@ -325,6 +332,7 @@ class ELU(Module):
|
|||
return 'alpha={}{}'.format(self.alpha, inplace_str)
|
||||
|
||||
|
||||
@torch._jit_internal.weak_module
|
||||
class CELU(Module):
|
||||
r"""Applies the element-wise function:
|
||||
|
||||
|
|
@ -353,12 +361,14 @@ class CELU(Module):
|
|||
.. _`Continuously Differentiable Exponential Linear Units`:
|
||||
https://arxiv.org/abs/1704.07483
|
||||
"""
|
||||
__constants__ = ['alpha', 'inplace']
|
||||
|
||||
def __init__(self, alpha=1., inplace=False):
|
||||
super(CELU, self).__init__()
|
||||
self.alpha = alpha
|
||||
self.inplace = inplace
|
||||
|
||||
@torch._jit_internal.weak_script_method
|
||||
def forward(self, input):
|
||||
return F.celu(input, self.alpha, self.inplace)
|
||||
|
||||
|
|
@ -367,6 +377,7 @@ class CELU(Module):
|
|||
return 'alpha={}{}'.format(self.alpha, inplace_str)
|
||||
|
||||
|
||||
@torch._jit_internal.weak_module
|
||||
class SELU(Module):
|
||||
r"""Applied element-wise, as:
|
||||
|
||||
|
|
@ -396,11 +407,13 @@ class SELU(Module):
|
|||
|
||||
.. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515
|
||||
"""
|
||||
__constants__ = ['inplace']
|
||||
|
||||
def __init__(self, inplace=False):
|
||||
super(SELU, self).__init__()
|
||||
self.inplace = inplace
|
||||
|
||||
@torch._jit_internal.weak_script_method
|
||||
def forward(self, input):
|
||||
return F.selu(input, self.inplace)
|
||||
|
||||
|
|
@ -409,6 +422,7 @@ class SELU(Module):
|
|||
return inplace_str
|
||||
|
||||
|
||||
@torch._jit_internal.weak_module
|
||||
class GLU(Module):
|
||||
r"""Applies the gated linear unit function
|
||||
:math:`{GLU}(a, b)= a \otimes \sigma(b)` where :math:`a` is the first half
|
||||
|
|
@ -428,11 +442,13 @@ class GLU(Module):
|
|||
>>> input = torch.randn(4, 2)
|
||||
>>> output = m(input)
|
||||
"""
|
||||
__constants__ = ['dim']
|
||||
|
||||
def __init__(self, dim=-1):
|
||||
super(GLU, self).__init__()
|
||||
self.dim = dim
|
||||
|
||||
@torch._jit_internal.weak_script_method
|
||||
def forward(self, input):
|
||||
return F.glu(input, self.dim)
|
||||
|
||||
|
|
@ -482,6 +498,7 @@ class Hardshrink(Module):
|
|||
return '{}'.format(self.lambd)
|
||||
|
||||
|
||||
@torch._jit_internal.weak_module
|
||||
class LeakyReLU(Module):
|
||||
r"""Applies the element-wise function:
|
||||
|
||||
|
|
@ -515,12 +532,14 @@ class LeakyReLU(Module):
|
|||
>>> input = torch.randn(2)
|
||||
>>> output = m(input)
|
||||
"""
|
||||
__constants__ = ['inplace', 'negative_slope']
|
||||
|
||||
def __init__(self, negative_slope=1e-2, inplace=False):
|
||||
super(LeakyReLU, self).__init__()
|
||||
self.negative_slope = negative_slope
|
||||
self.inplace = inplace
|
||||
|
||||
@torch._jit_internal.weak_script_method
|
||||
def forward(self, input):
|
||||
return F.leaky_relu(input, self.negative_slope, self.inplace)
|
||||
|
||||
|
|
@ -529,6 +548,7 @@ class LeakyReLU(Module):
|
|||
return 'negative_slope={}{}'.format(self.negative_slope, inplace_str)
|
||||
|
||||
|
||||
@torch._jit_internal.weak_module
|
||||
class LogSigmoid(Module):
|
||||
r"""Applies the element-wise function:
|
||||
|
||||
|
|
@ -548,10 +568,12 @@ class LogSigmoid(Module):
|
|||
>>> output = m(input)
|
||||
"""
|
||||
|
||||
@torch._jit_internal.weak_script_method
|
||||
def forward(self, input):
|
||||
return F.logsigmoid(input)
|
||||
|
||||
|
||||
@torch._jit_internal.weak_module
|
||||
class Softplus(Module):
|
||||
r"""Applies the element-wise function:
|
||||
|
||||
|
|
@ -581,12 +603,14 @@ class Softplus(Module):
|
|||
>>> input = torch.randn(2)
|
||||
>>> output = m(input)
|
||||
"""
|
||||
__constants__ = ['beta', 'threshold']
|
||||
|
||||
def __init__(self, beta=1, threshold=20):
|
||||
super(Softplus, self).__init__()
|
||||
self.beta = beta
|
||||
self.threshold = threshold
|
||||
|
||||
@torch._jit_internal.weak_script_method
|
||||
def forward(self, input):
|
||||
return F.softplus(input, self.beta, self.threshold)
|
||||
|
||||
|
|
@ -753,6 +777,7 @@ class Tanhshrink(Module):
|
|||
return F.tanhshrink(input)
|
||||
|
||||
|
||||
@torch._jit_internal.weak_module
|
||||
class Softmin(Module):
|
||||
r"""Applies the Softmin function to an n-dimensional input Tensor
|
||||
rescaling them so that the elements of the n-dimensional output Tensor
|
||||
|
|
@ -779,15 +804,18 @@ class Softmin(Module):
|
|||
>>> input = torch.randn(2, 3)
|
||||
>>> output = m(input)
|
||||
"""
|
||||
__constants__ = ['dim']
|
||||
|
||||
def __init__(self, dim=None):
|
||||
super(Softmin, self).__init__()
|
||||
self.dim = dim
|
||||
|
||||
@torch._jit_internal.weak_script_method
|
||||
def forward(self, input):
|
||||
return F.softmin(input, self.dim, _stacklevel=5)
|
||||
|
||||
|
||||
@torch._jit_internal.weak_module
|
||||
class Softmax(Module):
|
||||
r"""Applies the Softmax function to an n-dimensional input Tensor
|
||||
rescaling them so that the elements of the n-dimensional output Tensor
|
||||
|
|
@ -821,6 +849,7 @@ class Softmax(Module):
|
|||
>>> input = torch.randn(2, 3)
|
||||
>>> output = m(input)
|
||||
"""
|
||||
__constants__ = ['dim']
|
||||
|
||||
def __init__(self, dim=None):
|
||||
super(Softmax, self).__init__()
|
||||
|
|
@ -831,10 +860,12 @@ class Softmax(Module):
|
|||
if not hasattr(self, 'dim'):
|
||||
self.dim = None
|
||||
|
||||
@torch._jit_internal.weak_script_method
|
||||
def forward(self, input):
|
||||
return F.softmax(input, self.dim, _stacklevel=5)
|
||||
|
||||
|
||||
@torch._jit_internal.weak_module
|
||||
class Softmax2d(Module):
|
||||
r"""Applies SoftMax over features to each spatial location.
|
||||
|
||||
|
|
@ -857,11 +888,13 @@ class Softmax2d(Module):
|
|||
>>> output = m(input)
|
||||
"""
|
||||
|
||||
@torch._jit_internal.weak_script_method
|
||||
def forward(self, input):
|
||||
assert input.dim() == 4, 'Softmax2d requires a 4D tensor as input'
|
||||
return F.softmax(input, 1, _stacklevel=5)
|
||||
|
||||
|
||||
@torch._jit_internal.weak_module
|
||||
class LogSoftmax(Module):
|
||||
r"""Applies the :math:`\log(\text{Softmax}(x))` function to an n-dimensional
|
||||
input Tensor. The LogSoftmax formulation can be simplified as:
|
||||
|
|
@ -887,6 +920,7 @@ class LogSoftmax(Module):
|
|||
>>> input = torch.randn(2, 3)
|
||||
>>> output = m(input)
|
||||
"""
|
||||
__constants__ = ['dim']
|
||||
|
||||
def __init__(self, dim=None):
|
||||
super(LogSoftmax, self).__init__()
|
||||
|
|
@ -897,5 +931,6 @@ class LogSoftmax(Module):
|
|||
if not hasattr(self, 'dim'):
|
||||
self.dim = None
|
||||
|
||||
@torch._jit_internal.weak_script_method
|
||||
def forward(self, input):
|
||||
return F.log_softmax(input, self.dim, _stacklevel=5)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user