mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Add 3D upsampling (nearest and trilinear) with tests
This commit is contained in:
parent
edd41d8d80
commit
b9ab26765e
|
|
@ -525,6 +525,12 @@ Vision layers
|
|||
.. autoclass:: PixelShuffle
|
||||
:members:
|
||||
|
||||
:hidden:`Upsample`
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: UpsamplingNearest3d
|
||||
:members:
|
||||
|
||||
:hidden:`UpsamplingNearest2d`
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
|
@ -873,6 +879,23 @@ Vision functions
|
|||
|
||||
.. autofunction:: pad
|
||||
|
||||
:hidden:`upsample`
|
||||
~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autofunction:: upsample
|
||||
|
||||
:hidden:`upsample_nearest`
|
||||
~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autofunction:: upsample_nearest
|
||||
|
||||
:hidden:`upsample_bilinear`
|
||||
~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autofunction:: upsample_bilinear
|
||||
|
||||
|
||||
|
||||
torch.nn.init
|
||||
=============
|
||||
|
||||
|
|
|
|||
118
test/test_nn.py
118
test/test_nn.py
|
|
@ -2128,6 +2128,42 @@ class TestNN(NNTestCase):
|
|||
self.assertTrue(gradcheck(lambda x, y: F.cosine_similarity(x, y, dim=0), (input1, input2)))
|
||||
self.assertTrue(gradcheck(lambda x, y: F.cosine_similarity(x, y, dim=-1), (input1, input2)))
|
||||
|
||||
def test_upsamplingNearest2d(self):
|
||||
m = nn.Upsample(size=4, mode='nearest')
|
||||
in_t = torch.ones(1, 1, 2, 2)
|
||||
out_t = m(Variable(in_t))
|
||||
self.assertEqual(torch.ones(1, 1, 4, 4), out_t.data)
|
||||
|
||||
input = Variable(torch.randn(1, 1, 2, 2), requires_grad=True)
|
||||
self.assertTrue(gradcheck(lambda x: F.upsample(x, 4, mode='nearest'), (input,)))
|
||||
|
||||
def test_upsamplingBilinear2d(self):
|
||||
m = nn.Upsample(size=4, mode='bilinear')
|
||||
in_t = torch.ones(1, 1, 2, 2)
|
||||
out_t = m(Variable(in_t))
|
||||
self.assertEqual(torch.ones(1, 1, 4, 4), out_t.data)
|
||||
|
||||
input = Variable(torch.randn(1, 1, 2, 2), requires_grad=True)
|
||||
self.assertTrue(gradcheck(lambda x: F.upsample(x, 4, mode='bilinear'), (input,)))
|
||||
|
||||
def test_upsamplingNearest3d(self):
|
||||
m = nn.Upsample(size=4, mode='nearest')
|
||||
in_t = torch.ones(1, 1, 2, 2, 2)
|
||||
out_t = m(Variable(in_t))
|
||||
self.assertEqual(torch.ones(1, 1, 4, 4, 4), out_t.data)
|
||||
|
||||
input = Variable(torch.randn(1, 1, 2, 2, 2), requires_grad=True)
|
||||
self.assertTrue(gradcheck(lambda x: F.upsample(x, 4, mode='nearest'), (input,)))
|
||||
|
||||
def test_upsamplingTrilinear3d(self):
|
||||
m = nn.Upsample(size=4, mode='trilinear')
|
||||
in_t = torch.ones(1, 1, 2, 2, 2)
|
||||
out_t = m(Variable(in_t))
|
||||
self.assertEqual(torch.ones(1, 1, 4, 4, 4), out_t.data)
|
||||
|
||||
input = Variable(torch.randn(1, 1, 2, 2, 2), requires_grad=True)
|
||||
self.assertTrue(gradcheck(lambda x: F.upsample(x, 4, mode='trilinear'), (input,)))
|
||||
|
||||
def test_bilinear(self):
|
||||
module = nn.Bilinear(10, 10, 8)
|
||||
module_legacy = legacy.Bilinear(10, 10, 8)
|
||||
|
|
@ -2965,50 +3001,88 @@ new_module_tests = [
|
|||
input_size=(1, 9, 4, 4),
|
||||
),
|
||||
dict(
|
||||
module_name='UpsamplingNearest2d',
|
||||
constructor_args=(12,),
|
||||
module_name='Upsample',
|
||||
constructor_args=(12, None, 'nearest'),
|
||||
input_size=(1, 2, 4, 4),
|
||||
desc='nearest_2d'
|
||||
),
|
||||
dict(
|
||||
module_name='UpsamplingNearest2d',
|
||||
constructor_args=((12, 16)),
|
||||
module_name='Upsample',
|
||||
constructor_args=((12, 16), None, 'nearest'),
|
||||
input_size=(1, 2, 3, 4),
|
||||
desc='tuple'
|
||||
desc='nearest_tuple_2d'
|
||||
),
|
||||
dict(
|
||||
module_name='UpsamplingNearest2d',
|
||||
constructor_args=(None, 4),
|
||||
module_name='Upsample',
|
||||
constructor_args=(None, 4, 'nearest'),
|
||||
input_size=(1, 2, 4, 4),
|
||||
desc='scale'
|
||||
desc='nearest_scale_2d'
|
||||
),
|
||||
dict(
|
||||
module_name='UpsamplingBilinear2d',
|
||||
constructor_args=(12,),
|
||||
module_name='Upsample',
|
||||
constructor_args=(12, None, 'bilinear'),
|
||||
input_size=(1, 2, 4, 4),
|
||||
desc='bilinear_2d'
|
||||
),
|
||||
dict(
|
||||
module_name='UpsamplingBilinear2d',
|
||||
constructor_args=((4, 6)),
|
||||
module_name='Upsample',
|
||||
constructor_args=((4, 6), None, 'bilinear'),
|
||||
input_size=(1, 2, 2, 3),
|
||||
desc='tuple'
|
||||
desc='bilinear_tuple_2d'
|
||||
),
|
||||
dict(
|
||||
module_name='UpsamplingBilinear2d',
|
||||
constructor_args=(None, 4),
|
||||
module_name='Upsample',
|
||||
constructor_args=(None, 4, 'bilinear'),
|
||||
input_size=(1, 2, 4, 4),
|
||||
desc='scale'
|
||||
desc='bilinear_scale_2d'
|
||||
),
|
||||
dict(
|
||||
module_name='UpsamplingBilinear2d',
|
||||
constructor_args=(None, (2, 2)),
|
||||
module_name='Upsample',
|
||||
constructor_args=(None, (2, 2), 'bilinear'),
|
||||
input_size=(1, 2, 4, 4),
|
||||
desc='scale_tuple_shared'
|
||||
desc='bilinear_scale_tuple_shared_2d'
|
||||
),
|
||||
dict(
|
||||
module_name='UpsamplingBilinear2d',
|
||||
constructor_args=(None, (2, 1)),
|
||||
module_name='Upsample',
|
||||
constructor_args=(None, (2, 1), 'bilinear'),
|
||||
input_size=(1, 2, 4, 4),
|
||||
desc='scale_tuple_skewed'
|
||||
desc='bilinear_scale_tuple_skewed_2d'
|
||||
),
|
||||
dict(
|
||||
module_name='Upsample',
|
||||
constructor_args=(12, None, 'nearest'),
|
||||
input_size=(1, 2, 4, 4, 4),
|
||||
desc='nearest_3d'
|
||||
),
|
||||
dict(
|
||||
module_name='Upsample',
|
||||
constructor_args=((12, 16, 16), None, 'nearest'),
|
||||
input_size=(1, 2, 3, 4, 4),
|
||||
desc='nearest_tuple_3d'
|
||||
),
|
||||
dict(
|
||||
module_name='Upsample',
|
||||
constructor_args=(None, 4, 'nearest'),
|
||||
input_size=(1, 2, 4, 4, 4),
|
||||
desc='nearest_scale_3d'
|
||||
),
|
||||
dict(
|
||||
module_name='Upsample',
|
||||
constructor_args=(12, None, 'trilinear'),
|
||||
input_size=(1, 2, 4, 4, 4),
|
||||
desc='trilinear_3d'
|
||||
),
|
||||
dict(
|
||||
module_name='Upsample',
|
||||
constructor_args=((4, 6, 6), None, 'trilinear'),
|
||||
input_size=(1, 2, 2, 3, 3),
|
||||
desc='trilinear_tuple_3d'
|
||||
),
|
||||
dict(
|
||||
module_name='Upsample',
|
||||
constructor_args=(None, 4, 'trilinear'),
|
||||
input_size=(1, 2, 4, 4, 4),
|
||||
desc='trilinear_scale_3d'
|
||||
),
|
||||
dict(
|
||||
module_name='AdaptiveMaxPool1d',
|
||||
|
|
|
|||
|
|
@ -4,8 +4,7 @@ from torch.autograd import Function
|
|||
from torch._thnn import type2backend
|
||||
|
||||
from . import _all_functions
|
||||
from ...modules.utils import _pair
|
||||
from ...functional import _check_bilinear_2d_scale_factor
|
||||
from ...modules.utils import _pair, _triple
|
||||
|
||||
|
||||
class _UpsamplingBase(Function):
|
||||
|
|
@ -15,7 +14,7 @@ class _UpsamplingBase(Function):
|
|||
if size is None and scale_factor is None:
|
||||
raise ValueError('either size or scale_factor should be defined')
|
||||
if scale_factor is not None and not isinstance(scale_factor, (Integral, tuple)):
|
||||
raise ValueError('scale_factor must be of integer type or tuple of integer types')
|
||||
raise ValueError('scale_factor must be of integer type or a tuple of integer types')
|
||||
self.size = size
|
||||
self.scale_factor = scale_factor
|
||||
|
||||
|
|
@ -26,9 +25,7 @@ class UpsamplingNearest2d(_UpsamplingBase):
|
|||
super(UpsamplingNearest2d, self).__init__(size, scale_factor)
|
||||
|
||||
if self.scale_factor is not None and not isinstance(scale_factor, Integral):
|
||||
raise ValueError('scale_factor must be of integer type for nearest neighbor sampling')
|
||||
|
||||
self.size = _pair(self.size) if self.size is not None else None
|
||||
raise ValueError('scale_factor must be a single Integer value for nearest neighbor sampling')
|
||||
|
||||
def forward(self, input):
|
||||
assert input.dim() == 4
|
||||
|
|
@ -36,7 +33,7 @@ class UpsamplingNearest2d(_UpsamplingBase):
|
|||
if self.scale_factor is None:
|
||||
if (self.size[0] % input.size(2) != 0 or
|
||||
self.size[1] % input.size(3) != 0):
|
||||
raise RuntimeError("output size specified in UpSamplingNearest "
|
||||
raise RuntimeError("output size specified in UpsamplingNearest "
|
||||
"({}) has to be divisible by the input size, but got: "
|
||||
"{}".format('x'.join(map(str, self.size)),
|
||||
'x'.join(map(str, input.size()))))
|
||||
|
|
@ -57,6 +54,8 @@ class UpsamplingNearest2d(_UpsamplingBase):
|
|||
return output
|
||||
|
||||
def backward(self, grad_output):
|
||||
assert grad_output.dim() == 4
|
||||
|
||||
input, = self.saved_tensors
|
||||
grad_input = grad_output.new()
|
||||
backend = type2backend[type(input)]
|
||||
|
|
@ -70,15 +69,31 @@ class UpsamplingNearest2d(_UpsamplingBase):
|
|||
return grad_input
|
||||
|
||||
|
||||
def _check_linear_scale_factor(scale_factor, dim=2):
|
||||
if dim == 2:
|
||||
scale_factor = _pair(scale_factor)
|
||||
elif dim == 3:
|
||||
scale_factor = _triple(scale_factor)
|
||||
else:
|
||||
raise ValueError("dim has to be 2 or 3")
|
||||
|
||||
try:
|
||||
assert len(scale_factor) == 2 or len(scale_factor) == 3
|
||||
assert all(isinstance(s, Integral) and s >= 1 for s in scale_factor)
|
||||
except AssertionError as e:
|
||||
raise ValueError('scale_factor must be a non-negative integer, '
|
||||
'or a tuple of non-negative integers for bilinear and trilinear upsampling, but got: '
|
||||
'{}'.format(scale_factor))
|
||||
return scale_factor
|
||||
|
||||
|
||||
class UpsamplingBilinear2d(_UpsamplingBase):
|
||||
|
||||
def __init__(self, size=None, scale_factor=None):
|
||||
super(UpsamplingBilinear2d, self).__init__(size, scale_factor)
|
||||
|
||||
if self.scale_factor is not None:
|
||||
self.scale_factor = _check_bilinear_2d_scale_factor(self.scale_factor)
|
||||
|
||||
self.size = _pair(self.size) if self.size is not None else None
|
||||
self.scale_factor = _check_linear_scale_factor(self.scale_factor, dim=2)
|
||||
|
||||
def forward(self, input):
|
||||
assert input.dim() == 4
|
||||
|
|
@ -126,5 +141,107 @@ class UpsamplingBilinear2d(_UpsamplingBase):
|
|||
self.__dict__.update(state)
|
||||
self.scale_factor = _tuple(self.scale_factor)
|
||||
|
||||
|
||||
class UpsamplingNearest3d(_UpsamplingBase):
|
||||
def __init__(self, size=None, scale_factor=None):
|
||||
super(UpsamplingNearest3d, self).__init__(size, scale_factor)
|
||||
|
||||
if self.scale_factor is not None and not isinstance(scale_factor, Integral):
|
||||
raise ValueError('scale_factor must be a single Integer value for nearest neighbor sampling')
|
||||
|
||||
def forward(self, input):
|
||||
assert input.dim() == 5
|
||||
|
||||
if self.scale_factor is None:
|
||||
if (self.size[0] % input.size(2) != 0 or self.size[1] % input.size(3) != 0 or
|
||||
self.size[2] % input.size(4) != 0):
|
||||
raise RuntimeError("output size specified in UpSamplingNearest "
|
||||
"({}) has to be divisible by the input size, but got: "
|
||||
"{}".format('x'.join(map(str, self.size)),
|
||||
'x'.join(map(str, input.size()))))
|
||||
self.scale_factor = self.size[0] // input.size(2)
|
||||
if (self.scale_factor != self.size[1] // input.size(3) or
|
||||
self.scale_factor != self.size[2] // input.size(4)):
|
||||
raise RuntimeError("input aspect ratio doesn't match the "
|
||||
"output ratio")
|
||||
|
||||
output = input.new()
|
||||
backend = type2backend[type(input)]
|
||||
self.save_for_backward(input)
|
||||
backend.VolumetricUpSamplingNearest_updateOutput(backend.library_state,
|
||||
input,
|
||||
output,
|
||||
self.scale_factor)
|
||||
return output
|
||||
|
||||
def backward(self, grad_output):
|
||||
assert grad_output.dim() == 5
|
||||
input, = self.saved_tensors
|
||||
grad_input = grad_output.new()
|
||||
backend = type2backend[type(input)]
|
||||
backend.VolumetricUpSamplingNearest_updateGradInput(backend.library_state,
|
||||
input,
|
||||
grad_output,
|
||||
grad_input,
|
||||
self.scale_factor)
|
||||
return grad_input
|
||||
|
||||
|
||||
class UpsamplingTrilinear3d(_UpsamplingBase):
|
||||
def __init__(self, size=None, scale_factor=None):
|
||||
super(UpsamplingTrilinear3d, self).__init__(size, scale_factor)
|
||||
|
||||
if self.scale_factor is not None:
|
||||
self.scale_factor = _check_linear_scale_factor(self.scale_factor, dim=3)
|
||||
|
||||
def forward(self, input):
|
||||
assert input.dim() == 5
|
||||
|
||||
if self.scale_factor is not None:
|
||||
self.output_size = (
|
||||
input.size(2) * self.scale_factor[0],
|
||||
input.size(3) * self.scale_factor[1],
|
||||
input.size(4) * self.scale_factor[2],
|
||||
)
|
||||
else:
|
||||
self.output_size = self.size
|
||||
|
||||
self.input_size = input.size()
|
||||
output = input.new()
|
||||
backend = type2backend[type(input)]
|
||||
backend.VolumetricUpSamplingTrilinear_updateOutput(
|
||||
backend.library_state,
|
||||
input,
|
||||
output,
|
||||
self.output_size[0],
|
||||
self.output_size[1],
|
||||
self.output_size[2]
|
||||
)
|
||||
return output
|
||||
|
||||
def backward(self, grad_output):
|
||||
assert grad_output.dim() == 5
|
||||
|
||||
grad_output = grad_output.contiguous()
|
||||
grad_input = grad_output.new()
|
||||
backend = type2backend[type(grad_output)]
|
||||
backend.VolumetricUpSamplingTrilinear_updateGradInput(
|
||||
backend.library_state,
|
||||
grad_output,
|
||||
grad_input,
|
||||
self.input_size[0],
|
||||
self.input_size[1],
|
||||
self.input_size[2],
|
||||
self.input_size[3],
|
||||
self.input_size[4],
|
||||
self.output_size[0],
|
||||
self.output_size[1],
|
||||
self.output_size[2]
|
||||
)
|
||||
return grad_input
|
||||
|
||||
|
||||
_all_functions.append(UpsamplingNearest2d)
|
||||
_all_functions.append(UpsamplingBilinear2d)
|
||||
_all_functions.append(UpsamplingNearest3d)
|
||||
_all_functions.append(UpsamplingTrilinear3d)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
"""Functional interface"""
|
||||
|
||||
from numbers import Integral
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
from . import _functions
|
||||
|
|
@ -628,44 +629,76 @@ def pixel_shuffle(input, upscale_factor):
|
|||
return shuffle_out.view(batch_size, channels, out_height, out_width)
|
||||
|
||||
|
||||
def upsample_nearest(input, size=None, scale_factor=None):
|
||||
"""Upsamples the input, using nearest neighbours' pixel values.
|
||||
def upsample(input, size=None, scale_factor=None, mode='nearest'):
|
||||
"""Upsamples the input to either the given :attr:`size` or the given :attr:`scale_factor`
|
||||
|
||||
Currently only spatial upsampling is supported (i.e. expected inputs
|
||||
are 4 dimensional).
|
||||
The algorithm used for upsampling is determined by :attr:`mode`.
|
||||
|
||||
Currently spatial and volumetric upsampling are supported, i.e.
|
||||
expected inputs are 4-D or 5-D in shape.
|
||||
|
||||
The input dimensions are interpreted in the form: `mini-batch x channels x [depth] x height x width`
|
||||
|
||||
The modes available for upsampling are: `nearest`, `bilinear` (4D-only), `trilinear` (5D-only)
|
||||
|
||||
Args:
|
||||
input (Variable): input
|
||||
size (int or Tuple[int, int]): output spatial size.
|
||||
size (int or Tuple[int, int] or Tuple[int, int, int]): output spatial size.
|
||||
scale_factor (int): multiplier for spatial size. Has to be an integer.
|
||||
mode (string): algorithm used for upsampling: 'nearest' | 'bilinear' | 'trilinear'
|
||||
"""
|
||||
if input.dim() == 4 and mode == 'nearest':
|
||||
return _functions.thnn.UpsamplingNearest2d(_pair(size), scale_factor)(input)
|
||||
elif input.dim() == 5 and mode == 'nearest':
|
||||
return _functions.thnn.UpsamplingNearest3d(_triple(size), scale_factor)(input)
|
||||
elif input.dim() == 4 and mode == 'bilinear':
|
||||
return _functions.thnn.UpsamplingBilinear2d(_pair(size), scale_factor)(input)
|
||||
elif input.dim() == 4 and mode == 'trilinear':
|
||||
raise NotImplementedError("Got 4D input, but trilinear mode needs 5D input")
|
||||
elif input.dim() == 5 and mode == 'bilinear':
|
||||
raise NotImplementedError("Got 5D input, but bilinear mode needs 4D input")
|
||||
elif input.dim() == 5 and mode == 'trilinear':
|
||||
return _functions.thnn.UpsamplingTrilinear3d(_triple(size), scale_factor)(input)
|
||||
else:
|
||||
raise NotImplementedError("Input Error: Only 4D and 5D input Tensors supported"
|
||||
" (got {}D) for the modes: nearest | bilinear | trilinear"
|
||||
" (got {})".format(input.dim(), mode))
|
||||
|
||||
|
||||
def upsample_nearest(input, size=None, scale_factor=None):
|
||||
"""Upsamples the input, using nearest neighbours' pixel values.
|
||||
|
||||
**Note:: This function is deprecated. Use nn.functional.upsample instead**
|
||||
|
||||
Currently spatial and volumetric upsampling are supported (i.e. expected inputs
|
||||
are 4 or 5 dimensional).
|
||||
|
||||
Args:
|
||||
input (Variable): input
|
||||
size (int or Tuple[int, int] or Tuple[int, int, int]): output spatial size.
|
||||
scale_factor (int): multiplier for spatial size. Has to be an integer.
|
||||
"""
|
||||
return _functions.thnn.UpsamplingNearest2d(size, scale_factor)(input)
|
||||
# DeprecationWarning is ignored by default
|
||||
warnings.warn("nn.functional.upsample_nearest is deprecated. Use nn.functional.upsample instead.")
|
||||
return upsample(input, size, scale_factor, mode='nearest')
|
||||
|
||||
|
||||
def upsample_bilinear(input, size=None, scale_factor=None):
|
||||
"""Upscales the input, using the bilinear upsampling.
|
||||
"""Upscales the input, using bilinear upsampling.
|
||||
|
||||
Currently only spatial upsampling is supported (i.e. expected inputs
|
||||
are 4 dimensional).
|
||||
**Note:: This function is deprecated. Use nn.functional.upsample instead**
|
||||
|
||||
Expected inputs are spatial (4 dimensional). Use upsample_trilinear for volumetric (5 dimensional)
|
||||
inputs.
|
||||
|
||||
Args:
|
||||
input (Variable): input
|
||||
size (int or Tuple[int, int]): output spatial size.
|
||||
scale_factor (int or Tuple[int, int]): multiplier for spatial size
|
||||
"""
|
||||
return _functions.thnn.UpsamplingBilinear2d(size, scale_factor)(input)
|
||||
|
||||
|
||||
def _check_bilinear_2d_scale_factor(scale_factor):
|
||||
scale_factor = _pair(scale_factor)
|
||||
try:
|
||||
assert len(scale_factor) == 2
|
||||
assert all(isinstance(s, Integral) and s >= 1 for s in scale_factor)
|
||||
except AssertionError as e:
|
||||
raise ValueError('scale_factor must be a non-negative integer, '
|
||||
'or a tuple of non-negative integers for bilinear upsamplings, but got: '
|
||||
'{}'.format(scale_factor))
|
||||
return scale_factor
|
||||
# DeprecationWarning is ignored by default
|
||||
warnings.warn("nn.functional.upsample_bilinear is deprecated. Use nn.functional.upsample instead.")
|
||||
return upsample(input, size, scale_factor, mode='bilinear')
|
||||
|
||||
|
||||
def pad(input, pad, mode='constant', value=0):
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ from .sparse import Embedding
|
|||
from .rnn import RNNBase, RNN, LSTM, GRU, \
|
||||
RNNCell, LSTMCell, GRUCell
|
||||
from .pixelshuffle import PixelShuffle
|
||||
from .upsampling import UpsamplingNearest2d, UpsamplingBilinear2d
|
||||
from .upsampling import UpsamplingNearest2d, UpsamplingBilinear2d, Upsample
|
||||
from .distance import PairwiseDistance, CosineSimilarity
|
||||
|
||||
|
||||
|
|
@ -41,7 +41,7 @@ __all__ = [
|
|||
'InstanceNorm3d', 'Dropout', 'Dropout2d', 'Dropout3d', 'ReflectionPad2d',
|
||||
'ReplicationPad2d', 'ReplicationPad3d', 'CrossMapLRN2d',
|
||||
'Embedding', 'RNNBase', 'RNN', 'LSTM', 'GRU', 'RNNCell', 'LSTMCell', 'GRUCell',
|
||||
'PixelShuffle', 'UpsamplingNearest2d', 'UpsamplingBilinear2d', 'PairwiseDistance',
|
||||
'PixelShuffle', 'Upsample', 'UpsamplingNearest2d', 'UpsamplingBilinear2d', 'PairwiseDistance',
|
||||
'AdaptiveMaxPool1d', 'AdaptiveMaxPool2d', 'AdaptiveAvgPool1d', 'AdaptiveAvgPool2d',
|
||||
'TripletMarginLoss', 'ZeroPad2d', 'ConstantPad2d', 'Bilinear', 'CosineSimilarity',
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,30 +1,94 @@
|
|||
from numbers import Integral
|
||||
import warnings
|
||||
|
||||
from .module import Module
|
||||
from .. import functional as F
|
||||
from .utils import _pair
|
||||
from .utils import _pair, _triple
|
||||
|
||||
|
||||
class _UpsamplingBase(Module):
|
||||
class Upsample(Module):
|
||||
"""
|
||||
Upsamples a given multi-channel 2D (spatial) or 3D (volumetric) data.
|
||||
|
||||
def __init__(self, size=None, scale_factor=None):
|
||||
super(_UpsamplingBase, self).__init__()
|
||||
if size is None and scale_factor is None:
|
||||
raise ValueError('either size or scale_factor should be defined')
|
||||
if scale_factor is not None and not isinstance(scale_factor, (Integral, tuple)):
|
||||
raise ValueError('scale_factor must be of integer type or tuple of integer types')
|
||||
The input data is assumed to be of the form `minibatch x channels x [depth] x height x width.
|
||||
Hence, for spatial inputs, we expect a 4D Tensor and for volumetric inputs, we expect a 5D Tensor.
|
||||
|
||||
The algorithms available for upsampling are nearest neighbor, bilinear and trilinear upsampling,
|
||||
with bilinear only available for 4D Tensor inputs and trilinear for 4D Tensor inputs.
|
||||
|
||||
One can either give a :attr:`scale_factor` or the target output :attr:`size` to
|
||||
calculate the output size. (You cannot give both, as it is ambiguous)
|
||||
|
||||
Args:
|
||||
size (tuple, optional): a tuple of ints ([D_out], H_out, W_out) output sizes
|
||||
scale_factor (int / tuple of ints, optional): the multiplier for the image height / width / depth
|
||||
mode (string, optional): the upsampling algorithm: nearest | bilinear | trilinear
|
||||
|
||||
Shape:
|
||||
- Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(N, C, D_{in}, H_{in}, W_{in})`
|
||||
- Output: :math:`(N, C, H_{out}, W_{out})` or :math:`(N, C, D_{out}, H_{out}, W_{out})` where
|
||||
:math:`D_{out} = floor(D_{in} * scale\_factor)` or `size[-3]`
|
||||
:math:`H_{out} = floor(H_{in} * scale\_factor)` or `size[-2]`
|
||||
:math:`W_{out} = floor(W_{in} * scale\_factor)` or `size[-1]`
|
||||
|
||||
Examples::
|
||||
|
||||
>>> inp
|
||||
Variable containing:
|
||||
(0 ,0 ,.,.) =
|
||||
1 2
|
||||
3 4
|
||||
[torch.FloatTensor of size 1x1x2x2]
|
||||
|
||||
>>> m = nn.Upsample(scale_factor=2, mode='bilinear')
|
||||
>>> m(inp)
|
||||
Variable containing:
|
||||
(0 ,0 ,.,.) =
|
||||
1.0000 1.3333 1.6667 2.0000
|
||||
1.6667 2.0000 2.3333 2.6667
|
||||
2.3333 2.6667 3.0000 3.3333
|
||||
3.0000 3.3333 3.6667 4.0000
|
||||
[torch.FloatTensor of size 1x1x4x4]
|
||||
|
||||
>>> inp
|
||||
Variable containing:
|
||||
(0 ,0 ,.,.) =
|
||||
1 2
|
||||
3 4
|
||||
[torch.FloatTensor of size 1x1x2x2]
|
||||
|
||||
>>> m = nn.Upsample(scale_factor=2, mode='nearest')
|
||||
>>> m(inp)
|
||||
Variable containing:
|
||||
(0 ,0 ,.,.) =
|
||||
1 1 2 2
|
||||
1 1 2 2
|
||||
3 3 4 4
|
||||
3 3 4 4
|
||||
[torch.FloatTensor of size 1x1x4x4]
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, size=None, scale_factor=None, mode='nearest'):
|
||||
super(Upsample, self).__init__()
|
||||
self.size = size
|
||||
self.scale_factor = scale_factor
|
||||
self.mode = mode
|
||||
|
||||
def forward(self, input):
|
||||
return F.upsample(input, self.size, self.scale_factor, self.mode)
|
||||
|
||||
def __repr__(self):
|
||||
if self.scale_factor is not None:
|
||||
info = 'scale_factor=' + str(self.scale_factor)
|
||||
else:
|
||||
info = 'size=' + str(self.size)
|
||||
info += ', mode=' + self.mode
|
||||
return self.__class__.__name__ + '(' + info + ')'
|
||||
|
||||
|
||||
class UpsamplingNearest2d(_UpsamplingBase):
|
||||
class UpsamplingNearest2d(Upsample):
|
||||
"""
|
||||
Applies a 2D nearest neighbor upsampling to an input signal composed of several input
|
||||
channels.
|
||||
|
|
@ -64,18 +128,15 @@ class UpsamplingNearest2d(_UpsamplingBase):
|
|||
[torch.FloatTensor of size 1x1x4x4]
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, size=None, scale_factor=None):
|
||||
super(UpsamplingNearest2d, self).__init__(size, scale_factor)
|
||||
if self.scale_factor is not None and not isinstance(scale_factor, Integral):
|
||||
raise ValueError('scale_factor must be of integer type for neighest neighbor sampling')
|
||||
self.size = _pair(self.size) if self.size is not None else None
|
||||
super(UpsamplingNearest2d, self).__init__(size, scale_factor, mode='nearest')
|
||||
|
||||
def forward(self, input):
|
||||
return F.upsample_nearest(input, self.size, self.scale_factor)
|
||||
warnings.warn("nn.UpsamplingNearest2d is deprecated. Use nn.Upsample instead.")
|
||||
return super(UpsamplingNearest2d, self).forward(input)
|
||||
|
||||
|
||||
class UpsamplingBilinear2d(_UpsamplingBase):
|
||||
class UpsamplingBilinear2d(Upsample):
|
||||
"""
|
||||
Applies a 2D bilinear upsampling to an input signal composed of several input
|
||||
channels.
|
||||
|
|
@ -115,13 +176,9 @@ class UpsamplingBilinear2d(_UpsamplingBase):
|
|||
[torch.FloatTensor of size 1x1x4x4]
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, size=None, scale_factor=None):
|
||||
super(UpsamplingBilinear2d, self).__init__(size, scale_factor)
|
||||
|
||||
if self.scale_factor is not None:
|
||||
self.scale_factor = F._check_bilinear_2d_scale_factor(self.scale_factor)
|
||||
self.size = _pair(self.size) if self.size is not None else None
|
||||
super(UpsamplingBilinear2d, self).__init__(size, scale_factor, mode='bilinear')
|
||||
|
||||
def forward(self, input):
|
||||
return F.upsample_bilinear(input, self.size, self.scale_factor)
|
||||
warnings.warn("nn.UpsamplingBilinear2d is deprecated. Use nn.Upsample instead.")
|
||||
return super(UpsamplingBilinear2d, self).forward(input)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user