mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Use chainer-style constructor for Conv2d
* Conv2d, MaxPool2d, and AvgPool2d have one argument for each of ksize, stride, and pad. This argument can be either a single number or a tuple of (h, w)
This commit is contained in:
parent
1486d880b0
commit
cd0929aa5e
|
|
@ -9,7 +9,8 @@ try:
|
|||
import torch.cuda
|
||||
import torch.legacy.cunn
|
||||
TEST_CUDA = True
|
||||
except ImportError:
|
||||
except Exception:
|
||||
# TODO: catch ImportError once it works with "setup.py develop"
|
||||
TEST_CUDA = False
|
||||
|
||||
PRECISION = 1e-5
|
||||
|
|
@ -21,23 +22,6 @@ module_tests = [
|
|||
input_size=(4, 10),
|
||||
reference_fn=lambda i,p: torch.mm(i, p[0].t()) + p[1].view(1, -1).expand(4, 8)
|
||||
),
|
||||
dict(
|
||||
module_name='Conv2d',
|
||||
constructor_args=(3, 4, 3, 3),
|
||||
input_size=(2, 3, 6, 6)
|
||||
),
|
||||
dict(
|
||||
module_name='Conv2d',
|
||||
constructor_args=(3, 4, 3, 3, 2, 2),
|
||||
input_size=(2, 3, 6, 6),
|
||||
desc='strided'
|
||||
),
|
||||
dict(
|
||||
module_name='Conv2d',
|
||||
constructor_args=(3, 4, 3, 3, 2, 2, 1, 1),
|
||||
input_size=(2, 3, 6, 6),
|
||||
desc='padding'
|
||||
),
|
||||
dict(
|
||||
module_name='Threshold',
|
||||
constructor_args=(2, 1),
|
||||
|
|
@ -74,28 +58,6 @@ module_tests = [
|
|||
module_name='Tanh',
|
||||
input_size=(2, 3, 4, 5)
|
||||
),
|
||||
dict(
|
||||
module_name='MaxPool2d',
|
||||
constructor_args=(3, 3, 2, 2, 1, 1),
|
||||
input_size=(1, 3, 7, 7)
|
||||
),
|
||||
dict(
|
||||
module_name='AvgPool2d',
|
||||
constructor_args=(2, 2),
|
||||
input_size=(2, 3, 6, 6),
|
||||
),
|
||||
dict(
|
||||
module_name='AvgPool2d',
|
||||
constructor_args=(2, 2, 2, 2),
|
||||
input_size=(2, 3, 6, 6),
|
||||
desc='stride',
|
||||
),
|
||||
dict(
|
||||
module_name='AvgPool2d',
|
||||
constructor_args=(2, 2, 2, 2, 1, 1),
|
||||
input_size=(2, 3, 6, 6),
|
||||
desc='stride_pad',
|
||||
),
|
||||
dict(
|
||||
module_name='Softmax',
|
||||
input_size=(10, 20),
|
||||
|
|
|
|||
|
|
@ -343,6 +343,17 @@ tests = [
|
|||
(1,),
|
||||
input_size=[(1,), (2,), (3,), (4,)],
|
||||
reference_fn=lambda i,_: i[1]),
|
||||
OldModuleTest(nn.SpatialAveragePooling,
|
||||
(2, 2),
|
||||
input_size=(2, 3, 6, 6)),
|
||||
OldModuleTest(nn.SpatialAveragePooling,
|
||||
(2, 2, 2, 2),
|
||||
input_size=(2, 3, 6, 6),
|
||||
desc='stride'),
|
||||
OldModuleTest(nn.SpatialAveragePooling,
|
||||
(2, 2, 2, 2, 1, 1),
|
||||
input_size=(2, 3, 6, 6),
|
||||
desc='stride_pad'),
|
||||
OldModuleTest(nn.SpatialAdaptiveMaxPooling,
|
||||
(4, 4),
|
||||
input_size=(2, 3, 8, 8)),
|
||||
|
|
@ -352,6 +363,17 @@ tests = [
|
|||
desc='irregular'),
|
||||
# TODO: enable after implementing MaxPooling
|
||||
# reference_fn=lambda i,_: nn.SpatialMaxPooling(2, 2).forward(i)),
|
||||
OldModuleTest(nn.SpatialConvolution,
|
||||
(3, 4, 3, 3),
|
||||
input_size=(2, 3, 6, 6)),
|
||||
OldModuleTest(nn.SpatialConvolution,
|
||||
(3, 4, 3, 3, 2, 2),
|
||||
input_size=(2, 3, 6, 6),
|
||||
desc='strided'),
|
||||
OldModuleTest(nn.SpatialConvolution,
|
||||
(3, 4, 3, 3, 2, 2, 1, 1),
|
||||
input_size=(2, 3, 6, 6),
|
||||
desc='padding'),
|
||||
OldModuleTest(nn.SpatialConvolutionLocal,
|
||||
(3, 2, 4, 4, 2, 2),
|
||||
input_size=(1, 3, 4, 4)),
|
||||
|
|
@ -380,6 +402,9 @@ tests = [
|
|||
(3, 2, 3, 3, 2, 2, 1, 1, 2, 2),
|
||||
input_size=(2, 3, 8, 8),
|
||||
desc='stride_pad'),
|
||||
OldModuleTest(nn.SpatialMaxPooling,
|
||||
(3, 3, 2, 2, 1, 1),
|
||||
input_size=(1, 3, 7, 7)),
|
||||
OldModuleTest(nn.SpatialReflectionPadding,
|
||||
(1, 2, 3, 4),
|
||||
input_size=(2, 3, 8, 8)),
|
||||
|
|
|
|||
|
|
@ -147,7 +147,7 @@ class TestNN(NNTestCase):
|
|||
self.assertEqual(counter['backwards'], 7)
|
||||
|
||||
def test_volatile(self):
|
||||
module = nn.Conv2d(2, 5, 3, 3, padh=1, padw=1)
|
||||
module = nn.Conv2d(2, 5, ksize=3, pad=1)
|
||||
input = torch.randn(1, 2, 10, 10)
|
||||
x = Variable(input)
|
||||
y = Variable(input.clone(), volatile=True)
|
||||
|
|
@ -175,7 +175,49 @@ def add_test(test):
|
|||
setattr(TestNN, cuda_test_name, lambda self,test=test: test.test_cuda(self))
|
||||
|
||||
|
||||
for test_params in module_tests:
|
||||
new_module_tests = [
|
||||
dict(
|
||||
module_name='Conv2d',
|
||||
constructor_args=(3, 4, (3, 3)),
|
||||
input_size=(2, 3, 6, 6)
|
||||
),
|
||||
dict(
|
||||
module_name='Conv2d',
|
||||
constructor_args=(3, 4, (3, 3), (2, 2)),
|
||||
input_size=(2, 3, 6, 6),
|
||||
desc='strided'
|
||||
),
|
||||
dict(
|
||||
module_name='Conv2d',
|
||||
constructor_args=(3, 4, (3, 3), (2, 2), (1, 1)),
|
||||
input_size=(2, 3, 6, 6),
|
||||
desc='padding'
|
||||
),
|
||||
dict(
|
||||
module_name='MaxPool2d',
|
||||
constructor_args=((3, 3), (2, 2), (1, 1)),
|
||||
input_size=(1, 3, 7, 7)
|
||||
),
|
||||
dict(
|
||||
module_name='AvgPool2d',
|
||||
constructor_args=((2, 2),),
|
||||
input_size=(2, 3, 6, 6),
|
||||
),
|
||||
dict(
|
||||
module_name='AvgPool2d',
|
||||
constructor_args=((2, 2), (2, 2)),
|
||||
input_size=(2, 3, 6, 6),
|
||||
desc='stride',
|
||||
),
|
||||
dict(
|
||||
module_name='AvgPool2d',
|
||||
constructor_args=((2, 2), (2, 2), (1, 1)),
|
||||
input_size=(2, 3, 6, 6),
|
||||
desc='stride_pad',
|
||||
),
|
||||
]
|
||||
|
||||
for test_params in module_tests + new_module_tests:
|
||||
test_params = deepcopy(test_params)
|
||||
# TODO: CUDA is not implemented yet
|
||||
name = test_params.pop('module_name')
|
||||
|
|
|
|||
|
|
@ -3,28 +3,31 @@ import torch
|
|||
from torch.autograd import Variable
|
||||
|
||||
from .module import Module
|
||||
from .utils import _pair
|
||||
|
||||
class Conv2d(Module):
|
||||
def __init__(self, in_channels, out_channels, kh, kw, dh=1, dw=1, padh=0, padw=0):
|
||||
def __init__(self, in_channels, out_channels, ksize, stride=1, pad=0,
|
||||
nobias=False):
|
||||
super(Conv2d, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.kh = kh
|
||||
self.kw = kw
|
||||
self.dh = dh
|
||||
self.dw = dw
|
||||
self.padh = padh
|
||||
self.padw = padw
|
||||
self.kh, self.kw = _pair(ksize)
|
||||
self.dh, self.dw = _pair(stride)
|
||||
self.padh, self.padw = _pair(pad)
|
||||
|
||||
self.weight = Variable(torch.DoubleTensor(self.out_channels, self.in_channels, self.kh, self.kw))
|
||||
self.bias = Variable(torch.DoubleTensor(self.out_channels))
|
||||
if nobias:
|
||||
self.bias = None
|
||||
else:
|
||||
self.bias = Variable(torch.DoubleTensor(self.out_channels))
|
||||
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
stdv = 1. / math.sqrt(self.kh * self.kw * self.in_channels)
|
||||
self.weight.data.uniform_(-stdv, stdv)
|
||||
self.bias.data.uniform_(-stdv, stdv)
|
||||
if self.bias is not None:
|
||||
self.bias.data.uniform_(-stdv, stdv)
|
||||
|
||||
def forward(self, input):
|
||||
conv2d = self._backend.Conv2d(self.kw, self.kh, self.dw, self.dh, self.padw, self.padh)
|
||||
|
|
|
|||
|
|
@ -2,19 +2,16 @@ import torch
|
|||
from torch.autograd import Variable
|
||||
|
||||
from .module import Module
|
||||
from .utils import _pair
|
||||
|
||||
class MaxPool2d(Module):
|
||||
|
||||
def __init__(self, kh, kw, dh=None, dw=None, padh=0, padw=0, dilh=1, dilw=1, ceil_mode=False):
|
||||
def __init__(self, ksize, stride=None, pad=0, dil=1, ceil_mode=False):
|
||||
super(MaxPool2d, self).__init__()
|
||||
self.kw = kw
|
||||
self.kh = kh
|
||||
self.dw = dw or kw
|
||||
self.dh = dh or kh
|
||||
self.padw = padw
|
||||
self.padh = padh
|
||||
self.dilh = dilh
|
||||
self.dilw = dilw
|
||||
self.kh, self.kw = _pair(ksize)
|
||||
self.dh, self.dw = _pair(stride or ksize)
|
||||
self.padh, self.padw = _pair(pad)
|
||||
self.dilh, self.dilw = _pair(dil)
|
||||
self.ceil_mode = ceil_mode
|
||||
|
||||
def forward(self, input):
|
||||
|
|
@ -22,14 +19,11 @@ class MaxPool2d(Module):
|
|||
|
||||
class AvgPool2d(Module):
|
||||
|
||||
def __init__(self, kh, kw, dh=None, dw=None, padh=0, padw=0, ceil_mode=False, count_include_pad=True):
|
||||
def __init__(self, ksize, stride=None, pad=0, ceil_mode=False, count_include_pad=True):
|
||||
super(AvgPool2d, self).__init__()
|
||||
self.kw = kw
|
||||
self.kh = kh
|
||||
self.dw = dw or kw
|
||||
self.dh = dh or kh
|
||||
self.padw = padw
|
||||
self.padh = padh
|
||||
self.kh, self.kw = _pair(ksize)
|
||||
self.dh, self.dw = _pair(stride or ksize)
|
||||
self.padh, self.padw = _pair(pad)
|
||||
self.ceil_mode = ceil_mode
|
||||
self.count_include_pad = count_include_pad
|
||||
|
||||
|
|
|
|||
6
torch/nn/modules/utils.py
Normal file
6
torch/nn/modules/utils.py
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
import collections
|
||||
|
||||
def _pair(x):
|
||||
if isinstance(x, collections.Iterable):
|
||||
return x
|
||||
return x, x
|
||||
Loading…
Reference in New Issue
Block a user