diff --git a/aten/src/ATen/native/AveragePool2d.cpp b/aten/src/ATen/native/AveragePool2d.cpp index 368dc02c283..035228285bc 100644 --- a/aten/src/ATen/native/AveragePool2d.cpp +++ b/aten/src/ATen/native/AveragePool2d.cpp @@ -25,18 +25,19 @@ TORCH_PRECOMPUTE_META_FUNC(avg_pool2d) // #20866, #22032: Guarantee this for the official C++ API? TORCH_CHECK(kernel_size.size() == 1 || kernel_size.size() == 2, "avg_pool2d: kernel_size must either be a single int, or a tuple of two ints"); - const int64_t kH = kernel_size[0]; - const int64_t kW = kernel_size.size() == 1 ? kH : kernel_size[1]; + const int kH = safe_downcast(kernel_size[0]); + const int kW = kernel_size.size() == 1 ? kH : safe_downcast(kernel_size[1]); TORCH_CHECK(stride.empty() || stride.size() == 1 || stride.size() == 2, "avg_pool2d: stride must either be omitted, a single int, or a tuple of two ints"); - const int64_t dH = stride.empty() ? kH : stride[0]; - const int64_t dW = stride.empty() ? kW : stride.size() == 1 ? dH : stride[1]; + const int dH = stride.empty() ? kH : safe_downcast(stride[0]); + const int dW = stride.empty() ? kW : + stride.size() == 1 ? dH : safe_downcast(stride[1]); TORCH_CHECK(padding.size() == 1 || padding.size() == 2, "avg_pool2d: padding must either be a single int, or a tuple of two ints"); - const int64_t padH = padding[0]; - const int64_t padW = padding.size() == 1 ? padH : padding[1]; + const int padH = safe_downcast(padding[0]); + const int padW = padding.size() == 1 ? padH : safe_downcast(padding[1]); TORCH_CHECK(!divisor_override.has_value() || divisor_override.value() != 0, "divisor must be not zero"); diff --git a/test/nn/test_pooling.py b/test/nn/test_pooling.py index c3a7b829b2b..f20ee2a29d5 100644 --- a/test/nn/test_pooling.py +++ b/test/nn/test_pooling.py @@ -898,6 +898,16 @@ torch.cuda.synchronize() inp = torch.ones(1, 0, 50, 44, 31, device=device) mod(inp) + @onlyCPU + def test_LPPool1d_kernel_size_overflow_large(self, device): + avgpool = torch.nn.LPPool1d( + -1.38119e150, 7879455037536781369, ceil_mode=True + ).to(device) + inp = torch.randn(3, 15, device=device) + + with self.assertRaisesRegex(RuntimeError, "integer out of range"): + avgpool(inp) + @onlyNativeDeviceTypes def test_AvgPool2d_empty(self, device): avgpool = torch.nn.AvgPool2d(3, stride=2).to(device)