No-batch-dim support for ConvNd (#70506)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/70506

Test Plan: Imported from OSS

Reviewed By: albanD

Differential Revision: D33355034

Pulled By: jbschlosser

fbshipit-source-id: 5a42645299b1d82cee7d461826acca1c5b35a71c
This commit is contained in:
Joel Schlosser 2022-01-06 16:52:12 -08:00 committed by Facebook GitHub Bot
parent 6896b2d734
commit 7b8f73dd32
7 changed files with 326 additions and 107 deletions

View File

@ -575,6 +575,30 @@ static void check_shape_backward(
check_shape_forward(input, weight_sizes, /*bias=*/ Tensor(), params);
}
// Given an input tensor and an expected number of spatial dimensions, checks that the
// input is a valid shape and returns the batched form of the input.
//
// Args:
// input (Tensor): Input tensor
// num_spatial_dims (int): Number of spatial dimensions expected for the input
// func_name (string): Function name to produce a nice error message for invalid input
//
// Returns a std::tuple containing:
// batched_input (Tensor): Input with a batch dimension
// is_batched (bool): Indicates whether the original input was already batched
static std::tuple<Tensor, bool> batchify(
const Tensor& input,
const int64_t num_spatial_dims,
const std::string& func_name) {
const auto dim_count_no_batch = num_spatial_dims + 1;
const auto dim_count_batch = dim_count_no_batch + 1;
const auto is_batched = (input.dim() == dim_count_batch);
TORCH_CHECK(input.dim() == dim_count_no_batch || is_batched,
"Expected ", dim_count_no_batch, "D (unbatched) or ", dim_count_batch,
"D (batched) input to ", func_name, ", but got input of size: ", input.sizes());
return std::make_tuple(is_batched ? input : input.unsqueeze(0), is_batched);
}
static void check_input_same_type_as_parameters(
const Tensor& input,
const Tensor& weight,
@ -618,36 +642,45 @@ static at::Tensor subtensor(at::Tensor& tensor, int dim, int groups, int g) {
at::Tensor conv1d(
const Tensor& input, const Tensor& weight, const c10::optional<Tensor>& bias_opt,
const Tensor& input_, const Tensor& weight, const c10::optional<Tensor>& bias_opt,
IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, int64_t groups) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
const Tensor& bias = *bias_maybe_owned;
return at::convolution(input, weight, bias, stride, padding, dilation,
false, {0}, groups);
Tensor input;
bool is_batched;
std::tie(input, is_batched) = batchify(input_, /*num_spatial_dims=*/ 1, "conv1d");
auto output = at::convolution(input, weight, bias, stride, padding, dilation, false, {0}, groups);
return is_batched ? output : output.squeeze(0);
}
at::Tensor conv2d(
const Tensor& input, const Tensor& weight, const c10::optional<Tensor>& bias_opt,
const Tensor& input_, const Tensor& weight, const c10::optional<Tensor>& bias_opt,
IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, int64_t groups) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
const Tensor& bias = *bias_maybe_owned;
return at::convolution(input, weight, bias, stride, padding, dilation,
false, {{0, 0}}, groups);
Tensor input;
bool is_batched;
std::tie(input, is_batched) = batchify(input_, /*num_spatial_dims=*/ 2, "conv2d");
auto output = at::convolution(input, weight, bias, stride, padding, dilation, false, {{0, 0}}, groups);
return is_batched ? output : output.squeeze(0);
}
at::Tensor conv3d(
const Tensor& input, const Tensor& weight, const c10::optional<Tensor>& bias_opt,
const Tensor& input_, const Tensor& weight, const c10::optional<Tensor>& bias_opt,
IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, int64_t groups) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
const Tensor& bias = *bias_maybe_owned;
return at::convolution(input, weight, bias, stride, padding, dilation,
false, {{0, 0, 0}}, groups);
Tensor input;
bool is_batched;
std::tie(input, is_batched) = batchify(input_, /*num_spatial_dims=*/ 3, "conv3d");
auto output = at::convolution(input, weight, bias, stride, padding, dilation, false, {{0, 0, 0}}, groups);
return is_batched ? output : output.squeeze(0);
}
@ -736,60 +769,84 @@ Tensor _convolution_mode(
}
at::Tensor conv1d(
const Tensor& input, const Tensor& weight, const c10::optional<Tensor>& bias,
const Tensor& input_, const Tensor& weight, const c10::optional<Tensor>& bias,
IntArrayRef stride, c10::string_view padding, IntArrayRef dilation,
int64_t groups) {
return at::_convolution_mode(
Tensor input;
bool is_batched;
std::tie(input, is_batched) = batchify(input_, /*num_spatial_dims=*/ 1, "conv1d");
auto output = at::_convolution_mode(
input, weight, bias, stride, std::move(padding), dilation, groups);
return is_batched ? output : output.squeeze(0);
}
at::Tensor conv2d(
const Tensor& input, const Tensor& weight, const c10::optional<Tensor>& bias,
const Tensor& input_, const Tensor& weight, const c10::optional<Tensor>& bias,
IntArrayRef stride, c10::string_view padding, IntArrayRef dilation,
int64_t groups) {
return at::_convolution_mode(
Tensor input;
bool is_batched;
std::tie(input, is_batched) = batchify(input_, /*num_spatial_dims=*/ 2, "conv2d");
auto output = at::_convolution_mode(
input, weight, bias, stride, std::move(padding), dilation, groups);
return is_batched ? output : output.squeeze(0);
}
at::Tensor conv3d(
const Tensor& input, const Tensor& weight, const c10::optional<Tensor>& bias,
const Tensor& input_, const Tensor& weight, const c10::optional<Tensor>& bias,
IntArrayRef stride, c10::string_view padding, IntArrayRef dilation,
int64_t groups) {
return at::_convolution_mode(
Tensor input;
bool is_batched;
std::tie(input, is_batched) = batchify(input_, /*num_spatial_dims=*/ 3, "conv3d");
auto output = at::_convolution_mode(
input, weight, bias, stride, std::move(padding), dilation, groups);
return is_batched ? output : output.squeeze(0);
}
at::Tensor conv_transpose1d(
const Tensor& input, const Tensor& weight, const c10::optional<Tensor>& bias_opt,
const Tensor& input_, const Tensor& weight, const c10::optional<Tensor>& bias_opt,
IntArrayRef stride, IntArrayRef padding, IntArrayRef output_padding, int64_t groups, IntArrayRef dilation) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
const Tensor& bias = *bias_maybe_owned;
return at::convolution(input, weight, bias, stride, padding, dilation,
true, output_padding, groups);
Tensor input;
bool is_batched;
std::tie(input, is_batched) = batchify(input_, /*num_spatial_dims=*/ 1, "conv_transpose1d");
auto output = at::convolution(
input, weight, bias, stride, padding, dilation, true, output_padding, groups);
return is_batched ? output : output.squeeze(0);
}
at::Tensor conv_transpose2d(
const Tensor& input, const Tensor& weight, const c10::optional<Tensor>& bias_opt,
const Tensor& input_, const Tensor& weight, const c10::optional<Tensor>& bias_opt,
IntArrayRef stride, IntArrayRef padding, IntArrayRef output_padding, int64_t groups, IntArrayRef dilation) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
const Tensor& bias = *bias_maybe_owned;
return at::convolution(input, weight, bias, stride, padding, dilation,
true, output_padding, groups);
Tensor input;
bool is_batched;
std::tie(input, is_batched) = batchify(input_, /*num_spatial_dims=*/ 2, "conv_transpose2d");
auto output = at::convolution(
input, weight, bias, stride, padding, dilation, true, output_padding, groups);
return is_batched ? output : output.squeeze(0);
}
at::Tensor conv_transpose3d(
const Tensor& input, const Tensor& weight, const c10::optional<Tensor>& bias_opt,
const Tensor& input_, const Tensor& weight, const c10::optional<Tensor>& bias_opt,
IntArrayRef stride, IntArrayRef padding, IntArrayRef output_padding, int64_t groups, IntArrayRef dilation) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
const Tensor& bias = *bias_maybe_owned;
return at::convolution(input, weight, bias, stride, padding, dilation,
true, output_padding, groups);
Tensor input;
bool is_batched;
std::tie(input, is_batched) = batchify(input_, /*num_spatial_dims=*/ 3, "conv_transpose3d");
auto output = at::convolution(
input, weight, bias, stride, padding, dilation, true, output_padding, groups);
return is_batched ? output : output.squeeze(0);
}
at::Tensor convolution(

View File

@ -159,7 +159,7 @@ class TestORTTensor(common.TestCase):
bias = torch.empty(6, device='ort')
# Make sure forward is overriden
out = torch.nn.functional.conv1d(input, weight, bias, 2, 0, 1, 1)
out = torch.nn.functional.conv2d(input, weight, bias, 2, 0, 1, 1)
self.assertEqual(ort_extension.get_test_int(), 2)
self.assertEqual(out.shape[0], input.shape[0])
self.assertEqual(out.shape[1], weight.shape[0])

View File

@ -13005,7 +13005,7 @@ dedent """
return self.conv(x)
foo = Foo()
# testing that the correct error message propagates
with self.assertRaisesRegex(RuntimeError, "Expected 4-dimensional input for 4-dimensional weight"):
with self.assertRaisesRegex(RuntimeError, r"Expected 3D \(unbatched\) or 4D \(batched\) input to conv2d"):
foo(torch.ones([123])) # wrong size
def test_builtin_error_messsage(self):

View File

@ -12,7 +12,7 @@ from torch.testing._internal.common_device_type import (
from torch.testing._internal.common_modules import module_db, modules
from torch.testing._internal.common_utils import (
TestCase, run_tests, freeze_rng_state, mock_wrapper, get_tensors_from, gradcheck, gradgradcheck)
from unittest.mock import patch
from unittest.mock import patch, call
class TestModule(TestCase):
@ -122,9 +122,9 @@ class TestModule(TestCase):
with patch.object(torch.nn.UninitializedBuffer, '__new__', uninit_buffer_new):
m = module_cls(*args, **kwargs)
uninit_param_new.mock.assert_has_calls(
[mock.call(device=device, dtype=dtype) for _ in uninit_param_new.mock.mock_calls])
[call(device=device, dtype=dtype) for _ in uninit_param_new.mock.mock_calls])
uninit_buffer_new.mock.assert_has_calls(
[mock.call(device=device, dtype=dtype) for _ in uninit_buffer_new.mock.mock_calls])
[call(device=device, dtype=dtype) for _ in uninit_buffer_new.mock.mock_calls])
else:
# Check device placement and dtype for created parameters and buffers.
# Only verify floating point dtypes since that's what the kwarg applies to.
@ -421,9 +421,13 @@ class TestModule(TestCase):
params = tuple(m.parameters())
# === Perform gradient check on the input_args ===
# === Lazy modules need to see an input to initialize params before gradcheck is run. ===
input_args, input_kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
if issubclass(module_info.module_cls, torch.nn.modules.lazy.LazyModuleMixin):
with torch.no_grad():
m(*input_args, **input_kwargs)
# === Perform gradient check on the input_args ===
other_kwargs = {}
kwarg_tensors = []
for name, obj in input_kwargs.items():

View File

@ -895,8 +895,8 @@ class TestNN(NNTestCase):
w = torch.randn(6, 1, 5, 5)
with self.assertRaisesRegex(RuntimeError,
r'Expected 4-dimensional input for 4-dimensional weight \[6, 1, 5, 5\],' +
r' but got 5-dimensional input of size \[1, 10, 1, 28, 28\] instead'):
r'Expected 3D \(unbatched\) or 4D \(batched\) input to conv2d, but got ' +
r'input of size: \[1, 10, 1, 28, 28\]'):
F.conv2d(x, w)
@ -6172,9 +6172,9 @@ class TestNN(NNTestCase):
nn.Conv2d(3, 8, 3).to(dtype), nn.ConvTranspose2d(3, 8, 3).to(dtype),
nn.Conv3d(3, 8, 3).to(dtype), nn.ConvTranspose3d(3, 8, 3).to(dtype)]
invalid_input_dims = [(2, 4), (2, 4),
(3, 5), (3, 5),
(4, 6), (4, 6)]
invalid_input_dims = [(1, 4), (1, 4),
(2, 5), (2, 5),
(3, 6), (3, 6)]
for invalid_dims, module in zip(invalid_input_dims, modules):
for dims in invalid_dims:
@ -13402,7 +13402,7 @@ class TestNNDeviceType(NNTestCase):
gx_expect, gy_expect = x.grad, y.grad
x.grad, y.grad = None, None
z = F.conv1d(x, y, padding='same')
z = F.conv2d(x, y, padding='same')
z.sum().backward()
self.assertEqual(gx_expect, x.grad)
self.assertEqual(gy_expect, y.grad)

View File

@ -233,8 +233,8 @@ class Conv1d(_ConvNd):
""".format(**reproducibility_notes, **convolution_notes) + r"""
Shape:
- Input: :math:`(N, C_{in}, L_{in})`
- Output: :math:`(N, C_{out}, L_{out})` where
- Input: :math:`(N, C_{in}, L_{in})` or :math:`(C_{in}, L_{in})`
- Output: :math:`(N, C_{out}, L_{out})` or :math:`(C_{out}, L_{out})`, where
.. math::
L_{out} = \left\lfloor\frac{L_{in} + 2 \times \text{padding} - \text{dilation}
@ -370,8 +370,8 @@ class Conv2d(_ConvNd):
""".format(**reproducibility_notes, **convolution_notes) + r"""
Shape:
- Input: :math:`(N, C_{in}, H_{in}, W_{in})`
- Output: :math:`(N, C_{out}, H_{out}, W_{out})` where
- Input: :math:`(N, C_{in}, H_{in}, W_{in})` or :math:`(C_{in}, H_{in}, W_{in})`
- Output: :math:`(N, C_{out}, H_{out}, W_{out})` or :math:`(C_{out}, H_{out}, W_{out})`, where
.. math::
H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[0] - \text{dilation}[0]
@ -504,8 +504,9 @@ class Conv3d(_ConvNd):
""".format(**reproducibility_notes, **convolution_notes) + r"""
Shape:
- Input: :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})`
- Output: :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` where
- Input: :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})` or :math:`(C_{in}, D_{in}, H_{in}, W_{in})`
- Output: :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` or :math:`(C_{out}, D_{out}, H_{out}, W_{out})`,
where
.. math::
D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] - \text{dilation}[0]
@ -710,8 +711,8 @@ class ConvTranspose1d(_ConvTransposeNd):
""".format(**reproducibility_notes, **convolution_notes) + r"""
Shape:
- Input: :math:`(N, C_{in}, L_{in})`
- Output: :math:`(N, C_{out}, L_{out})` where
- Input: :math:`(N, C_{in}, L_{in})` or :math:`(C_{in}, L_{in})`
- Output: :math:`(N, C_{out}, L_{out})` or :math:`(C_{out}, L_{out})`, where
.. math::
L_{out} = (L_{in} - 1) \times \text{stride} - 2 \times \text{padding} + \text{dilation}
@ -838,8 +839,8 @@ class ConvTranspose2d(_ConvTransposeNd):
""".format(**reproducibility_notes, **convolution_notes) + r"""
Shape:
- Input: :math:`(N, C_{in}, H_{in}, W_{in})`
- Output: :math:`(N, C_{out}, H_{out}, W_{out})` where
- Input: :math:`(N, C_{in}, H_{in}, W_{in})` or :math:`(C_{in}, H_{in}, W_{in})`
- Output: :math:`(N, C_{out}, H_{out}, W_{out})` or :math:`(C_{out}, H_{out}, W_{out})`, where
.. math::
H_{out} = (H_{in} - 1) \times \text{stride}[0] - 2 \times \text{padding}[0] + \text{dilation}[0]
@ -991,8 +992,9 @@ class ConvTranspose3d(_ConvTransposeNd):
""".format(**reproducibility_notes, **convolution_notes) + r"""
Shape:
- Input: :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})`
- Output: :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` where
- Input: :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})` or :math:`(C_{in}, D_{in}, H_{in}, W_{in})`
- Output: :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` or
:math:`(C_{out}, D_{out}, H_{out}, W_{out})`, where
.. math::
D_{out} = (D_{in} - 1) \times \text{stride}[0] - 2 \times \text{padding}[0] + \text{dilation}[0]
@ -1122,7 +1124,7 @@ class _LazyConvXdMixin(LazyModuleMixin):
def initialize_parameters(self, input) -> None: # type: ignore[override]
# defined by parent class but using a protocol
if self.has_uninitialized_params(): # type: ignore[misc]
self.in_channels = input.shape[1]
self.in_channels = self._get_in_channels(input)
if self.in_channels % self.groups != 0:
raise ValueError('in_channels must be divisible by groups')
assert isinstance(self.weight, UninitializedParameter)
@ -1137,6 +1139,22 @@ class _LazyConvXdMixin(LazyModuleMixin):
self.bias.materialize((self.out_channels,))
self.reset_parameters()
# Function to extract in_channels from first input.
def _get_in_channels(self, input: Tensor) -> int:
num_spatial_dims = self._get_num_spatial_dims()
num_dims_no_batch = num_spatial_dims + 1 # +1 for channels dim
num_dims_batch = num_dims_no_batch + 1
if input.dim() not in (num_dims_no_batch, num_dims_batch):
raise RuntimeError("Expected {}D (unbatched) or {}D (batched) input to {}, but "
"got input of size: {}".format(num_dims_no_batch, num_dims_batch,
self.__class__.__name__, input.shape))
return input.shape[1] if input.dim() == num_dims_batch else input.shape[0]
# Function to return the number of spatial dims expected for inputs to the module.
# This is expected to be implemented by subclasses.
def _get_num_spatial_dims(self) -> int:
raise NotImplementedError()
# LazyConv1d defines weight as a Tensor but derived class defines it as UnitializeParameter
class LazyConv1d(_LazyConvXdMixin, Conv1d): # type: ignore[misc]
@ -1203,6 +1221,9 @@ class LazyConv1d(_LazyConvXdMixin, Conv1d): # type: ignore[misc]
if bias:
self.bias = UninitializedParameter(**factory_kwargs)
def _get_num_spatial_dims(self) -> int:
return 1
# LazyConv2d defines weight as a Tensor but derived class defines it as UnitializeParameter
class LazyConv2d(_LazyConvXdMixin, Conv2d): # type: ignore[misc]
@ -1269,6 +1290,9 @@ class LazyConv2d(_LazyConvXdMixin, Conv2d): # type: ignore[misc]
if bias:
self.bias = UninitializedParameter(**factory_kwargs)
def _get_num_spatial_dims(self) -> int:
return 2
# LazyConv3d defines weight as a Tensor but derived class defines it as UnitializeParameter
class LazyConv3d(_LazyConvXdMixin, Conv3d): # type: ignore[misc]
@ -1335,6 +1359,9 @@ class LazyConv3d(_LazyConvXdMixin, Conv3d): # type: ignore[misc]
if bias:
self.bias = UninitializedParameter(**factory_kwargs)
def _get_num_spatial_dims(self) -> int:
return 3
# LazyConvTranspose1d defines weight as a Tensor but derived class defines it as UnitializeParameter
class LazyConvTranspose1d(_LazyConvXdMixin, ConvTranspose1d): # type: ignore[misc]
@ -1400,6 +1427,9 @@ class LazyConvTranspose1d(_LazyConvXdMixin, ConvTranspose1d): # type: ignore[mi
if bias:
self.bias = UninitializedParameter(**factory_kwargs)
def _get_num_spatial_dims(self) -> int:
return 1
# LazyConvTranspose2d defines weight as a Tensor but derived class defines it as UnitializeParameter
class LazyConvTranspose2d(_LazyConvXdMixin, ConvTranspose2d): # type: ignore[misc]
@ -1465,6 +1495,9 @@ class LazyConvTranspose2d(_LazyConvXdMixin, ConvTranspose2d): # type: ignore[mi
if bias:
self.bias = UninitializedParameter(**factory_kwargs)
def _get_num_spatial_dims(self) -> int:
return 2
# LazyConvTranspose3d defines weight as a Tensor but derived class defines it as UnitializeParameter
class LazyConvTranspose3d(_LazyConvXdMixin, ConvTranspose3d): # type: ignore[misc]
@ -1529,3 +1562,6 @@ class LazyConvTranspose3d(_LazyConvXdMixin, ConvTranspose3d): # type: ignore[mi
self.out_channels = out_channels
if bias:
self.bias = UninitializedParameter(**factory_kwargs)
def _get_num_spatial_dims(self) -> int:
return 3

View File

@ -9,7 +9,7 @@ from torch.testing import make_tensor
from torch.testing._internal.common_dtype import floating_types
from torch.testing._internal.common_device_type import (
_TestParametrizer, _update_param_kwargs, skipIf, toleranceOverride, tol,
skipCUDAIfCudnnVersionLessThan, skipCUDAIfRocm, precisionOverride)
skipCUDAIfCudnnVersionLessThan, skipCUDAIfRocm, precisionOverride, skipMeta)
from torch.testing._internal.common_methods_invocations import DecorateInfo
from torch.testing._internal.common_nn import nllloss_reference, get_reduction
from torch.testing._internal.common_utils import (
@ -408,28 +408,24 @@ def module_inputs_torch_nn_BatchNorm3d(module_info, device, dtype, requires_grad
forward_input=FunctionInput(make_input(shape=(2, 3, 4, 4, 4))))]
def module_inputs_torch_nn_Conv2d(module_info, device, dtype, requires_grad, **kwargs):
def module_inputs_torch_nn_ConvNd(module_info, device, dtype, requires_grad, **kwargs):
N = kwargs['N']
lazy = kwargs.get('lazy', False)
transposed = kwargs.get('transposed', False)
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
conv_kwargs_list = [{}] if transposed else [{}, {'padding': 'same'}]
kernel_size, C_in, C_out = 3, 4, 5
input_no_batch_shape = (C_in,) + tuple((i + 3 for i in range(N)))
input_batch_shape = (2,) + input_no_batch_shape
return [
ModuleInput(constructor_input=FunctionInput(3, 4, 3),
forward_input=FunctionInput(make_input(shape=(2, 3, 7, 5))))]
def module_inputs_torch_nn_Conv3d(module_info, device, dtype, requires_grad, **kwargs):
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
return [
ModuleInput(constructor_input=FunctionInput(2, 3, (2, 3, 2)),
forward_input=FunctionInput(make_input(shape=(1, 2, 4, 5, 4))))]
def module_inputs_torch_nn_ConvTranspose2d(module_info, device, dtype, requires_grad, **kwargs):
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
return [
ModuleInput(constructor_input=FunctionInput(3, 4, 3, (3, 2), 1, (1, 1)),
forward_input=FunctionInput(make_input(shape=(1, 3, 7, 6))))]
ModuleInput(constructor_input=(FunctionInput(C_out, kernel_size, **conv_kwargs) if lazy else
FunctionInput(C_in, C_out, kernel_size, **conv_kwargs)),
forward_input=FunctionInput(make_input(
shape=(input_batch_shape if with_batch else input_no_batch_shape))),
desc=('' if with_batch else 'no_batch_dim'),
reference_fn=(None if with_batch else no_batch_dim_reference_fn))
for with_batch, conv_kwargs in itertools.product([True, False], conv_kwargs_list)
]
def module_inputs_torch_nn_ELU(module_info, device, dtype, requires_grad, **kwargs):
@ -818,54 +814,84 @@ module_db: List[ModuleInfo] = [
# Failure on ROCM for BatchNorm3d float32 issue #70125
DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),)
),
ModuleInfo(torch.nn.Conv2d,
module_inputs_func=module_inputs_torch_nn_Conv2d,
ModuleInfo(torch.nn.Conv1d,
module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=1, lazy=False),
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
module_memformat_affects_out=True,
skips=(
# NHWC is disabled for float64 input in CudNN Conv.
DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format', dtypes=[torch.float64]),
# No channels_last support for Conv2d on cpu currently.
DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format', device_type='cpu'),),
decorators=(
# Conv2d channels_last support on cuda requires cudnn >= 7603
# channels_last support on cuda requires cudnn >= 7603
DecorateInfo(skipCUDAIfCudnnVersionLessThan(version=7603), 'TestModule', 'test_memory_format'),
# Failure on ROCM for Conv2d float32 issue #70125
# Failure on ROCM for float32 issue #70125
DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'))
),
ModuleInfo(torch.nn.Conv3d,
module_inputs_func=module_inputs_torch_nn_Conv3d,
decorators=(
DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
)),
ModuleInfo(torch.nn.Conv2d,
module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=2, lazy=False),
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
module_memformat_affects_out=True,
skips=(
# NHWC is disabled for float64 input in CudNN Conv.
DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format', dtypes=[torch.float64]),
# No channels_last support for Conv3d on cpu currently.
DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format', device_type='cpu'),
# Greatest difference was 0.05072784423828125 > atol of 0.05
DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_cpu_gpu_parity'),),
decorators=(
# Conv3d channels_last support on cuda requires cudnn >= 8005
DecorateInfo(skipCUDAIfCudnnVersionLessThan(version=8005), 'TestModule', 'test_memory_format'),
# Failure on ROCM for Conv3d float32 issue #70125
DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]))
),
ModuleInfo(torch.nn.ConvTranspose2d,
module_inputs_func=module_inputs_torch_nn_ConvTranspose2d,
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
module_memformat_affects_out=True,
skips=(
# NHWC is disabled for float64 input in CudNN Conv.
DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format', dtypes=[torch.float64]),
# No channels_last support for ConvTranspose2d on cpu currently.
DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format', device_type='cpu'),),
decorators=(
# ConvTranspose2d channels_last support on cuda requires cudnn >= 7603
# channels_last support on cuda requires cudnn >= 7603
DecorateInfo(skipCUDAIfCudnnVersionLessThan(version=7603), 'TestModule', 'test_memory_format'),
# Failure on ROCM for ConvTranspose2d float32 issue #70125
DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]))
# Failure on ROCM for float32 issue #70125
DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
),
decorators=(
DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
)),
ModuleInfo(torch.nn.Conv3d,
module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=3, lazy=False),
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
module_memformat_affects_out=True,
skips=(
# channels_last support on cuda requires cudnn >= 8005
DecorateInfo(skipCUDAIfCudnnVersionLessThan(version=8005), 'TestModule', 'test_memory_format'),
# Failure on ROCM for float32 issue #70125
DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
),
decorators=(
DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
)),
ModuleInfo(torch.nn.ConvTranspose1d,
module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=1, lazy=False, transposed=True),
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
module_memformat_affects_out=True,
skips=(
# channels_last support on cuda requires cudnn >= 7603
DecorateInfo(skipCUDAIfCudnnVersionLessThan(version=7603), 'TestModule', 'test_memory_format'),
# Failure on ROCM for float32 issue #70125
DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
),
decorators=(
DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
)),
ModuleInfo(torch.nn.ConvTranspose2d,
module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=2, lazy=False, transposed=True),
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
module_memformat_affects_out=True,
skips=(
# channels_last support on cuda requires cudnn >= 7603
DecorateInfo(skipCUDAIfCudnnVersionLessThan(version=7603), 'TestModule', 'test_memory_format'),
# Failure on ROCM for float32 issue #70125
DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
),
decorators=(
DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
)),
ModuleInfo(torch.nn.ConvTranspose3d,
module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=3, lazy=False, transposed=True),
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
module_memformat_affects_out=True,
skips=(
# channels_last support on cuda requires cudnn >= 8005
DecorateInfo(skipCUDAIfCudnnVersionLessThan(version=8005), 'TestModule', 'test_memory_format'),
# Failure on ROCM for float32 issue #70125
DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
),
decorators=(
DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
)),
ModuleInfo(torch.nn.ELU,
module_inputs_func=module_inputs_torch_nn_ELU),
ModuleInfo(torch.nn.L1Loss,
@ -874,6 +900,102 @@ module_db: List[ModuleInfo] = [
# No channels_last support for loss functions.
DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
),
ModuleInfo(torch.nn.LazyConv1d,
module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=1, lazy=True),
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
module_memformat_affects_out=True,
skips=(
# channels_last support on cuda requires cudnn >= 7603
DecorateInfo(skipCUDAIfCudnnVersionLessThan(version=7603), 'TestModule', 'test_memory_format'),
# Failure on ROCM for float32 issue #70125
DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
# Lazy modules don't currently play well with ModuleInfo tests on the meta device.
# See https://github.com/pytorch/pytorch/issues/70505 for more info.
DecorateInfo(skipMeta),
),
decorators=(
DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
)),
ModuleInfo(torch.nn.LazyConv2d,
module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=2, lazy=True),
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
module_memformat_affects_out=True,
skips=(
# channels_last support on cuda requires cudnn >= 7603
DecorateInfo(skipCUDAIfCudnnVersionLessThan(version=7603), 'TestModule', 'test_memory_format'),
# Failure on ROCM for float32 issue #70125
DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
# Lazy modules don't currently play well with ModuleInfo tests on the meta device.
# See https://github.com/pytorch/pytorch/issues/70505 for more info.
DecorateInfo(skipMeta),
),
decorators=(
DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
)),
ModuleInfo(torch.nn.LazyConv3d,
module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=3, lazy=True),
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
module_memformat_affects_out=True,
skips=(
# channels_last support on cuda requires cudnn >= 8005
DecorateInfo(skipCUDAIfCudnnVersionLessThan(version=8005), 'TestModule', 'test_memory_format'),
# Failure on ROCM for float32 issue #70125
DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
# Lazy modules don't currently play well with ModuleInfo tests on the meta device.
# See https://github.com/pytorch/pytorch/issues/70505 for more info.
DecorateInfo(skipMeta),
),
decorators=(
DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
)),
ModuleInfo(torch.nn.LazyConvTranspose1d,
module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=1, lazy=True, transposed=True),
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
module_memformat_affects_out=True,
skips=(
# channels_last support on cuda requires cudnn >= 7603
DecorateInfo(skipCUDAIfCudnnVersionLessThan(version=7603), 'TestModule', 'test_memory_format'),
# Failure on ROCM for float32 issue #70125
DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
# Lazy modules don't currently play well with ModuleInfo tests on the meta device.
# See https://github.com/pytorch/pytorch/issues/70505 for more info.
DecorateInfo(skipMeta),
),
decorators=(
DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
)),
ModuleInfo(torch.nn.LazyConvTranspose2d,
module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=2, lazy=True, transposed=True),
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
module_memformat_affects_out=True,
skips=(
# channels_last support on cuda requires cudnn >= 7603
DecorateInfo(skipCUDAIfCudnnVersionLessThan(version=7603), 'TestModule', 'test_memory_format'),
# Failure on ROCM for float32 issue #70125
DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
# Lazy modules don't currently play well with ModuleInfo tests on the meta device.
# See https://github.com/pytorch/pytorch/issues/70505 for more info.
DecorateInfo(skipMeta),
),
decorators=(
DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
)),
ModuleInfo(torch.nn.LazyConvTranspose3d,
module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=3, lazy=True, transposed=True),
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
module_memformat_affects_out=True,
skips=(
# channels_last support on cuda requires cudnn >= 8005
DecorateInfo(skipCUDAIfCudnnVersionLessThan(version=8005), 'TestModule', 'test_memory_format'),
# Failure on ROCM for float32 issue #70125
DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
# Lazy modules don't currently play well with ModuleInfo tests on the meta device.
# See https://github.com/pytorch/pytorch/issues/70505 for more info.
DecorateInfo(skipMeta),
),
decorators=(
DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
)),
ModuleInfo(torch.nn.Linear,
module_inputs_func=module_inputs_torch_nn_Linear,
skips=(