Fixes floating point exception in torch.nn.PixelShuffle (#163154)

Fixes #162251

**Previous Output:**
`Floating point exception (core dumped)`

**Now Output:**
`RuntimeError: upscale factor is too large, (upscale_factor}^2 overflowed: upscale_factor=545460846592`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163154
Approved by: https://github.com/cyyever, https://github.com/albanD
This commit is contained in:
arkadip-maitra 2025-10-22 02:22:12 +00:00 committed by PyTorch MergeBot
parent 60992d98b2
commit 84d8d06fc3
3 changed files with 14 additions and 0 deletions

View File

@ -11,6 +11,8 @@ inline void check_pixel_shuffle_shapes(const Tensor& self, int64_t upscale_facto
"pixel_shuffle expects a positive upscale_factor, but got ", "pixel_shuffle expects a positive upscale_factor, but got ",
upscale_factor); upscale_factor);
int64_t c = self.size(-3); int64_t c = self.size(-3);
TORCH_CHECK_VALUE(upscale_factor <= std::numeric_limits<decltype(upscale_factor)>::max() / upscale_factor,
"upscale factor is too large, (upscale_factor)^2 overflowed: upscale_factor=", upscale_factor);
int64_t upscale_factor_squared = upscale_factor * upscale_factor; int64_t upscale_factor_squared = upscale_factor * upscale_factor;
TORCH_CHECK(c % upscale_factor_squared == 0, TORCH_CHECK(c % upscale_factor_squared == 0,
"pixel_shuffle expects its input's 'channel' dimension to be divisible by the square of " "pixel_shuffle expects its input's 'channel' dimension to be divisible by the square of "

View File

@ -515,6 +515,11 @@ class TestPixelShuffle(TestCaseMPS):
_test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, downscale_factor=0) _test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, downscale_factor=0)
_test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, downscale_factor=-2) _test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, downscale_factor=-2)
def test_pixel_shuffle_large_upscale_factor():
with self.assertRaises(ValueError):
ps = nn.PixelShuffle(545460846592)
ps(torch.randn(2, 16, 9, 3))
def test_pixel_shuffle_unshuffle_1D(): def test_pixel_shuffle_unshuffle_1D():
_test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=1) _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=1)
@ -530,6 +535,7 @@ class TestPixelShuffle(TestCaseMPS):
def test_pixel_shuffle_unshuffle_5D(): def test_pixel_shuffle_unshuffle_5D():
_test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=5) _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=5)
test_pixel_shuffle_large_upscale_factor()
test_pixel_shuffle_unshuffle_1D() test_pixel_shuffle_unshuffle_1D()
test_pixel_shuffle_unshuffle_2D() test_pixel_shuffle_unshuffle_2D()
test_pixel_shuffle_unshuffle_3D() test_pixel_shuffle_unshuffle_3D()

View File

@ -4565,6 +4565,11 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
_test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, downscale_factor=0) _test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, downscale_factor=0)
_test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, downscale_factor=-2) _test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, downscale_factor=-2)
def test_pixel_shuffle_large_upscale_factor():
with self.assertRaises(ValueError):
ps = nn.PixelShuffle(545460846592)
ps(torch.randn(2, 16, 9, 3))
def test_pixel_shuffle_unshuffle_1D(): def test_pixel_shuffle_unshuffle_1D():
_test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=1) _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=1)
@ -4580,6 +4585,7 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
def test_pixel_shuffle_unshuffle_5D(): def test_pixel_shuffle_unshuffle_5D():
_test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=5) _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=5)
test_pixel_shuffle_large_upscale_factor()
test_pixel_shuffle_unshuffle_1D() test_pixel_shuffle_unshuffle_1D()
test_pixel_shuffle_unshuffle_2D() test_pixel_shuffle_unshuffle_2D()
test_pixel_shuffle_unshuffle_3D() test_pixel_shuffle_unshuffle_3D()