diff --git a/aten/src/ATen/native/PixelShuffle.h b/aten/src/ATen/native/PixelShuffle.h index 49699107d9c..46ffa7ddb23 100644 --- a/aten/src/ATen/native/PixelShuffle.h +++ b/aten/src/ATen/native/PixelShuffle.h @@ -11,6 +11,8 @@ inline void check_pixel_shuffle_shapes(const Tensor& self, int64_t upscale_facto "pixel_shuffle expects a positive upscale_factor, but got ", upscale_factor); int64_t c = self.size(-3); + TORCH_CHECK_VALUE(upscale_factor <= std::numeric_limits::max() / upscale_factor, + "upscale factor is too large, (upscale_factor)^2 overflowed: upscale_factor=", upscale_factor); int64_t upscale_factor_squared = upscale_factor * upscale_factor; TORCH_CHECK(c % upscale_factor_squared == 0, "pixel_shuffle expects its input's 'channel' dimension to be divisible by the square of " diff --git a/test/test_mps.py b/test/test_mps.py index 6f4e957aa9d..254af075d8c 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -515,6 +515,11 @@ class TestPixelShuffle(TestCaseMPS): _test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, downscale_factor=0) _test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, downscale_factor=-2) + def test_pixel_shuffle_large_upscale_factor(): + with self.assertRaises(ValueError): + ps = nn.PixelShuffle(545460846592) + ps(torch.randn(2, 16, 9, 3)) + def test_pixel_shuffle_unshuffle_1D(): _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=1) @@ -530,6 +535,7 @@ class TestPixelShuffle(TestCaseMPS): def test_pixel_shuffle_unshuffle_5D(): _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=5) + test_pixel_shuffle_large_upscale_factor() test_pixel_shuffle_unshuffle_1D() test_pixel_shuffle_unshuffle_2D() test_pixel_shuffle_unshuffle_3D() diff --git a/test/test_nn.py b/test/test_nn.py index 49c503ac132..cb755992ffc 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -4565,6 +4565,11 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""") _test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, downscale_factor=0) _test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, downscale_factor=-2) + def test_pixel_shuffle_large_upscale_factor(): + with self.assertRaises(ValueError): + ps = nn.PixelShuffle(545460846592) + ps(torch.randn(2, 16, 9, 3)) + def test_pixel_shuffle_unshuffle_1D(): _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=1) @@ -4580,6 +4585,7 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""") def test_pixel_shuffle_unshuffle_5D(): _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=5) + test_pixel_shuffle_large_upscale_factor() test_pixel_shuffle_unshuffle_1D() test_pixel_shuffle_unshuffle_2D() test_pixel_shuffle_unshuffle_3D()