mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Improve input dimensions check for reflection_pad1d, reflection_pad2d and reflection_pad3d (#141670)
Fix https://github.com/pytorch/pytorch/issues/141447. Pull Request resolved: https://github.com/pytorch/pytorch/pull/141670 Approved by: https://github.com/mingfeima, https://github.com/malfet
This commit is contained in:
parent
b588a78ca3
commit
863e6e4567
|
|
@ -35,9 +35,10 @@ inline void check_valid_input(const Tensor& input, IntArrayRef padding) {
|
|||
int input_dim = input.dim();
|
||||
|
||||
bool is_batch_mode = input_dim == (dim + 2);
|
||||
bool is_non_batch_mode = input_dim == (dim + 1);
|
||||
|
||||
bool valid_batch_mode = is_batch_mode;
|
||||
bool valid_non_batch_mode = !is_batch_mode;
|
||||
bool valid_non_batch_mode = is_non_batch_mode;
|
||||
|
||||
if (is_batch_mode) {
|
||||
// allow batch size of 0-dim.
|
||||
|
|
|
|||
|
|
@ -8841,6 +8841,35 @@ class TestNNDeviceType(NNTestCase):
|
|||
inp = torch.randn(3, 0, 10, 10, 10, device=device, dtype=dtype)
|
||||
mod(inp)
|
||||
|
||||
@onlyNativeDeviceTypes
|
||||
def test_ReflectionPad_fails(self, device):
|
||||
with self.assertRaisesRegex(RuntimeError, 'Only 2D, 3D, 4D, 5D'):
|
||||
mod = torch.nn.ReflectionPad1d(2)
|
||||
inp = torch.randn(3, 3, 10, 10, device=device)
|
||||
mod(inp)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, '2D or 3D'):
|
||||
inp = torch.randn(3, 3, 10, 10, device=device)
|
||||
torch.ops.aten.reflection_pad1d(inp, (2, 2))
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, 'Only 2D, 3D, 4D, 5D'):
|
||||
mod = torch.nn.ReflectionPad2d(2)
|
||||
inp = torch.randn(3, 3, 10, 10, 10, device=device)
|
||||
mod(inp)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, '3D or 4D'):
|
||||
inp = torch.randn(3, 3, 10, 10, 10, device=device)
|
||||
torch.ops.aten.reflection_pad2d(inp, (2, 2, 2, 2))
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, 'Only 2D, 3D, 4D, 5D'):
|
||||
mod = torch.nn.ReflectionPad3d(3)
|
||||
inp = torch.randn(3, 3, 10, 10, 10, 10, device=device)
|
||||
mod(inp)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, '4D or 5D'):
|
||||
inp = torch.randn(3, 3, 10, 10, 10, 10, device=device)
|
||||
torch.ops.aten.reflection_pad3d(inp, (2, 2, 2, 2, 2, 2))
|
||||
|
||||
@onlyCUDA # Test if CPU and GPU results match
|
||||
def test_ReflectionPad2d_large(self, device):
|
||||
shapes = ([2, 65736, 6, 6], [65736, 2, 6, 6])
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user