mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
No-batch-dim support for ConvNd (#70506)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/70506 Test Plan: Imported from OSS Reviewed By: albanD Differential Revision: D33355034 Pulled By: jbschlosser fbshipit-source-id: 5a42645299b1d82cee7d461826acca1c5b35a71c
This commit is contained in:
parent
6896b2d734
commit
7b8f73dd32
|
|
@ -575,6 +575,30 @@ static void check_shape_backward(
|
|||
check_shape_forward(input, weight_sizes, /*bias=*/ Tensor(), params);
|
||||
}
|
||||
|
||||
// Given an input tensor and an expected number of spatial dimensions, checks that the
|
||||
// input is a valid shape and returns the batched form of the input.
|
||||
//
|
||||
// Args:
|
||||
// input (Tensor): Input tensor
|
||||
// num_spatial_dims (int): Number of spatial dimensions expected for the input
|
||||
// func_name (string): Function name to produce a nice error message for invalid input
|
||||
//
|
||||
// Returns a std::tuple containing:
|
||||
// batched_input (Tensor): Input with a batch dimension
|
||||
// is_batched (bool): Indicates whether the original input was already batched
|
||||
static std::tuple<Tensor, bool> batchify(
|
||||
const Tensor& input,
|
||||
const int64_t num_spatial_dims,
|
||||
const std::string& func_name) {
|
||||
const auto dim_count_no_batch = num_spatial_dims + 1;
|
||||
const auto dim_count_batch = dim_count_no_batch + 1;
|
||||
const auto is_batched = (input.dim() == dim_count_batch);
|
||||
TORCH_CHECK(input.dim() == dim_count_no_batch || is_batched,
|
||||
"Expected ", dim_count_no_batch, "D (unbatched) or ", dim_count_batch,
|
||||
"D (batched) input to ", func_name, ", but got input of size: ", input.sizes());
|
||||
return std::make_tuple(is_batched ? input : input.unsqueeze(0), is_batched);
|
||||
}
|
||||
|
||||
static void check_input_same_type_as_parameters(
|
||||
const Tensor& input,
|
||||
const Tensor& weight,
|
||||
|
|
@ -618,36 +642,45 @@ static at::Tensor subtensor(at::Tensor& tensor, int dim, int groups, int g) {
|
|||
|
||||
|
||||
at::Tensor conv1d(
|
||||
const Tensor& input, const Tensor& weight, const c10::optional<Tensor>& bias_opt,
|
||||
const Tensor& input_, const Tensor& weight, const c10::optional<Tensor>& bias_opt,
|
||||
IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, int64_t groups) {
|
||||
// See [Note: hacky wrapper removal for optional tensor]
|
||||
c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
|
||||
const Tensor& bias = *bias_maybe_owned;
|
||||
|
||||
return at::convolution(input, weight, bias, stride, padding, dilation,
|
||||
false, {0}, groups);
|
||||
Tensor input;
|
||||
bool is_batched;
|
||||
std::tie(input, is_batched) = batchify(input_, /*num_spatial_dims=*/ 1, "conv1d");
|
||||
auto output = at::convolution(input, weight, bias, stride, padding, dilation, false, {0}, groups);
|
||||
return is_batched ? output : output.squeeze(0);
|
||||
}
|
||||
|
||||
at::Tensor conv2d(
|
||||
const Tensor& input, const Tensor& weight, const c10::optional<Tensor>& bias_opt,
|
||||
const Tensor& input_, const Tensor& weight, const c10::optional<Tensor>& bias_opt,
|
||||
IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, int64_t groups) {
|
||||
// See [Note: hacky wrapper removal for optional tensor]
|
||||
c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
|
||||
const Tensor& bias = *bias_maybe_owned;
|
||||
|
||||
return at::convolution(input, weight, bias, stride, padding, dilation,
|
||||
false, {{0, 0}}, groups);
|
||||
Tensor input;
|
||||
bool is_batched;
|
||||
std::tie(input, is_batched) = batchify(input_, /*num_spatial_dims=*/ 2, "conv2d");
|
||||
auto output = at::convolution(input, weight, bias, stride, padding, dilation, false, {{0, 0}}, groups);
|
||||
return is_batched ? output : output.squeeze(0);
|
||||
}
|
||||
|
||||
at::Tensor conv3d(
|
||||
const Tensor& input, const Tensor& weight, const c10::optional<Tensor>& bias_opt,
|
||||
const Tensor& input_, const Tensor& weight, const c10::optional<Tensor>& bias_opt,
|
||||
IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, int64_t groups) {
|
||||
// See [Note: hacky wrapper removal for optional tensor]
|
||||
c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
|
||||
const Tensor& bias = *bias_maybe_owned;
|
||||
|
||||
return at::convolution(input, weight, bias, stride, padding, dilation,
|
||||
false, {{0, 0, 0}}, groups);
|
||||
Tensor input;
|
||||
bool is_batched;
|
||||
std::tie(input, is_batched) = batchify(input_, /*num_spatial_dims=*/ 3, "conv3d");
|
||||
auto output = at::convolution(input, weight, bias, stride, padding, dilation, false, {{0, 0, 0}}, groups);
|
||||
return is_batched ? output : output.squeeze(0);
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -736,60 +769,84 @@ Tensor _convolution_mode(
|
|||
}
|
||||
|
||||
at::Tensor conv1d(
|
||||
const Tensor& input, const Tensor& weight, const c10::optional<Tensor>& bias,
|
||||
const Tensor& input_, const Tensor& weight, const c10::optional<Tensor>& bias,
|
||||
IntArrayRef stride, c10::string_view padding, IntArrayRef dilation,
|
||||
int64_t groups) {
|
||||
return at::_convolution_mode(
|
||||
Tensor input;
|
||||
bool is_batched;
|
||||
std::tie(input, is_batched) = batchify(input_, /*num_spatial_dims=*/ 1, "conv1d");
|
||||
auto output = at::_convolution_mode(
|
||||
input, weight, bias, stride, std::move(padding), dilation, groups);
|
||||
return is_batched ? output : output.squeeze(0);
|
||||
}
|
||||
|
||||
at::Tensor conv2d(
|
||||
const Tensor& input, const Tensor& weight, const c10::optional<Tensor>& bias,
|
||||
const Tensor& input_, const Tensor& weight, const c10::optional<Tensor>& bias,
|
||||
IntArrayRef stride, c10::string_view padding, IntArrayRef dilation,
|
||||
int64_t groups) {
|
||||
return at::_convolution_mode(
|
||||
Tensor input;
|
||||
bool is_batched;
|
||||
std::tie(input, is_batched) = batchify(input_, /*num_spatial_dims=*/ 2, "conv2d");
|
||||
auto output = at::_convolution_mode(
|
||||
input, weight, bias, stride, std::move(padding), dilation, groups);
|
||||
return is_batched ? output : output.squeeze(0);
|
||||
}
|
||||
|
||||
at::Tensor conv3d(
|
||||
const Tensor& input, const Tensor& weight, const c10::optional<Tensor>& bias,
|
||||
const Tensor& input_, const Tensor& weight, const c10::optional<Tensor>& bias,
|
||||
IntArrayRef stride, c10::string_view padding, IntArrayRef dilation,
|
||||
int64_t groups) {
|
||||
return at::_convolution_mode(
|
||||
Tensor input;
|
||||
bool is_batched;
|
||||
std::tie(input, is_batched) = batchify(input_, /*num_spatial_dims=*/ 3, "conv3d");
|
||||
auto output = at::_convolution_mode(
|
||||
input, weight, bias, stride, std::move(padding), dilation, groups);
|
||||
return is_batched ? output : output.squeeze(0);
|
||||
}
|
||||
|
||||
at::Tensor conv_transpose1d(
|
||||
const Tensor& input, const Tensor& weight, const c10::optional<Tensor>& bias_opt,
|
||||
const Tensor& input_, const Tensor& weight, const c10::optional<Tensor>& bias_opt,
|
||||
IntArrayRef stride, IntArrayRef padding, IntArrayRef output_padding, int64_t groups, IntArrayRef dilation) {
|
||||
// See [Note: hacky wrapper removal for optional tensor]
|
||||
c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
|
||||
const Tensor& bias = *bias_maybe_owned;
|
||||
|
||||
return at::convolution(input, weight, bias, stride, padding, dilation,
|
||||
true, output_padding, groups);
|
||||
Tensor input;
|
||||
bool is_batched;
|
||||
std::tie(input, is_batched) = batchify(input_, /*num_spatial_dims=*/ 1, "conv_transpose1d");
|
||||
auto output = at::convolution(
|
||||
input, weight, bias, stride, padding, dilation, true, output_padding, groups);
|
||||
return is_batched ? output : output.squeeze(0);
|
||||
}
|
||||
|
||||
at::Tensor conv_transpose2d(
|
||||
const Tensor& input, const Tensor& weight, const c10::optional<Tensor>& bias_opt,
|
||||
const Tensor& input_, const Tensor& weight, const c10::optional<Tensor>& bias_opt,
|
||||
IntArrayRef stride, IntArrayRef padding, IntArrayRef output_padding, int64_t groups, IntArrayRef dilation) {
|
||||
// See [Note: hacky wrapper removal for optional tensor]
|
||||
c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
|
||||
const Tensor& bias = *bias_maybe_owned;
|
||||
|
||||
return at::convolution(input, weight, bias, stride, padding, dilation,
|
||||
true, output_padding, groups);
|
||||
Tensor input;
|
||||
bool is_batched;
|
||||
std::tie(input, is_batched) = batchify(input_, /*num_spatial_dims=*/ 2, "conv_transpose2d");
|
||||
auto output = at::convolution(
|
||||
input, weight, bias, stride, padding, dilation, true, output_padding, groups);
|
||||
return is_batched ? output : output.squeeze(0);
|
||||
}
|
||||
|
||||
at::Tensor conv_transpose3d(
|
||||
const Tensor& input, const Tensor& weight, const c10::optional<Tensor>& bias_opt,
|
||||
const Tensor& input_, const Tensor& weight, const c10::optional<Tensor>& bias_opt,
|
||||
IntArrayRef stride, IntArrayRef padding, IntArrayRef output_padding, int64_t groups, IntArrayRef dilation) {
|
||||
// See [Note: hacky wrapper removal for optional tensor]
|
||||
c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
|
||||
const Tensor& bias = *bias_maybe_owned;
|
||||
|
||||
return at::convolution(input, weight, bias, stride, padding, dilation,
|
||||
true, output_padding, groups);
|
||||
Tensor input;
|
||||
bool is_batched;
|
||||
std::tie(input, is_batched) = batchify(input_, /*num_spatial_dims=*/ 3, "conv_transpose3d");
|
||||
auto output = at::convolution(
|
||||
input, weight, bias, stride, padding, dilation, true, output_padding, groups);
|
||||
return is_batched ? output : output.squeeze(0);
|
||||
}
|
||||
|
||||
at::Tensor convolution(
|
||||
|
|
|
|||
|
|
@ -159,7 +159,7 @@ class TestORTTensor(common.TestCase):
|
|||
bias = torch.empty(6, device='ort')
|
||||
|
||||
# Make sure forward is overriden
|
||||
out = torch.nn.functional.conv1d(input, weight, bias, 2, 0, 1, 1)
|
||||
out = torch.nn.functional.conv2d(input, weight, bias, 2, 0, 1, 1)
|
||||
self.assertEqual(ort_extension.get_test_int(), 2)
|
||||
self.assertEqual(out.shape[0], input.shape[0])
|
||||
self.assertEqual(out.shape[1], weight.shape[0])
|
||||
|
|
|
|||
|
|
@ -13005,7 +13005,7 @@ dedent """
|
|||
return self.conv(x)
|
||||
foo = Foo()
|
||||
# testing that the correct error message propagates
|
||||
with self.assertRaisesRegex(RuntimeError, "Expected 4-dimensional input for 4-dimensional weight"):
|
||||
with self.assertRaisesRegex(RuntimeError, r"Expected 3D \(unbatched\) or 4D \(batched\) input to conv2d"):
|
||||
foo(torch.ones([123])) # wrong size
|
||||
|
||||
def test_builtin_error_messsage(self):
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from torch.testing._internal.common_device_type import (
|
|||
from torch.testing._internal.common_modules import module_db, modules
|
||||
from torch.testing._internal.common_utils import (
|
||||
TestCase, run_tests, freeze_rng_state, mock_wrapper, get_tensors_from, gradcheck, gradgradcheck)
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import patch, call
|
||||
|
||||
|
||||
class TestModule(TestCase):
|
||||
|
|
@ -122,9 +122,9 @@ class TestModule(TestCase):
|
|||
with patch.object(torch.nn.UninitializedBuffer, '__new__', uninit_buffer_new):
|
||||
m = module_cls(*args, **kwargs)
|
||||
uninit_param_new.mock.assert_has_calls(
|
||||
[mock.call(device=device, dtype=dtype) for _ in uninit_param_new.mock.mock_calls])
|
||||
[call(device=device, dtype=dtype) for _ in uninit_param_new.mock.mock_calls])
|
||||
uninit_buffer_new.mock.assert_has_calls(
|
||||
[mock.call(device=device, dtype=dtype) for _ in uninit_buffer_new.mock.mock_calls])
|
||||
[call(device=device, dtype=dtype) for _ in uninit_buffer_new.mock.mock_calls])
|
||||
else:
|
||||
# Check device placement and dtype for created parameters and buffers.
|
||||
# Only verify floating point dtypes since that's what the kwarg applies to.
|
||||
|
|
@ -421,9 +421,13 @@ class TestModule(TestCase):
|
|||
|
||||
params = tuple(m.parameters())
|
||||
|
||||
# === Perform gradient check on the input_args ===
|
||||
# === Lazy modules need to see an input to initialize params before gradcheck is run. ===
|
||||
input_args, input_kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
|
||||
if issubclass(module_info.module_cls, torch.nn.modules.lazy.LazyModuleMixin):
|
||||
with torch.no_grad():
|
||||
m(*input_args, **input_kwargs)
|
||||
|
||||
# === Perform gradient check on the input_args ===
|
||||
other_kwargs = {}
|
||||
kwarg_tensors = []
|
||||
for name, obj in input_kwargs.items():
|
||||
|
|
|
|||
|
|
@ -895,8 +895,8 @@ class TestNN(NNTestCase):
|
|||
w = torch.randn(6, 1, 5, 5)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
r'Expected 4-dimensional input for 4-dimensional weight \[6, 1, 5, 5\],' +
|
||||
r' but got 5-dimensional input of size \[1, 10, 1, 28, 28\] instead'):
|
||||
r'Expected 3D \(unbatched\) or 4D \(batched\) input to conv2d, but got ' +
|
||||
r'input of size: \[1, 10, 1, 28, 28\]'):
|
||||
|
||||
F.conv2d(x, w)
|
||||
|
||||
|
|
@ -6172,9 +6172,9 @@ class TestNN(NNTestCase):
|
|||
nn.Conv2d(3, 8, 3).to(dtype), nn.ConvTranspose2d(3, 8, 3).to(dtype),
|
||||
nn.Conv3d(3, 8, 3).to(dtype), nn.ConvTranspose3d(3, 8, 3).to(dtype)]
|
||||
|
||||
invalid_input_dims = [(2, 4), (2, 4),
|
||||
(3, 5), (3, 5),
|
||||
(4, 6), (4, 6)]
|
||||
invalid_input_dims = [(1, 4), (1, 4),
|
||||
(2, 5), (2, 5),
|
||||
(3, 6), (3, 6)]
|
||||
|
||||
for invalid_dims, module in zip(invalid_input_dims, modules):
|
||||
for dims in invalid_dims:
|
||||
|
|
@ -13402,7 +13402,7 @@ class TestNNDeviceType(NNTestCase):
|
|||
gx_expect, gy_expect = x.grad, y.grad
|
||||
x.grad, y.grad = None, None
|
||||
|
||||
z = F.conv1d(x, y, padding='same')
|
||||
z = F.conv2d(x, y, padding='same')
|
||||
z.sum().backward()
|
||||
self.assertEqual(gx_expect, x.grad)
|
||||
self.assertEqual(gy_expect, y.grad)
|
||||
|
|
|
|||
|
|
@ -233,8 +233,8 @@ class Conv1d(_ConvNd):
|
|||
""".format(**reproducibility_notes, **convolution_notes) + r"""
|
||||
|
||||
Shape:
|
||||
- Input: :math:`(N, C_{in}, L_{in})`
|
||||
- Output: :math:`(N, C_{out}, L_{out})` where
|
||||
- Input: :math:`(N, C_{in}, L_{in})` or :math:`(C_{in}, L_{in})`
|
||||
- Output: :math:`(N, C_{out}, L_{out})` or :math:`(C_{out}, L_{out})`, where
|
||||
|
||||
.. math::
|
||||
L_{out} = \left\lfloor\frac{L_{in} + 2 \times \text{padding} - \text{dilation}
|
||||
|
|
@ -370,8 +370,8 @@ class Conv2d(_ConvNd):
|
|||
""".format(**reproducibility_notes, **convolution_notes) + r"""
|
||||
|
||||
Shape:
|
||||
- Input: :math:`(N, C_{in}, H_{in}, W_{in})`
|
||||
- Output: :math:`(N, C_{out}, H_{out}, W_{out})` where
|
||||
- Input: :math:`(N, C_{in}, H_{in}, W_{in})` or :math:`(C_{in}, H_{in}, W_{in})`
|
||||
- Output: :math:`(N, C_{out}, H_{out}, W_{out})` or :math:`(C_{out}, H_{out}, W_{out})`, where
|
||||
|
||||
.. math::
|
||||
H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[0] - \text{dilation}[0]
|
||||
|
|
@ -504,8 +504,9 @@ class Conv3d(_ConvNd):
|
|||
""".format(**reproducibility_notes, **convolution_notes) + r"""
|
||||
|
||||
Shape:
|
||||
- Input: :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})`
|
||||
- Output: :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` where
|
||||
- Input: :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})` or :math:`(C_{in}, D_{in}, H_{in}, W_{in})`
|
||||
- Output: :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` or :math:`(C_{out}, D_{out}, H_{out}, W_{out})`,
|
||||
where
|
||||
|
||||
.. math::
|
||||
D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] - \text{dilation}[0]
|
||||
|
|
@ -710,8 +711,8 @@ class ConvTranspose1d(_ConvTransposeNd):
|
|||
""".format(**reproducibility_notes, **convolution_notes) + r"""
|
||||
|
||||
Shape:
|
||||
- Input: :math:`(N, C_{in}, L_{in})`
|
||||
- Output: :math:`(N, C_{out}, L_{out})` where
|
||||
- Input: :math:`(N, C_{in}, L_{in})` or :math:`(C_{in}, L_{in})`
|
||||
- Output: :math:`(N, C_{out}, L_{out})` or :math:`(C_{out}, L_{out})`, where
|
||||
|
||||
.. math::
|
||||
L_{out} = (L_{in} - 1) \times \text{stride} - 2 \times \text{padding} + \text{dilation}
|
||||
|
|
@ -838,8 +839,8 @@ class ConvTranspose2d(_ConvTransposeNd):
|
|||
""".format(**reproducibility_notes, **convolution_notes) + r"""
|
||||
|
||||
Shape:
|
||||
- Input: :math:`(N, C_{in}, H_{in}, W_{in})`
|
||||
- Output: :math:`(N, C_{out}, H_{out}, W_{out})` where
|
||||
- Input: :math:`(N, C_{in}, H_{in}, W_{in})` or :math:`(C_{in}, H_{in}, W_{in})`
|
||||
- Output: :math:`(N, C_{out}, H_{out}, W_{out})` or :math:`(C_{out}, H_{out}, W_{out})`, where
|
||||
|
||||
.. math::
|
||||
H_{out} = (H_{in} - 1) \times \text{stride}[0] - 2 \times \text{padding}[0] + \text{dilation}[0]
|
||||
|
|
@ -991,8 +992,9 @@ class ConvTranspose3d(_ConvTransposeNd):
|
|||
""".format(**reproducibility_notes, **convolution_notes) + r"""
|
||||
|
||||
Shape:
|
||||
- Input: :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})`
|
||||
- Output: :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` where
|
||||
- Input: :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})` or :math:`(C_{in}, D_{in}, H_{in}, W_{in})`
|
||||
- Output: :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` or
|
||||
:math:`(C_{out}, D_{out}, H_{out}, W_{out})`, where
|
||||
|
||||
.. math::
|
||||
D_{out} = (D_{in} - 1) \times \text{stride}[0] - 2 \times \text{padding}[0] + \text{dilation}[0]
|
||||
|
|
@ -1122,7 +1124,7 @@ class _LazyConvXdMixin(LazyModuleMixin):
|
|||
def initialize_parameters(self, input) -> None: # type: ignore[override]
|
||||
# defined by parent class but using a protocol
|
||||
if self.has_uninitialized_params(): # type: ignore[misc]
|
||||
self.in_channels = input.shape[1]
|
||||
self.in_channels = self._get_in_channels(input)
|
||||
if self.in_channels % self.groups != 0:
|
||||
raise ValueError('in_channels must be divisible by groups')
|
||||
assert isinstance(self.weight, UninitializedParameter)
|
||||
|
|
@ -1137,6 +1139,22 @@ class _LazyConvXdMixin(LazyModuleMixin):
|
|||
self.bias.materialize((self.out_channels,))
|
||||
self.reset_parameters()
|
||||
|
||||
# Function to extract in_channels from first input.
|
||||
def _get_in_channels(self, input: Tensor) -> int:
|
||||
num_spatial_dims = self._get_num_spatial_dims()
|
||||
num_dims_no_batch = num_spatial_dims + 1 # +1 for channels dim
|
||||
num_dims_batch = num_dims_no_batch + 1
|
||||
if input.dim() not in (num_dims_no_batch, num_dims_batch):
|
||||
raise RuntimeError("Expected {}D (unbatched) or {}D (batched) input to {}, but "
|
||||
"got input of size: {}".format(num_dims_no_batch, num_dims_batch,
|
||||
self.__class__.__name__, input.shape))
|
||||
return input.shape[1] if input.dim() == num_dims_batch else input.shape[0]
|
||||
|
||||
# Function to return the number of spatial dims expected for inputs to the module.
|
||||
# This is expected to be implemented by subclasses.
|
||||
def _get_num_spatial_dims(self) -> int:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
# LazyConv1d defines weight as a Tensor but derived class defines it as UnitializeParameter
|
||||
class LazyConv1d(_LazyConvXdMixin, Conv1d): # type: ignore[misc]
|
||||
|
|
@ -1203,6 +1221,9 @@ class LazyConv1d(_LazyConvXdMixin, Conv1d): # type: ignore[misc]
|
|||
if bias:
|
||||
self.bias = UninitializedParameter(**factory_kwargs)
|
||||
|
||||
def _get_num_spatial_dims(self) -> int:
|
||||
return 1
|
||||
|
||||
|
||||
# LazyConv2d defines weight as a Tensor but derived class defines it as UnitializeParameter
|
||||
class LazyConv2d(_LazyConvXdMixin, Conv2d): # type: ignore[misc]
|
||||
|
|
@ -1269,6 +1290,9 @@ class LazyConv2d(_LazyConvXdMixin, Conv2d): # type: ignore[misc]
|
|||
if bias:
|
||||
self.bias = UninitializedParameter(**factory_kwargs)
|
||||
|
||||
def _get_num_spatial_dims(self) -> int:
|
||||
return 2
|
||||
|
||||
|
||||
# LazyConv3d defines weight as a Tensor but derived class defines it as UnitializeParameter
|
||||
class LazyConv3d(_LazyConvXdMixin, Conv3d): # type: ignore[misc]
|
||||
|
|
@ -1335,6 +1359,9 @@ class LazyConv3d(_LazyConvXdMixin, Conv3d): # type: ignore[misc]
|
|||
if bias:
|
||||
self.bias = UninitializedParameter(**factory_kwargs)
|
||||
|
||||
def _get_num_spatial_dims(self) -> int:
|
||||
return 3
|
||||
|
||||
|
||||
# LazyConvTranspose1d defines weight as a Tensor but derived class defines it as UnitializeParameter
|
||||
class LazyConvTranspose1d(_LazyConvXdMixin, ConvTranspose1d): # type: ignore[misc]
|
||||
|
|
@ -1400,6 +1427,9 @@ class LazyConvTranspose1d(_LazyConvXdMixin, ConvTranspose1d): # type: ignore[mi
|
|||
if bias:
|
||||
self.bias = UninitializedParameter(**factory_kwargs)
|
||||
|
||||
def _get_num_spatial_dims(self) -> int:
|
||||
return 1
|
||||
|
||||
|
||||
# LazyConvTranspose2d defines weight as a Tensor but derived class defines it as UnitializeParameter
|
||||
class LazyConvTranspose2d(_LazyConvXdMixin, ConvTranspose2d): # type: ignore[misc]
|
||||
|
|
@ -1465,6 +1495,9 @@ class LazyConvTranspose2d(_LazyConvXdMixin, ConvTranspose2d): # type: ignore[mi
|
|||
if bias:
|
||||
self.bias = UninitializedParameter(**factory_kwargs)
|
||||
|
||||
def _get_num_spatial_dims(self) -> int:
|
||||
return 2
|
||||
|
||||
|
||||
# LazyConvTranspose3d defines weight as a Tensor but derived class defines it as UnitializeParameter
|
||||
class LazyConvTranspose3d(_LazyConvXdMixin, ConvTranspose3d): # type: ignore[misc]
|
||||
|
|
@ -1529,3 +1562,6 @@ class LazyConvTranspose3d(_LazyConvXdMixin, ConvTranspose3d): # type: ignore[mi
|
|||
self.out_channels = out_channels
|
||||
if bias:
|
||||
self.bias = UninitializedParameter(**factory_kwargs)
|
||||
|
||||
def _get_num_spatial_dims(self) -> int:
|
||||
return 3
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from torch.testing import make_tensor
|
|||
from torch.testing._internal.common_dtype import floating_types
|
||||
from torch.testing._internal.common_device_type import (
|
||||
_TestParametrizer, _update_param_kwargs, skipIf, toleranceOverride, tol,
|
||||
skipCUDAIfCudnnVersionLessThan, skipCUDAIfRocm, precisionOverride)
|
||||
skipCUDAIfCudnnVersionLessThan, skipCUDAIfRocm, precisionOverride, skipMeta)
|
||||
from torch.testing._internal.common_methods_invocations import DecorateInfo
|
||||
from torch.testing._internal.common_nn import nllloss_reference, get_reduction
|
||||
from torch.testing._internal.common_utils import (
|
||||
|
|
@ -408,28 +408,24 @@ def module_inputs_torch_nn_BatchNorm3d(module_info, device, dtype, requires_grad
|
|||
forward_input=FunctionInput(make_input(shape=(2, 3, 4, 4, 4))))]
|
||||
|
||||
|
||||
def module_inputs_torch_nn_Conv2d(module_info, device, dtype, requires_grad, **kwargs):
|
||||
def module_inputs_torch_nn_ConvNd(module_info, device, dtype, requires_grad, **kwargs):
|
||||
N = kwargs['N']
|
||||
lazy = kwargs.get('lazy', False)
|
||||
transposed = kwargs.get('transposed', False)
|
||||
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
|
||||
conv_kwargs_list = [{}] if transposed else [{}, {'padding': 'same'}]
|
||||
kernel_size, C_in, C_out = 3, 4, 5
|
||||
input_no_batch_shape = (C_in,) + tuple((i + 3 for i in range(N)))
|
||||
input_batch_shape = (2,) + input_no_batch_shape
|
||||
return [
|
||||
ModuleInput(constructor_input=FunctionInput(3, 4, 3),
|
||||
forward_input=FunctionInput(make_input(shape=(2, 3, 7, 5))))]
|
||||
|
||||
|
||||
def module_inputs_torch_nn_Conv3d(module_info, device, dtype, requires_grad, **kwargs):
|
||||
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
|
||||
return [
|
||||
ModuleInput(constructor_input=FunctionInput(2, 3, (2, 3, 2)),
|
||||
forward_input=FunctionInput(make_input(shape=(1, 2, 4, 5, 4))))]
|
||||
|
||||
|
||||
def module_inputs_torch_nn_ConvTranspose2d(module_info, device, dtype, requires_grad, **kwargs):
|
||||
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
|
||||
return [
|
||||
ModuleInput(constructor_input=FunctionInput(3, 4, 3, (3, 2), 1, (1, 1)),
|
||||
forward_input=FunctionInput(make_input(shape=(1, 3, 7, 6))))]
|
||||
ModuleInput(constructor_input=(FunctionInput(C_out, kernel_size, **conv_kwargs) if lazy else
|
||||
FunctionInput(C_in, C_out, kernel_size, **conv_kwargs)),
|
||||
forward_input=FunctionInput(make_input(
|
||||
shape=(input_batch_shape if with_batch else input_no_batch_shape))),
|
||||
desc=('' if with_batch else 'no_batch_dim'),
|
||||
reference_fn=(None if with_batch else no_batch_dim_reference_fn))
|
||||
for with_batch, conv_kwargs in itertools.product([True, False], conv_kwargs_list)
|
||||
]
|
||||
|
||||
|
||||
def module_inputs_torch_nn_ELU(module_info, device, dtype, requires_grad, **kwargs):
|
||||
|
|
@ -818,54 +814,84 @@ module_db: List[ModuleInfo] = [
|
|||
# Failure on ROCM for BatchNorm3d float32 issue #70125
|
||||
DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),)
|
||||
),
|
||||
ModuleInfo(torch.nn.Conv2d,
|
||||
module_inputs_func=module_inputs_torch_nn_Conv2d,
|
||||
ModuleInfo(torch.nn.Conv1d,
|
||||
module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=1, lazy=False),
|
||||
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
|
||||
module_memformat_affects_out=True,
|
||||
skips=(
|
||||
# NHWC is disabled for float64 input in CudNN Conv.
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format', dtypes=[torch.float64]),
|
||||
# No channels_last support for Conv2d on cpu currently.
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format', device_type='cpu'),),
|
||||
decorators=(
|
||||
# Conv2d channels_last support on cuda requires cudnn >= 7603
|
||||
# channels_last support on cuda requires cudnn >= 7603
|
||||
DecorateInfo(skipCUDAIfCudnnVersionLessThan(version=7603), 'TestModule', 'test_memory_format'),
|
||||
# Failure on ROCM for Conv2d float32 issue #70125
|
||||
# Failure on ROCM for float32 issue #70125
|
||||
DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
|
||||
DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'))
|
||||
),
|
||||
ModuleInfo(torch.nn.Conv3d,
|
||||
module_inputs_func=module_inputs_torch_nn_Conv3d,
|
||||
decorators=(
|
||||
DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
|
||||
)),
|
||||
ModuleInfo(torch.nn.Conv2d,
|
||||
module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=2, lazy=False),
|
||||
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
|
||||
module_memformat_affects_out=True,
|
||||
skips=(
|
||||
# NHWC is disabled for float64 input in CudNN Conv.
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format', dtypes=[torch.float64]),
|
||||
# No channels_last support for Conv3d on cpu currently.
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format', device_type='cpu'),
|
||||
# Greatest difference was 0.05072784423828125 > atol of 0.05
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_cpu_gpu_parity'),),
|
||||
decorators=(
|
||||
# Conv3d channels_last support on cuda requires cudnn >= 8005
|
||||
DecorateInfo(skipCUDAIfCudnnVersionLessThan(version=8005), 'TestModule', 'test_memory_format'),
|
||||
# Failure on ROCM for Conv3d float32 issue #70125
|
||||
DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]))
|
||||
),
|
||||
ModuleInfo(torch.nn.ConvTranspose2d,
|
||||
module_inputs_func=module_inputs_torch_nn_ConvTranspose2d,
|
||||
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
|
||||
module_memformat_affects_out=True,
|
||||
skips=(
|
||||
# NHWC is disabled for float64 input in CudNN Conv.
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format', dtypes=[torch.float64]),
|
||||
# No channels_last support for ConvTranspose2d on cpu currently.
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format', device_type='cpu'),),
|
||||
decorators=(
|
||||
# ConvTranspose2d channels_last support on cuda requires cudnn >= 7603
|
||||
# channels_last support on cuda requires cudnn >= 7603
|
||||
DecorateInfo(skipCUDAIfCudnnVersionLessThan(version=7603), 'TestModule', 'test_memory_format'),
|
||||
# Failure on ROCM for ConvTranspose2d float32 issue #70125
|
||||
DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]))
|
||||
# Failure on ROCM for float32 issue #70125
|
||||
DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
|
||||
),
|
||||
decorators=(
|
||||
DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
|
||||
)),
|
||||
ModuleInfo(torch.nn.Conv3d,
|
||||
module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=3, lazy=False),
|
||||
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
|
||||
module_memformat_affects_out=True,
|
||||
skips=(
|
||||
# channels_last support on cuda requires cudnn >= 8005
|
||||
DecorateInfo(skipCUDAIfCudnnVersionLessThan(version=8005), 'TestModule', 'test_memory_format'),
|
||||
# Failure on ROCM for float32 issue #70125
|
||||
DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
|
||||
),
|
||||
decorators=(
|
||||
DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
|
||||
)),
|
||||
ModuleInfo(torch.nn.ConvTranspose1d,
|
||||
module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=1, lazy=False, transposed=True),
|
||||
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
|
||||
module_memformat_affects_out=True,
|
||||
skips=(
|
||||
# channels_last support on cuda requires cudnn >= 7603
|
||||
DecorateInfo(skipCUDAIfCudnnVersionLessThan(version=7603), 'TestModule', 'test_memory_format'),
|
||||
# Failure on ROCM for float32 issue #70125
|
||||
DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
|
||||
),
|
||||
decorators=(
|
||||
DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
|
||||
)),
|
||||
ModuleInfo(torch.nn.ConvTranspose2d,
|
||||
module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=2, lazy=False, transposed=True),
|
||||
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
|
||||
module_memformat_affects_out=True,
|
||||
skips=(
|
||||
# channels_last support on cuda requires cudnn >= 7603
|
||||
DecorateInfo(skipCUDAIfCudnnVersionLessThan(version=7603), 'TestModule', 'test_memory_format'),
|
||||
# Failure on ROCM for float32 issue #70125
|
||||
DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
|
||||
),
|
||||
decorators=(
|
||||
DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
|
||||
)),
|
||||
ModuleInfo(torch.nn.ConvTranspose3d,
|
||||
module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=3, lazy=False, transposed=True),
|
||||
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
|
||||
module_memformat_affects_out=True,
|
||||
skips=(
|
||||
# channels_last support on cuda requires cudnn >= 8005
|
||||
DecorateInfo(skipCUDAIfCudnnVersionLessThan(version=8005), 'TestModule', 'test_memory_format'),
|
||||
# Failure on ROCM for float32 issue #70125
|
||||
DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
|
||||
),
|
||||
decorators=(
|
||||
DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
|
||||
)),
|
||||
ModuleInfo(torch.nn.ELU,
|
||||
module_inputs_func=module_inputs_torch_nn_ELU),
|
||||
ModuleInfo(torch.nn.L1Loss,
|
||||
|
|
@ -874,6 +900,102 @@ module_db: List[ModuleInfo] = [
|
|||
# No channels_last support for loss functions.
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
|
||||
),
|
||||
ModuleInfo(torch.nn.LazyConv1d,
|
||||
module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=1, lazy=True),
|
||||
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
|
||||
module_memformat_affects_out=True,
|
||||
skips=(
|
||||
# channels_last support on cuda requires cudnn >= 7603
|
||||
DecorateInfo(skipCUDAIfCudnnVersionLessThan(version=7603), 'TestModule', 'test_memory_format'),
|
||||
# Failure on ROCM for float32 issue #70125
|
||||
DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
|
||||
# Lazy modules don't currently play well with ModuleInfo tests on the meta device.
|
||||
# See https://github.com/pytorch/pytorch/issues/70505 for more info.
|
||||
DecorateInfo(skipMeta),
|
||||
),
|
||||
decorators=(
|
||||
DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
|
||||
)),
|
||||
ModuleInfo(torch.nn.LazyConv2d,
|
||||
module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=2, lazy=True),
|
||||
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
|
||||
module_memformat_affects_out=True,
|
||||
skips=(
|
||||
# channels_last support on cuda requires cudnn >= 7603
|
||||
DecorateInfo(skipCUDAIfCudnnVersionLessThan(version=7603), 'TestModule', 'test_memory_format'),
|
||||
# Failure on ROCM for float32 issue #70125
|
||||
DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
|
||||
# Lazy modules don't currently play well with ModuleInfo tests on the meta device.
|
||||
# See https://github.com/pytorch/pytorch/issues/70505 for more info.
|
||||
DecorateInfo(skipMeta),
|
||||
),
|
||||
decorators=(
|
||||
DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
|
||||
)),
|
||||
ModuleInfo(torch.nn.LazyConv3d,
|
||||
module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=3, lazy=True),
|
||||
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
|
||||
module_memformat_affects_out=True,
|
||||
skips=(
|
||||
# channels_last support on cuda requires cudnn >= 8005
|
||||
DecorateInfo(skipCUDAIfCudnnVersionLessThan(version=8005), 'TestModule', 'test_memory_format'),
|
||||
# Failure on ROCM for float32 issue #70125
|
||||
DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
|
||||
# Lazy modules don't currently play well with ModuleInfo tests on the meta device.
|
||||
# See https://github.com/pytorch/pytorch/issues/70505 for more info.
|
||||
DecorateInfo(skipMeta),
|
||||
),
|
||||
decorators=(
|
||||
DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
|
||||
)),
|
||||
ModuleInfo(torch.nn.LazyConvTranspose1d,
|
||||
module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=1, lazy=True, transposed=True),
|
||||
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
|
||||
module_memformat_affects_out=True,
|
||||
skips=(
|
||||
# channels_last support on cuda requires cudnn >= 7603
|
||||
DecorateInfo(skipCUDAIfCudnnVersionLessThan(version=7603), 'TestModule', 'test_memory_format'),
|
||||
# Failure on ROCM for float32 issue #70125
|
||||
DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
|
||||
# Lazy modules don't currently play well with ModuleInfo tests on the meta device.
|
||||
# See https://github.com/pytorch/pytorch/issues/70505 for more info.
|
||||
DecorateInfo(skipMeta),
|
||||
),
|
||||
decorators=(
|
||||
DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
|
||||
)),
|
||||
ModuleInfo(torch.nn.LazyConvTranspose2d,
|
||||
module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=2, lazy=True, transposed=True),
|
||||
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
|
||||
module_memformat_affects_out=True,
|
||||
skips=(
|
||||
# channels_last support on cuda requires cudnn >= 7603
|
||||
DecorateInfo(skipCUDAIfCudnnVersionLessThan(version=7603), 'TestModule', 'test_memory_format'),
|
||||
# Failure on ROCM for float32 issue #70125
|
||||
DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
|
||||
# Lazy modules don't currently play well with ModuleInfo tests on the meta device.
|
||||
# See https://github.com/pytorch/pytorch/issues/70505 for more info.
|
||||
DecorateInfo(skipMeta),
|
||||
),
|
||||
decorators=(
|
||||
DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
|
||||
)),
|
||||
ModuleInfo(torch.nn.LazyConvTranspose3d,
|
||||
module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=3, lazy=True, transposed=True),
|
||||
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
|
||||
module_memformat_affects_out=True,
|
||||
skips=(
|
||||
# channels_last support on cuda requires cudnn >= 8005
|
||||
DecorateInfo(skipCUDAIfCudnnVersionLessThan(version=8005), 'TestModule', 'test_memory_format'),
|
||||
# Failure on ROCM for float32 issue #70125
|
||||
DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
|
||||
# Lazy modules don't currently play well with ModuleInfo tests on the meta device.
|
||||
# See https://github.com/pytorch/pytorch/issues/70505 for more info.
|
||||
DecorateInfo(skipMeta),
|
||||
),
|
||||
decorators=(
|
||||
DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
|
||||
)),
|
||||
ModuleInfo(torch.nn.Linear,
|
||||
module_inputs_func=module_inputs_torch_nn_Linear,
|
||||
skips=(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user