Use nn module tests in test_jit (#14238)

Summary:
This PR adds weak modules for all activation modules and uses `test_nn` module tests to test weak modules that have been annotated with `weak_module` and therefore are in `torch._jit_internal._weak_types`

Also depends on #14379
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14238

Differential Revision: D13192230

Pulled By: driazati

fbshipit-source-id: 36488960b6c91448b38c0fa65422539a93af8c5e
This commit is contained in:
David Riazati 2018-11-27 21:17:51 -08:00 committed by Facebook Github Bot
parent a0def0b57e
commit 4cdcbbf410
6 changed files with 129 additions and 56 deletions

View File

@ -51,14 +51,14 @@ module_tests = [
),
dict(
module_name='Threshold',
constructor_args=(2, 1),
constructor_args=(2., 1.),
input_size=(2, 3, 4, 5),
check_inplace=True,
desc='threshold_value'
),
dict(
module_name='Threshold',
constructor_args=(2, 10),
constructor_args=(2., 10.),
input_size=(2, 3, 4, 5),
desc='large_value'
),

View File

@ -26,36 +26,37 @@ graph(%0 : Double(1, 3, 224, 224)
%25 : int[] = prim::ListConstruct(%24, %24), scope: AlexNet/Sequential[features]/Conv2d[0]
%26 : bool = prim::Constant[value=1](), scope: AlexNet/Sequential[features]/Conv2d[0]
%input.1 : Double(1, 64, 55, 55) = aten::_convolution(%0, %1, %2, %18, %20, %22, %23, %25, %21, %23, %23, %26), scope: AlexNet/Sequential[features]/Conv2d[0]
%input.2 : Double(1, 64, 55, 55) = aten::threshold_(%input.1, %24, %24), scope: AlexNet/Sequential[features]/ReLU[1]
%29 : int = prim::Constant[value=3](), scope: AlexNet/Sequential[features]/MaxPool2d[2]
%30 : int[] = prim::ListConstruct(%29, %29), scope: AlexNet/Sequential[features]/MaxPool2d[2]
%31 : Double(1, 64, 27, 27), %32 : Long(1, 64, 27, 27) = aten::max_pool2d_with_indices(%input.2, %30, %20, %25, %22, %23), scope: AlexNet/Sequential[features]/MaxPool2d[2]
%input.3 : Double(1, 192, 27, 27) = aten::_convolution(%31, %3, %4, %22, %20, %22, %23, %25, %21, %23, %23, %26), scope: AlexNet/Sequential[features]/Conv2d[3]
%input.4 : Double(1, 192, 27, 27) = aten::threshold_(%input.3, %24, %24), scope: AlexNet/Sequential[features]/ReLU[4]
%35 : Double(1, 192, 13, 13), %36 : Long(1, 192, 13, 13) = aten::max_pool2d_with_indices(%input.4, %30, %20, %25, %22, %23), scope: AlexNet/Sequential[features]/MaxPool2d[5]
%input.5 : Double(1, 384, 13, 13) = aten::_convolution(%35, %5, %6, %22, %22, %22, %23, %25, %21, %23, %23, %26), scope: AlexNet/Sequential[features]/Conv2d[6]
%38 : Double(1, 384, 13, 13) = aten::threshold_(%input.5, %24, %24), scope: AlexNet/Sequential[features]/ReLU[7]
%input.6 : Double(1, 256, 13, 13) = aten::_convolution(%38, %7, %8, %22, %22, %22, %23, %25, %21, %23, %23, %26), scope: AlexNet/Sequential[features]/Conv2d[8]
%40 : Double(1, 256, 13, 13) = aten::threshold_(%input.6, %24, %24), scope: AlexNet/Sequential[features]/ReLU[9]
%input.7 : Double(1, 256, 13, 13) = aten::_convolution(%40, %9, %10, %22, %22, %22, %23, %25, %21, %23, %23, %26), scope: AlexNet/Sequential[features]/Conv2d[10]
%input.8 : Double(1, 256, 13, 13) = aten::threshold_(%input.7, %24, %24), scope: AlexNet/Sequential[features]/ReLU[11]
%43 : Double(1, 256, 6, 6), %44 : Long(1, 256, 6, 6) = aten::max_pool2d_with_indices(%input.8, %30, %20, %25, %22, %23), scope: AlexNet/Sequential[features]/MaxPool2d[12]
%45 : int = aten::size(%43, %24), scope: AlexNet
%46 : Long() = prim::NumToTensor(%45), scope: AlexNet
%47 : int = prim::TensorToNum(%46), scope: AlexNet
%48 : int = prim::Constant[value=9216](), scope: AlexNet
%49 : int[] = prim::ListConstruct(%47, %48), scope: AlexNet
%input.9 : Double(1, 9216) = aten::view(%43, %49), scope: AlexNet
%51 : float = prim::Constant[value=0.5](), scope: AlexNet/Sequential[classifier]/Dropout[0]
%input.10 : Double(1, 9216) = aten::dropout(%input.9, %51, %26), scope: AlexNet/Sequential[classifier]/Dropout[0]
%53 : Double(9216!, 4096!) = aten::t(%11), scope: AlexNet/Sequential[classifier]/Linear[1]
%input.11 : Double(1, 4096) = aten::addmm(%12, %input.10, %53, %21, %21), scope: AlexNet/Sequential[classifier]/Linear[1]
%input.12 : Double(1, 4096) = aten::threshold_(%input.11, %24, %24), scope: AlexNet/Sequential[classifier]/ReLU[2]
%input.13 : Double(1, 4096) = aten::dropout(%input.12, %51, %26), scope: AlexNet/Sequential[classifier]/Dropout[3]
%57 : Double(4096!, 4096!) = aten::t(%13), scope: AlexNet/Sequential[classifier]/Linear[4]
%input.14 : Double(1, 4096) = aten::addmm(%14, %input.13, %57, %21, %21), scope: AlexNet/Sequential[classifier]/Linear[4]
%input : Double(1, 4096) = aten::threshold_(%input.14, %24, %24), scope: AlexNet/Sequential[classifier]/ReLU[5]
%60 : Double(4096!, 1000!) = aten::t(%15), scope: AlexNet/Sequential[classifier]/Linear[6]
%61 : Double(1, 1000) = aten::addmm(%16, %input, %60, %21, %21), scope: AlexNet/Sequential[classifier]/Linear[6]
return (%61);
%28 : float = prim::Constant[value=0](), scope: AlexNet/Sequential[features]/ReLU[1]
%input.2 : Double(1, 64, 55, 55) = aten::threshold_(%input.1, %28, %28), scope: AlexNet/Sequential[features]/ReLU[1]
%30 : int = prim::Constant[value=3](), scope: AlexNet/Sequential[features]/MaxPool2d[2]
%31 : int[] = prim::ListConstruct(%30, %30), scope: AlexNet/Sequential[features]/MaxPool2d[2]
%32 : Double(1, 64, 27, 27), %33 : Long(1, 64, 27, 27) = aten::max_pool2d_with_indices(%input.2, %31, %20, %25, %22, %23), scope: AlexNet/Sequential[features]/MaxPool2d[2]
%input.3 : Double(1, 192, 27, 27) = aten::_convolution(%32, %3, %4, %22, %20, %22, %23, %25, %21, %23, %23, %26), scope: AlexNet/Sequential[features]/Conv2d[3]
%input.4 : Double(1, 192, 27, 27) = aten::threshold_(%input.3, %28, %28), scope: AlexNet/Sequential[features]/ReLU[4]
%36 : Double(1, 192, 13, 13), %37 : Long(1, 192, 13, 13) = aten::max_pool2d_with_indices(%input.4, %31, %20, %25, %22, %23), scope: AlexNet/Sequential[features]/MaxPool2d[5]
%input.5 : Double(1, 384, 13, 13) = aten::_convolution(%36, %5, %6, %22, %22, %22, %23, %25, %21, %23, %23, %26), scope: AlexNet/Sequential[features]/Conv2d[6]
%39 : Double(1, 384, 13, 13) = aten::threshold_(%input.5, %28, %28), scope: AlexNet/Sequential[features]/ReLU[7]
%input.6 : Double(1, 256, 13, 13) = aten::_convolution(%39, %7, %8, %22, %22, %22, %23, %25, %21, %23, %23, %26), scope: AlexNet/Sequential[features]/Conv2d[8]
%41 : Double(1, 256, 13, 13) = aten::threshold_(%input.6, %28, %28), scope: AlexNet/Sequential[features]/ReLU[9]
%input.7 : Double(1, 256, 13, 13) = aten::_convolution(%41, %9, %10, %22, %22, %22, %23, %25, %21, %23, %23, %26), scope: AlexNet/Sequential[features]/Conv2d[10]
%input.8 : Double(1, 256, 13, 13) = aten::threshold_(%input.7, %28, %28), scope: AlexNet/Sequential[features]/ReLU[11]
%44 : Double(1, 256, 6, 6), %45 : Long(1, 256, 6, 6) = aten::max_pool2d_with_indices(%input.8, %31, %20, %25, %22, %23), scope: AlexNet/Sequential[features]/MaxPool2d[12]
%46 : int = aten::size(%44, %24), scope: AlexNet
%47 : Long() = prim::NumToTensor(%46), scope: AlexNet
%48 : int = prim::TensorToNum(%47), scope: AlexNet
%49 : int = prim::Constant[value=9216](), scope: AlexNet
%50 : int[] = prim::ListConstruct(%48, %49), scope: AlexNet
%input.9 : Double(1, 9216) = aten::view(%44, %50), scope: AlexNet
%52 : float = prim::Constant[value=0.5](), scope: AlexNet/Sequential[classifier]/Dropout[0]
%input.10 : Double(1, 9216) = aten::dropout(%input.9, %52, %26), scope: AlexNet/Sequential[classifier]/Dropout[0]
%54 : Double(9216!, 4096!) = aten::t(%11), scope: AlexNet/Sequential[classifier]/Linear[1]
%input.11 : Double(1, 4096) = aten::addmm(%12, %input.10, %54, %21, %21), scope: AlexNet/Sequential[classifier]/Linear[1]
%input.12 : Double(1, 4096) = aten::threshold_(%input.11, %28, %28), scope: AlexNet/Sequential[classifier]/ReLU[2]
%input.13 : Double(1, 4096) = aten::dropout(%input.12, %52, %26), scope: AlexNet/Sequential[classifier]/Dropout[3]
%58 : Double(4096!, 4096!) = aten::t(%13), scope: AlexNet/Sequential[classifier]/Linear[4]
%input.14 : Double(1, 4096) = aten::addmm(%14, %input.13, %58, %21, %21), scope: AlexNet/Sequential[classifier]/Linear[4]
%input : Double(1, 4096) = aten::threshold_(%input.14, %28, %28), scope: AlexNet/Sequential[classifier]/ReLU[5]
%61 : Double(4096!, 1000!) = aten::t(%15), scope: AlexNet/Sequential[classifier]/Linear[6]
%62 : Double(1, 1000) = aten::addmm(%16, %input, %61, %21, %21), scope: AlexNet/Sequential[classifier]/Linear[6]
return (%62);
}

View File

@ -11,8 +11,10 @@ from torch.autograd.function import traceable
from torch.testing import assert_allclose
from torch.onnx import OperatorExportTypes
from torch._six import inf, PY2
from common_utils import (TestCase, run_tests, IS_WINDOWS, TEST_WITH_UBSAN,
skipIfRocm, skipIfNoLapack, suppress_warnings, load_tests, IS_SANDCASTLE)
from common_utils import TestCase, run_tests, IS_WINDOWS, TEST_WITH_UBSAN, \
skipIfRocm, skipIfNoLapack, suppress_warnings, load_tests, IS_SANDCASTLE, \
freeze_rng_state
from test_nn import module_tests, new_module_tests
from textwrap import dedent
import os
import io
@ -41,6 +43,7 @@ from torch.jit import BatchTensor
from test_module.future_div import div_int_future, div_float_future
from test_module.no_future_div import div_int_nofuture, div_float_nofuture
# load_tests from common_utils is used to automatically filter tests for
# sharding on sandcastle. This line silences flake warnings
load_tests = load_tests
@ -446,9 +449,8 @@ class JitTestCase(TestCase):
def runAndSaveRNG(self, func, inputs, kwargs=None):
kwargs = kwargs if kwargs else {}
initial_rng_state = torch.get_rng_state()
results = func(*inputs, **kwargs)
torch.set_rng_state(initial_rng_state)
with freeze_rng_state():
results = func(*inputs, **kwargs)
return results
@ -9396,6 +9398,11 @@ EXCLUDE_SCRIPT = {
'test_nn_max_unpool1d',
}
EXCLUDE_SCRIPT_MODULES = {
'test_nn_LPPool2d_norm',
'test_nn_LPPool1d_norm',
}
DISABLE_AUTODIFF_SUBGRAPH_INLINING = {
'test_nn_avg_pool2d',
'test_nn_log_softmax',
@ -10199,10 +10206,37 @@ def add_nn_functional_test(name, self_size, args, variant_name='', skipTestIf=()
post_add_test(test_name, skipTestIf, do_test)
def add_nn_module_test(module_name, constructor_args, call_args,
use_as_constant=False, skipTestIf=()):
def add_nn_module_test(*args, **kwargs):
if 'module_name' in kwargs:
name = kwargs['module_name']
elif 'fullname' in kwargs:
name = kwargs['fullname']
elif 'constructor' in kwargs:
name = kwargs['constructor'].__name__
class_name = name.split("_")[0]
module = getattr(torch.nn, class_name, None)
if module is None or torch._jit_internal._weak_types.get(module) is None:
return
if 'desc' in kwargs and 'eval' in kwargs['desc']:
# eval() is not supported, so skip these tests
return
test_name = name
if 'desc' in kwargs:
test_name = "{}_{}".format(test_name, kwargs['desc'])
test_name = 'test_nn_{}'.format(test_name)
def do_test(self):
nn_module = getattr(torch.nn, module_name)
if test_name in EXCLUDE_SCRIPT_MODULES:
return
if 'constructor' in kwargs:
nn_module = kwargs['constructor']
else:
nn_module = getattr(torch.nn, name)
constructor_args = kwargs.get('constructor_args', ())
# Construct a script module that passes arguments through
# to self.submodule
@ -10215,7 +10249,7 @@ def add_nn_module_test(module_name, constructor_args, call_args,
script = script_method_template.format(method_args, call)
submodule_constants = []
if use_as_constant:
if kwargs.get('is_constant'):
submodule_constants = ['submodule']
# Create module to use the script method
@ -10241,13 +10275,16 @@ def add_nn_module_test(module_name, constructor_args, call_args,
return module(*args)
# Check against Python module as reference
args_variable, kwargs_variable = create_input(call_args)
if 'input_fn' in kwargs:
input_size = tuple(kwargs['input_fn']().size())
else:
input_size = kwargs['input_size']
args_variable, kwargs_variable = create_input((input_size,))
f_args_variable = deepcopy(unpack_variables(args_variable))
check_against_reference(self, create_script_module, create_nn_module, f_args_variable)
test_name = 'test_nn_{}'.format(module_name)
post_add_test(test_name, skipTestIf, do_test)
post_add_test(test_name, (), do_test)
def post_add_test(test_name, skipTestIf, do_test):
@ -10411,8 +10448,8 @@ for test in autograd_method_tests:
for test in nn_functional_tests:
add_nn_functional_test(*test)
for test in nn_module_tests:
add_nn_module_test(*test)
for test in module_tests + new_module_tests:
add_nn_module_test(**test)
if __name__ == '__main__':
run_tests()

View File

@ -8321,7 +8321,7 @@ new_module_tests = [
),
dict(
module_name='LPPool2d',
constructor_args=(2, (2, 2), 2),
constructor_args=(2, 2, 2),
input_size=(1, 3, 7, 7),
),
dict(
@ -9005,7 +9005,7 @@ new_module_tests = [
),
dict(
module_name='Threshold',
constructor_args=(2, 1),
constructor_args=(2., 1.),
input_size=(),
check_inplace=True,
desc='threshold_value_scalar'

View File

@ -519,7 +519,7 @@ def max_unpool3d(input, indices, kernel_size, stride=None, padding=0,
@torch._jit_internal.weak_script
def lp_pool2d(input, norm_type, kernel_size, stride=None, ceil_mode=False):
# type: (Tensor, float, int, Optional[BroadcastingList1[int]], bool) -> Tensor
# type: (Tensor, float, int, Optional[BroadcastingList2[int]], bool) -> Tensor
r"""Applies a 2D power-average pooling over an input signal composed of
several input planes. If the sum of all inputs to the power of `p` is
zero, the gradient is set to zero as well.

View File

@ -9,8 +9,6 @@ from ..._jit_internal import weak_module, weak_script_method
@torch._jit_internal.weak_module
class Threshold(Module):
__constants__ = ['threshold', 'value', 'inplace']
r"""Thresholds each element of the input Tensor
Threshold is defined as:
@ -38,6 +36,7 @@ class Threshold(Module):
>>> input = torch.randn(2)
>>> output = m(input)
"""
__constants__ = ['threshold', 'value', 'inplace']
def __init__(self, threshold, value, inplace=False):
super(Threshold, self).__init__()
@ -57,6 +56,7 @@ class Threshold(Module):
)
@torch._jit_internal.weak_module
class ReLU(Threshold):
r"""Applies the rectified linear unit function element-wise
:math:`\text{ReLU}(x)= \max(0, x)`
@ -79,7 +79,7 @@ class ReLU(Threshold):
"""
def __init__(self, inplace=False):
super(ReLU, self).__init__(0, 0, inplace)
super(ReLU, self).__init__(0., 0., inplace)
def extra_repr(self):
inplace_str = 'inplace' if self.inplace else ''
@ -143,6 +143,7 @@ class RReLU(Module):
return 'lower={}, upper={}{}'.format(self.lower, self.upper, inplace_str)
@torch._jit_internal.weak_module
class Hardtanh(Module):
r"""Applies the HardTanh function element-wise
@ -179,8 +180,9 @@ class Hardtanh(Module):
>>> input = torch.randn(2)
>>> output = m(input)
"""
__constants__ = ['min_val', 'max_val', 'inplace']
def __init__(self, min_val=-1, max_val=1, inplace=False, min_value=None, max_value=None):
def __init__(self, min_val=-1., max_val=1., inplace=False, min_value=None, max_value=None):
super(Hardtanh, self).__init__()
if min_value is not None:
warnings.warn("keyword argument min_value is deprecated and renamed to min_val")
@ -194,6 +196,7 @@ class Hardtanh(Module):
self.inplace = inplace
assert self.max_val > self.min_val
@torch._jit_internal.weak_script_method
def forward(self, input):
return F.hardtanh(input, self.min_val, self.max_val, self.inplace)
@ -204,6 +207,7 @@ class Hardtanh(Module):
)
@torch._jit_internal.weak_module
class ReLU6(Hardtanh):
r"""Applies the element-wise function:
@ -228,7 +232,7 @@ class ReLU6(Hardtanh):
"""
def __init__(self, inplace=False):
super(ReLU6, self).__init__(0, 6, inplace)
super(ReLU6, self).__init__(0., 6., inplace)
def extra_repr(self):
inplace_str = 'inplace' if self.inplace else ''
@ -288,6 +292,7 @@ class Tanh(Module):
return torch.tanh(input)
@torch._jit_internal.weak_module
class ELU(Module):
r"""Applies the element-wise function:
@ -311,12 +316,14 @@ class ELU(Module):
>>> input = torch.randn(2)
>>> output = m(input)
"""
__constants__ = ['alpha', 'inplace']
def __init__(self, alpha=1., inplace=False):
super(ELU, self).__init__()
self.alpha = alpha
self.inplace = inplace
@torch._jit_internal.weak_script_method
def forward(self, input):
return F.elu(input, self.alpha, self.inplace)
@ -325,6 +332,7 @@ class ELU(Module):
return 'alpha={}{}'.format(self.alpha, inplace_str)
@torch._jit_internal.weak_module
class CELU(Module):
r"""Applies the element-wise function:
@ -353,12 +361,14 @@ class CELU(Module):
.. _`Continuously Differentiable Exponential Linear Units`:
https://arxiv.org/abs/1704.07483
"""
__constants__ = ['alpha', 'inplace']
def __init__(self, alpha=1., inplace=False):
super(CELU, self).__init__()
self.alpha = alpha
self.inplace = inplace
@torch._jit_internal.weak_script_method
def forward(self, input):
return F.celu(input, self.alpha, self.inplace)
@ -367,6 +377,7 @@ class CELU(Module):
return 'alpha={}{}'.format(self.alpha, inplace_str)
@torch._jit_internal.weak_module
class SELU(Module):
r"""Applied element-wise, as:
@ -396,11 +407,13 @@ class SELU(Module):
.. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515
"""
__constants__ = ['inplace']
def __init__(self, inplace=False):
super(SELU, self).__init__()
self.inplace = inplace
@torch._jit_internal.weak_script_method
def forward(self, input):
return F.selu(input, self.inplace)
@ -409,6 +422,7 @@ class SELU(Module):
return inplace_str
@torch._jit_internal.weak_module
class GLU(Module):
r"""Applies the gated linear unit function
:math:`{GLU}(a, b)= a \otimes \sigma(b)` where :math:`a` is the first half
@ -428,11 +442,13 @@ class GLU(Module):
>>> input = torch.randn(4, 2)
>>> output = m(input)
"""
__constants__ = ['dim']
def __init__(self, dim=-1):
super(GLU, self).__init__()
self.dim = dim
@torch._jit_internal.weak_script_method
def forward(self, input):
return F.glu(input, self.dim)
@ -482,6 +498,7 @@ class Hardshrink(Module):
return '{}'.format(self.lambd)
@torch._jit_internal.weak_module
class LeakyReLU(Module):
r"""Applies the element-wise function:
@ -515,12 +532,14 @@ class LeakyReLU(Module):
>>> input = torch.randn(2)
>>> output = m(input)
"""
__constants__ = ['inplace', 'negative_slope']
def __init__(self, negative_slope=1e-2, inplace=False):
super(LeakyReLU, self).__init__()
self.negative_slope = negative_slope
self.inplace = inplace
@torch._jit_internal.weak_script_method
def forward(self, input):
return F.leaky_relu(input, self.negative_slope, self.inplace)
@ -529,6 +548,7 @@ class LeakyReLU(Module):
return 'negative_slope={}{}'.format(self.negative_slope, inplace_str)
@torch._jit_internal.weak_module
class LogSigmoid(Module):
r"""Applies the element-wise function:
@ -548,10 +568,12 @@ class LogSigmoid(Module):
>>> output = m(input)
"""
@torch._jit_internal.weak_script_method
def forward(self, input):
return F.logsigmoid(input)
@torch._jit_internal.weak_module
class Softplus(Module):
r"""Applies the element-wise function:
@ -581,12 +603,14 @@ class Softplus(Module):
>>> input = torch.randn(2)
>>> output = m(input)
"""
__constants__ = ['beta', 'threshold']
def __init__(self, beta=1, threshold=20):
super(Softplus, self).__init__()
self.beta = beta
self.threshold = threshold
@torch._jit_internal.weak_script_method
def forward(self, input):
return F.softplus(input, self.beta, self.threshold)
@ -753,6 +777,7 @@ class Tanhshrink(Module):
return F.tanhshrink(input)
@torch._jit_internal.weak_module
class Softmin(Module):
r"""Applies the Softmin function to an n-dimensional input Tensor
rescaling them so that the elements of the n-dimensional output Tensor
@ -779,15 +804,18 @@ class Softmin(Module):
>>> input = torch.randn(2, 3)
>>> output = m(input)
"""
__constants__ = ['dim']
def __init__(self, dim=None):
super(Softmin, self).__init__()
self.dim = dim
@torch._jit_internal.weak_script_method
def forward(self, input):
return F.softmin(input, self.dim, _stacklevel=5)
@torch._jit_internal.weak_module
class Softmax(Module):
r"""Applies the Softmax function to an n-dimensional input Tensor
rescaling them so that the elements of the n-dimensional output Tensor
@ -821,6 +849,7 @@ class Softmax(Module):
>>> input = torch.randn(2, 3)
>>> output = m(input)
"""
__constants__ = ['dim']
def __init__(self, dim=None):
super(Softmax, self).__init__()
@ -831,10 +860,12 @@ class Softmax(Module):
if not hasattr(self, 'dim'):
self.dim = None
@torch._jit_internal.weak_script_method
def forward(self, input):
return F.softmax(input, self.dim, _stacklevel=5)
@torch._jit_internal.weak_module
class Softmax2d(Module):
r"""Applies SoftMax over features to each spatial location.
@ -857,11 +888,13 @@ class Softmax2d(Module):
>>> output = m(input)
"""
@torch._jit_internal.weak_script_method
def forward(self, input):
assert input.dim() == 4, 'Softmax2d requires a 4D tensor as input'
return F.softmax(input, 1, _stacklevel=5)
@torch._jit_internal.weak_module
class LogSoftmax(Module):
r"""Applies the :math:`\log(\text{Softmax}(x))` function to an n-dimensional
input Tensor. The LogSoftmax formulation can be simplified as:
@ -887,6 +920,7 @@ class LogSoftmax(Module):
>>> input = torch.randn(2, 3)
>>> output = m(input)
"""
__constants__ = ['dim']
def __init__(self, dim=None):
super(LogSoftmax, self).__init__()
@ -897,5 +931,6 @@ class LogSoftmax(Module):
if not hasattr(self, 'dim'):
self.dim = None
@torch._jit_internal.weak_script_method
def forward(self, input):
return F.log_softmax(input, self.dim, _stacklevel=5)