mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fixes floating point exception in torch.nn.PixelShuffle (#163154)
Fixes #162251 **Previous Output:** `Floating point exception (core dumped)` **Now Output:** `RuntimeError: upscale factor is too large, (upscale_factor}^2 overflowed: upscale_factor=545460846592` Pull Request resolved: https://github.com/pytorch/pytorch/pull/163154 Approved by: https://github.com/cyyever, https://github.com/albanD
This commit is contained in:
parent
60992d98b2
commit
84d8d06fc3
|
|
@ -11,6 +11,8 @@ inline void check_pixel_shuffle_shapes(const Tensor& self, int64_t upscale_facto
|
|||
"pixel_shuffle expects a positive upscale_factor, but got ",
|
||||
upscale_factor);
|
||||
int64_t c = self.size(-3);
|
||||
TORCH_CHECK_VALUE(upscale_factor <= std::numeric_limits<decltype(upscale_factor)>::max() / upscale_factor,
|
||||
"upscale factor is too large, (upscale_factor)^2 overflowed: upscale_factor=", upscale_factor);
|
||||
int64_t upscale_factor_squared = upscale_factor * upscale_factor;
|
||||
TORCH_CHECK(c % upscale_factor_squared == 0,
|
||||
"pixel_shuffle expects its input's 'channel' dimension to be divisible by the square of "
|
||||
|
|
|
|||
|
|
@ -515,6 +515,11 @@ class TestPixelShuffle(TestCaseMPS):
|
|||
_test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, downscale_factor=0)
|
||||
_test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, downscale_factor=-2)
|
||||
|
||||
def test_pixel_shuffle_large_upscale_factor():
|
||||
with self.assertRaises(ValueError):
|
||||
ps = nn.PixelShuffle(545460846592)
|
||||
ps(torch.randn(2, 16, 9, 3))
|
||||
|
||||
def test_pixel_shuffle_unshuffle_1D():
|
||||
_test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=1)
|
||||
|
||||
|
|
@ -530,6 +535,7 @@ class TestPixelShuffle(TestCaseMPS):
|
|||
def test_pixel_shuffle_unshuffle_5D():
|
||||
_test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=5)
|
||||
|
||||
test_pixel_shuffle_large_upscale_factor()
|
||||
test_pixel_shuffle_unshuffle_1D()
|
||||
test_pixel_shuffle_unshuffle_2D()
|
||||
test_pixel_shuffle_unshuffle_3D()
|
||||
|
|
|
|||
|
|
@ -4565,6 +4565,11 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
|
|||
_test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, downscale_factor=0)
|
||||
_test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, downscale_factor=-2)
|
||||
|
||||
def test_pixel_shuffle_large_upscale_factor():
|
||||
with self.assertRaises(ValueError):
|
||||
ps = nn.PixelShuffle(545460846592)
|
||||
ps(torch.randn(2, 16, 9, 3))
|
||||
|
||||
def test_pixel_shuffle_unshuffle_1D():
|
||||
_test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=1)
|
||||
|
||||
|
|
@ -4580,6 +4585,7 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
|
|||
def test_pixel_shuffle_unshuffle_5D():
|
||||
_test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=5)
|
||||
|
||||
test_pixel_shuffle_large_upscale_factor()
|
||||
test_pixel_shuffle_unshuffle_1D()
|
||||
test_pixel_shuffle_unshuffle_2D()
|
||||
test_pixel_shuffle_unshuffle_3D()
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user