Improve input dimensions check for reflection_pad1d, reflection_pad2d and reflection_pad3d (#141670)

Fix https://github.com/pytorch/pytorch/issues/141447.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141670
Approved by: https://github.com/mingfeima, https://github.com/malfet
This commit is contained in:
Sun, Jiayi 2024-12-18 03:27:39 +00:00 committed by PyTorch MergeBot
parent b588a78ca3
commit 863e6e4567
2 changed files with 31 additions and 1 deletions

View File

@ -35,9 +35,10 @@ inline void check_valid_input(const Tensor& input, IntArrayRef padding) {
int input_dim = input.dim();
bool is_batch_mode = input_dim == (dim + 2);
bool is_non_batch_mode = input_dim == (dim + 1);
bool valid_batch_mode = is_batch_mode;
bool valid_non_batch_mode = !is_batch_mode;
bool valid_non_batch_mode = is_non_batch_mode;
if (is_batch_mode) {
// allow batch size of 0-dim.

View File

@ -8841,6 +8841,35 @@ class TestNNDeviceType(NNTestCase):
inp = torch.randn(3, 0, 10, 10, 10, device=device, dtype=dtype)
mod(inp)
@onlyNativeDeviceTypes
def test_ReflectionPad_fails(self, device):
with self.assertRaisesRegex(RuntimeError, 'Only 2D, 3D, 4D, 5D'):
mod = torch.nn.ReflectionPad1d(2)
inp = torch.randn(3, 3, 10, 10, device=device)
mod(inp)
with self.assertRaisesRegex(RuntimeError, '2D or 3D'):
inp = torch.randn(3, 3, 10, 10, device=device)
torch.ops.aten.reflection_pad1d(inp, (2, 2))
with self.assertRaisesRegex(RuntimeError, 'Only 2D, 3D, 4D, 5D'):
mod = torch.nn.ReflectionPad2d(2)
inp = torch.randn(3, 3, 10, 10, 10, device=device)
mod(inp)
with self.assertRaisesRegex(RuntimeError, '3D or 4D'):
inp = torch.randn(3, 3, 10, 10, 10, device=device)
torch.ops.aten.reflection_pad2d(inp, (2, 2, 2, 2))
with self.assertRaisesRegex(RuntimeError, 'Only 2D, 3D, 4D, 5D'):
mod = torch.nn.ReflectionPad3d(3)
inp = torch.randn(3, 3, 10, 10, 10, 10, device=device)
mod(inp)
with self.assertRaisesRegex(RuntimeError, '4D or 5D'):
inp = torch.randn(3, 3, 10, 10, 10, 10, device=device)
torch.ops.aten.reflection_pad3d(inp, (2, 2, 2, 2, 2, 2))
@onlyCUDA # Test if CPU and GPU results match
def test_ReflectionPad2d_large(self, device):
shapes = ([2, 65736, 6, 6], [65736, 2, 6, 6])